diff --git a/.bazelrc b/.bazelrc index 7d5bf267b31..7ca8b50fb7a 100644 --- a/.bazelrc +++ b/.bazelrc @@ -50,7 +50,6 @@ # Feature and Third party library support options: # xla: Build TF with XLA # tpu: Build TF with TPU support -# using_cuda: CUDA is available to build system. # cuda: Build with full cuda support. # rocm: Build with AMD GPU support (rocm). # mkl: Enable full mkl support. @@ -92,6 +91,7 @@ # release_cpu_macos: Toolchain and CUDA options for MacOS CPU builds. # release_gpu_linux: Toolchain and CUDA options for Linux GPU builds. # release_gpu_linux_cuda_10_1: Toolchain and CUDA options for CUDA 10.1 Linux GPU builds. +# release_gpu_linux_cuda_11_2: Toolchain and CUDA options for CUDA 11.2 Linux GPU builds. # release_cpu_windows: Toolchain and CUDA options for Windows CPU builds. # release_gpu_windows: Toolchain and CUDA options for Windows GPU builds. @@ -190,23 +190,22 @@ build:mkl_aarch64 -c opt # This config refers to building with CUDA available. It does not necessarily # mean that we build CUDA op kernels. -build:using_cuda --define=using_cuda=true -build:using_cuda --action_env TF_NEED_CUDA=1 -build:using_cuda --crosstool_top=@local_config_cuda//crosstool:toolchain +build:cuda_base --@local_config_cuda//:enable_cuda +build:cuda_base --action_env TF_NEED_CUDA=1 +build:cuda_base --crosstool_top=@local_config_cuda//crosstool:toolchain # Enable the mlir generated GPU kernels only for cuda builds. build --define=tensorflow_enable_mlir_generated_gpu_kernels=0 # This is a more specific option, so it takes precedence over the line above for cuda builds. -build:using_cuda --define=tensorflow_enable_mlir_generated_gpu_kernels=1 +build:cuda_base --define=tensorflow_enable_mlir_generated_gpu_kernels=1 # This config refers to building CUDA op kernels with nvcc. -build:cuda --config=using_cuda -build:cuda --define=using_cuda_nvcc=true +build:cuda --config=cuda_base +build:cuda --@local_config_cuda//:cuda_compiler=nvcc # This config refers to building CUDA op kernels with clang. -build:cuda_clang --config=using_cuda -build:cuda_clang --define=using_cuda_clang=true -build:cuda_clang --define=using_clang=true +build:cuda_clang --config=cuda_base +build:cuda_clang --@local_config_cuda//:cuda_compiler=clang build:cuda_clang --action_env TF_CUDA_CLANG=1 # dbg config, as a shorthand for '--config=opt -c dbg' @@ -430,6 +429,9 @@ build:rbe_cpu_linux --host_platform="@ubuntu16.04-manylinux2010-py3_config_platf build:rbe_cpu_linux --platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform" build:rbe_linux_cuda_base --config=rbe_linux +build:rbe_linux_cuda_base --@local_config_cuda//:enable_cuda +build:rbe_linux_cuda_base --@local_config_cuda//:cuda_compiler=nvcc +# TODO(csigg): those probably don't do anything because cuda_config is remote. build:rbe_linux_cuda_base --repo_env=TF_NEED_TENSORRT=1 build:rbe_linux_cuda_base --repo_env=TF_CUDA_VERSION=10 build:rbe_linux_cuda_base --repo_env=TF_CUDNN_VERSION=7 @@ -438,7 +440,6 @@ build:rbe_linux_cuda_base --repo_env=TF_NEED_CUDA=1 test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:rbe_linux_cuda10.1_nvcc_base --config=rbe_linux_cuda_base -build:rbe_linux_cuda10.1_nvcc_base --define=using_cuda_nvcc=true build:rbe_linux_cuda10.1_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain" build:rbe_linux_cuda10.1_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain" build:rbe_linux_cuda10.1_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64" @@ -455,7 +456,6 @@ build:rbe_linux_cuda10.1_nvcc_py3.7 --config=rbe_linux_cuda10.1_nvcc_base --repo build:rbe_linux_cuda10.1_nvcc_py3.8 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8" build:rbe_linux_cuda11.0_nvcc_base --config=rbe_linux_cuda_base -build:rbe_linux_cuda11.0_nvcc_base --define=using_cuda_nvcc=true build:rbe_linux_cuda11.0_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain" build:rbe_linux_cuda11.0_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain" build:rbe_linux_cuda11.0_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain-linux-x86_64" @@ -623,6 +623,13 @@ build:release_gpu_linux_cuda_10_1 --action_env CUDA_TOOLKIT_PATH="/usr/local/cud build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDA_VERSION="10" build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDNN_VERSION="7" + +build:release_gpu_linux_cuda_11_2 --config=release_gpu_linux +build:release_gpu_linux_cuda_11_2 --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-11.2" +build:release_gpu_linux_cuda_11_2 --crosstool_top=@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2:toolchain +build:release_gpu_linux_cuda_11_2 --action_env=TF_CUDA_VERSION="11.2" +build:release_gpu_linux_cuda_11_2 --action_env=TF_CUDNN_VERSION="8.1" + # Address sanitizer # CC=clang bazel build --config asan build:asan --strip=never diff --git a/RELEASE.md b/RELEASE.md index 19b64a21043..8a85d170565 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -119,6 +119,10 @@ `tf.config.experimental.get_memory_usage` in favor of this new function. * Extended `tf.config.experimental.enable_tensor_float_32_execution` to control Tensor-Float-32 evaluation in RNNs. + * Added a 'experimental_payloads' field to tf.errors.OpError and + its subclasses to support more detailed error reporting. + This is inspired from Abseil Status payloads: + https://github.com/abseil/abseil-cpp/blob/master/absl/status/status.h * `tf.summary`: * New `tf.summary.graph` allows manual write of TensorFlow graph @@ -140,7 +144,11 @@ `max_batch_size`. Previously, we issued a warning when the value of `rewriter_config_template` is not None. We issued an error when the value of `is_dynamic_op` is not True. We didn't use the value for - `max_batch_size` for building TensorRT engines. + `max_batch_size` for building TensorRT engines. Add parameters + `use_dynamic_shape` to enable dynamic shape support. The default is to + disable dynamic shape support. Add `dynamic_shape_profile_strategy` + for selecting a dynamic shape profile strategy. The default is profile + strategy is `Range`. * Issue a warning when function get_tensorrt_rewriter_config is used. * TF XLA diff --git a/tensorflow/BUILD b/tensorflow/BUILD index b5416cfee46..7fe37e881b0 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -197,7 +197,19 @@ config_setting( config_setting( name = "windows", - values = {"cpu": "x64_windows"}, + # Internal builds query the target OS. + # copybara:uncomment flag_values = {"//tools/cpp:cc_target_os": "windows"}, + # OSS builds query the CPU type. + values = {"cpu": "x64_windows"}, # copybara:comment + visibility = ["//visibility:public"], +) + +config_setting( + name = "msvc_cl_debug", + values = { + "compiler": "msvc-cl", + "compilation_mode": "dbg", + }, visibility = ["//visibility:public"], ) @@ -402,11 +414,12 @@ config_setting( # Crosses between platforms and file system libraries not supported on those # platforms due to limitations in nested select() statements. -config_setting( +selects.config_setting_group( name = "with_cuda_support_windows_override", - define_values = {"using_cuda_nvcc": "true"}, - values = {"cpu": "x64_windows"}, - visibility = ["//visibility:public"], + match_all = [ + ":using_cuda_nvcc", + ":windows", + ], ) config_setting( @@ -470,9 +483,46 @@ selects.config_setting_group( ], ) -config_setting( +# Config setting that is satisfied when TensorFlow is being built with CUDA +# support through e.g. `--config=cuda` (or `--config=cuda_clang` in OSS). +alias( + name = "is_cuda_enabled", + actual = if_oss( + "@local_config_cuda//:is_cuda_enabled", + "@local_config_cuda//cuda:using_clang", + ), +) + +# Config setting that is satisfied when CUDA device code should be compiled +# with clang. It does not imply that CUDA support has been enabled. +alias( + name = "is_cuda_compiler_clang", + actual = if_oss( + "@local_config_cuda//:is_cuda_compiler_clang", + "@local_config_cuda//cuda:TRUE", + ), +) + +# Config setting that is satisfied when CUDA device code should be compiled +# with nvcc. It does not imply that CUDA support has been enabled. +alias( + name = "is_cuda_compiler_nvcc", + actual = if_oss( + "@local_config_cuda//:is_cuda_compiler_nvcc", + "@local_config_cuda//cuda:FALSE", + ), +) + +# Config setting whether TensorFlow is built with CUDA support using clang. +alias( name = "using_cuda_clang", - define_values = {"using_cuda_clang": "true"}, + actual = "@local_config_cuda//cuda:using_clang", +) + +# Config setting whether TensorFlow is built with CUDA support using nvcc. +alias( + name = "using_cuda_nvcc", + actual = "@local_config_cuda//cuda:using_nvcc", ) # Config setting to use in select()s to distinguish open source build from @@ -493,12 +543,12 @@ bool_setting( visibility = ["//visibility:private"], ) -config_setting( +selects.config_setting_group( name = "using_cuda_clang_with_dynamic_build", - define_values = { - "using_cuda_clang": "true", - "framework_shared_object": "true", - }, + match_all = [ + ":using_cuda_clang", + ":framework_shared_object", + ], ) selects.config_setting_group( @@ -519,19 +569,6 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting( - name = "using_cuda_nvcc", - define_values = {"using_cuda_nvcc": "true"}, -) - -config_setting( - name = "using_cuda_nvcc_with_dynamic_build", - define_values = { - "using_cuda_nvcc": "true", - "framework_shared_object": "true", - }, -) - selects.config_setting_group( name = "build_oss_using_cuda_nvcc", match_all = [ @@ -559,15 +596,6 @@ config_setting( visibility = ["//visibility:public"], ) -# This flag is defined for select statements that match both -# on 'windows' and 'api_version_2'. In this case, bazel requires -# having a flag which is a superset of these two. -config_setting( - name = "windows_and_api_version_2", - define_values = {"tf_api_version": "2"}, - values = {"cpu": "x64_windows"}, -) - # This flag enables experimental MLIR support. config_setting( name = "with_mlir_support", @@ -934,6 +962,11 @@ tf_cc_shared_object( "-z defs", "-Wl,--version-script,$(location //tensorflow:tf_version_script.lds)", ], + }) + select({ + "//tensorflow:msvc_cl_debug": [ + "/DEBUG:FASTLINK", + ], + "//conditions:default": [], }), per_os_targets = True, soversion = VERSION, @@ -952,7 +985,6 @@ tf_cc_shared_object( "//tensorflow/cc:cc_ops", "//tensorflow/cc:client_session", "//tensorflow/cc:scope", - "//tensorflow/cc/profiler", "//tensorflow/core:tensorflow", ], ) diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 70e4fe56cae..9063049a871 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -283,7 +283,6 @@ tf_cuda_cc_test( "//tensorflow/c/experimental/gradients:not_differentiable", "//tensorflow/c/experimental/gradients/tape:tape_context", "//tensorflow/c/experimental/ops", - "//tensorflow/cc/profiler", "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -909,7 +908,6 @@ tf_cuda_cc_test( ":c_api_experimental", ":c_api_test_util", "//tensorflow/c:c_test_util", - "//tensorflow/cc/profiler", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", @@ -935,7 +933,6 @@ tf_cuda_cc_test( "//tensorflow/c:c_api", "//tensorflow/c:c_test_util", "//tensorflow/c:tf_status_helper", - "//tensorflow/cc/profiler", "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -974,7 +971,6 @@ tf_cc_test( ":custom_device_testutil", "//tensorflow/c:c_api", "//tensorflow/c:c_test_util", - "//tensorflow/cc/profiler", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc index 4fe83b5116d..c1949ae826f 100644 --- a/tensorflow/c/eager/c_api_experimental_test.cc +++ b/tensorflow/c/eager/c_api_experimental_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_test_util.h" -#include "tensorflow/cc/profiler/profiler.h" #include "tensorflow/core/lib/monitoring/collection_registry.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" diff --git a/tensorflow/c/tf_status_helper.cc b/tensorflow/c/tf_status_helper.cc index e0097e88019..7abd28b25a4 100644 --- a/tensorflow/c/tf_status_helper.cc +++ b/tensorflow/c/tf_status_helper.cc @@ -79,6 +79,7 @@ void Set_TF_Status_from_Status(TF_Status* tf_status, const Status& status) { assert(0); break; } + tf_status->status.ReplaceAllPayloads(status.GetAllPayloads()); } Status StatusFromTF_Status(const TF_Status* tf_status) { diff --git a/tensorflow/c/tf_status_helper_test.cc b/tensorflow/c/tf_status_helper_test.cc index 60780d74b21..0bd9d1e4e3c 100644 --- a/tensorflow/c/tf_status_helper_test.cc +++ b/tensorflow/c/tf_status_helper_test.cc @@ -24,6 +24,8 @@ namespace { TEST(StatusHelper, TestStatusHelper) { TF_Status* s = TF_NewStatus(); Status cc_status(errors::InvalidArgument("some error")); + cc_status.SetPayload("key1", "value1"); + cc_status.SetPayload("key2", "value2"); Set_TF_Status_from_Status(s, cc_status); ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)); ASSERT_EQ(std::string("some error"), TF_Message(s)); @@ -32,6 +34,9 @@ TEST(StatusHelper, TestStatusHelper) { ASSERT_FALSE(another_cc_status.ok()); ASSERT_EQ(std::string("some error"), another_cc_status.error_message()); ASSERT_EQ(error::INVALID_ARGUMENT, another_cc_status.code()); + // Ensure the payloads are not lost during conversions + ASSERT_EQ(cc_status.GetPayload("key1"), another_cc_status.GetPayload("key1")); + ASSERT_EQ(cc_status.GetPayload("key2"), another_cc_status.GetPayload("key2")); TF_DeleteStatus(s); } diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 8f7e447d322..61399f33292 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -6,7 +6,6 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "cc_library_with_android_deps", - "tf_cc_binary", "tf_cc_test", "tf_copts", "transitive_hdrs", @@ -748,36 +747,6 @@ tf_gen_op_wrappers_cc( ], ) -tf_cc_binary( - name = "tutorials_example_trainer", - srcs = ["tutorials/example_trainer.cc"], - copts = tf_copts(), - linkopts = select({ - "//tensorflow:windows": [], - "//tensorflow:macos": [ - "-lm", - "-lpthread", - ], - "//tensorflow:ios": [ - "-lm", - "-lpthread", - ], - "//conditions:default": [ - "-lm", - "-lpthread", - "-lrt", - ], - }), - deps = [ - ":cc_ops", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensorflow", - ], -) - cc_library( name = "queue_runner", srcs = ["training/queue_runner.cc"], @@ -856,7 +825,6 @@ transitive_hdrs( ":queue_runner", ":remote_fused_graph_ops", ":scope", - "//tensorflow/cc/profiler", "//tensorflow/cc/saved_model:constants", "//tensorflow/cc/saved_model:loader", "//tensorflow/cc/saved_model:reader", diff --git a/tensorflow/cc/profiler/BUILD b/tensorflow/cc/profiler/BUILD deleted file mode 100644 index 43240506f8c..00000000000 --- a/tensorflow/cc/profiler/BUILD +++ /dev/null @@ -1,40 +0,0 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -tf_cuda_cc_test( - name = "profiler_test", - srcs = ["profiler_test.cc"], - tags = [ - "no_gpu", # b/77649654 - "no_rocm", # stream level tracing not supported on ROCm - ], - deps = [ - ":profiler", - "//tensorflow/cc:cc_ops", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensorflow", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - -cc_library( - name = "profiler", - srcs = ["profiler.cc"], - hdrs = ["profiler.h"], - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/profiler:protos_all_cc", - "//tensorflow/core/profiler:tfprof_options", - "//tensorflow/core/profiler/internal:tfprof_stats", - ], -) diff --git a/tensorflow/cc/profiler/profiler.cc b/tensorflow/cc/profiler/profiler.cc deleted file mode 100644 index 3e55bac73e6..00000000000 --- a/tensorflow/cc/profiler/profiler.cc +++ /dev/null @@ -1,57 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/profiler/profiler.h" - -namespace tensorflow { -namespace tfprof { - -Profiler::Profiler(const GraphDef& graph) { - std::unique_ptr graph_ptr(new GraphDef()); - *graph_ptr = graph; - stats_.reset(new TFStats(std::move(graph_ptr), nullptr, nullptr, nullptr)); -} - -void Profiler::AddStep(int64 step, const RunMetadata& run_meta) { - std::unique_ptr run_meta_ptr(new RunMetadata()); - *run_meta_ptr = run_meta; - stats_->AddRunMeta(step, std::move(run_meta_ptr)); -} - -GraphNodeProto Profiler::ProfileGraph(const Options& options) { - stats_->BuildView(kCmds[1]); - return stats_->ShowGraphNode(kCmds[1], options); -} - -GraphNodeProto Profiler::ProfileNameScope(const Options& options) { - stats_->BuildView(kCmds[0]); - return stats_->ShowGraphNode(kCmds[0], options); -} - -MultiGraphNodeProto Profiler::ProfileOperations(const Options& options) { - stats_->BuildView(kCmds[3]); - return stats_->ShowMultiGraphNode(kCmds[3], options); -} - -Status Profiler::SerializeToString(string* content) { - if (!content) { - return Status(error::Code::INVALID_ARGUMENT, - "Cannot use null string pointer for SerializeToString."); - } - stats_->SerializeToString(content); - return Status::OK(); -} - -} // namespace tfprof -} // namespace tensorflow diff --git a/tensorflow/cc/profiler/profiler.h b/tensorflow/cc/profiler/profiler.h deleted file mode 100644 index dc60fd5fb37..00000000000 --- a/tensorflow/cc/profiler/profiler.h +++ /dev/null @@ -1,100 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CC_PROFILER_PROFILER_H_ -#define TENSORFLOW_CC_PROFILER_PROFILER_H_ - -#include -#include - -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/profiler/internal/tfprof_stats.h" -#include "tensorflow/core/profiler/tfprof_options.h" -#include "tensorflow/core/profiler/tfprof_output.pb.h" - -namespace tensorflow { -namespace tfprof { - -/// @addtogroup core -/// @{ - -/// A `Profiler` object lets the caller profile the execution of a graph. -/// -/// Example: -/// // First build a graph and run tracing. -/// Scope root = Scope::NewRootScope(); -/// auto a = Placeholder(root, DT_INT32); -/// auto c = Add(root, a, {41}); -/// -/// ClientSession session(root); -/// std::vector outputs; -/// RunOptions run_options; -/// run_options.set_trace_level(RunOptions::FULL_TRACE); -/// RunMetadata run_meta; -/// Status s = session.Run(run_options, { {a, {1}} }, {c}, &outputs, -/// &run_meta); -/// if (!s.ok()) { ... } -/// -/// // Then create profiler to do profiling. -/// GraphDef graph; -/// root.ToGraphDef(&graph); -/// Profiler profiler(graph); -/// profiler.AddStep(0, run_meta); -/// Options opts = ... // TODO(xpan): Support option building API. -/// MultiGraphNodeProto r = profiler.ProfileOperations(opts); -/// -class Profiler { - public: - /// `graph` is the model's GraphDef. - explicit Profiler(const GraphDef& graph); - - /// Adds tracing information `run_meta` to profiler. A `run_meta` is - /// generated by a TensorFlow session run call. `step` is the key - /// to the `run_meta`. When calling ProfileXXX methods, caller can specify - /// `step` in `options` to selectively profile the corresponding `run_meta`. - /// Multiple different `run_meta` can be keyed by the same `step` in order - /// to group them together. - void AddStep(int64 step, const RunMetadata& run_meta); - - /// Profiles the model by organizing nodes in graph structure. - /// Each node is an op and the nodes are connected by the op inputs/outputs. - GraphNodeProto ProfileGraph(const Options& options); - - /// Profiles the model by organizing nodes in name scope structure. - /// Each node is an op, and nodes are organized by the ops' name - /// scope, similar to a file system tree. - /// E.g. /foo is the root of operation /foo/matmul_1 and foo/conv_2. - GraphNodeProto ProfileNameScope(const Options& options); - - /// Profiles the model by organizing nodes by operation types. - /// Each node is an operation type (e.g. Conv2D or MatMul), containing all - /// ops belonging to that type in the model. - MultiGraphNodeProto ProfileOperations(const Options& options); - - /// Serialize the profile content (ProfileProto) into a binary string, - /// User can write the string to file for offline analysis by - /// tfprof command-line tools or graphical user interface. - Status SerializeToString(string* content); - - private: - std::unique_ptr stats_; -}; -/// @} - -} // namespace tfprof -} // namespace tensorflow - -#endif // TENSORFLOW_CC_PROFILER_PROFILER_H_ diff --git a/tensorflow/cc/profiler/profiler_test.cc b/tensorflow/cc/profiler/profiler_test.cc deleted file mode 100644 index 280cd74827f..00000000000 --- a/tensorflow/cc/profiler/profiler_test.cc +++ /dev/null @@ -1,177 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/platform/test.h" - -#include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/cc/profiler/profiler.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/graph/default_device.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/public/session.h" - -namespace tensorflow { -namespace tfprof { - -class ProfilerTest : public ::testing::Test { - protected: - ProfilerTest() {} -}; - -GraphDef CreateGraphDef() { - Scope root = Scope::NewRootScope(); - - auto a = ops::Const(root, {{3, 2}, {-1, 0}}); - - auto x = ops::Const(root.WithOpName("x"), {{1.f}, {1.f}}); - - auto y = ops::MatMul(root.WithOpName("y"), a, x); - - auto y2 = ops::Square(root, y); - - auto y2_sum = ops::Sum(root, y2, 0); - - auto y_norm = ops::Sqrt(root, y2_sum); - - auto y_div = ops::Div(root.WithOpName("y_normalized"), y, y_norm); - - GraphDef def; - TF_CHECK_OK(root.ToGraphDef(&def)); - - return def; -} - -Options Default() { - Options opts(1000, /* max_depth */ - 0, /* min_bytes */ - 0, /* min_peak_bytes */ - 0, /* min_residual_bytes */ - 0, /* min_output_bytes */ - 0, /* min_micros */ - 0, /* min_accelerator_micros */ - 0, /* min_cpu_micros */ - 0, /* min_params */ - 0, /* min_float_ops */ - 0, /* min_occurrence */ - 0, /* step */ - "name", /* order_by */ - {".*"}, /* account_type_regexes */ - {".*"}, /* start_name_regexes */ - {}, /* trim_name_regexes */ - {".*"}, {}, /* hide_name_regexes */ - false, /* account_displayed_op_only */ - {"micros"}, /* select */ - {"none"}, /* output_type */ - {}); - return opts; -} - -template -const T* ExtractNode(const T& pb, const string& name) { - if (pb.name() == name) { - return &pb; - } - for (const T& c : pb.children()) { - const T* ret = ExtractNode(c, name); - if (ret) return ret; - } - return nullptr; -} - -TEST_F(ProfilerTest, Basics) { - SessionOptions options; - options.config.set_allow_soft_placement(true); - std::unique_ptr session(NewSession(options)); - GraphDef def = CreateGraphDef(); - if (options.target.empty()) { - graph::SetDefaultDevice("/gpu:0", &def); - } - - TF_CHECK_OK(session->Create(def)); - - Tensor x(DT_FLOAT, TensorShape({2, 1})); - auto x_flat = x.flat(); - x_flat.setRandom(); - Eigen::Tensor inv_norm = - x_flat.square().sum().sqrt().inverse(); - x_flat = x_flat * inv_norm(); - - std::vector outputs; - RunOptions run_options; - run_options.set_trace_level(RunOptions::FULL_TRACE); - RunMetadata run_metadata; - outputs.clear(); - - Profiler profiler(def); - for (int i = 0; i < 2; ++i) { - TF_CHECK_OK(session->Run(run_options, {{"x", x}}, {"y:0", "y_normalized:0"}, - {}, &outputs, &run_metadata)); - profiler.AddStep(i, run_metadata); - CHECK_EQ(size_t{2}, outputs.size()); - } - - std::vector resp; - TF_CHECK_OK(session->ListDevices(&resp)); - bool has_gpu = false; - for (const auto& dev : resp) { - if (dev.device_type() == "GPU") { - has_gpu = true; - } - } - - GraphNodeProto ret = profiler.ProfileNameScope(Default()); - const GraphNodeProto* matmul = ExtractNode(ret, "y"); - EXPECT_TRUE(matmul); - EXPECT_GT(matmul->exec_micros(), 0); - if (has_gpu) { - EXPECT_GT(matmul->accelerator_exec_micros(), 0); - } else { - EXPECT_EQ(matmul->accelerator_exec_micros(), 0); - } - const GraphNodeProto* square = ExtractNode(ret, "Square"); - EXPECT_TRUE(square); - EXPECT_GT(square->exec_micros(), 0); - if (has_gpu) { - EXPECT_GT(square->accelerator_exec_micros(), 0); - } else { - EXPECT_EQ(square->accelerator_exec_micros(), 0); - } - - Options opts2 = Default(); - opts2.output_type = "timeline"; - string timeline_file = io::JoinPath(testing::TmpDir(), "timeline"); - opts2.output_options["outfile"] = timeline_file; - GraphNodeProto ret2 = profiler.ProfileGraph(opts2); - string s; - TF_CHECK_OK(ReadFileToString(Env::Default(), timeline_file + "_0", &s)); - EXPECT_TRUE(s.find("Square") != s.npos); - - MultiGraphNodeProto ret3 = profiler.ProfileOperations(Default()); - const MultiGraphNodeProto* matmul2 = ExtractNode(ret3, "MatMul"); - EXPECT_TRUE(matmul2); - EXPECT_GT(matmul2->exec_micros(), 0); - if (has_gpu) { - EXPECT_GT(matmul2->accelerator_exec_micros(), 0); - } else { - EXPECT_EQ(matmul2->accelerator_exec_micros(), 0); - } - - TF_CHECK_OK(session->Close()); -} - -} // namespace tfprof -} // namespace tensorflow diff --git a/tensorflow/cc/tutorials/example_trainer.cc b/tensorflow/cc/tutorials/example_trainer.cc deleted file mode 100644 index 789662f84d0..00000000000 --- a/tensorflow/cc/tutorials/example_trainer.cc +++ /dev/null @@ -1,234 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/graph/default_device.h" -#include "tensorflow/core/graph/graph_def_builder.h" -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/init_main.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/public/session.h" - -using tensorflow::string; -using tensorflow::int32; - -namespace tensorflow { -namespace example { - -struct Options { - int num_concurrent_sessions = 1; // The number of concurrent sessions - int num_concurrent_steps = 10; // The number of concurrent steps - int num_iterations = 100; // Each step repeats this many times - bool use_gpu = false; // Whether to use gpu in the training -}; - -// A = [3 2; -1 0]; x = rand(2, 1); -// We want to compute the largest eigenvalue for A. -// repeat x = y / y.norm(); y = A * x; end -GraphDef CreateGraphDef() { - // TODO(jeff,opensource): This should really be a more interesting - // computation. Maybe turn this into an mnist model instead? - Scope root = Scope::NewRootScope(); - using namespace ::tensorflow::ops; // NOLINT(build/namespaces) - - // A = [3 2; -1 0]. Using Const means the result will be a - // float tensor even though the initializer has integers. - auto a = Const(root, {{3, 2}, {-1, 0}}); - - // x = [1.0; 1.0] - auto x = Const(root.WithOpName("x"), {{1.f}, {1.f}}); - - // y = A * x - auto y = MatMul(root.WithOpName("y"), a, x); - - // y2 = y.^2 - auto y2 = Square(root, y); - - // y2_sum = sum(y2). Note that you can pass constants directly as - // inputs. Sum() will automatically create a Const node to hold the - // 0 value. - auto y2_sum = Sum(root, y2, 0); - - // y_norm = sqrt(y2_sum) - auto y_norm = Sqrt(root, y2_sum); - - // y_normalized = y ./ y_norm - Div(root.WithOpName("y_normalized"), y, y_norm); - - GraphDef def; - TF_CHECK_OK(root.ToGraphDef(&def)); - - return def; -} - -string DebugString(const Tensor& x, const Tensor& y) { - CHECK_EQ(x.NumElements(), 2); - CHECK_EQ(y.NumElements(), 2); - auto x_flat = x.flat(); - auto y_flat = y.flat(); - // Compute an estimate of the eigenvalue via - // (x' A x) / (x' x) = (x' y) / (x' x) - // and exploit the fact that x' x = 1 by assumption - Eigen::Tensor lambda = (x_flat * y_flat).sum(); - return strings::Printf("lambda = %8.6f x = [%8.6f %8.6f] y = [%8.6f %8.6f]", - lambda(), x_flat(0), x_flat(1), y_flat(0), y_flat(1)); -} - -void ConcurrentSteps(const Options* opts, int session_index) { - // Creates a session. - SessionOptions options; - std::unique_ptr session(NewSession(options)); - GraphDef def = CreateGraphDef(); - if (options.target.empty()) { - graph::SetDefaultDevice(opts->use_gpu ? "/device:GPU:0" : "/cpu:0", &def); - } - - TF_CHECK_OK(session->Create(def)); - - // Spawn M threads for M concurrent steps. - const int M = opts->num_concurrent_steps; - std::unique_ptr step_threads( - new thread::ThreadPool(Env::Default(), "trainer", M)); - - for (int step = 0; step < M; ++step) { - step_threads->Schedule([&session, opts, session_index, step]() { - // Randomly initialize the input. - Tensor x(DT_FLOAT, TensorShape({2, 1})); - auto x_flat = x.flat(); - x_flat.setRandom(); - Eigen::Tensor inv_norm = - x_flat.square().sum().sqrt().inverse(); - x_flat = x_flat * inv_norm(); - - // Iterations. - std::vector outputs; - for (int iter = 0; iter < opts->num_iterations; ++iter) { - outputs.clear(); - TF_CHECK_OK( - session->Run({{"x", x}}, {"y:0", "y_normalized:0"}, {}, &outputs)); - CHECK_EQ(size_t{2}, outputs.size()); - - const Tensor& y = outputs[0]; - const Tensor& y_norm = outputs[1]; - // Print out lambda, x, and y. - std::printf("%06d/%06d %s\n", session_index, step, - DebugString(x, y).c_str()); - // Copies y_normalized to x. - x = y_norm; - } - }); - } - - // Delete the threadpool, thus waiting for all threads to complete. - step_threads.reset(nullptr); - TF_CHECK_OK(session->Close()); -} - -void ConcurrentSessions(const Options& opts) { - // Spawn N threads for N concurrent sessions. - const int N = opts.num_concurrent_sessions; - - // At the moment our Session implementation only allows - // one concurrently computing Session on GPU. - CHECK_EQ(1, N) << "Currently can only have one concurrent session."; - - thread::ThreadPool session_threads(Env::Default(), "trainer", N); - for (int i = 0; i < N; ++i) { - session_threads.Schedule(std::bind(&ConcurrentSteps, &opts, i)); - } -} - -} // end namespace example -} // end namespace tensorflow - -namespace { - -bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, - int32* dst) { - if (absl::ConsumePrefix(&arg, flag) && absl::ConsumePrefix(&arg, "=")) { - char extra; - return (sscanf(arg.data(), "%d%c", dst, &extra) == 1); - } - - return false; -} - -bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, - bool* dst) { - if (absl::ConsumePrefix(&arg, flag)) { - if (arg.empty()) { - *dst = true; - return true; - } - - if (arg == "=true") { - *dst = true; - return true; - } else if (arg == "=false") { - *dst = false; - return true; - } - } - - return false; -} - -} // namespace - -int main(int argc, char* argv[]) { - tensorflow::example::Options opts; - std::vector unknown_flags; - for (int i = 1; i < argc; ++i) { - if (string(argv[i]) == "--") { - while (i < argc) { - unknown_flags.push_back(argv[i]); - ++i; - } - break; - } - - if (ParseInt32Flag(argv[i], "--num_concurrent_sessions", - &opts.num_concurrent_sessions) || - ParseInt32Flag(argv[i], "--num_concurrent_steps", - &opts.num_concurrent_steps) || - ParseInt32Flag(argv[i], "--num_iterations", &opts.num_iterations) || - ParseBoolFlag(argv[i], "--use_gpu", &opts.use_gpu)) { - continue; - } - - fprintf(stderr, "Unknown flag: %s\n", argv[i]); - return -1; - } - - // Passthrough any unknown flags. - int dst = 1; // Skip argv[0] - for (char* f : unknown_flags) { - argv[dst++] = f; - } - argv[dst++] = nullptr; - argc = static_cast(unknown_flags.size() + 1); - tensorflow::port::InitMain(argv[0], &argc, &argv); - tensorflow::example::ConcurrentSessions(opts); -} diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index f4004488286..cfd0cdef07c 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -297,6 +297,7 @@ cc_library( # Public visibility is needed for external TF/XLA backends. visibility = ["//visibility:public"], deps = XLA_DEVICE_DEPS + [":xla_compilation_cache"], + alwayslink = 1, ) cc_library( @@ -342,6 +343,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/base", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -355,7 +357,9 @@ cc_library( "//tensorflow/compiler/xla:parse_flags_from_env", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "//tensorflow/core/protobuf:for_core_protos_cc", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index 15bd5340503..9112b8d021b 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -292,6 +292,37 @@ MlirCommonFlags* GetMlirCommonFlags() { return mlir_flags; } +ConfigProto::Experimental::MlirBridgeRollout GetMlirBridgeRolloutState( + absl::optional config_proto) { + // TF1 graphs that do not override Sessions's ConfigProto and TF2 graphs + // can enable/disable the graph via tf_mlir_enable_mlir_bridge. + auto tf_mlir_enable_mlir_bridge = + GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge; + if (tf_mlir_enable_mlir_bridge != + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED) { + return tf_mlir_enable_mlir_bridge; + } + + // If a ConfigProto was not passed in, we can assume the caller is + // checking if TF2 graph should have the bridge enabled / disabled. In that + // case, we have already checked tf_mlir_enable_mlir_bridge so it is safe to + // return UNSPECIFIED here. + if (!config_proto.has_value()) { + return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED; + } + + // TF1 graphs that do override Session's ConfigProto and set + // ConfigProto's enable_mlir_bridge or mlir_bridge_rollout fields will not + // update tf_mlir_enable_mlir_bridge so check their values. + + // ConfigProto's enable_mlir_bridge defaults to false so only respect it + // when it is true. + if (config_proto.value().experimental().enable_mlir_bridge()) { + return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED; + } + return config_proto.value().experimental().mlir_bridge_rollout(); +} + void AppendMarkForCompilationPassFlags(std::vector* flag_list) { absl::call_once(flags_init, &AllocateAndParseFlags); AppendMarkForCompilationPassFlagsInternal(flag_list); diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index b54dcf942c7..ef4d89b2b56 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/types/optional.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/util/command_line_flags.h" @@ -156,6 +157,11 @@ GetIntroduceFloatingPointJitterPassFlags(); MlirCommonFlags* GetMlirCommonFlags(); +// Returns the effective MLIR bridge rollout state based on the flags and the +// optional configuration. +ConfigProto::Experimental::MlirBridgeRollout GetMlirBridgeRolloutState( + absl::optional config_proto); + // Appends the flag definitions associated with // MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`. // diff --git a/tensorflow/compiler/mlir/g3doc/includes/tf_passes.md b/tensorflow/compiler/mlir/g3doc/includes/tf_passes.md index 4fcd0d3a244..6f5441fc0a7 100644 --- a/tensorflow/compiler/mlir/g3doc/includes/tf_passes.md +++ b/tensorflow/compiler/mlir/g3doc/includes/tf_passes.md @@ -621,3 +621,6 @@ func @tf_tpu_rewrite(%arg0: tensor<8xi32>) -> tensor<8xi32> { return %1 : tensor<8xi32> } ``` +### `-tf-verify-for-export`: Verify module is suitable for export back to TF Graph +Verifies whether all functions in module are of single tf_executor.graph and +each tf_executor.island in tf_executor.graph only has a single op. diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index c9db345a425..9a42d9574c8 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -708,4 +708,49 @@ def HLOClient_BroadcastSelectOp : HLOClient_Op< }]; } +//===----------------------------------------------------------------------===// +// Helper ops +//===----------------------------------------------------------------------===// + +def HLOClient_MinimumBroadcastShapesOp : + HLOClient_Op<"minimum_broadcast_shapes", [NoSideEffect]> { + string summary = "Minimizes the rank of two or more shapes to be broadcasted"; + + string description = [{ + Given two or more 1D tensors representing shapes, returns one 1D tensor for + each operand, where operand `i` corresponds to output `i`. + + The returned tensors have the property that they specify a shape which is a + reshape of the corresponding input shape, and the broadcasted output shape + (using shape::BroadcastOp) of the returned shapes is a reshape of the + broadcasted output shape of the input shapes. Among all possibilities with + this property, the one is chosen which minimizes the rank of each returned + shape. + + The general idea of this op is that it can be used for ops which have a + broadcasting semantic to operate on shapes with a possibly smaller rank + while preserving equivalence of the computed values. After computing the + result of the op using reshaped operands, the result can be reshaped to the + result that would have been originally computed. + + Here is an example with two input shapes: + + ```mlir + chlo.minimum_broadcast_shapes [1, 2, 3, 1, 2, 1], + [1, 1, 1, 2, 3] -> [6, 2, 1], [2, 3] + ``` + + The broadcasted output shape of the operands is [1, 2, 3, 1, 2, 3], the + broadcasted output shape of the outputs is [6, 2, 3]. These two shapes are + reshapes of each other, and also each output is a reshape of the + corresponding input. + }]; + + let arguments = (ins Variadic<1DTensorOf<[Index]>>:$shapes); + let results = (outs Variadic<1DTensorOf<[Index]>>:$results); + + let assemblyFormat = "$shapes attr-dict `:` type($shapes) `->` type($results)"; + +} + #endif // CHLO_OPS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h index f1763c35a0c..21e9c9f07dd 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index cca165eeb51..519873d977c 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -888,6 +888,11 @@ def HLO_ConcatenateOp : HLO_Op<"concatenate", let hasCanonicalizer = 1; let hasFolder = 1; + let extraClassDeclaration = [{ + static bool isCompatibleReturnTypes(ArrayRef l, ArrayRef r) { + return succeeded(mlir::verifyCompatibleShapes(l, r)); + } + }]; } def HLO_CollectivePermuteOp: HLO_Op<"collective_permute", diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc index 57ae27174d2..3a1caa9e794 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc @@ -337,6 +337,26 @@ static LogicalResult Verify(ConstantLikeOp op) { return success(); } +//===----------------------------------------------------------------------===// +// MinimumBroadcastShapesOp +//===----------------------------------------------------------------------===// +static LogicalResult Verify(MinimumBroadcastShapesOp op) { + // Check that the number of operands matches the number of outputs. + unsigned result_shapes_count = op.results().size(); + unsigned operand_shapes_count = op.shapes().size(); + if (operand_shapes_count != result_shapes_count) { + return op.emitOpError() + << "number of operand shapes (" << operand_shapes_count + << ") does not match number of result shapes (" + << result_shapes_count << ")"; + } + if (operand_shapes_count < 2) { + return op.emitOpError() << "number of operand shapes (" + << operand_shapes_count << ") should be >= 2"; + } + return success(); +} + LogicalResult ConstantLikeOp::inferReturnTypeComponents( MLIRContext* context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc index f9ee38b1f16..5daf6078239 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc @@ -51,6 +51,7 @@ struct ChloLegalizeToHloPass conversionTarget.addLegalDialect< MhloDialect, mlir::StandardOpsDialect, mlir::tensor::TensorDialect, mlir::shape::ShapeDialect, mlir::scf::SCFDialect>(); + conversionTarget.addLegalOp(); if (broadcast_only_) { chlo::PopulateChloBroadcastingPatterns(&getContext(), diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index b562e3b5a2a..abaa53f99d9 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -424,7 +424,7 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion { buffer_args.push_back(InsertAlloc(loc, result, &rewriter)); } auto new_op = rewriter.create(loc, llvm::None, buffer_args, - op.getAttrs()); + op->getAttrs()); // Copy over the operations inside the region. rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end()); @@ -607,8 +607,9 @@ struct HloLegalizeToLhlo populateHLOToLHLOConversionPattern(&context, &converter, &patterns); populateFuncOpTypeConversionPattern(patterns, &context, converter); populateCallOpTypeConversionPattern(patterns, &context, converter); - populateBranchOpInterfaceAndReturnOpTypeConversionPattern( - patterns, &context, converter); + populateBranchOpInterfaceTypeConversionPattern(patterns, &context, + converter); + populateReturnOpTypeConversionPattern(patterns, &context, converter); populateEliminateBufferizeMaterializationsPatterns(&context, converter, patterns); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 2a523e62778..a2b54f227ba 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -517,14 +517,16 @@ class HloDynamicBroadcastInDimConverter auto shape_type = shape.getType().cast(); int64_t result_rank = shape_type.getDimSize(0); + auto result_type = op.getType().dyn_cast(); + if (!result_type) return failure(); + SmallVector dyn_dims; Location loc = op.getLoc(); for (int i = 0; i < result_rank; ++i) { + if (!result_type.isDynamicDim(i)) continue; Value index = rewriter.create(loc, i); dyn_dims.push_back(rewriter.create(loc, shape, index)); } - auto result_type = op.getType().dyn_cast(); - if (!result_type) return failure(); int64_t nloops = result_type.getRank(); Value init = rewriter.create( @@ -1146,8 +1148,7 @@ SmallVector GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc, return dyn_shape; } -template +template class DotOpOnTensorsConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -1157,28 +1158,13 @@ class DotOpOnTensorsConversion : public OpConversionPattern { if (!VerifyHloOpBufferOrTensorSemantics(op)) { return failure(); } + if (GetDotOperationType(op) != op_type) return failure(); mhlo::DotOp::Adaptor adaptor(args); - auto lhs_el_type = - adaptor.lhs().getType().cast().getElementType(); - auto rhs_el_type = - adaptor.lhs().getType().cast().getElementType(); - if (lhs_el_type != rhs_el_type || !lhs_el_type.isa() || - lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) { - return failure(); - } - + Location loc = op.getLoc(); auto output_type = op.getType().cast(); auto output_el_type = output_type.getElementType(); - if (!output_el_type.isa() || - output_el_type.getIntOrFloatBitWidth() != output_bit_width) { - return failure(); - } - - if (GetDotOperationType(op) != op_type) return failure(); - - Location loc = op.getLoc(); auto zero_attr = rewriter.getZeroAttr(output_el_type); Value zero = rewriter.create(loc, zero_attr); SmallVector dyn_shape = GetDotOpInitTensorDynSizes( @@ -1205,8 +1191,6 @@ SmallVector GetDotGeneralOpInitTensorDynSizes( return dyn_shape; } -template class DotGeneralOpOnTensorsConversion : public OpConversionPattern { public: @@ -1245,23 +1229,10 @@ class DotGeneralOpOnTensorsConversion } mhlo::DotGeneralOp::Adaptor adaptor(args); - auto lhs_el_type = - adaptor.lhs().getType().cast().getElementType(); - auto rhs_el_type = - adaptor.lhs().getType().cast().getElementType(); - if (lhs_el_type != rhs_el_type || !lhs_el_type.isa() || - lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) { - return failure(); - } - - auto output_type = op.getType().cast(); - auto output_el_type = output_type.getElementType(); - if (!output_el_type.isa() || - output_el_type.getIntOrFloatBitWidth() != output_bit_width) { - return failure(); - } Location loc = op.getLoc(); + auto output_type = op.getType().cast(); + auto output_el_type = output_type.getElementType(); SmallVector dyn_shape = GetDotGeneralOpInitTensorDynSizes( rewriter, loc, adaptor.lhs(), adaptor.rhs(), output_type); auto zero_attr = rewriter.getZeroAttr(output_el_type); @@ -1269,7 +1240,7 @@ class DotGeneralOpOnTensorsConversion auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape); Value zero_tensor = rewriter.create(loc, init_tensor, zero).getResult(0); - Operation* linalg_op = rewriter.create( + Operation* linalg_op = rewriter.create( loc, /*resultTensorTypes=*/TypeRange{op.getType()}, /*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()}, /*outputBuffers=*/ValueRange{zero_tensor}); @@ -1476,9 +1447,14 @@ struct NormalConvOpOnTensorsConversion // The output shape is N spatial_dims F. SmallVector dyn_sizes; - for (int64_t i = 0, e = rank - 1; i < e; ++i) { - if (!result_type.isDynamicDim(i)) continue; - dyn_sizes.push_back(rewriter.create(loc, input, i)); + if (result_type.isDynamicDim(0)) { + dyn_sizes.push_back(rewriter.create(loc, input, 0)); + } + for (int64_t i = 1, e = rank - 1; i < e; ++i) { + if (result_type.isDynamicDim(i)) { + return rewriter.notifyMatchFailure( + op, "expected output spatial dims to be static shapes"); + } } if (result_type.isDynamicDim(rank - 1)) { dyn_sizes.push_back(rewriter.create(loc, filter, rank - 1)); @@ -1702,49 +1678,12 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, ReverseConverter, SliceConverter, TransposeConverter, - DotOpOnTensorsConversion, - DotOpOnTensorsConversion, - DotOpOnTensorsConversion, - DotOpOnTensorsConversion, - DotOpOnTensorsConversion, - DotOpOnTensorsConversion, - DotOpOnTensorsConversion, - DotOpOnTensorsConversion, - DotOpOnTensorsConversion, - DotOpOnTensorsConversion, - DotOpOnTensorsConversion, - DotOpOnTensorsConversion, - DotGeneralOpOnTensorsConversion, - DotGeneralOpOnTensorsConversion, - DotGeneralOpOnTensorsConversion, - DotGeneralOpOnTensorsConversion, + DotOpOnTensorsConversion, + DotGeneralOpOnTensorsConversion, NormalConvOpOnTensorsConversion, ReduceOnTensorsConversion, PadOpOnTensorsConversion>(context); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 4be9b650cbb..9325d7d27c5 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -126,7 +126,7 @@ struct ElementwiseOpConversion : public OpRewritePattern { Type flatResultTy = RankedTensorType::get({ShapedType::kDynamicSize}, resultElementTy); Value flatResult = - rewriter.create(loc, flatResultTy, flatOperands, op.getAttrs()); + rewriter.create(loc, flatResultTy, flatOperands, op->getAttrs()); // Restore original shape. rewriter.replaceOpWithNewOp(op, op.getType(), @@ -192,7 +192,7 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp rhs_is_scalar ? rhs : reshaped}; Value computed = rewriter.create( loc, TypeRange{RankedTensorType::get({-1}, result_element_type)}, - new_operands, op.getAttrs()); + new_operands, op->getAttrs()); // Reshape the result back into an unranked tensor. rewriter.replaceOpWithNewOp(op, result_type, @@ -223,9 +223,8 @@ struct ConvertUnrankedDynamicBroadcastOpHelper { } static Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op, - Value value, int targeted_rank) { + Value shape, int targeted_rank) { auto loc = op.getLoc(); - Value shape = builder.create(loc, value); SmallVector ranked_shape(targeted_rank, 1); auto unknown_rank_extent_tensor_type = RankedTensorType::get( {RankedTensorType::kDynamicSize}, builder.getIndexType()); @@ -246,6 +245,7 @@ struct ConvertUnrankedDynamicBroadcastOpHelper { static void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, ChloOpTy op, ValueRange operands, + ValueRange operand_shapes, int targeted_rank) { auto loc = op.getLoc(); SmallVector reshaped_operands; @@ -253,10 +253,12 @@ struct ConvertUnrankedDynamicBroadcastOpHelper { auto dynamic_dimensions = llvm::SmallVector( targeted_rank, RankedTensorType::kDynamicSize); - for (Value operand : operands) { + for (auto it : llvm::zip(operands, operand_shapes)) { + Value operand, shape; + std::tie(operand, shape) = it; // Handle shape broadcasting and inference. Value extended_operand_casted = - createBroadcastToKnownRank(if_builder, op, operand, targeted_rank); + createBroadcastToKnownRank(if_builder, op, shape, targeted_rank); // 1. Reshape operands to the given rank (with the same number of // elements) @@ -278,7 +280,7 @@ struct ConvertUnrankedDynamicBroadcastOpHelper { auto result_type = RankedTensorType::get(dynamic_dimensions, result_element_type); Value result = if_builder.create( - loc, ArrayRef{result_type}, reshaped_operands, op.getAttrs()); + loc, ArrayRef{result_type}, reshaped_operands, op->getAttrs()); Value reshaped_result = if_builder.create( loc, UnrankedTensorType::get(result_element_type), result); if_builder.create(loc, reshaped_result); @@ -290,13 +292,37 @@ struct ConvertUnrankedDynamicBroadcastOpHelper { ValueRange operands) { auto loc = op.getLoc(); - // Find the larger rank of the operands. + // Get the minimum broadcast shapes of the operands. + SmallVector shapes; + shapes.reserve(operands.size()); auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize}, rewriter.getIndexType()); - Value greater_rank; for (Value operand : operands) { Value shape = rewriter.create(loc, extent_tensor_type, operand); + shapes.push_back(shape); + } + auto broadcast_shape = rewriter.create( + loc, extent_tensor_type, shapes, nullptr); + SmallVector result_types(shapes.size(), extent_tensor_type); + auto reduced_shapes = + rewriter + .create(loc, result_types, shapes) + .results(); + SmallVector reshaped_operands; + reshaped_operands.reserve(operands.size()); + for (auto it : llvm::zip(operands, reduced_shapes)) { + Value operand; + Value reduced_shape; + std::tie(operand, reduced_shape) = it; + auto reshaped_operand = rewriter.create( + loc, operand.getType(), operand, reduced_shape); + reshaped_operands.push_back(reshaped_operand); + } + + // Find the largest rank of the operands. + Value greater_rank; + for (Value shape : reduced_shapes) { Value rank = rewriter.create(loc, rewriter.getIndexType(), shape); if (!greater_rank) { @@ -314,17 +340,19 @@ struct ConvertUnrankedDynamicBroadcastOpHelper { scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp( rewriter, op, greater_rank, 1); OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener()); - createRankSpecializedBroadcastAndOp(if_builder, op, operands, 1); + createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands, + reduced_shapes, 1); // Put each subsequent rank specialization inside the else statement of the // previous one. OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); - constexpr int kMaxRankSpecialization = 6; + constexpr int kMaxRankSpecialization = 5; for (int i = 2; i < kMaxRankSpecialization; i++) { auto inner_if = createIfOpForRankSpecializedBroadcastAndOp( else_builder, op, greater_rank, i); if_builder = inner_if.getThenBodyBuilder(rewriter.getListener()); - createRankSpecializedBroadcastAndOp(if_builder, op, operands, i); + createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands, + reduced_shapes, i); else_builder.create(loc, inner_if.getResult(0)); else_builder = inner_if.getElseBodyBuilder(rewriter.getListener()); } @@ -336,12 +364,15 @@ struct ConvertUnrankedDynamicBroadcastOpHelper { kMaxRankSpecialization), "Input for dynamic binary op lowering was of a rank greater than " + std::to_string(kMaxRankSpecialization)); - // Add the rank 6 specialization to the innermost else block. - createRankSpecializedBroadcastAndOp(else_builder, op, operands, - kMaxRankSpecialization); + // Add the rank 5 specialization to the innermost else block. + createRankSpecializedBroadcastAndOp(else_builder, op, reshaped_operands, + reduced_shapes, kMaxRankSpecialization); - // Return the result of the outermost if statement. - return if_op.getResult(0); + // Return the reshaped result of the outermost if statement. + auto result = if_op.getResult(0); + auto reshaped_result = rewriter.create( + loc, result.getType(), result, broadcast_shape); + return reshaped_result; } }; @@ -386,7 +417,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs); Value if_lhs_scalar_result = if_lhs_scalar_builder.create( loc, ArrayRef{result_type}, ArrayRef{reshaped_lhs, rhs}, - op.getAttrs()); + op->getAttrs()); Value extended_if_lhs_scalar_result = extendToBroadcastShape(if_lhs_scalar_builder, loc, if_lhs_scalar_result, shape_of_lhs, shape_of_rhs); @@ -409,7 +440,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp loc, RankedTensorType::get({}, rhs_type.getElementType()), rhs); Value if_rhs_scalar_result = if_rhs_scalar_builder.create( loc, ArrayRef{result_type}, ArrayRef{lhs, reshaped_rhs}, - op.getAttrs()); + op->getAttrs()); Value extended_if_rhs_scalar_result = extendToBroadcastShape(if_rhs_scalar_builder, loc, if_rhs_scalar_result, shape_of_lhs, shape_of_rhs); @@ -497,16 +528,17 @@ struct ConvertUnrankedDynamicBroadcastSelectOp struct TransformUnrankedHloPass : public PassWrapper { void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } void runOnFunction() override { // Setup conversion target. MLIRContext &ctx = getContext(); ConversionTarget target(ctx); - target.addLegalDialect(); + target.addLegalDialect(); target.addLegalOp(); #define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor(&target) #define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor(&target) diff --git a/tensorflow/compiler/mlir/hlo/tests/chlo_ops.mlir b/tensorflow/compiler/mlir/hlo/tests/chlo_ops.mlir new file mode 100644 index 00000000000..a4d5f79218b --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/chlo_ops.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s + +// CHECK-LABEL: func @minimum_broadcast_shapes +func @minimum_broadcast_shapes(%lhs: tensor, %rhs: tensor) + -> (tensor, tensor) { + %0, %1 = chlo.minimum_broadcast_shapes %lhs, %rhs : + tensor, tensor -> tensor, tensor + return %0, %1 : tensor, tensor +} + +// ----- + +func @minimum_broadcast_shapes_mismatch_operand_and_result_count(%lhs: tensor, %rhs: tensor) { + // expected-error @+1{{number of operand shapes (2) does not match number of result shapes (1)}} + %0 = chlo.minimum_broadcast_shapes %lhs, %rhs : + tensor, tensor -> tensor + return +} + +// ----- + +func @minimum_broadcast_shapes_one_operand(%arg: tensor) { + // expected-error @+1{{number of operand shapes (1) should be >= 2}} + %0 = chlo.minimum_broadcast_shapes %arg : tensor -> tensor + return +} diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir index fdaf9f824ec..af423110db3 100644 --- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir @@ -954,6 +954,28 @@ func @dynamic_broadcast_in_dim(%shape: tensor<1xindex>) -> tensor { // ----- +// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> ()> +// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: func @dynamic_broadcast_in_dim( +// CHECK-SAME: [[SHAPE:%.*]]: tensor<2xindex> +func @dynamic_broadcast_in_dim(%shape: tensor<2xindex>) -> tensor { + %cst = mhlo.constant dense<0x7F800000> : tensor + %result = "mhlo.dynamic_broadcast_in_dim"(%cst, %shape) { + broadcast_dimensions = dense<> : tensor<0xi64> + } : (tensor, tensor<2xindex>) -> tensor + return %result : tensor +} +// CHECK: [[CST:%.*]] = constant +// CHECK: [[INIT:%.*]] = linalg.init_tensor +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-SAME: ins([[CST]] : tensor) outs([[INIT]] : tensor) +// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// ----- + func @dot_matmul(%arg0: tensor<2x3xf32>, %arg1: tensor<3x?xf32>) -> tensor<2x?xf32> { %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xf32>, @@ -982,7 +1004,7 @@ func @dot_matmul_i8_i8_i32(%arg0: tensor<2x3xi8>, // CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] -// CHECK: linalg.matmul_i8_i8_i32 +// CHECK: linalg.matmul // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi8>, tensor<3x?xi8>) // CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>) @@ -1000,7 +1022,7 @@ func @dot_matmul_i16_i16_i32(%arg0: tensor<2x3xi16>, // CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] -// CHECK: linalg.matmul_i16_i16_i32 +// CHECK: linalg.matmul // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi16>, tensor<3x?xi16>) // CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>) @@ -1018,7 +1040,7 @@ func @dot_matmul_i32_i32_i32(%arg0: tensor<2x3xi32>, // CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] -// CHECK: linalg.matmul_i32_i32_i32 +// CHECK: linalg.matmul // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi32>, tensor<3x?xi32>) // CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>) @@ -1109,7 +1131,7 @@ func @dot_general_batch_matmul_i8_i8_i32(%arg0: tensor, // CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] -// CHECK: linalg.batch_matmul_i8_i8_i32 +// CHECK: linalg.batch_matmul // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) // CHECK-SAME: outs(%[[FILL]] : tensor) @@ -1138,7 +1160,7 @@ func @dot_general_batch_matmul_i16_i16_i32(%arg0: tensor, // CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] -// CHECK: linalg.batch_matmul_i16_i16_i32 +// CHECK: linalg.batch_matmul // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) // CHECK-SAME: outs(%[[FILL]] : tensor) @@ -1444,8 +1466,8 @@ func @pad_tensor(%arg0: tensor<12x4xf32>, %arg1: tensor) -> tensor<18x12xf3 // ----- -func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor, %arg1: tensor) - -> tensor { +func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor, %arg1: tensor<2x?x?xf32>) + -> tensor { %0 = "mhlo.convolution"(%arg0, %arg1) { batch_group_count = 1 : i64, dimension_numbers = { @@ -1463,31 +1485,29 @@ func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor, %arg1: tenso padding = dense<[[0], [0]]> : tensor<2x1xi64>, rhs_dilation = dense<1> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64> - } : (tensor, tensor) -> tensor - return %0 : tensor + } : (tensor, tensor<2x?x?xf32>) -> tensor + return %0 : tensor } // CHECK-LABEL: func @linalg.conv_1d_input_nwc_filter_wcf // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] // CHECK: %[[C0:.+]] = constant 0 : index -// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[C1:.+]] = constant 1 : index -// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor +// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor // CHECK: %[[C2:.+]] = constant 2 : index -// CHECK: %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor -// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]]] +// CHECK: %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor<2x?x?xf32> +// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, %[[DIM2]]] // CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 // CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) // CHECK: linalg.conv_1d_input_nwc_filter_wcf // CHECK-SAME: {dilations = dense<1> : tensor<1xi64> // CHECK-SAME: strides = dense<1> : tensor<1xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) -// CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor<2x?x?xf32>) +// CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor // ----- -func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor, %arg1: tensor) - -> tensor { +func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor, %arg1: tensor<3x2x?x?xf32>) + -> tensor { %0 = "mhlo.convolution"(%arg0, %arg1) { batch_group_count = 1 : i64, dimension_numbers = { @@ -1505,33 +1525,29 @@ func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor, %arg1: tensor : tensor<2x2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64> - } : (tensor, tensor) -> tensor - return %0 : tensor + } : (tensor, tensor<3x2x?x?xf32>) -> tensor + return %0 : tensor } // CHECK-LABEL: func @conv_2d_input_nhwc_filter_hwcf // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] // CHECK: %[[C0:.+]] = constant 0 : index -// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[C1:.+]] = constant 1 : index -// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor -// CHECK: %[[C2:.+]] = constant 2 : index -// CHECK: %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor +// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor // CHECK: %[[C3:.+]] = constant 3 : index -// CHECK: %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor -// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]]] +// CHECK: %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor<3x2x?x?xf32> +// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 2, 3, %[[DIM3]]] // CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 // CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) // CHECK: linalg.conv_2d_input_nhwc_filter_hwcf // CHECK-SAME: {dilations = dense<1> : tensor<2xi64> // CHECK-SAME: strides = dense<1> : tensor<2xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) -// CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor<3x2x?x?xf32>) +// CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor // ----- -func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor, %arg1: tensor) - -> tensor { +func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor, %arg1: tensor<2x2x2x?x?xf32>) + -> tensor { %0 = "mhlo.convolution"(%arg0, %arg1) { batch_group_count = 1 : i64, dimension_numbers = { @@ -1549,30 +1565,24 @@ func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor, %arg1: tens padding = dense<[[0, 0, 0], [0, 0, 0]]> : tensor<2x3xi64>, rhs_dilation = dense<1> : tensor<3xi64>, window_strides = dense<1> : tensor<3xi64> - } : (tensor, tensor) -> tensor - return %0 : tensor + } : (tensor, tensor<2x2x2x?x?xf32>) -> tensor + return %0 : tensor } // CHECK-LABEL: func @conv_3d_input_ndhwc_filter_dhwcf // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] // CHECK: %[[C0:.+]] = constant 0 : index -// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[C1:.+]] = constant 1 : index -// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor -// CHECK: %[[C2:.+]] = constant 2 : index -// CHECK: %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor -// CHECK: %[[C3:.+]] = constant 3 : index -// CHECK: %[[DIM3:.+]] = dim %[[ARG0]], %[[C3]] : tensor +// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor // CHECK: %[[C4:.+]] = constant 4 : index -// CHECK: %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor -// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]], %[[DIM4]]] +// CHECK: %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor<2x2x2x?x?xf32> +// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, 7, 7, %[[DIM4]]] // CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 // CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) // CHECK: linalg.conv_3d_input_ndhwc_filter_dhwcf // CHECK-SAME: {dilations = dense<1> : tensor<3xi64> // CHECK-SAME: strides = dense<1> : tensor<3xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) -// CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor<2x2x2x?x?xf32>) +// CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor // ----- diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-transform-unranked.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-transform-unranked.mlir index 6ae07cea130..b3abe717c89 100644 --- a/tensorflow/compiler/mlir/hlo/tests/hlo-transform-unranked.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-transform-unranked.mlir @@ -199,20 +199,24 @@ func @addUnrankedUnranked( // CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor, tensor) -> tensor<*xf32> // CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32> // CHECK-NEXT: } else { -// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor -> index -// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor -> index +// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor -> tensor +// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor -> tensor, tensor +// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor -> index +// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor -> index // CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi sgt, %[[LHS_RANK]], %[[RHS_RANK]] : index // CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index // Handle rank 1 specialization // CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C1]] : index -// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) { +// CHECK-NEXT: %[[RESULT_RANK_SPECIALIZATION:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) { // CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] -// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor +// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor // CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor to tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor +// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor // CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor to tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor, tensor) -> tensor // CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor to tensor<*xf32> // CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32> @@ -222,12 +226,12 @@ func @addUnrankedUnranked( // Handle rank 2 specialization // CHECK-NEXT: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) { // CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1] -// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor +// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor // CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor to tensor<2xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor +// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor +// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor // CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor to tensor<2xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor +// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor // CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor, tensor) -> tensor // CHECK-NEXT: %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor to tensor<*xf32> // CHECK-NEXT: scf.yield %[[RESULT_2]] : tensor<*xf32> @@ -237,12 +241,12 @@ func @addUnrankedUnranked( // Handle rank 3 specialization // CHECK-NEXT: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) { // CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] -// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor +// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor // CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor to tensor<3xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor +// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor +// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor // CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor to tensor<3xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor +// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor // CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor, tensor) -> tensor // CHECK-NEXT: %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor to tensor<*xf32> // CHECK-NEXT: scf.yield %[[RESULT_3]] : tensor<*xf32> @@ -252,47 +256,30 @@ func @addUnrankedUnranked( // Handle rank 4 specialization // CHECK-NEXT: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) { // CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1] -// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor +// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor // CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor to tensor<4xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor +// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor +// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor // CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor to tensor<4xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor +// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor // CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor, tensor) -> tensor // CHECK-NEXT: %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor to tensor<*xf32> // CHECK-NEXT: scf.yield %[[RESULT_4]] : tensor<*xf32> // CHECK-NEXT: } else { // CHECK-NEXT: %[[C5:.*]] = constant 5 : index // CHECK-NEXT: %[[GREATEST_RANK_IS_5:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C5]] : index +// CHECK-NEXT: assert %[[GREATEST_RANK_IS_5]] // Handle rank 5 specialization -// CHECK-NEXT: %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1] -// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor to tensor<5xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor to tensor<5xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor -// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor to tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32> -// CHECK-NEXT: } else { -// CHECK-NEXT: %[[C6:.*]] = constant 6 : index -// CHECK-NEXT: %[[GREATEST_RANK_IS_6:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C6]] : index -// CHECK-NEXT: assert %[[GREATEST_RANK_IS_6]] -// Handle rank 6 specialization -// CHECK-NEXT: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1] -// CHECK-NEXT: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor, tensor<6xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_6:.*]] = tensor.cast %[[BROADCASTED_LHS_6]] : tensor to tensor<6xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor, tensor<6xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_6:.*]] = tensor.cast %[[BROADCASTED_RHS_6]] : tensor to tensor<6xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor -// CHECK-NEXT: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_6:.*]] = tensor.cast %[[RESULT_RANK_6]] : tensor to tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESULT_6]] : tensor<*xf32> -// CHECK-NEXT: } -// CHECK-NEXT: scf.yield %[[VAL_65:.*]] : tensor<*xf32> +// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1] +// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor +// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor to tensor<5xindex> +// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor +// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor +// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor to tensor<5xindex> +// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor +// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor to tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32> // CHECK-NEXT: } // CHECK-NEXT: scf.yield %[[VAL_66:.*]] : tensor<*xf32> // CHECK-NEXT: } @@ -300,7 +287,8 @@ func @addUnrankedUnranked( // CHECK-NEXT: } // CHECK-NEXT: scf.yield %[[VAL_68:.*]] : tensor<*xf32> // CHECK-NEXT: } -// CHECK-NEXT: scf.yield %[[VAL_69:.*]] : tensor<*xf32> +// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESULT_RANK_SPECIALIZATION]], %[[RESULT_SHAPE]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESHAPED_RESULT]] : tensor<*xf32> // CHECK-NEXT: } // CHECK-NEXT: scf.yield %[[VAL_70:.*]] : tensor<*xf32> // CHECK-NEXT: } @@ -325,13 +313,18 @@ func @selectUnrankedUnrankedUnranked( // CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>, // CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> { // CHECK-NEXT: %[[PRED_SHAPE:.*]] = shape.shape_of %[[PRED]] : tensor<*xi1> -> tensor -// CHECK-NEXT: %[[PRED_RANK:.*]] = shape.rank %[[PRED_SHAPE]] : tensor -> index // CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor -// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor -> index +// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor +// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor, tensor -> tensor +// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:3 = chlo.minimum_broadcast_shapes %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor, tensor -> tensor, tensor, tensor +// CHECK-NEXT: %[[MINIMUM_RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xi1>, tensor) -> tensor<*xi1> +// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#2) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK-NEXT: %[[PRED_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor -> index +// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor -> index // CHECK-NEXT: %[[GREATER_RANK_CMP:.*]] = cmpi sgt, %[[PRED_RANK]], %[[LHS_RANK]] : index // CHECK-NEXT: %[[GREATER_RANK:.*]] = select %[[GREATER_RANK_CMP]], %[[PRED_RANK]], %[[LHS_RANK]] : index -// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor -// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor -> index +// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#2 : tensor -> index // CHECK-NEXT: %[[GREATEST_RANK_CMP:.*]] = cmpi sgt, %[[GREATER_RANK]], %[[RHS_RANK]] : index // CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[GREATEST_RANK_CMP]], %[[GREATER_RANK]], %[[RHS_RANK]] : index // CHECK-NEXT: %c1 = constant 1 : index @@ -339,15 +332,15 @@ func @selectUnrankedUnrankedUnranked( // Handle rank 1 specialization // CHECK-NEXT: scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) { // CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] : tensor<1xindex> -// CHECK-NEXT: %[[BROADCASTED_PRED:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor +// CHECK-NEXT: %[[BROADCASTED_PRED:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor // CHECK-NEXT: %[[CASTED_PRED:.*]] = tensor.cast %[[BROADCASTED_PRED]] : tensor to tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[CASTED_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_LHS:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor +// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_PRED]], %[[CASTED_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[BROADCASTED_LHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor // CHECK-NEXT: %[[CASTED_LHS:.*]] = tensor.cast %[[BROADCASTED_LHS]] : tensor to tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor +// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[BROADCASTED_RHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#2, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor // CHECK-NEXT: %[[CASTED_RHS:.*]] = tensor.cast %[[BROADCASTED_RHS]] : tensor to tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_select %[[RESHAPED_PRED]], %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor, tensor, tensor) -> tensor // CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1:.*]] : tensor to tensor<*xf32> // CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32> @@ -357,4 +350,3 @@ func @selectUnrankedUnrankedUnranked( // CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor // CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor // CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor -// CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor diff --git a/tensorflow/compiler/mlir/hlo/tests/ops.mlir b/tensorflow/compiler/mlir/hlo/tests/ops.mlir index 93c4a7693a3..358b760ed9f 100644 --- a/tensorflow/compiler/mlir/hlo/tests/ops.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/ops.mlir @@ -368,6 +368,16 @@ func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> { // ----- +// CHECK-LABEL: @concat_1D +// Verifies that an error is not thrown if the inferred type is compatible with +// the result type. +func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<3xi32> { + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<3xi32> + return %0 : tensor<3xi32> +} + +// ----- + func @concat_1D_type_error(%arg0: tensor<1xi32>, %arg1: tensor<2xf32>) -> tensor<3xi32> { // expected-error@+1 {{'mhlo.concatenate' op requires the same element type for all operands and results}} %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xf32>) -> tensor<3xi32> @@ -384,14 +394,6 @@ func @concat_1D_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor< // ----- -func @concat_1D_unranked_error(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<3xi32> { - // expected-error@+1 {{op inferred type(s) 'tensor<*xi32>' are incompatible with return type(s) of operation 'tensor<3xi32>'}} - %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<3xi32> - return %0 : tensor<3xi32> -} - -// ----- - func @concat_1D_error(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<4xi32> { // expected-error@+1 {{op inferred type(s) 'tensor<3xi32>' are incompatible with return type(s) of operation 'tensor<4xi32>'}} %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<4xi32> diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 20ff771405e..5e9a42c282b 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -627,6 +627,7 @@ tf_native_cc_binary( srcs = [ "converter_gen.cc", ], + compatible_with = get_compatible_with_cloud(), deps = [ "@llvm-project//llvm:Support", "@llvm-project//llvm:TableGen", diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 1103038675a..318121bc6d1 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -1192,11 +1192,11 @@ void AddRegionsForTflWhileOp(mlir::ModuleOp module) { auto cond = symbol_table.lookup( while_op->getAttr("cond").cast().getValue()); AddCallOpInWhileOpRegion(while_op.cond(), cond); - while_op.removeAttr("cond"); + while_op->removeAttr("cond"); auto body = symbol_table.lookup( while_op->getAttr("body").cast().getValue()); AddCallOpInWhileOpRegion(while_op.body(), body); - while_op.removeAttr("body"); + while_op->removeAttr("body"); }); } } // namespace diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 35cd5e6575d..04dc41d61ba 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -1011,9 +1011,7 @@ def TFL_BatchMatMulOp : TFL_Op<"batch_matmul", [ TFL_OperandHasAtleastRank<0, 2>, TFL_OperandHasAtleastRank<1, 2>, PredOpTrait<"x and output must have same element type", - TFL_TCresVTEtIsSameAsOp<0, 0>>, - PredOpTrait<"y and output must have same element type", - TFL_TCresVTEtIsSameAsOp<0, 1>>]> { + TFL_TCresVTEtIsSameAsOp<0, 0>>]> { let summary = "Batch Matrix Multiply Operator"; @@ -2774,7 +2772,7 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2", [ def TFL_SelectOp : TFL_Op<"select", [ NoSideEffect, SameOperandsAndResultsScale, - PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>, + PredOpTrait<"operands have same element type", TFL_TCopVTEtAreSameAt<1, 2>>, PredOpTrait<"operands and result have same element type", TFL_TCresVTEtIsSameAsOp<0, 1>>]> { let summary = "Select operator"; @@ -2810,8 +2808,9 @@ def TFL_SelectOp : TFL_Op<"select", [ def TFL_SelectV2Op : TFL_Op<"select_v2", [ ResultsBroadcastableShape, NoSideEffect, + SameOperandsAndResultsScale, TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1, 2], 4>, - PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>, + PredOpTrait<"operands have same element type", TFL_TCopVTEtAreSameAt<1, 2>>, PredOpTrait<"operands and result have same element type", TFL_TCresVTEtIsSameAsOp<0, 1>>]> { let summary = "SelectV2 operator"; diff --git a/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir b/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir index 77e4846aefa..354f8faaf97 100644 --- a/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir +++ b/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir @@ -14,15 +14,16 @@ func @testDilatedConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x120x120x8xf32> } -func @testDilatedConvWithNonConstantPadAndCrops(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x120x120x8xf32> { +func @testDilatedConvWithNonConstantPadAndCrops(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x120x120x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> - %1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> - %2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x120x120x8xf32> + %cst_1 = constant dense<0> : tensor<2x2xi32> + %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x64x64x3xf32> + %1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x64x64x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x60x60x8xf32> + %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_1) : (tensor<4x60x60x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x120x120x8xf32> return %2 : tensor<1x120x120x8xf32> // CHECK-LABEL: testDilatedConvWithNonConstantPadAndCrops - // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>) + // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>) // CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x120x120x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x120x120x8xf32> } @@ -73,37 +74,39 @@ func @testDilatedDepthWiseConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5 // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> } -func @testDilatedConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> { +func @testDilatedConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>, %arg2: tensor<8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = constant dense<4> : tensor<2x2xi32> %cst_1 = constant dense<0> : tensor<2x2xi32> + %cst_2 = constant dense<0> : tensor<4x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> - %1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> - %2 = "tf.Pad"(%1, %arg1) : (tensor<4x64x64x8xf32>, tensor<2x2xi32>) -> tensor<4x64x64x8xf32> + %1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> + %2 = "tf.Pad"(%1, %cst_2) : (tensor<4x64x64x8xf32>, tensor<4x2xi32>) -> tensor<4x64x64x8xf32> %3 = "tf.BatchToSpaceND"(%2, %cst, %cst_1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> - %4 = "tf.BiasAdd"(%3, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> + %4 = "tf.BiasAdd"(%3, %arg2) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> return %4 : tensor<1x128x128x8xf32> // CHECK-LABEL: testDilatedConvWithPad - // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>) + // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>) // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> } -func @testDilatedDepthWiseConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> { +func @testDilatedDepthWiseConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>, %arg2: tensor<8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = constant dense<4> : tensor<2x2xi32> %cst_1 = constant dense<0> : tensor<2x2xi32> + %cst_2 = constant dense<0> : tensor<4x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> - %1 = "tf.DepthwiseConv2dNative"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> - %2 = "tf.Pad"(%1, %arg1) : (tensor<4x64x64x8xf32>, tensor<2x2xi32>) -> tensor<4x64x64x8xf32> + %1 = "tf.DepthwiseConv2dNative"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> + %2 = "tf.Pad"(%1, %cst_2) : (tensor<4x64x64x8xf32>, tensor<4x2xi32>) -> tensor<4x64x64x8xf32> %3 = "tf.BatchToSpaceND"(%2, %cst, %cst_1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> - %4 = "tf.BiasAdd"(%3, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> + %4 = "tf.BiasAdd"(%3, %arg2) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> return %4 : tensor<1x128x128x8xf32> // CHECK-LABEL: testDilatedDepthWiseConvWithPad - // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>) + // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>) // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> @@ -235,22 +238,23 @@ func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, % // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32> } -func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> { +func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor<128xf32>) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = "tf.Const"() { value = dense<3> : tensor } : () -> tensor %cst_1 = constant dense<4> : tensor<2x2xi32> %cst_2 = constant dense<0> : tensor<2x2xi32> + %cst_3 = constant dense<0> : tensor<3x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor) -> tensor<4x68x68x1xf32> - %2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> + %2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32> - %4 = "tf.Pad"(%3, %arg1) : (tensor<4x64x64xf32>, tensor<2x2xi32>) -> tensor<4x64x64xf32> + %4 = "tf.Pad"(%3, %cst_3) : (tensor<4x64x64xf32>, tensor<3x2xi32>) -> tensor<4x64x64xf32> %5 = "tf.BatchToSpaceND"(%4, %cst, %cst_2) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> - %6 = "tf.BiasAdd"(%5, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> + %6 = "tf.BiasAdd"(%5, %arg2) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> return %6 : tensor<1x128x128xf32> // CHECK-LABEL: testDilatedConvWithExpandSqueeze3 - // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>) + // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>) // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor} : () -> tensor // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> @@ -259,22 +263,23 @@ func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: ten // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32> } -func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> { +func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor<128xf32>) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = "tf.Const"() { value = dense<3> : tensor } : () -> tensor %cst_1 = constant dense<4> : tensor<2x2xi32> %cst_2 = constant dense<0> : tensor<2x2xi32> + %cst_3 = constant dense<0> : tensor<3x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor) -> tensor<4x68x68x1xf32> - %2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> + %2 = "tf.DepthwiseConv2dNative"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32> - %4 = "tf.Pad"(%3, %arg1) : (tensor<4x64x64xf32>, tensor<2x2xi32>) -> tensor<4x64x64xf32> + %4 = "tf.Pad"(%3, %cst_3) : (tensor<4x64x64xf32>, tensor<3x2xi32>) -> tensor<4x64x64xf32> %5 = "tf.BatchToSpaceND"(%4, %cst, %cst_2) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> - %6 = "tf.BiasAdd"(%5, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> + %6 = "tf.BiasAdd"(%5, %arg2) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> return %6 : tensor<1x128x128xf32> // CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze3 - // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>) + // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>) // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor} : () -> tensor // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> @@ -481,3 +486,27 @@ func @testDilatedConv1DWithMixedPostiveAndNegativeAxis(%arg0: tensor<1x128x3xf32 // CHECK-NEXT: [[RESULT:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [-3]} : (tensor<1x1x128x8xf32>) -> tensor<1x128x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x8xf32> } + +func @testPaddedDilatedConv(%arg0 : tensor<2x1920x64xf32>) -> tensor<2x1920x128xf32> { + %0 = "tf.Const"() {value = dense<[[0, 0], [2, 0], [0, 0]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + %1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<0> : tensor<1x2xi32>} : () -> tensor<1x2xi32> + %4 = "tf.Const"() {value = dense<0.0> : tensor<3x1x64x128xf32>} : () -> tensor<3x1x64x128xf32> + %5 = "tf.SpaceToBatchND"(%arg0, %1, %3) {device = ""} : (tensor<2x1920x64xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<4x960x64xf32> + %6 = "tf.ExpandDims"(%5, %2) {device = ""} : (tensor<4x960x64xf32>, tensor) -> tensor<4x960x1x64xf32> + %7 = "tf.Conv2D"(%6, %4) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<4x960x1x64xf32>, tensor<3x1x64x128xf32>) -> tensor<4x958x1x128xf32> + %8 = "tf.Squeeze"(%7) {device = "", squeeze_dims = [2]} : (tensor<4x958x1x128xf32>) -> tensor<4x958x128xf32> + %9 = "tf.Pad"(%8, %0) {device = ""} : (tensor<4x958x128xf32>, tensor<3x2xi32>) -> tensor<4x960x128xf32> + %10 = "tf.BatchToSpaceND"(%9, %1, %3) {device = ""} : (tensor<4x960x128xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<2x1920x128xf32> + return %10 : tensor<2x1920x128xf32> + + // CHECK-LABEL: testPaddedDilatedConv + // CHECK-SAME: ([[INPUT:%.*]]: tensor<2x1920x64xf32>) + // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + // CHECK-NEXT: [[FILTER:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x1x64x128xf32>} : () -> tensor<3x1x64x128xf32> + // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) {device = ""} : (tensor<2x1920x64xf32>, tensor) -> tensor<2x1920x1x64xf32> + // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {data_format = "NHWC", device = "", dilations = [1, 2, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<2x1920x1x64xf32>, tensor<3x1x64x128xf32>) -> tensor<2x1920x1x128xf32> + // CHECK-NEXT: [[RESULT:%.*]] = "tf.Squeeze"([[CONV]]) {device = "", squeeze_dims = [2]} : (tensor<2x1920x1x128xf32>) -> tensor<2x1920x128xf32> + // CHECK-NEXT: return [[RESULT]] : tensor<2x1920x128xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index 7473c5d403a..e36b5abbc9a 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -1387,6 +1387,15 @@ func @testBatchMatmulQuant(%arg0 : tensor<1x4x384x32x!quant.uniform>, tensor<1x4x384x32x!quant.uniform>) -> tensor<1x4x384x384x!quant.uniform> return %0 : tensor<1x4x384x384x!quant.uniform> } + +// ----- + +func @testBatchMatmulHybridQuant(%arg0 : tensor<1x4x384x32xf32>, %arg1 : tensor<1x4x384x32x!quant.uniform>) -> tensor<1x4x384x384xf32> { + // CHECK: "tfl.batch_matmul"(%arg0, %arg1) + %0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = true} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32x!quant.uniform>) -> tensor<1x4x384x384xf32> + return %0 : tensor<1x4x384x384xf32> +} + // ----- func @testConcat(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xi32>) -> tensor<2x2xi32> { diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index f1d698df0c4..50768cbc982 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -564,6 +564,78 @@ func @FuseFullyConnectedReshapeAddConstWithActivation(%arg0: tensor<40x37xf32>, // FOLD: return %[[fc]] } +// CHECK-LABEL: @FuseFullyConnectedReshapeAdd2DConst +func @FuseFullyConnectedReshapeAdd2DConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<1x40x4x10xf32> { + %cst = constant unit + %cst2 = constant dense<2.0> : tensor<4x10xf32> + %shape = constant dense<[1, 40, 4, 10]> : tensor<4xi32> + + %0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>) + %1 = "tfl.reshape"(%0, %shape) : (tensor<40x40xf32>, tensor<4xi32>) -> tensor<1x40x4x10xf32> + %2 = "tfl.add"(%1, %cst2) {fused_activation_function = "NONE"} : (tensor<1x40x4x10xf32>, tensor<4x10xf32>) -> tensor<1x40x4x10xf32> + + return %2 : tensor<1x40x4x10xf32> + + // CHECK: %[[cst:.*]] = constant dense<2.000000e+00> : tensor<40xf32> + // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} + // CHECK: %[[rs:.*]] = "tfl.reshape"(%[[fc]] + // CHECK: return %[[rs]] +} + +// CHECK-LABEL: @FuseFullyConnectedReshapeAdd2DConstWithActivation +func @FuseFullyConnectedReshapeAdd2DConstWithActivation(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<1x40x4x10xf32> { + %cst = constant unit + %cst2 = constant dense<2.0> : tensor<4x10xf32> + %shape = constant dense<[1, 40, 4, 10]> : tensor<4xi32> + + %0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>) + %1 = "tfl.reshape"(%0, %shape) : (tensor<40x40xf32>, tensor<4xi32>) -> tensor<1x40x4x10xf32> + %2 = "tfl.add"(%1, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x40x4x10xf32>, tensor<4x10xf32>) -> tensor<1x40x4x10xf32> + + return %2 : tensor<1x40x4x10xf32> + + // CHECK: %[[cst:.*]] = constant dense<2.000000e+00> : tensor<40xf32> + // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} + // CHECK: %[[rs:.*]] = "tfl.reshape"(%[[fc]] + // CHECK: return %[[rs]] +} + +// CHECK-LABEL: @FuseFullyConnectedReshapeAdd2DConstWithExistingBias +func @FuseFullyConnectedReshapeAdd2DConstWithExistingBias(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<1x40x4x10xf32> { + %cst = constant dense<3.0> : tensor<40xf32> + %cst2 = constant dense<2.0> : tensor<4x10xf32> + %shape = constant dense<[1, 40, 4, 10]> : tensor<4xi32> + + %0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40xf32>) -> (tensor<40x40xf32>) + %1 = "tfl.reshape"(%0, %shape) : (tensor<40x40xf32>, tensor<4xi32>) -> tensor<1x40x4x10xf32> + %2 = "tfl.add"(%1, %cst2) {fused_activation_function = "NONE"} : (tensor<1x40x4x10xf32>, tensor<4x10xf32>) -> tensor<1x40x4x10xf32> + + return %2 : tensor<1x40x4x10xf32> + + // CHECK: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40xf32> + // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) + // CHECK: %[[rs:.*]] = "tfl.reshape"(%[[fc]] + // CHECK: return %[[rs]] +} + +// CHECK-LABEL: @NotFuseFullyConnectedReshapeAdd2DConstIfLastDimIsNotNumElementsOfRhs +func @NotFuseFullyConnectedReshapeAdd2DConstIfLastDimIsNotNumElementsOfRhs(%arg0: tensor<40x37xf32>, %arg1: tensor<20x37xf32>) -> tensor<1x20x4x10xf32> { + %cst = constant unit + %cst2 = constant dense<2.0> : tensor<4x10xf32> + %shape = constant dense<[1, 20, 4, 10]> : tensor<4xi32> + + %0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<20x37xf32>, none) -> (tensor<40x20xf32>) + %1 = "tfl.reshape"(%0, %shape) : (tensor<40x20xf32>, tensor<4xi32>) -> tensor<1x20x4x10xf32> + %2 = "tfl.add"(%1, %cst2) {fused_activation_function = "NONE"} : (tensor<1x20x4x10xf32>, tensor<4x10xf32>) -> tensor<1x20x4x10xf32> + + return %2 : tensor<1x20x4x10xf32> + + // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1 + // CHECK: %[[rs:.*]] = "tfl.reshape"(%[[fc]] + // CHECK: %[[add:.*]] = "tfl.add"(%[[rs]] + // CHECK: return %[[add]] +} + // CHECK-LABEL: @NotReorderReshapeAddIfNotBroadcastableAfter func @NotReorderReshapeAddIfNotBroadcastableAfter(%arg0: tensor<40x10x4xf32>) -> tensor<40x40xf32> { %cst = constant dense<2.0> : tensor<40xf32> @@ -642,6 +714,19 @@ func @NotReorderReshapeAddIfHighDim(%arg0: tensor<1x1x1x1x30x96xf32>) -> tensor< // CHECK: return %[[rs2]] } +// CHECK-LABEL: @NotReorderReshapeAdd2DConstIfInputIsNotDefinedByFullyConnected +func @NotReorderReshapeAdd2DConstIfInputIsNotDefinedByFullyConnected(%arg0: tensor<8x15xf32>) -> tensor<1x8x3x5xf32> { + %cst = constant dense<2.0> : tensor<3x5xf32> + %shape = constant dense<[1, 8, 3, 5]> : tensor<4xi32> + %1 = "tfl.reshape"(%arg0, %shape) : (tensor<8x15xf32>, tensor<4xi32>) -> tensor<1x8x3x5xf32> + %2 = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<1x8x3x5xf32>, tensor<3x5xf32>) -> tensor<1x8x3x5xf32> + return %2 : tensor<1x8x3x5xf32> + + // CHECK: %[[rs:.*]] = "tfl.reshape"(%arg0 + // CHECK: %[[add:.*]] = "tfl.add"(%[[rs]] + // CHECK: return %[[add]] +} + // CHECK-LABEL: @ReorderElementwiseValueOpAndMoveOp func @ReorderElementwiseValueOpAndMoveOp(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> { %shape = constant dense<[40, 40]> : tensor<2xi32> @@ -1679,8 +1764,17 @@ func @ReorderReshapex2Add(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1x2x3x4xf32> // CHECK: return %[[VAL_1]] } -// CHECK-LABEL: ConvertSliceToIdentity -func @ConvertSliceToIdentity(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> { +// CHECK-LABEL: ConvertSliceToIdentityI32 +func @ConvertSliceToIdentityI32(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> { + %begin = constant dense<0> : tensor<4xi32> + %shape = constant dense<[2,3,4,5]> : tensor<4xi32> + %0 = "tfl.slice"(%arg0, %begin, %shape) : (tensor<2x3x4x5xf32>, tensor<4xi32>, tensor<4xi32>) -> tensor<2x3x4x5xf32> + return %0 : tensor<2x3x4x5xf32> + // CHECK: return %arg0 +} + +// CHECK-LABEL: ConvertSliceToIdentityI64 +func @ConvertSliceToIdentityI64(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> { %begin = constant dense<0> : tensor<4xi64> %shape = constant dense<[2,3,4,5]> : tensor<4xi64> %0 = "tfl.slice"(%arg0, %begin, %shape) : (tensor<2x3x4x5xf32>, tensor<4xi64>, tensor<4xi64>) -> tensor<2x3x4x5xf32> @@ -1688,6 +1782,24 @@ func @ConvertSliceToIdentity(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> // CHECK: return %arg0 } +// CHECK-LABEL: ConvertSliceToIdentityStaticDimWithShapeWithNeg1 +func @ConvertSliceToIdentityStaticDimWithShapeWithNeg1(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> { + %begin = constant dense<0> : tensor<4xi32> + %shape = constant dense<[-1, 3, -1, 5]> : tensor<4xi32> + %0 = "tfl.slice"(%arg0, %begin, %shape) : (tensor<2x3x4x5xf32>, tensor<4xi32>, tensor<4xi32>) -> tensor<2x3x4x5xf32> + return %0 : tensor<2x3x4x5xf32> + // CHECK: return %arg0 +} + +// CHECK-LABEL: ConvertSliceToIdentityDynamicDimAndShapeWithNeg1 +func @ConvertSliceToIdentityDynamicDimAndShapeWithNeg1(%arg0: tensor) -> tensor { + %begin = constant dense<0> : tensor<4xi32> + %shape = constant dense<[-1, 3, -1, 5]> : tensor<4xi32> + %0 = "tfl.slice"(%arg0, %begin, %shape) : (tensor, tensor<4xi32>, tensor<4xi32>) -> tensor + return %0 : tensor + // CHECK: return %arg0 +} + // CHECK-LABEL: DontConvertSliceToIdentity func @DontConvertSliceToIdentity(%arg0: tensor<2x3x4x5xf32>) -> (tensor<2x3x4x4xf32>, tensor<1x2x3x4xf32>) { %begin0 = constant dense<0> : tensor<4xi64> @@ -1706,6 +1818,28 @@ func @DontConvertSliceToIdentity(%arg0: tensor<2x3x4x5xf32>) -> (tensor<2x3x4x4x // CHECK: return %[[SLICE_0]], %[[SLICE_1]] : tensor<2x3x4x4xf32>, tensor<1x2x3x4xf32> } +// CHECK-LABEL: DontConvertSliceToIdentityNonConstShape +func @DontConvertSliceToIdentityNonConstShape(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { + %begin = constant dense<0> : tensor<1xi32> + %0 = "tfl.slice"(%arg0, %begin, %arg1) : (tensor, tensor<1xi32>, tensor<1xi32>) -> tensor + return %0 : tensor + // CHECK: %[[BEGIN:.*]] = constant dense<0> : tensor<1xi32> + // CHECK: %[[SLICE:.*]] = "tfl.slice"(%arg0, %[[BEGIN]], %arg1) : (tensor, tensor<1xi32>, tensor<1xi32>) -> tensor + // CHECK: return %[[SLICE]] : tensor +} + +// CHECK-LABEL: DontConvertSliceToIdentityDynamicDimButEqualShape +func @DontConvertSliceToIdentityDynamicDimButEqualShape(%arg0: tensor) -> tensor { + %begin = constant dense<0> : tensor<1xi32> + %shape = constant dense<2> : tensor<1xi32> + %0 = "tfl.slice"(%arg0, %begin, %shape) : (tensor, tensor<1xi32>, tensor<1xi32>) -> tensor + return %0 : tensor + // CHECK: %[[BEGIN:.*]] = constant dense<0> : tensor<1xi32> + // CHECK: %[[SHAPE:.*]] = constant dense<2> : tensor<1xi32> + // CHECK: %[[SLICE:.*]] = "tfl.slice"(%arg0, %[[BEGIN]], %[[SHAPE]]) : (tensor, tensor<1xi32>, tensor<1xi32>) -> tensor + // CHECK: return %[[SLICE]] : tensor +} + // CHECK-LABEL: @FuseAddWithFullyConnectedWithBias func @FuseAddWithFullyConnectedWithBias(%arg: tensor<2x512xf32>) -> tensor<2x1024xf32> { %cst_add = constant dense<2.0> : tensor<512xf32> diff --git a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h index ce83da99561..2cd53b49067 100644 --- a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h +++ b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h @@ -22,11 +22,13 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -82,35 +84,77 @@ template LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( Conv2dOpTy op, PatternRewriter& rewriter) const { if (!op.getResult().hasOneUse()) { - return failure(); + return rewriter.notifyMatchFailure( + op, "result for current op has more than 1 use"); } // Make sure Conv2D has 'VALID' padding. if (op->template getAttrOfType("padding").getValue() != "VALID") { - return failure(); + return rewriter.notifyMatchFailure(op, + "Conv2D op doesn't have valid padding"); } // Make sure dilations are all ones if set. const ArrayAttr& dilations = op->template getAttrOfType("dilations"); if (dilations && !TFIntListIsAllOnes(dilations)) { - return failure(); + return rewriter.notifyMatchFailure(op, "dilations should be all 1"); } - if (!TFTypeIsFloat32Tensor(op.input()) || !TFDataFormatIsNHWC(op)) - return failure(); + if (!TFTypeIsFloat32Tensor(op.input()) || !TFDataFormatIsNHWC(op)) { + return rewriter.notifyMatchFailure( + op, "op's input is not float or the data format isn't NHWC"); + } // Allow dynamic width and height dimensions only. auto result_ty = op.getResult().getType().template cast(); if (!result_ty.hasRank() || result_ty.getRank() != 4 || - result_ty.isDynamicDim(0) || result_ty.isDynamicDim(3)) - return failure(); + result_ty.isDynamicDim(0) || result_ty.isDynamicDim(3)) { + return rewriter.notifyMatchFailure( + op, "only dynamic width and height dimensions are allowed"); + } // Check if the ConvOp is preceded by a `Expand` op and succeeded by a // `Squeeze` op. Operation* prev_op = op.getOperation()->getPrevNode(); - if (!prev_op) return failure(); + if (!prev_op || prev_op->getNumResults() != 1) { + return rewriter.notifyMatchFailure( + op, "op doesn't have a preceding node that has a single result"); + } + if (!prev_op->hasOneUse() || *(prev_op->getResult(0).user_begin()) != op) { + return rewriter.notifyMatchFailure( + op, "op's input isn't produced by previous operation"); + } - Operation* next_op = op.getOperation()->getNextNode(); - if (!next_op) return failure(); + auto tryGetNextNode = + [&rewriter](Operation* current) -> std::pair { + // Check the current operation has a single result. + if (current->getNumResults() != 1) { + return { + rewriter.notifyMatchFailure(current, "op doesn't have single result"), + nullptr}; + } + // Check the current operation has a next node. + Operation* next_op = current->getNextNode(); + if (!next_op) { + return {rewriter.notifyMatchFailure(current, "op doesn't have next node"), + nullptr}; + } + // Check the current operation's result is used by its successor node. + if (!current->hasOneUse() || + *(current->getResult(0).user_begin()) != next_op) { + return { + rewriter.notifyMatchFailure( + current, "op's result isn't directly consumed by the next op"), + nullptr}; + } + return {LogicalResult::success(), next_op}; + }; + + std::pair maybeNextNode = + tryGetNextNode(op.getOperation()); + if (failed(maybeNextNode.first)) { + return maybeNextNode.first; + } + Operation* next_op = maybeNextNode.second; TF::ExpandDimsOp expand_op; TF::SqueezeOp squeeze_op; @@ -119,15 +163,19 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( if (llvm::isa(prev_op)) { if (!llvm::isa(next_op)) { // Expand/Squeeze op must come in pair. - return failure(); + return rewriter.notifyMatchFailure( + op, "ExpandDimsOp and SqueezeOp should come in pair"); } expand_op = llvm::cast(prev_op); squeeze_op = llvm::cast(next_op); - if (!expand_op.getResult().hasOneUse() || - !squeeze_op.getResult().hasOneUse()) { - return failure(); + if (!expand_op.getResult().hasOneUse()) { + return rewriter.notifyMatchFailure( + expand_op, "result for current op has more than 1 use"); + } + if (!squeeze_op.getResult().hasOneUse()) { + return rewriter.notifyMatchFailure( + squeeze_op, "result for current op has more than 1 use"); } - // Make sure that the axis in `expand_op` is constant. if (auto const_op = llvm::dyn_cast(expand_op.dim().getDefiningOp())) { @@ -141,12 +189,14 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( expand_axis += 4; } } else { - return failure(); + return rewriter.notifyMatchFailure( + expand_op, "ExpandDimsOp doesn't have a constant axis"); } // Make sure that the `squeeze_dims` is equal to `expand_axis`. auto squeeze_dims = squeeze_op.squeeze_dims(); if (squeeze_dims.size() != 1) { - return failure(); + return rewriter.notifyMatchFailure( + squeeze_op, "squeeze dims should have exactly 1 dimension specified"); } int64_t squeeze_axis = squeeze_dims[0].cast().getInt(); if (squeeze_axis < 0) { @@ -154,36 +204,62 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( squeeze_axis += 4; } if (squeeze_axis != expand_axis) { - return failure(); + return rewriter.notifyMatchFailure( + op, "squeeze axis and expand axis doesn't match"); } // Update previous/next op pointer. - prev_op = prev_op->getPrevNode(); - if (!prev_op) return failure(); - next_op = next_op->getNextNode(); - if (!next_op) return failure(); + Operation* tmp = prev_op->getPrevNode(); + if (!tmp || tmp->getNumResults() != 1) { + return rewriter.notifyMatchFailure( + prev_op, "op doesn't have a preceding node that has a single result"); + } + if (!tmp->hasOneUse() || *(tmp->getResult(0).user_begin()) != prev_op) { + return rewriter.notifyMatchFailure( + prev_op, "op's input isn't defined by its previous node"); + } + prev_op = tmp; + std::pair maybeNextNode = + tryGetNextNode(next_op); + if (failed(maybeNextNode.first)) { + return maybeNextNode.first; + } + next_op = maybeNextNode.second; } // SpaceToBatchND op. - if (!llvm::isa(prev_op)) return failure(); + if (!llvm::isa(prev_op)) { + return rewriter.notifyMatchFailure(prev_op, + "op should be a SpaceToBatchND op"); + } // TODO(b/149936532): Check `padding` input, currently ignored. TF::SpaceToBatchNDOp stb_op = llvm::cast(prev_op); if (!stb_op.getResult().hasOneUse()) { - return failure(); + return rewriter.notifyMatchFailure( + stb_op, "result for current op has more than 1 use"); } // Pad op. TF::PadOp pad_op; - // TODO(b/149936532): Currently we just ignore the PadOp. However note that - // in real scenarios this may not always be correct: user can put a PadOp here - // with non-trivial consequences. + ElementsAttr pad_attr; if (llvm::isa(next_op)) { pad_op = llvm::cast(next_op); if (!pad_op.getResult().hasOneUse()) { - return failure(); + return rewriter.notifyMatchFailure( + pad_op, "result for current op has more than 1 use"); + } + std::pair maybeNextNode = + tryGetNextNode(next_op); + if (failed(maybeNextNode.first)) { + return maybeNextNode.first; + } + next_op = maybeNextNode.second; + if (!matchPattern(pad_op.paddings(), m_Constant(&pad_attr))) { + // If the padding value isn't constant, we can't determine the padding + // scheme for Conv2D below, in this case just reject the pattern. + return rewriter.notifyMatchFailure( + pad_op, "PadOp's padding value isn't constant"); } - next_op = next_op->getNextNode(); - if (!next_op) return failure(); } // BatchToSpaceND + BiasAdd. @@ -194,33 +270,53 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( // Must be BiasAdd + BatchToSpaceND. biasadd_op = llvm::cast(next_op); if (!biasadd_op.getResult().hasOneUse()) { - return failure(); + return rewriter.notifyMatchFailure( + biasadd_op, "result for current op has more than 1 use"); } - next_op = next_op->getNextNode(); - if (!next_op || !llvm::isa(next_op)) return failure(); + std::pair maybeNextNode = + tryGetNextNode(next_op); + if (failed(maybeNextNode.first)) { + return maybeNextNode.first; + } + if (!llvm::isa(maybeNextNode.second)) { + return rewriter.notifyMatchFailure( + next_op, "op's next node isn't BatchToSpaceND op"); + } + next_op = maybeNextNode.second; bts_op = llvm::cast(next_op); } else if (llvm::isa(next_op)) { // BatchToSpaceND + (optional) BiasAdd. bts_op = llvm::cast(next_op); - next_op = next_op->getNextNode(); - if (next_op && llvm::isa(next_op)) { + Operation* tmp = next_op->getNextNode(); + if (tmp && llvm::isa(tmp)) { if (!bts_op.getResult().hasOneUse()) { - return failure(); + return rewriter.notifyMatchFailure( + bts_op, "result for current op has more than 1 use"); } + if (!next_op->hasOneUse() || + *(next_op->getResult(0).user_begin()) != tmp) { + return rewriter.notifyMatchFailure( + next_op, "op's result isn't directly consumed by the next op"); + } + next_op = tmp; biasadd_op = llvm::cast(next_op); final_op_is_bts = false; } } else { - return failure(); + return rewriter.notifyMatchFailure( + next_op, "next op is neither BiasAdd nor BatchToSpaceND"); } llvm::Optional dilations_attr = ExtractDilationsAttrFromBlockShape( stb_op.block_shape(), bts_op.block_shape(), expand_axis, rewriter); - if (!dilations_attr.hasValue()) return failure(); + if (!dilations_attr.hasValue()) { + return rewriter.notifyMatchFailure(op, "failed to extract dilation rate"); + } if (expand_op) { if (stb_op.input().getType().dyn_cast() == nullptr) { - return failure(); + return rewriter.notifyMatchFailure( + stb_op, "SpaceToBatchND op's input should have RankedTensorType"); } } @@ -255,16 +351,33 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( auto stb_paddings = stb_op.paddings(); auto bts_crops = bts_op.crops(); ElementsAttr stb_paddings_attr, bts_crops_attr; - if (matchPattern(stb_paddings, m_Constant(&stb_paddings_attr)) && - matchPattern(bts_crops, m_Constant(&bts_crops_attr))) { - if (stb_paddings_attr.getNumElements() != bts_crops_attr.getNumElements()) - return failure(); - // padding - crop. - auto paddings = stb_paddings_attr.getValues(); - auto crops = bts_crops_attr.getValues(); - for (auto it1 = paddings.begin(), it2 = crops.begin(); - it1 != paddings.end() && it2 != crops.end(); it1++, it2++) { - if ((*it1).getInt() != (*it2).getInt()) { + if (!matchPattern(stb_paddings, m_Constant(&stb_paddings_attr)) || + !matchPattern(bts_crops, m_Constant(&bts_crops_attr))) { + return rewriter.notifyMatchFailure( + op, + "either SpaceToBatchND or BatchToSpaceND " + "doesn't have constant padding/crops value"); + } + if (stb_paddings_attr.getType() != bts_crops_attr.getType()) { + return rewriter.notifyMatchFailure( + stb_op, + "SpaceToBatchND op's padding doesn't have same shape/type with " + "BatchToSpaceND op's crops"); + } + int64_t m = stb_paddings_attr.getType().getDimSize(0); + // padding - crop. + for (uint64_t i = 0; i < m; ++i) { + for (uint64_t j = 0; j < 2; ++j) { + // `crops` tensor has shape [M, 2], crops[i] = [crop_start, crop_end] + // specifies the amount to crop from input dimension i + 1. If the input + // of `BatchToSpaceND` has been padded explicitly, then we need to + // take into account the additional padding when determining the padding + // scheme for `Conv2D`. + int64_t addtional_pad = + pad_attr ? pad_attr.getValue({i + 1, j}).getInt() : 0; + if (stb_paddings_attr.getValue({i, j}).getInt() + + addtional_pad != + bts_crops_attr.getValue({i, j}).getInt()) { op->setAttr("padding", rewriter.getStringAttr("SAME")); break; } @@ -316,7 +429,11 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( } if (final_op_is_bts) { - bts_op.getResult().replaceAllUsesWith(bts_op.input()); + if (bts_op.input().getDefiningOp()) { + bts_op.getResult().replaceAllUsesWith(pad_op.input()); + } else { + bts_op.getResult().replaceAllUsesWith(bts_op.input()); + } } stb_op.getResult().dropAllUses(); diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index 42494e54564..35fb7c84ecb 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -816,7 +816,7 @@ struct ConvertIdentity : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Value input = operands[0]; rewriter.replaceOpWithNewOp(op, input.getType(), operands, - op.getAttrs()); + op->getAttrs()); return success(); } }; @@ -948,7 +948,7 @@ struct ConvertWhile : public OpConversionPattern { // Create a new while op with new operands and updated result types. auto converted = rewriter.create(op.getLoc(), result_types, - operands, op.getAttrs()); + operands, op->getAttrs()); converted.removeAttr("T"); (void)UpdateFunctionTypes(rewriter, converted, tensor_list_args); @@ -972,7 +972,7 @@ struct ConvertWhileRegion : public OpConversionPattern { // Create a new while op with new operands and updated result types. auto converted = rewriter.create( - op.getLoc(), result_types, operands, op.getAttrs()); + op.getLoc(), result_types, operands, op->getAttrs()); // Inline the regions from the old while into the new one, and apply // signature conversion to inlined region. diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 35b159f4022..d8b9145994a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -210,6 +210,50 @@ bool CanOptimizeIdentityGatherNdOrScatterNdOp(Value params, return true; } +// Returns true if we can eliminate the SliceOp. When the values of `begin` are +// all 0s and `size[i]` is equal to either -1 or `input.shape[i]` +// for each dim i, the output tensor is identical to `input`. +bool CanOptimizeIdentitySliceOp(Value input, Attribute begin, Attribute size) { + // Checks if `begin` and `size` are i32 or i64. + auto begin_attr = begin.dyn_cast(); + auto size_attr = size.dyn_cast(); + if (!begin_attr || !size_attr) { + return false; + } + + auto begin_elem_ty = begin_attr.getType().getElementType(); + if (!begin_elem_ty.isInteger(32) && !begin_elem_ty.isInteger(64)) { + return false; + } + auto size_elem_ty = size_attr.getType().getElementType(); + if (!size_elem_ty.isInteger(32) && !size_elem_ty.isInteger(64)) { + return false; + } + + // Checks if `input` is ranked and its rank is equal to number of elements in + // `begin` and `size`. + auto input_ty = input.getType().cast(); + if (!input_ty.hasRank()) { + return false; + } + + int64_t rank = input_ty.getRank(); + if (rank != begin_attr.getNumElements() || + rank != size_attr.getNumElements()) { + return false; + } + + // Checks if `begin` is all 0s, and `size[i]` is equal to either -1 or + // `input.shape[i]`. + for (uint64_t i = 0; i < rank; ++i) { + if (begin_attr.getValue({i}).getSExtValue() != 0) return false; + int64_t si = size_attr.getValue({i}).getSExtValue(); + if (si != -1 && si != input_ty.getDimSize(i)) return false; + } + + return true; +} + // Expand Attribute 'a' to 4D with all 1s except 1 dimension. // Which dimension depends on 'is_depthwise' is true or false. ElementsAttr ExpandTo4DForConvImpl(Attribute a, bool is_depthwise) { diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 8cdf9b61308..00d0bc0049c 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -367,6 +367,19 @@ def OperandsBroadcastToOutputType : Constraint>; +def Flatten : NativeCodeCall< + "$0.cast()" + ".reshape(RankedTensorType::get({$0.getType().cast().getNumElements()}, " + "$0.getType().cast().getElementType()))">; + +def IsLastDimEqualToNumElements : Constraint().getRank() >= 1 && " + "$0.getType().cast().getDimSize($0.getType().cast().getRank() - 1) == " + "$1.getType().cast().getNumElements()">>; + +def IsDefinedByFullyConnectedOp : Constraint() != nullptr">>; + // Pattern for skipping Tile if it is mainly for broadcasting and the // Op is already supporting broadcasting. multiclass FuseTileBroadcastIntoFollowingBinary { @@ -446,6 +459,29 @@ foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in { (IsTailOfShape $lhs, $rhs), (IsTailOfShape $input1, $input2), (IsTailOfShape $input2, $input1)]>; + + // Move binary op before reshape: + // binary(reshape(lhs), rhs) => reshape(binary(lhs, flatten(rhs))) + // This is valid only when the last dimension of lhs is equal to the + // number of elements in constant rhs. + // Therefore, after transformation broadcast of binary op is always + // applied to the last dimension of $input. + def MoveBinaryOpFlattenConstBeforeReshape#BinaryOp : Pat< + (BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)), + (ConstantOp:$rhs ElementsAttr:$rhs_attr), $act_fn), + (TFL_ReshapeOp (BinaryOp $input, (ConstantOp (Flatten $rhs_attr)), + $act_fn), + $shape), + [(AnyStaticShapeTensor $input), + (IsTailOfShape $rhs, $lhs), + (IsLastDimEqualToNumElements $input, $rhs), + (HasOneUse $lhs), + // Restrict operands to have at most rank 4 because TFLite binary + // kernel supports up to 4D broadcast. + (HasRankAtMost<4> $input), + (HasRankAtMost<4> $lhs), + (HasRankAtMost<4> $rhs), + (IsDefinedByFullyConnectedOp $input)]>; } foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp, @@ -491,6 +527,28 @@ foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp, (IsTailOfShape $lhs, $rhs), (IsTailOfShape $input1, $input2), (IsTailOfShape $input2, $input1)]>; + + // Move binary op before reshape: + // binary(reshape(lhs), rhs) => reshape(binary(lhs, flatten(rhs))) + // This is valid only when the last dimension of lhs is equal to the + // number of elements in constant rhs. + // Therefore, after transformation broadcast of binary op is always + // applied to the last dimension of $input. + def MoveBinaryOpFlattenConstBeforeReshape#BinaryOp : Pat< + (BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)), + (ConstantOp:$rhs ElementsAttr:$rhs_attr)), + (TFL_ReshapeOp (BinaryOp $input, (ConstantOp (Flatten $rhs_attr))), + $shape), + [(AnyStaticShapeTensor $input), + (IsTailOfShape $rhs, $lhs), + (IsLastDimEqualToNumElements $input, $rhs), + (HasOneUse $lhs), + // Restrict operands to have at most rank 4 because TFLite binary + // kernel supports up to 4D broadcast. + (HasRankAtMost<4> $input), + (HasRankAtMost<4> $lhs), + (HasRankAtMost<4> $rhs), + (IsDefinedByFullyConnectedOp $input)]>; } // Reorder the element-wise value operations and the element move operations, @@ -760,17 +818,11 @@ foreach ArgMinMaxOp = [TFL_ArgMinOp, TFL_ArgMaxOp] in { (AxesIsLastDimension $axes, $logits)]>; } -class AllElementsAreI64 : Constraint() && " - "$0.cast().getType().cast().getElementType().isInteger(64) && " - "std::all_of($0.cast().getValues().begin(), " - "$0.cast().getValues().end(), " - "[](int64_t v){ return v == " #val# ";}))">>; +def CanOptimizeIdentitySliceOp : Constraint>; -// Remove Slice ops slicing the whole input tensor, effectively no-op +// Remove Slice ops slicing the whole input tensor, effectively no-op. def OptimizeSliceOp : Pat< - (TFL_SliceOp:$output $input, (ConstantOp $begin), $shape), + (TFL_SliceOp:$output $input, (ConstantOp $begin), (ConstantOp $size)), (replaceWithValue $input), - [(AllElementsAreI64<"0"> $begin), - (IsTailOfShape $input, $output), - (IsTailOfShape $output, $input)]>; + [(CanOptimizeIdentitySliceOp $input, $begin, $size)]>; diff --git a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc index f1362326ec5..a3060f39dfc 100644 --- a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc +++ b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc @@ -254,7 +254,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) { new_types.push_back(extra_operand.getType()); auto new_while_op = OpBuilder(while_op).create( - while_op.getLoc(), new_types, operands, while_op.getAttrs()); + while_op.getLoc(), new_types, operands, while_op->getAttrs()); new_while_op.cond().takeBody(while_op.cond()); new_while_op.body().takeBody(while_op.body()); while_op.replaceAllUsesWith( diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc index 28e3b0799f2..0b9aa400734 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc @@ -62,8 +62,7 @@ FuncOp createLstmCompositeFunc(mlir::Builder* builder, bool ln, bool cifg) { auto func_type = builder->getFunctionType(input_types, output_type); auto func = - FuncOp::create(mlir::NameLoc::get(builder->getIdentifier("fused_func"), - builder->getContext()), + FuncOp::create(mlir::NameLoc::get(builder->getIdentifier("fused_func")), "fused_func", func_type, {}); func.addEntryBlock(); diff --git a/tensorflow/compiler/mlir/lite/utils/perception_ops_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils_test.cc index 523545f656f..34e8cbb3cc2 100644 --- a/tensorflow/compiler/mlir/lite/utils/perception_ops_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils_test.cc @@ -38,8 +38,7 @@ FuncOp createMaxUnpoolingFunc( const SmallVector& output_types) { auto func_type = builder->getFunctionType(input_types, output_types); auto func = - FuncOp::create(mlir::NameLoc::get(builder->getIdentifier("fused_func"), - builder->getContext()), + FuncOp::create(mlir::NameLoc::get(builder->getIdentifier("fused_func")), "fused_func", func_type, {}); func.addEntryBlock(); diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index a4dd8f114c2..abcca678a3d 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -163,11 +163,11 @@ gentbl( tf_ops_category_list = [ { "name": "ops_a_m", - "include": "tf.[A-M].*$$", + "include": "tf.[A-M].*$", }, { "name": "ops_n_z", - "include": "tf.[N-Z].*$$", + "include": "tf.[N-Z].*$", }, ] @@ -177,11 +177,11 @@ tf_ops_category_list = [ compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( - "-gen-op-decls -op-include-regex='" + target["include"] + "'", + "-gen-op-decls -op-include-regex=" + target["include"], "ir/tf_" + target["name"] + ".h.inc", ), ( - "-gen-op-defs -op-include-regex='" + target["include"] + "'", + "-gen-op-defs -op-include-regex=" + target["include"], "ir/tf_" + target["name"] + ".cc.inc", ), ], @@ -198,11 +198,11 @@ gentbl( compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( - "-gen-op-decls -op-exclude-regex='" + "|".join([target["include"] for target in tf_ops_category_list]) + "' ", + "-gen-op-decls -op-exclude-regex=" + "|".join([target["include"] for target in tf_ops_category_list]), "ir/tf_remaining_ops.h.inc", ), ( - "-gen-op-defs -op-exclude-regex='" + "|".join([target["include"] for target in tf_ops_category_list]) + "' ", + "-gen-op-defs -op-exclude-regex=" + "|".join([target["include"] for target in tf_ops_category_list]), "ir/tf_remaining_ops.cc.inc", ), ], @@ -964,6 +964,7 @@ cc_library( "transforms/tpu_space_to_depth_pass.cc", "transforms/tpu_update_embedding_enqueue_op_inputs.cc", "transforms/tpu_variable_runtime_reformatting.cc", + "transforms/verify_suitable_for_graph_export_pass.cc", "translate/breakup-islands.cc", "translate/tf_executor_to_functional.cc", "translate/tf_functional_to_executor.cc", @@ -1010,6 +1011,7 @@ cc_library( ":translate_utils", ":unroll_batch_matmul_pass", ":verification_utils", + ":verify_suitable_for_graph_export", ":visitor_util", ":xla_sharding_util", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", @@ -1151,6 +1153,7 @@ cc_library( ":tf_saved_model_passes", ":translate_utils", ":upgrade_graph", + ":verify_suitable_for_graph_export", "//tensorflow/cc/saved_model:bundle_v2", "//tensorflow/cc/saved_model:constants", "//tensorflow/cc/saved_model:loader_lite", @@ -2106,3 +2109,14 @@ cc_library( "@llvm-project//mlir:Support", ], ) + +cc_library( + name = "verify_suitable_for_graph_export", + srcs = ["utils/verify_suitable_for_graph_export.cc"], + hdrs = ["utils/verify_suitable_for_graph_export.h"], + deps = [ + ":tensorflow", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc index de7861c19b1..8e3c7cf6642 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc @@ -226,8 +226,7 @@ bool OpIsKnownToHaveNoSideEffect(Operation* op) { if (isa(op)) return true; // For op's in the Tensorflow dialect, query the dialect. - if (op->getName().getDialect() == - TF::TensorFlowDialect::getDialectNamespace()) + if (isa_and_nonnull(op->getDialect())) return !TensorFlowDialect::CanHaveSideEffects(op); // Otherwise, conservatively assume that there can be side effects. diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc index 739b24e159a..68d1ae2bc83 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc @@ -413,7 +413,7 @@ void Print(ReplicateOp op, OpAsmPrinter* p) { // Skip derived `operand_segment_sizes` attribute as custom print format of // operands holds enough information to calculate these variadic operand list // lengths. - p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/ArrayRef{ + p->printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/ArrayRef{ kOperandSegmentSizesAttr}); p->printRegion(op.body(), /*printEntryBlockArgs=*/false); } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index ea86df310e1..4b76ce0f9d4 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -213,7 +213,7 @@ LogicalResult Verify(GraphOp graph) { void Print(GraphOp graph, OpAsmPrinter &p) { p << graph.getOperationName(); p.printRegion(graph.getOperation()->getRegion(0)); - p.printOptionalAttrDict(graph.getAttrs()); + p.printOptionalAttrDict(graph->getAttrs()); } ParseResult ParseGraphOp(OpAsmParser &parser, OperationState &result) { @@ -321,7 +321,7 @@ void Print(IslandOp op, OpAsmPrinter &p) { // Check if we can print the short "wraps" form: that is if the island // contains a single operation and the result of this operation are perfectly // forwarded to the yield. - if (op.getAttrs().empty() && op.WrapsSingleOp()) { + if (op->getAttrs().empty() && op.WrapsSingleOp()) { Operation &wrapped_op = op.GetBody().front(); YieldOp yield_op = op.GetYield(); // The "wraps" syntax only encodes a single location. @@ -335,7 +335,7 @@ void Print(IslandOp op, OpAsmPrinter &p) { } } p.printRegion(op.getOperation()->getRegion(0)); - p.printOptionalAttrDict(op.getAttrs()); + p.printOptionalAttrDict(op->getAttrs()); } ParseResult ParseIslandOp(OpAsmParser &parser, OperationState &result) { @@ -449,7 +449,7 @@ void Print(SwitchOp switch_op, OpAsmPrinter &p) { } else { p << switch_op.getType(0); } - p.printOptionalAttrDict(switch_op.getAttrs()); + p.printOptionalAttrDict(switch_op->getAttrs()); } } // anonymous namespace @@ -525,7 +525,7 @@ void Print(SwitchNOp switchn, OpAsmPrinter &p) { p << ")"; } p << " : " << switchn.getType(0); - p.printOptionalAttrDict(switchn.getAttrs(), {"num_outs"}); + p.printOptionalAttrDict(switchn->getAttrs(), {"num_outs"}); } ParseResult ParseSwitchNOp(OpAsmParser &parser, OperationState &result) { @@ -655,7 +655,7 @@ void Print(MergeOp merge, OpAsmPrinter &p) { p << output_type; } - p.printOptionalAttrDict(merge.getAttrs()); + p.printOptionalAttrDict(merge->getAttrs()); } ParseResult ParseMergeOp(OpAsmParser &parser, OperationState &result) { @@ -723,7 +723,7 @@ void Print(EnterOp enter, OpAsmPrinter &p) { p << enter.getType(0); } - p.printOptionalAttrDict(enter.getAttrs(), + p.printOptionalAttrDict(enter->getAttrs(), {"frame_name", "parallel_iterations", "is_constant"}); } @@ -843,7 +843,7 @@ void Print(ExitOp exit, OpAsmPrinter &p) { p << exit.getOperationName() << ' '; p.printOperands(exit.getOperands()); p << " : " << exit.getType(0); - p.printOptionalAttrDict(exit.getAttrs()); + p.printOptionalAttrDict(exit->getAttrs()); } ParseResult ParseExitOp(OpAsmParser &parser, OperationState &result) { @@ -887,7 +887,7 @@ void Print(LoopCondOp loop_cond, OpAsmPrinter &p) { p << " : " << loop_cond.input().getType(); } - p.printOptionalAttrDict(loop_cond.getAttrs()); + p.printOptionalAttrDict(loop_cond->getAttrs()); } ParseResult ParseLoopCondOp(OpAsmParser &parser, OperationState &result) { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 007840a4598..a7b28a83a86 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -3887,6 +3887,8 @@ Returns the input tensor otherwise. ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let hasFolder = 1; } def TF_EqualOp : TF_Op<"Equal", [Commutative, NoSideEffect]> { @@ -8563,6 +8565,28 @@ the result here is consistent with a truncating divide. E.g. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_ModelDatasetOp : TF_Op<"ModelDataset", [NoSideEffect]> { + let summary = "Identity transformation that models performance."; + + let description = [{ +Identity transformation that models performance. + }]; + + let arguments = (ins + Arg:$input_dataset, + + DefaultValuedAttr:$algorithm, + DefaultValuedAttr:$cpu_budget, + DefaultValuedAttr:$ram_budget, + Confined]>:$output_types, + Confined]>:$output_shapes + ); + + let results = (outs + TF_VariantTensor:$handle + ); +} + def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_SameOperandsAndResultElementTypeResolveRef]>, WithBroadcastableBinOpBuilder { let summary = "Returns x * y element-wise."; @@ -9217,6 +9241,31 @@ def TF_OnesLikeOp : TF_Op<"OnesLike", [Idempotent, NoSideEffect, SameOperandsAnd TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_OptimizeDatasetV2Op : TF_Op<"OptimizeDatasetV2", [NoSideEffect]> { + let summary = [{ +Creates a dataset by applying related optimizations to `input_dataset`. + }]; + + let description = [{ +Creates a dataset by applying related optimizations to `input_dataset`. + }]; + + let arguments = (ins + Arg:$input_dataset, + Arg:$optimizations_enabled, + Arg:$optimizations_disabled, + Arg:$optimizations_default, + + Confined]>:$output_types, + Confined]>:$output_shapes, + DefaultValuedAttr:$optimization_configs + ); + + let results = (outs + TF_VariantTensor:$handle + ); +} + def TF_OptionalGetValueOp : TF_Op<"OptionalGetValue", [NoSideEffect]> { let summary = [{ Returns the value stored in an Optional variant or raises an error if none exists. @@ -9305,6 +9354,8 @@ This is the opposite of `unpack`. return Verify(*this); }]; + let hasCanonicalizer = 1; + let hasFolder = 1; } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index a7589e83fc3..281d3da545e 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -888,7 +888,7 @@ class CaseOrIfRegionEliminatePassThrough // Create new case/if region op. auto new_op = rewriter.create( - op.getLoc(), new_result_types, op.getOperand(), op.getAttrs(), + op.getLoc(), new_result_types, op.getOperand(), op->getAttrs(), op.getNumRegions()); int next_index = 0; @@ -2092,6 +2092,19 @@ static LogicalResult Verify(EmptyTensorListOp op) { return success(); } +//===----------------------------------------------------------------------===// +// EnsureShapeOp +//===----------------------------------------------------------------------===// + +OpFoldResult EnsureShapeOp::fold(llvm::ArrayRef) { + ShapedType type = input().getType().dyn_cast(); + if (!type || !type.hasRank()) return {}; + // If shape attribute equals input operand's type's shape, fold it to input. + if (type.getShape() == shape()) return input(); + // Else retain to enable failing dynamically. + return {}; +} + //===----------------------------------------------------------------------===// // EqualOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index 72204daba92..6aa3e20da86 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -302,6 +302,49 @@ OpFoldResult PackOp::fold(ArrayRef operands) { return slice_op.input(); } +// Convert Pack to Reshape when there is only one operand to be packed. +// For example, +// +// %0 = tf.Pack(%input) {axis = 0} // %input : tensor<2x3xf32> +// +// can be canonicalized to +// +// %shape = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi64>} +// %0 = tf.Reshape(%input, %shape) +struct ConvertPackToReshape : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PackOp pack_op, + PatternRewriter &rewriter) const override { + // Check if there is only one operand to be packed. + if (pack_op.N() != 1) { + return failure(); + } + + // Check if input and output are static. + auto input_ty = pack_op.getOperand(0).getType().cast(); + auto output_ty = pack_op.output().getType().cast(); + if (!input_ty.hasStaticShape() || !output_ty.hasStaticShape()) { + return failure(); + } + + // Create constant shape for reshape. + auto type = + RankedTensorType::get(output_ty.getRank(), rewriter.getIntegerType(64)); + auto shape_attr = DenseIntElementsAttr::get(type, output_ty.getShape()); + auto shape = rewriter.create(pack_op.getLoc(), shape_attr); + + rewriter.replaceOpWithNewOp(pack_op, output_ty, + pack_op.getOperand(0), shape); + return success(); + } +}; + +void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // PadOp //===----------------------------------------------------------------------===// @@ -2587,7 +2630,7 @@ class ConvertFusedBatchNorm : public OpRewritePattern { TF::FusedBatchNormV3Op::getOperationName(), tf_fused_batch_norm_op.getOperands(), new_result_types, - tf_fused_batch_norm_op.getAttrs()); + tf_fused_batch_norm_op->getAttrs()); Operation *tf_fused_batch_norm_op_v3 = rewriter.createOperation(new_state); rewriter.replaceOp(tf_fused_batch_norm_op, @@ -3029,9 +3072,9 @@ struct WhileRegionEliminatePassThrough } // Create the new while operation. - auto new_while_op = - rewriter.create(while_op.getLoc(), new_result_types, - new_while_operands, while_op.getAttrs()); + auto new_while_op = rewriter.create( + while_op.getLoc(), new_result_types, new_while_operands, + while_op->getAttrs()); // Move region bodies to the new while. rewriter.inlineRegionBefore(while_op.cond(), new_while_op.cond(), diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 002ee651c7e..fbf1529fa96 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -1643,3 +1643,39 @@ func @testFoldStridedSliceShapeWithEmptySlice(%arg0: tensor) -> (te // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> // CHECK: return %[[CST]] } + +// CHECK-LABEL: testFoldEnsureShapeOp +func @testFoldEnsureShapeOp(%arg0: tensor<10x20xf32>) -> (tensor<10x20xf32>, tensor<20x10xf32>) { + %0 = "tf.EnsureShape"(%arg0) {shape = #tf.shape<10x20>} : (tensor<10x20xf32>) -> tensor<10x20xf32> + // Failing case which should not be folded. + // CHECK: %[[NF:.*]] = "tf.EnsureShape"(%arg0) {shape = #tf.shape<20x10>} + %1 = "tf.EnsureShape"(%arg0) {shape = #tf.shape<20x10>} : (tensor<10x20xf32>) -> tensor<20x10xf32> + // CHECK: return %arg0, %[[NF]] + return %0, %1: tensor<10x20xf32>, tensor<20x10xf32> +} + +// CHECK-LABEL: testConvertPackToReshapeAxis0 +func @testConvertPackToReshapeAxis0(%arg0: tensor<2x3xf32>) -> tensor<1x2x3xf32> { + %0 = "tf.Pack"(%arg0) {axis = 0 : i64} : (tensor<2x3xf32>) -> tensor<1x2x3xf32> + return %0 : tensor<1x2x3xf32> + // CHECK: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi64>} : () -> tensor<3xi64> + // CHECK: %[[RESHAPE:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2x3xf32>, tensor<3xi64>) -> tensor<1x2x3xf32> + // CHECK: return %[[RESHAPE]] : tensor<1x2x3xf32> +} + +// CHECK-LABEL: testConvertPackToReshapeAxis1 +func @testConvertPackToReshapeAxis1(%arg0: tensor<2x3xf32>) -> tensor<2x1x3xf32> { + %0 = "tf.Pack"(%arg0) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1x3xf32> + return %0 : tensor<2x1x3xf32> + // CHECK: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[2, 1, 3]> : tensor<3xi64>} : () -> tensor<3xi64> + // CHECK: %[[RESHAPE:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2x3xf32>, tensor<3xi64>) -> tensor<2x1x3xf32> + // CHECK: return %[[RESHAPE]] : tensor<2x1x3xf32> +} + +// CHECK-LABEL: testDontConvertPackToReshapeDynamicShape +func @testDontConvertPackToReshapeDynamicShape(%arg0: tensor<2x?xf32>) -> tensor<1x2x?xf32> { + %0 = "tf.Pack"(%arg0) {axis = 0 : i64} : (tensor<2x?xf32>) -> tensor<1x2x?xf32> + return %0 : tensor<1x2x?xf32> + // CHECK: %[[PACK:.*]] = "tf.Pack"(%arg0) {axis = 0 : i64} : (tensor<2x?xf32>) -> tensor<1x2x?xf32> + // CHECK: return %[[PACK]] : tensor<1x2x?xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/drop_while_shape_invariant.mlir b/tensorflow/compiler/mlir/tensorflow/tests/drop_while_shape_invariant.mlir index b20776c7c8a..f4dc5663641 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/drop_while_shape_invariant.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/drop_while_shape_invariant.mlir @@ -1,10 +1,29 @@ // RUN: tf-opt %s -tf-drop-while-shape-invariant | FileCheck %s +// RUN: tf-opt %s -tf-drop-while-shape-invariant-in-device-cluster | FileCheck -check-prefix=IN-CLUSTER %s -// CHECK-LABEL: while_shape_invariant + +func @while_cond(%arg0: tensor<*xf32>) -> tensor { + %0 = "tf.Const"() {value = dense : tensor} : () -> tensor + return %0 : tensor +} + +func @while_body(%arg0: tensor<*xf32>) -> (tensor<*xf32>) { + %0 = "tf.SomeOp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// Test that -tf-drop-while-shape-invariant-in-device-cluster pass does not drop +// the shape_invariant attribute from While/WhileRegion ops outside the device +// cluster, while the other pass drops them. + +// CHECK-LABEL: while_shape_invariant_outside_cluster // CHECK-NOT: shape_invariant -func @while_shape_invariant(%arg0: tensor<4xf32>) -> (tensor<*xf32>, tensor<*xf32>) { +// IN-CLUSTER-LABEL: while_shape_invariant_outside_cluster +func @while_shape_invariant_outside_cluster(%arg0: tensor<4xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + // IN-CLUSTER: shape_invariant %0 = "tf.While"(%arg0) {cond = @while_cond, body = @while_body, is_stateless = false, shape_invariant} : (tensor<4xf32>) -> (tensor<*xf32>) + // IN-CLUSTER: shape_invariant %1 = "tf.WhileRegion"(%arg0) ( { ^cond(%carg0: tensor<*xf32>): %2 = "tf.Const"() {value = dense : tensor} : () -> tensor @@ -18,12 +37,28 @@ func @while_shape_invariant(%arg0: tensor<4xf32>) -> (tensor<*xf32>, tensor<*xf3 return %0, %1 : tensor<*xf32>, tensor<*xf32> } -func @while_cond(%arg0: tensor<*xf32>) -> tensor { - %0 = "tf.Const"() {value = dense : tensor} : () -> tensor - return %0 : tensor -} +// Test that both passes drop the shape_invariant attribute from +// While/WhileRegion ops within a cluster. -func @while_body(%arg0: tensor<*xf32>) -> (tensor<*xf32>) { - %0 = "tf.SomeOp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> - return %0 : tensor<*xf32> +// CHECK-LABEL: while_shape_invariant_within_cluster +// CHECK-NOT: shape_invariant +// IN-CLUSTER-LABEL: while_shape_invariant_within_cluster +// IN-CLUSTER-NOT: shape_invariant +func @while_shape_invariant_within_cluster(%arg0: tensor<4xf32>) { + "tf_device.cluster"() ( { + %0 = "tf.While"(%arg0) {cond = @while_cond, body = @while_body, is_stateless = false, shape_invariant} : (tensor<4xf32>) -> (tensor<*xf32>) + + %1 = "tf.WhileRegion"(%arg0) ( { + ^cond(%carg0: tensor<*xf32>): + %2 = "tf.Const"() {value = dense : tensor} : () -> tensor + "tf.Yield"(%2) : (tensor) -> () + }, { + ^body(%barg0: tensor<*xf32>): + %2 = "tf.SomeOp"(%barg0) : (tensor<*xf32>) -> tensor<*xf32> + "tf.Yield"(%2) : (tensor<*xf32>) -> () + }) {is_stateless = false, shape_invariant} : (tensor<4xf32>) -> (tensor<*xf32>) + tf_device.return + }) {} : () -> () + + return } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir index 69dc04ea303..7d4df4217c0 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir @@ -213,11 +213,11 @@ func @testWhileResult(tensor<*xf32>) -> (tensor<*xf32>) { } : (tensor<*xf32>) -> (tensor<*xf32>) // CHECK: [[Result0:%.*]] = "tf.WhileRegion" - // CHECK: [[ResultCast0:%.*]] = "tf.Cast" - // CHECK: [[Result1:%.*]] = call @testWhileCond([[ResultCast0]]) + // CHECK: ^bb0(%[[CARG0:.*]]: tensor<4xf32> + // CHECK: [[Result1:%.*]] = call @testWhileCond(%[[CARG0]]) // CHECK: "tf.Yield"([[Result1]]) - // CHECK: [[ResultCast1:%.*]] = "tf.Cast" - // CHECK: [[Result2:%.*]] = call @testWhileBody([[ResultCast1]]) + // CHECK: ^bb0(%[[BARG0:.*]]: tensor<4xf32> + // CHECK: [[Result2:%.*]] = call @testWhileBody(%[[BARG0]]) // CHECK: "tf.Yield"([[Result2]]) // CHECK: return [[Result0]] return %1 : tensor<*xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/batch_use_same_function/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/batch_use_same_function/BUILD new file mode 100644 index 00000000000..c925849e87f --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/batch_use_same_function/BUILD @@ -0,0 +1,34 @@ +load("//tensorflow:tensorflow.bzl", "filegroup") +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") + +licenses(["notice"]) + +glob_lit_tests( + data = [ + ":debug_info_files", + ":test_utilities", + ], + driver = "@llvm-project//mlir:run_lit.sh", + test_file_exts = ["pbtxt"], +) + +# Bundle together all of the test utilities that are used by tests. +filegroup( + name = "test_utilities", + testonly = True, + data = [ + "//tensorflow/compiler/mlir:tf-mlir-translate", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:not", + ], +) + +# Bundle together all the debug info files that are used by the tests. +filegroup( + name = "debug_info_files", + srcs = glob( + [ + "**/*.debug", + ], + ), +) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/batch_use_same_function/saved_model.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/batch_use_same_function/saved_model.pbtxt new file mode 100644 index 00000000000..44b559f3ba3 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/batch_use_same_function/saved_model.pbtxt @@ -0,0 +1,137 @@ +#RUN: tf-mlir-translate --savedmodel-signaturedefs-to-mlir-lite -tf-savedmodel-tags=serve,tpu %p | FileCheck %s + +# Test importing a saved model with 2 signatures that are using a same +# BatchFunction Op, which references to a same inference_func from graph_def +# library. The result should be that both signatures uses the same +# BatchFunction Op (the shared_name is the same) and the same copy of +# inference_func. + +# CHECK: func @predict0 +# CHECK: f = @inference_func[[post_fix:[^,]*]] +# CHECK-SAME: shared_name = "batch" + +# CHECK: func @predict1 +# CHECK: f = @inference_func[[post_fix]], +# CHECK-SAME: shared_name = "batch" + +meta_graphs: { + meta_info_def: { + tags: ["serve", "tpu"] + } + graph_def: { + node: { + name: "input0" + op: "Placeholder" + attr: { + key: "dtype" + value: { + type: DT_STRING + } + } + } + node: { + name: "input1" + op: "Placeholder" + attr: { + key: "dtype" + value: { + type: DT_INT32 + } + } + } + node: { + name: "batch_func" + op: "BatchFunction" + input: ["input1"] + attr: { + key: "Tcaptured" + value: { + list: { + type: [] + } + } + } + attr: { + key: "Tin" + value: { + list: { + type: [DT_INT32] + } + } + } + attr: { + key: "Tout" + value: { + list: { + type: [DT_FLOAT, DT_FLOAT] + } + } + } + attr: { + key: "f" + value: { + func: { + name: "inference_func" + } + } + } + attr: { + key: "shared_name" + value: { + s: "batch" + } + } + } + library: { + function { + signature { + name: "inference_func" + input_arg { + name: "arg0" + type: DT_FLOAT + } + } + ret { + key: "retval0" + value: "arg0" + } + } + } + } + signature_def: { + key: "predict0" + value: { + inputs: { + key: "inputs" + value: { + name: "input0" + dtype: DT_STRING + } + } + outputs: { + key: "outputs" + value: { + name: "batch_func:0" + } + } + } + } + signature_def: { + key: "predict1" + value: { + inputs: { + key: "tf_example_input" + value: { + name: "input0" + dtype: DT_STRING + } + } + outputs: { + key: "outputs" + value: { + name: "batch_func:1" + } + } + } + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/empty-input-shapes.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/empty-input-shapes.pbtxt new file mode 100644 index 00000000000..61ccf82af77 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/empty-input-shapes.pbtxt @@ -0,0 +1,39 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s + +node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_BOOL + } + } +} + +node { + name: "func0" + op: "func_name" + input: "input" +} + +library { + function { + signature { + name: "func_name" + input_arg { + name: "arg0" + type: DT_BOOL + } + } + ret { + key: "retval0" + value: "arg0" + } + attr: { + key: "_input_shapes" + value: { + } + } + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/partial-device-name.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/partial-device-name.pbtxt new file mode 100644 index 00000000000..51a7fe35dfb --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/partial-device-name.pbtxt @@ -0,0 +1,76 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-data-types=DT_INT32,DT_INT32 -tf-input-shapes=10:10 -tf-output-arrays=Sub -o - | FileCheck %s + +node { + name: "Add" + op: "Add" + input: "input0" + input: "input1" + # If device type or id doesn't exist, assign a default one (device:CPU:0). + device: "/job:localhost/replica:0/task:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "Mul" + op: "Mul" + input: "Add" + input: "Add" + # Empty device name should be kept untouched. + device: "" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "Sub" + op: "Sub" + input: "Add" + input: "Mul" + # Device name is not modified if complete + device: "/job:localhost/replica:0/task:0/device:CPU:1" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "input0" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } +} +node { + name: "input1" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } +} +versions { + producer: 27 +} + +# CHECK-LABEL: func @main +# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<10xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<10xi32>) -> tensor<*xi32> +# CHECK-SAME: control_outputs = "" +# CHECK-SAME: inputs = "input0,input1" +# CHECK-SAME: outputs = "Sub" +# CHECK: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]]) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} +# CHECK: %[[mul:.*]], %[[mul_control:.*]] = tf_executor.island wraps "tf.Mul"(%[[add]], %[[add]]) {device = ""} +# CHECK: %[[sub:.*]], %[[sub_control:.*]] = tf_executor.island wraps "tf.Sub"(%[[add]], %[[mul]]) {device = "/job:localhost/replica:0/task:0/device:CPU:1"} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_end.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_end.mlir index a391e5f4459..283a4a8d29c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_end.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_end.mlir @@ -126,3 +126,20 @@ func @fold_into_fused_batch_norm(%arg0: tensor<1x64x112x112xf32>, %arg1: tensor< return %2#0 : tensor<1x112x112x64xf32> } + +// CHECK-LABEL: func @fold_into_pad_with_extra_uses +func @fold_into_pad_with_extra_uses(%arg0: tensor<1x2x4x4x3xf32>) -> (tensor<1x2x3x4x4xf32>, tensor<1x2x3x6x6xf32>) { + + // CHECK: %[[PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 1, 4, 2, 3]> : tensor<5xi32>} + // CHECK: %[[TRANSPOSE_OP:[0-9]*]] = "tf.Transpose"(%arg0, %[[PERM]]) + // CHECK: %[[PADDING:[0-9]*]] = "tf.Const"() {value = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<5x2xi32>} + // CHECK: %[[PAD_OP:[0-9]*]] = "tf.Pad"(%arg0, %[[PADDING]]) + // CHECK: %[[DUP_TRANSPOSE_OP:[0-9]*]] = "tf.Transpose"(%[[PAD_OP]], %[[PERM]]) + // CHECK: return %[[TRANSPOSE_OP]], %[[DUP_TRANSPOSE_OP]] + + %0 = "tf.Const"() {value = dense<[0, 1, 4, 2, 3]> : tensor<5xi32>} : () -> tensor<5xi32> + %1 = "tf.Const"() {value = dense<[[0, 0], [0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<5x2xi32>} : () -> tensor<5x2xi32> + %2 = "tf.Transpose"(%arg0, %0) : (tensor<1x2x4x4x3xf32>, tensor<5xi32>) -> tensor<1x2x3x4x4xf32> + %3 = "tf.Pad"(%2, %1) : (tensor<1x2x3x4x4xf32>, tensor<5x2xi32>) -> tensor<1x2x3x6x6xf32> + return %2, %3 : tensor<1x2x3x4x4xf32>, tensor<1x2x3x6x6xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index 120eee547ba..9f9fa73302d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -1910,6 +1910,21 @@ func @convert_gather_nd(%arg0: tensor<98x128xf32>, %arg1: tensor<4x64xi32>) -> t return %0 : tensor<4x64x128xf32> } +// CHECK-LABEL: func @convert_gather_transpose( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<128x256xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>) -> tensor<4x128xf32> { +// CHECK: %[[VAL_2:.*]] = "tf.Const"{{.*}}value = dense<[1, 0]> : tensor<2xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_2]]) : {{.*}} -> tensor<256x128xf32> +// CHECK: %[[VAL_4:.*]] = "tf.GatherNd"(%[[VAL_3]], %[[VAL_1]]) : {{.*}} -> tensor<4x128xf32> +// CHECK: return %[[VAL_4]] +// CHECK: } +// Test the case when start_index_map isn't an iota what requires a transpose to +// convert it to tf.GatherNd. +func @convert_gather_transpose(%arg0: tensor<128x256xf32>, %arg1: tensor<4x1xi32>) -> tensor<4x128xf32> { + %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<1> : tensor<1xi64>, index_vector_dim = 1 : i64, offset_dims = dense<1> : tensor<1xi64>, start_index_map = dense<1> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[128, 1]> : tensor<2xi64>} : (tensor<128x256xf32>, tensor<4x1xi32>) -> tensor<4x128xf32> + return %0 : tensor<4x128xf32> +} + // CHECK-LABEL: func @convert_dynamic_slice( // CHECK-SAME: %[[VAL_0:.*]]: tensor<7x3xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor, diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/export_main_to_flib.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/export_main_to_flib.mlir new file mode 100644 index 00000000000..03fc9783d78 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/export_main_to_flib.mlir @@ -0,0 +1,20 @@ +// RUN: tf-mlir-translate -mlir-to-graphdef -tf-export-entry-func-to-flib %s -o - 2>&1 | FileCheck %s + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 458 : i32}} { + func @main() { + tf_executor.graph { + %0:2 = tf_executor.island wraps "tf.Const"() {device = "TPU:0", name = "const", dtype = "tfdtype$DT_INT32", value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32> + tf_executor.fetch + } + return + } +} + +// CHECK-NOT: node + +// CHECK: library +// CHECK-NEXT: function +// CHECK-NEXT: signature +// CHECK-NEXT: name: "main" +// CHECK: node_def +// CHECK: op: "Const" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/invalid_input.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/invalid_input.mlir index 883fbe647b9..1ee2de2c937 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/invalid_input.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/invalid_input.mlir @@ -9,7 +9,7 @@ func @main() { return } -// CHECK: Functions must be of a single Graph with single op Islands: only single block functions are supported. +// CHECK: functions must be of a single Graph with single op Islands: only single block functions are supported // ----- @@ -19,7 +19,7 @@ func @main() { return } -// CHECK: Functions must be of a single Graph with single op Islands: first op in function is not a tf_executor.graph. +// CHECK: functions must be of a single Graph with single op Islands: first op in function is not a tf_executor.graph // ----- @@ -33,7 +33,7 @@ func @main() { return } -// CHECK: Functions must be of a single Graph with single op Islands: function does not only contain a single tf_executor.graph. +// CHECK: functions must be of a single Graph with single op Islands: function does not only contain a single tf_executor.graph // ----- @@ -47,7 +47,7 @@ func @main() { return } -// CHECK: Functions must be of a single Graph with single op Islands: tf_executor.island must perfectly wrap a single op. +// CHECK: functions must be of a single Graph with single op Islands: tf_executor.island must perfectly wrap a single op // ----- @@ -63,7 +63,7 @@ func @main() { return } -// CHECK: Functions must be of a single Graph with single op Islands: tf_executor.island must perfectly wrap a single op. +// CHECK: functions must be of a single Graph with single op Islands: tf_executor.island must perfectly wrap a single op // ----- @@ -78,7 +78,7 @@ func @main() { return } -// CHECK: Functions must be of a single Graph with single op Islands: tf_executor.island must perfectly wrap a single op. +// CHECK: functions must be of a single Graph with single op Islands: tf_executor.island must perfectly wrap a single op // ----- @@ -93,4 +93,4 @@ func @main(%arg0: tensor, %arg1: tensor) { return } -// CHECK: Functions must be of a single Graph with single op Islands: tf_executor.island must perfectly wrap a single op. +// CHECK: functions must be of a single Graph with single op Islands: tf_executor.island must perfectly wrap a single op diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index a5ac3a9f00a..151bbe0829c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -470,8 +470,9 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK: tf.TensorListSetItem{{.*}}: (tensor>>, tensor, tensor<2x2xf32>) -> tensor>> %6 = "tf.TensorListSetItem"(%3, %4, %5) {device = ""} : (tensor>>, tensor, tensor<2x2xf32>)-> tensor<*x!tf.variant> %7 = "tf.Const"() {device = "", value = dense<-1> : tensor} : () -> tensor + %8 = "tf.StopGradient"(%6) : (tensor<*x!tf.variant>) -> tensor<*x!tf.variant> // CHECK: tf.TensorListStack{{.*}}: (tensor>>, tensor) -> tensor - %8 = "tf.TensorListStack"(%6, %7) {device = "", num_elements = -1 : i64} : (tensor<*x!tf.variant>, tensor) -> tensor<*xf32> + %9 = "tf.TensorListStack"(%8, %7) {device = "", num_elements = -1 : i64} : (tensor<*x!tf.variant>, tensor) -> tensor<*xf32> tf_executor.yield } tf_executor.fetch diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_duplicate_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_duplicate_v1.py index 67563fad5d5..ab786ac8300 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_duplicate_v1.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_duplicate_v1.py @@ -27,19 +27,16 @@ from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1 # CHECK: func {{.*}} tf_saved_model.exported_names = ["key_1"] # CHECK: "tf.If" -# CHECK-SAME: else_branch = @[[else_1:"key_1/[a-zA-Z_0-9]+"]] -# CHECK-SAME: then_branch = @[[then_1:"key_1/[a-zA-Z_0-9]+"]] +# CHECK-SAME: else_branch = @[[else:[a-zA-Z_0-9]+]] +# CHECK-SAME: then_branch = @[[then:[a-zA-Z_0-9]+]] # CHECK: func {{.*}} tf_saved_model.exported_names = ["key_2"] # CHECK: "tf.If" -# CHECK-SAME: else_branch = @[[else_2:"key_2/[a-zA-Z_0-9]+"]] -# CHECK-SAME: then_branch = @[[then_2:"key_2/[a-zA-Z_0-9]+"]] +# CHECK-SAME: else_branch = @[[else]] +# CHECK-SAME: then_branch = @[[then]] -# CHECK: func private @[[else_1]]( -# CHECK: func private @[[then_1]]( - -# CHECK: func private @[[else_2]]( -# CHECK: func private @[[then_2]]( +# CHECK: func private @[[else]]( +# CHECK: func private @[[then]]( def Test(): diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index 56ecc355646..d299cb456fa 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -39,21 +39,8 @@ void EnableLogging(PassManager *pm) { } // namespace namespace TFTPU { + namespace { -void AddGraphExportLoweringPasses(OpPassManager &pm) { - auto add_pass = [&](std::unique_ptr pass) { - pm.addNestedPass(std::move(pass)); - pm.addPass(CreateBreakUpIslandsPass()); - }; - - add_pass(CreateFunctionalToExecutorDialectConversionPass()); - add_pass(TFDevice::CreateReplicateToIslandPass()); - add_pass(TFDevice::CreateParallelExecuteToIslandsPass()); - add_pass(TFDevice::CreateLaunchToDeviceAttributePass()); - pm.addNestedPass(CreateTPUDevicePropagationPass()); - pm.addPass(createSymbolDCEPass()); -} - tensorflow::Status RunTPUBridge( ModuleOp module, bool enable_logging, llvm::function_ref pipeline_builder) { @@ -68,7 +55,7 @@ tensorflow::Status RunTPUBridge( pipeline_builder(bridge); // Add set of passes to lower back to graph (from tf_executor). - AddGraphExportLoweringPasses(bridge); + TF::AddGraphExportLoweringPasses(bridge); // Run the bridge on the module, in case of failure, the `diag_handler` // converts MLIR errors emitted to the MLIRContext into a tensorflow::Status. @@ -110,13 +97,19 @@ void CreateTPUBridgePipeline(OpPassManager &pm) { func_pm.addPass(CreateTPUHostComputationExpansionPass()); func_pm.addPass(CreateTPUUpdateEmbeddingEnqueueOpInputsPass()); } - // Run another shape inference pass because resource decomposition might have - // created new partial types. - pm.addPass(TF::CreateTFShapeInferencePass()); + // Note that the region-based control-flow produced here still contains // function call ops which get inlined by the subsequent inliner pass. pm.addPass(TF::CreateTFFunctionalControlFlowToRegions()); pm.addPass(mlir::createInlinerPass()); + pm.addNestedPass( + TF::CreateDropWhileShapeInvariantInDeviceClusterPass()); + // Run another shape inference pass because resource decomposition might have + // created new partial types. Also, after dropping `shape_invariant` attribute + // from While/WhileRegion ops within cluster would lead to more precise + // shapes. + pm.addPass(TF::CreateTFShapeInferencePass()); + pm.addNestedPass(createCanonicalizerPass()); pm.addPass(CreateTPUClusterCleanupAttributesPass()); pm.addPass(TFDevice::CreateResourceOpLiftingPass()); pm.addNestedPass(createCSEPass()); @@ -166,6 +159,21 @@ tensorflow::Status TPUBridgeV1Compat(ModuleOp module, bool enable_logging) { namespace TF { +void AddGraphExportLoweringPasses(OpPassManager &pm) { + auto add_pass = [&](std::unique_ptr pass) { + pm.addNestedPass(std::move(pass)); + pm.addPass(CreateBreakUpIslandsPass()); + }; + + add_pass(CreateFunctionalToExecutorDialectConversionPass()); + add_pass(TFDevice::CreateReplicateToIslandPass()); + add_pass(TFDevice::CreateParallelExecuteToIslandsPass()); + add_pass(TFDevice::CreateLaunchToDeviceAttributePass()); + pm.addPass(TFTPU::CreateTPUDevicePropagationPass()); + pm.addPass(createSymbolDCEPass()); + pm.addPass(CreateVerifySuitableForExportPass()); +} + tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module, bool enable_logging, bool enable_inliner) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc index 080aea2521e..c05581fd202 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc @@ -20,20 +20,32 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +namespace mlir { +namespace TFTPU { +extern void AddGraphExportLoweringPasses(OpPassManager &pm); +} // namespace TFTPU +} // namespace mlir + namespace { -// Registers an existing pipeline builder function. +// Registers a pipeline builder function for TF TPU bridge. mlir::PassPipelineRegistration<> tpu_pipeline( "tf-tpu-bridge", "Run all the passes involved in transforming the graph before execution so " "that it is suitable for targeting TPUs.", mlir::TFTPU::CreateTPUBridgePipeline); -// Registers an existing pipeline builder function. +// Registers a pipeline builder function for TF TPU V1 bridge. mlir::PassPipelineRegistration<> tpu_pipeline_v1( "tf-tpu-bridge-v1", "Run all the passes involved in transforming a TensorFlow V1 graph before " "execution so that it is suitable for targeting TPUs.", mlir::TFTPU::CreateTPUBridgePipelineV1); +// Registers a pipeline builder function for TF Graph export. +mlir::PassPipelineRegistration<> tpu_export( + "tf-graph-export", + "Run passes to prepare for exporting module back to TF Graph.", + mlir::TF::AddGraphExportLoweringPasses); + } // anonymous namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc index 1705da3eab8..897ffe91117 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc @@ -112,7 +112,7 @@ void OutlineCluster(tf_device::ClusterOp cluster_op, SymbolTable* symbol_table, builder->setInsertionPoint(cluster_op); auto cluster_func_op = builder->create( cluster_op.getLoc(), outlined_func.getType().getResults(), - live_ins.getArrayRef(), cluster_op.getAttrs()); + live_ins.getArrayRef(), cluster_op->getAttrs()); cluster_op.replaceAllUsesWith(cluster_func_op); cluster_op.erase(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/drop_while_shape_invariant.cc b/tensorflow/compiler/mlir/tensorflow/transforms/drop_while_shape_invariant.cc index cd55dbcf6ed..0a433f4e030 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/drop_while_shape_invariant.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/drop_while_shape_invariant.cc @@ -15,6 +15,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" namespace mlir { @@ -32,22 +33,50 @@ class DropWhileShapeInvariantPass void runOnFunction() override; }; +// Drop `shape_invariant` attribute from tf.While and tf.WhileRegion op only +// inside device cluster. This would allow shape inference pass to further +// refine operand/result shapes of these ops. This is only safe to do when +// compiling to XLA. +class DropWhileShapeInvariantInDeviceClusterPass + : public PassWrapper { + void runOnFunction() override; +}; + +void DropWhileShapeInvariantAttr(Operation* op) { + if (llvm::isa(op)) + op->removeAttr(kShapeInvariantAttr); +} void DropWhileShapeInvariantPass::runOnFunction() { - getFunction().walk([](Operation* op) { - if (llvm::isa(op)) - op->removeAttr(kShapeInvariantAttr); + getFunction().walk([](Operation* op) { DropWhileShapeInvariantAttr(op); }); +} + +void DropWhileShapeInvariantInDeviceClusterPass::runOnFunction() { + getFunction().walk([](tf_device::ClusterOp cluster) { + cluster.walk([](Operation* op) { DropWhileShapeInvariantAttr(op); }); }); } -static PassRegistration pass( +static PassRegistration drop_shape_invariant_pass( "tf-drop-while-shape-invariant", "Drop `shape_invariant` attrbute from While/WhileRegion ops."); +static PassRegistration + drop_shape_invariant_in_cluster_pass( + "tf-drop-while-shape-invariant-in-device-cluster", + "Drop `shape_invariant` attrbute from While/WhileRegion ops inside " + "device cluster."); + } // namespace std::unique_ptr> CreateDropWhileShapeInvariantPass() { return std::make_unique(); } +std::unique_ptr> +CreateDropWhileShapeInvariantInDeviceClusterPass() { + return std::make_unique(); +} + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc index 12da23466ca..a5cb293eca9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc @@ -63,7 +63,7 @@ YieldOp CreateCall(Operation* op, FuncOp func, Region& caller_region, Block* entry = builder.createBlock(&caller_region); if (use_region_args) { - entry->addArguments(args.getType()); + entry->addArguments(func.getType().getInputs()); args = entry->getArguments(); } llvm::SmallVector casted_args; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc index c9ea0be67b1..22e25e5eb0d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc @@ -156,7 +156,7 @@ class FuseContractionWithBiasAdd : public OpRewritePattern { // The fused contraction has the same attributes as the original // contraction, with two additions: the list of ops which have been fused // together; epsilon (only with FusedBatchNorm). - std::vector attrs = contraction.getAttrs(); + std::vector attrs = contraction->getAttrs(); ArrayAttr fused_ops_attr = ArrayAttr::get(context, fused_ops); attrs.push_back( NamedAttribute(Identifier::get("fused_ops", context), fused_ops_attr)); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc index baa45f65c3b..2c4817cbe3d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc @@ -96,7 +96,7 @@ struct ReluToFusedBatchNorm : public OpRewritePattern { state.addOperands(batch_norm.getOperands()); if (side_input) state.operands.push_back(side_input); state.addTypes(batch_norm.getResultTypes()); - state.addAttributes(batch_norm.getAttrs()); + state.addAttributes(batch_norm->getAttrs()); Operation *op = rewriter.createOperation(state); rewriter.replaceOp(batch_norm, op->getResults()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc index f715890db93..135d2ae7c2b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc @@ -387,7 +387,7 @@ void MoveTransposeAfter(Operation* op, SmallVector* work_list, // Maybe add Transpose nodes for layout dependent results // (or reuse existing transposes). OpBuilder builder(op); - builder.setInsertionPoint(op); + builder.setInsertionPointAfter(op); for (unsigned idx : layout_dependent_results) { OpResult result = op->getResult(idx); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index 5bda6b51a3f..255e7d9dffd 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -1198,18 +1198,22 @@ class ConvertGatherOp : public OpConversionPattern { return failure(); } - // Verify that start_index_map and collapsed_slice_dims are both an iota - // with the same number of elements as the last dimension of start_indices. + // Verify that start_index_map and collapsed_slice_dims contains the same + // values. auto start_index_map = gather_op.dimension_numbers().start_index_map(); auto collapsed_slice_dims = gather_op.dimension_numbers().collapsed_slice_dims(); - if (!IsIotaAttr(start_index_map, start_indices_type.getShape().back()) || - !IsIotaAttr(collapsed_slice_dims, - start_indices_type.getShape().back())) { - // TODO(tberghammer): Transform start_indices to support non-standard - // start_index_maps. + if (start_index_map.getNumElements() != + collapsed_slice_dims.getNumElements()) { return rewriter.notifyMatchFailure( - gather_op, "unsupported start index map and/or collapsed slice dims"); + gather_op, + "different size for start index map and collapsed slice dims"); + } + for (auto c : collapsed_slice_dims) { + if (llvm::count(start_index_map, c) == 0) { + return rewriter.notifyMatchFailure( + gather_op, "collapsed slice dim isn't present in start index map"); + } } // Verify that slice_sizes is 1 for the indexed dimensions and the full @@ -1217,7 +1221,7 @@ class ConvertGatherOp : public OpConversionPattern { auto slice_sizes = gather_op.slice_sizes(); int64_t index = 0; for (int64_t s : slice_sizes.getValues()) { - if (index < start_indices_type.getShape().back()) { + if (llvm::count(start_index_map, index)) { if (s != 1) { return rewriter.notifyMatchFailure(gather_op, "unsupported slice sizes"); @@ -1242,6 +1246,25 @@ class ConvertGatherOp : public OpConversionPattern { ++offset; } + // Transpose the operand to handle non-iota start index map. + llvm::SmallVector transpose_dimensions; + llvm::SmallVector transpose_shape; + for (auto s : start_index_map) { + transpose_dimensions.push_back(s.getZExtValue()); + transpose_shape.push_back(operand_type.getShape()[s.getZExtValue()]); + } + for (int64_t i = 0, e = operand_type.getRank(); i < e; ++i) { + if (llvm::count(start_index_map, i) == 0) { + transpose_dimensions.push_back(i); + transpose_shape.push_back(operand_type.getShape()[i]); + } + } + operand_type = + RankedTensorType::get(transpose_shape, operand_type.getElementType()); + operand = rewriter.create( + gather_op.getLoc(), operand_type, operand, + rewriter.getI64TensorAttr(transpose_dimensions)); + rewriter.replaceOpWithNewOp(gather_op, result_type, operand, start_indices); return success(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc index 471da7c3bb6..b5251ca2083 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc @@ -67,7 +67,6 @@ LogicalResult LiftVariablesFromSession( ModuleOp module, Session* session, const SmallSet& resource_names) { OpBuilder builder(module.getBodyRegion()); - MLIRContext* context = module.getContext(); if (!session) return module.emitOpError() << "no session provided"; @@ -137,7 +136,7 @@ LogicalResult LiftVariablesFromSession( ElementsAttr tensor_attr = tensor_attr_or.ValueOrDie(); builder.create( - NameLoc::get(builder.getIdentifier(name.str()), context), + NameLoc::get(builder.getIdentifier(name.str())), builder.getStringAttr(name), tensor_attr, TypeAttr::get(tensor_attr.getType()), builder.getUnitAttr()); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc index 5c64f754694..be31338032d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc @@ -68,7 +68,7 @@ class ResourceAnalyzer { public: explicit ResourceAnalyzer(ModuleOp module) { for (auto func : module.getOps()) { - (void)AnalyzeFunc(func); + (void)AnalyzeRegion(func.getRegion()); } } @@ -82,18 +82,18 @@ class ResourceAnalyzer { } private: - // Analyze the specified func for resource mutating operations, namely + // Analyze the specified region for resource mutating operations, namely // TF::AssignVariableOp, if so, set the resource associated as "potentially - // written". Do this recursively across the chain of funcs via call or control - // flow ops. + // written". Do this recursively across the chain of regions via call or + // control flow ops. // TODO(ashwinm): Move to iterative traversal. - LogicalResult AnalyzeFunc(FuncOp func) { + LogicalResult AnalyzeRegion(Region& region) { // Avoid infinite recursion. - if (!discovered_.insert(func).second) { + if (!discovered_.insert(®ion).second) { return success(); } - func.walk([&](Operation* op) { + region.walk([&](Operation* op) { if (isa(op)) { return; } @@ -103,23 +103,40 @@ class ResourceAnalyzer { } if (auto call = dyn_cast(op)) { if (auto func = dyn_cast(call.resolveCallable())) { - PropagatePotentiallyWrittenUpFromCallee(func, call.getArgOperands()); + PropagatePotentiallyWrittenUpFromCallee(func.getRegion(), + call.getArgOperands()); } return; } if (auto if_op = dyn_cast(op)) { for (auto callee : {if_op.then_function(), if_op.else_function()}) { - PropagatePotentiallyWrittenUpFromCallee(callee, if_op.input()); + PropagatePotentiallyWrittenUpFromCallee(callee.getRegion(), + if_op.input()); } return; } + if (auto if_op = dyn_cast(op)) { + PropagatePotentiallyWrittenUpFromCallee(if_op.then_branch(), + if_op.getODSOperands(1)); + PropagatePotentiallyWrittenUpFromCallee(if_op.else_branch(), + if_op.getODSOperands(1)); + return; + } if (auto while_op = dyn_cast(op)) { for (auto callee : {while_op.cond_function(), while_op.body_function()}) { - PropagatePotentiallyWrittenUpFromCallee(callee, while_op.input()); + PropagatePotentiallyWrittenUpFromCallee(callee.getRegion(), + while_op.input()); } return; } + if (auto while_op = dyn_cast(op)) { + PropagatePotentiallyWrittenUpFromCallee(while_op.cond(), + while_op.input()); + PropagatePotentiallyWrittenUpFromCallee(while_op.body(), + while_op.input()); + return; + } // For all other ops, we assume it mutates all resources it uses, so // this errs on the side of being conservative. We should improve // this by using either a property or a trait that clearly @@ -144,14 +161,14 @@ class ResourceAnalyzer { }); } - // Given a FuncOp associated with the callee and operands from the + // Given a Region associated with the callee and operands from the // corresponding callOp, propagate the potentially written decision to the - // callOp's operands, if the corresponding func's arguments are potentially + // callOp's operands, if the corresponding region's arguments are potentially // written resources. void PropagatePotentiallyWrittenUpFromCallee( - FuncOp func, Operation::operand_range propagate_to) { - (void)AnalyzeFunc(func); - for (auto t : llvm::zip(func.getArguments(), propagate_to)) { + Region& region, Operation::operand_range propagate_to) { + (void)AnalyzeRegion(region); + for (auto t : llvm::zip(region.getArguments(), propagate_to)) { if (!IsResource(std::get<0>(t))) { continue; } @@ -172,8 +189,8 @@ class ResourceAnalyzer { // Value: Information we know about that Value. // Note that these Value's are in general in different functions. DenseMap resource_infos_; - // The set of func's we already discovered. - DenseSet discovered_; + // The set of regions we already discovered. + DenseSet discovered_; }; bool IsImmutable(GlobalTensorOp global_tensor, @@ -231,7 +248,7 @@ void MarkGlobalTensorsImmutable( auto global_tensor = kv.first; const auto& global_tensor_uses = kv.second; if (IsImmutable(global_tensor, global_tensor_uses, resource_analyzer)) { - global_tensor.removeAttr("is_mutable"); + global_tensor->removeAttr("is_mutable"); } } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index d80f8af6917..31438b9f6cd 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -43,6 +43,11 @@ namespace TF { // ops. std::unique_ptr> CreateDropWhileShapeInvariantPass(); +// Creates a pass that drops `shape_invariant` attribute from While/WhileRegion +// ops within device cluster. +std::unique_ptr> +CreateDropWhileShapeInvariantInDeviceClusterPass(); + // Transforms functional control flow operations in the TensorFlow dialect to // MLIR Control Flow Graph (CFG) form. std::unique_ptr> CreateTFFunctionalControlFlowToCFG(); @@ -203,6 +208,15 @@ std::unique_ptr> CreateCrossHostTransferPass(); // will replicate the tf.Const op once for each device. std::unique_ptr> CreateConstantOpDeviceAssignmentPass(); +// Populates the supplied passmanager with the passes required to export +// to TensorFlow Graph. +void AddGraphExportLoweringPasses(OpPassManager& pm); + +// Returns pass that verifies whether all functions in module are of single +// tf_executor.graph and each tf_executor.island in tf_executor.graph only has a +// single op. +std::unique_ptr> CreateVerifySuitableForExportPass(); + } // namespace TF namespace tf_executor { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index e158fb7117f..db5f56d6ada 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -931,7 +931,7 @@ LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) { while_op.getLoc(), body.getType().getResults(), FilterRange(while_op.getOperands(), resource_arg_uses), - while_op.getAttrs()); + while_op->getAttrs()); // Prepare for AddLoadsStoresOutsideControlFlowOp(). llvm::SmallDenseMap> arg_data_type_and_updated_output_index; @@ -1035,7 +1035,7 @@ LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef branches) { FuncOp first_func = branches.front(); auto new_op = builder.create(op.getLoc(), first_func.getType().getResults(), - new_operands, op.getAttrs()); + new_operands, op->getAttrs()); // Prepare for AddLoadsStoresOutsideControlFlowOp() llvm::SmallDenseMap> arg_data_type_and_updated_output_index; @@ -1179,7 +1179,7 @@ void UpdatePartitionedCallOpWithNewCallee( FilterRange(call_op.args(), lifting_info.use_info); auto new_call = builder.create( call_op.getLoc(), lifting_info.lifted_callee.getType().getResults(), - new_operands, call_op.getAttrs()); + new_operands, call_op->getAttrs()); new_call->setAttr( "f", builder.getSymbolRefAttr(lifting_info.lifted_callee.getName())); AddLoadsStoresOutsideControlFlowOp( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 7660192f3da..fc53d17a93a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -172,121 +172,124 @@ RankedTensorType DropFirstDimension(Type type) { bool CanInferTensorListElementType(Value tensorlist, Value initial_element_shape, RankedTensorType* potential_element_type) { + DCOMMENT("CanInferTensorListElementType " << tensorlist << " with initial " + << initial_element_shape); // Verifies if the new element type has static shape and matches the potential // type passed from caller. Updates the potential_element_type, if not defined // yet. auto verify_and_update_potential_element_type = [&](RankedTensorType new_element_type) -> bool { + DCOMMENT("\t\tConsidering " << new_element_type << " with old " + << *potential_element_type); if (!new_element_type || !new_element_type.hasStaticShape()) return false; if (!*potential_element_type) { + DCOMMENT("\t\tUpdating potential_element_type " << new_element_type); *potential_element_type = new_element_type; return true; } return *potential_element_type == new_element_type; }; - // TensorLists are semantically immutable. For example, TensorListSetItem - // takes a TensorList as input and produces a TensorList as output. So to - // traverse modifications to TensorList and verify that all elements written - // to it have the same shape, we need to follow use-def chain of ops that - // (conceptually) modify it i.e., ops that take an input TensorList and - // produce an output TensorList. - for (auto& use : tensorlist.getUses()) { - if (auto push = llvm::dyn_cast(use.getOwner())) { - auto element_type = push.tensor().getType().dyn_cast(); - if (!verify_and_update_potential_element_type(element_type)) return false; - if (!CanInferTensorListElementType(push.output_handle(), - initial_element_shape, - potential_element_type)) - return false; - continue; - } - if (auto scatter = llvm::dyn_cast( - use.getOwner())) { - // For scatter op we can get the element shape by dropping the first - // dimension of the input tensor. - RankedTensorType element_type = - DropFirstDimension(scatter.tensor().getType()); - if (!verify_and_update_potential_element_type(element_type)) return false; - if (!CanInferTensorListElementType(scatter.output_handle(), - initial_element_shape, - potential_element_type)) - return false; - continue; - } - if (auto set_item = llvm::dyn_cast(use.getOwner())) { - auto element_type = - set_item.item().getType().dyn_cast(); - if (!verify_and_update_potential_element_type(element_type)) return false; - if (!CanInferTensorListElementType(set_item.output_handle(), - initial_element_shape, - potential_element_type)) - return false; - continue; - } - if (auto pop = llvm::dyn_cast(use.getOwner())) { - if (!CanInferTensorListElementType(pop.output_handle(), - initial_element_shape, - potential_element_type)) - return false; - continue; - } - if (auto resize = llvm::dyn_cast(use.getOwner())) { - if (!CanInferTensorListElementType(resize.output_handle(), - initial_element_shape, - potential_element_type)) - return false; - continue; - } - // WhileRegionOp can explicitly capture TensorList value to be used inside - // its regions. So we check the uses of corresponding block argument in each - // region and the use of TensorList returned using YieldOp. - if (auto while_region = llvm::dyn_cast(use.getOwner())) { - for (auto branch : while_region.getRegions()) { - if (!CanInferTensorListElementType( - branch->getArgument(use.getOperandNumber()), - initial_element_shape, potential_element_type)) - return false; - } - continue; - } - if (auto yield = llvm::dyn_cast(use.getOwner())) { - Operation* parent = yield->getParentOp(); - if (!CanInferTensorListElementType( - parent->getResult(use.getOperandNumber()), initial_element_shape, - potential_element_type)) - return false; - continue; - } - // Refining the tensor list element type might change the output of - // TensorListElementShape which is expected to be the originally assigned - // shape to TensorList init ops. So replace it with the original element - // shape value. - if (auto tl_element_shape = - dyn_cast(use.getOwner())) { - // If element types match, we can do a direct replacement. - if (getElementTypeOrSelf(tl_element_shape.getResult()) == - getElementTypeOrSelf(initial_element_shape.getType())) { - tl_element_shape.replaceAllUsesWith(initial_element_shape); - } else { - OpBuilder b(use.getOwner()); - auto cast_op = b.create( - use.getOwner()->getLoc(), tl_element_shape.getResult().getType(), - initial_element_shape, - /*truncate=*/b.getBoolAttr(false)); - tl_element_shape.replaceAllUsesWith(cast_op.getResult()); - } - continue; - } - // Ignore ops that just consume a TensorList and do not output another - // TensorList. - if (isa(use.getOwner())) - continue; + std::stack worklist; + worklist.emplace(tensorlist); - // For any other unknown users of the TensorList, we are conservative and - // stop element shape inference. - return false; + while (!worklist.empty()) { + tensorlist = worklist.top(); + worklist.pop(); + + // TensorLists are semantically immutable. For example, TensorListSetItem + // takes a TensorList as input and produces a TensorList as output. So to + // traverse modifications to TensorList and verify that all elements written + // to it have the same shape, we need to follow use-def chain of ops that + // (conceptually) modify it i.e., ops that take an input TensorList and + // produce an output TensorList. + for (auto& use : tensorlist.getUses()) { + if (auto push = llvm::dyn_cast(use.getOwner())) { + auto element_type = + push.tensor().getType().dyn_cast(); + if (!verify_and_update_potential_element_type(element_type)) + return false; + worklist.emplace(push.output_handle()); + continue; + } + if (auto scatter = llvm::dyn_cast( + use.getOwner())) { + // For scatter op we can get the element shape by dropping the first + // dimension of the input tensor. + RankedTensorType element_type = + DropFirstDimension(scatter.tensor().getType()); + if (!verify_and_update_potential_element_type(element_type)) + return false; + worklist.emplace(scatter.output_handle()); + continue; + } + if (auto set_item = llvm::dyn_cast(use.getOwner())) { + auto element_type = + set_item.item().getType().dyn_cast(); + DCOMMENT("\tTensorListSetItemOp " << element_type); + if (!verify_and_update_potential_element_type(element_type)) + return false; + worklist.emplace(set_item.output_handle()); + continue; + } + if (auto pop = llvm::dyn_cast(use.getOwner())) { + worklist.emplace(pop.output_handle()); + continue; + } + if (auto resize = llvm::dyn_cast(use.getOwner())) { + worklist.emplace(resize.output_handle()); + continue; + } + // WhileRegionOp can explicitly capture TensorList value to be used inside + // its regions. So we check the uses of corresponding block argument in + // each region and the use of TensorList returned using YieldOp. + if (auto while_region = llvm::dyn_cast(use.getOwner())) { + DCOMMENT("\tTL WhileRegion"); + for (auto branch : while_region.getRegions()) + worklist.emplace(branch->getArgument(use.getOperandNumber())); + continue; + } + if (auto yield = llvm::dyn_cast(use.getOwner())) { + Operation* parent = yield->getParentOp(); + worklist.emplace(parent->getResult(use.getOperandNumber())); + continue; + } + // TODO(jpienaar): This can be generalized. + if (isa(use.getOwner())) { + worklist.emplace(use.getOwner()->getResult(use.getOperandNumber())); + continue; + } + // Refining the tensor list element type might change the output of + // TensorListElementShape which is expected to be the originally assigned + // shape to TensorList init ops. So replace it with the original element + // shape value. + if (auto tl_element_shape = + dyn_cast(use.getOwner())) { + // If element types match, we can do a direct replacement. + if (getElementTypeOrSelf(tl_element_shape.getResult()) == + getElementTypeOrSelf(initial_element_shape.getType())) { + tl_element_shape.replaceAllUsesWith(initial_element_shape); + } else { + OpBuilder b(use.getOwner()); + auto cast_op = b.create( + use.getOwner()->getLoc(), tl_element_shape.getResult().getType(), + initial_element_shape, + /*truncate=*/b.getBoolAttr(false)); + tl_element_shape.replaceAllUsesWith(cast_op.getResult()); + } + continue; + } + // Ignore ops that just consume a TensorList and do not output another + // TensorList. + if (isa(use.getOwner())) + continue; + + // For any other unknown users of the TensorList, we are conservative and + // stop element shape inference. + DCOMMENT("TensorListType infer, unknown op " << *use.getOwner()); + return false; + } } return true; } @@ -778,8 +781,12 @@ bool ShapeInference::InferShapeForTensorListInitOps(Operation* op) { if (!element_type || !element_type.hasStaticShape()) return false; } if (!CanInferTensorListElementType(handle, initial_element_shape, - &element_type)) + &element_type)) { + DCOMMENT("InferShapeForListInitOps " << op << " could not infer"); return false; + } + DCOMMENT("InferShapeForListInitOps " << op << " could be inferred " + << element_type); if (!element_type || !element_type.hasStaticShape()) return false; auto variant_type = VariantType::get(element_type, op->getContext()); auto tensor_type = RankedTensorType::get({}, variant_type); @@ -1049,7 +1056,8 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) { // The shape function of these ops sometimes does not propagate subtypes // (handle shapes) for resource and variant types. We use a simple passthrough // to make sure they are preserved in the output. - if (isa(op)) { + if (isa( + op)) { return RefineTypeForPassThroughOperands(op, op->getOperands(), op->getResults()); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc index 33f9d344f0f..8fd0674c300 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc @@ -204,7 +204,7 @@ LogicalResult HandleWhileOp( } auto new_while = builder.create(while_op.getLoc(), body.getType().getInputs(), - new_while_operands, while_op.getAttrs()); + new_while_operands, while_op->getAttrs()); for (int64_t i = 0; i < while_op.getNumResults(); ++i) { if (!getElementTypeOrSelf(while_op.getOperand(i).getType()) .isa()) { @@ -257,7 +257,7 @@ LogicalResult HandleIfOp( } auto new_if = OpBuilder(if_op).create( if_op.getLoc(), then_func.getType().getResults(), new_if_operands, - if_op.getAttrs()); + if_op->getAttrs()); for (auto result : if_op.getResults()) { if (!getElementTypeOrSelf(result.getType()).isa()) { continue; @@ -306,7 +306,7 @@ LogicalResult HandlePartitionedCallOp( OpBuilder builder(call); auto new_call = builder.create( call.getLoc(), info.decomposed_callee.getType().getResults(), - new_operands, call.getAttrs()); + new_operands, call->getAttrs()); new_call->setAttr( "f", builder.getSymbolRefAttr( const_cast(info.decomposed_callee).getName())); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc index 52a9dd3bb05..b983efe9beb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc @@ -625,7 +625,7 @@ LogicalResult HandleWhileOp(TF::WhileOp while_op, ModuleOp module, OpBuilder builder(while_op); auto new_while = builder.create(while_op.getLoc(), body.getType().getInputs(), - operands, while_op.getAttrs()); + operands, while_op->getAttrs()); for (int64_t i = 0; i < while_op.getNumOperands(); ++i) { if (ta_arg_buffer_type(i)) { while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i)); @@ -692,7 +692,7 @@ LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module, OpBuilder builder(if_op); auto new_if = builder.create(if_op.getLoc(), then_branch.getType().getResults(), - operands, if_op.getAttrs()); + operands, if_op->getAttrs()); auto ret_forwards_input = [](FuncOp f, int64_t ret_ind) -> int64_t { auto retval = f.front().getTerminator()->getOperand(ret_ind); auto arg = retval.dyn_cast(); @@ -751,7 +751,7 @@ LogicalResult HandlePartitionedCallOp( OpBuilder builder(call); auto new_call = builder.create( call.getLoc(), info.decomposed_callee.getType().getResults(), - new_operands, call.getAttrs()); + new_operands, call->getAttrs()); new_call->setAttr( "f", builder.getSymbolRefAttr( const_cast(info.decomposed_callee).getName())); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc index d2a465926d1..cd414698ee5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc @@ -208,7 +208,7 @@ LogicalResult HandleWhileOp( } auto new_while = builder.create(while_op.getLoc(), body.getType().getInputs(), - new_while_operands, while_op.getAttrs()); + new_while_operands, while_op->getAttrs()); for (const auto& entry : output_buffer_to_size) { (*buffer_to_size)[new_while.getResult(std::get<0>(entry))] = { new_while.getResult(std::get<1>(entry)), std::get<2>(entry)}; @@ -268,7 +268,7 @@ LogicalResult HandleCaseOrIfOp( FuncOp first_branch = branches.front(); auto new_op = OpBuilder(op).create( op.getLoc(), first_branch.getType().getResults(), new_operands, - op.getAttrs()); + op->getAttrs()); for (const auto& entry : output_buffer_to_size) { (*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = { new_op.getResult(std::get<1>(entry)), std::get<2>(entry)}; @@ -329,7 +329,7 @@ LogicalResult HandleWhileRegionOp( } auto new_while = builder.create( while_op.getLoc(), body_region.front().getTerminator()->getOperandTypes(), - new_while_operands, while_op.getAttrs()); + new_while_operands, while_op->getAttrs()); new_while.body().takeBody(body_region); new_while.cond().takeBody(cond_region); for (const auto& entry : output_buffer_to_size) { @@ -369,7 +369,7 @@ LogicalResult HandleIfRegionOp( // Recreate the op. auto new_op = OpBuilder(if_op).create( if_op.getLoc(), then_branch.front().getTerminator()->getOperandTypes(), - if_op.getOperand(), if_op.getAttrs()); + if_op.getOperand(), if_op->getAttrs()); for (const auto& entry : output_buffer_to_size) { (*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = { new_op.getResult(std::get<1>(entry)), std::get<2>(entry)}; @@ -415,7 +415,7 @@ LogicalResult HandleCaseRegionOp( auto new_op = OpBuilder(case_op).create( case_op.getLoc(), first_branch->front().getTerminator()->getOperandTypes(), - case_op.getOperand(), case_op.getAttrs(), case_op.getNumRegions()); + case_op.getOperand(), case_op->getAttrs(), case_op.getNumRegions()); for (const auto& entry : output_buffer_to_size) { (*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = { new_op.getResult(std::get<1>(entry)), std::get<2>(entry)}; @@ -456,7 +456,7 @@ LogicalResult HandlePartitionedCallOp( OpBuilder builder(call); auto new_call = builder.create( call.getLoc(), info.decomposed_callee.getType().getResults(), - new_operands, call.getAttrs()); + new_operands, call->getAttrs()); new_call->setAttr( "f", builder.getSymbolRefAttr( const_cast(info.decomposed_callee).getName())); @@ -818,7 +818,7 @@ LogicalResult DecomposeTensorListOpsInternal( decomposed_partitioned_call_callees) { for (auto& op : llvm::make_early_inc_range(block->getOperations())) { // TODO(yuanzx): Add a pass to remove identities in device computation. - if (llvm::isa(&op)) { + if (llvm::isa(&op)) { op.replaceAllUsesWith(op.getOperands()); op.erase(); } else if (auto list = llvm::dyn_cast(&op)) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td index e5b95e3d03c..a9efa870e6c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td @@ -761,3 +761,12 @@ func @cluster_oplist(%arg0: tensor, %arg1: tensor) -> tensor { ]; } +def VerifySuitableForExportPass : Pass<"tf-verify-for-export", "ModuleOp"> { + let summary = "Verify module is suitable for export back to TF Graph"; + let description = [{ + Verifies whether all functions in module are of single tf_executor.graph and + each tf_executor.island in tf_executor.graph only has a single op. + }]; + + let constructor = "TF::CreateVerifySuitableForExportPass()"; +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc index d2879f57a25..0529e472f42 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc @@ -582,7 +582,7 @@ LogicalResult FormClustersInBlock( cluster->setAttrs( cluster_metadata->second.getDictionary(cluster.getContext())); // Exclude `num_replicas` as cluster should be replicated if necessary. - cluster.removeAttr(kNumReplicasAttr); + cluster->removeAttr(kNumReplicasAttr); } return success(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_device_propagation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_device_propagation.cc index 06176bb0972..e8d1c7c3dbb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_device_propagation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_device_propagation.cc @@ -225,24 +225,26 @@ void PropagateDevicesToResults( } struct TPUDevicePropagation - : public PassWrapper { - void runOnFunction() override; + : public PassWrapper> { + void runOnOperation() override; }; -void TPUDevicePropagation::runOnFunction() { - FuncOp func = getFunction(); - if (!IsSupportedGraph(func)) return; +void TPUDevicePropagation::runOnOperation() { + ModuleOp m = getOperation(); + m.walk([&](FuncOp func) { + if (!IsSupportedGraph(func)) return; - llvm::DenseMap value_to_device; - PropagateDevicesFromArguments(func, value_to_device); - auto graph = llvm::cast(func.front().front()); - PropagateDevicesInGraph(graph, value_to_device); - PropagateDevicesToResults(func, graph.GetFetch(), value_to_device); + llvm::DenseMap value_to_device; + PropagateDevicesFromArguments(func, value_to_device); + auto graph = llvm::cast(func.front().front()); + PropagateDevicesInGraph(graph, value_to_device); + PropagateDevicesToResults(func, graph.GetFetch(), value_to_device); + }); } } // namespace -std::unique_ptr> CreateTPUDevicePropagationPass() { +std::unique_ptr> CreateTPUDevicePropagationPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc index af21ee4e4ab..b1f16a0f444 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc @@ -326,7 +326,7 @@ tf_device::ClusterOp UpdateClusterResults( auto new_cluster = builder->create( cluster.getLoc(), new_cluster_result_types, - /*operands=*/llvm::ArrayRef{}, cluster.getAttrs()); + /*operands=*/llvm::ArrayRef{}, cluster->getAttrs()); new_cluster.body().takeBody(cluster.body()); auto operand_not_in_cluster = [&](OpOperand& operand) { @@ -400,7 +400,7 @@ void RemoveClusterAliasedOutputs(OpBuilder* builder, builder->setInsertionPoint(cluster); auto new_cluster = builder->create( cluster.getLoc(), new_cluster_result_types, - /*operands=*/llvm::ArrayRef{}, cluster.getAttrs()); + /*operands=*/llvm::ArrayRef{}, cluster->getAttrs()); new_cluster.body().takeBody(cluster.body()); new_cluster.GetBody().getTerminator()->setOperands(new_cluster_results); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_reorder_replicate_and_partitioned_inputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_reorder_replicate_and_partitioned_inputs.cc index c84de533cc8..84d180347ff 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_reorder_replicate_and_partitioned_inputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_reorder_replicate_and_partitioned_inputs.cc @@ -94,13 +94,13 @@ LogicalResult ReorderReplicateAndPartitionedInputs( for (const auto& operands_per_replica : operands_per_replica_per_core) { auto replicate_op = builder.create( replicated_input.getLoc(), replicated_input.getType(), - operands_per_replica, replicated_input.getAttrs()); + operands_per_replica, replicated_input->getAttrs()); operands_per_core.push_back(replicate_op); } auto pi = builder.create( first_partitioned_input.getLoc(), replicated_input.getType(), - operands_per_core, first_partitioned_input.getAttrs()); + operands_per_core, first_partitioned_input->getAttrs()); replicated_input.replaceAllUsesWith(pi.output()); return success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/verify_suitable_for_graph_export_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/verify_suitable_for_graph_export_pass.cc new file mode 100644 index 00000000000..add53a5c664 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/verify_suitable_for_graph_export_pass.cc @@ -0,0 +1,43 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.h" + +namespace mlir { +namespace TF { +namespace { + +class VerifySuitableForExportPass + : public VerifySuitableForExportPassBase { + public: + void runOnOperation() override { + if (failed(tensorflow::VerifyExportSuitable(getOperation()))) + return signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr> CreateVerifySuitableForExportPass() { + return std::make_unique(); +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 36db48a3668..07e6254a29b 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" @@ -46,8 +47,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.h" #include "tensorflow/compiler/mlir/utils/name_utils.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/graph.pb.h" @@ -74,9 +77,6 @@ using stream_executor::port::StatusOr; namespace { -constexpr char kInvalidExecutorGraphMsg[] = - "Functions must be of a single Graph with single op Islands: "; - constexpr char kDeviceAttr[] = "tf.device"; constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id"; @@ -91,52 +91,6 @@ class LegalizedOpOrValLocNameMapper : public OpOrArgLocNameMapper { } }; -// Checks functions in module are of single tf_executor.graph and each -// tf_executor.island in tf_executor.graph only has a single op. -Status HasSingleGraphSingleOpIslandsFunctions(mlir::ModuleOp module) { - Status status = Status::OK(); - module.walk([&](mlir::FuncOp function) { - if (!llvm::hasSingleElement(function)) { - status = errors::FailedPrecondition( - kInvalidExecutorGraphMsg, - "only single block functions are supported."); - return mlir::WalkResult::interrupt(); - } - - auto block = function.front().without_terminator(); - auto graph = llvm::dyn_cast(block.begin()); - if (!graph) { - status = errors::FailedPrecondition( - kInvalidExecutorGraphMsg, - "first op in function is not a tf_executor.graph."); - return mlir::WalkResult::interrupt(); - } - - if (!hasSingleElement(block)) { - status = errors::FailedPrecondition( - kInvalidExecutorGraphMsg, - "function does not only contain a single tf_executor.graph."); - return mlir::WalkResult::interrupt(); - } - - for (Operation& op : graph.GetBody()) { - auto island = llvm::dyn_cast(op); - if (!island) continue; - - if (!island.WrapsSingleOp()) { - status = errors::FailedPrecondition( - kInvalidExecutorGraphMsg, - "tf_executor.island must perfectly wrap a single op."); - return mlir::WalkResult::interrupt(); - } - } - - return mlir::WalkResult::advance(); - }); - - return status; -} - // Finds first inner op if `op` is a tf_executor.island. Otherwise `op` is // returned. Operation* GetIslandInnerOpOrSelf(mlir::Operation* op) { @@ -159,10 +113,10 @@ class Exporter { // Converts a given FuncOp to a FunctionDef and adds it to the function // definition library - static Status ConvertLibFunction(const GraphExportConfig& configs, - const Dialect* tf_dialect, - mlir::FuncOp function, - FunctionDefLibrary* flib); + static Status ConvertLibFunction( + const GraphExportConfig& configs, const Dialect* tf_dialect, + mlir::FuncOp function, FunctionDefLibrary* flib, + llvm::SmallDenseSet& visited_functions); // Converts the given FuncOp to a Graph. The arguments and returns of // function are added to the graph with special op names kArgOp and kRetOp. // Later on, this graph can be converted a function definition and added to @@ -170,6 +124,7 @@ class Exporter { static StatusOr> Convert( const GraphExportConfig& configs, const Dialect* tf_dialect, mlir::FuncOp function, FunctionDefLibrary* flib, + llvm::SmallDenseSet& visited_functions, absl::flat_hash_set* control_ret_nodes); private: @@ -451,6 +406,7 @@ Status Exporter::GetControlRetNodes( StatusOr> Exporter::Convert( const GraphExportConfig& configs, const Dialect* tf_dialect, mlir::FuncOp function, FunctionDefLibrary* flib, + llvm::SmallDenseSet& visited_functions, absl::flat_hash_set* control_ret_nodes) { mlir::Block& block = function.front(); @@ -550,7 +506,8 @@ StatusOr> Exporter::Convert( function->getParentOfType().lookupSymbol( name); if (func != nullptr) { - TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, func, flib)); + TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, func, flib, + visited_functions)); TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib)); } return Status::OK(); @@ -610,22 +567,21 @@ StatusOr> Exporter::Convert( return graph; } -Status Exporter::ConvertLibFunction(const GraphExportConfig& configs, - const Dialect* tf_dialect, - mlir::FuncOp function, - FunctionDefLibrary* flib) { - // First look for the function in the current function library. If found, - // nothing needs to be done. - OpRegistry empty_registry; - FunctionLibraryDefinition flib_def(&empty_registry, *flib); +Status Exporter::ConvertLibFunction( + const GraphExportConfig& configs, const Dialect* tf_dialect, + mlir::FuncOp function, FunctionDefLibrary* flib, + llvm::SmallDenseSet& visited_functions) { + // Return early if the function has already been exported. + bool is_new_function = visited_functions.insert(function).second; + if (!is_new_function) return Status::OK(); + auto function_name = function.getName().str(); - if (flib_def.Find(function_name)) return Status::OK(); // TODO(fengliuai): use a small flib_def to reduce overhead absl::flat_hash_set control_ret_nodes; TF_ASSIGN_OR_RETURN(auto sub_graph, Exporter::Convert(configs, tf_dialect, function, flib, - &control_ret_nodes)); + visited_functions, &control_ret_nodes)); const auto control_ret = [&](const Node* n) -> absl::optional { return control_ret_nodes.contains(n) ? absl::make_optional(n->name()) @@ -652,8 +608,8 @@ Status Exporter::ConvertLibFunction(const GraphExportConfig& configs, auto grad_func = function->getParentOfType().lookupSymbol( attr.getValue()); - TF_RETURN_IF_ERROR( - ConvertLibFunction(configs, tf_dialect, grad_func, flib)); + TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, grad_func, flib, + visited_functions)); GradientDef grad; grad.set_function_name(function_name); grad.set_gradient_func(grad_func.getName().str()); @@ -709,26 +665,32 @@ Status Exporter::Convert(mlir::ModuleOp module, mlir::Identifier::get("main", module.getContext()); absl::optional entry_func; FunctionDefLibrary flib; + llvm::SmallDenseSet visited_functions; auto tf_dialect = module.getContext()->getLoadedDialect("tf"); for (auto function : module.getOps()) { if (function.isExternal()) return errors::FailedPrecondition("External functions not supported"); - if (function.getName() == entry_func_id) { + if (function.getName() == entry_func_id && + !configs.export_entry_func_to_flib) { entry_func.emplace(function); } else { - TF_RETURN_IF_ERROR( - ConvertLibFunction(configs, tf_dialect, function, &flib)); + TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, function, + &flib, visited_functions)); } } - if (!entry_func.has_value()) - return errors::FailedPrecondition("entry function `main` must be present"); + if (!configs.export_entry_func_to_flib) { + if (!entry_func.has_value()) + return errors::FailedPrecondition( + "entry function `main` must be present"); + + // Updates the graph and the function library definition. + TF_ASSIGN_OR_RETURN( + *graph, Exporter::Convert(configs, tf_dialect, entry_func.value(), + &flib, visited_functions, control_ret_nodes)); + } - // Updates the graph and the function library definition. - TF_ASSIGN_OR_RETURN( - *graph, Exporter::Convert(configs, tf_dialect, entry_func.value(), &flib, - control_ret_nodes)); for (auto& func_def : flib.function()) { TF_RETURN_IF_ERROR(flib_def->AddFunctionDef(func_def)); } @@ -744,8 +706,10 @@ Status ConvertMlirToGraph(mlir::ModuleOp module, std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, absl::flat_hash_set* control_ret_nodes) { - TF_RETURN_IF_ERROR(HasSingleGraphSingleOpIslandsFunctions(module)); - return Exporter::Convert(module, configs, graph, flib_def, control_ret_nodes); + mlir::StatusScopedDiagnosticHandler sh(module.getContext()); + if (failed(VerifyExportSuitable(module))) return sh.ConsumeStatus(); + return sh.Combine( + Exporter::Convert(module, configs, graph, flib_def, control_ret_nodes)); } Status ConvertMlirToGraph(mlir::ModuleOp module, @@ -761,8 +725,16 @@ StatusOr> ConvertMlirToGraphdef( mlir::ModuleOp module, const GraphExportConfig& configs) { FunctionLibraryDefinition flib_def(OpRegistry::Global(), FunctionDefLibrary()); - auto graph = absl::make_unique(flib_def); + std::unique_ptr graph; TF_RETURN_IF_ERROR(ConvertMlirToGraph(module, configs, &graph, &flib_def)); + + // If the entry function is exported to flib, then no graph is constructed. + // Construct one in that case. + if (configs.export_entry_func_to_flib) { + graph = std::make_unique(OpRegistry::Global()); + } + TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(flib_def.ToProto())); + auto graphdef = absl::make_unique(); graph->ToGraphDef(graphdef.get()); if (!configs.export_library) graphdef->clear_library(); @@ -784,8 +756,9 @@ stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef( FunctionDef* function_def) { Dialect* tf_dialect = func.getContext()->getLoadedDialect("tf"); FunctionDefLibrary flib; - TF_RETURN_IF_ERROR( - Exporter::ConvertLibFunction(configs, tf_dialect, func, &flib)); + llvm::SmallDenseSet visited_functions; + TF_RETURN_IF_ERROR(Exporter::ConvertLibFunction(configs, tf_dialect, func, + &flib, visited_functions)); for (auto& func_def : flib.function()) { if (func_def.signature().name() == func.getName()) { *function_def = func_def; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 2d0654bada2..e50438b4cac 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -113,6 +113,7 @@ limitations under the License. #include "tensorflow/core/protobuf/saver.pb.h" #include "tensorflow/core/protobuf/struct.pb.h" #include "tensorflow/core/protobuf/trackable_object_graph.pb.h" +#include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/stream_executor/lib/statusor.h" static inline absl::string_view StringRefToView(llvm::StringRef ref) { @@ -1291,17 +1292,23 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) { if (name_and_value.first == "_input_shapes") { auto& list = name_and_value.second.list(); auto& signature = func_def->signature(); - if (list.shape_size() != signature.input_arg_size()) { + // Some models have "_input_shapes" attribute, but with its value empty + if (list.shape_size() > 0 && + list.shape_size() != signature.input_arg_size()) { return errors::FailedPrecondition( "Number of input arguments must be equal to the length of " "_input_shapes attribute in function '", StringRefToView(func_name), "'."); } - for (int i = 0; i < list.shape_size(); i++) { + for (int i = 0; i < signature.input_arg_size(); i++) { auto& input_arg = signature.input_arg(i); auto& array_info = specs.inputs[input_arg.name()]; array_info.imported_dtype = input_arg.type(); - array_info.shape = list.shape(i); + // set to unranked for empty "_input_shapes" attribute + if (list.shape_size() > 0) + array_info.shape = list.shape(i); + else + array_info.shape.set_unknown_rank(true); } } } @@ -1661,7 +1668,7 @@ mlir::Location ImporterBase::GetLocation(const Node& node) { // If there are no locations in the stack trace, fall back to just a // NameLoc with no child. - if (locations.empty()) return mlir::NameLoc::get(name_loc_id, context_); + if (locations.empty()) return mlir::NameLoc::get(name_loc_id); // Use the front FileLineColLoc to generate a NameLoc. mlir::Location node_name_loc = @@ -1984,8 +1991,28 @@ Status ImporterBase::ConvertNode(const Node& node) { } const auto& node_def = node.def(); + // NodeDef can contain partial TF device names. In such cases, canonicalize + // it. Note that in current TF, placer will place full device name to each + // node. + DeviceNameUtils::ParsedName parsed_name; + if (!DeviceNameUtils::ParseFullName(node_def.device(), &parsed_name)) { + return errors::InvalidArgument( + "Op ", op_name, " has invalid device name: ", node_def.device()); + } + // Keep the parsed name untouched if the device name is empty. + if (!node_def.device().empty()) { + if (!parsed_name.has_type) { + parsed_name.type = "CPU"; + parsed_name.has_type = true; + } + if (!parsed_name.has_id) { + parsed_name.id = 0; + parsed_name.has_id = true; + } + } result.attributes.push_back(builder_.getNamedAttr( - "device", builder_.getStringAttr(std::string(node_def.device())))); + "device", builder_.getStringAttr( + DeviceNameUtils::ParsedNameToString(parsed_name)))); // Map user function calls to LegacyCall ops and add the user function name // as an attribute. @@ -2165,7 +2192,8 @@ class GraphDefImporter : public ImporterBase { mlir::MLIRContext* context, const Graph& graph, const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, - llvm::StringRef func_name); + llvm::StringRef func_name, + std::unordered_map& tf_name_to_mlir_name); private: explicit GraphDefImporter( @@ -2206,11 +2234,11 @@ class GraphDefImporter : public ImporterBase { StatusOr GraphDefImporter::Convert( mlir::MLIRContext* context, const Graph& graph, const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def, - const GraphImportConfig& specs, llvm::StringRef func_name) { + const GraphImportConfig& specs, llvm::StringRef func_name, + std::unordered_map& tf_name_to_mlir_name) { LoadImporterDialects(*context); mlir::OwningModuleRef module = mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); - std::unordered_map tf_name_to_mlir_name; NameUniquifier function_name_uniquifier(flib_def); GraphDefImporter importer(flib_def, debug_info, specs, module.get(), @@ -3499,12 +3527,14 @@ class SavedModelSignatureDefImporterLite { const std::string& name, const std::vector>& inputs, const std::vector>& outputs, - const std::vector control_outputs); + const std::vector control_outputs, + std::unordered_map& tf_name_to_mlir_name); // Moves the functions in `sub_module` to `module_` and skips the duplicate // functions. - Status MoveConvertedFunctionsToModule(absl::string_view name, - mlir::ModuleOp sub_module); + Status MoveConvertedFunctionsToModule( + absl::string_view name, mlir::ModuleOp sub_module, + const std::unordered_map& tf_name_to_mlir_name); GraphImportConfig::InputArrays ParseInputArrays( llvm::ArrayRef> inputs); @@ -3545,15 +3575,25 @@ SavedModelSignatureDefImporterLite::ConvertAssets() { } Status SavedModelSignatureDefImporterLite::MoveConvertedFunctionsToModule( - absl::string_view name, mlir::ModuleOp sub_module) { + absl::string_view name, mlir::ModuleOp sub_module, + const std::unordered_map& tf_name_to_mlir_name) { mlir::Builder builder(sub_module.getContext()); mlir::SymbolTable sub_module_symbol_table(sub_module); + // Functions originally from graphdef library might have a different name + // after conversion, we build the set of the converted names + absl::flat_hash_set original_func_mlir_names; + for (const auto& kv : tf_name_to_mlir_name) + original_func_mlir_names.insert(kv.second); + // Prefix private functions with the unique signature name, so that it cannot // collide with private functions used in the other signatures. for (auto func : sub_module.getOps()) { if (mlir::tf_saved_model::IsExported(func)) continue; + // Skip the original functions from graphdef library + if (original_func_mlir_names.count(func.sym_name().str())) continue; + std::string new_sym_name = absl::StrCat(name, "/", func.sym_name().str()); if (mlir::failed(sub_module_symbol_table.replaceAllSymbolUses( func, new_sym_name, sub_module))) @@ -3567,7 +3607,7 @@ Status SavedModelSignatureDefImporterLite::MoveConvertedFunctionsToModule( // Copy all functions used by this signature to the final MLIR module. for (auto func : sub_module.getOps()) { - DCHECK(symbol_table_.lookup(func.sym_name()) == nullptr); + // The insert here is a NO-OP if the function already exists. symbol_table_.insert(func.clone()); } @@ -3586,13 +3626,15 @@ Status SavedModelSignatureDefImporterLite::ConvertInitializer( inputs.push_back({asset.tensor_name, tensor_info}); } - TF_ASSIGN_OR_RETURN(auto sub_module, ConvertGraph(target_node_name, inputs, - {}, {target_node_name})); + std::unordered_map tf_name_to_mlir_name; + TF_ASSIGN_OR_RETURN(auto sub_module, + ConvertGraph(target_node_name, inputs, {}, + {target_node_name}, tf_name_to_mlir_name)); mlir::SymbolTable sub_symbol_table(*sub_module); auto init_func_op = sub_symbol_table.lookup(target_node_name); - init_func_op.removeAttr("tf.entry_function"); + init_func_op->removeAttr("tf.entry_function"); mlir::OpBuilder builder(module_->getBodyRegion()); @@ -3612,7 +3654,8 @@ Status SavedModelSignatureDefImporterLite::ConvertInitializer( "__tf_saved_model_session_initializer_", target_node_name)})); // Move the converted functions to top level MLIR module. - return MoveConvertedFunctionsToModule(target_node_name, *sub_module); + return MoveConvertedFunctionsToModule(target_node_name, *sub_module, + tf_name_to_mlir_name); } StatusOr @@ -3620,7 +3663,8 @@ SavedModelSignatureDefImporterLite::ConvertGraph( const std::string& name, const std::vector>& inputs, const std::vector>& outputs, - const std::vector control_outputs) { + const std::vector control_outputs, + std::unordered_map& tf_name_to_mlir_name) { VLOG(1) << "Importing Signature: " << name; GraphImportConfig specs; @@ -3634,7 +3678,7 @@ SavedModelSignatureDefImporterLite::ConvertGraph( // Convert sub-graph to MLIR module. return GraphDefImporter::Convert(module_->getContext(), *subgraph, input_.debug_info(), subgraph->flib_def(), - specs, name); + specs, name, tf_name_to_mlir_name); } Status SavedModelSignatureDefImporterLite::ConvertSignature( @@ -3654,9 +3698,12 @@ Status SavedModelSignatureDefImporterLite::ConvertSignature( return lhs.first.size() < rhs.first.size() || lhs.first > rhs.first; }); + std::unordered_map tf_name_to_mlir_name; + // Convert sub-graph to MLIR module. - TF_ASSIGN_OR_RETURN(auto sub_module, - ConvertGraph(sig_def_key, inputs, outputs, {})); + TF_ASSIGN_OR_RETURN( + auto sub_module, + ConvertGraph(sig_def_key, inputs, outputs, {}, tf_name_to_mlir_name)); mlir::OpBuilder builder(sub_module->getBodyRegion()); // Find the FuncOp which corresponds to current SignatureDef. @@ -3682,7 +3729,8 @@ Status SavedModelSignatureDefImporterLite::ConvertSignature( } // Move the converted functions to top level MLIR module. - return MoveConvertedFunctionsToModule(sig_def_key, *sub_module); + return MoveConvertedFunctionsToModule(sig_def_key, *sub_module, + tf_name_to_mlir_name); } GraphImportConfig::InputArrays @@ -3820,7 +3868,7 @@ class SavedModelSignatureDefImporter { (*module)->setAttr("tf_saved_model.under_construction", builder.getUnitAttr()); TF_RETURN_IF_ERROR(LiftVariables(bundle, *module)); - module->removeAttr("tf_saved_model.under_construction"); + (*module)->removeAttr("tf_saved_model.under_construction"); return module; } @@ -3895,8 +3943,9 @@ StatusOr ConvertGraphToMlir( const_cast(&flib_def), specs.restrict_functionalization_to_tpu_nodes)); } + std::unordered_map tf_name_to_mlir_name; return GraphDefImporter::Convert(context, graph, debug_info, flib_def, specs, - /*func_name=*/"main"); + /*func_name=*/"main", tf_name_to_mlir_name); } stream_executor::port::StatusOr ConvertFunctionToMlir( @@ -3908,9 +3957,10 @@ stream_executor::port::StatusOr ConvertFunctionToMlir( specs.graph_as_function = true; for (const auto* control_ret_node : fbody->control_ret_nodes) specs.control_outputs.push_back(control_ret_node->name()); - return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info, - flib_def, specs, - fbody->fdef.signature().name()); + std::unordered_map tf_name_to_mlir_name; + return GraphDefImporter::Convert( + context, *fbody->graph, dummy_debug_info, flib_def, specs, + fbody->fdef.signature().name(), tf_name_to_mlir_name); } StatusOr ConvertSavedModelToMlir( diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h index 7c5790e05f9..94753454b8c 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h @@ -75,6 +75,9 @@ struct GraphExportConfig { bool export_library = true; // Whether to export debug original node name in the GraphDef. bool export_debug_info = true; + // Whether to export the entry function to function library instead of the + // graph. + bool export_entry_func_to_flib = false; }; // Parses the command line flag strings to the specification of nodes in diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.cc index 249ed2767c0..65dd83929fa 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.cc @@ -21,6 +21,7 @@ limitations under the License. using llvm::cl::opt; +// Import options. // NOLINTNEXTLINE opt input_arrays( "tf-input-arrays", llvm::cl::desc("Input tensor names, separated by ','"), @@ -115,3 +116,11 @@ opt enable_shape_inference( "tf-enable-shape-inference-on-import", llvm::cl::desc("Enable shape inference on import (temporary)"), llvm::cl::init(false)); + +// Export options. +// NOLINTNEXTLINE +opt export_entry_func_to_flib( + "tf-export-entry-func-to-flib", + llvm::cl::desc( + "Export entry function to function library instead of graph"), + llvm::cl::init(false)); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h index d235400b2e7..b1fc4d9aa04 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h @@ -26,6 +26,7 @@ limitations under the License. // Please see the implementation file for documentation of these options. +// Import options. extern llvm::cl::opt input_arrays; extern llvm::cl::opt input_dtypes; extern llvm::cl::opt input_shapes; @@ -42,4 +43,7 @@ extern llvm::cl::opt upgrade_legacy; // TODO(jpienaar): Temporary flag, flip default and remove. extern llvm::cl::opt enable_shape_inference; +// Export options. +extern llvm::cl::opt export_entry_func_to_flib; + #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_CL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc index a45fa3342c9..361ec3639b4 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc @@ -75,6 +75,7 @@ static LogicalResult MlirToGraphdefTranslateFunction( // TODO(fengliuai): Add exporter flags. tensorflow::GraphExportConfig confs; + confs.export_entry_func_to_flib = export_entry_func_to_flib; StatusOr> graphdef_or( tensorflow::ConvertMlirToGraphdef(module, confs)); if (!graphdef_or.status().ok()) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc index 9f09096d4ce..8221b1bba19 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc @@ -100,15 +100,15 @@ struct WritableFileRawStream : public llvm::raw_ostream { struct CrashReproducerStream : public mlir::PassManager::ReproducerStream { CrashReproducerStream(llvm::StringRef name, - std::unique_ptr file) + std::unique_ptr file) : name(name), ostream(std::move(file)) {} llvm::StringRef description() override { return name; } - raw_ostream& os() override { return ostream; } + raw_ostream& os() override { return *ostream; } private: std::string name; - WritableFileRawStream ostream; + std::unique_ptr ostream; }; } // namespace @@ -225,25 +225,32 @@ void SetCrashReproducer(mlir::PassManager& pm, llvm::StringRef dir_path) { } } - auto* env = tensorflow::Env::Default(); - auto status = env->RecursivelyCreateDir(path); - if (!status.ok()) { - LOG(WARNING) << "cannot create directory '" + path + - "': " + status.error_message(); - return; - } + if (path != "-") { + auto* env = tensorflow::Env::Default(); + auto status = env->RecursivelyCreateDir(path); + if (!status.ok()) { + LOG(WARNING) << "cannot create directory '" + path + + "': " + status.error_message(); + return; + } - path += "/mlir_reproducer_"; + path += "/mlir_reproducer_"; - if (!tensorflow::Env::Default()->CreateUniqueFileName(&path, ".mlir")) { - LOG(WARNING) - << "cannot create unique filename, won't enable MLIR crash reproducer."; - return; + if (!tensorflow::Env::Default()->CreateUniqueFileName(&path, ".mlir")) { + LOG(WARNING) << "cannot create unique filename, won't enable MLIR crash " + "reproducer."; + return; + } } mlir::PassManager::ReproducerStreamFactory factory = [path](std::string& error) -> std::unique_ptr { + // Use the stderr stream. + if (path == "-") + return std::make_unique( + "(stderr)", std::make_unique()); + // Try to open the file and generate a raw_ostream. std::unique_ptr file; Status status = tensorflow::Env::Default()->NewWritableFile(path, &file); @@ -252,7 +259,8 @@ void SetCrashReproducer(mlir::PassManager& pm, llvm::StringRef dir_path) { "': ", status.error_message()); return nullptr; } - return std::make_unique(path, std::move(file)); + return std::make_unique( + path, std::make_unique(std::move(file))); }; pm.enableCrashReproducerGeneration(factory, /*genLocalReproducer=*/false); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc index 5c1c595dce0..d1b2304a40c 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc @@ -177,7 +177,7 @@ LogicalResult InferReturnTypeComponentsForTFOp( OpResultAsShapeFn op_result_as_shape_fn, ResultElementTypeFn result_element_type_fn, SmallVectorImpl& inferred_return_shapes) { - assert(op->getName().getDialect() == + assert(op->getName().getDialectNamespace() == TensorFlowDialect::getDialectNamespace()); auto op_name_or = diff --git a/tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.cc b/tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.cc new file mode 100644 index 00000000000..e8e61517e86 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.cc @@ -0,0 +1,68 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.h" + +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" + +namespace tensorflow { +namespace { + +constexpr char kInvalidExecutorGraphMsg[] = + "functions must be of a single Graph with single op Islands: "; + +} // namespace + +mlir::LogicalResult VerifyExportSuitable(mlir::ModuleOp module) { + mlir::WalkResult result = module.walk([&](mlir::FuncOp function) { + if (!llvm::hasSingleElement(function)) { + function.emitError(kInvalidExecutorGraphMsg) + << "only single block functions are supported"; + return mlir::WalkResult::interrupt(); + } + + auto block = function.front().without_terminator(); + auto graph = llvm::dyn_cast(block.begin()); + if (!graph) { + block.begin()->emitError(kInvalidExecutorGraphMsg) + << "first op in function is not a tf_executor.graph"; + return mlir::WalkResult::interrupt(); + } + + if (!hasSingleElement(block)) { + function.emitError(kInvalidExecutorGraphMsg) + << "function does not only contain a single tf_executor.graph"; + return mlir::WalkResult::interrupt(); + } + + for (mlir::Operation& op : graph.GetBody()) { + auto island = llvm::dyn_cast(op); + if (!island) continue; + + if (!island.WrapsSingleOp()) { + island.emitError(kInvalidExecutorGraphMsg) + << "tf_executor.island must perfectly wrap a single op"; + return mlir::WalkResult::interrupt(); + } + } + + return mlir::WalkResult::advance(); + }); + + return mlir::failure(result.wasInterrupted()); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.h b/tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.h new file mode 100644 index 00000000000..d2cc8cc04d4 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.h @@ -0,0 +1,30 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VERIFY_SUITABLE_FOR_GRAPH_EXPORT_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VERIFY_SUITABLE_FOR_GRAPH_EXPORT_H_ + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +namespace tensorflow { + +// Returns whether all functions in module are of single tf_executor.graph and +// each tf_executor.island in tf_executor.graph only has a single op. +mlir::LogicalResult VerifyExportSuitable(mlir::ModuleOp module); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VERIFY_SUITABLE_FOR_GRAPH_EXPORT_H_ diff --git a/tensorflow/compiler/mlir/tfr/define_op_template.py b/tensorflow/compiler/mlir/tfr/define_op_template.py index c0db2981d2d..8ce653b1513 100644 --- a/tensorflow/compiler/mlir/tfr/define_op_template.py +++ b/tensorflow/compiler/mlir/tfr/define_op_template.py @@ -22,10 +22,10 @@ from __future__ import print_function import os import sys +from absl import app from tensorflow.compiler.mlir.tfr.python.composite import Composite from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module -from tensorflow.python.platform import app from tensorflow.python.platform import flags FLAGS = flags.FLAGS diff --git a/tensorflow/compiler/mlir/tfr/examples/customization/ops_defs.py b/tensorflow/compiler/mlir/tfr/examples/customization/ops_defs.py index 8920ad09399..2dbd79b690f 100644 --- a/tensorflow/compiler/mlir/tfr/examples/customization/ops_defs.py +++ b/tensorflow/compiler/mlir/tfr/examples/customization/ops_defs.py @@ -22,13 +22,13 @@ from __future__ import print_function import os import sys +from absl import app from tensorflow.compiler.mlir.tfr.python import composite from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module from tensorflow.python.framework import dtypes from tensorflow.python.ops import gen_array_ops as array_ops -from tensorflow.python.platform import app from tensorflow.python.platform import flags diff --git a/tensorflow/compiler/mlir/tfr/examples/mnist/ops_defs.py b/tensorflow/compiler/mlir/tfr/examples/mnist/ops_defs.py index 0cf4678892e..998fcd24323 100644 --- a/tensorflow/compiler/mlir/tfr/examples/mnist/ops_defs.py +++ b/tensorflow/compiler/mlir/tfr/examples/mnist/ops_defs.py @@ -22,6 +22,7 @@ from __future__ import print_function import os import sys +from absl import app import tensorflow as tf @@ -31,7 +32,6 @@ from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops -from tensorflow.python.platform import app from tensorflow.python.platform import flags Composite = composite.Composite diff --git a/tensorflow/compiler/mlir/tfr/examples/pad/ops_defs.py b/tensorflow/compiler/mlir/tfr/examples/pad/ops_defs.py index 4b072a58f08..dbc823f5b51 100644 --- a/tensorflow/compiler/mlir/tfr/examples/pad/ops_defs.py +++ b/tensorflow/compiler/mlir/tfr/examples/pad/ops_defs.py @@ -23,13 +23,13 @@ from __future__ import print_function import os import sys +from absl import app import tensorflow as tf from tensorflow.compiler.mlir.tfr.python import composite from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module from tensorflow.python.ops import gen_array_ops -from tensorflow.python.platform import app from tensorflow.python.platform import flags Composite = composite.Composite diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td index 9d1e7fb8513..6971edc298f 100644 --- a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td @@ -441,7 +441,7 @@ def TFR_TFRFuncOp : TFR_Op<"func", [HasParent<"ModuleOp">, // non-derived ones. llvm::StringSet<> getDefinedAttributeNames() { llvm::StringSet<> all_attrs; - for (auto& attr : getAttrs()) { + for (auto& attr : (*this)->getAttrs()) { all_attrs.insert(attr.first.strref()); } for (const auto& operand : llvm::enumerate(getType().getInputs())) { diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc index 9a12fa5dd67..e01ac784a77 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc @@ -118,22 +118,6 @@ LogicalResult Verify(TFAllocOp op) { } } -//===----------------------------------------------------------------------===// -// MinimumBroadcastShapesOp -//===----------------------------------------------------------------------===// -template <> -LogicalResult Verify(MinimumBroadcastShapesOp op) { - // Check that the number of operands matches the number of outputs. - unsigned result_shapes_count = op.results().size(); - unsigned operand_shapes_count = op.shapes().size(); - if (operand_shapes_count != result_shapes_count) { - return op.emitOpError() - << "number of operand shapes " << operand_shapes_count - << " does not match number of result shapes " << result_shapes_count; - } - return success(); -} - } // namespace tf_framework } // namespace kernel_gen } // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td index 2b0bd688f13..67a4c753329 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td @@ -168,47 +168,4 @@ def TFFramework_ReportErrorOp : TFFramework_Op<"report_error", []> { let assemblyFormat = "$ctx `,` $error_code `,` $msg attr-dict"; } -//===----------------------------------------------------------------------===// -// MinimumBroadcastShapesOp -//===----------------------------------------------------------------------===// -def TFFramework_MinimumBroadcastShapesOp : - TFFramework_Op<"minimum_broadcast_shapes", [NoSideEffect]> { - let summary = "Minimizes the rank of two or more shapes to be broadcasted"; - let description = [{ - Given two or more 1D tensors representing shapes, returns one 1D tensor for - each operand, where operand `i` corresponds to output `i`. - - The returned tensors have the property that they specify a shape which is a - reshape of the corresponding input shape, and the broadcasted output shape - (using shape::BroadcastOp) of the returned shapes is a reshape of the - broadcasted output shape of the input shapes. Among all possibilities with - this property, the one is chosen which minimizes the rank of each returned - shape. - - The general idea of this op is that it can be used for ops which have a - broadcasting semantic to operate on shapes with a possibly smaller rank - while preserving equivalence of the computed values. After computing the - result of the op using reshaped operands, the result can be reshaped to the - result that would have been originally computed. - - Here is an example with two input shapes: - - ```mlir - tf_framework.minimum_broadcast_shapes [1, 2, 3, 1, 2, 1], - [1, 1, 1, 2, 3] -> [6, 2, 1], [2, 3] - ``` - - The broadcasted output shape of the operands is [1, 2, 3, 1, 2, 3], the - broadcasted output shape of the outputs is [6, 2, 3]. These two shapes are - reshapes of each other, and also each output is a reshape of the - corresponding input. - }]; - - let arguments = (ins Variadic<1DTensorOf<[Index]>>:$shapes); - let results = (outs Variadic<1DTensorOf<[Index]>>:$results); - - let assemblyFormat = "$shapes attr-dict `:` type($shapes) `->` type($results)"; - -} - #endif // TF_FRAMEWORK_OPS diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir index ffeb74df6a2..7fb6d78b35e 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir @@ -108,3 +108,95 @@ func @const_splat() -> tensor<3xf32> { %result = constant dense<4.0> : tensor<3xf32> return %result : tensor<3xf32> } + +// CHECK-LABEL: @minimum_broadcast_shapes +// CHECK-SAME: (%[[LHS:.*]]: memref, %[[RHS:.*]]: memref) +func @minimum_broadcast_shapes(%lhs: tensor, %rhs: tensor) -> (tensor, tensor) { + // CHECK-NEXT: %[[C0:.*]] = constant 0 : index + // CHECK-NEXT: %[[RANK_LHS:.*]] = dim %[[LHS]], %[[C0]] : memref + // CHECK-NEXT: %[[TRUE:.*]] = constant true + // CHECK-NEXT: %[[C0_0:.*]] = constant 0 : index + // CHECK-NEXT: %[[C1:.*]] = constant 1 : index + // CHECK-NEXT: %[[FOR_0:.*]]:2 = scf.for %[[IV:.*]] = %[[C0_0]] to %[[RANK_LHS]] step %[[C1]] iter_args(%[[ALL_ONES:.*]] = %[[TRUE]], %[[ONE_COUNT:.*]] = %[[C0_0]]) -> (i1, index) { + // CHECK-NEXT: %[[SIZE:.*]] = load %[[LHS]][%[[IV]]] : memref + // CHECK-NEXT: %[[IS_ONE:.*]] = cmpi eq, %[[SIZE]], %[[C1]] : index + // CHECK-NEXT: %[[NEXT_ALL_ONES:.*]] = and %[[ALL_ONES]], %[[IS_ONE]] : i1 + // CHECK-NEXT: %[[ONE_COUNT_PLUS_ONE:.*]] = addi %[[ONE_COUNT]], %[[C1]] : index + // CHECK-NEXT: %[[NEXT_ONE_COUNT:.*]] = select %[[NEXT_ALL_ONES]], %[[ONE_COUNT_PLUS_ONE]], %[[ONE_COUNT]] : index + // CHECK-NEXT: scf.yield %[[NEXT_ALL_ONES]], %[[NEXT_ONE_COUNT]] : i1, index + // CHECK-NEXT: } + // CHECK-NEXT: %[[REDUCED_RANK_LHS:.*]] = subi %[[RANK_LHS]], %[[FOR_0]]#1 : index + // CHECK-NEXT: %[[RANK_RHS:.*]] = dim %[[RHS]], %[[C0]] : memref + // CHECK: %[[REDUCED_RANK_RHS:.*]] = subi %[[RANK_RHS]], %[[LEADING_ONES:.*]]#1 : index + // CHECK-NEXT: %[[IS_GREATER_RANK:.*]] = cmpi ugt, %[[REDUCED_RANK_RHS]], %[[REDUCED_RANK_LHS]] : index + // CHECK-NEXT: %[[MAX_RANK:.*]] = select %[[IS_GREATER_RANK]], %[[REDUCED_RANK_RHS]], %[[REDUCED_RANK_LHS]] : index + // CHECK-NEXT: %[[C1_1:.*]] = constant 1 : index + // CHECK-NEXT: %[[RESULT_LHS:.*]] = alloca(%[[REDUCED_RANK_LHS]]) : memref + // CHECK-NEXT: scf.for %[[IV:.*]] = %[[C0]] to %[[REDUCED_RANK_LHS]] step %[[C1_1]] { + // CHECK-NEXT: store %[[C1_1]], %[[RESULT_LHS]][%[[IV]]] : memref + // CHECK-NEXT: } + // CHECK-NEXT: %[[RESULT_RHS:.*]] = alloca(%[[REDUCED_RANK_RHS]]) : memref + // CHECK: %[[C2:.*]] = constant 2 : index + // CHECK-NEXT: %[[UPPER_BOUND:.*]] = addi %[[MAX_RANK]], %[[C2]] : index + // CHECK-NEXT: %[[MAIN_FOR:.*]]:2 = scf.for %[[IV:.*]] = %[[C1_1]] to %[[UPPER_BOUND]] step %[[C1_1]] iter_args(%[[RUNNING_PRODUCT:.*]] = %[[C1_1]], %[[OFFSET:.*]] = %[[C0]]) -> (index, index) { + // CHECK-NEXT: %[[FALSE:.*]] = constant false + // CHECK-NEXT: %[[MINUS_ONE:.*]] = constant -1 : index + // CHECK-NEXT: %[[IS_OUT_OF_BOUNDS:.*]] = cmpi ult, %[[REDUCED_RANK_LHS]], %[[IV]] : index + // CHECK-NEXT: %[[DIMENSION:.*]] = subi %[[RANK_LHS]], %[[IV]] : index + // CHECK-NEXT: %[[RESULT_DIMENSION:.*]] = subi %[[DIMENSION]], %[[FOR_0]]#1 : index + // CHECK-NEXT: %[[CURRENT_SIZE:.*]] = scf.if %[[IS_OUT_OF_BOUNDS]] -> (index) { + // CHECK-NEXT: scf.yield %[[MINUS_ONE]] : index + // CHECK-NEXT: } else { + // CHECK-NEXT: %[[SIZE:.*]] = load %[[LHS]][%[[DIMENSION]]] : memref + // CHECK-NEXT: scf.yield %[[SIZE]] : index + // CHECK-NEXT: } + // CHECK-NEXT: %[[IS_INITIALIZED:.*]] = cmpi ne, %[[MINUS_ONE]], %[[MINUS_ONE]] : index + // CHECK-NEXT: %[[SAME_SIZE:.*]] = select %[[IS_INITIALIZED]], %[[MINUS_ONE]], %[[CURRENT_SIZE]] : index + // CHECK-NEXT: %[[IS_DIFFERENT_SIZE:.*]] = cmpi ne, %[[CURRENT_SIZE]], %[[SAME_SIZE]] : index + // CHECK-NEXT: %[[NEW_SAME_SIZE:.*]] = select %[[IS_DIFFERENT_SIZE]], %[[CURRENT_SIZE]], %[[SAME_SIZE]] : index + // CHECK-NEXT: %[[DIFFERENT_SIZES:.*]] = or %[[FALSE]], %[[IS_DIFFERENT_SIZE]] : i1 + // CHECK-NEXT: %[[IS_ONE_OUT_OF_BOUNDS:.*]] = cmpi eq, %[[RESULT_DIMENSION]], %[[MINUS_ONE]] : index + // CHECK-NEXT: %[[JUST_OUT_OF_BOUNDS:.*]] = or %[[FALSE]], %[[IS_ONE_OUT_OF_BOUNDS]] : i1 + // CHECK: %[[IS_INITIALIZED:.*]] = cmpi ne, %[[NEW_SAME_SIZE]], %[[MINUS_ONE]] : index + // CHECK-NEXT: %[[SAME_SIZE:.*]] = select %[[IS_INITIALIZED]], %[[NEW_SAME_SIZE]], %[[CURRENT_SIZE_1:.*]] : index + // CHECK-NEXT: %[[IS_DIFFERENT_SIZE:.*]] = cmpi ne, %[[CURRENT_SIZE_1]], %[[SAME_SIZE]] : index + // CHECK-NEXT: %[[FINAL_SAME_SIZE:.*]] = select %[[IS_DIFFERENT_SIZE]], %[[CURRENT_SIZE_1]], %[[SAME_SIZE]] : index + // CHECK: %[[FINAL_DIFFERENT_SIZES:.*]] = or %[[DIFFERENT_SIZES]], %[[IS_DIFFERENT_SIZE:.*]] : i1 + // CHECK: %[[FINAL_JUST_OUT_OF_BOUNDS:.*]] = or %[[JUST_OUT_OF_BOUNDS]], %[[IS_ONE_OUT_OF_BOUNDS:.*]] : i1 + // CHECK-NEXT: %[[STOP_COMBINING_DIMENSIONS:.*]] = or %[[FINAL_DIFFERENT_SIZES]], %[[FINAL_JUST_OUT_OF_BOUNDS]] : i1 + // CHECK-NEXT: %[[IF_STOP_COMBINING_DIMENSIONS:.*]]:2 = scf.if %[[STOP_COMBINING_DIMENSIONS]] -> (index, index) { + // CHECK-NEXT: %[[IS_RUNNING_PRODUCT_NOT_ONE:.*]] = cmpi ne, %[[RUNNING_PRODUCT]], %[[C1_1]] : index + // CHECK-NEXT: %[[NEW_OFFSET_1:.*]] = scf.if %[[IS_RUNNING_PRODUCT_NOT_ONE]] -> (index) { + // CHECK-NEXT: %[[NEW_OFFSET_0:.*]] = addi %[[OFFSET]], %[[C1_1]] : index + // CHECK-NEXT: %[[WAS_IN_BOUNDS:.*]] = cmpi sge, %[[RESULT_DIMENSION]], %[[MINUS_ONE]] : index + // CHECK-NEXT: scf.if %[[WAS_IN_BOUNDS]] { + // CHECK-NEXT: %[[CURRENT_DIMENSION:.*]] = subi %[[REDUCED_RANK_LHS]], %[[NEW_OFFSET_0]] : index + // CHECK-NEXT: store %[[RUNNING_PRODUCT]], %[[RESULT_LHS]][%[[CURRENT_DIMENSION]]] : memref + // CHECK-NEXT: } + // CHECK: scf.yield %[[NEW_OFFSET_0]] : index + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.yield %[[OFFSET]] : index + // CHECK-NEXT: } + // CHECK-NEXT: %[[IF_DIFFERENT_SIZES:.*]]:2 = scf.if %[[FINAL_DIFFERENT_SIZES]] -> (index, index) { + // CHECK-NEXT: %[[NEW_OFFSET_2:.*]] = addi %[[NEW_OFFSET_1]], %[[C1_1]] : index + // CHECK-NEXT: %[[IS_IN_BOUNDS:.*]] = cmpi sge, %[[RESULT_DIMENSION]], %[[C0]] : index + // CHECK-NEXT: scf.if %[[IS_IN_BOUNDS]] { + // CHECK-NEXT: %[[CURRENT_DIMENSION:.*]] = subi %[[REDUCED_RANK_LHS]], %[[NEW_OFFSET_2]] : index + // CHECK-NEXT: store %[[CURRENT_SIZE]], %[[RESULT_LHS]][%[[CURRENT_DIMENSION]]] : memref + // CHECK-NEXT: } + // CHECK: scf.yield %[[C1_1]], %[[NEW_OFFSET_2]] : index, index + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.yield %[[FINAL_SAME_SIZE]], %[[NEW_OFFSET_1]] : index, index + // CHECK-NEXT: } + // CHECK-NEXT: scf.yield %[[IF_DIFFERENT_SIZES]]#0, %[[IF_DIFFERENT_SIZES]]#1 : index, index + // CHECK-NEXT: } else { + // CHECK-NEXT: %[[NEW_RUNNING_PRODUCT:.*]] = muli %[[RUNNING_PRODUCT]], %[[FINAL_SAME_SIZE]] : index + // CHECK-NEXT: scf.yield %[[NEW_RUNNING_PRODUCT]], %[[OFFSET]] : index, index + // CHECK-NEXT: } + // CHECK-NEXT: scf.yield %[[IF_STOP_COMBINING_DIMENSIONS]]#0, %[[IF_STOP_COMBINING_DIMENSIONS]]#1 : index, index + // CHECK-NEXT: } + // CHECK: return %[[SUBVIEW_LHS:.*]], %[[SUBVIEW_RHS:.*]] : memref, memref + %0, %1 = chlo.minimum_broadcast_shapes %lhs, %rhs : + tensor, tensor -> tensor, tensor + return %0, %1 : tensor, tensor +} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir index d1e70821975..1d3d5e485fb 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir @@ -5,12 +5,3 @@ func @alloc_raw(%ctx: !tf_framework.op_kernel_context, %size : index) { %buf = tf_framework.alloc(%ctx, %size) : memref return } - -// ----- - -func @minimum_broadcast_shapes(%lhs: tensor, %rhs: tensor) { - // expected-error @+1{{number of operand shapes 2 does not match number of result shapes 1}} - %0 = tf_framework.minimum_broadcast_shapes %lhs, %rhs : - tensor, tensor -> tensor - return -} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir index 6e67195ed6b..7ba69dcfac0 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir @@ -46,11 +46,3 @@ func @null_context() { tf_framework.null_context : !tf_framework.op_kernel_context return } - -// CHECK-LABEL: func @minimum_broadcast_shapes -func @minimum_broadcast_shapes(%lhs: tensor, %rhs: tensor) - -> (tensor, tensor) { - %0, %1 = tf_framework.minimum_broadcast_shapes %lhs, %rhs : - tensor, tensor -> tensor, tensor - return %0, %1 : tensor, tensor -} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index cc02b3f082b..a6341f00736 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -37,6 +37,7 @@ cc_library( hdrs = ["rewriters.h"], compatible_with = get_compatible_with_cloud(), deps = [ + "//tensorflow/compiler/mlir/hlo", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc index 8ff10c9d38e..44d2a989f22 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc @@ -22,7 +22,9 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h" namespace mlir { @@ -86,6 +88,289 @@ class BufferizeDimOp : public OpConversionPattern { } }; +class BufferizeAndConvertMinimumBroadcastShapesOp + : public OpConversionPattern { + public: + using OpConversionPattern< + chlo::MinimumBroadcastShapesOp>::OpConversionPattern; + + LogicalResult matchAndRewrite( + chlo::MinimumBroadcastShapesOp broadcast_shapes_op, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + chlo::MinimumBroadcastShapesOp::Adaptor adaptor(operands); + auto loc = broadcast_shapes_op.getLoc(); + ImplicitLocOpBuilder lb(loc, rewriter); + Value zero = lb.create(0); + SmallVector shapes = adaptor.shapes(); + size_t k = shapes.size(); + SmallVector ranks; + ranks.reserve(k); + SmallVector real_ranks; + real_ranks.reserve(k); + SmallVector leading_ones; + leading_ones.reserve(k); + + // Determine the "real" rank of each operand shape by counting leading 1's. + for (size_t i = 0; i < k; ++i) { + Value rank = lb.create(loc, shapes[i], zero); + ranks.push_back(rank); + leading_ones.push_back(CountLeadingOnes(lb, shapes[i], rank)); + Value real_rank = lb.create(rank, leading_ones[i]); + real_ranks.push_back(real_rank); + } + + // Determine the maximum real rank of the operands. + Value max_rank = real_ranks[0]; + for (size_t i = 1; i < k; ++i) { + Value rank_is_greater = + lb.create(CmpIPredicate::ugt, real_ranks[i], max_rank); + max_rank = lb.create(rank_is_greater, real_ranks[i], max_rank); + } + + // Allocate buffers for the return values and initialize them with 1's. + SmallVector result_shapes; + result_shapes.reserve(k); + auto result_type = + MemRefType::get({ShapedType::kDynamicSize}, lb.getIndexType()); + Value one = lb.create(1); + for (size_t i = 0; i < k; ++i) { + // We assume the buffer will be small, so we allocate it on the stack. + // TODO(b/181654096): Replace AllocaOp with AllocOp. + auto result = lb.create(result_type, real_ranks[i]); + lb.create(zero, real_ranks[i], one, llvm::None, + [&one, &result](OpBuilder &b, Location l, Value idx, + ValueRange /*vr*/) { + b.create(l, one, result, idx); + b.create(l, llvm::None); + }); + result_shapes.push_back(result); + } + + // Iterate through the dimensions and determine which adjacent dimensions + // can be combined. Keep a running product of the dimensions that can be + // combined as iteration variable (initialized to 1), and the current + // dimension offset in the result shapes. We iterate through the shapes + // backward, because the broadcasting semantics mean that the last + // dimensions of each shape (the least significant ones) are matched + // together. + Value running_product = one; + Value current_dimension_offset = zero; + Value two = lb.create(2); + Value max_rank_plus_two = lb.create(loc, max_rank, two); + + // Iterate from 1 to max_rank + 1 (inclusive). This iteration variable is + // used as an offset from the end of each shape vector. We iterate until + // max_rank + 1 to handle the case that we have a running_product > 1 left + // when we have processed all dimensions of the largest shape. + lb.create( + one, max_rank_plus_two, one, + ValueRange{running_product, current_dimension_offset}, + [&](OpBuilder &b, Location l, Value v, ValueRange vr) { + Value constant_false = + b.create(l, b.getI1Type(), b.getBoolAttr(false)); + Value just_out_of_bounds = constant_false; + Value different_sizes = constant_false; + Value minus_one = b.create(l, -1); + + // Initialize 'same_size' to a size that we don't expect to see. + Value same_size = minus_one; + // 'result_dimensions' stores the current dimension with an offset of + // 'leading_ones' to make it easier to check whether we are in-bounds + // with respect to the "real" shape with leading 1's removed. + SmallVector result_dimensions; + SmallVector sizes; + result_dimensions.reserve(k); + sizes.reserve(k); + + // This loop checks whether we have at least two shapes with different + // sizes at the current dimension, and whether we just ran out of + // bounds in at least one shape. + for (size_t i = 0; i < k; ++i) { + // Determine the size of the dimension. If the dimension is out of + // bounds, we choose the value 'same_size', because then the shape + // should not affect the check anymore whether there are two shapes + // with different sizes at the current dimension. + Value is_out_of_bounds = + b.create(l, CmpIPredicate::ult, real_ranks[i], v); + Value dimension = b.create(l, ranks[i], v); + Value result_dimension = + b.create(l, dimension, leading_ones[i]); + result_dimensions.push_back(result_dimension); + Value current_size = + b.create( + l, TypeRange{b.getIndexType()}, is_out_of_bounds, + [&](OpBuilder &b, Location l) { + b.create(l, same_size); + }, + [&](OpBuilder &b, Location l) { + // Using IfOp instead of SelectOp makes sure that we + // don't try to load if the dimension is out of bounds. + Value size = b.create(l, shapes[i], dimension); + b.create(l, size); + }) + .getResult(0); + sizes.push_back(current_size); + Value is_initialized = + b.create(l, CmpIPredicate::ne, same_size, minus_one); + same_size = + b.create(l, is_initialized, same_size, current_size); + Value is_different_size = + b.create(l, CmpIPredicate::ne, current_size, same_size); + same_size = b.create(l, is_different_size, current_size, + same_size); + different_sizes = + b.create(l, different_sizes, is_different_size); + Value is_one_out_of_bounds = b.create( + l, CmpIPredicate::eq, result_dimension, minus_one); + just_out_of_bounds = + b.create(l, just_out_of_bounds, is_one_out_of_bounds); + } + Value running_product = vr.front(); + Value current_dimension_offset = vr.back(); + + // We need to stop combining dimensions if we just ran out of bounds + // in one shape, or there are at least two shapes with different sizes + // at the current dimension. + Value stop_combining_dimensions = + b.create(l, different_sizes, just_out_of_bounds); + auto if_stop_combining_dimensions = b.create( + l, TypeRange{b.getIndexType(), b.getIndexType()}, + stop_combining_dimensions, + [&](OpBuilder &b, Location l) { + // If the running product is not 1, add one dimension of size + // 'running_product' to each shape that is still indexed + // in-bounds or has just gone out of bounds. + Value running_product_not_one = b.create( + l, CmpIPredicate::ne, running_product, one); + Value new_dimension_offset = + b.create( + l, TypeRange{b.getIndexType()}, + running_product_not_one, + [&](OpBuilder &b, Location l) { + Value new_dimension_offset = b.create( + l, current_dimension_offset, one); + for (size_t i = 0; i < k; ++i) { + Value was_in_bounds = b.create( + l, CmpIPredicate::sge, result_dimensions[i], + minus_one); + b.create( + l, was_in_bounds, + [&](OpBuilder &b, Location l) { + Value output_dimension = b.create( + l, real_ranks[i], new_dimension_offset); + b.create(l, running_product, + result_shapes[i], + output_dimension); + b.create(l, llvm::None); + }); + } + b.create(l, new_dimension_offset); + }, + [&](OpBuilder &b, Location l) { + b.create(l, current_dimension_offset); + }) + .getResult(0); + + // If there are at least two different sizes, copy the dimension + // size from the input to the output shapes for all shapes that + // are still indexed in-bounds. + auto if_different_sizes = b.create( + l, TypeRange{b.getIndexType(), b.getIndexType()}, + different_sizes, + [&](OpBuilder &b, Location l) { + Value dimension_offset = + b.create(l, new_dimension_offset, one); + for (size_t i = 0; i < k; ++i) { + Value is_in_bounds = b.create( + l, CmpIPredicate::sge, result_dimensions[i], zero); + b.create( + l, is_in_bounds, [&](OpBuilder &b, Location l) { + Value output_dimension = b.create( + l, real_ranks[i], dimension_offset); + b.create(l, sizes[i], result_shapes[i], + output_dimension); + b.create(l, llvm::None); + }); + } + b.create(l, + ValueRange{one, dimension_offset}); + }, + [&](OpBuilder &b, Location l) { + b.create( + l, ValueRange{same_size, new_dimension_offset}); + }); + b.create(l, if_different_sizes.getResults()); + }, + [&](OpBuilder &b, Location l) { + Value new_running_product = + b.create(l, running_product, same_size); + b.create(l, ValueRange{new_running_product, + current_dimension_offset}); + }); + b.create(l, if_stop_combining_dimensions.getResults()); + }); + for (size_t i = 0; i < k; ++i) { + result_shapes[i] = + RemoveLeadingOnesFrom1DMemref(lb, result_shapes[i], real_ranks[i]); + } + rewriter.replaceOp(broadcast_shapes_op, result_shapes); + return success(); + } + + private: + Value CountLeadingOnes(ImplicitLocOpBuilder &lb, Value extent_memref, + Value rank) const { + // Count leading 1's. Use two iteration variables for that: one with a + // boolean flag for whether every size so far was 1, one with the number of + // leading 1's. + Value constant_true = + lb.create(lb.getI1Type(), lb.getBoolAttr(true)); + Value zero = lb.create(0); + Value one = lb.create(1); + auto leading_ones_loop = lb.create( + zero, rank, one, ValueRange{constant_true, zero}, + [&](OpBuilder &b, Location l, Value idx, ValueRange vr) { + auto size = b.create(l, extent_memref, idx); + auto is_equal_to_one = + b.create(l, CmpIPredicate::eq, size, one); + auto all_ones = b.create(l, vr.front(), is_equal_to_one); + auto increased_value = b.create(l, vr.back(), one); + auto number_of_leading_ones = + b.create(l, all_ones, increased_value, vr.back()); + b.create(l, + ValueRange{all_ones, number_of_leading_ones}); + }); + return leading_ones_loop.results()[1]; + } + + Value RemoveLeadingOnesFrom1DMemref(ImplicitLocOpBuilder &lb, + Value extent_memref, Value rank) const { + Value leading_ones = CountLeadingOnes(lb, extent_memref, rank); + Value new_rank = lb.create(rank, leading_ones); + auto result_type = + MemRefType::get({ShapedType::kDynamicSize}, lb.getIndexType()); + // Ideally we would use SubView here to return a MemRef with 'leading_ones' + // as offset, but several things related to MemRef with offsets are + // currently broken, so instead we just allocate another buffer of the + // desired size and copy the elements over. We assume the buffer will be + // small, so we allocate it on the stack. + // TODO(b/181654096): Replace AllocaOp with AllocOp. + Value result = lb.create(result_type, new_rank); + Value zero = lb.create(0); + Value one = lb.create(1); + lb.create( + zero, new_rank, one, llvm::None, + [&](OpBuilder &b, Location l, Value idx, ValueRange /*vr*/) { + Value idx_with_offset = b.create(l, idx, leading_ones); + auto size = b.create(l, extent_memref, idx_with_offset); + b.create(l, size, result, idx); + b.create(l, llvm::None); + }); + return result; + } +}; + class BufferizeRankOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -102,8 +387,10 @@ class BufferizeRankOp : public OpConversionPattern { void populateExtraStdBufferizePattern(MLIRContext *context, BufferizeTypeConverter *converter, OwningRewritePatternList *patterns) { - patterns->insert( - *converter, context); + patterns + ->insert( + *converter, context); } } // namespace transforms diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc index 22e20668694..d542c104119 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc @@ -42,6 +42,7 @@ limitations under the License. #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Transforms/Bufferize.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" @@ -124,8 +125,9 @@ struct HloBufferizePass : public HloBufferizePassBase { &context, &converter, &patterns, /*insert_copy=*/false); populateFuncOpTypeConversionPattern(patterns, &context, converter); populateCallOpTypeConversionPattern(patterns, &context, converter); - populateBranchOpInterfaceAndReturnOpTypeConversionPattern( - patterns, &context, converter); + populateBranchOpInterfaceTypeConversionPattern(patterns, &context, + converter); + populateReturnOpTypeConversionPattern(patterns, &context, converter); // Configure legality and structural patterns. populateBufferizeMaterializationLegality(target); @@ -172,7 +174,8 @@ struct FinalBufferizePass : public FinalBufferizePassBase { target.addIllegalDialect(); target.addIllegalOp(); BufferizeTypeConverter converter; auto typesAreLegal = [&converter](Operation* op) { diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_memref_prints.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_memref_prints.cc index 4b859628ced..26722336d2c 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_memref_prints.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_memref_prints.cc @@ -73,12 +73,12 @@ void EmitPrint(Operation* op, Liveness& liveness, OpBuilder* b) { element_type = b->getI64Type(); memref_type = MemRefType::get(memref_type.getShape(), element_type, memref_type.getAffineMaps(), - memref_type.getMemorySpace()); + memref_type.getMemorySpaceAsInt()); memref = b->create(loc, memref, memref_type); } auto unranked_type = - UnrankedMemRefType::get(element_type, memref_type.getMemorySpace()); + UnrankedMemRefType::get(element_type, memref_type.getMemorySpaceAsInt()); Value unranked_memref = b->create(loc, memref, unranked_type); if (element_type.isF32()) { diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc index 042a8d8db0e..a761b07b7cd 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc @@ -356,7 +356,7 @@ class NullMemRefOpConverter : public ConvertOpToLLVMPattern { loc, getVoidPtrType(), sizes.front(), llvm::None); // Populate underlying ranked descriptor. - unsigned address_space = result_type.getMemorySpace(); + unsigned address_space = result_type.getMemorySpaceAsInt(); Type elem_type = result_type.getElementType(); Type llvm_elem_type = typeConverter->convertType(elem_type); Type elem_ptr_ptr_type = LLVM::LLVMPointerType::get( diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc index c3b3d99c782..b487814caa2 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc @@ -2117,7 +2117,7 @@ void LegalizeTF::runOnFunction() { } // anonymous namespace -// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass. +// Creates an instance of the TensorFlow dialect LegalizeTF pass. std::unique_ptr> createLegalizeTFPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 63c07a3fe49..6fdd7571d9f 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -124,7 +124,6 @@ cc_library( ":type_to_shape", ":xla_legalize_tf_passes_inc_gen", ":xla_legalize_tf_with_tf2xla", - ":xla_passes", "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:chlo_legalize_to_hlo", "//tensorflow/compiler/mlir/hlo:convert_op_folder", diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index 55e3277afda..e42d0a99cdc 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -86,12 +86,11 @@ mlir::Location GenerateInstructionLocation(HloInstruction* instruction, mlir::OpBuilder* func_builder) { const std::string& op_name = instruction->metadata().op_name(); if (op_name.empty()) { - return mlir::NameLoc::get(func_builder->getIdentifier(instruction->name()), - func_builder->getContext()); + return mlir::NameLoc::get(func_builder->getIdentifier(instruction->name())); } - mlir::Location op_name_loc = mlir::NameLoc::get( - func_builder->getIdentifier(op_name), func_builder->getContext()); + mlir::Location op_name_loc = + mlir::NameLoc::get(func_builder->getIdentifier(op_name)); const std::string& source_file = instruction->metadata().source_file(); if (source_file.empty()) { return op_name_loc; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 482fc379bf0..9658bdbc9a3 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -4625,7 +4625,7 @@ class ConvertInfeedDequeueTupleOp /*layout=*/layout); // TODO(b/171212005): Reenable layout. - data_and_token.removeAttr("layout"); + data_and_token->removeAttr("layout"); if (op._XlaSharding().hasValue()) { // _XlaSharding attribute in TF is a serialized string of the OpSharding diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index bae0196fc63..722aad59c2b 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -130,6 +130,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h index 4373060e3a3..d5b51b81870 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h @@ -227,8 +227,7 @@ class LhloDialectEmitter : public xla::ConstDfsHloVisitorWithDefault { // Return an MLIR location for an HLO instruction. Location getLocation(const xla::HloInstruction* inst) { - return NameLoc::get(builder_.getIdentifier(inst->name()), - builder_.getContext()); + return NameLoc::get(builder_.getIdentifier(inst->name())); } // This map provides access to MLIR buffers for each HLO buffer allocation. diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index 77ba879e000..f56ef27a2ab 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -51,7 +51,7 @@ std::unique_ptr> createLegalizeTfWithTf2XlaPass( /// Replaces types that do not exist in MHLO with equivalent types that do /// exist. -std::unique_ptr> CreateLegalizeTfTypesPass(); +std::unique_ptr> CreateLegalizeTfTypesPass(); /// Adds the TF to XLA via TF2XLA rewrite patterns to the pattern list. void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type, diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index 64ac229d905..f68eee7e813 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -580,6 +580,7 @@ cc_library( copts = tf_copts(), deps = [ "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib_proto_parsing", diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index e7e5cef8b86..e2fa7f30873 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -481,22 +481,27 @@ Status CreateTRTNode(const ConversionParams& params, NodeDef trt_node; NameAttrList function; function.set_name(StrCat(info.engine_name, "_native_segment")); - Status status = - node_builder.Attr("input_shapes", input_shape_protos) - .Attr("static_engine", - info.engine_type == EngineInfo::EngineType::TRTStatic) - .Attr("segment_func", function) - .Attr("serialized_segment", segment_string) - .Attr("calibration_data", "") - .Attr("max_cached_engines_count", info.maximum_cached_engines) - .Attr("workspace_size_bytes", info.max_workspace_size_bytes) - .Attr("max_batch_size", max_batch_size) - .Attr("precision_mode", prec_string) - .Attr("use_calibration", info.use_calibration) - .Attr("_use_implicit_batch", params.use_implicit_batch) - .Attr("_allow_build_at_runtime", info.allow_build_at_runtime) - .Attr("OutT", out_types) - .Finalize(&trt_node); + node_builder.Attr("input_shapes", input_shape_protos) + .Attr("static_engine", + info.engine_type == EngineInfo::EngineType::TRTStatic) + .Attr("segment_func", function) + .Attr("serialized_segment", segment_string) + .Attr("calibration_data", "") + .Attr("max_cached_engines_count", info.maximum_cached_engines) + .Attr("workspace_size_bytes", info.max_workspace_size_bytes) + .Attr("max_batch_size", max_batch_size) + .Attr("precision_mode", prec_string) + .Attr("use_calibration", info.use_calibration) + .Attr("_use_implicit_batch", params.use_implicit_batch) + .Attr("_allow_build_at_runtime", info.allow_build_at_runtime) + .Attr("OutT", out_types); + + if (!params.use_implicit_batch) { + node_builder.Attr("profile_strategy", + ProfileStrategyToName(params.profile_strategy)); + } + + Status status = node_builder.Finalize(&trt_node); if (!status.ok()) { LOG(ERROR) << "Node construction failed with" << status; return status; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h index d3897e864fa..43a551e01bc 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/grappler_item.h" @@ -46,6 +47,7 @@ struct ConversionParams { int max_cached_engines = 1; bool use_calibration = true; bool use_implicit_batch = true; + ProfileStrategy profile_strategy = ProfileStrategy::kRange; bool allow_build_at_runtime = true; }; diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc index 12fea3ade40..324ba0cf682 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc @@ -76,6 +76,10 @@ Status TRTOptimizationPass::Init( if (params.count("use_implicit_batch")) { use_implicit_batch_ = params.at("use_implicit_batch").b(); } + if (params.count("profile_strategy")) { + TF_RETURN_IF_ERROR(ProfileStrategyFromName( + params.at("profile_strategy").s(), &profile_strategy_)); + } return Status::OK(); } @@ -242,6 +246,7 @@ Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster, cp.max_cached_engines = max_cached_batches_; cp.use_calibration = use_calibration_; cp.use_implicit_batch = use_implicit_batch_; + cp.profile_strategy = profile_strategy_; cp.allow_build_at_runtime = allow_build_at_runtime_; auto status = ConvertAfterShapes(cp); VLOG(1) << "Returning from " << name_; diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h index e0aaa5500ab..fd984e5772c 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h @@ -42,6 +42,7 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer { max_workspace_size_bytes_(256LL << 20), use_calibration_(true), use_implicit_batch_(true), + profile_strategy_(ProfileStrategy::kRange), allow_build_at_runtime_(true) { VLOG(1) << "Constructing " << name_; } @@ -75,6 +76,7 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer { int64_t max_workspace_size_bytes_; bool use_calibration_; bool use_implicit_batch_; + ProfileStrategy profile_strategy_; bool allow_build_at_runtime_; }; diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.cc b/tensorflow/compiler/tf2tensorrt/convert/utils.cc index a4bd8d5afd9..34cbaf9a15e 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "absl/strings/ascii.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/errors.h" @@ -257,6 +258,35 @@ int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine) { return n_input / n_profiles; } +string ProfileStrategyToName(const ProfileStrategy strategy) { + switch (strategy) { + case ProfileStrategy::kRange: + return "Range"; + case ProfileStrategy::kOptimal: + return "Optimal"; + case ProfileStrategy::kRangeOptimal: + return "Range+Optimal"; + case ProfileStrategy::kImplicitBatchModeCompatible: + return "ImplicitBatchModeCompatible"; + } + return "Unknown"; +} + +Status ProfileStrategyFromName(const string& name, ProfileStrategy* strategy) { + if (name == "range") { + *strategy = ProfileStrategy::kRange; + } else if (name == "optimal") { + *strategy = ProfileStrategy::kOptimal; + } else if (name == "range+optimal") { + *strategy = ProfileStrategy::kRangeOptimal; + } else if (name == "implicitbatchmodecompatible") { + *strategy = ProfileStrategy::kImplicitBatchModeCompatible; + } else { + return errors::InvalidArgument("Invalid profile strategy: ", name); + } + return Status::OK(); +} + #endif absl::string_view GetDeviceName(const Node* node) { diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.h b/tensorflow/compiler/tf2tensorrt/convert/utils.h index 29b32b6514d..f9cb293a3db 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.h +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.h @@ -200,6 +200,17 @@ absl::optional MergeIfCompatible( absl::optional MergeIfCompatible( const DeviceNameUtils::ParsedName& a, absl::string_view b); +// Optimization profile generation strategies. +enum class ProfileStrategy { + kRange, + kOptimal, + kRangeOptimal, + kImplicitBatchModeCompatible, +}; + +string ProfileStrategyToName(const ProfileStrategy strategy); +Status ProfileStrategyFromName(const string& name, ProfileStrategy* strategy); + #endif // GOOGLE_CUDA && GOOGLE_TENSORRT } // namespace tensorrt diff --git a/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc index a2a41f5a03c..22eebdcf884 100644 --- a/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc @@ -56,7 +56,8 @@ REGISTER_OP("TRTEngineOp") .Attr("cached_engine_batches: list(int) >= 0 = []") .Attr("fixed_input_size: bool = true") .Attr("output_shapes: list(shape) = []") - .Attr("static_engine: bool = true"); + .Attr("static_engine: bool = true") + .Attr("profile_strategy: string = ''"); } // namespace tensorflow #endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.cc index ca25a5840f5..9b3bf6b5acc 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.cc @@ -109,6 +109,10 @@ void TrtShapeOptimizationProfile::InitProfiles( VLOG(1) << "Creating profiles with ImplicitBatchModeCompatible strategy"; ImplicitBatchModeCompatibleStrategy(); break; + // Treat all other strategies the same as kOptimal for now. Implementing + // those is outlined in the dynamic shape support implementation plan. + case ProfileStrategy::kRange: + case ProfileStrategy::kRangeOptimal: case ProfileStrategy::kOptimal: VLOG(1) << "Creating profiles with Optimal strategy"; OptimalStrategy(); diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h b/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h index 71d7d8b1667..854b94dbdd7 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h @@ -115,12 +115,6 @@ struct OptimizationProfileConfig { } }; -// Optimization profile generation strategies. -enum class ProfileStrategy { - kImplicitBatchModeCompatible, - kOptimal, -}; - // Manages Optimization profiles during TRT Engine construction. // // An optimization profile describes a range of dimensions for each TRT network diff --git a/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc index 233ac8e7b45..59ee5ef11e0 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc @@ -57,10 +57,8 @@ class SelfAdjointEigV2Op : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("XlaSelfAdjointEig").TypeConstraint("T", kFloatTypes), - XlaSelfAdjointEigOp); -REGISTER_XLA_OP(Name("SelfAdjointEigV2").TypeConstraint("T", kFloatTypes), - SelfAdjointEigV2Op); +REGISTER_XLA_OP(Name("XlaSelfAdjointEig"), XlaSelfAdjointEigOp); +REGISTER_XLA_OP(Name("SelfAdjointEigV2"), SelfAdjointEigV2Op); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 217bb19f952..ea07ac3d5d2 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -25,6 +25,11 @@ limitations under the License. #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/lib/core/errors.h" +// Note: Most of the operators defined in this module are used by the jax2tf +// converter (see go/jax2tf for details) and are used in SavedModel produced +// by jax2tf. Hence, we need to maintain backwards compatibility for these +// operators. Please reach out to the JAX team if you want to make changes. + namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 1176490adc7..2ba2246b376 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -18,9 +18,10 @@ It is sometimes useful to be able to build HLO programs directly from TensorFlow. This file provides Tensorflow operators that mirror the semantics of HLO operators as closely as possible. -Note: There is no promise of backward or forward compatibility for operators -defined in this module. This is primarily because the underlying HLO operators -do not promise backward or forward compatibility. +Note: Most of the operators defined in this module are used by the jax2tf +converter (see go/jax2tf for details) and are used in SavedModel produced +by jax2tf. Hence, we need to maintain backwards compatibility for these +operators. Please reach out to the JAX team if you want to make changes. """ from __future__ import absolute_import diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 5dc792d3288..1587dcc73f9 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -457,6 +457,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/core:lib", @@ -466,18 +467,15 @@ cc_library( xla_test( name = "self_adjoint_eig_test", srcs = ["self_adjoint_eig_test.cc"], - disabled_backends = [ - "cpu", - "gpu", - ], real_hardware_only = True, - shard_count = 10, tags = ["optonly"], deps = [ ":arithmetic", ":constants", + ":math", ":matrix", ":self_adjoint_eig", + "//tensorflow/compiler/xla:array", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:literal", diff --git a/tensorflow/compiler/xla/client/lib/logdet.cc b/tensorflow/compiler/xla/client/lib/logdet.cc index d2cdc230065..69940ce0b87 100644 --- a/tensorflow/compiler/xla/client/lib/logdet.cc +++ b/tensorflow/compiler/xla/client/lib/logdet.cc @@ -34,13 +34,18 @@ limitations under the License. namespace xla { -// log(det(A)) = sum(log(vecdiag(QR(A).r))), since R is triangular and Q is -// orthonormal -XlaOp LogDet(XlaOp a) { - return a.builder()->ReportErrorOrReturn([&]() -> StatusOr { +SignAndLogDet SLogDet(XlaOp a) { + StatusOr result = [&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, a.builder()->GetShape(a)); auto qr = Qr(a); + int64 m = ShapeUtil::GetDimension(a_shape, -2); + int64 n = ShapeUtil::GetDimension(a_shape, -1); + if (m != n) { + return InvalidArgument( + "Arguments to logdet must be (batched) square matrices, got: %s", + a_shape.ToString()); + } // Get the sign and logarithm of the determinant based on the values along // the diagonal of R and the number of zeros in taus. auto log_abs_det = Einsum(Log(Abs(qr.q_and_r)), "...aa->..."); @@ -49,14 +54,27 @@ XlaOp LogDet(XlaOp a) { One(a.builder(), a_shape.element_type()), CreateScalarMultiplyComputation(a_shape.element_type(), a.builder()), {a_shape.rank() - 2}); + auto sliced_taus = SliceInMinorDims(qr.taus, {0}, {n - 1}); auto sign_taus = Reduce( - Select(Eq(qr.taus, ZerosLike(qr.taus)), FullLike(qr.taus, -1), - FullLike(qr.taus, 1)), + Select(Ne(sliced_taus, ZerosLike(sliced_taus)), + FullLike(sliced_taus, -1), FullLike(sliced_taus, 1)), One(a.builder(), a_shape.element_type()), CreateScalarMultiplyComputation(a_shape.element_type(), a.builder()), {a_shape.rank() - 2}); - return sign_diag * log_abs_det * sign_taus; - }); + return SignAndLogDet{sign_diag * sign_taus, log_abs_det}; + }(); + if (!result.ok()) { + XlaOp error = a.builder()->ReportError(result.status()); + return SignAndLogDet{error, error}; + } + return result.ValueOrDie(); +} + +XlaOp LogDet(XlaOp a) { + SignAndLogDet slogdet = SLogDet(a); + return Select( + Ge(slogdet.sign, ZerosLike(slogdet.sign)), slogdet.logdet, + FullLike(slogdet.logdet, std::numeric_limits::quiet_NaN())); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/logdet.h b/tensorflow/compiler/xla/client/lib/logdet.h index 96e598a6475..83d9365b9f7 100644 --- a/tensorflow/compiler/xla/client/lib/logdet.h +++ b/tensorflow/compiler/xla/client/lib/logdet.h @@ -20,8 +20,16 @@ limitations under the License. namespace xla { -// For matrix a with shape [..., n, n], return log(det(a)) with shape[...]. -// Only hermitian positive definite matrices are supported. +// Computes the sign and logarithm of the absolute value of the determinant +// of a batch of square matrices with shape [..., n, n]. +struct SignAndLogDet { + XlaOp sign; // Either 1, 0, or -1, depending on the determinant's sign. + XlaOp logdet; // log(abs(det(a)). +}; +SignAndLogDet SLogDet(XlaOp a); + +// For a batch of matrices with shape [..., n, n], return log(det(a)). +// Returns NaN if a matrix has a negative determinant. XlaOp LogDet(XlaOp a); } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/logdet_test.cc b/tensorflow/compiler/xla/client/lib/logdet_test.cc index 588b6af74c9..a023edc9aec 100644 --- a/tensorflow/compiler/xla/client/lib/logdet_test.cc +++ b/tensorflow/compiler/xla/client/lib/logdet_test.cc @@ -41,14 +41,17 @@ XLA_TEST_F(LogDetTest, Simple) { {10, 63, 166, 310}, }); - float expected = 14.1601f; - xla::XlaOp a; auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); - xla::LogDet(a); - - ComputeAndCompareR0(&builder, expected, {a_data.get()}, - xla::ErrorSpec(1e-4)); + xla::SignAndLogDet slogdet = xla::SLogDet(a); + xla::XlaOp logdet = xla::LogDet(a); + xla::Tuple(&builder, {slogdet.sign, slogdet.logdet, logdet}); + xla::Literal expected = xla::LiteralUtil::MakeTupleOwned( + xla::LiteralUtil::CreateR0(1.f), + xla::LiteralUtil::CreateR0(14.1601f), + xla::LiteralUtil::CreateR0(14.1601f)); + ComputeAndCompareLiteral(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4)); } XLA_TEST_F(LogDetTest, SimpleTriangle) { @@ -61,14 +64,18 @@ XLA_TEST_F(LogDetTest, SimpleTriangle) { {4, 6, 8, 320}, }); - float expected = 15.9131355f; - xla::XlaOp a; auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); - xla::LogDet(a); + xla::SignAndLogDet slogdet = xla::SLogDet(a); + xla::XlaOp logdet = xla::LogDet(a); + xla::Tuple(&builder, {slogdet.sign, slogdet.logdet, logdet}); + xla::Literal expected = xla::LiteralUtil::MakeTupleOwned( + xla::LiteralUtil::CreateR0(1.f), + xla::LiteralUtil::CreateR0(15.9131355f), + xla::LiteralUtil::CreateR0(15.9131355f)); - ComputeAndCompareR0(&builder, expected, {a_data.get()}, - xla::ErrorSpec(1e-4)); + ComputeAndCompareLiteral(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4)); } XLA_TEST_F(LogDetTest, SimpleBatched) { @@ -87,16 +94,68 @@ XLA_TEST_F(LogDetTest, SimpleBatched) { {8, 82, 456, 106}, {12, 48, 106, 62}, }, + {{2, 2, 3, 4}, {4, 5, 6, 7}, {7, 8, 9, 8}, {10, 11, 12, 13}}, + {{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}}, }); - std::vector expected = {14.1601, 14.3092}; - xla::XlaOp a; auto a_data = CreateR3Parameter(a_vals, 0, "a", &builder, &a); - xla::LogDet(a); + xla::SignAndLogDet slogdet = xla::SLogDet(a); + xla::XlaOp logdet = xla::LogDet(a); + xla::Tuple(&builder, {slogdet.sign, slogdet.logdet, logdet}); + xla::Literal expected = xla::LiteralUtil::MakeTupleOwned( + xla::LiteralUtil::CreateR1({1.f, 1.f, -1.f, 0.f}), + xla::LiteralUtil::CreateR1( + {14.1601f, 14.3092f, 2.4849f, + -std::numeric_limits::infinity()}), + xla::LiteralUtil::CreateR1( + {14.1601f, 14.3092f, std::numeric_limits::quiet_NaN(), + -std::numeric_limits::infinity()})); - ComputeAndCompareR1(&builder, expected, {a_data.get()}, - xla::ErrorSpec(1e-4)); + ComputeAndCompareLiteral(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4)); +} + +XLA_TEST_F(LogDetTest, LogdetOfLargerMatricesBatched) { + xla::XlaBuilder builder(TestName()); + + xla::Array a_vals = { + {{7.2393, 1.1413, 4.1883, -4.8272, 3.2831, -0.0568, -2.4776}, + {0.4347, 3.4095, 1.6259, -4.7100, 1.5942, 1.4217, -2.8009}, + {3.6964, 0.4882, 6.5276, -1.2128, 1.3851, 0.7417, -3.8515}, + {-3.7986, -5.1188, -1.9410, 14.0205, -5.4515, 3.1831, 5.1488}, + {1.5621, 3.0426, 1.4819, -4.5938, 10.1397, 4.9312, -2.8351}, + {-1.5436, -0.0287, -0.1139, 4.4499, 2.5894, 6.1216, 2.7201}, + {-3.7241, -2.7670, -3.8162, 4.5961, -1.7251, -0.4190, 8.6562}}, + + {{3.3789, -2.3607, -1.2471, 2.1503, 0.6062, -0.6057, 1.7748}, + {-1.8670, 11.0947, 0.1229, 0.0599, 3.1714, -4.7941, -4.5442}, + {-0.6905, -0.0829, 5.2156, 2.9528, 2.6200, 6.1638, 1.8652}, + {3.0521, 2.2174, 0.7444, 10.7268, 0.6443, -2.7732, 1.6840}, + {1.8479, 3.0821, 4.5671, 2.9254, 6.1338, 5.2066, 2.3662}, + {-0.0360, -5.5341, 5.9687, -0.3297, 2.1174, 13.0016, 4.0118}, + {0.4380, -4.6683, 3.1548, 0.0924, 0.7176, 6.4679, 6.1819}}, + + {{10.0487, 4.0350, -0.8471, -1.2887, -0.8172, -3.3698, 1.3191}, + {4.8678, 4.6081, 0.8419, -0.2454, -3.2599, -1.2386, 2.4070}, + {1.4877, 0.8362, 2.6077, 1.1782, -0.1116, 1.7130, -1.1883}, + {-0.9245, -0.7435, -0.9456, 2.5936, 1.9887, -0.1324, -0.1453}, + {0.2918, -0.5301, -0.8775, 1.0478, 8.9262, 2.4731, -0.4393}, + {-3.5759, -1.5619, 2.4410, 1.3046, 4.2678, 7.3587, -4.0935}, + {-1.1187, 0.9150, -1.8253, 0.0390, -2.5684, -4.0778, 4.1447}}}; + + xla::XlaOp a; + auto a_data = CreateParameter(a_vals, 0, "a", &builder, &a); + xla::SignAndLogDet slogdet = xla::SLogDet(a); + xla::XlaOp logdet = xla::LogDet(a); + xla::Tuple(&builder, {slogdet.sign, slogdet.logdet, logdet}); + xla::Literal expected = xla::LiteralUtil::MakeTupleOwned( + xla::LiteralUtil::CreateR1({1.f, 1.f, 1.f}), + xla::LiteralUtil::CreateR1({8.93788053, 6.77846303, 7.4852403}), + xla::LiteralUtil::CreateR1({8.93788053, 6.77846303, 7.4852403})); + + ComputeAndCompareLiteral(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4)); } } // namespace diff --git a/tensorflow/compiler/xla/client/lib/matrix.cc b/tensorflow/compiler/xla/client/lib/matrix.cc index dbb73602801..34ba9e1f80f 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.cc +++ b/tensorflow/compiler/xla/client/lib/matrix.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -235,6 +236,36 @@ XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); } XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); } +XlaOp Symmetrize(XlaOp x, bool lower) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + if (shape.rank() < 2) { + return InvalidArgument( + "Argument to symmetrize must have >= 2 dimensions, got %s", + shape.ToString()); + } + const int64 m = ShapeUtil::GetDimension(shape, -2); + const int64 n = ShapeUtil::GetDimension(shape, -1); + if (m != n) { + return InvalidArgument( + "The two most minor dimensions of the argument to symmetrize must be " + "equal size, got %s", + shape.ToString()); + } + auto mask = lower ? TriangleMask(x, 0) : Not(TriangleMask(x, -1)); + if (primitive_util::IsComplexType(shape.element_type())) { + auto re = Select(mask, Real(x), TransposeInMinorDims(Real(x))); + auto im_mask = lower ? TriangleMask(x, -1) : Not(TriangleMask(x, 0)); + auto im = Select(im_mask, Imag(x), ZerosLike(Imag(x))); + im = Select(mask, im, -TransposeInMinorDims(im)); + return Complex(re, im); + } else { + return Select(mask, x, TransposeInMinorDims(x)); + } + }); +} + namespace { absl::optional, 3>> EinsumDiagonalLabels( absl::Span config) { diff --git a/tensorflow/compiler/xla/client/lib/matrix.h b/tensorflow/compiler/xla/client/lib/matrix.h index 1a9f72dedf2..42b21f243f3 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.h +++ b/tensorflow/compiler/xla/client/lib/matrix.h @@ -50,8 +50,8 @@ XlaOp GetMatrixDiagonalViaGather(XlaOp x, int k = 0); // Places diag along the kth diagonal of target. XlaOp SetMatrixDiagonal(XlaOp matrix, XlaOp diag, int k = 0); -// Returns a lower-triangular mask, i.e., true below the `diagonal`-th diagonal -// and false above that diagonal. +// Returns a lower-triangular mask, i.e., true below and including the +// `diagonal`-th diagonal and false above that diagonal. XlaOp TriangleMask(XlaOp x, int diagonal); // Get the upper or lower triangle part of the last two dimensions @@ -63,6 +63,13 @@ XlaOp UpperTriangle(XlaOp x); // Get the lower triangle part of the last two dimensions XlaOp LowerTriangle(XlaOp x); +// If x is an array of shape [..., n, n], symmetrizes the matrix by replacing +// the upper triangle with the transpose of the lower triangle (if lower is +// True, vice-versa otherwise). If the type of `x` is complex, makes the matrix +// Hermitian by taking the conjugate of the complex part and setting the +// complex diagonal to zero. +XlaOp Symmetrize(XlaOp x, bool lower); + // Multiplies slices of two tensors in batches. // Multiplies all slices of `Tensor` `x` and `y` (each slice can be diff --git a/tensorflow/compiler/xla/client/lib/matrix_test.cc b/tensorflow/compiler/xla/client/lib/matrix_test.cc index 628447c289e..85e074b82cf 100644 --- a/tensorflow/compiler/xla/client/lib/matrix_test.cc +++ b/tensorflow/compiler/xla/client/lib/matrix_test.cc @@ -73,6 +73,54 @@ XLA_TEST_F(MatrixTest, Triangle) { ComputeAndCompareR3(&builder, expected, {a_data.get()}); } +XLA_TEST_F(MatrixTest, Symmetrize) { + for (bool lower : {false, true}) { + XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + Array input = { + {1, nan, nan}, + {2, 3, nan}, + {4, 5, 6}, + }; + + XlaOp a; + auto a_data = CreateParameter(input, 0, "a", &builder, &a); + Symmetrize(lower ? a : TransposeInMinorDims(a), /*lower=*/lower); + + Array expected = { + {1, 2, 4}, + {2, 3, 5}, + {4, 5, 6}, + }; + + ComputeAndCompare(&builder, expected, {a_data.get()}); + } +} + +XLA_TEST_F(MatrixTest, SymmetrizeComplex) { + for (bool lower : {false, true}) { + XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + Array input = { + {complex64{1, nan}, nan, nan}, + {complex64{2, 7}, complex64{3, nan}, nan}, + {complex64{4, 8}, complex64{5, 9}, complex64{6, nan}}, + }; + + XlaOp a; + auto a_data = CreateParameter(input, 0, "a", &builder, &a); + Symmetrize(lower ? a : Conj(TransposeInMinorDims(a)), /*lower=*/lower); + + Array expected = { + {1, complex64{2, -7}, complex64{4, -8}}, + {complex64{2, 7}, 3, complex64{5, -9}}, + {complex64{4, 8}, complex64{5, 9}, 6}, + }; + + ComputeAndCompare(&builder, expected, {a_data.get()}); + } +} + template void MatrixTest::TestMatrixDiagonal() { XlaBuilder builder("SetMatrixDiagonal"); diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc index 58905e4ca6f..d77c44de5f6 100644 --- a/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc +++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc @@ -27,444 +27,65 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { -namespace { - -// Jacobi rotation (also known as Givens rotation): -// G = [[ c, s], -// [-s, c]] -// matmul(G_T, G) = I -struct JacobiRotation { - XlaOp c; // cosine. - XlaOp s; // sine. -}; - -// JacobiUpdate holds the intermediate orthogonal matrix, Jacobi-rotated matrix. -struct JacobiUpdate { - XlaOp v; - XlaOp w; -}; - -struct FrobeniusNorms { - XlaOp off_diagonal_norm; - XlaOp total_norm; -}; - -// Given an n-by-n symmetric A and integers p and q that satisfy 0 <= p < q < n, -// it computes a rotation matrix G = [[c, s], [-s, c]], such that -// G_T * A[[p, q], [p, q]] * G -// is diagonalized. -// -// def sym_schur2x2(A, p, q): -// if np.abs(A[p, q]) > 1e-6: -// tau = (A[q, q] - A[p, p]) / (2 * A[p, q]) -// if tau >= 0: -// t = 1.0 / (tau + np.sqrt(1 + tau ** 2)) -// else: -// t = -1.0 / (-tau + np.sqrt(1 + tau ** 2)) -// c = 1.0 / np.sqrt(1.0 + t ** 2) -// s = t * c -// else: -// c = 1.0 -// s = 0.0 -// return c, s -StatusOr SymmetricShurDecomposition2x2(XlaOp a, XlaOp p, - XlaOp q, XlaOp tol) { - XlaBuilder* builder = a.builder(); - TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - - auto zero = ScalarLike(a, 0.0); - auto one = ScalarLike(a, 1.0); - auto two = ScalarLike(a, 2.0); - - auto pqs = DynamicSliceInMinorDims(a, {p, q}, {1, 1}); - - auto ps = DynamicSliceInMinorDims(a, {p, p}, {1, 1}); - auto qs = DynamicSliceInMinorDims(a, {q, q}, {1, 1}); - - auto tau = (qs - ps) / (pqs * two); - auto t_pos = one / (tau + Sqrt(one + Square(tau))); - auto t_neg = -one / (-tau + Sqrt(one + Square(tau))); - auto t = Select(Ge(tau, zero), t_pos, t_neg); - - auto c_temp = Rsqrt(one + Square(t)); - auto s_temp = t * c_temp; - - auto c = Select(Ge(Abs(pqs), tol), c_temp, ZerosLike(c_temp) + one); - auto s = Select(Ge(Abs(pqs), tol), s_temp, ZerosLike(s_temp)); - // Renormalize c and s to compensate for low precision arithmetic, this step - // is redundant if high precision float is used, like float64. - auto rnorm = Rsqrt(Square(c) + Square(s)); - - JacobiRotation schur; - - schur.c = c * rnorm; - schur.s = s * rnorm; - - return schur; -} - -StatusOr Update(JacobiUpdate jacobi_update, XlaOp p, XlaOp q, - XlaOp tol, int64 n) { - XlaBuilder* builder = jacobi_update.w.builder(); - TF_ASSIGN_OR_RETURN(JacobiRotation schur, SymmetricShurDecomposition2x2( - jacobi_update.w, p, q, tol)); - - TF_ASSIGN_OR_RETURN(Shape w_shape, builder->GetShape(jacobi_update.w)); - const std::vector batch_dims(w_shape.dimensions().begin(), - w_shape.dimensions().end() - 2); - const int64 num_dims = w_shape.rank(); - - auto zero = ScalarLike(p, 0); - - XlaOp c = schur.c; - XlaOp s = schur.s; - - auto slice_p = DynamicSliceInMinorDims(jacobi_update.w, {p, zero}, {1, n}); - auto slice_q = DynamicSliceInMinorDims(jacobi_update.w, {q, zero}, {1, n}); - - auto slice_p_new = c * slice_p - s * slice_q; - auto slice_q_new = s * slice_p + c * slice_q; - - jacobi_update.w = - DynamicUpdateSliceInMinorDims(jacobi_update.w, slice_p_new, {p, zero}); - jacobi_update.w = - DynamicUpdateSliceInMinorDims(jacobi_update.w, slice_q_new, {q, zero}); - - slice_p = DynamicSliceInMinorDims(jacobi_update.w, {zero, p}, {n, 1}); - slice_q = DynamicSliceInMinorDims(jacobi_update.w, {zero, q}, {n, 1}); - - slice_p_new = c * slice_p - s * slice_q; - slice_q_new = s * slice_p + c * slice_q; - - jacobi_update.w = - DynamicUpdateSliceInMinorDims(jacobi_update.w, slice_p_new, {zero, p}); - jacobi_update.w = - DynamicUpdateSliceInMinorDims(jacobi_update.w, slice_q_new, {zero, q}); - - // Zero out a_{pq} explicitly. - std::vector pq_dims(batch_dims.begin(), batch_dims.end()); - pq_dims.push_back(1); - pq_dims.push_back(1); - auto pq_zero = ScalarLike(jacobi_update.w, 0.0); - auto pq_zeros = Broadcast(pq_zero, pq_dims); - jacobi_update.w = - DynamicUpdateSliceInMinorDims(jacobi_update.w, pq_zeros, {p, q}); - jacobi_update.w = - DynamicUpdateSliceInMinorDims(jacobi_update.w, pq_zeros, {q, p}); - - slice_p = DynamicSliceInMinorDims(jacobi_update.v, {zero, p}, {n, 1}); - slice_q = DynamicSliceInMinorDims(jacobi_update.v, {zero, q}, {n, 1}); - - std::vector broadcast_dims(batch_dims.size()); - std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); - broadcast_dims.push_back(num_dims - 1); - - // Renormalize the p-th and q-th columns. This step is redundant if high - // precision floats are used, like 64-bit float. But for 32-bit float, it - // becomes necessary. This step will not increase the overall complexity. - slice_p_new = c * slice_p - s * slice_q; - slice_p_new = Mul( - slice_p_new, - Rsqrt(Reduce(Square(slice_p_new), pq_zero, - CreateScalarAddComputation(w_shape.element_type(), builder), - {num_dims - 2})), - broadcast_dims); - slice_q_new = s * slice_p + c * slice_q; - slice_q_new = Mul( - slice_q_new, - Rsqrt(Reduce(Square(slice_q_new), pq_zero, - CreateScalarAddComputation(w_shape.element_type(), builder), - {num_dims - 2})), - broadcast_dims); - - jacobi_update.v = - DynamicUpdateSliceInMinorDims(jacobi_update.v, slice_p_new, {zero, p}); - jacobi_update.v = - DynamicUpdateSliceInMinorDims(jacobi_update.v, slice_q_new, {zero, q}); - - return jacobi_update; -} - -StatusOr ComputeFrobeniusNorms(XlaOp w) { - XlaBuilder* builder = w.builder(); - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(w)); - const int64 num_dims = shape.rank(); - auto frobenius_norm = - Sqrt(Reduce(Square(w), ScalarLike(w, 0.0), - CreateScalarAddComputation(shape.element_type(), builder), - {num_dims - 2, num_dims - 1})); - auto diag = GetMatrixDiagonal(w); - auto diag_square = - Reduce(Square(diag), ScalarLike(w, 0.0), - CreateScalarAddComputation(shape.element_type(), builder), - {num_dims - 2}); - - FrobeniusNorms frobenius_norms; - - frobenius_norms.off_diagonal_norm = - Sqrt(Max(Square(frobenius_norm) - diag_square, ScalarLike(w, 0.0))); - frobenius_norms.total_norm = frobenius_norm; - - return frobenius_norms; -} - -StatusOr> WhileLoopFn( - absl::Span initial_values, // - int matrix_dimension, // - int max_sweep_updates, // - PrimitiveType index_type, // - absl::string_view name, // - XlaBuilder* builder) { - auto while_cond_fn = [&](absl::Span values, - XlaBuilder* cond_builder) -> StatusOr { - auto k = values[0]; - auto max_sweeps = ScalarLike(k, max_sweep_updates); - auto sweep_update_cond = Gt(max_sweeps, k); - - TF_ASSIGN_OR_RETURN(auto norms, ComputeFrobeniusNorms(values[2])); - auto tol = norms.total_norm * values[3]; - auto tol_cond = ReduceAll(Lt(tol, norms.off_diagonal_norm), - xla::ConstantR0(cond_builder, false), - CreateScalarOrComputation(PRED, cond_builder)); - - return And(sweep_update_cond, tol_cond); - }; - - auto while_body_fn = - [&](absl::Span values, - XlaBuilder* body_builder) -> StatusOr> { - auto while_cond_fn_inner = - [&](absl::Span values_inner, - XlaBuilder* inner_cond_builder) -> StatusOr { - auto p = values_inner[0]; - return Lt(p, ScalarLike(p, matrix_dimension - 1)); - }; - - auto while_body_fn_inner = - [&](absl::Span values_inner, - XlaBuilder* inner_body_builder) -> StatusOr> { - auto while_cond_fn_innermost = - [&](absl::Span values_innermost, - XlaBuilder* innermost_cond_builder) -> StatusOr { - auto q = values_innermost[1]; - return Lt(q, ScalarLike(q, matrix_dimension)); - }; - auto while_body_fn_innermost = - [&](absl::Span values_innermost, - XlaBuilder* innermost_body_builder) - -> StatusOr> { - auto p = values_innermost[0]; - auto q = values_innermost[1]; - - JacobiUpdate jacobi_update; - jacobi_update.v = values_innermost[2]; - jacobi_update.w = values_innermost[3]; - - auto tol = values_innermost[4]; - - TF_ASSIGN_OR_RETURN(jacobi_update, - Update(jacobi_update, p, q, tol, matrix_dimension)); - - std::vector updated_values_innermost; - updated_values_innermost.reserve(values_innermost.size()); - - updated_values_innermost.push_back(p); - updated_values_innermost.push_back(q + ScalarLike(q, 1)); - updated_values_innermost.push_back(jacobi_update.v); - updated_values_innermost.push_back(jacobi_update.w); - updated_values_innermost.push_back(tol); - - return updated_values_innermost; - }; - - std::vector values_innermost(5); - auto p = values_inner[0]; - auto q = p + ScalarLike(p, 1); - values_innermost[0] = p; // index p. - values_innermost[1] = q; // index q. - values_innermost[2] = values_inner[1]; // v. - values_innermost[3] = values_inner[2]; // w. - values_innermost[4] = values_inner[3]; // tol. - TF_ASSIGN_OR_RETURN( - values_innermost, - WhileLoopHelper(while_cond_fn_innermost, while_body_fn_innermost, - values_innermost, absl::StrCat(name, "-Innermost"), - inner_body_builder)); - - std::vector updated_values_inner; - updated_values_inner.reserve(values_inner.size()); - - updated_values_inner.push_back(p + ScalarLike(p, 1)); - updated_values_inner.push_back(values_innermost[2]); - updated_values_inner.push_back(values_innermost[3]); - updated_values_inner.push_back(values_innermost[4]); - return updated_values_inner; - }; - // Indexes. - XlaOp k = values[0]; - - std::vector values_inner(4); - values_inner[0] = ScalarLike(k, 0); // index p. - values_inner[1] = values[1]; // v. - values_inner[2] = values[2]; // w. - values_inner[3] = values[3]; // tol. - TF_ASSIGN_OR_RETURN( - values_inner, - WhileLoopHelper(while_cond_fn_inner, while_body_fn_inner, values_inner, - absl::StrCat(name, "-Inner"), body_builder)); - - std::vector updated_values; - updated_values.reserve(values_inner.size()); - - updated_values.push_back(k + ScalarLike(k, 1)); - updated_values.push_back(values_inner[1]); - updated_values.push_back(values_inner[2]); - updated_values.push_back(values_inner[3]); - - return updated_values; - }; - std::vector values; - TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn, - initial_values, name, builder)); - - return values; -} - -StatusOr SortByEigenvalues(SelfAdjointEigResult result) { - XlaBuilder* builder = result.v.builder(); - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(result.v)); - const int64 num_dims = shape.rank(); - auto dimensions = shape.dimensions(); - - std::vector broadcast_dims(num_dims - 1); - std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); - broadcast_dims[num_dims - 2] = num_dims - 1; - result.w = BroadcastInDim(result.w, dimensions, broadcast_dims); - - XlaOp sort_result = - Sort({result.w, result.v}, - CreateScalarLtComputation( - {shape.element_type(), shape.element_type()}, builder), - num_dims - 1); - result.w = GetMatrixDiagonal(GetTupleElement(sort_result, 0)); - result.v = GetTupleElement(sort_result, 1); - return result; -} - -} // namespace - -// This is the cyclic Jacobi iteration. Please note that the eigenvalues are -// possibly not ordered. -// -// def jacobi(A): -// n, _ = A.shape -// V = np.eye(n) -// frobenius_norm = np.linalg.norm(A) -// diag_norm = np.linalg.norm(np.diag(A)) -// off_diag_norm = np.sqrt( -// frobenius_norm - diag_norm) * np.sqrt(frobenius_norm + diag_norm) -// while off_diag_norm > 1e-6 * frobenius_norm: -// for p in range(n - 1): -// for q in range(p + 1, n): -// c, s = sym_schur2x2(A, p, q) -// A[[p, q], :] = np.matmul(np.array([[c, -s], [s, c]]), -// A[[p, q], :]) -// A[:, [p, q]] = np.matmul(A[:, [p, q]], -// np.array([[c, s], [-s, c]])) -// V[:, [p, q]] = np.matmul(V[:, [p, q]], -// np.array([[c, s], [-s, c]])) -// frobenius_norm = np.linalg.norm(A) -// diag_norm = np.linalg.norm(np.diag(A)) -// off_diag_norm = np.sqrt( -// frobenius_norm - diag_norm) * np.sqrt( -// frobenius_norm + diag_norm) -// -// return A, V -// -// TODO(kuny): Implement parallel order Jacobi. -// SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower, int64 max_iter, - float epsilon) { + float tol) { XlaBuilder* builder = a.builder(); - auto return_error = [&](const Status& status) { - SelfAdjointEigResult result; - result.v = builder->ReportError(status); - result.w = builder->ReportError(status); - return result; - }; - auto shape_with_status = builder->GetShape(a); - if (!shape_with_status.ok()) { - return return_error(shape_with_status.status()); - } - Shape a_shape = shape_with_status.ValueOrDie(); - const int64 num_dims = a_shape.rank(); - if (num_dims < 2) { - return return_error(InvalidArgument( - "Arguments to Eigen decomposition must have rank >= 2: got shape %s.", - a_shape.ToString())); - } - PrimitiveType type = a_shape.element_type(); - if (!primitive_util::IsFloatingPointType(type)) { - return return_error(InvalidArgument( - "Type of the input matrix must be float: got %s.", a_shape.ToString())); - } + XlaOp result = builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int64 num_dims = a_shape.rank(); + if (num_dims < 2) { + return InvalidArgument( + "Arguments to Eigen decomposition must have rank >= 2: got shape %s.", + a_shape.ToString()); + } + PrimitiveType type = a_shape.element_type(); + if (!primitive_util::IsFloatingPointType(type) && + !primitive_util::IsComplexType(type)) { + return InvalidArgument( + "Type of the input matrix must be floating point " + "or complex: got %s.", + a_shape.ToString()); + } - const int64 m = ShapeUtil::GetDimension(a_shape, -2); - const int64 n = ShapeUtil::GetDimension(a_shape, -1); + const int64 m = ShapeUtil::GetDimension(a_shape, -2); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); - if (m != n) { - return return_error(InvalidArgument( - "Arguments to Eigen decomposition must be square matrices: got shape " - "(%d, %d).", - m, n)); - } + if (m != n) { + return InvalidArgument( + "Arguments to symmetric eigendecomposition must be square matrices: " + "got shape (%d, %d).", + m, n); + } - const int64 num_batch_dims = num_dims - 2; - std::vector batch_dims(num_batch_dims); - for (int i = 0; i < num_batch_dims; ++i) { - batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); - } + const int num_batch_dims = a_shape.dimensions().size() - 2; + const std::vector batch_dims( + a_shape.dimensions().begin(), + a_shape.dimensions().begin() + num_batch_dims); - auto tol = ScalarLike(a, epsilon); - - auto v_init = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims); - auto w_init = Triangle(a, lower); - w_init = w_init + TransposeInMinorDims(w_init) - w_init * v_init; - - auto output_with_status = WhileLoopFn( - { - Zero(builder, S32), // k - v_init, // v - w_init, // w - tol, // - }, // - n, // - max_iter, // - S32, // - "CyclicJacobi", // - builder); - if (!output_with_status.ok()) { - return return_error(output_with_status.status()); - } - - auto output = output_with_status.ValueOrDie(); - - SelfAdjointEigResult result; - result.v = output[1]; - result.w = GetMatrixDiagonal(output[2]); - - auto result_or = SortByEigenvalues(result); - if (!result_or.ok()) { - return return_error(result_or.status()); - } - return result_or.ValueOrDie(); + PrimitiveType eigvals_type = + primitive_util::IsComplexType(type) + ? primitive_util::ComplexComponentType(type) + : type; + std::vector eigvals_dims = batch_dims; + eigvals_dims.push_back(m); + Shape eigh_shape = ShapeUtil::MakeTupleShape( + {a_shape, ShapeUtil::MakeShape(eigvals_type, eigvals_dims)}); + // TODO(phawkins): upgrade Eigh decomposition to a first-class HLO operator. + std::string opaque = + absl::StrFormat("%d,%d,%f", lower ? 1 : 0, max_iter, tol); + return CustomCall(a.builder(), "Eigh", {a}, eigh_shape, opaque); + }); + return SelfAdjointEigResult{GetTupleElement(result, 0), + GetTupleElement(result, 1)}; } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eig.h b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.h index 954c979ee29..d4ec5663a19 100644 --- a/tensorflow/compiler/xla/client/lib/self_adjoint_eig.h +++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.h @@ -33,7 +33,7 @@ struct SelfAdjointEigResult { }; SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower = true, - int64 max_iter = 15, float epsilon = 1e-6); + int64 max_iter = 15, float tol = 1e-7); } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc b/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc index ba26701cb7c..99f259130bb 100644 --- a/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc +++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc @@ -15,10 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" +#include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" @@ -100,48 +102,35 @@ class SelfAdjointEigTest : public ClientLibraryTestBase { return result; } - XlaOp ComputeMatmulVWVt(SelfAdjointEigResult result, XlaBuilder* builder) { - Shape shape = builder->GetShape(result.v).ValueOrDie(); - absl::Span out_dims = shape.dimensions(); - std::vector broadcast_dims(shape.rank() - 1); - std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); - - broadcast_dims[shape.rank() - 2] = shape.rank() - 1; - auto vw = Mul(result.v, BroadcastInDim(result.w, out_dims, broadcast_dims)); - return BatchDot(vw, TransposeInMinorDims(result.v), - PrecisionConfig::HIGHEST); - } - - XlaOp GetAverageAbsoluteError(XlaOp m1, XlaOp m2, XlaBuilder* builder) { - Shape shape = builder->GetShape(m1).ValueOrDie(); - int64 size = 1; - for (auto d : shape.dimensions()) { - size *= d; - } - return ReduceAll(Abs(m1 - m2), ConstantR0WithType(builder, F32, 0), - CreateScalarAddComputation(F32, builder)) / - ConstantR0WithType(builder, F32, size); - } - - Array2D GenerateRandomSymmetricMatrix(int size) { - Array2D result{size, size, 0.0}; - // TODO(b/128001705): This seed should not be needed but makes the test - // avoid inputs which trigger numerical instability. - result.FillRandom(10 /* stddev */, 2 /* mean */, 12346 /* seed */); - for (int i = 0; i < size; ++i) { - for (int j = 0; j < i; ++j) { - result({j, i}) = result({i, j}); - } - } - return result; - } - Array3D batch_3d_4x4_; Array2D matrix2d_8x8_; Array2D low_rank_4x4_; Array2D wrong_type_4x4_; }; +XlaOp GetAverageAbsoluteError(XlaOp m1, XlaOp m2, XlaBuilder* builder) { + Shape shape = builder->GetShape(m1).ValueOrDie(); + int64 size = ShapeUtil::ElementsIn(shape); + return ReduceAll(Abs(m1 - m2), ConstantR0WithType(builder, F32, 0), + CreateScalarAddComputation(F32, builder)) / + ConstantR0WithType(builder, F32, std::max(1, size)); +} + +XlaOp ComputeMatmulVWVt(SelfAdjointEigResult result, XlaBuilder* builder) { + Shape shape = builder->GetShape(result.v).ValueOrDie(); + absl::Span out_dims = shape.dimensions(); + std::vector broadcast_dims(shape.rank() - 1); + std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); + + broadcast_dims[shape.rank() - 2] = shape.rank() - 1; + auto vw = + Mul(result.v, + BroadcastInDim(ConvertElementType(result.w, shape.element_type()), + out_dims, broadcast_dims)); + return BatchDot(vw, MaybeConjugate(TransposeInMinorDims(result.v), true), + PrecisionConfig::HIGHEST); +} + XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_2x4x4) { XlaBuilder builder(TestName()); @@ -154,6 +143,22 @@ XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_2x4x4) { ErrorSpec(1e-3, 1e-3)); } +XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_3x3_Complex) { + XlaBuilder builder(TestName()); + Array input = { + {1, complex64{2, -7}, complex64{4, -8}}, + {complex64{2, 7}, 3, complex64{5, -9}}, + {complex64{4, 8}, complex64{5, 9}, 6}, + }; + XlaOp a; + auto a_data = CreateParameter(input, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + ComputeMatmulVWVt(result, &builder); + + ComputeAndCompare(&builder, input, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Lower_2x4x4) { XlaBuilder builder(TestName()); @@ -247,69 +252,43 @@ XLA_TEST_F(SelfAdjointEigTest, Wrong_Type_Int) { EXPECT_FALSE(result.w.valid()); } -XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_8x8) { +Array2D GenerateRandomSymmetricMatrix(int size) { + Array2D result{size, size, 0.0}; + // TODO(b/128001705): This seed should not be needed but makes the test + // avoid inputs which trigger numerical instability. + result.FillRandom(10 /* stddev */, 2 /* mean */, 12346 /* seed */); + for (int i = 0; i < size; ++i) { + for (int j = 0; j < i; ++j) { + result({j, i}) = result({i, j}); + } + } + return result; +} + +using EighTestCase = int64; +class RandomEighTest : public ClientLibraryTestBase, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(RandomEighTest, Random) { XlaBuilder builder(TestName()); - int size = 8; + int64 size = GetParam(); Array2D a_val = GenerateRandomSymmetricMatrix(size); XlaOp a; auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); auto result = SelfAdjointEig(a); GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); - ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, - ErrorSpec(1e-3, 1e-3)); + // TODO(phawkins): this would be better expressed as <= 6e-3. + ComputeAndCompareR0(&builder, 3e-3, {a_data.get()}, + ErrorSpec(3e-3, 0)); } -XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_16x16) { - XlaBuilder builder(TestName()); - int size = 16; - Array2D a_val = GenerateRandomSymmetricMatrix(size); - XlaOp a; - auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); - auto result = SelfAdjointEig(a); - GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); - - ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, - ErrorSpec(1e-3, 1e-3)); -} - -XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_32x32) { - XlaBuilder builder(TestName()); - int size = 32; - Array2D a_val = GenerateRandomSymmetricMatrix(size); - XlaOp a; - auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); - auto result = SelfAdjointEig(a); - GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); - - ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, - ErrorSpec(1e-3, 1e-3)); -} - -XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_256x256) { - XlaBuilder builder(TestName()); - int size = 256; - Array2D a_val = GenerateRandomSymmetricMatrix(size); - XlaOp a; - auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); - auto result = SelfAdjointEig(a); - GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); - - ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, - ErrorSpec(1e-3, 1e-3)); -} - -XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_512x512) { - XlaBuilder builder(TestName()); - int size = 512; - Array2D a_val = GenerateRandomSymmetricMatrix(size); - XlaOp a; - auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); - auto result = SelfAdjointEig(a); - GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); - - ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, - ErrorSpec(1e-3, 1e-3)); -} +INSTANTIATE_TEST_SUITE_P( + RandomEighTestInstantiation, RandomEighTest, + ::testing::Values(0, 1, 2, 3, 8, 16, 32, 256, 512), + [](const ::testing::TestParamInfo& info) { + const int64 size = info.param; + return absl::StrCat(size); + }); } // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD index a36be414ce1..7cbfe505897 100644 --- a/tensorflow/compiler/xla/pjrt/BUILD +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -145,6 +145,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/core:lib", + "@com_google_absl//absl/base", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc index 839d56ab687..436fdff79f0 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc @@ -15,8 +15,23 @@ limitations under the License. #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "absl/base/casts.h" + namespace xla { PjRtBuffer::ExternalReference::~ExternalReference() = default; +StatusOr PjRtClient::UnsafeBufferPointer(PjRtBuffer* buffer) { + if (buffer->on_device_shape().IsTuple()) { + return Unimplemented( + "unsafe_buffer_pointer is not implemented for tuple buffers."); + } + + TF_ASSIGN_OR_RETURN( + std::unique_ptr external_reference_hold, + buffer->AcquireExternalReference()); + const void* ptr = external_reference_hold->OpaqueDeviceMemoryDataPointer(); + return absl::bit_cast(ptr); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index c755efe8866..7c9eab95743 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -245,6 +245,10 @@ class PjRtClient { void* device_ptr, const Shape& shape, PjRtDevice* device, std::function on_delete_callback) = 0; + // Returns platform-dependent address for the given buffer that is often but + // not guaranteed to be the physical/device address. + virtual StatusOr UnsafeBufferPointer(PjRtBuffer* buffer); + // Asynchronously makes a vector of PjRtBuffers that can be used to receive // cross host transfers using `client` on `device'. `shapes` must be the exact // shapes, with identical layouts, corresponding to the buffers that will be diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc index 73a3b93a056..af8dd2d8d97 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc @@ -187,19 +187,6 @@ class CpuAllocator : public tensorflow::Allocator { } }; -static int DefaultThreadPoolSize() { - // Google's CI system exposes an environment variable NPROC that describes - // a CPU reservation for tests. - // TODO(phawkins): expose a better thought-out set of knobs to control - // parallelism. - const char* nproc_str = std::getenv("NPROC"); - int nproc = 0; - if (nproc_str && absl::SimpleAtoi(nproc_str, &nproc)) { - return std::max(0, nproc); - } - return tensorflow::port::MaxParallelism(); -} - PjRtStreamExecutorClient::PjRtStreamExecutorClient( std::string platform_name, LocalClient* client, std::vector> devices, int task_id, diff --git a/tensorflow/compiler/xla/pjrt/utils.cc b/tensorflow/compiler/xla/pjrt/utils.cc index a919c846c11..940f98e13e2 100644 --- a/tensorflow/compiler/xla/pjrt/utils.cc +++ b/tensorflow/compiler/xla/pjrt/utils.cc @@ -260,4 +260,17 @@ StatusOr> GetParametersThatMustBeDonated( return parameters_to_donate; } +int DefaultThreadPoolSize() { + // Google's CI system exposes an environment variable NPROC that describes + // a CPU reservation for tests. + // TODO(phawkins): expose a better thought-out set of knobs to control + // parallelism. + const char* nproc_str = std::getenv("NPROC"); + int nproc = 0; + if (nproc_str && absl::SimpleAtoi(nproc_str, &nproc)) { + return std::max(0, nproc); + } + return tensorflow::port::MaxParallelism(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/utils.h b/tensorflow/compiler/xla/pjrt/utils.h index ff4b5ba283f..b84599fd2a0 100644 --- a/tensorflow/compiler/xla/pjrt/utils.h +++ b/tensorflow/compiler/xla/pjrt/utils.h @@ -52,6 +52,9 @@ Status DetermineArgumentLayoutsFromCompileOptions( StatusOr> GetParametersThatMustBeDonated( const HloModule& module, bool tuple_inputs); +// Return max parallelism level. +int DefaultThreadPoolSize(); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_PJRT_UTILS_H_ diff --git a/tensorflow/compiler/xla/python/jax_jit.cc b/tensorflow/compiler/xla/python/jax_jit.cc index 19e17e8feda..8e02cdbf18c 100644 --- a/tensorflow/compiler/xla/python/jax_jit.cc +++ b/tensorflow/compiler/xla/python/jax_jit.cc @@ -110,12 +110,55 @@ void CallSignature::DecRef() const { namespace { -thread_local bool disable_jit; -void SetDisableJit(bool disable_jit_) { disable_jit = disable_jit_; } -bool GetDisableJit() { return disable_jit; } +// These 2 constants are protected by the GIL. +ABSL_CONST_INIT bool disable_jit_flag = false; +ABSL_CONST_INIT bool enable_x64_flag = false; + +ABSL_CONST_INIT thread_local absl::optional disable_jit_thread_local = + absl::nullopt; +ABSL_CONST_INIT thread_local absl::optional jax_enable_x64_thread_local = + absl::nullopt; + +// The x64 mode is controlled by: +// - a global flag value, associated to --jax_enable_x64 +// - possibly a thread-local value, which initially is absl::nullopt and which +// will default to the flag value as long as it's not set. +void SetEnableX64Flag(bool jax_enable_x64) { enable_x64_flag = jax_enable_x64; } +bool GetEnableX64Flag() { return enable_x64_flag; } +void SetEnableX64ThreadLocal(absl::optional jax_enable_x64) { + jax_enable_x64_thread_local = jax_enable_x64; +} +absl::optional GetEnableX64ThreadLocal() { + return jax_enable_x64_thread_local; +} + +void SetDisableJitFlag(bool disable_jit) { disable_jit_flag = disable_jit; } +bool GetDisableJitFlag() { return disable_jit_flag; } +void SetDisableJitThreadLocal(absl::optional disable_jit) { + disable_jit_thread_local = disable_jit; +} +absl::optional GetDisableJitThreadLocal() { + return disable_jit_thread_local; +} + +bool JitIsDisabled() { + if (disable_jit_thread_local != absl::nullopt) { + return disable_jit_thread_local.value(); + } else { + return disable_jit_flag; + } +} } // namespace +bool GetEnableX64() { + if (jax_enable_x64_thread_local != absl::nullopt) { + return jax_enable_x64_thread_local.value(); + } else { + return enable_x64_flag; + } +} + std::string CallSignature::DebugString() const { std::vector static_args_str; static_args_str.reserve(static_args.size()); @@ -841,9 +884,7 @@ struct CacheEntry { class CompiledFunction { public: CompiledFunction(py::function fun, py::function cache_miss, - py::function get_device, py::function get_jax_enable_x64, - py::function get_jax_disable_jit, - std::vector static_argnums); + py::function get_device, std::vector static_argnums); ~CompiledFunction(); // This function will: @@ -869,8 +910,6 @@ class CompiledFunction { CacheEntry* AddCacheEntry(const py::args& args, const py::kwargs& kwargs, const CallSignature& signature, py::object out_and_fastpath_data); - bool JitIsDisabled() { return GetDisableJit() || jax_disable_jit_.value(); } - bool always_fallback_to_python_ = false; const py::function fun_; // The Python function to jit. @@ -883,22 +922,12 @@ class CompiledFunction { // We need a `unique_ptr` here to ensure value pointer stability. absl::flat_hash_map> executables_; - // As top-level functions are decorated with `jax.jit`, when - // `CompiledFunction` is being instantiated from Python, the clients are not - // yet available (done after GoogleInit). They will be during the first call - // to `Call`. // A function taking no arguments and returning the default device and whether // jax.jit has been committed to it. - const py::function get_jax_enable_x64_; - const py::function get_jax_disable_jit_; const py::function get_device_; // The writing of the following is protected by the mutex. absl::Mutex mu_; - // The value of the Python flag. The value will be computed only during the - // first object call, because GoogleInit must have been executed. - absl::optional jax_enable_x64_ = absl::nullopt; - absl::optional jax_disable_jit_ = absl::nullopt; // The logic if the following: // - if `device` or `backend` are not specified to `jax.jit`, we will use @@ -915,14 +944,10 @@ class CompiledFunction { CompiledFunction::CompiledFunction(py::function fun, py::function cache_miss, py::function get_device, - py::function get_jax_enable_x64, - py::function get_jax_disable_jit, std::vector static_argnums) : fun_(std::move(fun)), cache_miss_(std::move(cache_miss)), static_argnums_(std::move(static_argnums)), - get_jax_enable_x64_(get_jax_enable_x64), - get_jax_disable_jit_(get_jax_disable_jit), get_device_(std::move(get_device)) { std::sort(static_argnums_.begin(), static_argnums_.end()); } @@ -1115,8 +1140,6 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) { absl::MutexLock lock1(&mu_); py::gil_scoped_acquire gil_aquire; - jax_enable_x64_ = py::cast(get_jax_enable_x64_()); - jax_disable_jit_ = py::cast(get_jax_disable_jit_()); if (!default_device_) { py::object device_and_is_committed = get_device_(); try { @@ -1147,10 +1170,12 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) { return py::cast(cache_miss_(*args, **kwargs))[0]; } + bool jax_enable_x64 = GetEnableX64(); + arguments.signature.jax_enable_x64 = jax_enable_x64; // The C++ jit do not support Tracers arguments inputs yet. The Python-based // jit function will be called if any of the dynamic arguments is unsupported. - if (!ConvertArgsToBuffers(jax_enable_x64_.value(), *default_pyclient_, - default_device_, is_committed_, arguments) + if (!ConvertArgsToBuffers(jax_enable_x64, *default_pyclient_, default_device_, + is_committed_, arguments) .ok()) { return py::cast(cache_miss_(*args, **kwargs))[0]; } @@ -1215,16 +1240,27 @@ void BuildJaxjitSubmodule(pybind11::module& m) { cfun.def_property_readonly("__signature__", &CompiledFunction::PythonSignature); - jitlib.def("set_disable_jit", &SetDisableJit); - jitlib.def("get_disable_jit", &GetDisableJit); + jitlib.def("set_disable_jit_cpp_flag", &SetDisableJitFlag); + jitlib.def("get_disable_jit_cpp_flag", &GetDisableJitFlag); + jitlib.def("set_disable_jit_thread_local", &SetDisableJitThreadLocal); + jitlib.def("get_disable_jit_thread_local", &GetDisableJitThreadLocal); + jitlib.def("jit_is_disabled", &JitIsDisabled); + // TODO(jblespiau): Remove from the Python code and remove this + jitlib.def("set_disable_jit", &SetDisableJitThreadLocal); + jitlib.def("get_disable_jit", &GetDisableJitThreadLocal); + + jitlib.def("set_enable_x64_cpp_flag", &SetEnableX64Flag); + jitlib.def("get_enable_x64_cpp_flag", &GetEnableX64Flag); + jitlib.def("set_enable_x64_thread_local", &SetEnableX64ThreadLocal); + jitlib.def("get_enable_x64_thread_local", &GetEnableX64ThreadLocal); + jitlib.def("get_enable_x64", &GetEnableX64); + jitlib.def( "jit", [](py::function fun, py::function cache_miss, py::function get_device, - py::function get_jax_enable_x64, py::function get_jax_disable_jit, std::vector static_argnums) -> std::unique_ptr { return std::make_unique( std::move(fun), std::move(cache_miss), std::move(get_device), - std::move(get_jax_enable_x64), std::move(get_jax_disable_jit), std::move(static_argnums)); }); diff --git a/tensorflow/compiler/xla/python/jax_jit.h b/tensorflow/compiler/xla/python/jax_jit.h index af493ab6333..7370ce119e9 100644 --- a/tensorflow/compiler/xla/python/jax_jit.h +++ b/tensorflow/compiler/xla/python/jax_jit.h @@ -27,6 +27,10 @@ limitations under the License. namespace jax { +// Returns the value for jax_enable_x64 (defined by a thread-local value if +// defined, defaulting to the value of the flag otherwise). +bool GetEnableX64(); + // Describes the abstract shape and dtype of an argument. struct ArgSignature { ArgSignature(xla::PrimitiveType dtype, absl::Span shape, @@ -86,6 +90,7 @@ struct CallSignature { // arguments (sorted by keyword name). std::vector dynamic_args_signatures; xla::PjRtDevice* device; + bool jax_enable_x64; bool operator==(const CallSignature& other) const; bool operator!=(const CallSignature& other) const { diff --git a/tensorflow/compiler/xla/python/pmap_lib.cc b/tensorflow/compiler/xla/python/pmap_lib.cc index ee8cafd951a..a9492a51c7a 100644 --- a/tensorflow/compiler/xla/python/pmap_lib.cc +++ b/tensorflow/compiler/xla/python/pmap_lib.cc @@ -74,11 +74,10 @@ struct PmapCacheEntry { class PmapFunction { public: PmapFunction(py::function fun, py::function cache_miss, - py::function get_jax_enable_x64, std::vector static_argnums) + std::vector static_argnums) : fun_(std::move(fun)), cache_miss_(std::move(cache_miss)), - static_argnums_(std::move(static_argnums)), - get_jax_enable_x64_(get_jax_enable_x64) { + static_argnums_(std::move(static_argnums)) { std::sort(static_argnums_.begin(), static_argnums_.end()); } @@ -124,9 +123,6 @@ class PmapFunction { absl::flat_hash_map> executables_; - const py::function get_jax_enable_x64_; - absl::optional jax_enable_x64_ = absl::nullopt; - // A vector of size `num_outputs`, specifying the sharding of each output std::vector sharding_specs_; }; @@ -198,10 +194,6 @@ py::object PmapFunction::Call(py::args args, py::kwargs kwargs) { if (always_fallback_to_python_) { return py::cast(cache_miss_(*args, **kwargs))[0]; } - // Delayed values are retrieved on the first call to `Call`. - if (jax_enable_x64_ == absl::nullopt) { - jax_enable_x64_ = py::cast(get_jax_enable_x64_()); - } ParsedArgumentsAsBuffers arguments; if (!ParseArguments(args, kwargs, static_argnums_, arguments).ok()) { @@ -210,7 +202,7 @@ py::object PmapFunction::Call(py::args args, py::kwargs kwargs) { // Get dynamic argument signatures. for (py::handle arg : arguments.flat_dynamic_args) { - auto signature_or_error = ArgSignatureOfValue(arg, jax_enable_x64_.value()); + auto signature_or_error = ArgSignatureOfValue(arg, GetEnableX64()); if (!signature_or_error.ok()) { return py::cast(cache_miss_(*args, **kwargs))[0]; } @@ -367,11 +359,9 @@ void BuildPmapSubmodule(pybind11::module& m) { pmap_lib.def( "pmap", [](py::function fun, py::function cache_miss, - py::function get_jax_enable_x64, std::vector static_argnums) -> std::unique_ptr { return std::make_unique( - std::move(fun), std::move(cache_miss), - std::move(get_jax_enable_x64), std::move(static_argnums)); + std::move(fun), std::move(cache_miss), std::move(static_argnums)); }); } diff --git a/tensorflow/compiler/xla/python/py_buffer.cc b/tensorflow/compiler/xla/python/py_buffer.cc index 2f943420c4d..f34193a8f6d 100644 --- a/tensorflow/compiler/xla/python/py_buffer.cc +++ b/tensorflow/compiler/xla/python/py_buffer.cc @@ -149,19 +149,8 @@ StatusOr PyBuffer::AsNumPyArray(py::handle this_obj) { return array; } -// TODO(zhangqiaorjc): Delete UnsafeBufferPointer. StatusOr PyBuffer::UnsafeBufferPointer() const { - if (buffer_->on_device_shape().IsTuple()) { - return Unimplemented( - "unsafe_buffer_pointer is not implemented for tuple " - "buffers."); - } - - TF_ASSIGN_OR_RETURN( - std::unique_ptr external_reference_hold, - buffer_->AcquireExternalReference()); - const void* ptr = external_reference_hold->OpaqueDeviceMemoryDataPointer(); - return absl::bit_cast(ptr); + return client_->pjrt_client()->UnsafeBufferPointer(buffer_.get()); } StatusOr PyBuffer::CudaArrayInterface() const { diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 01ca81d2882..dbe6579e5c4 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -50,7 +50,7 @@ profiler = _xla.profiler # Just an internal arbitrary increasing number to help with backward-compatible # changes. -_version = 3 +_version = 5 xla_platform_names = { 'cpu': 'Host', diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index bd4bece0c40..950b0559b81 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -556,6 +556,19 @@ def TestFactory(xla_backend, cloud_tpu=False): arr = arr.to_py() self.assertEqual(dtype, type(arr[0])) + def testUnsafeBufferPointer(self): + if not isinstance(self.backend, xla_client.Client): + self.skipTest("TPU Driver doesn't support UnsafeBufferPointer().") + arg0 = np.array([]) + arg1 = np.array([[0., 1., 2.]], np.float32) + arg2 = np.array([[3., 4., 5.]], bfloat16) + arg0_buffer = self.backend.buffer_from_pyval(arg0) + arg1_buffer = self.backend.buffer_from_pyval(arg1) + arg2_buffer = self.backend.buffer_from_pyval(arg2) + self.assertGreaterEqual(arg0_buffer.unsafe_buffer_pointer(), 0) + self.assertGreaterEqual(arg1_buffer.unsafe_buffer_pointer(), 0) + self.assertGreaterEqual(arg2_buffer.unsafe_buffer_pointer(), 0) + tests.append(BufferTest) class SingleOpTest(ComputationTest): diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 7f8cdc994de..7d1ed07b9c1 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1971,6 +1971,30 @@ cc_library( ], ) +cc_library( + name = "eigh_expander", + srcs = ["eigh_expander.cc"], + hdrs = ["eigh_expander.h"], + deps = [ + ":op_expander_pass", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:comparators", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:loops", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + cc_library( name = "convolution_4d_expander", srcs = ["convolution_4d_expander.cc"], diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index c3d8df85b6c..9622e176e68 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -345,11 +345,8 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, return true; } -namespace { - -// Returns whether we should avoid changing the precision of inst regardless of -// the producers and users. -bool ShouldKeepPrecisionUnchanged(const HloInstruction* inst) { +bool BFloat16Propagation::ShouldKeepPrecisionUnchanged( + const HloInstruction* inst) { if (inst->opcode() == HloOpcode::kFusion && inst->fusion_kind() == HloInstruction::FusionKind::kCustom) { return ShouldKeepPrecisionUnchanged( @@ -358,14 +355,12 @@ bool ShouldKeepPrecisionUnchanged(const HloInstruction* inst) { // Do not change precision for side-effecting instructions, control flow, and // bitcast-convert, because this pass might break the interfaces or // assumptions for them. - return inst->opcode() == HloOpcode::kCustomCall || // - inst->opcode() == HloOpcode::kCall || // - inst->opcode() == HloOpcode::kBitcastConvert || // + return inst->opcode() == HloOpcode::kCustomCall || + inst->opcode() == HloOpcode::kCall || + inst->opcode() == HloOpcode::kBitcastConvert || inst->HasSideEffectNoRecurse(); } -} // namespace - void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo, bool skip_parameters) { // We handle any fusion computation, while body/condition or conditional diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h index 200599efab2..168649a10bd 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.h +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h @@ -72,6 +72,10 @@ class BFloat16Propagation : public HloModulePass { // (precision reductions were added). StatusOr Run(HloModule* module) override; + // Returns whether we should avoid changing the precision of inst regardless + // of the producers and users. + virtual bool ShouldKeepPrecisionUnchanged(const HloInstruction* inst); + private: // *************************** // Function called and state produced by the forward analysis pass (from diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 2f6bc8e2e9e..623b8262178 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -77,7 +77,6 @@ class AotCompilationOptions { virtual int64 replica_count() const { return 0; } virtual int64 num_cores() const { return 0; } - virtual bool broadcast_replicated_params() const { return false; } virtual bool use_spmd_partitioning() const { return false; } virtual bool deduplicate_hlo() const { return false; } diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 32258433974..60b7a66bf19 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -186,6 +186,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:call_inliner", "//tensorflow/compiler/xla/service:cholesky_expander", + "//tensorflow/compiler/xla/service:eigh_expander", "//tensorflow/compiler/xla/service:qr_expander", "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:convolution_group_converter", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 92b7a3b70be..907c61aed71 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -86,6 +86,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/dynamic_padder.h" +#include "tensorflow/compiler/xla/service/eigh_expander.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -305,6 +306,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); diff --git a/tensorflow/compiler/xla/service/eigh_expander.cc b/tensorflow/compiler/xla/service/eigh_expander.cc new file mode 100644 index 00000000000..d3db58588d6 --- /dev/null +++ b/tensorflow/compiler/xla/service/eigh_expander.cc @@ -0,0 +1,655 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/eigh_expander.h" + +#include +#include + +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" + +// Parallel two-sided Jacobi symmetric eigendecomposition. +// +// The implementation follows the approach described in: +// Brent, Richard P., and Franklin T. Luk. "The solution of singular-value and +// symmetric eigenvalue problems on multiprocessor arrays." SIAM Journal on +// Scientific and Statistical Computing 6.1 (1985): 69-84. +// +// Where the Brent/Luk paper uses "processors", we use "vector elements". +namespace xla { + +namespace { + +// A 2x2 symmetric Eigendecomposition of a matrix A. +// If +// G = [[ c, s], +// [-s, c]] +// matmul(G_T, G) = I +// and +// G @ [[rt1, 0 ], @ G.T = A +// [ 0, rt2]] +struct Eigh2x2 { + // Eigenvalues + XlaOp rt1; + XlaOp rt2; + // First row of Eigenvector matrix. + XlaOp c; // cosine. + XlaOp s; // sine. +}; + +// sqrt(x**2 + y**2), calculated avoiding overflow. +XlaOp Hypot(XlaOp x, XlaOp y) { + x = Abs(x); + y = Abs(y); + auto xy_min = Min(x, y); + auto xy_max = Max(x, y); + auto out = xy_max * Sqrt(ScalarLike(x, 1) + Square(xy_min / xy_max)); + return Select(Eq(xy_min, xy_max), xy_min * ScalarLike(xy_min, std::sqrt(2.)), + out); +} + +// Given an n-by-n symmetric A and integers p and q that satisfy 0 <= p < q < n, +// a Jacobi rotation computes a rotation matrix G = [[c, s], [-s, c]], such that +// G_T * A[[p, q], [p, q]] * G +// is diagonalized. We do this by computing a 2x2 eigendecomposition. +// +// In this parallel Jacobi algorithm, we simultaneously compute Jacobi rotations +// for all of the matrix diagonal elements at the same time. The matrix diagonal +// elements correspond to different rows and columns of the original matrix and +// their rotations do not interfere and hence can be computed in parallel. +// +// The algorithm is based on slaev2/claev2 from LAPACK, modified to allow for +// vectorization. +// In addition, slaev2 always returns the largest eigenvalue as rt1, which has +// the effect of swapping eigenvalues around in the Jacob algorithm. This does +// not converge when used in a parallel Jacobi algorithm, so we modify the +// algorithm to maintain the following symmetry property: +// slaev2(a, b, c) has the opposite Eigenvalue order from slaev2(c, b, a) + +// def symmetric_eigendecomposition_2x2(a, b, c): +// # Input matrix [[a, b], [b, c]]. +// ac_sum = a + c +// ac_diff = a - c +// two_b = 2*b +// +// rt = hypot(ac_diff, two_b) +// +// which_max_abs = np.abs(a) > np.abs(c) +// ac_max = np.where(which_max_abs, a, c) +// ac_min = np.where(which_max_abs, c, a) +// rt1 = np.float32(0.5)*(ac_sum + np.where(ac_sum < 0, -rt, rt)) +// rt2 = np.where(ac_sum != 0, (ac_max / rt1)*ac_min - (b/rt1)*b, +// -np.float32(0.5)*rt) +// +// +// # Modification: don't sort the Eigenvalues. +// rt1, rt2 = (np.where(which_max_abs, rt1, rt2), +// np.where(which_max_abs, rt2, rt1)) +// +// # Compute eigenvectors +// cs = ac_diff + np.where(ac_diff >= 0, rt, -rt) +// +// ct = -two_b / cs +// tn = -cs / two_b +// +// cosine = np.where(two_b != 0, np.float32(1) / np.sqrt(1 + tn*tn), +// np.float32(1)) +// sine = np.where(two_b != 0, tn * cosine, np.float32(0)) +// +// tmp = 1 / np.sqrt(1 + ct*ct) +// cosine = np.where(np.abs(cs) > np.abs(two_b), ct*tmp, cosine) +// sine = np.where(np.abs(cs) > np.abs(two_b), tmp, sine) +// same_sign = (ac_sum >= 0) == (ac_diff >= 0) +// # Modification: use Eigenvalues corresponding to the Eigenvectors above. +// same_sign = (same_sign == which_max_abs) +// cosine, sine = (np.where(same_sign, -sine, cosine), +// np.where(same_sign, cosine, sine)) +// return rt1, rt2, cosine, sine +StatusOr HermitianEigenDecomposition2x2(XlaOp w_tl, XlaOp w_tr, + XlaOp w_br) { + TF_ASSIGN_OR_RETURN(Shape w_tl_shape, w_tl.builder()->GetShape(w_tl)); + bool is_complex = primitive_util::IsComplexType(w_tl_shape.element_type()); + + auto a = GetMatrixDiagonal(Real(w_tl)); + auto b = GetMatrixDiagonal(w_tr); + auto abs_b = Abs(b); + + XlaOp w; + if (is_complex) { + w = Select(Eq(abs_b, ZerosLike(abs_b)), FullLike(b, 1), + Conj(b) / Complex(abs_b, ZerosLike(abs_b))); + b = abs_b; + } + + auto c = GetMatrixDiagonal(Real(w_br)); + auto zero = ScalarLike(a, 0.0); + auto half = ScalarLike(a, 0.5); + auto neg_half = ScalarLike(a, -0.5); + auto one = ScalarLike(a, 1.0); + auto two = ScalarLike(a, 2.0); + + auto ac_sum = a + c; + auto ac_diff = a - c; + auto two_b = two * b; + auto rt = Hypot(ac_diff, two_b); + + // Compute eigenvalues + auto which_max_abs = Gt(Abs(a), Abs(c)); + auto ac_max = Select(which_max_abs, a, c); + auto ac_min = Select(which_max_abs, c, a); + auto rt1 = half * (ac_sum + Select(Lt(ac_sum, zero), -rt, rt)); + auto rt2 = Select(Ne(ac_sum, zero), (ac_max / rt1) * ac_min - (b / rt1) * b, + neg_half * rt); + std::tie(rt1, rt2) = std::make_tuple(Select(which_max_abs, rt1, rt2), + Select(which_max_abs, rt2, rt1)); + + // Compute eigenvectors + auto cs = ac_diff + Select(Ge(ac_diff, zero), rt, -rt); + auto ct = -two_b / cs; + auto tn = -cs / two_b; + + auto cosine = Select(Ne(two_b, zero), Rsqrt(one + Square(tn)), one); + auto sine = Select(Ne(two_b, zero), tn * cosine, zero); + + auto tmp = Rsqrt(one + Square(ct)); + auto abs_cs_larger = Gt(Abs(cs), Abs(two_b)); + cosine = Select(abs_cs_larger, ct * tmp, cosine); + sine = Select(abs_cs_larger, tmp, sine); + auto same_sign = Eq(Ge(ac_sum, zero), Ge(ac_diff, zero)); + same_sign = Eq(same_sign, which_max_abs); + std::tie(cosine, sine) = std::make_tuple(Select(same_sign, -sine, cosine), + Select(same_sign, cosine, sine)); + + // Negate 'sine' because we are returning the first row of the rotation matrix + // not the first eigenvector. + if (is_complex) { + rt1 = Complex(rt1, ZerosLike(rt1)); + rt2 = Complex(rt2, ZerosLike(rt2)); + cosine = Complex(cosine, ZerosLike(cosine)); + sine = Complex(sine, ZerosLike(sine)) * w; + } + return Eigh2x2{rt1, rt2, cosine, -sine}; +} + +// tl, tr, bl, br = ( +// tl * c[:, None] - bl * s[:, None], +// tr * c[:, None] - br * s[:, None], +// tl * s[:, None] + bl * c[:, None], +// tr * s[:, None] + br * c[:, None], +// ) +void ApplyJacobiRotationOverRows(Eigh2x2 rotation, XlaOp& tl, XlaOp& tr, + XlaOp& bl, XlaOp& br) { + Shape shape = tl.builder()->GetShape(tl).ValueOrDie(); + std::vector broadcast_dims(shape.dimensions().size() - 1); + absl::c_iota(broadcast_dims, 0); + auto c = BroadcastInDim(rotation.c, shape.dimensions(), broadcast_dims); + auto s = BroadcastInDim(rotation.s, shape.dimensions(), broadcast_dims); + + auto s_conj = MaybeConjugate(s, true); + std::tie(tl, tr, bl, br) = + std::make_tuple(tl * c - bl * s_conj, tr * c - br * s_conj, + tl * s + bl * c, tr * s + br * c); +} + +// tl, tr, bl, br = ( +// tl * c[None, :] - tr * s[None, :], +// tl * s[None, :] + tr * c[None, :], +// bl * c[None, :] - br * s[None, :], +// bl * s[None, :] + br * c[None, :], +// ) +void ApplyJacobiRotationOverCols(Eigh2x2 rotation, XlaOp& tl, XlaOp& tr, + XlaOp& bl, XlaOp& br) { + Shape shape = tl.builder()->GetShape(tl).ValueOrDie(); + std::vector broadcast_dims(shape.dimensions().size() - 1); + absl::c_iota(broadcast_dims, 0); + broadcast_dims.back() = shape.dimensions().size() - 1; + auto c = BroadcastInDim(rotation.c, shape.dimensions(), broadcast_dims); + auto s = BroadcastInDim(rotation.s, shape.dimensions(), broadcast_dims); + + auto s_conj = MaybeConjugate(s, true); + std::tie(tl, tr, bl, br) = + std::make_tuple(tl * c - tr * s, tl * s_conj + tr * c, bl * c - br * s, + bl * s_conj + br * c); +} + +// def permute_rows_in_col(top, bottom): +// top_out = np.zeros_like(l) +// top_out[0] = top[0] +// top_out[1] = bottom[0] +// top_out[2:] = top[1:-1] +// bottom_out = np.zeros_like(r) +// bottom_out[:-1] = bottom[1:] +// bottom_out[-1] = top[-1] +// return top_out, bottom_out +void PermuteRowsInColumn(XlaOp& top, XlaOp& bottom) { + XlaBuilder* builder = top.builder(); + Shape shape = builder->GetShape(top).ValueOrDie(); + int64 k = ShapeUtil::GetDimension(shape, -1); + if (k <= 1) { + return; + } + int ndim = shape.dimensions_size(); + std::tie(top, bottom) = + std::make_tuple(ConcatInDim(builder, + {SliceInMinorDims(top, {0, 0}, {1, k}), + SliceInMinorDims(bottom, {0, 0}, {1, k}), + SliceInMinorDims(top, {1, 0}, {k - 1, k})}, + ndim - 2), + ConcatInDim(builder, + {SliceInMinorDims(bottom, {1, 0}, {k, k}), + SliceInMinorDims(top, {k - 1, 0}, {k, k})}, + ndim - 2)); +} + +void PermuteColumnsInRow(XlaOp& left, XlaOp& right) { + XlaBuilder* builder = left.builder(); + Shape shape = builder->GetShape(left).ValueOrDie(); + int64 k = ShapeUtil::GetDimension(shape, -1); + if (k <= 1) { + return; + } + int ndim = shape.dimensions_size(); + std::tie(left, right) = + std::make_tuple(ConcatInDim(builder, + {SliceInMinorDims(left, {0}, {1}), + SliceInMinorDims(right, {0}, {1}), + SliceInMinorDims(left, {1}, {k - 1})}, + ndim - 1), + ConcatInDim(builder, + {SliceInMinorDims(right, {1}, {k}), + SliceInMinorDims(left, {k - 1}, {k})}, + ndim - 1)); +} + +// Performs one round of parallel Jacobi rotations; n-1 rounds make a sweep. +// After each rotation, we permute the rows and columns of the quadrants of the +// matrix. The effect of the permutations is that all pairs of rows end up +// on the diagonal of the quadrants after n-1 rounds. The permutations are an +// implicit way of computing a tournament for n players such that each player +// plays every other player exactly once in n - 1 rounds. See the Brent/Luk +// paper for more details. +Status ApplyRotations(int64 n, XlaOp& w_tl, XlaOp& w_tr, XlaOp& w_bl, + XlaOp& w_br, XlaOp& v_tl, XlaOp& v_tr, XlaOp& v_bl, + XlaOp& v_br) { + TF_ASSIGN_OR_RETURN(Eigh2x2 rotation, + HermitianEigenDecomposition2x2(w_tl, w_tr, w_br)); + + ApplyJacobiRotationOverRows(rotation, w_tl, w_tr, w_bl, w_br); + ApplyJacobiRotationOverCols(rotation, w_tl, w_tr, w_bl, w_br); + w_tl = SetMatrixDiagonal(w_tl, rotation.rt1); + w_tr = SetMatrixDiagonal(w_tr, ZerosLike(rotation.rt1)); + w_bl = SetMatrixDiagonal(w_bl, ZerosLike(rotation.rt1)); + w_br = SetMatrixDiagonal(w_br, rotation.rt2); + + PermuteColumnsInRow(w_tl, w_tr); + PermuteColumnsInRow(w_bl, w_br); + PermuteRowsInColumn(w_tl, w_bl); + PermuteRowsInColumn(w_tr, w_br); + + // Apply the rotations to the eigenvector matrix. + // TODO(phawkins): we could omit this if we aren't interested in computing the + // eigenvectors. + ApplyJacobiRotationOverRows(rotation, v_tl, v_tr, v_bl, v_br); + PermuteRowsInColumn(v_tl, v_bl); + PermuteRowsInColumn(v_tr, v_br); + return Status::OK(); +} + +struct FrobeniusNorms { + XlaOp off_diagonal_norm; + XlaOp total_norm; +}; + +StatusOr ComputeFrobeniusNorms(XlaOp w_tl, XlaOp w_tr, + XlaOp w_bl, XlaOp w_br) { + XlaBuilder* builder = w_tl.builder(); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(w_tl)); + const int64 num_dims = shape.rank(); + auto square_norm = [](XlaOp x) -> XlaOp { + return Real(x * MaybeConjugate(x, true)); + }; + PrimitiveType norm_type = + primitive_util::IsComplexType(shape.element_type()) + ? primitive_util::ComplexComponentType(shape.element_type()) + : shape.element_type(); + auto zero = ScalarLike(Real(w_tl), 0.0); + auto frobenius_norm = + Sqrt(Reduce(square_norm(w_tl) + square_norm(w_tr) + square_norm(w_bl) + + square_norm(w_br), + zero, CreateScalarAddComputation(norm_type, builder), + {num_dims - 2, num_dims - 1})); + auto diag_square = Reduce( + Square(GetMatrixDiagonal(Real(w_tl))) + + Square(GetMatrixDiagonal(Real(w_br))), + zero, CreateScalarAddComputation(norm_type, builder), {num_dims - 2}); + + FrobeniusNorms frobenius_norms; + + frobenius_norms.off_diagonal_norm = + Sqrt(Max(Square(frobenius_norm) - diag_square, zero)); + frobenius_norms.total_norm = frobenius_norm; + + return frobenius_norms; +} + +StatusOr> Sweeps(absl::Span initial_values, + int64 n, int max_iters, + PrimitiveType index_type, + XlaBuilder* builder) { + auto while_cond_fn = [&](absl::Span values, + XlaBuilder* cond_builder) -> StatusOr { + auto iter_cond = Lt(values[0], ScalarLike(values[0], max_iters)); + + XlaOp w_tl, w_tr, w_bl, w_br; + std::tie(w_tl, w_tr, w_bl, w_br) = + std::make_tuple(values[2], values[3], values[4], values[5]); + TF_ASSIGN_OR_RETURN(auto norms, + ComputeFrobeniusNorms(w_tl, w_tr, w_bl, w_br)); + auto tol = norms.total_norm * values[1]; + auto tol_cond = ReduceAll(Lt(tol, norms.off_diagonal_norm), + xla::ConstantR0(cond_builder, false), + CreateScalarOrComputation(PRED, cond_builder)); + + return And(iter_cond, tol_cond); + }; + + auto while_body_fn = + [&](absl::Span values, + XlaBuilder* body_builder) -> StatusOr> { + std::vector sweep_values(values.begin() + 1, values.end()); + TF_ASSIGN_OR_RETURN( + sweep_values, + ForEachIndex( + n - 1, S32, + [&](XlaOp iter, absl::Span values, + XlaBuilder* builder) -> StatusOr> { + XlaOp tol, w_tl, w_tr, w_bl, w_br, v_tl, v_tr, v_bl, v_br; + std::tie(tol, w_tl, w_tr, w_bl, w_br, v_tl, v_tr, v_bl, v_br) = + std::make_tuple(values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7], + values[8]); + TF_RETURN_IF_ERROR(ApplyRotations(n, w_tl, w_tr, w_bl, w_br, v_tl, + v_tr, v_bl, v_br)); + return std::vector{tol, w_tl, w_tr, w_bl, w_br, + v_tl, v_tr, v_bl, v_br}; + }, + sweep_values, "ApplyRotations", body_builder)); + std::vector output(values.size()); + output[0] = values[0] + ScalarLike(values[0], 1); + std::copy(sweep_values.begin(), sweep_values.end(), output.begin() + 1); + return output; + }; + return WhileLoopHelper(while_cond_fn, while_body_fn, initial_values, + "EighJacobiSweeps", builder); +} + +StatusOr> SortByEigenvalues(XlaOp v, XlaOp w) { + XlaBuilder* builder = v.builder(); + TF_ASSIGN_OR_RETURN(Shape v_shape, builder->GetShape(v)); + TF_ASSIGN_OR_RETURN(Shape w_shape, builder->GetShape(w)); + const int64 num_dims = v_shape.rank(); + auto dimensions = v_shape.dimensions(); + + std::vector broadcast_dims(num_dims - 1); + std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); + broadcast_dims[num_dims - 2] = num_dims - 1; + w = BroadcastInDim(w, dimensions, broadcast_dims); + + XlaOp sort_result = + Sort({w, v}, + CreateScalarLtComputation( + {w_shape.element_type(), v_shape.element_type()}, builder), + num_dims - 1); + w = GetMatrixDiagonal(GetTupleElement(sort_result, 0)); + v = GetTupleElement(sort_result, 1); + return std::make_pair(v, w); +} + +} // namespace + +// This is the cyclic Jacobi iteration. +// +// def jacobi(A): +// n, _ = A.shape +// tl = A[:n // 2, :n // 2] +// bl = A[n // 2:, :n // 2] +// tr = A[:n // 2, n // 2:] +// br = A[n // 2:, n // 2:] +// v_tl = np.eye(n // 2, dtype=A.dtype) +// v_tr = np.zeros((n // 2, n // 2), A.dtype) +// v_bl = np.zeros((n // 2, n // 2), A.dtype) +// v_br = np.eye(n // 2, dtype=A.dtype) +// frobenius_norm = np.sqrt(np.sum(np.square(tl) + np.square(tr) + +// np.square(bl) + np.square(br))) +// diag_norm = np.sqrt(np.sum(np.square(np.diag(tl)) + +// np.square(np.diag(br)))) +// off_diag_norm = np.sqrt(frobenius_norm - diag_norm) * np.sqrt( +// frobenius_norm + diag_norm) +// while off_diag_norm > 1e-6 * frobenius_norm: +// for i in range(n - 1): +// c, s = sym_schur2x2(tl, tr, br) +// tl, tr, bl, br = ( +// tl * c[:, None] - bl * s[:, None], +// tr * c[:, None] - br * s[:, None], +// tl * s[:, None] + bl * c[:, None], +// tr * s[:, None] + br * c[:, None], +// ) +// tl, tr, bl, br = ( +// tl * c[None, :] - tr * s[None, :], +// tl * s[None, :] + tr * c[None, :], +// bl * c[None, :] - br * s[None, :], +// bl * s[None, :] + br * c[None, :], +// ) +// tl, bl = permute_rows_in_col(tl, bl) +// tr, br = permute_rows_in_col(tr, br) +// tl, tr = permute_cols_in_row(tl, tr) +// bl, br = permute_cols_in_row(bl, br) +// v_tl, v_tr, v_bl, v_br = ( +// v_tl * c[:, None] - v_bl * s[:, None], +// v_tr * c[:, None] - v_br * s[:, None], +// v_tl * s[:, None] + v_bl * c[:, None], +// v_tr * s[:, None] + v_br * c[:, None], +// ) +// v_tl, v_bl = permute_rovs_in_col(v_tl, v_bl) +// v_tr, v_br = permute_rovs_in_col(v_tr, v_br) +// +// frobenius_norm = np.sqrt(np.sum(np.square(tl) + np.square(tr) + +// np.square(bl) + np.square(br))) +// diag_norm = np.sqrt(np.sum(np.square(np.diag(tl)) + +// np.square(np.diag(br)))) +// off_diag_norm = np.sqrt(frobenius_norm - diag_norm) * np.sqrt( +// frobenius_norm + diag_norm) +// return A, V +XlaOp EighExpander::BuildEigh(XlaOp a, bool lower, int64 max_iter, float tol) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int64 num_dims = a_shape.rank(); + if (num_dims < 2) { + return InvalidArgument( + "Arguments to Eigen decomposition must have rank >= 2: got shape %s.", + a_shape.ToString()); + } + PrimitiveType type = a_shape.element_type(); + if (!primitive_util::IsFloatingPointType(type) && + !primitive_util::IsComplexType(type)) { + return InvalidArgument( + "Type of the input matrix must be floating point " + "or complex: got %s.", + a_shape.ToString()); + } + + const int64 m = ShapeUtil::GetDimension(a_shape, -2); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + + if (m != n) { + return InvalidArgument( + "Arguments to symmetric eigendecomposition must be square matrices: " + "got shape (%d, %d).", + m, n); + } + + const int64 num_batch_dims = num_dims - 2; + std::vector batch_dims(num_batch_dims); + for (int i = 0; i < num_batch_dims; ++i) { + batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); + } + + if (m <= 1) { + return Tuple(builder, {FullLike(a, 1), GetMatrixDiagonal(Real(a))}); + } + + a = Symmetrize(a, lower); + + const int64 k = CeilOfRatio(n, int64{2}); + // tl = A[:n // 2, :n // 2] + // bl = A[n // 2:, :n // 2] + // tr = A[:n // 2, n // 2:] + // br = A[n // 2:, n // 2:] + auto tl = SliceInMinorDims(a, {0, 0}, {k, k}); + auto bl = SliceInMinorDims(a, {k, 0}, {n, k}); + auto tr = SliceInMinorDims(a, {0, k}, {k, n}); + auto br = SliceInMinorDims(a, {k, k}, {n, n}); + if (n % 2) { + auto zero = Zero(builder, type); + tr = PadInDim(tr, zero, num_dims - 1, /*pad_lo=*/0, /*pad_hi=*/1); + bl = PadInDim(bl, zero, num_dims - 2, /*pad_lo=*/0, /*pad_hi=*/1); + PaddingConfig config = MakeNoPaddingConfig(num_dims); + config.mutable_dimensions(num_dims - 2)->set_edge_padding_high(1); + config.mutable_dimensions(num_dims - 1)->set_edge_padding_high(1); + br = Pad(br, zero, config); + } + // v_tl = np.eye(n // 2, dtype=A.dtype) + // v_tr = np.zeros((n // 2, n // 2), A.dtype) + // v_bl = np.zeros((n // 2, n // 2), A.dtype) + // v_br = np.eye(n // 2, dtype=A.dtype) + auto v_tl = Broadcast(IdentityMatrix(builder, type, k, k), batch_dims); + auto v_br = v_tl; + auto v_tr = ZerosLike(v_tl); + auto v_bl = v_tr; + + TF_ASSIGN_OR_RETURN(auto output, Sweeps( + { + Zero(builder, S32), + ScalarLike(Real(a), tol), + tl, + tr, + bl, + br, + v_tl, + v_tr, + v_bl, + v_br, + }, + k * 2, max_iter, S32, builder)); + + std::tie(tl, tr, bl, br) = + std::make_tuple(output[2], output[3], output[4], output[5]); + std::tie(v_tl, v_tr, v_bl, v_br) = + std::make_tuple(output[6], output[7], output[8], output[9]); + + auto w = ConcatInDim( + builder, {GetMatrixDiagonal(Real(tl)), GetMatrixDiagonal(Real(br))}, + num_dims - 2); + auto v = ConcatInDim(builder, + {ConcatInDim(builder, {v_tl, v_tr}, num_dims - 1), + ConcatInDim(builder, {v_bl, v_br}, num_dims - 1)}, + num_dims - 2); + if (n % 2) { + w = SliceInMinorDims(w, {0}, {n}); + v = SliceInMinorDims(v, {0, 0}, {n, n}); + } + v = MaybeConjugate(TransposeInMinorDims(v), true); + + TF_ASSIGN_OR_RETURN(std::tie(v, w), SortByEigenvalues(v, w)); + return Tuple(builder, {v, w}); + }); +} + +static const char* kEighCustomCallName = "Eigh"; + +bool EighExpander::InstructionMatchesPattern(HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kCustomCall && + instruction->custom_call_target() == kEighCustomCallName; +} + +StatusOr EighExpander::ExpandInstruction( + HloInstruction* instruction) { + const string name = + absl::StrFormat("xla.%s_%s", instruction->custom_call_target(), + instruction->operand(0)->shape().ToString()); + + HloModule* module = instruction->parent()->parent(); + + HloComputation*& computation = + computation_cache_.emplace(name, nullptr).first->second; + if (!computation) { + // Builds a new expansion. + // + // TODO(b/62327888): We do something unusual here: we build the computation + // using the XlaBuilder API, which is nominally an XLA client API. We do + // this because the external APIs for building complicated computations + // (XlaBuilder) are much more ergonomic than the internal ones. As it turns + // out, XlaBuilder isn't really a client API—what it does is build a + // HloModuleProto protocol buffer, that we can then deserialize and clone + // into our HloModule. Ideally we would avoid the protocol buffer step; + // that is left as an exercise for future work. + XlaBuilder builder(name); + TF_RET_CHECK(instruction->operand_count() == 1); + XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a"); + + std::vector config_strs = + absl::StrSplit(instruction->raw_backend_config_string(), ','); + int lower; + int64 max_iter; + float tol; + if (config_strs.size() != 3 || !absl::SimpleAtoi(config_strs[0], &lower) || + !absl::SimpleAtoi(config_strs[1], &max_iter) || + !absl::SimpleAtof(config_strs[2], &tol)) { + return Internal("Unable to parse arguments to Eigh custom call, got: %s", + instruction->raw_backend_config_string()); + } + XlaOp result = BuildEigh(a, lower, max_iter, tol); + TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build(result)); + + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, + xla_computation.GetProgramShape()); + HloModuleConfig config(program_shape); + TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( + xla_computation.proto(), config)); + HloCloneContext context(module); + computation = + module->DeepCloneComputation(new_module->entry_computation(), &context); + } + + return instruction->parent()->AddInstruction(HloInstruction::CreateCall( + instruction->shape(), instruction->operands(), computation)); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/eigh_expander.h b/tensorflow/compiler/xla/service/eigh_expander.h new file mode 100644 index 00000000000..ec282e78dbb --- /dev/null +++ b/tensorflow/compiler/xla/service/eigh_expander.h @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_EIGH_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_EIGH_EXPANDER_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" + +namespace xla { + +class EighExpander : public OpExpanderPass { + public: + absl::string_view name() const override { return "eigh_expander"; } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + StatusOr ExpandInstruction( + HloInstruction* instruction) override; + + virtual XlaOp BuildEigh(XlaOp a, bool lower, int64 max_iter, float tol); + + private: + // Mapping from op signatures to existing computations. + absl::flat_hash_map computation_cache_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_EIGH_EXPANDER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index b580128221a..744ef8e7a9f 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1290,6 +1290,7 @@ cc_library( "//tensorflow/compiler/xla/service:dump", "//tensorflow/compiler/xla/service:dynamic_index_splitter", "//tensorflow/compiler/xla/service:dynamic_padder", + "//tensorflow/compiler/xla/service:eigh_expander", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:gather_expander", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 1a265e25690..49594124cbf 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -52,6 +52,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/dynamic_padder.h" +#include "tensorflow/compiler/xla/service/eigh_expander.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/gpu/alias_passthrough_params.h" @@ -191,8 +192,10 @@ Status GpuCompiler::OptimizeHloModule( pipeline.AddPass(); pipeline.AddPass(); - // TODO(phawkins): replace QR decompositions with calls to cuSOLVER. + // TODO(phawkins): replace QR and Eigh decompositions with calls to + // cuSOLVER. pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc index 0368d1c679a..5640d44750f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc @@ -747,12 +747,19 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( absl::Milliseconds(profile_result.elapsed_time_in_ms())); } } - const auto& best_result = absl::c_min_element( - profile_results, - [&](const AutotuneResult& lhs, const AutotuneResult& rhs) { - return tensorflow::proto_utils::FromDurationProto(lhs.run_time()) < - tensorflow::proto_utils::FromDurationProto(rhs.run_time()); - }); + auto best_result = profile_results.begin(); + if (!RequireCudnnDeterminism() && !instr->parent() + ->parent() + ->config() + .debug_options() + .xla_gpu_deterministic_ops()) { + best_result = absl::c_min_element( + profile_results, + [&](const AutotuneResult& lhs, const AutotuneResult& rhs) { + return tensorflow::proto_utils::FromDurationProto(lhs.run_time()) < + tensorflow::proto_utils::FromDurationProto(rhs.run_time()); + }); + } if (best_result != profile_results.end()) { return *best_result; diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index 1167f6df140..8d02df045d6 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -133,11 +133,12 @@ class HloModuleConfig { } int64 num_partitions() const { return num_partitions_; } - void set_broadcast_replicated_params(bool broadcast_replicated_params) { - broadcast_replicated_params_ = broadcast_replicated_params; + const std::vector param_requires_broadcast_via_collectives() const { + return param_requires_broadcast_via_collectives_; } - bool broadcast_replicated_params() const { - return broadcast_replicated_params_; + void set_param_requires_broadcast_via_collectives( + const std::vector require_broadcast) { + param_requires_broadcast_via_collectives_ = std::move(require_broadcast); } void set_use_spmd_partitioning(bool use_spmd_partitioning) { @@ -256,8 +257,8 @@ class HloModuleConfig { // The number of partitions (model parallelism) to compile this binary for. int64 num_partitions_ = 1; - // Whether to use XLA collectives to broadcast params to all replicas. - bool broadcast_replicated_params_ = false; + // Whether to broadcast args across all replicas. One entry per arg. + std::vector param_requires_broadcast_via_collectives_; // Whether to use SPMD (true) or MPMD (false) when num_partitions_ > 0 and XLA // needs to partition the module. diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index c6cec33df73..77e54800f7c 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -382,18 +382,44 @@ Status ShapeVerifier::HandleReplicaId(HloInstruction* hlo) { namespace { -Status CheckDuplicatedSourceOrTarget(HloInstruction* hlo) { +Status CheckDuplicatedSourceOrTarget(HloInstruction* hlo, + CollectiveOpGroupMode group_mode) { // A source or target cannot appear twice in the collective-permute's - // source-target pairs. + // source-target pairs. Also, based on the group formation mode, check if the + // source and target IDs are within expected range. + + // Note: for collective-permute, only kCrossReplica and kCrossPartition modes + // are valid. + const HloModuleConfig& config = hlo->GetModule()->config(); + const int64 limit = group_mode == CollectiveOpGroupMode::kCrossReplica + ? config.replica_count() + : config.num_partitions(); + absl::flat_hash_set seen_sources; absl::flat_hash_set seen_targets; for (const auto& p : hlo->source_target_pairs()) { + TF_RET_CHECK(p.first >= 0) + << "Source " << p.first + << " in the instruction's source-target pair must be >= 0 : " + << hlo->ToString(); + TF_RET_CHECK(limit == 1 || p.first < limit) + << "Source " << p.first + << " in the instruction's source-target pair must be < " << limit + << " : " << hlo->ToString(); if (!seen_sources.insert(p.first).second) { return InternalError( "Source %d appears more than once in instruction's source-target " "pairs: %s", p.first, hlo->ToString()); } + TF_RET_CHECK(p.second >= 0) + << "Target " << p.second + << " in the instruction's source-target pair must be >= 0 : " + << hlo->ToString(); + TF_RET_CHECK(limit == 1 || p.second < limit) + << "Target " << p.second + << " in the instruction's source-target pair must be < " << limit + << " : " << hlo->ToString(); if (!seen_targets.insert(p.second).second) { return InternalError( "Target %d appears more than once in instruction's source-target " @@ -407,13 +433,21 @@ Status CheckDuplicatedSourceOrTarget(HloInstruction* hlo) { } // namespace Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { - TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo)); + TF_ASSIGN_OR_RETURN( + CollectiveOpGroupMode group_mode, + GetCollectiveOpGroupMode(hlo->channel_id().has_value(), + /*use_global_device_ids=*/absl::nullopt)); + TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo, group_mode)); return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape( hlo->operand(0)->shape())); } Status ShapeVerifier::HandleCollectivePermuteStart(HloInstruction* hlo) { - TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo)); + TF_ASSIGN_OR_RETURN( + CollectiveOpGroupMode group_mode, + GetCollectiveOpGroupMode(hlo->channel_id().has_value(), + /*use_global_device_ids=*/absl::nullopt)); + TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo, group_mode)); return CheckShape( hlo, ShapeUtil::MakeTupleShape( {hlo->operand(0)->shape(), hlo->operand(0)->shape(), diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 2ddeeb3731d..e7f556ab35a 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -1124,6 +1124,82 @@ TEST_F(HloVerifierTest, CollectivePermuteSameTargetTwice) { HasSubstr("Target 2 appears more than once")); } +TEST_F(HloVerifierTest, CollectivePermuteCrossReplicaSourceOOR) { + const char* const kModuleStr = R"( + HloModule test + ENTRY entry { + p0 = f32[128] parameter(0) + ROOT permute = f32[128] collective-permute(p0), + source_target_pairs={{5,2}, {1,2}, {2,0}} + } + )"; + HloModuleConfig config; + config.set_replica_count(3); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr, config)); + const std::string error_message = + verifier().Run(module.get()).status().error_message(); + EXPECT_THAT(error_message, HasSubstr("Source 5")); + EXPECT_THAT(error_message, HasSubstr("must be < 3")); +} + +TEST_F(HloVerifierTest, CollectivePermuteCrossReplicaTargetOOR) { + const char* const kModuleStr = R"( + HloModule test + ENTRY entry { + p0 = f32[128] parameter(0) + ROOT permute = f32[128] collective-permute(p0), + source_target_pairs={{0,1}, {1,2}, {2,7}} + } + )"; + HloModuleConfig config; + config.set_replica_count(3); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr, config)); + const std::string error_message = + verifier().Run(module.get()).status().error_message(); + EXPECT_THAT(error_message, HasSubstr("Target 7")); + EXPECT_THAT(error_message, HasSubstr("must be < 3")); +} + +TEST_F(HloVerifierTest, CollectivePermuteCrossPartitionSourceOOR) { + const char* const kModuleStr = R"( + HloModule test + ENTRY entry { + p0 = f32[128] parameter(0) + ROOT permute = f32[128] collective-permute(p0), + source_target_pairs={{5,2}, {1,2}, {2,0}}, channel_id=1 + } + )"; + HloModuleConfig config; + config.set_num_partitions(3); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr, config)); + const std::string error_message = + verifier().Run(module.get()).status().error_message(); + EXPECT_THAT(error_message, HasSubstr("Source 5")); + EXPECT_THAT(error_message, HasSubstr("must be < 3")); +} + +TEST_F(HloVerifierTest, CollectivePermuteCrossPartitionTargetOOR) { + const char* const kModuleStr = R"( + HloModule test + ENTRY entry { + p0 = f32[128] parameter(0) + ROOT permute = f32[128] collective-permute(p0), + source_target_pairs={{0,2}, {1,7}, {2,0}}, channel_id=1 + } + )"; + HloModuleConfig config; + config.set_num_partitions(3); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr, config)); + const std::string error_message = + verifier().Run(module.get()).status().error_message(); + EXPECT_THAT(error_message, HasSubstr("Target 7")); + EXPECT_THAT(error_message, HasSubstr("must be < 3")); +} + TEST_F(HloVerifierTest, FusionShapeVerifier) { const char* const kModuleStr = R"( HloModule test diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index c134b7ba6a6..7d4a2a674e0 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -40,6 +40,7 @@ cc_library( "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:custom_call_target_registry", "//tensorflow/compiler/xla/service:dynamic_index_splitter", + "//tensorflow/compiler/xla/service:eigh_expander", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 8b0a046ffa9..9d7f79baf6e 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" +#include "tensorflow/compiler/xla/service/eigh_expander.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" #include "tensorflow/compiler/xla/service/hlo_cse.h" @@ -84,6 +85,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass( diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index 22404cd1d54..3f547fe3308 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -215,7 +215,9 @@ class LowerToROCDLPass { ::mlir::OwningRewritePatternList patterns; ::mlir::populateGpuRewritePatterns(m.getContext(), patterns); - ::mlir::applyPatternsAndFoldGreedily(m, std::move(patterns)); + if (failed(mlir::applyPatternsAndFoldGreedily(m, std::move(patterns)))) { + signalPassFailure(); + } } ::mlir::OwningRewritePatternList patterns; diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter.cc b/tensorflow/compiler/xla/service/space_to_batch_converter.cc index 9140e279bec..491835d8a0c 100644 --- a/tensorflow/compiler/xla/service/space_to_batch_converter.cc +++ b/tensorflow/compiler/xla/service/space_to_batch_converter.cc @@ -75,6 +75,13 @@ class ConvolutionVisitor { kernel_spatial_dim_size, input_dim_size; }; + // Structure to keep a tab of dimensions of interest in a given shape. + struct DimensionMap { + int64 batch; + int64 space; + int64 feature; + }; + // Return a struct containing various necessary information pieces for // performing space-to-batch on a convolution. ConvDetails GetConvolutionDetails(HloInstruction* convolution, @@ -227,10 +234,9 @@ class ConvolutionVisitor { // instructions. absl::flat_hash_map old_to_new_instrs_; - // Map from instruction to dimensions of the shape (first is batch, second is - // space). This is with respect to the old instruction. - absl::flat_hash_map> - instr_to_dim_map_; + // Map from instruction to dimensions of the shape. This is with respect to + // the old instruction. + absl::flat_hash_map instr_to_dim_map_; // Map from space-to-batch'ed instruction to its permute dims. absl::flat_hash_map> @@ -720,8 +726,10 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, pivot_operand = old_producer; VLOG(2) << "Elementwise op: pivot " << old_producer->ToString(); } else { - if (instr_to_dim_map_[pivot_operand] != - instr_to_dim_map_[old_producer]) { + if (instr_to_dim_map_[pivot_operand].batch != + instr_to_dim_map_[old_producer].batch || + instr_to_dim_map_[pivot_operand].space != + instr_to_dim_map_[old_producer].space) { VLOG(2) << "Elementwise op: checking for shape equivalence " << consumer->ToString() << " failed due to changed batch space ordering "; @@ -745,7 +753,7 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, new_instr->shape().dimensions(j)) { if (!((consumer->IsElementwiseBinary() || consumer->opcode() == HloOpcode::kSelect) && - j == instr_to_dim_map_[pivot_operand].second)) { + j == instr_to_dim_map_[pivot_operand].space)) { VLOG(2) << "Elementwise op: checking for shape equivalence " << consumer->ToString() << " failed due to changed shape sizes "; @@ -759,23 +767,58 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, } if (consumer->opcode() == HloOpcode::kConvolution) { + // Lambda that checks basic sanity of dimension propagation on convolutions. + // This includes: the split dimension from the previous convolution should + // remain the same. No feature/batch dimension should be turned into a + // spatial dimension. + auto are_conv_dims_compatible = + [&](const ConvolutionDimensionNumbers dim_numbers, DimensionMap dim_map, + bool check_lhs) { + if (check_lhs) { + if (dim_numbers.input_spatial_dimensions( + get_chosen_spatial_dim(consumer)) != dim_map.space) { + return false; + } + for (int i = 0; i < dim_numbers.input_spatial_dimensions().size(); + ++i) { + if (dim_numbers.input_spatial_dimensions(i) == dim_map.batch || + dim_numbers.input_spatial_dimensions(i) == dim_map.feature) { + return false; + } + } + } else { + if (dim_numbers.kernel_spatial_dimensions( + get_chosen_spatial_dim(consumer)) != dim_map.space) { + return false; + } + for (int i = 0; i < dim_numbers.kernel_spatial_dimensions().size(); + ++i) { + if (dim_numbers.kernel_spatial_dimensions(i) == dim_map.batch || + dim_numbers.kernel_spatial_dimensions(i) == dim_map.feature) { + return false; + } + } + } + return true; + }; + VLOG(1) << "Checking if conv is supported for propagation " << consumer->ToString(); if (IsConvSuitableForSpaceToBatch(consumer)) { + // Activations must have been space-to-batched to enable propagation. if (!old_to_new_instrs_.contains(consumer->mutable_operand(0))) { return false; } auto dim_map_val_op_0 = instr_to_dim_map_[consumer->mutable_operand(0)]; - // Make sure that the space dimension is the same across the producer - // and consumer. - if (consumer->convolution_dimension_numbers().input_spatial_dimensions( - get_chosen_spatial_dim(consumer)) != dim_map_val_op_0.second) { + + if (!are_conv_dims_compatible(consumer->convolution_dimension_numbers(), + dim_map_val_op_0, /*check_lhs*/ true)) { return false; } // Make sure that the batch dimension is the same across the producer // and consumer. if (consumer->convolution_dimension_numbers().input_batch_dimension() != - dim_map_val_op_0.first) { + dim_map_val_op_0.batch) { return false; } @@ -834,8 +877,8 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, kernel->shape().dimensions(consumer->convolution_dimension_numbers() .kernel_input_feature_dimension()); auto dim_map_val_op_0 = instr_to_dim_map_[activations]; - const int64 old_batch_dim = dim_map_val_op_0.first; - const int64 old_space_dim = dim_map_val_op_0.second; + const int64 old_batch_dim = dim_map_val_op_0.batch; + const int64 old_space_dim = dim_map_val_op_0.space; auto first_operand = old_to_new_instrs_[activations]; auto permute_dims_first_operand = instr_to_dim_permute_map_[first_operand]; @@ -857,6 +900,11 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, return false; } + if (!are_conv_dims_compatible(consumer->convolution_dimension_numbers(), + dim_map_val_op_0, /*check_lhs*/ true)) { + return false; + } + // If kernel have not been propagated through, we can do // space-to-batch on them provided kernel has been propagated. VLOG(2) @@ -869,7 +917,7 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, const int64 lhs_batch = activations->shape().dimensions( consumer->convolution_dimension_numbers().input_feature_dimension()); auto dim_map_val_op_1 = instr_to_dim_map_[consumer->mutable_operand(1)]; - const int64 old_batch_dim = dim_map_val_op_1.first; + const int64 old_batch_dim = dim_map_val_op_1.batch; auto second_operand = old_to_new_instrs_[kernel]; auto permute_dims_second_operand = instr_to_dim_permute_map_[second_operand]; @@ -885,6 +933,11 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, return false; } + if (!are_conv_dims_compatible(consumer->convolution_dimension_numbers(), + dim_map_val_op_1, /*check_lhs*/ false)) { + return false; + } + // If activations have not been propagated through, we can do // space-to-batch on them provided kernel has been propagated. VLOG(2) << "Backprop filter conv ready for propagation: kernel ready, " @@ -902,14 +955,14 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, instr_to_dim_permute_map_[second_operand]; const int64 new_batch_dim_operand_0 = - DimLookUp(permute_dims_first_operand, dim_map_val_op_0.first); + DimLookUp(permute_dims_first_operand, dim_map_val_op_0.batch); const int64 new_space_dim_operand_0 = - DimLookUp(permute_dims_first_operand, dim_map_val_op_0.second); + DimLookUp(permute_dims_first_operand, dim_map_val_op_0.space); const int64 new_batch_dim_operand_1 = - DimLookUp(permute_dims_second_operand, dim_map_val_op_1.first); + DimLookUp(permute_dims_second_operand, dim_map_val_op_1.batch); const int64 new_space_dim_operand_1 = - DimLookUp(permute_dims_second_operand, dim_map_val_op_1.second); + DimLookUp(permute_dims_second_operand, dim_map_val_op_1.space); if (first_operand->shape().dimensions(new_batch_dim_operand_0) != second_operand->shape().dimensions(new_batch_dim_operand_1)) { @@ -926,6 +979,16 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, return false; } + if (!are_conv_dims_compatible(consumer->convolution_dimension_numbers(), + dim_map_val_op_0, /*check_lhs*/ true)) { + return false; + } + + if (!are_conv_dims_compatible(consumer->convolution_dimension_numbers(), + dim_map_val_op_1, /*check_lhs*/ false)) { + return false; + } + VLOG(2) << "Backprop filter conv ready for propagation"; return true; @@ -976,8 +1039,8 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, return false; } - const int64 old_batch_dim = dim_map_val_op_0.first; - const int64 old_space_dim = dim_map_val_op_0.second; + const int64 old_batch_dim = dim_map_val_op_0.batch; + const int64 old_space_dim = dim_map_val_op_0.space; const int64 new_batch_dim = DimLookUp(permute_dims_first_operand, old_batch_dim); @@ -1017,8 +1080,8 @@ void ConvolutionVisitor::PropagateOnBroadcast(HloInstruction* consumer, auto permute_dims = instr_to_dim_permute_map_[new_producer]; auto dim_map_val = instr_to_dim_map_[producer]; - const int64 old_batch_dim = dim_map_val.first; - const int64 old_space_dim = dim_map_val.second; + const int64 old_batch_dim = dim_map_val.batch; + const int64 old_space_dim = dim_map_val.space; auto orig_broadcast_dims = consumer->dimensions(); @@ -1119,7 +1182,7 @@ bool ConvolutionVisitor::IsBroadcastPropagatable(HloInstruction* broadcast, CHECK(instr_to_dim_map_.contains(old_other_op)); auto result = instr_to_dim_map_[old_other_op]; - const int64 space_dim = result.second; + const int64 space_dim = result.space; auto broadcast_dims = broadcast->dimensions(); return !absl::c_linear_search(broadcast_dims, space_dim); } @@ -1148,8 +1211,8 @@ bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer, auto reduce_dims = consumer->dimensions(); auto result = instr_to_dim_map_[consumer->mutable_operand(0)]; - const int64 batch_dim = result.first; - const int64 space_dim = result.second; + const int64 batch_dim = result.batch; + const int64 space_dim = result.space; VLOG(1) << "Checking if reduce is supported batch_dim " << batch_dim << " space_dim " << space_dim << " reduce " << consumer->ToString(); @@ -1173,8 +1236,8 @@ bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer, } // Disallow windowing on on the batch dim auto result = instr_to_dim_map_[first_operand]; - const int64 old_batch_dim = result.first; - const int64 old_space_dim = result.second; + const int64 old_batch_dim = result.batch; + const int64 old_space_dim = result.space; if (window.dimensions(old_batch_dim).size() != 1) { return false; } @@ -1270,8 +1333,8 @@ StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, } else if (old_to_new_instrs_.contains(consumer->mutable_operand(i))) { HloInstruction* operand_to_use = nullptr; auto result = instr_to_dim_map_[producer]; - const int64 old_batch_dim = result.first; - const int64 old_space_dim = result.second; + const int64 old_batch_dim = result.batch; + const int64 old_space_dim = result.space; const int64 old_batch_size = producer->shape().dimensions(old_batch_dim); HloInstruction* new_instr = @@ -1379,8 +1442,8 @@ StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)]; auto dim_map_val = instr_to_dim_map_[consumer->mutable_operand(0)]; - const int64 old_batch_dim = dim_map_val.first; - const int64 old_space_dim = dim_map_val.second; + const int64 old_batch_dim = dim_map_val.batch; + const int64 old_space_dim = dim_map_val.space; auto permute_dims = instr_to_dim_permute_map_[first_operand]; const int64 new_batch_dim = DimLookUp(permute_dims, old_batch_dim); const int64 new_space_dim = DimLookUp(permute_dims, old_space_dim); @@ -1418,8 +1481,8 @@ StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, auto init_val = is_select_and_scatter ? consumer->mutable_operand(2) : consumer->mutable_operand(1); auto dim_map_val = instr_to_dim_map_[consumer->mutable_operand(0)]; - const int64 old_batch_dim = dim_map_val.first; - const int64 old_space_dim = dim_map_val.second; + const int64 old_batch_dim = dim_map_val.batch; + const int64 old_space_dim = dim_map_val.space; auto permute_dims = instr_to_dim_permute_map_[first_operand]; const int64 new_batch_dim = DimLookUp(permute_dims, old_batch_dim); const int64 new_space_dim = DimLookUp(permute_dims, old_space_dim); @@ -1727,8 +1790,8 @@ StatusOr ConvolutionVisitor::BatchToSpace( } auto result = instr_to_dim_map_[old_instr]; - const int64 old_batch_dim = result.first; - const int64 old_space_dim = result.second; + const int64 old_batch_dim = result.batch; + const int64 old_space_dim = result.space; const int64 old_batch_size = old_instr->shape().dimensions(old_batch_dim); CHECK(old_to_new_instrs_.contains(old_instr)); @@ -2047,9 +2110,10 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { VLOG(1) << "Space-to-batched convolution " << new_conv->ToString(); instr_to_dim_map_[convolution] = - std::make_pair(original_conv_dims.output_batch_dimension(), - original_conv_dims.output_spatial_dimensions( - get_chosen_spatial_dim(convolution))); + DimensionMap{original_conv_dims.output_batch_dimension(), + original_conv_dims.output_spatial_dimensions( + get_chosen_spatial_dim(convolution)), + original_conv_dims.output_feature_dimension()}; instr_to_dim_permute_map_[new_conv] = std::vector(transpose_dims); @@ -2147,8 +2211,8 @@ StatusOr ConvolutionVisitor::PropagateOnConstant( MakeTransposeHlo(consumer, reversed_transpose_dims)); auto dim_map = instr_to_dim_map_[producer]; - const int64 old_batch_dim = dim_map.first; - const int64 old_space_dim = dim_map.second; + const int64 old_batch_dim = dim_map.batch; + const int64 old_space_dim = dim_map.space; const int64 new_batch_dim = DimLookUp(prod_transpose_dims, old_batch_dim); const int64 new_space_dim = DimLookUp(prod_transpose_dims, old_space_dim); @@ -2200,16 +2264,23 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv( bool activations_locally_space_to_batched = false; bool kernel_locally_space_to_batched = false; std::vector permute_dims_kernel, permute_dims; + + if (old_to_new_instrs_.contains(activations_old)) { + activations_new = old_to_new_instrs_[activations_old]; + permute_dims = instr_to_dim_permute_map_[activations_new]; + } + + if (old_to_new_instrs_.contains(kernel_old)) { + kernel_new = old_to_new_instrs_[kernel_old]; + permute_dims_kernel = instr_to_dim_permute_map_[kernel_new]; + } + // If activations were no space-to-batched, we space-to-batch them below. if (!old_to_new_instrs_.contains(activations_old)) { kernel_new = old_to_new_instrs_[kernel_old]; permute_dims_kernel = instr_to_dim_permute_map_[kernel_new]; VLOG(1) << "Space-to-batching activations to enable space-to-depth"; - const int64 prev_feature_dim = original_conv_dims.input_feature_dimension(); - const int64 prev_batch_dim = original_conv_dims.input_batch_dimension(); - instr_to_dim_map_[activations_old] = - std::make_pair(prev_feature_dim, prev_batch_dim); const int64 new_kernel_space_dim = DimLookUp(permute_dims_kernel, kernel_space_dim); @@ -2228,13 +2299,13 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv( /*high_padding=*/pad_size, /*low_padding=*/0, needed_spatial_size, kNumSplits, /*is_backprop=*/true)); - old_to_new_instrs_[activations_old] = retval.first; + activations_new = retval.first; std::vector reversed_transpose_dims(retval.second.size()); for (int64 i = 0; i < retval.second.size(); ++i) { reversed_transpose_dims[i] = ReverseDimLookUp(retval.second, i); } - instr_to_dim_permute_map_[retval.first] = reversed_transpose_dims; + permute_dims = reversed_transpose_dims; VLOG(3) << "New Activations " << retval.first->ToString(); @@ -2244,15 +2315,6 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv( permute_dims = instr_to_dim_permute_map_[activations_new]; VLOG(1) << "Space-to-batching kernel to enable space-to-depth"; - const int64 prev_feature_dim = - original_conv_dims.kernel_input_feature_dimension(); - const int64 prev_output_feature_dim = - original_conv_dims.kernel_output_feature_dimension(); - // TODO(b/168316428): The instr_to_dim_map_ is set incorrectly here, but it - // doesn't matter since it is never used. Investigate further to see if just - // not setting it works. - instr_to_dim_map_[kernel_old] = - std::make_pair(prev_feature_dim, prev_output_feature_dim); const int64 new_space_dim = DimLookUp(permute_dims, old_space_dim); const int64 new_split_dim_size = @@ -2273,29 +2335,25 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv( needed_spatial_size, kNumSplits, /*is_backprop=*/true, /*is_rhs=*/true)); - old_to_new_instrs_[kernel_old] = retval.first; + kernel_new = retval.first; std::vector reversed_transpose_dims(retval.second.size()); for (int64 i = 0; i < retval.second.size(); ++i) { reversed_transpose_dims[i] = ReverseDimLookUp(retval.second, i); } - instr_to_dim_permute_map_[retval.first] = reversed_transpose_dims; + permute_dims_kernel = reversed_transpose_dims; VLOG(3) << "New kernel " << retval.first->ToString(); kernel_locally_space_to_batched = true; } - CHECK(old_to_new_instrs_.contains(activations_old)); - CHECK(old_to_new_instrs_.contains(kernel_old)); - activations_new = old_to_new_instrs_[activations_old]; - kernel_new = old_to_new_instrs_[kernel_old]; + CHECK_NE(activations_new, nullptr); + CHECK_NE(kernel_new, nullptr); + const int64 new_spatial_dimension = activations_new->shape().dimensions_size(); - permute_dims = instr_to_dim_permute_map_[activations_new]; - permute_dims_kernel = instr_to_dim_permute_map_[kernel_new]; - auto permuted_conv_dims_numbers = original_conv_dims; // Note the inversion here : batch and feature are inverted in backprop @@ -2573,9 +2631,10 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv( VLOG(1) << "Space-to-featured convolution " << new_conv->ToString(); instr_to_dim_map_[convolution] = - std::make_pair(original_conv_dims.output_batch_dimension(), - original_conv_dims.output_spatial_dimensions( - get_chosen_spatial_dim(convolution))); + DimensionMap{original_conv_dims.output_batch_dimension(), + original_conv_dims.output_spatial_dimensions( + get_chosen_spatial_dim(convolution)), + original_conv_dims.output_feature_dimension()}; std::vector trans_dims(convolution->shape().dimensions_size()); absl::c_iota(trans_dims, 0); @@ -2618,7 +2677,7 @@ bool ConvolutionVisitor::IsSpaceToBatchedSpaceSizeSuitable( auto old_producer = instr->mutable_operand(0); auto dim_map_val_op = instr_to_dim_map_[old_producer]; - const int64 old_space_dim = dim_map_val_op.second; + const int64 old_space_dim = dim_map_val_op.space; auto first_operand = old_to_new_instrs_[old_producer]; auto permute_dims_first_operand = instr_to_dim_permute_map_[first_operand]; const int64 new_space_dim = @@ -2864,9 +2923,10 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( old_to_new_instrs_[original_conv] = new_conv; instr_to_dim_map_[original_conv] = - std::make_pair(dim_numbers.output_batch_dimension(), - dim_numbers.output_spatial_dimensions( - get_chosen_spatial_dim(original_conv))); + DimensionMap{dim_numbers.output_batch_dimension(), + dim_numbers.output_spatial_dimensions( + get_chosen_spatial_dim(convolution)), + dim_numbers.output_feature_dimension()}; instr_to_dim_permute_map_[new_conv] = std::vector(transpose_dims); if (non_propagatable_instrs_.count(convolution) > 0) { diff --git a/tensorflow/compiler/xla/service/spmd/canonicalize_all_gather_for_cse.cc b/tensorflow/compiler/xla/service/spmd/canonicalize_all_gather_for_cse.cc index 6b4092ce0a2..8290b44fa84 100644 --- a/tensorflow/compiler/xla/service/spmd/canonicalize_all_gather_for_cse.cc +++ b/tensorflow/compiler/xla/service/spmd/canonicalize_all_gather_for_cse.cc @@ -71,9 +71,19 @@ StatusOr CanonicalizeAllGatherForCSE::RunOnComputation( // adding the dimension the all-gather is operating on then perform the // canonicalization. if (real_data != ag->operand(0)) { - std::vector new_dimensions(real_data->shape().dimensions().begin(), - real_data->shape().dimensions().end()); - new_dimensions[0] *= all_gather_participants; + std::vector new_dimensions; + new_dimensions.reserve(real_data->shape().dimensions_size() + 1); + new_dimensions.push_back(1); + new_dimensions.insert(new_dimensions.end(), + real_data->shape().dimensions().begin(), + real_data->shape().dimensions().end()); + // Adding specialized all-gather dimension. + HloInstruction* ag_input = + comp->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(real_data->shape().element_type(), + new_dimensions), + real_data)); + new_dimensions[0] = all_gather_participants; absl::optional new_channel_id = ag->channel_id() ? absl::make_optional(this->NextChannelId()) : absl::nullopt; @@ -81,7 +91,7 @@ StatusOr CanonicalizeAllGatherForCSE::RunOnComputation( comp->AddInstruction(HloInstruction::CreateAllGather( ShapeUtil::MakeShape(real_data->shape().element_type(), new_dimensions), - real_data, /*all_gather_dimension=*/0, ag->replica_groups(), + ag_input, /*all_gather_dimension=*/0, ag->replica_groups(), ag->constrain_layout(), new_channel_id, ag->use_global_device_ids())); HloInstruction* new_formatting = comp->AddInstruction( diff --git a/tensorflow/compiler/xla/service/spmd/canonicalize_all_gather_for_cse_test.cc b/tensorflow/compiler/xla/service/spmd/canonicalize_all_gather_for_cse_test.cc index b9378d700ec..3b6b15c57d9 100644 --- a/tensorflow/compiler/xla/service/spmd/canonicalize_all_gather_for_cse_test.cc +++ b/tensorflow/compiler/xla/service/spmd/canonicalize_all_gather_for_cse_test.cc @@ -85,8 +85,9 @@ ENTRY entry { auto module = module_status.ConsumeValueOrDie(); const HloInstruction* const reshape = module->entry_computation()->root_instruction(); - EXPECT_THAT(reshape, - AllOf(op::Reshape(op::AllGather(_)), op::Shape("s32[2,8,1,1]"))); + EXPECT_THAT(reshape, AllOf(op::Reshape(op::AllGather( + AllOf(op::Reshape(_), op::Shape("s32[1,8]")))), + op::Shape("s32[2,8,1,1]"))); } TEST_F(AllGatherCanonicalizeTest, MultipleDegenerateReshapes2) { @@ -105,8 +106,9 @@ ENTRY entry { auto module = module_status.ConsumeValueOrDie(); const HloInstruction* const reshape = module->entry_computation()->root_instruction(); - EXPECT_THAT(reshape, - AllOf(op::Reshape(op::AllGather(_)), op::Shape("s32[2,8,1,1]"))); + EXPECT_THAT(reshape, AllOf(op::Reshape(op::AllGather( + AllOf(op::Reshape(_), op::Shape("s32[1,8]")))), + op::Shape("s32[2,8,1,1]"))); } TEST_F(AllGatherCanonicalizeTest, MultipleDegenerateReshapesNoDim0) { diff --git a/tensorflow/compiler/xla/service/spmd/dot_handler.cc b/tensorflow/compiler/xla/service/spmd/dot_handler.cc index 26d77450385..12530f39167 100644 --- a/tensorflow/compiler/xla/service/spmd/dot_handler.cc +++ b/tensorflow/compiler/xla/service/spmd/dot_handler.cc @@ -452,6 +452,11 @@ absl::optional GetWindowedEinsumConfiguration( return absl::nullopt; } +// We use a recursive approach where sets of matching dimensions are recognized +// one at a time. The base shapes and shardings can be changed during the +// recursion as we group devices together. So refer to the passed in shapes and +// shardings for inputs and output, and do not use shape inference. + StatusOr PartitionBaseCase( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, @@ -2159,6 +2164,17 @@ StatusOr PartitionDotGroupOnContracting( HloSharding inner_output_sharding = HloSharding::Replicate(); HloSharding outer_output_tmp_sharding = HloSharding::Replicate(); + Shape inner_output_base_shape = output_base_shape; + std::vector output_slice_dims; + auto get_non_slice_dims = [&] { + std::vector non_group_dims; + for (int64 i = 0; i < output_base_shape.rank(); ++i) { + if (!absl::c_linear_search(output_slice_dims, i)) { + non_group_dims.push_back(i); + } + } + return non_group_dims; + }; if (output_sharding.ReplicateOnLastTileDim() && output_sharding.tile_assignment().dimensions().back() % group_count == 0) { @@ -2173,38 +2189,69 @@ StatusOr PartitionDotGroupOnContracting( outer_output_tmp_sharding = UngroupSharding(grouped); inner_output_sharding = std::move(grouped.sharding); } else { - std::vector group_dims; if (auto found_dims = FindMatchingPartitionedDimsForGrouping( output_sharding, lhs_grouped.device_groups)) { - group_dims = std::move(*found_dims); + output_slice_dims = std::move(*found_dims); } else if (output_lhs_non_contracting_partitions == group_count || output_rhs_non_contracting_partitions == group_count || output_batch_partitions == group_count) { if (output_lhs_non_contracting_partitions == group_count) { for (const auto& dim : dims_mapping.lhs_non_contracting_dims) { - group_dims.push_back(dim.output); + output_slice_dims.push_back(dim.output); } } else if (output_rhs_non_contracting_partitions == group_count) { for (const auto& dim : dims_mapping.rhs_non_contracting_dims) { - group_dims.push_back(dim.output); + output_slice_dims.push_back(dim.output); } } else { for (const auto& dim : dims_mapping.batch_dims) { - group_dims.push_back(dim.output); + output_slice_dims.push_back(dim.output); } } } - if (!group_dims.empty()) { + if (!output_slice_dims.empty()) { auto grouped = AlignGroupsWith( - GroupShardingOnDims(output_sharding, group_dims), lhs_grouped); + GroupShardingOnDims(output_sharding, output_slice_dims), lhs_grouped); inner_output_sharding = grouped.sharding; - outer_output_tmp_sharding = + // Since the recursive callee will use inner_creator to create + // reduce-scatter in-place, the output shape it sees is also sliced so + // inner_output_base_shape adjusts that expectation. + inner_output_base_shape = MakePartitionedShape( + output_base_shape, hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( - UngroupSharding(grouped), group_dims); + output_sharding, get_non_slice_dims())); + outer_output_tmp_sharding = UngroupSharding(grouped); } } - auto inner_state = CreatePerGroupPartitioningState( - lhs.state(), lhs_grouped.device_groups, b); + auto inner_creator = + [&](HloInstruction* l, HloInstruction* r, SpmdBuilder* b, + const Window& conv_window) -> StatusOr { + TF_ASSIGN_OR_RETURN(auto inner_dot, + create_sharded_dot(l, r, b, conv_window)); + auto ar = lhs.state().partitioner->AllReduceAlongShardingDims( + b, inner_dot, lhs_sharding, lhs.state().next_channel_id, lhs_dims, + lhs.state().collective_ops_creator, + MakeBinaryAdd(output_base_shape.element_type(), module)); + if (output_slice_dims.empty()) { + return ar; + } + // Use resharding to slice the output. Use a temporary reshard cache since + // we are faking with replicated sharding. + PartitionedHlo::PartitioningState new_state = lhs.state(); + new_state.b = b; + new_state.partition_id = + lhs.state().collective_ops_creator.create_partition_id(b); + PartitionedHlo::ReshardCache tmp_cache; + new_state.reshard_cache = &tmp_cache; + ar->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(ar, ar->shape(), new_state) + .Reshard(hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + output_sharding, get_non_slice_dims())) + .hlo(); + }; + PartitionedHlo::PartitioningState inner_state = + CreatePerGroupPartitioningState(lhs.state(), lhs_grouped.device_groups, + b); TF_ASSIGN_OR_RETURN( auto dot, PartitionDot( @@ -2214,20 +2261,17 @@ StatusOr PartitionDotGroupOnContracting( PartitionedHlo(rhs.hlo(), GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()), inner_state), - output_base_shape, inner_output_sharding, dims_mapping, - num_partitions / group_count, create_sharded_dot, conv_window, module, + inner_output_base_shape, inner_output_sharding, dims_mapping, + num_partitions / group_count, inner_creator, conv_window, module, original_hlo, options, b, windowed_dot_general_loops)); if (!dot) { return nullptr; } - auto ar = lhs.state().partitioner->AllReduceAlongShardingDims( - b, dot, lhs_sharding, lhs.state().next_channel_id, lhs_dims, - lhs.state().collective_ops_creator, - MakeBinaryAdd(output_base_shape.element_type(), module)); - ar->set_sharding(outer_output_tmp_sharding); - return PartitionedHlo(ar, output_base_shape, lhs.state()) - .Reshard(output_sharding) - .hlo(); + dot->set_sharding(outer_output_tmp_sharding); + auto d = PartitionedHlo(dot, output_base_shape, lhs.state()) + .Reshard(output_sharding) + .hlo(); + return d; } DotConvDimsMapping ConvertDimsMappingWithFeatureGroupCount( diff --git a/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.cc b/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.cc index 71f52d06271..ead5a376e25 100644 --- a/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.cc +++ b/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.cc @@ -25,6 +25,30 @@ limitations under the License. namespace xla { namespace { +// Returns if an instructions adds only degenerate dimensions to the shape of +// the input, like going from [X,Y] to [1,X,Y,1]. +bool IsAddingOnlyDegenerateDimensions(const HloInstruction* inst) { + if (inst->opcode() != HloOpcode::kBitcast && + inst->opcode() != HloOpcode::kReshape) { + return false; + } + const Shape& in_shape = inst->operand(0)->shape(); + const Shape& out_shape = inst->shape(); + return ShapeUtil::ElementsIn(in_shape) == ShapeUtil::ElementsIn(out_shape) && + ShapeUtil::DimensionsUnmodifiedByReshape(in_shape, out_shape).size() == + in_shape.rank(); +} + +// Passthrough reshapes or bitcasts adding only degenerate hdimensions to some +// shape. +const HloInstruction* PassthroughDegenerateAddingReshapes( + const HloInstruction* inst) { + while (IsAddingOnlyDegenerateDimensions(inst)) { + inst = inst->operand(0); + } + return inst; +} + HloCollectiveInstruction* MayConsiderAsAllGather(HloInstruction* hlo, bool for_replicas) { auto coll = DynCast(hlo); @@ -85,16 +109,23 @@ StatusOr RunOnComputation(HloComputation* comp, bool for_replicas, if (!ag) { continue; } - - auto& earlier_ags = operand_to_ag[ag->operand(0)]; + auto& earlier_ags = + operand_to_ag[PassthroughDegenerateAddingReshapes(ag->operand(0))]; bool found = false; int64 ag_height = height[ag]; for (auto& eag : earlier_ags) { + if (!ShapeUtil::Equal(eag->shape(), ag->shape())) { + continue; + } + HloInstruction* ag_operand = ag->mutable_operand(0); + TF_RETURN_IF_ERROR(ag->ReplaceOperandWith(0, eag->mutable_operand(0))); if (!eag->IdenticalIgnoringChannelIdValues(*ag)) { + TF_RETURN_IF_ERROR(ag->ReplaceOperandWith(0, ag_operand)); continue; } found = true; if (lowest_user_height(eag) > ag_height + distance_threshold) { + TF_RETURN_IF_ERROR(ag->ReplaceOperandWith(0, ag_operand)); eag = ag; continue; } diff --git a/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse_test.cc b/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse_test.cc index da3d3a562a6..dca956e3e7f 100644 --- a/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse_test.cc +++ b/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse_test.cc @@ -63,6 +63,52 @@ ENTRY entry { EXPECT_EQ(tuple->operand(0), tuple->operand(1)); } +TEST_F(AllGatherCseTest, SimpleCseReshapeLookthrough) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + param0 = s32[8]{0} parameter(0) + rshp = s32[1,8]{1,0} reshape(param0) + rshp2 = s32[1,8]{1,0} reshape(param0) + ag1 = s32[2,8]{1,0} all-gather(rshp), replica_groups={{0,1}}, dimensions={0}, + channel_id=0, use_global_device_ids=true + ag2 = s32[2,8]{1,0} all-gather(rshp2), replica_groups={{0,1}}, dimensions={0}, + channel_id=1, use_global_device_ids=true + ROOT tuple = (s32[2,8]{1,0}, s32[2,8]{1,0}) tuple(ag1, ag2) +})"; + auto module_status = RunPass(hlo_string); + EXPECT_TRUE(module_status.status().ok()); + auto module = module_status.ConsumeValueOrDie(); + HloInstruction* tuple = module->entry_computation()->root_instruction(); + EXPECT_EQ(tuple->opcode(), HloOpcode::kTuple); + EXPECT_EQ(tuple->operand_count(), 2); + EXPECT_EQ(tuple->operand(0), tuple->operand(1)); +} + +TEST_F(AllGatherCseTest, SimpleNoCseInvalidReshapes) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + param0 = s32[8]{0} parameter(0) + rshp = s32[2,4]{1,0} reshape(param0) + rshp2 = s32[2,4]{1,0} reshape(param0) + ag1 = s32[4,4]{1,0} all-gather(rshp), replica_groups={{0,1}}, dimensions={0}, + channel_id=0, use_global_device_ids=true + ag2 = s32[4,4]{1,0} all-gather(rshp2), replica_groups={{0,1}}, dimensions={0}, + channel_id=1, use_global_device_ids=true + ROOT tuple = (s32[4,4]{1,0}, s32[4,4]{1,0}) tuple(ag1, ag2) +})"; + auto module_status = RunPass(hlo_string); + EXPECT_TRUE(module_status.status().ok()); + auto module = module_status.ConsumeValueOrDie(); + HloInstruction* tuple = module->entry_computation()->root_instruction(); + EXPECT_EQ(tuple->opcode(), HloOpcode::kTuple); + EXPECT_EQ(tuple->operand_count(), 2); + EXPECT_NE(tuple->operand(0), tuple->operand(1)); +} + TEST_F(AllGatherCseTest, SimpleCseDifferentDim) { absl::string_view hlo_string = R"( HloModule module @@ -84,6 +130,29 @@ ENTRY entry { EXPECT_EQ(tuple->operand(0), tuple->operand(1)); } +TEST_F(AllGatherCseTest, SimpleCseDifferentDimReshapeLookthrough) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + param0 = s32[8]{0} parameter(0) + rshp = s32[1,8]{1,0} reshape(param0) + rshp2 = s32[1,8]{1,0} reshape(param0) + ag1 = s32[1,16]{1,0} all-gather(rshp), replica_groups={{0,1}}, dimensions={1}, + channel_id=0, use_global_device_ids=true + ag2 = s32[1,16]{1,0} all-gather(rshp2), replica_groups={{0,1}}, + dimensions={1}, channel_id=1, use_global_device_ids=true + ROOT tuple = (s32[1,16]{1,0}, s32[2,8,1,1]{3,2,1,0}) tuple(ag1, ag2) +})"; + auto module_status = RunPass(hlo_string); + EXPECT_TRUE(module_status.status().ok()); + auto module = module_status.ConsumeValueOrDie(); + HloInstruction* tuple = module->entry_computation()->root_instruction(); + EXPECT_EQ(tuple->opcode(), HloOpcode::kTuple); + EXPECT_EQ(tuple->operand_count(), 2); + EXPECT_EQ(tuple->operand(0), tuple->operand(1)); +} + TEST_F(AllGatherCseTest, NoCseGlobalDevice) { absl::string_view hlo_string = R"( HloModule module diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc index 052c2e63139..fb4009cd27a 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -1975,8 +1975,9 @@ Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) { auto operand = GetPartitionedHlo(hlo->operand(0)); // The output shape is the source and the operand shape is the target to get // the aligned sharding for the operand. - auto desired_operand_sharding = hlo_sharding_util::ReshapeSharding( - hlo->shape(), hlo->operand(0)->shape(), hlo->sharding()); + absl::optional desired_operand_sharding = + hlo_sharding_util::ReshapeSharding(hlo->shape(), hlo->operand(0)->shape(), + hlo->sharding()); if (desired_operand_sharding.has_value()) { auto operand_hlo = operand.Reshard(*desired_operand_sharding).hlo(); SetPartitionedHlo(hlo, [&] { @@ -1985,6 +1986,21 @@ Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) { }); return Status::OK(); } + absl::optional desired_output_sharding = + hlo_sharding_util::ReshapeSharding(hlo->operand(0)->shape(), hlo->shape(), + operand.sharding()); + if (desired_output_sharding.has_value()) { + auto reshape = b_.AddInstruction(hlo->CloneWithNewOperands( + MakePartitionedShape(hlo->shape(), *desired_output_sharding), + {operand.hlo()})); + reshape->set_sharding(*desired_output_sharding); + SetPartitionedHlo(hlo, [&] { + return PartitionedHlo(reshape, hlo->shape(), MakePartitioningState()) + .Reshard(sharding) + .hlo(); + }); + return Status::OK(); + } // Check if operand sharding and sharding are both tiled or partial replicate. // If both of them are partial replicate, check num_replications are the same. diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index c380a16f816..6daf3946ae9 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -2925,6 +2925,49 @@ ENTRY entry { EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]"))); } +TEST_F(SpmdPartitioningTest, ReshapeWithReshard) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[38,38,324] parameter(0), sharding={devices=[2,1,1]0,1} + ROOT %reshape = f32[38,38,4,81] reshape(%param0), + sharding={devices=[1,2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto input_reshard = + op::Reshape(op::Transpose(op::AllToAll(op::Reshape(op::Parameter(0))))); + EXPECT_THAT(root, + AllOf(op::Reshape(input_reshard), op::Shape("f32[38,19,4,81]"))); +} + +TEST_F(SpmdPartitioningTest, ReshapeWithReshard2) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[38,38,324] parameter(0), sharding={devices=[2,1,1]0,1} + ROOT %reshape = f32[38,38,2,162] reshape(%param0), + sharding={devices=[1,1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto local_reshape = + AllOf(op::Reshape(op::Parameter(0)), op::Shape("f32[19,38,2,162]")); + EXPECT_THAT(root, AllOf(op::Shape("f32[38,38,2,81]"), + op::Reshape(op::Transpose( + op::AllToAll(op::Reshape(local_reshape)))))); +} + TEST_F(SpmdPartitioningTest, PartialReplicateShardableReshape) { absl::string_view hlo_string = R"( HloModule module @@ -2949,35 +2992,6 @@ ENTRY entry { EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]"))); } -TEST_F(SpmdPartitioningTest, NonShardableReshape) { - absl::string_view hlo_string = R"( -HloModule module - -ENTRY entry { - %param0 = f32[38,38,324] parameter(0) - %param0.copy = f32[38,38,324] copy(%param0), sharding={devices=[1,1,2]0,1} - ROOT %transpose = f32[38,38,4,81] reshape(%param0.copy), - sharding={devices=[1,1,1,2]0,1} -})"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, - PartitionComputation(hlo_string, /*num_devices=*/2)); - VLOG(1) << module->ToString(); - - auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT( - root, - AllOf(op::DynamicSlice( - AllOf(op::Pad( - AllOf(op::Reshape(AllOf(op::AllReduce(), - op::Shape("f32[38,38,324]"))), - op::Shape("f32[38,38,4,81]")), - op::Constant()), - op::Shape("f32[38,38,4,82]")), - op::Constant(), op::Constant(), op::Constant(), op::Reshape()), - op::Shape("f32[38,38,4,41]"))); -} - TEST_F(SpmdPartitioningTest, ReshapeMergeDimsWithHaloExchange) { absl::string_view hlo_string = R"( HloModule module @@ -7319,6 +7333,34 @@ ENTRY entry { EXPECT_THAT(dot_op, op::Dot(op1, op2)); } +TEST_F(SpmdPartitioningTest, PartitionDotGroupOnBatchContractingReshard) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,32,24,4096] parameter(0), + sharding={devices=[2,1,1,2]0,1,2,3} + %rhs = f32[32,4096,1024] parameter(1), + sharding={devices=[2,2,1]0,1,2,3} + ROOT %dot = f32[32,32,24,1024] dot(%lhs, %rhs), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={3}, rhs_contracting_dims={1}, + sharding={devices=[1,2,1,2]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto dot = AllOf(op::Shape("f32[16,32,24,1024]"), + op::Dot(op::Parameter(0), op::Parameter(1))); + auto reduce_scatter = AllOf(op::Shape("f32[16,32,24,512]"), + op::DynamicSlice(op::AllReduce(dot), _, _, _, _)); + EXPECT_THAT(root, AllOf(op::Reshape(op::Transpose( + op::AllToAll(op::Reshape(reduce_scatter)))), + op::Shape("f32[32,16,24,512]"))); +} + } // namespace } // namespace spmd } // namespace xla diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 1929b90a44a..b7cf13e084a 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -381,7 +381,7 @@ message ExecutionOptions { // works on TPU. bool deduplicate_hlo = 12; - reserved 13; // Was broadcast_replicated_parameters_via_collectives = 13; + reserved 13; // Was broadcast_replicated_parameters_via_collectives } message GetDeviceHandlesRequest { diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index a69345d544a..ac32b043cf9 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -68,7 +68,6 @@ load( "if_ios", "if_libtpu", "if_mobile", - "if_not_windows", "tf_cc_test", "tf_cc_test_mkl", "tf_cc_tests", @@ -657,9 +656,7 @@ cc_library( "//tensorflow/core/kernels/linalg:linalg", "//tensorflow/core/kernels/image:image", "//tensorflow/core/kernels/sparse:kernels", - ] + if_not_windows([ - "//tensorflow/core/kernels/neon:neon_depthwise_conv_op", - ]) + if_mkl([ + ] + if_mkl([ "//tensorflow/core/kernels/mkl:mkl_aggregate_ops", "//tensorflow/core/kernels/mkl:mkl_concat_op", "//tensorflow/core/kernels/mkl:mkl_dequantize_op", diff --git a/tensorflow/core/common_runtime/eager/eager_executor.cc b/tensorflow/core/common_runtime/eager/eager_executor.cc index 015b914337d..2d06390b810 100644 --- a/tensorflow/core/common_runtime/eager/eager_executor.cc +++ b/tensorflow/core/common_runtime/eager/eager_executor.cc @@ -116,10 +116,6 @@ Status EagerExecutor::SyncExecute(EagerNode* node) { // Inline execution in sync mode. s = node->Run(); tensorflow::mutex_lock l(node_queue_mutex_); - if (!s.ok()) { - status_ = s; - ok_ = false; - } NotifyWaiters(id); return s; } diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamallocasync_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_cudamallocasync_allocator.h index 8848ecd3ac5..b31154d7aca 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_cudamallocasync_allocator.h +++ b/tensorflow/core/common_runtime/gpu/gpu_cudamallocasync_allocator.h @@ -106,8 +106,8 @@ class GpuCudaMallocAsyncAllocator : public Allocator { // Stats. // Structures mutable after construction mutable mutex lock_; - std::unique_ptr stats_ PT_GUARDED_BY(lock_); - absl::flat_hash_map size_map_ GUARDED_BY(lock_); + std::unique_ptr stats_ TF_PT_GUARDED_BY(lock_); + absl::flat_hash_map size_map_ TF_GUARDED_BY(lock_); }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h index 5794f49dd17..27d4f04c68a 100644 --- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h +++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h @@ -22,6 +22,7 @@ limitations under the License. #ifdef INTEL_MKL #include + #include "tensorflow/core/common_runtime/bfc_allocator.h" #include "tensorflow/core/common_runtime/pool_allocator.h" #include "tensorflow/core/lib/strings/numbers.h" @@ -29,10 +30,6 @@ limitations under the License. #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/numa.h" -#ifndef INTEL_MKL_DNN_ONLY -#include "i_malloc.h" -#endif - #ifdef _WIN32 typedef unsigned int uint; #endif @@ -186,14 +183,6 @@ class MklCPUAllocator : public Allocator { new MklSmallSizeAllocator(sub_allocator_, max_mem_bytes, kName); large_size_allocator_ = new BFCAllocator(sub_allocator_, max_mem_bytes, kAllowGrowth, kName); -#ifndef INTEL_MKL_DNN_ONLY - // For redirecting all allocations from MKL to this allocator - // From: http://software.intel.com/en-us/node/528565 - i_malloc = MallocHook; - i_calloc = CallocHook; - i_realloc = ReallocHook; - i_free = FreeHook; -#endif return Status::OK(); } diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index 7e7bfafea80..747ae231bcd 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/graph/graph_def_builder.h" @@ -499,6 +500,14 @@ Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensor( return s; } } + if (t.dtype() == DT_RESOURCE && ctx->serialize_data_tensors()) { + Status s = AddResourceHelper(ctx, t, output); + if (!errors::IsUnimplemented(s)) { + // Fall through to AddTensor if AsGraphDef is not implemented for this + // resource. + return s; + } + } return AddTensor(t, output); } @@ -524,6 +533,15 @@ Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensorHelper( return Status::OK(); } +Status DatasetBase::DatasetGraphDefBuilder::AddResourceHelper( + SerializationContext* ctx, const Tensor& t, Node** output) { + const ResourceHandle& handle = t.flat()(0); + ResourceBase* resource; + TF_RETURN_IF_ERROR(ctx->resource_mgr()->Lookup(handle, &resource)); + core::ScopedUnref unref(resource); + return resource->AsGraphDef(*builder(), output); +} + DatasetBaseIterator::DatasetBaseIterator(const BaseParams& params) : params_(params) { params_.dataset->Ref(); diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 9770bf2f8cc..adc9f9ae1a8 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -570,6 +570,9 @@ class SerializationContext { // seeds. This param does not affect datasets that use fixed seeds; fixed // seeds will always be preserved. bool preserve_random_seeds = true; + + // A resource manager for looking up resources during serialization. + ResourceMgr* resource_mgr; }; explicit SerializationContext(Params params) : params_(params) {} @@ -588,6 +591,8 @@ class SerializationContext { bool preserve_random_seeds() const { return params_.preserve_random_seeds; } + ResourceMgr* resource_mgr() const { return params_.resource_mgr; } + private: Params params_; @@ -922,6 +927,8 @@ class DatasetBase : public core::RefCounted { private: Status AddDatasetOrTensorHelper(SerializationContext* ctx, const Tensor& val, Node** output); + Status AddResourceHelper(SerializationContext* ctx, const Tensor& val, + Node** output); }; // Serializes the dataset into a `GraphDef`, which has two uses: diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h index 01688795af9..1d67d327f33 100644 --- a/tensorflow/core/framework/resource_mgr.h +++ b/tensorflow/core/framework/resource_mgr.h @@ -31,10 +31,12 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/type_index.h" #include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" @@ -83,6 +85,13 @@ class ResourceBase : public core::RefCounted { // Returns memory used by this resource. virtual int64 MemoryUsed() const { return 0; } + + // Writes a representation of this resource into `builder`, so that executing + // `*out` will recreate this resource. + virtual Status AsGraphDef(GraphDefBuilder& builder, Node** out) const { + return errors::Unimplemented("AsGraphDef not implemented for resource ", + DebugString()); + } }; // Container used for per-step resources. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index ea2901d05ac..e0eee7360dd 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -50,6 +50,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/tensor_coding.h" #include "tensorflow/core/protobuf/error_codes.pb.h" #include "tensorflow/core/util/device_name_utils.h" @@ -199,6 +200,21 @@ bool NodeIsOnCpu(const NodeDef& node) { absl::StrContains(device, DEVICE_CPU); } +// True if all regular (non-control) inputs reference the same node or if there +// are no non-control inputs +bool AllRegularInputsEqual(const NodeDef& node) { + if (!HasRegularInputs(node)) return true; + for (int i = 1; i < node.input_size(); ++i) { + if (IsControlInput(node.input(i))) { + break; + } + if (node.input(0) != node.input(i)) { + return false; + } + } + return true; +} + // Graph optimizer context extension specific to ArithmeticOptimizer. struct ArithmeticOptimizerContext { explicit ArithmeticOptimizerContext(SetVector* nodes_to_simplify) @@ -2673,6 +2689,180 @@ class ReplaceMulWithBroadcastByTile : public ArithmeticOptimizerStage { } }; +// Replace a sequence of Pack nodes with identical inputs with Tile +// For example, given a Tensor X with shape (I,J,K) +// Let P(x, n) = Pack([x, x], axis=n) +// +// P(P(X, 2), 1) +// = Tile(Reshape(Tile(Reshape(x, +// [I, J, 1, K]), [1, 1, 2, 1]), +// [I, 1, J, 2, K]), [1, 2, 1, 1, 1])) +// = Tile(Reshape(x, +// [I, 1, J, 1, K]), [1, 2, 1, 2, 1]) +// = Reshape(Tile(x, [1, 2, 2]), [I, 2, J, 2, K]) +// +// The outermost reshape is often redundant and can be removed in another pass +class ReplacePackWithTileReshape : public ArithmeticOptimizerStage { + public: + explicit ReplacePackWithTileReshape(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("ReplacePackWithTileReshape", ctx, ctx_ext) {} + ~ReplacePackWithTileReshape() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsPack(*node) && NumNonControlInputs(*node) > 1 && + !IsInPreserveSet(*node); + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + // 1. traverse the chain of Pack ops to get the original input + NodeDef* input = node; + std::vector chain; + while (IsPack(*input) && NumNonControlInputs(*node) > 1 && + !IsInPreserveSet(*input)) { + // Only pack operations with all identical inputs are supported + if (!AllRegularInputsEqual(*input)) { + break; + } + chain.push_back(input); + TF_RETURN_IF_ERROR(GetInputNode(input->input(0), &input)); + } + + // Must be at least two Pack operations to consider for replacement + if (chain.empty()) { + return Status::OK(); + } + + // Avoid optimizing the same node twice + const NodeScopeAndName node_scope_and_name = + ParseNodeScopeAndName(node->name()); + const string new_const_name = + OptimizedNodeName(node_scope_and_name, "Multiples"); + const string new_tile_name = OptimizedNodeName(node_scope_and_name, "Tile"); + const string new_shape_name = + OptimizedNodeName(node_scope_and_name, "Shape"); + const string new_reshape_name = + OptimizedNodeName(node_scope_and_name, "Reshape"); + if (ctx().node_map->NodeExists(new_const_name) || + ctx().node_map->NodeExists(new_tile_name) || + ctx().node_map->NodeExists(new_shape_name) || + ctx().node_map->NodeExists(new_reshape_name)) { + return Status::OK(); + } + + // 2. Calculate the multiples and shape tensor using the chain + const OpInfo::TensorProperties* input_props; + TF_RETURN_IF_ERROR(GetTensorProperties(input->name(), &input_props)); + const TensorShapeProto& input_shape = input_props->shape(); + if (!PartialTensorShape(input_shape).IsFullyDefined()) { + return Status::OK(); + } + Tensor multiples(DT_INT32, TensorShape({input_shape.dim_size()})); + TF_RETURN_IF_ERROR(CalculateMultiplesFromChain(chain, &multiples)); + + const OpInfo::TensorProperties* output_props; + TF_RETURN_IF_ERROR(GetTensorProperties(node->name(), &output_props)); + const TensorShapeProto& output_shape = output_props->shape(); + if (!PartialTensorShape(output_shape).IsFullyDefined()) { + return Status::OK(); + } + Tensor output_shape_tensor(DT_INT32, + TensorShape({output_shape.dim_size()})); + for (int i = 0; i < output_shape.dim_size(); ++i) { + output_shape_tensor.flat()(i) = output_shape.dim(i).size(); + } + + // 3. Create constant node with correct multiples value + NodeDef* new_const_node = AddEmptyNode(new_const_name); + TF_RETURN_IF_ERROR(ConstantFolding::CreateNodeDef( + new_const_node->name(), TensorValue(&multiples), new_const_node)); + new_const_node->set_device(node->device()); + MaybeAddControlInput(input->name(), new_const_node, ctx().optimized_graph, + ctx().node_map); + AddToOptimizationQueue(new_const_node); + + // 4. Replace the Pack node with Tile(Const(N), input); + DataType dtype = GetDataTypeFromAttr(*node, "T"); + NodeDef* new_tile_node = AddEmptyNode(new_tile_name); + new_tile_node->set_op("Tile"); + new_tile_node->set_device(node->device()); + SetDataTypeToAttr(dtype, "T", new_tile_node); + SetDataTypeToAttr(DT_INT32, "Tmultiples", new_tile_node); + new_tile_node->add_input(input->name()); + ctx().node_map->AddOutput(input->name(), new_tile_node->name()); + new_tile_node->add_input(new_const_node->name()); + ctx().node_map->AddOutput(new_const_node->name(), new_tile_node->name()); + + // Tile inherits all control dependencies from the original pack chain + ForwardControlDependencies(new_tile_node, chain); + AddToOptimizationQueue(new_tile_node); + + // 5. Add a new Reshape node to preserve the existing shape + NodeDef* new_shape_node = AddEmptyNode(new_shape_name); + TF_RETURN_IF_ERROR(ConstantFolding::CreateNodeDef( + new_shape_node->name(), TensorValue(&output_shape_tensor), + new_shape_node)); + new_shape_node->set_device(node->device()); + MaybeAddControlInput(input->name(), new_shape_node, ctx().optimized_graph, + ctx().node_map); + AddToOptimizationQueue(new_shape_node); + + NodeDef* new_reshape_node = AddEmptyNode(new_reshape_name); + new_reshape_node->set_op("Reshape"); + new_reshape_node->set_device(node->device()); + SetDataTypeToAttr(dtype, "T", new_reshape_node); + SetDataTypeToAttr(DT_INT32, "Tshape", new_reshape_node); + new_reshape_node->add_input(new_tile_node->name()); + ctx().node_map->AddOutput(new_tile_node->name(), new_reshape_node->name()); + new_reshape_node->add_input(new_shape_node->name()); + ctx().node_map->AddOutput(new_shape_node->name(), new_reshape_node->name()); + + *simplified_node_name = new_reshape_node->name(); + + return Status::OK(); + } + + protected: + Status CalculateMultiplesFromChain(const std::vector& chain, + Tensor* multiples) { + // Keep track of how the multiples correspond to each shape dimension. + // For example, given Stack([x, x], axis=1) with rank(x) = 3, we start with + // multiples=[1, 1, 1] , dims=[0, 1, 2] + // After processing the stack op + // multiples=[1, 2, 1] , dims=[0, 1, 1, 2] + std::vector dims(multiples->NumElements()); + std::iota(dims.begin(), dims.end(), 0); + + for (int i = 0; i < multiples->NumElements(); ++i) { + multiples->flat()(i) = 1; + } + + for (auto it = chain.rbegin(); it != chain.rend(); ++it) { + AttrSlice attrs(**it); + int64 axis, n; + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "axis", &axis)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "N", &n)); + + if (axis >= dims.size()) { + // We don't handle the case where Pack is performed on the last axis, + // e.g. Pack([x, x], axis=3) where rank(x) == 3 + return Status(error::OUT_OF_RANGE, "axis value out of range of dims"); + } + + int64 m = multiples->flat()(dims[axis]) * n; + if (TF_PREDICT_FALSE(m > INT_MAX)) { + return Status(error::OUT_OF_RANGE, "int32 overflow"); + } + multiples->flat()(dims[axis]) = static_cast(m); + + // Copy index from immediate right of inserted axis + dims.insert(dims.begin() + axis, dims[axis]); + } + + return Status::OK(); + } +}; + // Simplify aggregation (e.g. AddN) nodes: // // 1. Discard aggregate nodes with a single input and no control dependencies. @@ -3917,6 +4107,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage(ctx, ctx_ext); if (options_.remove_negation) pipeline.AddStage(ctx, ctx_ext); + if (options_.replace_pack_with_tile_reshape) + pipeline.AddStage(ctx, ctx_ext); if (options_.replace_mul_with_square) pipeline.AddStage(ctx, ctx_ext); if (options_.replace_mul_with_tile) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 373e5004d8d..45d1993661f 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -81,13 +81,14 @@ class ArithmeticOptimizer : public GraphOptimizer { bool reorder_redundant_reshape_around_unary = true; bool replace_mul_with_tile = true; bool replace_mul_with_square = true; - bool simplify_aggregation = true; + bool replace_pack_with_tile_reshape = true; bool convert_pow = true; bool convert_log1p = true; bool convert_log_softmax = true; bool convert_expm1 = true; bool unary_ops_composition = true; bool remove_stack_slice_same_axis = true; + bool simplify_aggregation = true; bool simplify_embedding_lookup = true; bool remove_cast_into_segment_reduction = true; diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index a0d971365d8..a42c81ebc97 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -324,6 +324,195 @@ TEST_F(ArithmeticOptimizerTest, ReplaceMulWithSquare) { test::ExpectTensorNear(tensors[0], tensors_expected[0], 1e-6); } +TEST_F(ArithmeticOptimizerTest, ReplacePackWithTileReshape) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT, + ops::Placeholder::Shape({3, 5, 7, 11})); + // Stack creates Pack nodes + Output b = ops::Stack(s.WithOpName("b"), {a, a}, ops::Stack::Axis(3)); + Output c = ops::Stack(s.WithOpName("c"), {b, b}, ops::Stack::Axis(2)); + Output o = ops::Identity(s.WithOpName("output"), c); + + GrapplerItem item; + item.fetch = {"output"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto a_t = GenerateRandomTensor(TensorShape({3, 5, 7, 11})); + auto expected = EvaluateNodes(item.graph, item.fetch, {{"a", a_t}}); + ASSERT_EQ(expected.size(), 1); + + GraphDef g; + ArithmeticOptimizer optimizer; + EnableOnlyReplacePackWithTileReshape(&optimizer); + OptimizeAndPrune(&optimizer, &item, &g); + + EXPECT_EQ(g.node_size(), 6); + EXPECT_EQ(CountOpNodes(g, "Pack"), 0); + EXPECT_EQ(CountOpNodes(g, "Tile"), 1); + EXPECT_EQ(CountOpNodes(g, "Const"), 2); + EXPECT_EQ(CountOpNodes(g, "Reshape"), 1); + + NodeMap node_map(&g); + const string p = "ArithmeticOptimizer/ReplacePackWithTileReshape"; + const NodeDef* t_node = node_map.GetNode(absl::StrCat(p, "_", "Tile_c")); + const NodeDef* c_node = node_map.GetNode(absl::StrCat(p, "_", "Multiples_c")); + const NodeDef* s_node = node_map.GetNode(absl::StrCat(p, "_", "Shape_c")); + const NodeDef* r_node = node_map.GetNode(absl::StrCat(p, "_", "Reshape_c")); + const NodeDef* a_node = node_map.GetNode("a"); + ASSERT_NE(t_node, nullptr); + ASSERT_NE(c_node, nullptr); + ASSERT_NE(s_node, nullptr); + ASSERT_NE(r_node, nullptr); + ASSERT_NE(a_node, nullptr); + + EXPECT_EQ(c_node->op(), "Const"); + EXPECT_EQ(s_node->op(), "Const"); + + // Check Reshape properties + ASSERT_EQ(r_node->input_size(), 2); + EXPECT_EQ(r_node->op(), "Reshape"); + EXPECT_EQ(r_node->input(0), t_node->name()); + EXPECT_EQ(r_node->input(1), s_node->name()); + + // Check Tile properties + ASSERT_EQ(t_node->input_size(), 2); + EXPECT_EQ(t_node->op(), "Tile"); + EXPECT_EQ(t_node->input(0), a_node->name()); + EXPECT_EQ(t_node->input(1), c_node->name()); + EXPECT_EQ(t_node->attr().at("T").type(), DT_FLOAT); + EXPECT_EQ(t_node->attr().at("Tmultiples").type(), + c_node->attr().at("dtype").type()); + + auto result = EvaluateNodes(g, item.fetch, {{"a", a_t}}); + ASSERT_EQ(result.size(), 1); + test::ExpectTensorNear(result[0], expected[0], 1e-6); +} + +TEST_F(ArithmeticOptimizerTest, ReplacePackWithTileReshapeControlDeps) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT, + ops::Placeholder::Shape({3, 5, 7, 11})); + + Output x = ops::Identity(s.WithOpName("x"), a); + Output y = ops::Identity(s.WithOpName("y"), a); + + Output b = ops::Stack(s.WithOpName("b").WithControlDependencies(x), {a, a}, + ops::Stack::Axis(3)); + Output c = ops::Stack(s.WithOpName("c").WithControlDependencies(y), {b, b}, + ops::Stack::Axis(2)); + Output o = ops::Identity(s.WithOpName("output"), c); + + GrapplerItem item; + item.fetch = {"output"}; + item.keep_ops = {"x", "y"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto a_t = GenerateRandomTensor(TensorShape({3, 5, 7, 11})); + auto expected = EvaluateNodes(item.graph, item.fetch, {{"a", a_t}}); + ASSERT_EQ(expected.size(), 1); + + GraphDef g; + ArithmeticOptimizer optimizer; + EnableOnlyReplacePackWithTileReshape(&optimizer); + OptimizeAndPrune(&optimizer, &item, &g); + + EXPECT_EQ(g.node_size(), 8); + EXPECT_EQ(CountOpNodes(g, "Pack"), 0); + EXPECT_EQ(CountOpNodes(g, "Tile"), 1); + EXPECT_EQ(CountOpNodes(g, "Const"), 2); + EXPECT_EQ(CountOpNodes(g, "Reshape"), 1); + EXPECT_EQ(CountOpNodes(g, "Identity"), 3); + + NodeMap node_map(&g); + const string p = "ArithmeticOptimizer/ReplacePackWithTileReshape"; + const NodeDef* t_node = node_map.GetNode(absl::StrCat(p, "_", "Tile_c")); + const NodeDef* c_node = node_map.GetNode(absl::StrCat(p, "_", "Multiples_c")); + const NodeDef* s_node = node_map.GetNode(absl::StrCat(p, "_", "Shape_c")); + const NodeDef* a_node = node_map.GetNode("a"); + ASSERT_NE(t_node, nullptr); + ASSERT_NE(c_node, nullptr); + ASSERT_NE(s_node, nullptr); + ASSERT_NE(a_node, nullptr); + + ASSERT_EQ(t_node->input_size(), 4); + EXPECT_EQ(t_node->op(), "Tile"); + EXPECT_EQ(t_node->input(0), a_node->name()); + EXPECT_EQ(t_node->input(1), c_node->name()); + EXPECT_EQ(t_node->input(2), "^y"); + EXPECT_EQ(t_node->input(3), "^x"); + + ASSERT_EQ(c_node->input_size(), 1); + EXPECT_EQ(c_node->input(0), "^a"); + + ASSERT_EQ(s_node->input_size(), 1); + ASSERT_EQ(s_node->input(0), "^a"); + + auto result = EvaluateNodes(g, item.fetch, {{"a", a_t}}); + ASSERT_EQ(result.size(), 1); + test::ExpectTensorNear(result[0], expected[0], 1e-6); +} + +TEST_F(ArithmeticOptimizerTest, ReplacePackWithTileRemoveReshape) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT, + ops::Placeholder::Shape({3, 5, 7, 11})); + // Stack creates Pack nodes + Output b = ops::Stack(s.WithOpName("b"), {a, a}, ops::Stack::Axis(3)); + Output c = ops::Stack(s.WithOpName("c"), {b, b}, ops::Stack::Axis(2)); + Output r = + ops::Reshape(s.WithOpName("r"), c, ops::Const(s, {3, 10, 14, 11}, {4})); + Output o = ops::Identity(s.WithOpName("output"), r); + + GrapplerItem item; + item.fetch = {"output"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto a_t = GenerateRandomTensor(TensorShape({3, 5, 7, 11})); + auto expected = EvaluateNodes(item.graph, item.fetch, {{"a", a_t}}); + ASSERT_EQ(expected.size(), 1); + + GraphDef g; + ArithmeticOptimizer optimizer; + EnableOnlyReplacePackWithTileReshape(&optimizer); + OptimizeAndPrune(&optimizer, &item, &g); + + EXPECT_EQ(g.node_size(), 8); + EXPECT_EQ(CountOpNodes(g, "Pack"), 0); + EXPECT_EQ(CountOpNodes(g, "Tile"), 1); + EXPECT_EQ(CountOpNodes(g, "Const"), 3); + EXPECT_EQ(CountOpNodes(g, "Reshape"), 2); + + EnableOnlyRemoveRedundantReshape(&optimizer); + OptimizeAndPrune(&optimizer, &item, &g); + + EXPECT_EQ(g.node_size(), 6); + EXPECT_EQ(CountOpNodes(g, "Pack"), 0); + EXPECT_EQ(CountOpNodes(g, "Tile"), 1); + EXPECT_EQ(CountOpNodes(g, "Const"), 2); + EXPECT_EQ(CountOpNodes(g, "Reshape"), 1); + + auto result = EvaluateNodes(g, item.fetch, {{"a", a_t}}); + ASSERT_EQ(result.size(), 1); + test::ExpectTensorNear(result[0], expected[0], 1e-6); +} + +TEST_F(ArithmeticOptimizerTest, ReplacePackWithTileReshapeOutOfRange) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT, + ops::Placeholder::Shape({3, 5, 7, 11})); + // Stack creates Pack nodes + Output b = ops::Stack(s.WithOpName("b"), {a, a}, ops::Stack::Axis(4)); + Output o = ops::Identity(s.WithOpName("output"), b); + + GrapplerItem item; + item.fetch = {"output"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + GraphDef g; + ArithmeticOptimizer optimizer; + EnableOnlyReplacePackWithTileReshape(&optimizer); + OptimizeAndPrune(&optimizer, &item, &g); + + VerifyGraphsMatch(item.graph, g, __LINE__); +} + TEST_F(ArithmeticOptimizerTest, RemoveInvolutionAdjacentNodes) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h index f2600cf79ad..a7a6642fbfd 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h @@ -169,6 +169,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { optimizer->options_.replace_mul_with_square = true; } + void EnableOnlyReplacePackWithTileReshape(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.replace_pack_with_tile_reshape = true; + } + void EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer* optimizer) { DisableAllStages(optimizer); optimizer->options_.hoist_cwise_unary_chains = true; diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc index 9b41df4cab5..92ad70761a8 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc @@ -1986,7 +1986,7 @@ Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item, return errors::InvalidArgument("cluster == nullptr"); } -#if !defined(INTEL_MKL) || !defined(ENABLE_INTEL_MKL_BFLOAT16) +#if !defined(INTEL_MKL) if (mode_ == AutoMixedPrecisionMode::MKL) { return errors::Unimplemented( "The auto_mixed_precision_mkl optimizer cannot be used since " diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc index 90c8bc82b70..3c0def0169f 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM || \ - (INTEL_MKL && defined(ENABLE_INTEL_MKL_BFLOAT16)) +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM || INTEL_MKL #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h" @@ -1178,7 +1177,6 @@ TEST_F(AutoMixedPrecisionTest, TanhOp) { #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #if INTEL_MKL -#ifdef ENABLE_INTEL_MKL_BFLOAT16 class AutoMixedPrecisionMklTest : public GrapplerTest { protected: @@ -1354,12 +1352,10 @@ TEST_F(AutoMixedPrecisionMklTest, TensorListSetGet) { } } -#endif // ENABLE_INTEL_MKL_BFLOAT16 #endif // INTEL_MKL } // namespace } // namespace grappler } // namespace tensorflow -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM || (INTEL_MKL && - // defined(ENABLE_INTEL_MKL_BFLOAT16)) +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM || INTEL_MKL diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index c5e25feb14e..623cae3302a 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -36,6 +36,7 @@ cc_library( ":reorder_data_discarding_ops", ":shuffle_and_repeat_fusion", ":slack", + ":use_private_thread_pool", ], ) @@ -954,6 +955,41 @@ tf_cc_test( ], ) +cc_library( + name = "use_private_thread_pool", + srcs = ["use_private_thread_pool.cc"], + hdrs = ["use_private_thread_pool.h"], + deps = [ + ":graph_utils", + ":optimizer_base", + "//tensorflow/core/grappler:mutable_graph_view", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "//tensorflow/core:lib_internal", + ] + tf_protos_all(), + alwayslink = 1, +) + +tf_cc_test( + name = "use_private_thread_pool_test", + srcs = ["use_private_thread_pool_test.cc"], + deps = [ + ":graph_test_utils", + ":graph_utils", + ":use_private_thread_pool", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", + ], +) + cc_library( name = "vectorization_utils", srcs = ["vectorization_utils.cc"], diff --git a/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc index b952a56bf25..6a12fb184da 100644 --- a/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc @@ -35,9 +35,10 @@ using ConfigMap = std::map; // tf.data optimizations, in the order we want to perform them. -constexpr std::array kTFDataOptimizations = { +constexpr std::array kTFDataOptimizations = { "noop_elimination", "disable_intra_op_parallelism", + "use_private_thread_pool", "shuffle_and_repeat_fusion", "map_fusion", "filter_fusion", diff --git a/tensorflow/core/grappler/optimizers/data/use_private_thread_pool.cc b/tensorflow/core/grappler/optimizers/data/use_private_thread_pool.cc new file mode 100644 index 00000000000..8eb3b505e7e --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/use_private_thread_pool.cc @@ -0,0 +1,117 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/use_private_thread_pool.h" + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { +namespace grappler { +namespace { + +constexpr char kPrivateThreadPoolDataset[] = "PrivateThreadPoolDataset"; +constexpr char kModelDataset[] = "ModelDataset"; + +} // namespace + +Status UsePrivateThreadPool::OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) { + *output = item.graph; + MutableGraphView graph(output); + + // If the GrapplerItem is derived from a FunctionDef, we don't optimize it, + if (graph_utils::IsItemDerivedFromFunctionDef(item, graph)) + return Status::OK(); + + if (item.fetch.size() != 1) { + return errors::InvalidArgument( + "Expected only one fetch node but there were ", item.fetch.size(), ": ", + absl::StrJoin(item.fetch, ", ")); + } + + for (const NodeDef& node : item.graph.node()) { + if (node.op() == kPrivateThreadPoolDataset) { + // If private thread pool is set by the user, we keep the user setting + // instead of rewriting it. + return Status::OK(); + } + } + + NodeDef* sink_node = graph.GetNode(item.fetch.at(0)); + NodeDef* last_node = graph_utils::GetInputNode(*sink_node, graph); + // If the pipeline is autotuned (ModelDataset exists as the last dataset in + // the pipeline), we insert PrivateThreadPoolDataset before ModelDataset. + // If the pipeline is not autotuned (ModelDataset doesn't exist), we insert + // PrivateThreadPoolDataset as the last dataset in the pipeline. + // + // In general, if exists, ModelDataset should be the last dataset in the + // pipeline. + if (last_node->op() == kModelDataset) { + last_node = graph_utils::GetInputNode(*last_node, graph); + } + + // Add a const node with value 0 to indicate it is not set by users. + NodeDef* num_threads_value = + graph_utils::AddScalarConstNode(int64{0}, &graph); + + NodeDef insert_node; + graph_utils::SetUniqueGraphNodeName("private_thread_pool", graph.graph(), + &insert_node); + insert_node.set_op(kPrivateThreadPoolDataset); + + // `input_dataset` input + *insert_node.mutable_input()->Add() = last_node->name(); + // `num_threads` input + *insert_node.mutable_input()->Add() = num_threads_value->name(); + + // Set `output_types` and `output_shapes` attributes by copying the relevant + // attrs from the input node. If we fail to set the attributes, we abort the + // rewrite. + for (auto attr : {"output_shapes", "output_types"}) { + if (last_node->attr().find(attr) != last_node->attr().end()) { + graph_utils::CopyAttribute(attr, *last_node, &insert_node); + } else { + return Status::OK(); + } + } + + auto* added_node = graph.AddNode(std::move(insert_node)); + TF_RETURN_IF_ERROR( + graph.UpdateFanouts(last_node->name(), added_node->name())); + + stats->num_changes++; + return Status::OK(); +} + +void UsePrivateThreadPool::Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, + double result) { + // no-op +} + +REGISTER_GRAPH_OPTIMIZER_AS(UsePrivateThreadPool, "use_private_thread_pool"); + +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/use_private_thread_pool.h b/tensorflow/core/grappler/optimizers/data/use_private_thread_pool.h new file mode 100644 index 00000000000..1cafa1f5308 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/use_private_thread_pool.h @@ -0,0 +1,50 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_USE_PRIVATE_THREAD_POOL_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_USE_PRIVATE_THREAD_POOL_H_ + +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +// This optimization creates private thread pool for the input pipeline. +class UsePrivateThreadPool : public TFDataOptimizerBase { + public: + UsePrivateThreadPool() = default; + ~UsePrivateThreadPool() override = default; + + string name() const override { return "use_private_thread_pool"; }; + + bool UsesFunctionLibrary() const override { return false; } + + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return Status::OK(); + } + + Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, double result) override; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_USE_PRIVATE_THREAD_POOL_H_ diff --git a/tensorflow/core/grappler/optimizers/data/use_private_thread_pool_test.cc b/tensorflow/core/grappler/optimizers/data/use_private_thread_pool_test.cc new file mode 100644 index 00000000000..ad1761e9849 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/use_private_thread_pool_test.cc @@ -0,0 +1,199 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/use_private_thread_pool.h" + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +using test::function::NDef; + +// If the user manually sets private thread pool, we don't insert the op. +class ThreadPoolOpAlreadySetTest : public ::testing::TestWithParam {}; + +TEST_P(ThreadPoolOpAlreadySetTest, PrivateThreadPool) { + const int64 num_of_threads = GetParam(); + + GrapplerItem item; + MutableGraphView graph(&item.graph); + + NodeDef *start_val = graph_utils::AddScalarConstNode(0, &graph); + NodeDef *stop_val = graph_utils::AddScalarConstNode(10, &graph); + NodeDef *step_val = graph_utils::AddScalarConstNode(1, &graph); + std::vector range_inputs(3); + range_inputs[0] = start_val->name(); + range_inputs[1] = stop_val->name(); + range_inputs[2] = step_val->name(); + std::vector> range_attrs; + NodeDef *range_node = graph_utils::AddNode("range", "RangeDataset", + range_inputs, range_attrs, &graph); + NodeDef *num_of_threads_val = + graph_utils::AddScalarConstNode(num_of_threads, &graph); + std::vector private_threads_inputs(2); + private_threads_inputs[0] = range_node->name(); + private_threads_inputs[1] = num_of_threads_val->name(); + std::vector> private_threads_attrs; + NodeDef *private_threads_node = graph_utils::AddNode( + "private_thread_pool", "PrivateThreadPoolDataset", private_threads_inputs, + private_threads_attrs, &graph); + std::vector sink_inputs(1); + sink_inputs[0] = private_threads_node->name(); + std::vector> sink_attrs; + NodeDef *sink_node = + graph_utils::AddNode("Sink", "Identity", sink_inputs, sink_attrs, &graph); + item.fetch.push_back(sink_node->name()); + + EXPECT_TRUE( + graph_utils::ContainsNodeWithOp("PrivateThreadPoolDataset", item.graph)); + EXPECT_EQ(item.graph.node_size(), 7); + EXPECT_EQ(num_of_threads_val->attr().at("value").tensor().int64_val(0), + num_of_threads); + + UsePrivateThreadPool optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + EXPECT_EQ(output.node_size(), 7); + EXPECT_TRUE( + graph_utils::ContainsNodeWithOp("PrivateThreadPoolDataset", output)); + NodeDef new_private_threads_node = output.node( + graph_utils::FindGraphNodeWithOp("PrivateThreadPoolDataset", output)); + NodeDef new_num_of_threads_val = + output.node(graph_utils::FindGraphNodeWithName( + new_private_threads_node.input(1), output)); + EXPECT_EQ(new_num_of_threads_val.attr().at("value").tensor().int64_val(0), + num_of_threads); +} + +INSTANTIATE_TEST_SUITE_P(Test, ThreadPoolOpAlreadySetTest, + ::testing::Values(1, 2, 4)); + +// Test the case if the user hasn't set private thread pool. +// +// If we can not find the sink node or sink node op is "_Retval", we don't apply +// the optimization; otherwise, we insert the op to use private thread pool. +class ThreadPoolOpNotSetTest : public ::testing::TestWithParam {}; + +TEST_P(ThreadPoolOpNotSetTest, PrivateThreadPool) { + const string op = GetParam(); + GrapplerItem item; + + item.graph = test::function::GDef( + {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}), + NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}), + NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("range", "RangeDataset", {"start", "stop", "step"}, + {{"output_shapes", gtl::ArraySlice{}}, + {"output_types", gtl::ArraySlice{}}}), + NDef("Sink", op, {"range"}, {})}); + EXPECT_FALSE( + graph_utils::ContainsNodeWithOp("PrivateThreadPoolDataset", item.graph)); + EXPECT_EQ(item.graph.node_size(), 5); + item.fetch.push_back("Sink_fake"); + + UsePrivateThreadPool optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + EXPECT_FALSE( + graph_utils::ContainsNodeWithOp("PrivateThreadPoolDataset", output)); + EXPECT_EQ(output.node_size(), 5); + + item.fetch[0] = "Sink"; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + if (op == "_Retval") { + EXPECT_FALSE( + graph_utils::ContainsNodeWithOp("PrivateThreadPoolDataset", output)); + EXPECT_EQ(output.node_size(), 5); + return; + } + + EXPECT_EQ(output.node_size(), 7); + EXPECT_TRUE( + graph_utils::ContainsNodeWithOp("PrivateThreadPoolDataset", output)); + NodeDef sink_node = + output.node(graph_utils::FindGraphNodeWithName("Sink", output)); + EXPECT_EQ(sink_node.input_size(), 1); + NodeDef private_threads_node = output.node( + graph_utils::FindGraphNodeWithName(sink_node.input(0), output)); + EXPECT_EQ(private_threads_node.op(), "PrivateThreadPoolDataset"); + EXPECT_EQ(private_threads_node.input_size(), 2); + NodeDef range_node = output.node(graph_utils::FindGraphNodeWithName( + private_threads_node.input(0), output)); + EXPECT_EQ(range_node.name(), "range"); + NodeDef num_of_threads_val = output.node(graph_utils::FindGraphNodeWithName( + private_threads_node.input(1), output)); + EXPECT_EQ(num_of_threads_val.attr().at("value").tensor().int64_val(0), 0); +} + +INSTANTIATE_TEST_SUITE_P(Test, ThreadPoolOpNotSetTest, + ::testing::Values("Identity", "_Retval")); + +// Test the autotune case with ModelDataset in the pipeline. We will insert +// PrivateThreadPoolDataset before ModelDataset. +TEST(AutotuneWithModelTest, PrivateThreadPool) { + GrapplerItem item; + + item.graph = test::function::GDef( + {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}), + NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}), + NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("range", "RangeDataset", {"start", "stop", "step"}, + {{"output_shapes", gtl::ArraySlice{}}, + {"output_types", gtl::ArraySlice{}}}), + NDef("model", "ModelDataset", {"range"}, {}), + NDef("Sink", "Identity", {"model"}, {})}); + EXPECT_FALSE( + graph_utils::ContainsNodeWithOp("PrivateThreadPoolDataset", item.graph)); + EXPECT_EQ(item.graph.node_size(), 6); + item.fetch.push_back("Sink"); + + UsePrivateThreadPool optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + EXPECT_EQ(output.node_size(), 8); + EXPECT_TRUE( + graph_utils::ContainsNodeWithOp("PrivateThreadPoolDataset", output)); + NodeDef sink_node = + output.node(graph_utils::FindGraphNodeWithName("Sink", output)); + EXPECT_EQ(sink_node.input_size(), 1); + NodeDef model_node = output.node( + graph_utils::FindGraphNodeWithName(sink_node.input(0), output)); + EXPECT_EQ(model_node.op(), "ModelDataset"); + EXPECT_EQ(model_node.input_size(), 1); + NodeDef private_threads_node = output.node( + graph_utils::FindGraphNodeWithName(model_node.input(0), output)); + EXPECT_EQ(private_threads_node.op(), "PrivateThreadPoolDataset"); + EXPECT_EQ(private_threads_node.input_size(), 2); + NodeDef range_node = output.node(graph_utils::FindGraphNodeWithName( + private_threads_node.input(0), output)); + EXPECT_EQ(range_node.name(), "range"); + NodeDef num_of_threads_val = output.node(graph_utils::FindGraphNodeWithName( + private_threads_node.input(1), output)); + EXPECT_EQ(num_of_threads_val.attr().at("value").tensor().int64_val(0), 0); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index 76f082ae7d5..2719150d70b 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -230,15 +230,9 @@ bool IsCpuCompatibleDataType(const NodeDef* contraction, const string& type_attr = "T") { DataType dtype = GetDataTypeFromAttr(*contraction, type_attr); #if defined(INTEL_MKL) -#if defined(ENABLE_INTEL_MKL_BFLOAT16) if (IsConv2D(*contraction) || IsDepthwiseConv2dNative(*contraction) || IsMatMul(*contraction)) { return dtype == DT_FLOAT || dtype == DT_BFLOAT16; -#else - if (IsConv2D(*contraction) || IsDepthwiseConv2dNative(*contraction) || - IsMatMul(*contraction)) { - return dtype == DT_FLOAT; -#endif // ENABLE_INTEL_MKL_BFLOAT16 #else if (IsConv2D(*contraction)) { return dtype == DT_FLOAT || dtype == DT_DOUBLE; @@ -677,14 +671,9 @@ bool FindContractionWithBiasAddAndAdd(const RemapperContext& ctx, const auto* node_def = node_view.node(); if (!IsAddN(*node_def) && !IsAddWithNoBroadcast(ctx, *node_def)) return false; -#ifdef ENABLE_INTEL_MKL_BFLOAT16 // MKL AddN ops only support float and bfloat16 data types. if (!HasDataType(node_def, DT_FLOAT) && !HasDataType(node_def, DT_BFLOAT16)) return false; -#else - // MKL AddN ops only support float data type. - if (!HasDataType(node_def, DT_FLOAT)) return false; -#endif // ENABLE_INTEL_MKL_BFLOAT16 ContractionWithBiasAdd base; matched->port_id = 0; @@ -730,14 +719,9 @@ bool FindContractionWithBiasAndAddActivation( // Currently, Contraction + Bias + Add + Tanh pattern is not supported if (IsTanh(*node_def)) return false; -#ifdef ENABLE_INTEL_MKL_BFLOAT16 // MKL activation op only supports float and bfloat16 data types. if (!HasDataType(node_def, DT_FLOAT) && !HasDataType(node_def, DT_BFLOAT16)) return false; -#else - // MKL activation op only supports float data type. - if (!HasDataType(node_def, DT_FLOAT)) return false; -#endif // ENABLE_INTEL_MKL_BFLOAT16 // And input to activation must match ContractionWithBiasAddAndAdd pattern. if (node_view->NumRegularFanins() < 1) return false; @@ -843,7 +827,7 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index, const auto* fused_batch_norm_node_def = fused_batch_norm.node(); if (!IsFusedBatchNorm(*fused_batch_norm_node_def)) return false; -#ifndef ENABLE_MKLDNN_V1 +#ifndef INTEL_MKL // We fuse FusedBatchNorm on GPU or MKL CPU. if (!NodeIsOnGpu(fused_batch_norm_node_def)) return false; #else @@ -851,7 +835,7 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index, #endif DataType t_dtype = GetDataTypeFromAttr(*fused_batch_norm_node_def, "T"); -#ifndef ENABLE_MKLDNN_V1 +#ifndef INTEL_MKL if (t_dtype != DT_FLOAT && t_dtype != DT_HALF) return false; #else if (t_dtype != DT_FLOAT && t_dtype != DT_BFLOAT16) return false; @@ -919,7 +903,7 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index, if (IsAdd(*relu_fanin_0_node_def)) { // Currently no CPU implementation for "FusedBatchNorm + SideInput + // "" -#ifdef ENABLE_MKLDNN_V1 +#ifdef INTEL_MKL return false; #endif @@ -1022,7 +1006,7 @@ void CopyFusedBatchNormAttributes(const NodeDef& fused_batch_norm, if (fused_batch_norm.op() != "FusedBatchNorm") { SetAttrValue(src_attr.at("U"), &(*attr)["U"]); } else { -#ifndef ENABLE_MKLDNN_V1 +#ifndef INTEL_MKL SetAttrValue(src_attr.at("T"), &(*attr)["U"]); #else SetAttrValue(DT_FLOAT, &(*attr)["U"]); diff --git a/tensorflow/core/grappler/optimizers/remapper_test.cc b/tensorflow/core/grappler/optimizers/remapper_test.cc index 784fcaa9963..c3373481e08 100644 --- a/tensorflow/core/grappler/optimizers/remapper_test.cc +++ b/tensorflow/core/grappler/optimizers/remapper_test.cc @@ -461,10 +461,10 @@ class RemapperFuseMatMulWithBiasTest : public RemapperTest { TEST_F(RemapperFuseMatMulWithBiasTest, F32) { RunTest(); } TEST_F(RemapperFuseMatMulWithBiasTest, Bf16) { -#if !defined(INTEL_MKL) || !defined(ENABLE_INTEL_MKL_BFLOAT16) +#if !defined(INTEL_MKL) GTEST_SKIP() << "Intel MKL with bfloat16 support is not enabled, skipping " "FuseMatMulWithBias with bfloat16."; -#endif // !defined(INTEL_MKL) || !defined(ENABLE_INTEL_MKL_BFLOAT16) +#endif // !defined(INTEL_MKL) RunTest(); // NOLINT } @@ -742,10 +742,10 @@ TEST_F(RemapperFuseMatMulWithBiasAndActivationTest, F32) { } TEST_F(RemapperFuseMatMulWithBiasAndActivationTest, Bf16) { -#if !defined(INTEL_MKL) || !defined(ENABLE_INTEL_MKL_BFLOAT16) +#if !defined(INTEL_MKL) GTEST_SKIP() << "Intel MKL with bfloat16 support is not enabled, skipping " "FuseMatMulWithBiasAndActivation with bfloat16."; -#endif // !defined(INTEL_MKL) || !defined(ENABLE_INTEL_MKL_BFLOAT16) +#endif // !defined(INTEL_MKL) RunTest(); // NOLINT } diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index c908774cede..4af4ce2f969 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -548,10 +548,14 @@ tf_cc_test( ], ) -cc_library( +tf_kernel_library( name = "reshape_util", srcs = ["reshape_util.cc"], hdrs = ["reshape_util.h"], + gpu_srcs = [ + "reshape_util_gpu.cu.cc", + "reshape_util.h", + ], deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -5971,6 +5975,17 @@ filegroup( "xent_op.h", ] + [ "//tensorflow/core/kernels/boosted_trees/quantiles:weighted_quantiles_hdrs", + "//tensorflow/core/kernels/data:batch_dataset_op.h", + "//tensorflow/core/kernels/data:iterator_ops.h", + "//tensorflow/core/kernels/data:map_dataset_op.h", + "//tensorflow/core/kernels/data:dataset_utils.h", + "//tensorflow/core/kernels/data:unbounded_thread_pool.h", + "//tensorflow/core/kernels/data:model_dataset_op.h", + "//tensorflow/core/kernels/data:optimize_dataset_op.h", + "//tensorflow/core/kernels/data:name_utils.h", + "//tensorflow/core/kernels/data:optional_ops.h", + "//tensorflow/core/kernels/data:stats_utils.h", + "//tensorflow/core/kernels/data:captured_function.h", "//tensorflow/core/kernels/image:adjust_contrast_op.h", "//tensorflow/core/kernels/image:adjust_hue_op.h", "//tensorflow/core/kernels/image:adjust_saturation_op.h", @@ -6281,6 +6296,17 @@ filegroup( "xent_op.cc", ] + [ "//tensorflow/core/kernels/boosted_trees:quantile_ops.cc", + "//tensorflow/core/kernels/data:batch_dataset_op.cc", + "//tensorflow/core/kernels/data:iterator_ops.cc", + "//tensorflow/core/kernels/data:map_dataset_op.cc", + "//tensorflow/core/kernels/data:model_dataset_op.cc", + "//tensorflow/core/kernels/data:optimize_dataset_op.cc", + "//tensorflow/core/kernels/data:dataset_utils.cc", + "//tensorflow/core/kernels/data:unbounded_thread_pool.cc", + "//tensorflow/core/kernels/data:stats_utils.cc", + "//tensorflow/core/kernels/data:name_utils.cc", + "//tensorflow/core/kernels/data:optional_ops.cc", + "//tensorflow/core/kernels/data:captured_function.cc", "//tensorflow/core/kernels/image:adjust_contrast_op.cc", "//tensorflow/core/kernels/image:adjust_hue_op.cc", "//tensorflow/core/kernels/image:adjust_saturation_op.cc", @@ -6296,9 +6322,9 @@ filegroup( "//tensorflow/core/kernels/image:resize_bilinear_op.cc", "//tensorflow/core/kernels/image:resize_nearest_neighbor_op.cc", "//tensorflow/core/kernels/image:sample_distorted_bounding_box_op.cc", - "//tensorflow/core/kernels/linalg:linalg_ops_common.cc", "//tensorflow/core/kernels/linalg:cholesky_op.cc", "//tensorflow/core/kernels/linalg:determinant_op.cc", + "//tensorflow/core/kernels/linalg:linalg_ops_common.cc", "//tensorflow/core/kernels/linalg:matrix_diag_op.cc", "//tensorflow/core/kernels/linalg:matrix_inverse_op.cc", "//tensorflow/core/kernels/linalg:matrix_set_diag_op.cc", diff --git a/tensorflow/core/kernels/count_ops.cc b/tensorflow/core/kernels/count_ops.cc index 087deef0812..d6ab68c2c70 100644 --- a/tensorflow/core/kernels/count_ops.cc +++ b/tensorflow/core/kernels/count_ops.cc @@ -192,6 +192,10 @@ class SparseCount : public OpKernel { "; values shape: ", values.shape().DebugString())); } + OP_REQUIRES(context, shape.NumElements() != 0, + errors::InvalidArgument( + "The shape argument requires at least one element.")); + bool is_1d = shape.NumElements() == 1; int num_batches = is_1d ? 1 : shape.flat()(0); int num_values = values.NumElements(); @@ -212,6 +216,14 @@ class SparseCount : public OpKernel { for (int idx = 0; idx < num_values; ++idx) { int batch = is_1d ? 0 : indices_values(idx, 0); + if (batch >= num_batches) { + OP_REQUIRES(context, batch < num_batches, + errors::InvalidArgument( + "Indices value along the first dimension must be ", + "lower than the first index of the shape.", "Got ", + batch, " as batch and ", num_batches, + " as the first dimension of the shape.")); + } const auto& value = values_values(idx); if (value >= 0 && (maxlength_ <= 0 || value < maxlength_)) { if (binary_output_) { diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc index bd62f39e8ae..8330aa81cb3 100644 --- a/tensorflow/core/kernels/cudnn_rnn_ops.cc +++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc @@ -826,14 +826,34 @@ Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc, } Stream* stream = context->op_device_context()->stream(); + + Tensor seq_lengths_tensor; + DeviceMemory seq_lengths_ptr; + if (sequence_lengths != nullptr) { + TF_RETURN_IF_ERROR(context->allocate_temp( + DT_INT32, {static_cast(seq_lengths.size())}, + &seq_lengths_tensor)); + seq_lengths_ptr = AsDeviceMemory(&seq_lengths_tensor); + if (!stream + ->ThenMemcpy(&seq_lengths_ptr, seq_lengths.data(), + seq_lengths.size() * sizeof(int)) + .ok()) { + return errors::InvalidArgument( + "Failed to copy memory from host to " + "device for sequence_lengths in " + "CudnnRNNV3"); + } + } + bool launch_success = stream - ->ThenRnnForward(rnn_desc, *input_desc, input_data, *h_state_desc, - input_h_data, *c_state_desc, input_c_data, - params_data, *output_desc, &output_data, - *h_state_desc, &output_h_data, *c_state_desc, - &output_c_data, is_training, reserve_space_allocator, - workspace_allocator, output_profile_result) + ->ThenRnnForward(rnn_desc, *input_desc, input_data, seq_lengths_ptr, + *h_state_desc, input_h_data, *c_state_desc, + input_c_data, params_data, *output_desc, + &output_data, *h_state_desc, &output_h_data, + *c_state_desc, &output_c_data, is_training, + reserve_space_allocator, workspace_allocator, + output_profile_result) .ok(); return launch_success ? Status::OK() @@ -905,17 +925,36 @@ Status DoBackward( // Creates a memory callback for the workspace. The memory lives to the end // of this kernel calls. Stream* stream = context->op_device_context()->stream(); + + Tensor seq_lengths_tensor; + DeviceMemory seq_lengths_ptr; + if (sequence_lengths != nullptr) { + TF_RETURN_IF_ERROR(context->allocate_temp( + DT_INT32, {static_cast(seq_lengths.size())}, + &seq_lengths_tensor)); + seq_lengths_ptr = AsDeviceMemory(&seq_lengths_tensor); + if (!stream + ->ThenMemcpy(&seq_lengths_ptr, seq_lengths.data(), + seq_lengths.size() * sizeof(int)) + .ok()) { + return errors::InvalidArgument( + "Failed to copy memory from host to " + "device for sequence_lengths in " + "CudnnRNNBackwardOpV3"); + } + } + bool launch_success = stream ->ThenRnnBackward( - rnn_desc, *input_desc, input_data, *h_state_desc, input_h_data, - *c_state_desc, input_c_data, params_data, *output_desc, - output_data, *h_state_desc, output_h_data, *c_state_desc, - output_c_data, output_backprop_data, output_h_backprop_data, - output_c_backprop_data, &input_backprop_data, - &input_h_backprop_data, &input_c_backprop_data, - ¶ms_backprop_data, &reserve_space_uint8, workspace_allocator, - output_profile_result) + rnn_desc, *input_desc, input_data, seq_lengths_ptr, *h_state_desc, + input_h_data, *c_state_desc, input_c_data, params_data, + *output_desc, output_data, *h_state_desc, output_h_data, + *c_state_desc, output_c_data, output_backprop_data, + output_h_backprop_data, output_c_backprop_data, + &input_backprop_data, &input_h_backprop_data, + &input_c_backprop_data, ¶ms_backprop_data, + &reserve_space_uint8, workspace_allocator, output_profile_result) .ok(); return launch_success ? Status::OK() diff --git a/tensorflow/core/kernels/cwise_op_rsqrt.cc b/tensorflow/core/kernels/cwise_op_rsqrt.cc index 4741b7eb438..996263a3a52 100644 --- a/tensorflow/core/kernels/cwise_op_rsqrt.cc +++ b/tensorflow/core/kernels/cwise_op_rsqrt.cc @@ -16,8 +16,14 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { + +#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \ + !defined(MLIR_GENERATED_EXPERIMENTAL_KERNELS_ENABLED) REGISTER5(UnaryOp, CPU, "Rsqrt", functor::rsqrt, float, Eigen::half, double, complex64, complex128); +#else +REGISTER2(UnaryOp, CPU, "Rsqrt", functor::rsqrt, complex64, complex128); +#endif #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 448a5def807..91880a53d8f 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -25,6 +25,32 @@ package( licenses = ["notice"], # Apache 2.0 ) +# Export a few files for use on Android. +exports_files([ + "batch_dataset_op.cc", + "batch_dataset_op.h", + "captured_function.cc", + "captured_function.h", + "dataset_utils.cc", + "dataset_utils.h", + "iterator_ops.cc", + "iterator_ops.h", + "map_dataset_op.cc", + "map_dataset_op.h", + "model_dataset_op.cc", + "model_dataset_op.h", + "name_utils.cc", + "name_utils.h", + "optimize_dataset_op.cc", + "optimize_dataset_op.h", + "optional_ops.cc", + "optional_ops.h", + "stats_utils.cc", + "stats_utils.h", + "unbounded_thread_pool.cc", + "unbounded_thread_pool.h", +]) + tf_kernel_library( name = "batch_dataset_op", srcs = ["batch_dataset_op.cc"], diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index b8aef6c56e6..157fab74221 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -570,13 +570,7 @@ Status CapturedFunction::AddToGraph( other_arguments_types->reserve(captured_inputs_.size()); for (const Tensor& t : captured_inputs_) { Node* node; - DatasetBase* input; - Status s = GetDatasetFromVariantTensor(t, &input); - if (s.ok()) { - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); - } else { - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); - } + TF_RETURN_IF_ERROR(b->AddDatasetOrTensor(ctx, t, &node)); other_arguments->emplace_back(node); other_arguments_types->emplace_back(t.dtype()); } diff --git a/tensorflow/core/kernels/data/dataset_ops.cc b/tensorflow/core/kernels/data/dataset_ops.cc index 597e2587e66..77daeef145f 100644 --- a/tensorflow/core/kernels/data/dataset_ops.cc +++ b/tensorflow/core/kernels/data/dataset_ops.cc @@ -82,6 +82,7 @@ void DatasetToGraphOp::Compute(OpKernelContext* ctx) { DatasetBase* dataset; OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset)); SerializationContext::Params params; + params.resource_mgr = ctx->resource_manager(); params.external_state_policy = external_state_policy_; GraphDef graph_def; diff --git a/tensorflow/core/kernels/data/experimental/data_service_ops.cc b/tensorflow/core/kernels/data/experimental/data_service_ops.cc index 4d993d9462f..cf5e3edf009 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_ops.cc +++ b/tensorflow/core/kernels/data/experimental/data_service_ops.cc @@ -52,6 +52,7 @@ void RegisterDatasetOp::Compute(OpKernelContext* ctx) { errors::InvalidArgument(kProtocol, " must be non-empty.")); SerializationContext::Params params; + params.resource_mgr = ctx->resource_manager(); params.external_state_policy = external_state_policy_; SerializationContext serialization_ctx(params); GraphDef graph_def; diff --git a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc index ebe6871182d..0afadb540db 100644 --- a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc @@ -218,11 +218,14 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { if (num_parallel_calls_->value == model::kAutotune) { num_parallel_calls_->value = ctx->runner_threadpool_size(); } + cancellation_manager_ = absl::make_unique(); TF_RETURN_IF_ERROR(RegisterCancellationCallback( ctx->cancellation_manager(), [this]() { CancelThreads(/*wait=*/false); }, &deregister_fn_)); - TF_RETURN_IF_ERROR( - dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_)); + IteratorContext::Params params(ctx); + params.cancellation_manager = cancellation_manager_.get(); + TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator( + IteratorContext(params), this, prefix(), &input_impl_)); return dataset()->captured_func_->Instantiate( ctx, &instantiated_captured_func_); } @@ -477,6 +480,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { } void CancelThreads(bool wait) TF_LOCKS_EXCLUDED(mu_) { + cancellation_manager_->StartCancel(); mutex_lock l(*mu_); cancelled_ = true; cond_var_->notify_all(); @@ -658,6 +662,9 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { // Identifies the maximum number of parallel calls. const std::shared_ptr num_parallel_calls_; + // Controls cancellation of `input_impl_`. Must be ordered before + // `input_impl_` so that `input_impl_` is destroyed first. + std::unique_ptr cancellation_manager_; // Counts the number of outstanding calls for this batch. int64 num_calls_ TF_GUARDED_BY(*mu_) = 0; // Counts the total number of calls. diff --git a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc index 01dfc044b70..33ce77566ff 100644 --- a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc @@ -279,13 +279,19 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { workers_(dataset()->num_threads()), worker_thread_states_(dataset()->num_threads()) {} - ~Iterator() override { CancelThreads(); } + ~Iterator() override { + CancelThreads(); + if (deregister_fn_) deregister_fn_(); + } + // TODO(jsimsa): Register cancellation callback once the implementation is + // refactored not to hold mu_ while calling `GetNext` on the input. Status Initialize(IteratorContext* ctx) override { - // TODO(jsimsa): Register cancellation callback once the implementation is - // refactored not to hold mu_ while calling `GetNext` on the input. - TF_RETURN_IF_ERROR( - dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_)); + cancellation_manager_ = absl::make_unique(); + IteratorContext::Params params(ctx); + params.cancellation_manager = cancellation_manager_.get(); + TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator( + IteratorContext(params), this, prefix(), &input_impl_)); return dataset()->captured_func_->Instantiate( ctx, &instantiated_captured_func_); } @@ -647,6 +653,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { }; void CancelThreads() TF_LOCKS_EXCLUDED(mu_) { + cancellation_manager_->StartCancel(); mutex_lock l(mu_); cancelled_ = true; for (auto& worker : workers_) { @@ -1123,6 +1130,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { // `ckpt_mu_` in either shared or exclusive modes. mutex ckpt_mu_; + // Controls cancellation of `input_impl_`. Must be ordered before + // `input_impl_` so that `input_impl_` is destroyed first. + std::unique_ptr cancellation_manager_; + // The iterator producing elements which are converted to datasets by // the dataset()->captured_func_ then interleaved together. // input_impl_ is reset when we have exhausted its input. @@ -1154,6 +1165,9 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { // threads have exited before any other members are deallocated. // TODO(b/65178177): Avoid allocating additional threads. std::vector> worker_threads_ TF_GUARDED_BY(mu_); + + // Method for deregistering the cancellation callback. + std::function deregister_fn_; }; const DatasetBase* const input_; diff --git a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc index 2c7a9a77268..6d2b558329a 100644 --- a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc @@ -20,6 +20,8 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/stringprintf.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/util/work_sharder.h" @@ -278,7 +280,11 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel { int64 max_intra_op_parallelism) : DatasetBase(DatasetContext(ctx)), input_(input), - max_intra_op_parallelism_(max_intra_op_parallelism) { + max_intra_op_parallelism_(max_intra_op_parallelism), + traceme_metadata_( + {{"parallelism", + strings::Printf("%lld", static_cast( + max_intra_op_parallelism_))}}) { input_->Ref(); } @@ -370,12 +376,17 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } + TraceMeMetadata GetTraceMeMetadata() const override { + return dataset()->traceme_metadata_; + } + private: std::unique_ptr input_impl_; }; const DatasetBase* const input_; const int64 max_intra_op_parallelism_; + const TraceMeMetadata traceme_metadata_; }; }; @@ -389,8 +400,8 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel { int64 num_threads = 0; OP_REQUIRES_OK( ctx, ParseScalarArgument(ctx, "num_threads", &num_threads)); - OP_REQUIRES(ctx, num_threads >= 1, - errors::InvalidArgument("`num_threads` must be >= 1")); + OP_REQUIRES(ctx, num_threads >= 0, + errors::InvalidArgument("`num_threads` must be >= 0")); *output = new Dataset(ctx, input, num_threads); } @@ -400,9 +411,12 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel { Dataset(OpKernelContext* ctx, const DatasetBase* input, int num_threads) : DatasetBase(DatasetContext(ctx)), input_(input), - num_threads_(num_threads) { + num_threads_(num_threads == 0 ? port::MaxParallelism() : num_threads), + traceme_metadata_( + {{"num_threads", strings::Printf("%lld", static_cast( + num_threads_))}}) { thread_pool_ = absl::make_unique( - ctx->env(), ThreadOptions{}, "data_private_threadpool", num_threads, + ctx->env(), ThreadOptions{}, "data_private_threadpool", num_threads_, /*low_latency_hint=*/false); input_->Ref(); } @@ -496,12 +510,17 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } + TraceMeMetadata GetTraceMeMetadata() const override { + return dataset()->traceme_metadata_; + } + private: std::unique_ptr input_impl_; }; const DatasetBase* const input_; const int64 num_threads_; + const TraceMeMetadata traceme_metadata_; std::unique_ptr thread_pool_; }; }; diff --git a/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc b/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc index 426087d4dc2..66971307abf 100644 --- a/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc @@ -196,10 +196,14 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { if (num_parallel_calls_->value == model::kAutotune) { num_parallel_calls_->value = ctx->runner_threadpool_size(); } + cancellation_manager_ = absl::make_unique(); TF_RETURN_IF_ERROR(RegisterCancellationCallback( ctx->cancellation_manager(), [this]() { CancelThreads(/*wait=*/false); }, &deregister_fn_)); - return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); + IteratorContext::Params params(ctx); + params.cancellation_manager = cancellation_manager_.get(); + return dataset()->input_->MakeIterator(IteratorContext(params), this, + prefix(), &input_impl_); } Status GetNextInternal(IteratorContext* ctx, @@ -367,7 +371,6 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { if (batch_elements->empty()) { CallCompleted(ctx, result); - DCHECK(end_of_input); return; } @@ -388,6 +391,7 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { } void CancelThreads(bool wait) TF_LOCKS_EXCLUDED(mu_) { + cancellation_manager_->StartCancel(); mutex_lock l(*mu_); cancelled_ = true; cond_var_->notify_all(); @@ -548,6 +552,9 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { const std::shared_ptr num_parallel_calls_; const bool deterministic_; + // Controls cancellation of `input_impl_`. Must be ordered before + // `input_impl_` so that `input_impl_` is destroyed first. + std::unique_ptr cancellation_manager_; // Counts the number of outstanding calls for this batch. int64 num_calls_ TF_GUARDED_BY(*mu_) = 0; std::unique_ptr input_impl_; diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index 20787fa959f..fffded0bf25 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -319,9 +319,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { ~ParallelInterleaveIterator() override { CancelThreads(/*wait=*/true); - if (deregister_fn_) deregister_fn_(); } + // TODO(jsimsa): Register cancellation callback once the implementation is + // refactored not to hold mu_ while calling `GetNext` on the input. Status Initialize(IteratorContext* ctx) override { mutex_lock l(*mu_); // Note that if `ctx->thread_pool()` is non-null, then instead of creating @@ -344,11 +345,12 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { if (num_parallel_calls_->value == model::kAutotune) { num_parallel_calls_->value = dataset()->cycle_length_; } - // TODO(jsimsa): Register cancellation callback once the implementation is - // refactored not to hold mu_ while calling `GetNext` on the input. ctx_ = std::make_unique(*ctx); - TF_RETURN_IF_ERROR( - dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_)); + cancellation_manager_ = absl::make_unique(); + IteratorContext::Params params(ctx); + params.cancellation_manager = cancellation_manager_.get(); + TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator( + IteratorContext(params), this, prefix(), &input_impl_)); return dataset()->captured_func_->Instantiate( ctx, &instantiated_captured_func_); } @@ -563,6 +565,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { // cancelled. Optionally, the method waits until all threads finish // executing. void CancelThreads(bool wait) TF_LOCKS_EXCLUDED(mu_) { + cancellation_manager_->StartCancel(); mutex_lock l(*mu_); cancelled_ = true; // Wake up all threads so that they can exit. This will also wake up any @@ -1503,6 +1506,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { // Determines whether outputs can be produced in deterministic order. const bool deterministic_; + // Controls cancellation of `input_impl_`. Must be ordered before + // `input_impl_` so that `input_impl_` is destroyed first. + std::unique_ptr cancellation_manager_; + // Iterator for input elements. std::unique_ptr input_impl_ TF_GUARDED_BY(mu_); @@ -1550,9 +1557,6 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { // Identifies whether background threads should be cancelled. bool cancelled_ TF_GUARDED_BY(mu_) = false; - - // Method for deregistering the cancellation callback. - std::function deregister_fn_; }; const DatasetBase* const input_; diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index b2887f00cee..629a70d49ec 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -223,11 +223,11 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { num_parallel_calls_->value = ctx->runner_threadpool_size(); } cancellation_manager_ = absl::make_unique(); - IteratorContext::Params params(ctx); - params.cancellation_manager = cancellation_manager_.get(); TF_RETURN_IF_ERROR(RegisterCancellationCallback( ctx->cancellation_manager(), [this]() { CancelThreads(/*wait=*/false); }, &deregister_fn_)); + IteratorContext::Params params(ctx); + params.cancellation_manager = cancellation_manager_.get(); TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator( IteratorContext(params), this, prefix(), &input_impl_)); return dataset()->captured_func_->Instantiate( @@ -644,9 +644,8 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { const bool autotune_; // Counts the number of outstanding calls. int64 num_calls_ TF_GUARDED_BY(*mu_) = 0; - // Controls cancellation of `input_impl_`. - // Must be ordered before `input_impl_` so that `input_impl_` is destroyed - // first. + // Controls cancellation of `input_impl_`. Must be ordered before + // `input_impl_` so that `input_impl_` is destroyed first. std::unique_ptr cancellation_manager_; std::unique_ptr instantiated_captured_func_; // Must be ordered after `cancellation_manager_` so that `input_impl_` is diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc index 55ed5a887c1..d2ac18bb3e8 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc @@ -157,10 +157,14 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { if (buffer_size_->value == model::kAutotune) { buffer_size_->value = buffer_size_min_; } + cancellation_manager_ = absl::make_unique(); TF_RETURN_IF_ERROR(RegisterCancellationCallback( ctx->cancellation_manager(), [this]() { CancelThreads(); }, &deregister_fn_)); - return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); + IteratorContext::Params params(ctx); + params.cancellation_manager = cancellation_manager_.get(); + return dataset()->input_->MakeIterator(IteratorContext(params), this, + prefix(), &input_impl_); } Status GetNextInternal(IteratorContext* ctx, @@ -360,6 +364,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { } void CancelThreads() TF_LOCKS_EXCLUDED(mu_) { + cancellation_manager_->StartCancel(); mutex_lock l(*mu_); cancelled_ = true; cond_var_->notify_all(); @@ -558,6 +563,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { // accessing the input iterator. We keep this separate from `mu_` to allow // prefetching to run in parallel with GetNext calls. mutex input_mu_ TF_ACQUIRED_BEFORE(*mu_); + // Controls cancellation of `input_impl_`. Must be ordered before + // `input_impl_` so that `input_impl_` is destroyed first. + std::unique_ptr cancellation_manager_; std::unique_ptr input_impl_ TF_GUARDED_BY(input_mu_); const std::shared_ptr cond_var_; const int64 buffer_size_min_; diff --git a/tensorflow/core/kernels/data/serialization_utils.cc b/tensorflow/core/kernels/data/serialization_utils.cc index 628d6952c6d..20833b74f31 100644 --- a/tensorflow/core/kernels/data/serialization_utils.cc +++ b/tensorflow/core/kernels/data/serialization_utils.cc @@ -57,6 +57,7 @@ Status AsGraphDefMinimal(OpKernelContext* ctx, const DatasetBase* input, std::vector>* input_list, GraphDef* result, string* dataset_node) { SerializationContext::Params params; + params.resource_mgr = ctx->resource_manager(); params.input_list = input_list; params.external_state_policy = SerializationContext::ExternalStatePolicy::kIgnore; diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc index 525f7940351..b246d4a8e76 100644 --- a/tensorflow/core/kernels/lookup_table_op.cc +++ b/tensorflow/core/kernels/lookup_table_op.cc @@ -26,10 +26,16 @@ limitations under the License. #include "tensorflow/core/kernels/initializable_lookup_table.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/random.h" namespace tensorflow { namespace lookup { +std::string UniqueNodeName(const std::string& base) { + static std::atomic counter(0); + return strings::StrCat(base, "/", counter.fetch_add(1), "/", random::New64()); +} + // Lookup table that wraps an unordered_map, where the key and value data type // is specified. Each individual value must be a scalar. If vector values are // required, use MutableHashTableOfTensors. diff --git a/tensorflow/core/kernels/lookup_table_op.h b/tensorflow/core/kernels/lookup_table_op.h index ad1afdf8242..1a4b4b3c6f5 100644 --- a/tensorflow/core/kernels/lookup_table_op.h +++ b/tensorflow/core/kernels/lookup_table_op.h @@ -162,6 +162,9 @@ inline const ResourceHandle& SubtleMustCopyIfIntegral( return value; } +// Returns a unique node name starting with "base". +std::string UniqueNodeName(const std::string& base); + // Lookup table that wraps an flat_hash_map, where the key and value data type // is specified. // @@ -181,6 +184,54 @@ class HashTable : public InitializableLookupTable { public: HashTable(OpKernelContext* ctx, OpKernel* kernel) {} + Status AsGraphDef(GraphDefBuilder& builder, Node** out) const override { + // We set use_node_name_sharing with a unique node name so that the resource + // can outlive the HashTableV2 kernel. This means that the lifetime of the + // HashTable resource will be tied to the lifetime of the resource manager + // it is created in. + // TODO(b/181695913): Provide a mechanism for deleting this resource + // earlier when appropriate. + Node* hash_table_node = ops::SourceOp( + "HashTableV2", builder.opts() + .WithName(UniqueNodeName("HashTableFromGraphDef")) + .WithAttr("key_dtype", key_dtype()) + .WithAttr("value_dtype", value_dtype()) + .WithAttr("use_node_name_sharing", true)); + if (table_.empty()) { + *out = hash_table_node; + return Status::OK(); + } + int size = table_.size(); + Tensor keys(key_dtype(), TensorShape({size})); + Tensor values(value_dtype(), TensorShape({size})); + + auto keys_data = keys.flat(); + auto values_data = values.flat(); + int64 i = 0; + for (auto it = table_.begin(); it != table_.end(); ++it, ++i) { + keys_data(i) = it->first; + values_data(i) = it->second; + } + Node* keys_node = ops::SourceOp( + "Const", + builder.opts().WithAttr("dtype", key_dtype()).WithAttr("value", keys)); + Node* values_node = + ops::SourceOp("Const", builder.opts() + .WithAttr("dtype", value_dtype()) + .WithAttr("value", values)); + auto opts = builder.opts() + .WithAttr("Tin", key_dtype()) + .WithAttr("Tout", value_dtype()); + if (opts.HaveError()) return errors::Internal("Invalid builder opts"); + NodeBuilder node_builder(opts.GetNameForOp("LookupTableImportV2"), + "LookupTableImportV2", opts.op_registry()); + node_builder.Input(hash_table_node).Input(keys_node).Input(values_node); + Node* initialize_table = opts.FinalizeBuilder(&node_builder); + *out = ops::UnaryOp("Identity", hash_table_node, + builder.opts().WithControlInput(initialize_table)); + return Status::OK(); + } + size_t size() const override { if (!is_initialized()) return 0; diff --git a/tensorflow/core/kernels/mkl/mkl_aggregate_ops.cc b/tensorflow/core/kernels/mkl/mkl_aggregate_ops.cc index a6e1589f9b5..6bc87ddebab 100644 --- a/tensorflow/core/kernels/mkl/mkl_aggregate_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_aggregate_ops.cc @@ -175,7 +175,8 @@ class MklAddNOp : public OpKernel { } std::shared_ptr fwd_cpu_stream; - fwd_cpu_stream.reset(CreateStream(ctx, cpu_engine)); + MklDnnThreadPool eigen_tp(ctx); + fwd_cpu_stream.reset(CreateStream(&eigen_tp, cpu_engine)); // Create memory descriptor for MKL-DNN. // If all input in Tensorflow format, create block memory descriptor, diff --git a/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc index 1ce77f14b23..e255c23e053 100644 --- a/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc @@ -129,7 +129,8 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase { T* dst_data = output_tensor->flat().data(); std::shared_ptr fwd_cpu_stream; - fwd_cpu_stream.reset(CreateStream(context, pooling_fwd->GetEngine())); + MklDnnThreadPool eigen_tp(context); + fwd_cpu_stream.reset(CreateStream(&eigen_tp, pooling_fwd->GetEngine())); // Execute pooling op. pooling_fwd->Execute(src_data, dst_data, nullptr, fwd_cpu_stream); @@ -250,7 +251,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase { MklPoolingBwdPrimitiveFactory::Get(bwdParams); std::shared_ptr bwd_cpu_stream; - bwd_cpu_stream.reset(CreateStream(context, pooling_bwd->GetEngine())); + MklDnnThreadPool eigen_tp(context); + bwd_cpu_stream.reset(CreateStream(&eigen_tp, pooling_bwd->GetEngine())); Tensor* output_tensor = nullptr; this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()), orig_input_dims_mkl_order, diff --git a/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc index ec3f526bd9d..939282dc587 100644 --- a/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc @@ -144,7 +144,8 @@ class BatchMatMulMkl : public OpKernel { *params, false /* value for do_not_cache */); // Execute matmul primitive. std::shared_ptr cpu_stream; - cpu_stream.reset(CreateStream(ctx, matmul_prim->GetEngine())); + MklDnnThreadPool eigen_tp(ctx); + cpu_stream.reset(CreateStream(&eigen_tp, matmul_prim->GetEngine())); matmul_prim->Execute(lhs.flat().data(), rhs.flat().data(), out->flat().data(), cpu_stream); } diff --git a/tensorflow/core/kernels/mkl/mkl_concat_op.cc b/tensorflow/core/kernels/mkl/mkl_concat_op.cc index 3a2861e9407..82208c2f64c 100644 --- a/tensorflow/core/kernels/mkl/mkl_concat_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_concat_op.cc @@ -732,7 +732,8 @@ class MklConcatOp : public OpKernel { DCHECK(dst_tensor != nullptr) << "Output tensor pointer is NULL"; std::shared_ptr fwd_cpu_stream; - fwd_cpu_stream.reset(CreateStream(context, cpu_engine)); + MklDnnThreadPool eigen_tp(context); + fwd_cpu_stream.reset(CreateStream(&eigen_tp, cpu_engine)); if (dnn_shape_dst.IsMklTensor()) dst_md = dnn_shape_dst.GetMklLayout(); @@ -769,7 +770,9 @@ class MklConcatOp : public OpKernel { dst_md = dnn_shape_dst.IsMklTensor() ? dnn_shape_dst.GetMklLayout() : dst_md; std::shared_ptr fwd_cpu_stream; - fwd_cpu_stream.reset(CreateStream(context, concat_fwd->GetEngine())); + MklDnnThreadPool eigen_tp(context); + fwd_cpu_stream.reset( + CreateStream(&eigen_tp, concat_fwd->GetEngine())); dst.SetUsrMem(dst_md, dst_tensor); dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream); // Execute concat diff --git a/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc index ec17b7891f8..bba1d167c82 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc @@ -585,7 +585,9 @@ class MklConvCustomBackpropFilterOp // Execute convolution backward filter. std::shared_ptr bwd_cpu_stream; - bwd_cpu_stream.reset(CreateStream(context, conv_bwd_filter->GetEngine())); + MklDnnThreadPool eigen_tp(context); + bwd_cpu_stream.reset( + CreateStream(&eigen_tp, conv_bwd_filter->GetEngine())); if (bias_enabled) { T* diff_bias_data = static_cast(const_cast(diff_bias_tensor->flat().data())); diff --git a/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc index 55f1eed7e76..97b7ac1d960 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc @@ -482,7 +482,9 @@ class MklConvCustomBackpropInputOp } std::shared_ptr bwd_cpu_stream; - bwd_cpu_stream.reset(CreateStream(context, conv_bwd_input->GetEngine())); + MklDnnThreadPool eigen_tp(context); + bwd_cpu_stream.reset( + CreateStream(&eigen_tp, conv_bwd_input->GetEngine())); // Execute conv bwd input primitive. conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data, bwd_cpu_stream); diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc index f76659955a2..f6d31a84e61 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc @@ -724,7 +724,8 @@ class MklConvOp : public OpKernel { // Execute convolution std::shared_ptr fwd_cpu_stream; - fwd_cpu_stream.reset(CreateStream(context, conv_fwd->GetEngine())); + MklDnnThreadPool eigen_tp(context); + fwd_cpu_stream.reset(CreateStream(&eigen_tp, conv_fwd->GetEngine())); if (fuse_biasadd_) { const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias); Tbias* bias_data = diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops_test.cc b/tensorflow/core/kernels/mkl/mkl_conv_ops_test.cc index ac0809d6fa5..e1b71ebd34b 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops_test.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops_test.cc @@ -26,10 +26,7 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/public/session.h" - -#if defined(INTEL_MKL_DNN_ONLY) #include "tensorflow/core/util/mkl_util.h" -#endif // TODO(ezhulenev): Add numerical tests that will compare results of default // (aka Eigen) convolutions with MKL convolutions. @@ -123,7 +120,6 @@ static Graph* DefaultConv2D(const Conv2DDimensions& dims) { return graph; } -#if defined(INTEL_MKL_DNN_ONLY) static Graph* MklConv2D(const Conv2DDimensions& dims) { auto* graph = new Graph(OpRegistry::Global()); @@ -150,7 +146,6 @@ static Graph* MklConv2D(const Conv2DDimensions& dims) { return graph; } -#endif static Graph* DefaultConv2DBwdInput(const Conv2DDimensions& dims) { auto* graph = new Graph(OpRegistry::Global()); @@ -179,7 +174,6 @@ static Graph* DefaultConv2DBwdInput(const Conv2DDimensions& dims) { return graph; } -#if defined(INTEL_MKL_DNN_ONLY) static Graph* MklConv2DBwdInput(const Conv2DDimensions& dims) { auto* graph = new Graph(OpRegistry::Global()); @@ -213,7 +207,6 @@ static Graph* MklConv2DBwdInput(const Conv2DDimensions& dims) { return graph; } -#endif static Graph* DefaultConv2DBwdFilter(const Conv2DDimensions& dims) { auto* graph = new Graph(OpRegistry::Global()); @@ -243,7 +236,6 @@ static Graph* DefaultConv2DBwdFilter(const Conv2DDimensions& dims) { return graph; } -#if defined(INTEL_MKL_DNN_ONLY) static Graph* MklConv2DBwdFilter(const Conv2DDimensions& dims) { Graph* graph = new Graph(OpRegistry::Global()); @@ -278,7 +270,6 @@ static Graph* MklConv2DBwdFilter(const Conv2DDimensions& dims) { return graph; } -#endif // Macro arguments names: --------------------------------------------------- // // N: batch size @@ -297,74 +288,65 @@ static Graph* MklConv2DBwdFilter(const Conv2DDimensions& dims) { // Flops computation in these benchmarks are the same as in // eigen_benchmark_cpu_test.cc. -#define BM_Conv2DT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \ - static void BM_NAME(Conv2D_##kind, type, N, H, W, C, FC, FH, \ - FW)(int iters) { \ - testing::SetLabel(LABEL); \ - \ - int64 num_computed_elements = (N) * (H) * (W) * (FC); \ - int64 flops_per_iter = num_computed_elements * ((C) * (FH) * (FW)); \ - testing::ItemsProcessed(static_cast(iters) * flops_per_iter); \ - \ - Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \ - test::Benchmark(#type, BM_CONCAT(kind, Conv2D)(dims)).Run(iters); \ - } \ +#define BM_Conv2DT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \ + static void BM_NAME(Conv2D_##kind, type, N, H, W, C, FC, FH, \ + FW)(::testing::benchmark::State & state) { \ + testing::SetLabel(LABEL); \ + \ + int64 num_computed_elements = (N) * (H) * (W) * (FC); \ + int64 flops_per_iter = num_computed_elements * ((C) * (FH) * (FW)); \ + \ + Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \ + test::Benchmark(#type, BM_CONCAT(kind, Conv2D)(dims), \ + /*old_benchmark_api*/ false) \ + .Run(state); \ + testing::ItemsProcessed(state.iterations() * flops_per_iter); \ + } \ BENCHMARK(BM_NAME(Conv2D_##kind, type, N, H, W, C, FC, FH, FW)) -#if defined(INTEL_MKL_DNN_ONLY) #define BM_Conv2D(N, H, W, C, FC, FH, FW, type, LABEL) \ BM_Conv2DT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \ BM_Conv2DT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL); -#else -#define BM_Conv2D(N, H, W, C, FC, FH, FW, type, LABEL) \ - BM_Conv2DT(Default, N, H, W, C, FC, FH, FW, type, LABEL); -#endif -#define BM_Conv2DBwdInputT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \ - static void BM_NAME(Conv2DBwdInput_##kind, type, N, H, W, C, FC, FH, \ - FW)(int iters) { \ - testing::SetLabel(LABEL); \ - \ - int64 num_computed_elements = (N) * (H) * (W) * (C); \ - int64 flops_per_iter = num_computed_elements * ((C) * (FH) * (FW)); \ - testing::ItemsProcessed(static_cast(iters) * flops_per_iter); \ - \ - Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \ - test::Benchmark(#type, BM_CONCAT(kind, Conv2DBwdInput)(dims)).Run(iters); \ - } \ +#define BM_Conv2DBwdInputT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \ + static void BM_NAME(Conv2DBwdInput_##kind, type, N, H, W, C, FC, FH, \ + FW)(::testing::benchmark::State & state) { \ + testing::SetLabel(LABEL); \ + \ + int64 num_computed_elements = (N) * (H) * (W) * (C); \ + int64 flops_per_iter = num_computed_elements * ((C) * (FH) * (FW)); \ + \ + Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \ + test::Benchmark(#type, BM_CONCAT(kind, Conv2DBwdInput)(dims), \ + /*old_benchmark_api*/ false) \ + .Run(state); \ + testing::ItemsProcessed(state.iterations() * flops_per_iter); \ + } \ BENCHMARK(BM_NAME(Conv2DBwdInput_##kind, type, N, H, W, C, FC, FH, FW)) -#if defined(INTEL_MKL_DNN_ONLY) #define BM_Conv2DBwdInput(N, H, W, C, FC, FH, FW, type, LABEL) \ BM_Conv2DBwdInputT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \ BM_Conv2DBwdInputT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL); -#else -#define BM_Conv2DBwdInput(N, H, W, C, FC, FH, FW, type, LABEL) \ - BM_Conv2DBwdInputT(Default, N, H, W, C, FC, FH, FW, type, LABEL); -#endif -#define BM_Conv2DBwdFilterT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \ - static void BM_NAME(Conv2DBwdFilter_##kind, type, N, H, W, C, FC, FH, \ - FW)(int iters) { \ - testing::SetLabel(LABEL); \ - \ - int64 num_computed_elements = (FH) * (FW) * (C) * (FC); \ - int64 flops_per_iter = num_computed_elements * ((N) * (H) * (W)); \ - testing::ItemsProcessed(static_cast(iters) * flops_per_iter); \ - \ - Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \ - test::Benchmark(#type, BM_CONCAT(kind, Conv2DBwdFilter)(dims)).Run(iters); \ - } \ +#define BM_Conv2DBwdFilterT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \ + static void BM_NAME(Conv2DBwdFilter_##kind, type, N, H, W, C, FC, FH, \ + FW)(::testing::benchmark::State & state) { \ + testing::SetLabel(LABEL); \ + \ + int64 num_computed_elements = (FH) * (FW) * (C) * (FC); \ + int64 flops_per_iter = num_computed_elements * ((N) * (H) * (W)); \ + \ + Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \ + test::Benchmark(#type, BM_CONCAT(kind, Conv2DBwdFilter)(dims), \ + /*old_benchmark_api*/ false) \ + .Run(state); \ + testing::ItemsProcessed(state.iterations() * flops_per_iter); \ + } \ BENCHMARK(BM_NAME(Conv2DBwdFilter_##kind, type, N, H, W, C, FC, FH, FW)) -#if defined(INTEL_MKL_DNN_ONLY) #define BM_Conv2DBwdFilter(N, H, W, C, FC, FH, FW, type, LABEL) \ BM_Conv2DBwdFilterT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \ BM_Conv2DBwdFilterT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL); -#else -#define BM_Conv2DBwdFilter(N, H, W, C, FC, FH, FW, type, LABEL) \ - BM_Conv2DBwdFilterT(Default, N, H, W, C, FC, FH, FW, type, LABEL); -#endif // ImageNet Convolutions ---------------------------------------------------- // diff --git a/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc b/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc index 5557d4ecdeb..fcc38986ecc 100644 --- a/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc @@ -75,7 +75,8 @@ class MklDequantizeOp : public OpKernel { MklDnnData dst(&cpu_engine); std::shared_ptr reorder_stream; - reorder_stream.reset(CreateStream(ctx, cpu_engine)); + MklDnnThreadPool eigen_tp(ctx); + reorder_stream.reset(CreateStream(&eigen_tp, cpu_engine)); // If input is in MKL layout, then simply grab input layout; otherwise, // construct input TF layout. For TF layout, although input shape diff --git a/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc index b769e8a93d8..e498651697d 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc @@ -866,7 +866,8 @@ class MklFusedBatchNormOp : public OpKernel { // Execute std::shared_ptr fwd_cpu_stream; - fwd_cpu_stream.reset(CreateStream(context, bn_fwd->GetEngine())); + MklDnnThreadPool eigen_tp(context); + fwd_cpu_stream.reset(CreateStream(&eigen_tp, bn_fwd->GetEngine())); bn_fwd->Execute(src_data, weights_op_data, dst_data, mean_op_data, variance_op_data, fwd_cpu_stream, ws_data); float adjust_factor = 1.0; @@ -1272,7 +1273,8 @@ class MklFusedBatchNormGradOp : public OpKernel { // Execute std::shared_ptr bwd_cpu_stream; - bwd_cpu_stream.reset(CreateStream(context, bn_bwd->GetEngine())); + MklDnnThreadPool eigen_tp(context); + bwd_cpu_stream.reset(CreateStream(&eigen_tp, bn_bwd->GetEngine())); bn_bwd->Execute(src_data, mean_data, variance_data, diff_dst_data, weights_data, diff_src_data, diff_weights_data, res_space_data, bwd_cpu_stream); diff --git a/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc b/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc index cf800eec6cc..af2b61d4c86 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc @@ -888,11 +888,7 @@ TYPED_TEST_P(FusedPadConvOpTest, PaddingConvTestNchw) { this->Run("NCHW"); } REGISTER_TYPED_TEST_SUITE_P(FusedPadConvOpTest, PaddingConvTest, PaddingConvTestNchw); -#ifdef ENABLE_INTEL_MKL_BFLOAT16 using FusedPadConvDataTypes = ::testing::Types; -#else -using FusedPadConvDataTypes = ::testing::Types; -#endif INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedPadConvOpTest, FusedPadConvDataTypes); class FilterCacheTest : public OpsTestBase { diff --git a/tensorflow/core/kernels/mkl/mkl_lrn_op.cc b/tensorflow/core/kernels/mkl/mkl_lrn_op.cc index c315385ddae..456977c7e6f 100644 --- a/tensorflow/core/kernels/mkl/mkl_lrn_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_lrn_op.cc @@ -167,7 +167,8 @@ class MklLRNOp : public OpKernel { src_dnn_data.CheckReorderToOpMem(lrn_prim_desc.src_desc(), cpu_engine_); std::vector net; - fwd_stream_.reset(CreateStream(context, cpu_engine_)); + MklDnnThreadPool eigen_tp(context); + fwd_stream_.reset(CreateStream(&eigen_tp, cpu_engine_)); net.push_back(lrn_forward(lrn_prim_desc)); std::vector> net_args; net_args.push_back({{MKLDNN_ARG_SRC, src_dnn_data.GetOpMem()}, diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl/mkl_matmul_op.cc index e0aa9944583..2e11f9242b4 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_matmul_op.cc @@ -156,18 +156,15 @@ class MklMatMulOp : public OpKernel { char char_transb = transb ? 'T' : 'N'; VLOG(2) << "MKL DNN SGEMM called"; #ifdef ENABLE_MKLDNN_THREADPOOL - auto eigen_tp = - MklDnnThreadPoolWrapper::GetInstance().CreateThreadPoolPtr(ctx); - + MklDnnThreadPool eigen_tp(ctx); dnnl_sgemm_tp(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb, - beta, c, ldc, eigen_tp); + beta, c, ldc, &eigen_tp); #else dnnl_sgemm(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); #endif // ENABLE_MKLDNN_THREADPOOL } -#ifdef ENABLE_INTEL_MKL_BFLOAT16 void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m, const int n, const int k, const bfloat16* a, const int lda, const bfloat16* b, const int ldb, bfloat16* c, @@ -181,7 +178,6 @@ class MklMatMulOp : public OpKernel { dnnl_gemm(ftrans[index_transa], ftrans[index_transb], m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, ctx); } -#endif // ENABLE_INTEL_MKL_BFLOAT16 }; #define REGISTER_CPU(T) \ @@ -196,9 +192,7 @@ class MklMatMulOp : public OpKernel { // TODO(inteltf) Consider template specialization when adding/removing // additional types TF_CALL_float(REGISTER_CPU); -#ifdef ENABLE_INTEL_MKL_BFLOAT16 TF_CALL_bfloat16(REGISTER_CPU); -#endif // ENABLE_INTEL_MKL_BFLOAT16 #endif // ENABLE_MKL } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc b/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc index 0acd94cdc6e..de334f3c8d2 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc +++ b/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc @@ -249,7 +249,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { } } std::shared_ptr cpu_stream; - cpu_stream.reset(CreateStream(ctx, matmul_prim->GetEngine())); + MklDnnThreadPool eigen_tp(ctx); + cpu_stream.reset(CreateStream(&eigen_tp, matmul_prim->GetEngine())); // Execute fused matmul op. matmul_prim->Execute(src_data, weight_data, bias_data, dst_data, cpu_stream); diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h index 19175391091..dc915eeb606 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h +++ b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h @@ -712,7 +712,8 @@ void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k, // Execute matmul primitive. std::shared_ptr cpu_stream; - cpu_stream.reset(CreateStream(ctx, matmul_prim->GetEngine())); + MklDnnThreadPool eigen_tp(ctx); + cpu_stream.reset(CreateStream(&eigen_tp, matmul_prim->GetEngine())); matmul_prim->Execute(a, b, c, cpu_stream); } diff --git a/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc index 2acca5805a2..b16a2a50976 100644 --- a/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc @@ -153,7 +153,8 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase { T* dst_data = output_tensor->flat().data(); std::shared_ptr fwd_cpu_stream; - fwd_cpu_stream.reset(CreateStream(context, pooling_fwd->GetEngine())); + MklDnnThreadPool eigen_tp(context); + fwd_cpu_stream.reset(CreateStream(&eigen_tp, pooling_fwd->GetEngine())); if (int8_forward_inference) { // Execute pooling op @@ -304,7 +305,8 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase { MklPoolingBwdPrimitiveFactory::Get(bwdParams); std::shared_ptr bwd_cpu_stream; - bwd_cpu_stream.reset(CreateStream(context, pooling_bwd->GetEngine())); + MklDnnThreadPool eigen_tp(context); + bwd_cpu_stream.reset(CreateStream(&eigen_tp, pooling_bwd->GetEngine())); // Allocate output tensor and memory primitive. Tensor* output_tensor = nullptr; this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()), diff --git a/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc b/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc index b30264f27dc..5bfc2661e86 100644 --- a/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc @@ -289,7 +289,8 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase { } std::shared_ptr cpu_stream; - cpu_stream.reset(CreateStream(context, matmul_fwd->GetEngine())); + MklDnnThreadPool eigen_tp(context); + cpu_stream.reset(CreateStream(&eigen_tp, matmul_fwd->GetEngine())); // Execute inner-product Tbias* bias_data = this->GetBiasHandle( context, matmul_fwd_pd, bias_tensor, weight_tensor, cpu_stream); diff --git a/tensorflow/core/kernels/mkl/mkl_quantize_op.cc b/tensorflow/core/kernels/mkl/mkl_quantize_op.cc index 57ae4132e34..a96d1a59be0 100644 --- a/tensorflow/core/kernels/mkl/mkl_quantize_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_quantize_op.cc @@ -474,7 +474,8 @@ class MklQuantizeV2Op : public OpKernel { MklReorderWithScalePrimitiveFactory::Get(src.GetUsrMem(), dst.GetUsrMem(), fwdParams); std::shared_ptr cpu_stream; - cpu_stream.reset(CreateStream(ctx, reorder_prim->GetEngine())); + MklDnnThreadPool eigen_tp(ctx); + cpu_stream.reset(CreateStream(&eigen_tp, reorder_prim->GetEngine())); reorder_prim->Execute(src.GetUsrMemDataHandle(), dst.GetUsrMemDataHandle(), cpu_stream); diff --git a/tensorflow/core/kernels/mkl/mkl_relu_op.cc b/tensorflow/core/kernels/mkl/mkl_relu_op.cc index 299b64b8560..ffbc6697742 100644 --- a/tensorflow/core/kernels/mkl/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_relu_op.cc @@ -479,7 +479,8 @@ class MklReluOpBase : public OpKernel { MklEltwiseFwdPrimitiveFactory::Get(fwdParams); auto eltwise_fwd_pd = eltwise_fwd->GetEltwiseFwdPd(); std::shared_ptr fwd_cpu_stream; - fwd_cpu_stream.reset(CreateStream(context, eltwise_fwd->GetEngine())); + MklDnnThreadPool eigen_tp(context); + fwd_cpu_stream.reset(CreateStream(&eigen_tp, eltwise_fwd->GetEngine())); // Check if src needs to be reordered bool is_src_reordered = false; const T* src_data = src_tensor.flat().data(); @@ -685,7 +686,8 @@ class MklReluGradOpBase : public OpKernel { auto eltwise_bwd_pd = eltwise_bwd->GetEltwiseBwdPd(); std::shared_ptr bwd_cpu_stream; - bwd_cpu_stream.reset(CreateStream(context, eltwise_bwd->GetEngine())); + MklDnnThreadPool eigen_tp(context); + bwd_cpu_stream.reset(CreateStream(&eigen_tp, eltwise_bwd->GetEngine())); // check whether need reorder for src / diff_dst const T* src_data = src_tensor.flat().data(); if (src_md != eltwise_bwd_pd->src_desc()) { diff --git a/tensorflow/core/kernels/mkl/mkl_relu_op_test.cc b/tensorflow/core/kernels/mkl/mkl_relu_op_test.cc index ac23a56d620..30d27c4b9f6 100644 --- a/tensorflow/core/kernels/mkl/mkl_relu_op_test.cc +++ b/tensorflow/core/kernels/mkl/mkl_relu_op_test.cc @@ -100,14 +100,17 @@ static Graph* Activation(const string& op_name, const string& kind, return graph; } -#define BM_Activation(op, kind, A, B, C, D, type) \ - static void BM_##op##_##kind##_##type##_##A##_##B##_##C##_##D(int iters) { \ - int64 num_computed_elements = (A) * (B) * (C) * (D); \ - int64 flops_per_iter = num_computed_elements; \ - testing::ItemsProcessed(static_cast(iters) * flops_per_iter); \ - \ - test::Benchmark(#type, Activation(#op, #kind, {A, B, C, D})).Run(iters); \ - } \ +#define BM_Activation(op, kind, A, B, C, D, type) \ + static void BM_##op##_##kind##_##type##_##A##_##B##_##C##_##D( \ + ::testing::benchmark::State& state) { \ + int64 num_computed_elements = (A) * (B) * (C) * (D); \ + int64 flops_per_iter = num_computed_elements; \ + \ + test::Benchmark(#type, Activation(#op, #kind, {A, B, C, D}), \ + /*old_benchmark_api*/ false) \ + .Run(state); \ + state.SetItemsProcessed(state.iterations() * flops_per_iter); \ + } \ BENCHMARK(BM_##op##_##kind##_##type##_##A##_##B##_##C##_##D) #define BM(op, A, B, C, D, type) \ diff --git a/tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc b/tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc index 40570a467ea..c0f9845cd4b 100644 --- a/tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc @@ -129,7 +129,8 @@ class MklRequantizePerChannelOp : public OpKernel { ReorderPd(cpu_engine_, input_mem_prim->get_desc(), cpu_engine_, output_mem_prim->get_desc(), reorder_attr); std::shared_ptr reorder_stream; - reorder_stream.reset(CreateStream(ctx, cpu_engine_)); + MklDnnThreadPool eigen_tp(ctx); + reorder_stream.reset(CreateStream(&eigen_tp, cpu_engine_)); std::unordered_map reorder_args = { {MKLDNN_ARG_FROM, *input_mem_prim}, {MKLDNN_ARG_TO, *output_mem_prim}}; diff --git a/tensorflow/core/kernels/mkl/mkl_slice_op.cc b/tensorflow/core/kernels/mkl/mkl_slice_op.cc index 6e277f27b4b..a956cf66d40 100644 --- a/tensorflow/core/kernels/mkl/mkl_slice_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_slice_op.cc @@ -453,7 +453,8 @@ class MklSliceOp : public OpKernel { MklSlicePrimitiveFactory::Get(sliceParams); // Execute slice reorder. std::shared_ptr slice_stream; - slice_stream.reset(CreateStream(context, reorder_prim->GetEngine())); + MklDnnThreadPool eigen_tp(context); + slice_stream.reset(CreateStream(&eigen_tp, reorder_prim->GetEngine())); reorder_prim->Execute(sliceParams, slice_stream); } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + diff --git a/tensorflow/core/kernels/mkl/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl/mkl_softmax_op.cc index 71837e9e91d..f436f0feec8 100644 --- a/tensorflow/core/kernels/mkl/mkl_softmax_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_softmax_op.cc @@ -298,7 +298,8 @@ class MklSoftmaxOp : public OpKernel { const T* src_data = src_tensor.flat().data(); T* dst_data = reinterpret_cast(output_tensor->flat().data()); std::shared_ptr fwd_cpu_stream; - fwd_cpu_stream.reset(CreateStream(context, softmax_fwd->GetEngine())); + MklDnnThreadPool eigen_tp(context); + fwd_cpu_stream.reset(CreateStream(&eigen_tp, softmax_fwd->GetEngine())); softmax_fwd->Execute(src_data, dst_data, fwd_cpu_stream); } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + diff --git a/tensorflow/core/kernels/mkl/mkl_tmp_bf16_ops.cc b/tensorflow/core/kernels/mkl/mkl_tmp_bf16_ops.cc index 2e07dfc08be..e60375bf0f5 100644 --- a/tensorflow/core/kernels/mkl/mkl_tmp_bf16_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_tmp_bf16_ops.cc @@ -19,8 +19,6 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/no_op.h" -#ifdef ENABLE_INTEL_MKL_BFLOAT16 - namespace tensorflow { // This file contains temporary registrations for some of the Eigen CPU backend @@ -62,5 +60,3 @@ TF_CALL_bfloat16(REGISTER_CPU); #undef REGISTER_CPU } // namespace tensorflow - -#endif // ENABLE_INTEL_MKL_BFLOAT16 diff --git a/tensorflow/core/kernels/mkl/mkl_transpose_op.cc b/tensorflow/core/kernels/mkl/mkl_transpose_op.cc index 72cc760c0de..f2e41a72f28 100644 --- a/tensorflow/core/kernels/mkl/mkl_transpose_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_transpose_op.cc @@ -19,10 +19,6 @@ limitations under the License. #define EIGEN_USE_THREADS -#if !defined(INTEL_MKL_DNN_ONLY) -#include "mkl_trans.h" -#endif - #include "mkldnn.hpp" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/transpose_functor.h" @@ -49,62 +45,6 @@ namespace tensorflow { // REQUIRES: perm is a permutation. namespace { -#if !defined(INTEL_MKL_DNN_ONLY) -template -Status MKLTranspose2D(const char trans, const Tensor& in, Tensor* out); - -// Documentation here: https://software.intel.com/en-us/node/520863 -// Parameters: (ordering:row-major, operation:transpose, num_rows, num_cols, -// alpha (for scaling), array, dist_bet_adjacent_cols/rows -// (source), array, dist_bet_adjacent_cols/rows (dest)) - -#define INSTANTIATE(T, PREFIX) \ - template <> \ - Status MKLTranspose2D(const char trans, const Tensor& in, Tensor* out) { \ - mkl_##PREFIX##omatcopy('R', trans, in.dim_size(0), in.dim_size(1), 1, \ - in.flat().data(), in.dim_size(1), \ - out->flat().data(), in.dim_size(0)); \ - return Status::OK(); \ - } - -INSTANTIATE(float, s) -INSTANTIATE(double, d) - -#undef INSTANTIATE - -template <> -Status MKLTranspose2D(const char trans, const Tensor& in, - Tensor* out) { - const MKL_Complex8 alpha = {1.0f, 0.0f}; - mkl_comatcopy( - 'R', trans, in.dim_size(0), in.dim_size(1), alpha, - reinterpret_cast(in.flat().data()), - in.dim_size(1), - reinterpret_cast( - const_cast(out->flat().data())), - in.dim_size(0)); - return Status::OK(); -} - -template <> -Status MKLTranspose2D(const char trans, const Tensor& in, - Tensor* out) { - const MKL_Complex16 alpha = {1.0, 0.0}; - mkl_zomatcopy( - 'R', trans, in.dim_size(0), in.dim_size(1), alpha, - reinterpret_cast(in.flat().data()), - in.dim_size(1), - reinterpret_cast( - const_cast(out->flat().data())), - in.dim_size(0)); - return Status::OK(); -} - -static const char kMKLTranspose = 'T'; -static const char kMKLConjugateTranspose = 'C'; - -#endif // if !defined(INTEL_MKL_DNN_ONLY) - // MKL-DNN based Transpose implementation template Status MKLTransposeND(OpKernelContext* ctx, const Tensor& in, Tensor* out, @@ -144,7 +84,8 @@ Status MKLTransposeND(OpKernelContext* context, const Tensor& in_tensor, std::vector net; auto* prim = FindOrCreateReorder(in.GetUsrMem(), out.GetUsrMem()); - transpose_stream.reset(CreateStream(context, prim->GetEngine())); + MklDnnThreadPool eigen_tp(context); + transpose_stream.reset(CreateStream(&eigen_tp, prim->GetEngine())); in.SetUsrMemDataHandle(&in_tensor, transpose_stream); out.SetUsrMemDataHandle(out_tensor, transpose_stream); net.push_back(*(prim->GetPrimitive())); @@ -167,26 +108,6 @@ Status MKLTransposeND(OpKernelContext* context, const Tensor& in_tensor, Status MklTransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, gtl::ArraySlice perm, Tensor* out) { -#if !defined(INTEL_MKL_DNN_ONLY) - if (in.dims() == 2) { - if (perm[0] == 0 && perm[1] == 1) { - return Status::OK(); - } - switch (in.dtype()) { - case DT_FLOAT: - return MKLTranspose2D(kMKLTranspose, in, out); - case DT_DOUBLE: - return MKLTranspose2D(kMKLTranspose, in, out); - case DT_COMPLEX64: - return MKLTranspose2D(kMKLTranspose, in, out); - case DT_COMPLEX128: - return MKLTranspose2D(kMKLTranspose, in, out); - default: - break; - } - } -#endif - // MKL-DNN has limit on the maximum number of dimensions in a tensor. // Fallback to Eigen for not supported cases. if (in.dims() <= MKLDNN_MAX_NDIMS) { @@ -213,27 +134,6 @@ Status MklConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, gtl::ArraySlice perm, Tensor* out) { -#if !defined(INTEL_MKL_DNN_ONLY) - if (in.dims() == 2 && perm[0] == 1 && perm[1] == 0) { - // TODO(rmlarsen): By setting lda and ldb, we could use the MKL kernels - // for any transpose that can be reduced to swapping the last two - // dimensions in a rank-3 tensor. We can even run each outer dimension in - // a separate thread. - switch (in.dtype()) { - case DT_FLOAT: - return MKLTranspose2D(kMKLTranspose, in, out); - case DT_DOUBLE: - return MKLTranspose2D(kMKLTranspose, in, out); - case DT_COMPLEX64: - return MKLTranspose2D(kMKLConjugateTranspose, in, out); - case DT_COMPLEX128: - return MKLTranspose2D(kMKLConjugateTranspose, in, out); - default: - break; - } - } -#endif - // MKL-DNN has limit on the maximum number of dimensions in a tensor. // Fallback to Eigen for not supported cases. if (in.dims() <= MKLDNN_MAX_NDIMS) { diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD index 5deadbfd5a7..5d6e09b23a4 100644 --- a/tensorflow/core/kernels/mlir_generated/BUILD +++ b/tensorflow/core/kernels/mlir_generated/BUILD @@ -234,6 +234,7 @@ tf_kernel_library( name = "cpu_cwise_unary_op", srcs = [ "cpu_op_abs.cc", + "cpu_op_rsqrt.cc", "cpu_op_sqrt.cc", ], tags = ["manual"], @@ -245,6 +246,7 @@ tf_kernel_library( deps = [ ":base_cpu_op", ":cpu_abs_kernels", + ":cpu_rsqrt_kernels", ":cpu_sqrt_kernels", "//third_party/eigen3", ], @@ -1067,6 +1069,18 @@ cpu_kernel_library( unroll_factors = "4", ) +cpu_kernel_library( + name = "cpu_rsqrt_lib", + op = "rsqrt", + tile_size = "256", + types = [ + "f16", + "f32", + "f64", + ], + unroll_factors = "4", +) + cpu_kernel_library( name = "cpu_sqrt_lib", op = "sqrt", diff --git a/tensorflow/core/kernels/mlir_generated/base_binary_ops_test.h b/tensorflow/core/kernels/mlir_generated/base_binary_ops_test.h index a6b0535d577..d4ed84b9deb 100644 --- a/tensorflow/core/kernels/mlir_generated/base_binary_ops_test.h +++ b/tensorflow/core/kernels/mlir_generated/base_binary_ops_test.h @@ -309,6 +309,41 @@ class BinaryOpsTestBase : public OpsTestBase { expected_output, config); } + template + void TestBroadcastingRank6(const std::string& op_name, + const absl::InlinedVector& lhs_input, + const absl::InlinedVector& rhs_input, + BaselineOutT (*baseline_callback)(BaselineT, + BaselineT), + const test::OpsTestConfig& config) { + // Prepare inputs. + TensorShape lhs_shape{1, 2, 3, 1, 2, 1}; + TensorShape rhs_shape{1, 1, 1, 2, 3}; + auto repeated_lhs_input = + test::RepeatInputToMatchShape(lhs_input, lhs_shape.num_elements()); + auto repeated_rhs_input = + test::RepeatInputToMatchShape(rhs_input, rhs_shape.num_elements()); + + // Compute expected results. + TensorShape expected_shape{1, 2, 3, 1, 2, 3}; + std::vector lhs_indices = {0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, + 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, + 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11}; + std::vector rhs_indices = { + 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, + 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, + }; + auto expected_output = + ComputeExpectedOutput( + lhs_indices, repeated_lhs_input, rhs_indices, repeated_rhs_input, + baseline_callback); + + RunAndExpectResult(op_name, lhs_shape, repeated_lhs_input, + rhs_shape, repeated_rhs_input, expected_shape, + expected_output, config); + } + template void TestEmptyShapeBroadcasting(const std::string& op_name, @@ -392,6 +427,11 @@ class BinaryOpsTestBase : public OpsTestBase { #op_name, lhs_input, rhs_input, baseline_callback, config); \ } \ \ + TEST_F(BinaryOpsTest, op_name##BroadcastingRank6##test_name) { \ + TestBroadcastingRank6( \ + #op_name, lhs_input, rhs_input, baseline_callback, config); \ + } \ + \ TEST_F(BinaryOpsTest, op_name##EmptyShapeBroadcasting##test_name) { \ TestEmptyShapeBroadcasting( \ #op_name, lhs_input, rhs_input, config); \ diff --git a/tensorflow/core/kernels/mlir_generated/base_unary_ops_test.h b/tensorflow/core/kernels/mlir_generated/base_unary_ops_test.h index fa996eb5c38..a7c334e9eb8 100644 --- a/tensorflow/core/kernels/mlir_generated/base_unary_ops_test.h +++ b/tensorflow/core/kernels/mlir_generated/base_unary_ops_test.h @@ -80,23 +80,55 @@ class UnaryOpsTestBase : public OpsTestBase { } template - void Test(const std::string& op_name, const TensorShape& shape, - const absl::InlinedVector& input, - BaselineOutT (*baseline_callback)(BaselineT), - const test::OpsTestConfig& config) { + typename BaselineCallback> + void TestImpl(const std::string& op_name, const TensorShape& shape, + const absl::InlinedVector& input, + const BaselineCallback& baseline_callback, + const test::OpsTestConfig& config) { // Prepare inputs and compute expected results. CHECK(input.size() <= shape.num_elements()); auto repeated_input = test::RepeatInputToMatchShape(input, shape.num_elements()); absl::InlinedVector expected_output = - ComputeExpectedOutput( - repeated_input, baseline_callback); + ComputeExpectedOutput(repeated_input, + baseline_callback); RunAndExpectResult(op_name, shape, repeated_input, expected_output, config); } + template + void Test(const std::string& op_name, const TensorShape& shape, + const absl::InlinedVector& input, + const BaselineCallback& baseline_callback, + const test::OpsTestConfig& config) { + TestImpl(op_name, shape, input, baseline_callback, + config); + } + + // Allow deduction of overloaded function with const ref input. + template + void Test(const std::string& op_name, const TensorShape& shape, + const absl::InlinedVector& input, + BaselineOutT (*baseline_callback)(const BaselineT&), + const test::OpsTestConfig& config) { + TestImpl(op_name, shape, input, baseline_callback, + config); + } + + // Allow deduction of overloaded function with value input. + template + void Test(const std::string& op_name, const TensorShape& shape, + const absl::InlinedVector& input, + BaselineOutT (*baseline_callback)(BaselineT), + const test::OpsTestConfig& config) { + TestImpl(op_name, shape, input, baseline_callback, + config); + } + template void TestEmptyShape(const std::string& op_name, const test::OpsTestConfig& config) { @@ -112,10 +144,10 @@ class UnaryOpsTestBase : public OpsTestBase { constexpr static double kRelativeTolerance = 0.001; template + typename BaselineCallback> absl::InlinedVector ComputeExpectedOutput( absl::InlinedVector input, - BaselineOutT (*baseline_callback)(BaselineT)) { + const BaselineCallback& baseline_callback) { absl::InlinedVector expected_output; for (int i = 0; i < input.size(); i++) { auto arg = static_cast(input[i]); diff --git a/tensorflow/core/kernels/mlir_generated/cpu_op_rsqrt.cc b/tensorflow/core/kernels/mlir_generated/cpu_op_rsqrt.cc new file mode 100644 index 00000000000..ed20fa2bbd3 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/cpu_op_rsqrt.cc @@ -0,0 +1,25 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/kernels/mlir_generated/base_cpu_op.h" + +namespace tensorflow { + +GENERATE_AND_REGISTER_UNARY_CPU_KERNEL(Rsqrt, DT_HALF); +GENERATE_AND_REGISTER_UNARY_CPU_KERNEL(Rsqrt, DT_FLOAT); +GENERATE_AND_REGISTER_UNARY_CPU_KERNEL(Rsqrt, DT_DOUBLE); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/mlir_generated/cpu_unary_ops_test.cc b/tensorflow/core/kernels/mlir_generated/cpu_unary_ops_test.cc index 6d57c2491a0..1f856bfff8f 100644 --- a/tensorflow/core/kernels/mlir_generated/cpu_unary_ops_test.cc +++ b/tensorflow/core/kernels/mlir_generated/cpu_unary_ops_test.cc @@ -67,23 +67,20 @@ GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES( Abs, DT_INT64, DT_INT64, test::NearZeroAndExtremeInput(), std::abs, test::OpsTestConfig().NoBufferReuse().ExpectStrictlyEqual()) +/// Test `tf.Rsqrt`. +GENERATE_DEFAULT_TEST(Rsqrt, DT_HALF, DT_HALF, Eigen::numext::rsqrt, + test::OpsTestConfig().NoBufferReuse()) +GENERATE_DEFAULT_TEST(Rsqrt, DT_FLOAT, DT_FLOAT, Eigen::numext::rsqrt, + test::OpsTestConfig().NoBufferReuse()) +GENERATE_DEFAULT_TEST(Rsqrt, DT_DOUBLE, DT_DOUBLE, Eigen::numext::rsqrt, + test::OpsTestConfig().NoBufferReuse()) + /// Test `tf.Sqrt`. - -// Forwards to Eigen, necessary since Eigen passes by `const T&` but existing -// Test class expects passing by value. Eigen::numext::sqrt works properly for -// Eigen::half and Eigen::bfloat16, which do not have a std::sqrt -// implementation. -template -T baseline_sqrt(T x) { - using Eigen::numext::sqrt; - return sqrt(x); -} - -GENERATE_DEFAULT_TEST(Sqrt, DT_HALF, DT_HALF, baseline_sqrt, +GENERATE_DEFAULT_TEST(Sqrt, DT_HALF, DT_HALF, Eigen::numext::sqrt, test::OpsTestConfig().NoBufferReuse()) -GENERATE_DEFAULT_TEST(Sqrt, DT_FLOAT, DT_FLOAT, baseline_sqrt, +GENERATE_DEFAULT_TEST(Sqrt, DT_FLOAT, DT_FLOAT, Eigen::numext::sqrt, test::OpsTestConfig().NoBufferReuse()) -GENERATE_DEFAULT_TEST(Sqrt, DT_DOUBLE, DT_DOUBLE, baseline_sqrt, +GENERATE_DEFAULT_TEST(Sqrt, DT_DOUBLE, DT_DOUBLE, Eigen::numext::sqrt, test::OpsTestConfig().NoBufferReuse()) } // namespace diff --git a/tensorflow/core/kernels/neon/BUILD b/tensorflow/core/kernels/neon/BUILD deleted file mode 100644 index 668ac4b2f64..00000000000 --- a/tensorflow/core/kernels/neon/BUILD +++ /dev/null @@ -1,27 +0,0 @@ -load("//tensorflow:tensorflow.bzl", "tf_kernel_library") - -# Description: -# Kernel implementations using Neon intrinsics. -# -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -tf_kernel_library( - name = "neon_depthwise_conv_op", - hdrs = [ - "depthwiseconv_float.h", - "types.h", - ], - features = ["-parse_headers"], # included gemmlowp headers are not self-contained - prefix = "neon_depthwise_conv_op", - deps = [ - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core/framework:bounds_check", - "//tensorflow/core/kernels:ops_util", - "@gemmlowp", - ], -) diff --git a/tensorflow/core/kernels/neon/depthwiseconv_float.h b/tensorflow/core/kernels/neon/depthwiseconv_float.h deleted file mode 100644 index 1593f2173d2..00000000000 --- a/tensorflow/core/kernels/neon/depthwiseconv_float.h +++ /dev/null @@ -1,723 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_FLOAT_H_ -#define TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_FLOAT_H_ - -#include "public/gemmlowp.h" -#include "tensorflow/core/kernels/neon/types.h" - -#if defined(__ARM_NEON__) || defined(__ARM_NEON) -#define USE_NEON -#include -#endif - -namespace tensorflow { -namespace neon { - -// Implementation of float DepthwiseConv - -template -struct FloatDepthwiseConvKernel {}; - -#ifdef USE_NEON - -template <> -struct FloatDepthwiseConvKernel { - static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const float* input_ptr, int input_ptr_increment, - const float* filter_ptr, float* acc_buffer_ptr) { - // Load the filters - float32x4_t filter[2]; - for (int i = 0; i < 2; i++) { - filter[i] = vld1q_f32(filter_ptr + 4 * i); - } - int outp = 0; - // Handle 2 output pixels at a time. - for (; outp <= num_output_pixels - 2; outp += 2) { - // Load the inputs - float32x4_t input[4]; - for (int i = 0; i < 4; i++) { - input[i] = vld1q_f32(input_ptr + 4 * i); - } - input_ptr += 16; - // Load the accumulators from acc_buffer - float32x4_t acc[4]; - for (int i = 0; i < 4; i++) { - acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); - } - // Multiply-accumulate - acc[0] = vmlaq_f32(acc[0], input[0], filter[0]); - acc[1] = vmlaq_f32(acc[1], input[1], filter[1]); - acc[2] = vmlaq_f32(acc[2], input[2], filter[0]); - acc[3] = vmlaq_f32(acc[3], input[3], filter[1]); - // Store the accumulators back to acc_buffer - for (int i = 0; i < 4; i++) { - vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); - } - acc_buffer_ptr += 16; - } - // Handle one output pixel at a time. - for (; outp < num_output_pixels; outp++) { - // Load the inputs - float32x4_t input[2]; - for (int i = 0; i < 2; i++) { - input[i] = vld1q_f32(input_ptr + 4 * i); - } - input_ptr += 8; - // Load the accumulators from acc_buffer - float32x4_t acc[2]; - for (int i = 0; i < 2; i++) { - acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); - } - // Multiply-accumulate - for (int i = 0; i < 2; i++) { - acc[i] = vmlaq_f32(acc[i], input[i], filter[i]); - } - // Store the accumulators back to acc_buffer - for (int i = 0; i < 2; i++) { - vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); - } - acc_buffer_ptr += 8; - } - } -}; - -template <> -struct FloatDepthwiseConvKernel { - static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const float* input_ptr, int input_ptr_increment, - const float* filter_ptr, float* acc_buffer_ptr) { - const float32x2_t filters = vld1_f32(filter_ptr); - const float32x4_t filters_dup2 = vcombine_f32(filters, filters); - int outp = 0; - // Handle 8 output pixels at a time. - for (; outp <= num_output_pixels - 8; outp += 8) { - // Load the inputs - float32x4_t input[4]; - for (int i = 0; i < 4; i++) { - input[i] = vld1q_f32(input_ptr + 4 * i); - } - input_ptr += 16; - // Load the accumulators from acc_buffer - float32x4_t acc[4]; - for (int i = 0; i < 4; i++) { - acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); - } - // Multiply-accumulate - for (int i = 0; i < 4; i++) { - acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2); - } - // Store the accumulators back to acc_buffer - for (int i = 0; i < 4; i++) { - vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); - } - acc_buffer_ptr += 16; - } - // Handle 4 output pixels at a time. - for (; outp <= num_output_pixels - 4; outp += 4) { - // Load the inputs - float32x4_t input[2]; - for (int i = 0; i < 2; i++) { - input[i] = vld1q_f32(input_ptr + 4 * i); - } - input_ptr += 8; - // Load the accumulators from acc_buffer - float32x4_t acc[2]; - for (int i = 0; i < 2; i++) { - acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); - } - // Multiply-accumulate - for (int i = 0; i < 2; i++) { - acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2); - } - // Store the accumulators back to acc_buffer - for (int i = 0; i < 2; i++) { - vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); - } - acc_buffer_ptr += 8; - } - // Handle 2 output pixels at a time. - for (; outp <= num_output_pixels - 2; outp += 2) { - // Load the inputs - const float32x4_t input = vld1q_f32(input_ptr); - input_ptr += 4; - // Load the accumulators from acc_buffer - float32x4_t acc = vld1q_f32(acc_buffer_ptr); - // Multiply-accumulate - acc = vmlaq_f32(acc, input, filters_dup2); - // Store the accumulators back to acc_buffer - vst1q_f32(acc_buffer_ptr, acc); - acc_buffer_ptr += 4; - } - // Handle 1 output pixel at a time - for (; outp < num_output_pixels; outp++) { - // Load the inputs - const float32x2_t input = vld1_f32(input_ptr); - input_ptr += 2; - // Load the accumulators from acc_buffer - float32x2_t acc = vld1_f32(acc_buffer_ptr); - // Multiply-accumulate - acc = vmla_f32(acc, input, filters); - // Store the accumulators back to acc_buffer - vst1_f32(acc_buffer_ptr, acc); - acc_buffer_ptr += 2; - } - } -}; - -template <> -struct FloatDepthwiseConvKernel { - static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const float* input_ptr, int input_ptr_increment, - const float* filter_ptr, float* acc_buffer_ptr) { - // Handle one output pixel at a time. - for (int outp = 0; outp < num_output_pixels; outp++) { - const float* local_filter_ptr = filter_ptr; - const float* local_input_ptr = input_ptr; - int ic = 0; - // Handle 16 input channels at a time. - for (; ic <= input_depth - 16; ic += 16) { - // Load the filters - float32x4_t filter[4]; - for (int i = 0; i < 4; i++) { - filter[i] = vld1q_f32(local_filter_ptr + 4 * i); - } - local_filter_ptr += 16; - // Load the inputs - float32x4_t input[4]; - for (int i = 0; i < 4; i++) { - input[i] = vld1q_f32(local_input_ptr + 4 * i); - } - local_input_ptr += 16; - // Load the accumulators from acc_buffer - float32x4_t acc[4]; - for (int i = 0; i < 4; i++) { - acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); - } - // Multiply-accumulate - for (int i = 0; i < 4; i++) { - acc[i] = vmlaq_f32(acc[i], input[i], filter[i]); - } - // Store the accumulators back to acc_buffer - for (int i = 0; i < 4; i++) { - vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); - } - acc_buffer_ptr += 16; - } - // Handle 4 input channels at a time. - for (; ic <= input_depth - 4; ic += 4) { - // Load the filters - float32x4_t filter; - filter = vld1q_f32(local_filter_ptr); - local_filter_ptr += 4; - // Load the inputs - float32x4_t input; - input = vld1q_f32(local_input_ptr); - local_input_ptr += 4; - // Load the accumulators from acc_buffer - float32x4_t acc; - acc = vld1q_f32(acc_buffer_ptr); - // Multiply-accumulate - acc = vmlaq_f32(acc, input, filter); - // Store the accumulators back to acc_buffer - vst1q_f32(acc_buffer_ptr, acc); - acc_buffer_ptr += 4; - } - // Handle one input channel at a time. - for (; ic < input_depth; ic++) { - const float input_val = *local_input_ptr++; - const float filter_val = *local_filter_ptr++; - *acc_buffer_ptr++ += filter_val * input_val; - } - input_ptr += input_ptr_increment; - } - } -}; - -template <> -struct FloatDepthwiseConvKernel { - static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const float* input_ptr, int input_ptr_increment, - const float* filter_ptr, float* acc_buffer_ptr) { - // Handle one output pixel at a time. - for (int outp = 0; outp < num_output_pixels; outp++) { - const float* local_filter_ptr = filter_ptr; - const float* local_input_ptr = input_ptr; - int ic = 0; - // Handle 2 input channels at a time. - for (; ic <= input_depth - 2; ic += 2) { - // Load the filters - float32x4_t filter[4]; - for (int i = 0; i < 4; i++) { - filter[i] = vld1q_f32(local_filter_ptr + 4 * i); - } - local_filter_ptr += 16; - // Load the inputs - const float32x2_t input = vld1_f32(local_input_ptr); - local_input_ptr += 2; - // Load the accumulators from acc_buffer - float32x4_t acc[4]; - for (int i = 0; i < 4; i++) { - acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); - } - // Multiply-accumulate - acc[0] = vmlaq_lane_f32(acc[0], filter[0], input, 0); - acc[1] = vmlaq_lane_f32(acc[1], filter[1], input, 0); - acc[2] = vmlaq_lane_f32(acc[2], filter[2], input, 1); - acc[3] = vmlaq_lane_f32(acc[3], filter[3], input, 1); - // Store the accumulators back to acc_buffer - for (int i = 0; i < 4; i++) { - vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); - } - acc_buffer_ptr += 16; - } - // Handle one input channel at a time. - for (; ic < input_depth; ic++) { - // Load the filters - float32x4_t filter[2]; - for (int i = 0; i < 2; i++) { - filter[i] = vld1q_f32(local_filter_ptr + 4 * i); - } - local_filter_ptr += 8; - // Load the inputs - const float input_val = *local_input_ptr++; - // Load the accumulators from acc_buffer - float32x4_t acc[2]; - for (int i = 0; i < 2; i++) { - acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); - } - // Multiply-accumulate - for (int i = 0; i < 2; i++) { - acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val); - } - // Store the accumulators back to acc_buffer - for (int i = 0; i < 2; i++) { - vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); - } - acc_buffer_ptr += 8; - } - input_ptr += input_ptr_increment; - } - } -}; - -template <> -struct FloatDepthwiseConvKernel { - static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const float* input_ptr, int input_ptr_increment, - const float* filter_ptr, float* acc_buffer_ptr) { - // Handle one output pixel at a time. - for (int outp = 0; outp < num_output_pixels; outp++) { - const float* local_filter_ptr = filter_ptr; - const float* local_input_ptr = input_ptr; - int ic = 0; - // Handle 8 input channels at a time. - for (; ic <= input_depth - 8; ic += 8) { - // Load the filters - float32x4_t filter[4]; - for (int i = 0; i < 4; i++) { - filter[i] = vld1q_f32(local_filter_ptr + 4 * i); - } - local_filter_ptr += 16; - // Load the inputs - float32x4x2_t input_dup2[2]; - for (int i = 0; i < 2; i++) { - const float32x4_t input = vld1q_f32(local_input_ptr + 4 * i); - input_dup2[i] = vzipq_f32(input, input); - } - local_input_ptr += 8; - // Load the accumulators from acc_buffer - float32x4_t acc[4]; - for (int i = 0; i < 4; i++) { - acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); - } - // Multiply-accumulate - acc[0] = vmlaq_f32(acc[0], filter[0], input_dup2[0].val[0]); - acc[1] = vmlaq_f32(acc[1], filter[1], input_dup2[0].val[1]); - acc[2] = vmlaq_f32(acc[2], filter[2], input_dup2[1].val[0]); - acc[3] = vmlaq_f32(acc[3], filter[3], input_dup2[1].val[1]); - // Store the accumulators back to acc_buffer - for (int i = 0; i < 4; i++) { - vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); - } - acc_buffer_ptr += 16; - } - // Handle 4 input channels at a time. - for (; ic <= input_depth - 4; ic += 4) { - // Load the filters - float32x2_t filter[4]; - for (int i = 0; i < 4; i++) { - filter[i] = vld1_f32(local_filter_ptr + 2 * i); - } - local_filter_ptr += 8; - // Load the inputs - const float32x4_t input = vld1q_f32(local_input_ptr); - local_input_ptr += 4; - // Load the accumulators from acc_buffer - float32x2_t acc[4]; - for (int i = 0; i < 4; i++) { - acc[i] = vld1_f32(acc_buffer_ptr + 2 * i); - } - // Multiply-accumulate - acc[0] = vmla_lane_f32(acc[0], filter[0], vget_low_f32(input), 0); - acc[1] = vmla_lane_f32(acc[1], filter[1], vget_low_f32(input), 1); - acc[2] = vmla_lane_f32(acc[2], filter[2], vget_high_f32(input), 0); - acc[3] = vmla_lane_f32(acc[3], filter[3], vget_high_f32(input), 1); - // Store the accumulators back to acc_buffer - for (int i = 0; i < 4; i++) { - vst1_f32(acc_buffer_ptr + 2 * i, acc[i]); - } - acc_buffer_ptr += 8; - } - // Handle 2 input channels at a time. - for (; ic <= input_depth - 2; ic += 2) { - // Load the filters - const float32x4_t filter = vld1q_f32(local_filter_ptr); - local_filter_ptr += 4; - // Load the inputs - const float32x2_t input = vld1_f32(local_input_ptr); - local_input_ptr += 2; - // Load the accumulators from acc_buffer - float32x2_t acc[2]; - for (int i = 0; i < 2; i++) { - acc[i] = vld1_f32(acc_buffer_ptr + 2 * i); - } - // Multiply-accumulate - acc[0] = vmla_lane_f32(acc[0], vget_low_f32(filter), input, 0); - acc[1] = vmla_lane_f32(acc[1], vget_high_f32(filter), input, 1); - // Store the accumulators back to acc_buffer - for (int i = 0; i < 2; i++) { - vst1_f32(acc_buffer_ptr + 2 * i, acc[i]); - } - acc_buffer_ptr += 4; - } - // Handle one input channel at a time. - for (; ic < input_depth; ic++) { - // Load the inputs - const float input_val = *local_input_ptr++; - // Multiply-accumulate - for (int i = 0; i < 2; i++) { - acc_buffer_ptr[i] += local_filter_ptr[i] * input_val; - } - local_filter_ptr += 2; - acc_buffer_ptr += 2; - } - input_ptr += input_ptr_increment; - } - } -}; -#endif - -// Accumulates the effect of one row of the filter, on a segment of one row -// of the output, accessing the corresponding one row of the input. -template -void FloatDepthwiseConvAccumRow(int stride, int input_depth, int input_width, - const float* input_data, int pad_width, - int depth_multiplier, int filter_width, - const float* filter_data, - int out_x_buffer_start, int out_x_buffer_end, - int output_depth, float* acc_buffer) { -#ifdef GEMMLOWP_PROFILING - gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__); -#endif - // Sanity check parameters. This is important in particular to ensure - // that we keep the number of template instantiations minimal, so we don't - // increase binary size unnecessarily. - static_assert(kFixedDepthMultiplier || !kFixedInputDepth, ""); - static_assert(kFixedInputDepth || kAllowStrided, ""); - DCHECK(stride == 1 || kAllowStrided); - if (kFixedInputDepth) { - DCHECK_EQ(input_depth, kFixedInputDepth); - } - if (kFixedDepthMultiplier) { - DCHECK_EQ(depth_multiplier, kFixedDepthMultiplier); - } - DCHECK_EQ(output_depth, input_depth * depth_multiplier); - const int input_ptr_increment = stride * input_depth; - const float* filter_base_ptr = filter_data; - for (int filter_x = 0; filter_x < filter_width; ++filter_x) { - // For the current (filter_x, filter_y) point in the filter, - // compute the boundaries of the corresponding output row segment. - int out_x_loop_start_unclamped = 0; - int out_x_loop_end_unclamped = 0; - if (kAllowStrided) { - if (stride == 2) { - out_x_loop_start_unclamped = (pad_width - filter_x + 1) / 2; - out_x_loop_end_unclamped = (pad_width + input_width - filter_x + 1) / 2; - } else if (stride == 4) { - out_x_loop_start_unclamped = (pad_width - filter_x + 3) / 4; - out_x_loop_end_unclamped = (pad_width + input_width - filter_x + 3) / 4; - } else { - out_x_loop_start_unclamped = - (pad_width - filter_x + stride - 1) / stride; - out_x_loop_end_unclamped = - (pad_width + input_width - filter_x + stride - 1) / stride; - } - } else { - out_x_loop_start_unclamped = pad_width - filter_x; - out_x_loop_end_unclamped = pad_width + input_width - filter_x; - } - // The kernel will have to iterate on the segment of the - // output row that starts at out_x_loop_start and out_x_loop_end. - const int out_x_loop_start = - std::max(out_x_buffer_start, out_x_loop_start_unclamped); - const int out_x_loop_end = - std::min(out_x_buffer_end, out_x_loop_end_unclamped); - - float* acc_buffer_ptr = - acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; - const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x; - const float* input_ptr = input_data + in_x_origin * input_depth; - const int num_output_pixels = out_x_loop_end - out_x_loop_start; - FloatDepthwiseConvKernel::Run(num_output_pixels, - input_depth, - depth_multiplier, - input_ptr, - input_ptr_increment, - filter_base_ptr, - acc_buffer_ptr); - filter_base_ptr += output_depth; - } -} - -// generic fallback of FloatDepthwiseConvAccumRow, portable, non-templatized. -inline void FloatDepthwiseConvAccumRowGeneric( - int stride, int input_depth, int input_width, const float* input_data, - int pad_width, int depth_multiplier, int filter_width, - const float* filter_data, int out_x_buffer_start, int out_x_buffer_end, - int output_depth, float* acc_buffer) { - gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)"); - - VLOG(1) << "DepthwiseConv2d using slow path with " - << "stride = " << stride << ", " - << "input_depth = " << input_depth << ", " - << "depth_multiplier = " << depth_multiplier << "."; - - const float* filter_base_ptr = filter_data; - for (int filter_x = 0; filter_x < filter_width; ++filter_x) { - const int out_x_loop_start = std::max( - out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride); - const int out_x_loop_end = - std::min(out_x_buffer_end, - (pad_width + input_width - filter_x + stride - 1) / stride); - - float* acc_buffer_ptr = - acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; - const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x; - const float* input_ptr = input_data + in_x_origin * input_depth; - const int input_ptr_increment = (stride - 1) * input_depth; - for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) { - const float* filter_ptr = filter_base_ptr; - for (int ic = 0; ic < input_depth; ++ic) { - const float input_val = *input_ptr++; - for (int m = 0; m < depth_multiplier; m++) { - const float filter_val = *filter_ptr++; - *acc_buffer_ptr++ += filter_val * input_val; - } - } - input_ptr += input_ptr_increment; - } - filter_base_ptr += output_depth; - } -} - -// Initializes the accumulator buffer with bias values. -inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth, - const float* bias_data, - float* acc_buffer) { - // TODO(benoitjacob): This might need optimized specializations - // for small output_depth values, if that ever becomes an important - // case (like it was for some quantized DepthwiseConv cases). - for (int i = 0; i < num_output_pixels; i++) { - memcpy(acc_buffer + i * output_depth, bias_data, - sizeof(acc_buffer[0]) * output_depth); - } -} - -template -void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, - const float* filter_data, const Dims<4>& filter_dims, - const float* bias_data, const Dims<4>& bias_dims, int stride, - int pad_width, int pad_height, int depth_multiplier, - float* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("DepthwiseConv"); - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int input_depth = ArraySize(input_dims, 0); - const int filter_height = ArraySize(filter_dims, 2); - const int filter_width = ArraySize(filter_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - DCHECK(output_depth == input_depth * depth_multiplier); - - static const int kAccBufferMaxSize = 1024; - float acc_buffer[kAccBufferMaxSize]; - DCHECK_GE(kAccBufferMaxSize, output_depth) - << "Too small kAccBufferMaxSize for this model!"; - const int kOutputPixelsInAccBuffer = kAccBufferMaxSize / output_depth; - const int kAccBufferActualSize = kOutputPixelsInAccBuffer * output_depth; - DCHECK_LE(kOutputPixelsInAccBuffer * output_depth, kAccBufferActualSize); - DCHECK_LE(kAccBufferActualSize, kAccBufferMaxSize); - DCHECK_GE(kOutputPixelsInAccBuffer, 1); - - // row_accum_func will point to the core accumulation function to be used - // for this DepthwiseConv op. - auto* row_accum_func = FloatDepthwiseConvAccumRowGeneric; - - const int kMaxFixedDepthMultiplier = 8; - int fixed_depth_multiplier = 0; - if (depth_multiplier <= kMaxFixedDepthMultiplier) { - fixed_depth_multiplier = depth_multiplier; - } - // kMaxUnrolling is the max number of output values that we aim to handle - // in one unrolled iteration of the inner loop. For practical performance - // reasons, it is limited by the number of available registers. We could - // fine-tune it depending on the architecture, but that's not worth doing - // since this whole code is not very optimized to begin with. The - // present value reflects what's realistic on ARM 32bit NEON with 16 128-bit - // vector registers. - const int kMaxUnrolling = 8; - int fixed_input_depth = 0; - if (fixed_depth_multiplier && - input_depth * fixed_depth_multiplier <= kMaxUnrolling) { - fixed_input_depth = input_depth; - } -#define TF_NEON_USE_DEPTHWISECONV_KERNEL(ALLOW_STRIDED, FIXED_INPUT_DEPTH, \ - FIXED_DEPTH_MULTIPLIER) \ - if ((stride == 1 || ALLOW_STRIDED) && \ - fixed_input_depth == FIXED_INPUT_DEPTH && \ - fixed_depth_multiplier == FIXED_DEPTH_MULTIPLIER) { \ - row_accum_func = \ - FloatDepthwiseConvAccumRow; \ - } - -#ifdef USE_NEON - TF_NEON_USE_DEPTHWISECONV_KERNEL(true, 0, 1) - TF_NEON_USE_DEPTHWISECONV_KERNEL(true, 0, 8) - TF_NEON_USE_DEPTHWISECONV_KERNEL(true, 0, 2) - TF_NEON_USE_DEPTHWISECONV_KERNEL(false, 8, 1) - TF_NEON_USE_DEPTHWISECONV_KERNEL(false, 2, 1) -#endif // USE_NEON - -#undef TF_NEON_USE_DEPTHWISECONV_KERNEL - - // Now that we have determined row_accum_func, we can start work. - float* output_ptr = output_data; - for (int b = 0; b < batches; ++b) { - for (int out_y = 0; out_y < output_height; ++out_y) { - const int in_y_origin = (out_y * stride) - pad_height; - const int filter_y_start = std::max(0, -in_y_origin); - const int filter_y_end = - std::min(filter_height, input_height - in_y_origin); - for (int out_x_buffer_start = 0; out_x_buffer_start < output_width; - out_x_buffer_start += kOutputPixelsInAccBuffer) { - const int out_x_buffer_end = std::min( - output_width, out_x_buffer_start + kOutputPixelsInAccBuffer); - // We call a 'pixel' a group of activation that share all but the - // 'depth'/'channel' coordinate. num_output_pixels is the number of - // output pixels that we will accumulate in this loop iteration. - const int num_output_pixels = out_x_buffer_end - out_x_buffer_start; - // Initialize our local accumulator with the bias values, so we don't - // have to add them later. - DepthwiseConvInitAccBuffer(num_output_pixels, output_depth, bias_data, - acc_buffer); - // Accumulation loop. Most of the time should be spent in here. - for (int filter_y = filter_y_start; filter_y < filter_y_end; - ++filter_y) { - const int in_y = in_y_origin + filter_y; - row_accum_func(stride, input_depth, input_width, - input_data + in_y * input_dims.strides[2] + - b * input_dims.strides[3], - pad_width, depth_multiplier, filter_width, - filter_data + filter_y * filter_dims.strides[2], - out_x_buffer_start, out_x_buffer_end, output_depth, - acc_buffer); - } - // Finished accumulating. Now store to destination. - const int num_output_values = output_depth * num_output_pixels; - int i = 0; -// TODO(benoitjacob) optimized code goes here -#ifdef USE_NEON - // Handle 16 values at a time - for (; i <= num_output_values - 16; i += 16) { - float32x4_t acc[4]; - for (int k = 0; k < 4; k++) { - acc[k] = vld1q_f32(acc_buffer + i + 4 * k); - } - if (Ac == FusedActivationFunctionType::kRelu) { - for (int k = 0; k < 4; k++) { - acc[k] = vmaxq_f32(vdupq_n_f32(0.f), acc[k]); - } - } else if (Ac == FusedActivationFunctionType::kRelu6) { - for (int k = 0; k < 4; k++) { - acc[k] = vmaxq_f32(vdupq_n_f32(0.f), - vminq_f32(vdupq_n_f32(6.f), acc[k])); - } - } else if (Ac == FusedActivationFunctionType::kRelu1) { - for (int k = 0; k < 4; k++) { - acc[k] = vmaxq_f32(vdupq_n_f32(-1.f), - vminq_f32(vdupq_n_f32(1.f), acc[k])); - } - } - for (int k = 0; k < 4; k++) { - vst1q_f32(output_ptr + 4 * k, acc[k]); - } - output_ptr += 16; - } - // Handle 4 values at a time - for (; i <= num_output_values - 4; i += 4) { - float32x4_t acc = vld1q_f32(acc_buffer + i); - if (Ac == FusedActivationFunctionType::kRelu) { - acc = vmaxq_f32(vdupq_n_f32(0.f), acc); - } else if (Ac == FusedActivationFunctionType::kRelu6) { - acc = vmaxq_f32(vdupq_n_f32(0.f), vminq_f32(vdupq_n_f32(6.f), acc)); - } else if (Ac == FusedActivationFunctionType::kRelu1) { - acc = - vmaxq_f32(vdupq_n_f32(-1.f), vminq_f32(vdupq_n_f32(1.f), acc)); - } - vst1q_f32(output_ptr, acc); - output_ptr += 4; - } -#endif - // Handle leftover values, one by one. This is very slow. - for (; i < num_output_values; i++) { - float acc = acc_buffer[i]; - if (Ac == FusedActivationFunctionType::kRelu) { - acc = std::max(0.f, acc); - } else if (Ac == FusedActivationFunctionType::kRelu6) { - acc = std::max(0.f, std::min(6.f, acc)); - } else if (Ac == FusedActivationFunctionType::kRelu1) { - acc = std::max(-1.f, std::min(1.f, acc)); - } - *output_ptr++ = acc; - } - } - } - } -} - -} // end namespace neon -} // end namespace tensorflow - -#endif // TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_FLOAT_H_ diff --git a/tensorflow/core/kernels/neon/neon_depthwise_conv_op.cc b/tensorflow/core/kernels/neon/neon_depthwise_conv_op.cc deleted file mode 100644 index 8e853f2338b..00000000000 --- a/tensorflow/core/kernels/neon/neon_depthwise_conv_op.cc +++ /dev/null @@ -1,204 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK -#include "public/gemmlowp.h" -#include "tensorflow/core/framework/bounds_check.h" -#include "tensorflow/core/framework/kernel_shape_util.h" -#include "tensorflow/core/framework/numeric_op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/neon/depthwiseconv_float.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/mem.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/padding.h" - -namespace tensorflow { - -// A version of tensorflow/core/kernels/depthwise_conv_op.cc that -// uses the neon intrinsics. -class NeonDepthwiseConv2dNativeOp : public BinaryOp { - public: - explicit NeonDepthwiseConv2dNativeOp(OpKernelConstruction* context) - : BinaryOp(context) { - OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); - OP_REQUIRES(context, strides_.size() == 4, - errors::InvalidArgument("Sliding window strides field must " - "specify 4 dimensions")); - OP_REQUIRES(context, strides_[1] == strides_[2], - errors::InvalidArgument( - "Current implementation only supports equal length " - "strides in the row and column dimensions.")); - OP_REQUIRES( - context, (strides_[0] == 1 && strides_[3] == 1), - errors::InvalidArgument("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); - OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); - } - - void Compute(OpKernelContext* context) override { - const Tensor& input = context->input(0); - const Tensor& filter = context->input(1); - - // For 2D convolution, there should be 4 dimensions. - OP_REQUIRES(context, input.dims() == 4, - errors::InvalidArgument("input must be 4-dimensional", - input.shape().DebugString())); - OP_REQUIRES(context, filter.dims() == 4, - errors::InvalidArgument("filter must be 4-dimensional: ", - filter.shape().DebugString())); - - const int32 in_depth = input.dim_size(3); - OP_REQUIRES(context, in_depth == filter.dim_size(2), - errors::InvalidArgument( - "input and filter must have the same depth: ", in_depth, - " vs ", filter.dim_size(2))); - const int32 batch = input.dim_size(0); - const int32 input_rows = input.dim_size(1); - const int32 input_cols = input.dim_size(2); - - const int32 filter_rows = filter.dim_size(0); - const int32 filter_cols = filter.dim_size(1); - const int32 depth_multiplier = filter.dim_size(3); - - const int32 out_depth = in_depth * depth_multiplier; - - const int32 stride = strides_[1]; - - int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; - OP_REQUIRES_OK(context, - GetWindowedOutputSize(input_rows, filter_rows, stride, - padding_, &out_rows, &pad_rows)); - OP_REQUIRES_OK(context, - GetWindowedOutputSize(input_cols, filter_cols, stride, - padding_, &out_cols, &pad_cols)); - TensorShape out_shape({batch, out_rows, out_cols, out_depth}); - OP_REQUIRES( - context, - FastBoundsCheck(out_shape.num_elements(), - std::numeric_limits::max()), - errors::InvalidArgument("Output elements too large for NEON kernel")); - - // Output tensor is of the following dimensions: - // [ in_batch, out_rows, out_cols, out_depth ] - Tensor* output = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); - - VLOG(2) << "NeonDepthwiseConv2dNative: " - << " Input: [" << batch << ", " << input_rows << ", " << input_cols - << ", " << in_depth << "]; Filter: [" << filter_rows << ", " - << filter_cols << ", " << in_depth << ", " << depth_multiplier - << "]; stride = " << stride << ", pad_rows = " << pad_rows - << ", pad_cols = " << pad_cols << ", output: [" << batch << ", " - << out_rows << ", " << out_cols << ", " << out_depth << "]"; - - // If there is nothing to compute, return. - if (out_shape.num_elements() == 0) { - return; - } - - const float* input_ptr = input.template flat().data(); - const float* filter_ptr = filter.template flat().data(); - float* output_ptr = output->template flat().data(); - - auto input_neon_dims = ToNeonDims(input.shape()); - auto filter_neon_dims = FilterToNeonDims(filter.shape()); - auto bias_neon_dims = BiasNeonDims(filter.shape()); - - int64 bias_size = bias_neon_dims.sizes[0]; - float* bias_ptr = static_cast(port::AlignedMalloc( - bias_size * sizeof(float), Allocator::kAllocatorAlignment)); - memset(bias_ptr, 0, bias_size * sizeof(float)); - - neon::DepthwiseConv( - input_ptr, input_neon_dims, filter_ptr, filter_neon_dims, bias_ptr, - bias_neon_dims, stride, pad_cols, pad_rows, depth_multiplier, - output_ptr, ToNeonDims(out_shape)); - - port::AlignedFree(bias_ptr); - } - - private: - void SetNeonDimStrides(neon::Dims<4>* d) { - int64 stride = 1; - for (int i = 0; i < 4; ++i) { - d->strides[i] = stride; - stride *= d->sizes[i]; - } - } - - neon::Dims<4> ToNeonDims(const TensorShape& input) { - // Dims in the neon kernels are channel, x, y, batch order. - neon::Dims<4> result; - result.sizes[0] = input.dim_size(3); - result.sizes[1] = input.dim_size(2); - result.sizes[2] = input.dim_size(1); - result.sizes[3] = input.dim_size(0); - SetNeonDimStrides(&result); - return result; - } - - neon::Dims<4> FilterToNeonDims(const TensorShape& filter) { - // Dims in the neon kernels are channel, x, y, batch order. - neon::Dims<4> result; - result.sizes[0] = filter.dim_size(2) * filter.dim_size(3); - result.sizes[1] = filter.dim_size(1); - result.sizes[2] = filter.dim_size(0); - result.sizes[3] = 1; - SetNeonDimStrides(&result); - - return result; - } - - neon::Dims<4> BiasNeonDims(const TensorShape& filter) { - // Dims in the neon kernels are channel, x, y, batch order. - // Bias has only output channel set. - neon::Dims<4> result; - result.sizes[0] = - filter.dim_size(2) * filter.dim_size(3); // output channels - result.sizes[1] = 1; - result.sizes[2] = 1; - result.sizes[3] = 1; - SetNeonDimStrides(&result); - - return result; - } - - std::vector strides_; - Padding padding_; - - TF_DISALLOW_COPY_AND_ASSIGN(NeonDepthwiseConv2dNativeOp); -}; - -#define REGISTER_CPU_KERNEL(T) \ - REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label("neon"), \ - NeonDepthwiseConv2dNativeOp); - -TF_CALL_float(REGISTER_CPU_KERNEL); - -} // namespace tensorflow diff --git a/tensorflow/core/kernels/neon/types.h b/tensorflow/core/kernels/neon/types.h deleted file mode 100644 index 05ff1bcc6cd..00000000000 --- a/tensorflow/core/kernels/neon/types.h +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_KERNELS_NEON_TYPES_H_ -#define TENSORFLOW_CORE_KERNELS_NEON_TYPES_H_ - -#include "tensorflow/core/platform/logging.h" - -namespace tensorflow { -namespace neon { - -enum class FusedActivationFunctionType { kNone, kRelu6, kRelu1, kRelu }; - -template -struct Dims { - int sizes[N]; - int strides[N]; -}; - -inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) { - DCHECK(i0 >= 0 && i0 < dims.sizes[0]); - DCHECK(i1 >= 0 && i1 < dims.sizes[1]); - DCHECK(i2 >= 0 && i2 < dims.sizes[2]); - DCHECK(i3 >= 0 && i3 < dims.sizes[3]); - return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] + - i3 * dims.strides[3]; -} - -// Get array size, DCHECKing that the dim index is in range. -template -int ArraySize(const Dims& array, int index) { - DCHECK(index >= 0 && index < N); - return array.sizes[index]; -} - -// Get common array size, DCHECKing that they all agree. -template -int MatchingArraySize(const ArrayType1& array1, int index1, - const ArrayType2& array2, int index2) { - DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2)); - return ArraySize(array1, index1); -} - -template -int MatchingArraySize(const ArrayType1& array1, int index1, - const ArrayType2& array2, int index2, Args... args) { - DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2)); - return MatchingArraySize(array1, index1, args...); -} - -inline int RequiredBufferSizeForDims(const Dims<4>& dims) { - int max_offset = 0; - for (int i = 0; i < 4; i++) { - max_offset += (dims.sizes[i] - 1) * dims.strides[i]; - } - return max_offset + 1; -} - -} // end namespace neon -} // end namespace tensorflow - -#endif // TENSORFLOW_CORE_KERNELS_NEON_TYPES_H_ diff --git a/tensorflow/core/kernels/reshape_util.cc b/tensorflow/core/kernels/reshape_util.cc index 72fd0ebb871..d54c081ec90 100644 --- a/tensorflow/core/kernels/reshape_util.cc +++ b/tensorflow/core/kernels/reshape_util.cc @@ -32,15 +32,17 @@ limitations under the License. namespace tensorflow { using CPUDevice = Eigen::ThreadPoolDevice; +using GPUDevice = Eigen::GpuDevice; namespace functor { template <> struct ReshapeSparseTensorFunctor { - Status operator()(const TensorShape &input_shape, + Status operator()(OpKernelContext *context, const TensorShape &input_shape, const TensorShape &output_shape, typename TTypes::ConstMatrix input_indices, typename TTypes::Matrix output_indices) const { + (void)context; // Unused (only used in GPU implementation) const int64 input_rank = input_shape.dims(); const int64 output_rank = output_shape.dims(); const int64 nnz = input_indices.dimension(0); @@ -173,7 +175,7 @@ void ReshapeSparseTensor(OpKernelContext *context, &result_indices)); if (nnz > 0) { OP_REQUIRES_OK(context, functor::ReshapeSparseTensorFunctor()( - input_shape, output_shape, + context, input_shape, output_shape, input_indices_in.matrix(), result_indices->matrix())); } @@ -185,6 +187,10 @@ void ReshapeSparseTensor(OpKernelContext *context, const Tensor &input_shape_in, const Tensor &target_shape_in, \ int output_indices_idx, int output_shape_idx) EXPLICITLY_INSTANTIATE_FUNCTION(CPUDevice); + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +EXPLICITLY_INSTANTIATE_FUNCTION(GPUDevice); +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #undef EXPLICITLY_INSTANTIATE_FUNCTION } // namespace tensorflow diff --git a/tensorflow/core/kernels/reshape_util.h b/tensorflow/core/kernels/reshape_util.h index b3a35651e63..8584057645e 100644 --- a/tensorflow/core/kernels/reshape_util.h +++ b/tensorflow/core/kernels/reshape_util.h @@ -26,7 +26,7 @@ class OpKernelContext; class Tensor; // Reshapes the input indices and input shape to the target shape. -// Note: This template is explicitly instantiated for CPU device only. +// Note: This template is explicitly instantiated for CPU and GPU devices. template void ReshapeSparseTensor(OpKernelContext *context, const Tensor &input_indices_in, @@ -38,7 +38,7 @@ namespace functor { template struct ReshapeSparseTensorFunctor { - Status operator()(const TensorShape &input_shape, + Status operator()(OpKernelContext *context, const TensorShape &input_shape, const TensorShape &output_shape, typename TTypes::ConstMatrix input_indices, typename TTypes::Matrix output_indices) const; diff --git a/tensorflow/core/kernels/reshape_util_gpu.cu.cc b/tensorflow/core/kernels/reshape_util_gpu.cu.cc new file mode 100644 index 00000000000..80bb2122524 --- /dev/null +++ b/tensorflow/core/kernels/reshape_util_gpu.cu.cc @@ -0,0 +1,110 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/reshape_util.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" + +namespace tensorflow { + +using GPUDevice = Eigen::GpuDevice; + +namespace { + +template +__global__ void ReshapeSparseTensorKernel( + const Tindex nnz, const Tindex input_rank, const Tindex output_rank, + const Tindex* __restrict__ input_shape, + const Tindex* __restrict__ output_shape, + const Tindex* __restrict__ input_indices, + Tindex* __restrict__ output_indices) { + GPU_1D_KERNEL_LOOP(sparse_index, nnz) { + const Tindex* input_index = &input_indices[sparse_index * input_rank]; + Tindex* output_index = &output_indices[sparse_index * output_rank]; + int64 dense_index = 0; // int64 to avoid overflow if Tindex is int32 + // Flatten input index from slowest- to fastest-changing dimension. + for (int i = 0; i < input_rank; ++i) { + dense_index = dense_index * input_shape[i] + input_index[i]; + } + // Compute output index from fastest- to slowest-changing dimension. + for (int i = output_rank - 1; i >= 0; --i) { + Tindex output_size = output_shape[i]; + output_index[i] = dense_index % output_size; + dense_index /= output_size; + } + } +} + +} // namespace + +namespace functor { + +template <> +Status ReshapeSparseTensorFunctor::operator()( + OpKernelContext* context, const TensorShape& input_shape, + const TensorShape& output_shape, + typename TTypes::ConstMatrix input_indices, + typename TTypes::Matrix output_indices) const { + const int64 input_rank = input_shape.dims(); + const int64 output_rank = output_shape.dims(); + const int64 nnz = input_indices.dimension(0); + // We copy input_shape and output_shape to the GPU and then launch a kernel + // to compute output_indices. + Tensor input_shape_gpu_t; + TF_RETURN_IF_ERROR(context->allocate_temp(DT_INT64, TensorShape({input_rank}), + &input_shape_gpu_t)); + auto input_shape_gpu = input_shape_gpu_t.flat(); + Tensor output_shape_gpu_t; + TF_RETURN_IF_ERROR(context->allocate_temp( + DT_INT64, TensorShape({output_rank}), &output_shape_gpu_t)); + auto output_shape_gpu = output_shape_gpu_t.flat(); + se::Stream* stream = context->op_device_context()->stream(); + if (!stream) return errors::Internal("No GPU stream available."); + se::DeviceMemoryBase input_shape_gpu_mem(input_shape_gpu.data(), + input_rank * sizeof(int64)); + if (!stream + ->ThenMemcpy(&input_shape_gpu_mem, input_shape.dim_sizes().data(), + input_rank * sizeof(int64)) + .ok()) { + return errors::Internal("Failed to copy input_shape to device"); + } + se::DeviceMemoryBase output_shape_gpu_mem(output_shape_gpu.data(), + output_rank * sizeof(int64)); + if (!stream + ->ThenMemcpy(&output_shape_gpu_mem, output_shape.dim_sizes().data(), + output_rank * sizeof(int64)) + .ok()) { + return errors::Internal("Failed to copy output_shape to device"); + } + const GPUDevice& device = context->template eigen_device(); + auto config = GetGpuLaunchConfig(nnz, device); + return GpuLaunchKernel(ReshapeSparseTensorKernel, config.block_count, + config.thread_per_block, 0, device.stream(), nnz, + /*input_rank=*/input_rank, + /*output_rank=*/output_rank, + /*input_shape=*/input_shape_gpu.data(), + /*output_shape=*/output_shape_gpu.data(), + /*input_indices=*/input_indices.data(), + /*output_indices=*/output_indices.data()); +} + +} // namespace functor + +} // namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/sparse_reshape_op.cc b/tensorflow/core/kernels/sparse_reshape_op.cc index 472a7a270a5..13e9010dcbb 100644 --- a/tensorflow/core/kernels/sparse_reshape_op.cc +++ b/tensorflow/core/kernels/sparse_reshape_op.cc @@ -30,6 +30,7 @@ limitations under the License. namespace tensorflow { using CPUDevice = Eigen::ThreadPoolDevice; +using GPUDevice = Eigen::GpuDevice; template class SparseReshapeOp : public OpKernel { @@ -46,4 +47,13 @@ class SparseReshapeOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("SparseReshape").Device(DEVICE_CPU), SparseReshapeOp) +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +REGISTER_KERNEL_BUILDER(Name("SparseReshape") + .Device(DEVICE_GPU) + .HostMemory("input_shape") + .HostMemory("new_shape") + .HostMemory("output_shape"), + SparseReshapeOp) +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + } // namespace tensorflow diff --git a/tensorflow/core/lib/core/status_test.cc b/tensorflow/core/lib/core/status_test.cc index 32147889c53..eb117a49d28 100644 --- a/tensorflow/core/lib/core/status_test.cc +++ b/tensorflow/core/lib/core/status_test.cc @@ -202,6 +202,54 @@ TEST(Status, ErasePayloadRemovesIt) { ASSERT_EQ(s.GetPayload("Error key"), tensorflow::StringPiece()); } +TEST(Status, GetAllPayloads) { + Status s_error(error::INTERNAL, "Error message"); + s_error.SetPayload("Error key", "foo"); + auto payloads_error_status = s_error.GetAllPayloads(); + ASSERT_EQ(payloads_error_status.size(), 1); + ASSERT_EQ(payloads_error_status["Error key"], "foo"); + + Status s_ok = Status(); + auto payloads_ok_status = s_ok.GetAllPayloads(); + ASSERT_TRUE(payloads_ok_status.empty()); +} + +TEST(Status, OKStatusReplaceAllPayloadsFromErrorStatus) { + // An OK status will should not change after ReplaceAllPayloads() calls. + Status s_error(error::INTERNAL, "Error message"); + s_error.SetPayload("Error key", "foo"); + Status s_ok = Status(); + + s_ok.ReplaceAllPayloads(s_error.GetAllPayloads()); + auto payloads_ok_status = s_ok.GetAllPayloads(); + ASSERT_TRUE(payloads_ok_status.empty()); +} + +TEST(Status, ErrorStatusReplaceAllPayloadsFromOKStatus) { + // An ReplaceAllPayloads() call should not take effect from empty inputs. + Status s_error(error::INTERNAL, "Error message"); + s_error.SetPayload("Error key", "foo"); + Status s_ok = Status(); + + s_error.ReplaceAllPayloads(s_ok.GetAllPayloads()); + ASSERT_EQ(s_error.GetPayload("Error key"), "foo"); +} + +TEST(Status, ErrorStatusReplaceAllPayloadsFromErrorStatus) { + Status s_error1(error::INTERNAL, "Error message"); + s_error1.SetPayload("Error key 1", "foo"); + s_error1.SetPayload("Error key 2", "bar"); + Status s_error2(error::INTERNAL, "Error message"); + s_error2.SetPayload("Error key", "bar"); + ASSERT_EQ(s_error2.GetPayload("Error key"), "bar"); + + s_error2.ReplaceAllPayloads(s_error1.GetAllPayloads()); + ASSERT_EQ(s_error2.GetPayload("Error key 1"), "foo"); + ASSERT_EQ(s_error2.GetPayload("Error key 2"), "bar"); + auto payloads_error_status = s_error2.GetAllPayloads(); + ASSERT_EQ(payloads_error_status.size(), 2); +} + static void BM_TF_CHECK_OK(::testing::benchmark::State& state) { tensorflow::Status s = (state.max_iterations < 0) ? errors::InvalidArgument("Invalid") diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index bef34eec8da..13969310358 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -152,7 +152,6 @@ def cc_proto_library( cc_libs = [], include = None, protoc = "@com_google_protobuf//:protoc", - internal_bootstrap_hack = False, use_grpc_plugin = False, use_grpc_namespace = False, make_default_target_header_only = False, @@ -169,10 +168,6 @@ def cc_proto_library( cc_library. include: a string indicating the include path of the .proto files. protoc: the label of the protocol compiler to generate the sources. - internal_bootstrap_hack: a flag indicate the cc_proto_library is used only - for bootstraping. When it is set to True, no files will be generated. - The rule will simply be a provider for .proto files, so that other - cc_proto_library can depend on it. use_grpc_plugin: a flag to indicate whether to call the grpc C++ plugin when processing the proto files. use_grpc_namespace: the namespace for the grpc services. @@ -194,25 +189,6 @@ def cc_proto_library( if protolib_name == None: protolib_name = name - if internal_bootstrap_hack: - # For pre-checked-in generated files, we add the internal_bootstrap_hack - # which will skip the codegen action. - proto_gen( - name = protolib_name + "_genproto", - srcs = srcs, - includes = includes, - protoc = protoc, - visibility = ["//visibility:public"], - deps = [s + "_genproto" for s in all_protolib_deps], - ) - - # An empty cc_library to make rule dependency consistent. - native.cc_library( - name = name, - **kargs - ) - return - grpc_cpp_plugin = None plugin_options = [] if use_grpc_plugin: @@ -272,10 +248,10 @@ def cc_proto_library( ) native.cc_library( name = header_only_name, + hdrs = gen_hdrs, deps = [ "@com_google_protobuf//:protobuf_headers", ] + header_only_deps + if_static([impl_name]), - hdrs = gen_hdrs, **kargs ) diff --git a/tensorflow/core/platform/default/build_config/BUILD b/tensorflow/core/platform/default/build_config/BUILD index 83eadf2c460..6cc7e731430 100644 --- a/tensorflow/core/platform/default/build_config/BUILD +++ b/tensorflow/core/platform/default/build_config/BUILD @@ -27,6 +27,20 @@ cc_library( deps = [], ) +cc_library( + name = "_empty_lib", + visibility = ["//visibility:private"], +) + +alias( + name = "_cuda_runtime", + actual = select({ + "//tensorflow:framework_shared_object": ":_empty_lib", + "//conditions:default": "//tensorflow/stream_executor/cuda:all_runtime", + }), + visibility = ["//visibility:private"], +) + tf_cuda_library( name = "stream_executor", cuda_deps = [ @@ -44,14 +58,14 @@ tf_cuda_library( "//tensorflow/stream_executor/platform:dso_loader", "//tensorflow/stream_executor/rocm:rocm_platform_id", ] + select({ - "@local_config_cuda//cuda:darwin": ["IOKit"], + "//tensorflow:macos": ["IOKit"], "//conditions:default": [], }) + select({ - "//tensorflow:using_cuda_clang": ["//tensorflow/stream_executor/cuda:all_runtime"], - "//tensorflow:using_cuda_nvcc": ["//tensorflow/stream_executor/cuda:all_runtime"], - "//tensorflow:using_cuda_clang_with_dynamic_build": [], - "//tensorflow:using_cuda_nvcc_with_dynamic_build": [], - "//tensorflow:using_rocm_hipcc": ["//tensorflow/stream_executor/rocm:all_runtime"], + "//tensorflow:using_cuda_clang": [":_cuda_runtime"], + "//tensorflow:using_cuda_nvcc": [":_cuda_runtime"], + "//tensorflow:using_rocm_hipcc": [ + "//tensorflow/stream_executor/rocm:all_runtime", + ], "//conditions:default": [], }), ) @@ -67,7 +81,7 @@ cc_library( ":cuda", ], }) + select({ - "@local_config_cuda//cuda:darwin": ["IOKit"], + "//tensorflow:macos": ["IOKit"], "//conditions:default": [], }), ) @@ -194,7 +208,7 @@ cc_library( "@local_config_cuda//cuda:cudart", ], linkopts = select({ - "@local_config_cuda//cuda:darwin": [ + "//tensorflow:macos": [ "-Wl,-rpath,../local_config_cuda/cuda/lib", "-Wl,-rpath,../local_config_cuda/cuda/extras/CUPTI/lib", ], diff --git a/tensorflow/core/platform/ram_file_system.h b/tensorflow/core/platform/ram_file_system.h index 8a6092beb34..e8cd0b45aac 100644 --- a/tensorflow/core/platform/ram_file_system.h +++ b/tensorflow/core/platform/ram_file_system.h @@ -109,7 +109,7 @@ class RamFileSystem : public FileSystem { const std::string& fname_, TransactionToken* token, std::unique_ptr* result) override { mutex_lock m(mu_); - auto fname = StripPrefix(fname_); + auto fname = StripRamFsPrefix(fname_); if (fs_.find(fname) == fs_.end()) { return errors::NotFound(""); @@ -125,7 +125,7 @@ class RamFileSystem : public FileSystem { Status NewWritableFile(const std::string& fname_, TransactionToken* token, std::unique_ptr* result) override { mutex_lock m(mu_); - auto fname = StripPrefix(fname_); + auto fname = StripRamFsPrefix(fname_); if (fs_.find(fname) == fs_.end()) { fs_[fname] = std::make_shared(); @@ -141,7 +141,7 @@ class RamFileSystem : public FileSystem { Status NewAppendableFile(const std::string& fname_, TransactionToken* token, std::unique_ptr* result) override { mutex_lock m(mu_); - auto fname = StripPrefix(fname_); + auto fname = StripRamFsPrefix(fname_); if (fs_.find(fname) == fs_.end()) { fs_[fname] = std::make_shared(); @@ -163,7 +163,7 @@ class RamFileSystem : public FileSystem { Status FileExists(const std::string& fname_, TransactionToken* token) override { FileStatistics stat; - auto fname = StripPrefix(fname_); + auto fname = StripRamFsPrefix(fname_); return Stat(fname, token, &stat); } @@ -171,14 +171,14 @@ class RamFileSystem : public FileSystem { Status GetChildren(const std::string& dir_, TransactionToken* token, std::vector* result) override { mutex_lock m(mu_); - auto dir = StripPrefix(dir_); + auto dir = StripRamFsPrefix(dir_); auto it = fs_.lower_bound(dir); - while (it != fs_.end() && absl::StartsWith(it->first, dir)) { - auto filename = absl::StripPrefix(absl::StripPrefix(it->first, dir), "/"); + while (it != fs_.end() && StartsWith(it->first, dir)) { + auto filename = StripPrefix(StripPrefix(it->first, dir), "/"); // It is not either (a) the parent directory itself or (b) a subdirectory if (!filename.empty() && filename.find("/") == std::string::npos) { - result->push_back(std::string(filename.data(), filename.size())); + result->push_back(filename); } ++it; } @@ -189,12 +189,12 @@ class RamFileSystem : public FileSystem { Status GetMatchingPaths(const std::string& pattern_, TransactionToken* token, std::vector* results) override { mutex_lock m(mu_); - auto pattern = StripPrefix(pattern_); + auto pattern = StripRamFsPrefix(pattern_); Env* env = Env::Default(); for (auto it = fs_.begin(); it != fs_.end(); ++it) { if (env->MatchPath(it->first, pattern)) { - results->push_back(absl::StrCat("ram://", it->first)); + results->push_back("ram://" + it->first); } } return Status::OK(); @@ -203,10 +203,10 @@ class RamFileSystem : public FileSystem { Status Stat(const std::string& fname_, TransactionToken* token, FileStatistics* stat) override { mutex_lock m(mu_); - auto fname = StripPrefix(fname_); + auto fname = StripRamFsPrefix(fname_); auto it = fs_.lower_bound(fname); - if (it == fs_.end() || !absl::StartsWith(it->first, fname)) { + if (it == fs_.end() || !StartsWith(it->first, fname)) { return errors::NotFound(""); } @@ -226,7 +226,7 @@ class RamFileSystem : public FileSystem { Status DeleteFile(const std::string& fname_, TransactionToken* token) override { mutex_lock m(mu_); - auto fname = StripPrefix(fname_); + auto fname = StripRamFsPrefix(fname_); if (fs_.find(fname) != fs_.end()) { fs_.erase(fname); @@ -239,7 +239,7 @@ class RamFileSystem : public FileSystem { Status CreateDir(const std::string& dirname_, TransactionToken* token) override { mutex_lock m(mu_); - auto dirname = StripPrefix(dirname_); + auto dirname = StripRamFsPrefix(dirname_); auto it = fs_.find(dirname); if (it != fs_.end() && it->second != nullptr) { @@ -253,9 +253,9 @@ class RamFileSystem : public FileSystem { Status RecursivelyCreateDir(const std::string& dirname_, TransactionToken* token) override { - auto dirname = StripPrefix(dirname_); + auto dirname = StripRamFsPrefix(dirname_); - std::vector dirs = absl::StrSplit(dirname, '/'); + std::vector dirs = StrSplit(dirname, "/"); Status last_status; std::string dir = dirs[0]; last_status = CreateDir(dir, token); @@ -270,7 +270,7 @@ class RamFileSystem : public FileSystem { Status DeleteDir(const std::string& dirname_, TransactionToken* token) override { mutex_lock m(mu_); - auto dirname = StripPrefix(dirname_); + auto dirname = StripRamFsPrefix(dirname_); auto it = fs_.find(dirname); if (it == fs_.end()) { @@ -287,7 +287,7 @@ class RamFileSystem : public FileSystem { Status GetFileSize(const std::string& fname_, TransactionToken* token, uint64* file_size) override { mutex_lock m(mu_); - auto fname = StripPrefix(fname_); + auto fname = StripRamFsPrefix(fname_); if (fs_.find(fname) != fs_.end()) { if (fs_[fname] == nullptr) { @@ -302,8 +302,8 @@ class RamFileSystem : public FileSystem { Status RenameFile(const std::string& src_, const std::string& target_, TransactionToken* token) override { mutex_lock m(mu_); - auto src = StripPrefix(src_); - auto target = StripPrefix(target_); + auto src = StripRamFsPrefix(src_); + auto target = StripRamFsPrefix(target_); if (fs_.find(src) != fs_.end()) { fs_[target] = fs_[src]; @@ -320,9 +320,34 @@ class RamFileSystem : public FileSystem { mutex mu_; std::map> fs_; - std::string StripPrefix(std::string name) { - auto sv = absl::StripSuffix(absl::StripPrefix(name, "ram://"), "/"); - return std::string(sv.data(), sv.size()); + std::vector StrSplit(std::string s, std::string delim) { + std::vector ret; + size_t curr_pos = 0; + while ((curr_pos = s.find(delim)) != std::string::npos) { + ret.push_back(s.substr(0, curr_pos)); + s.erase(0, curr_pos + delim.size()); + } + ret.push_back(s); + return ret; + } + + bool StartsWith(std::string s, std::string prefix) { + return s.find(prefix) == 0; + } + + string StripPrefix(std::string s, std::string prefix) { + if (s.find(prefix) == 0) { + return s.erase(0, prefix.size()); + } + return s; + } + + string StripRamFsPrefix(std::string name) { + std::string s = StripPrefix(name, "ram://"); + if (*(s.rbegin()) == '/') { + s.pop_back(); + } + return s; } }; diff --git a/tensorflow/core/platform/status.cc b/tensorflow/core/platform/status.cc index e5cc422b08b..b960b285e65 100644 --- a/tensorflow/core/platform/status.cc +++ b/tensorflow/core/platform/status.cc @@ -222,6 +222,19 @@ bool Status::ErasePayload(tensorflow::StringPiece type_url) { return true; } +const std::unordered_map Status::GetAllPayloads() + const { + if (ok()) return {}; + return state_->payloads; +} + +void Status::ReplaceAllPayloads( + const std::unordered_map& payloads) { + if (ok() || payloads.empty()) return; + if (state_ == nullptr) state_ = std::make_unique(); + state_->payloads = payloads; +} + std::ostream& operator<<(std::ostream& os, const Status& x) { os << x.ToString(); return os; diff --git a/tensorflow/core/platform/status.h b/tensorflow/core/platform/status.h index a61b6e68b70..8ad40f52d6c 100644 --- a/tensorflow/core/platform/status.h +++ b/tensorflow/core/platform/status.h @@ -113,7 +113,7 @@ class Status { // Sets the payload for a non-ok status using a `type_url` key, overwriting // any existing payload for that `type_url`. // - // NOTE: This function does nothing if the Status is ok. + // This function does nothing if the Status is ok. void SetPayload(tensorflow::StringPiece type_url, tensorflow::StringPiece payload); @@ -121,6 +121,15 @@ class Status { // the payload was present. bool ErasePayload(tensorflow::StringPiece type_url); + // Returns all the payload information. + // Returns an empty result if status is ok. + const std::unordered_map GetAllPayloads() const; + + // Copies all the payloads using the input and discards existing payloads. + // Does nothing if status is ok or 'payloads' is empty. + void ReplaceAllPayloads( + const std::unordered_map& payloads); + private: static const std::string& empty_string(); static const std::vector& empty_stack_trace(); diff --git a/tensorflow/core/profiler/convert/post_process_single_host_xplane.cc b/tensorflow/core/profiler/convert/post_process_single_host_xplane.cc index accbe4ecfba..c089a6ab884 100644 --- a/tensorflow/core/profiler/convert/post_process_single_host_xplane.cc +++ b/tensorflow/core/profiler/convert/post_process_single_host_xplane.cc @@ -28,7 +28,8 @@ void MergeHostPlanesAndSortLines(XSpace* space) { XPlane* host_plane = FindOrAddMutablePlaneWithName(space, kHostThreadsPlaneName); std::vector additional_host_planes = FindPlanesWithNames( - *space, {kCuptiDriverApiPlaneName, kPythonTracerPlaneName}); + *space, + {kTpuRuntimePlaneName, kCuptiDriverApiPlaneName, kPythonTracerPlaneName}); if (!additional_host_planes.empty()) { MergePlanes(additional_host_planes, host_plane); RemovePlanes(space, additional_host_planes); diff --git a/tensorflow/core/profiler/internal/gpu/BUILD b/tensorflow/core/profiler/internal/gpu/BUILD index be65a0317b8..00105cdb284 100644 --- a/tensorflow/core/profiler/internal/gpu/BUILD +++ b/tensorflow/core/profiler/internal/gpu/BUILD @@ -1,6 +1,6 @@ +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load( "//tensorflow:tensorflow.bzl", - "if_cuda_is_configured_compat", "tf_copts", "tf_cuda_library", ) @@ -90,7 +90,7 @@ tf_cuda_cc_test( tf_cuda_library( name = "cupti_interface", - hdrs = if_cuda_is_configured_compat(["cupti_interface.h"]), + hdrs = if_cuda(["cupti_interface.h"]), copts = tf_profiler_copts() + tf_copts(), visibility = ["//visibility:public"], deps = [ @@ -107,8 +107,8 @@ tf_cuda_library( # that the wrapper is about the only direct user. tf_cuda_library( name = "cupti_wrapper", - srcs = if_cuda_is_configured_compat(["cupti_wrapper.cc"]), - hdrs = if_cuda_is_configured_compat(["cupti_wrapper.h"]), + srcs = if_cuda(["cupti_wrapper.cc"]), + hdrs = if_cuda(["cupti_wrapper.h"]), copts = tf_profiler_copts() + tf_copts(), linkstatic = 1, visibility = ["//visibility:public"], @@ -119,8 +119,8 @@ tf_cuda_library( tf_cuda_library( name = "cupti_tracer", - srcs = if_cuda_is_configured_compat(["cupti_tracer.cc"]), - hdrs = if_cuda_is_configured_compat(["cupti_tracer.h"]), + srcs = if_cuda(["cupti_tracer.cc"]), + hdrs = if_cuda(["cupti_tracer.h"]), copts = tf_profiler_copts() + tf_copts(), visibility = ["//visibility:public"], deps = [ @@ -140,8 +140,8 @@ tf_cuda_library( tf_cuda_library( name = "nvtx_utils", - srcs = if_cuda_is_configured_compat(["nvtx_utils.cc"]), - hdrs = if_cuda_is_configured_compat(["nvtx_utils.h"]), + srcs = if_cuda(["nvtx_utils.cc"]), + hdrs = if_cuda(["nvtx_utils.h"]), copts = tf_profiler_copts() + tf_copts(), deps = [ "//tensorflow/core:lib", @@ -150,8 +150,8 @@ tf_cuda_library( tf_cuda_library( name = "cupti_collector", - srcs = if_cuda_is_configured_compat(["cupti_collector.cc"]), - hdrs = if_cuda_is_configured_compat(["cupti_collector.h"]), + srcs = if_cuda(["cupti_collector.cc"]), + hdrs = if_cuda(["cupti_collector.h"]), copts = tf_profiler_copts() + tf_copts(), visibility = ["//visibility:public"], deps = [ @@ -192,7 +192,7 @@ cc_library( tf_cuda_library( name = "cupti_utils", - srcs = if_cuda_is_configured_compat(["cupti_utils.cc"]), + srcs = if_cuda(["cupti_utils.cc"]), copts = tf_profiler_copts() + tf_copts(), cuda_deps = [ ":cupti_interface", diff --git a/tensorflow/core/profiler/lib/profiler_session.cc b/tensorflow/core/profiler/lib/profiler_session.cc index 68ca7cc6f57..0ccb1a8f03f 100644 --- a/tensorflow/core/profiler/lib/profiler_session.cc +++ b/tensorflow/core/profiler/lib/profiler_session.cc @@ -118,7 +118,7 @@ ProfilerSession::ProfilerSession(ProfileOptions options) options_(std::move(options)) { #if !defined(IS_MOBILE_PLATFORM) if (!active_) { - status_ = tensorflow::Status(error::UNAVAILABLE, + status_ = tensorflow::Status(error::ALREADY_EXISTS, "Another profiler session is active."); return; } diff --git a/tensorflow/core/profiler/utils/group_events.cc b/tensorflow/core/profiler/utils/group_events.cc index fdea87c66e6..784e30e2e4f 100644 --- a/tensorflow/core/profiler/utils/group_events.cc +++ b/tensorflow/core/profiler/utils/group_events.cc @@ -71,13 +71,6 @@ int64 GetEventType(bool is_host_plane, const EventNode& event) { // KernelExecute event types. return *kernel_event_type; } else { - absl::string_view name = event.GetEventVisitor().Name(); - // Legacy event names appended with arguments. - if (absl::StartsWith(name, "BatchingSessionRun")) { - return HostEventType::kBatchingSessionRun; - } else if (absl::StartsWith(name, "ProcessBatch")) { - return HostEventType::kProcessBatch; - } return HostEventType::kUnknownHostEventType; } } diff --git a/tensorflow/core/profiler/utils/xplane_schema.cc b/tensorflow/core/profiler/utils/xplane_schema.cc index c463cc94ae0..94b8177f32c 100644 --- a/tensorflow/core/profiler/utils/xplane_schema.cc +++ b/tensorflow/core/profiler/utils/xplane_schema.cc @@ -29,6 +29,7 @@ namespace profiler { const absl::string_view kHostThreadsPlaneName = "/host:CPU"; const absl::string_view kGpuPlanePrefix = "/device:GPU:"; const absl::string_view kTpuPlanePrefix = "/device:TPU:"; +const absl::string_view kTpuRuntimePlaneName = "/host:TPU-runtime"; const absl::string_view kCuptiDriverApiPlaneName = "/host:CUPTI"; const absl::string_view kMetadataPlaneName = "/host:metadata"; const absl::string_view kTFStreamzPlaneName = "/host:tfstreamz"; diff --git a/tensorflow/core/profiler/utils/xplane_schema.h b/tensorflow/core/profiler/utils/xplane_schema.h index 35b61ac51a6..5dd0817b6a9 100644 --- a/tensorflow/core/profiler/utils/xplane_schema.h +++ b/tensorflow/core/profiler/utils/xplane_schema.h @@ -33,6 +33,8 @@ TF_CONST_INIT extern const absl::string_view kHostThreadsPlaneName; TF_CONST_INIT extern const absl::string_view kGpuPlanePrefix; // Name prefix of XPlane that contains TPU events. TF_CONST_INIT extern const absl::string_view kTpuPlanePrefix; +// Name prefix of XPlane that contains TPU runtime events. +TF_CONST_INIT extern const absl::string_view kTpuRuntimePlaneName; // Name of XPlane that contains CUPTI driver API generated events. TF_CONST_INIT extern const absl::string_view kCuptiDriverApiPlaneName; // Name of XPlane that contains profile metadata such as XLA debug info. diff --git a/tensorflow/core/protobuf/tpu/compile_metadata.proto b/tensorflow/core/protobuf/tpu/compile_metadata.proto index 3d90cfb1cbf..5c21f078299 100644 --- a/tensorflow/core/protobuf/tpu/compile_metadata.proto +++ b/tensorflow/core/protobuf/tpu/compile_metadata.proto @@ -62,6 +62,10 @@ message TPUCompileMetadataProto { // Name of the node that the arg comes from. string name = 10; + + // Whether to use XLA collectives to broadcast this parameter to all + // replicas, instead of using TensorFlow Send/Recv among the tasks. + bool requires_xla_broadcast = 11; } repeated Arg args = 1; @@ -116,7 +120,5 @@ message TPUCompileMetadataProto { // requested. bool use_spmd_for_xla_partitioning = 15; - // Enables use of XLA collectives for broadcast of replicated parameters to - // all replicas, instead of using TensorFlow Send/Recv. - bool broadcast_replicated_parameters_via_collectives = 16; + reserved 16; // Was broadcast_replicated_parameters_via_collectives } diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index e4a925fa812..a831587b6b0 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 689 // Updated: 2021/2/26 +#define TF_GRAPH_DEF_VERSION 694 // Updated: 2021/3/3 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index fee4f06e55d..b2c8f8b9c68 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -304,6 +304,7 @@ cc_library( "//tensorflow/core/profiler/protobuf:xplane_proto_cc", "//tensorflow/stream_executor/tpu:c_api_decl", "//tensorflow/stream_executor/tpu:proto_helper", + "@com_google_absl//absl/types:optional", ], alwayslink = True, ) diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc index a183c3dc522..13ba2ce0fc5 100644 --- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc @@ -2299,6 +2299,42 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores( return Status::OK(); } +namespace { + +bool XlaBroadcastTypeSupported(const DataType dtype) { + return (dtype == DT_FLOAT || dtype == DT_BFLOAT16 || dtype == DT_INT32 || + dtype == DT_BOOL); +} + +bool XlaBroadcastKindSupported( + const DistributedTPURewritePass::ParameterInfo& params_info, + int param_num) { + // NOTE: This is intended to cover non-sharded data parallel variables, for + // training only. . Is it correct to just check if the arg_type is + // DT_RESOURCE? + return params_info.IsVariableArg(param_num) && + !(params_info.IsPerReplicaArg(param_num) || + params_info.IsDistributedArg(param_num) || + params_info.IsBroadcastArg(param_num) || + params_info.IsConstantArg(param_num)); +} + +bool EnableXlaParamBroadcast( + bool enable_xla_param_broadcast, + const DistributedTPURewritePass::ParameterInfo& params_info, int param_num, + DataType dtype, int num_cores_per_replica) { + // Conditions necessary to use XLA collectives for arg broadcast: + // 1. Globally enabled via enable_xla_param_broadcast. + // 2. DataType must be supported. + // 3. Parameter must be a variable, and not distributed or broadcasted. + // 4. Model parallelism is not currently supported. + return enable_xla_param_broadcast && XlaBroadcastTypeSupported(dtype) && + XlaBroadcastKindSupported(params_info, param_num) && + (num_cores_per_replica == 1); +} + +} // namespace + // Builds a TPUCompile node that compiles the bodies of the function call // `nodes`. Status DistributedTPURewritePass::BuildCompileNode( @@ -2315,7 +2351,7 @@ Status DistributedTPURewritePass::BuildCompileNode( int num_cores_per_replica, const string& compile_device, const xla::DeviceAssignment* xla_device_assignment, const std::vector& dynamic_shape_nodes, Graph* graph, - Node** compile_node, int64 autotuner_thresh) { + Node** compile_node, int64 autotuner_thresh, int num_tasks) { VLOG(1) << "BuildCompileNode"; tpu::TPUCompileMetadataProto proto; @@ -2334,8 +2370,6 @@ Status DistributedTPURewritePass::BuildCompileNode( return s.type() == xla::OpSharding::MAXIMAL; }); proto.set_use_spmd_for_xla_partitioning(use_spmd); - proto.set_broadcast_replicated_parameters_via_collectives( - enable_xla_param_broadcast_); // Get and fill padding map. if (replicate_node != nullptr) { @@ -2383,6 +2417,15 @@ Status DistributedTPURewritePass::BuildCompileNode( arg->set_kind(tpu::TPUCompileMetadataProto::Arg::PARAMETER); } } + + // Use XLA collective primitives to distribute variables to all replicas, + // for multi-host systems. + arg->set_requires_xla_broadcast( + num_tasks > 1 && + EnableXlaParamBroadcast(enable_xla_param_broadcast_, params_info, i, + arg_shape.handle_type /*arg.dtype?*/, + num_cores_per_replica)); + // As long as the argument is not a per-replica one, it should have the same // value for all replicas. For clarity, we keep the (redundant) checks for // variable, broadcast and constant types, to prevent bugs in case new types @@ -2686,20 +2729,39 @@ Status DistributedTPURewritePass::BuildVariableWrites( namespace { // Creates nodes for zero-initialized dummy arguments for TPUExecute nodes. -xla::StatusOr CreatePerHostDummyArgs(const InferredShape& raw_var_shape, - const string& host_cpu_device, - Node* var_read, - absl::string_view name_prefix, - Graph* graph) { +xla::StatusOr MaybeCreatePerHostDummyArgs( + const std::vector& arg_shapes, const string& host_cpu_device, + const DistributedTPURewritePass::ParameterInfo& params_info, Node* var_read, + int var_num, int num_cores_per_replica, Graph* graph) { Status status; - DataType dtype; - TF_RETURN_IF_ERROR(GetNodeAttr(var_read->def(), "dtype", &dtype)); - if (!(dtype == DT_FLOAT || dtype == DT_BFLOAT16 || dtype == DT_INT32 || - dtype == DT_BOOL)) { + if (num_cores_per_replica > 1) { + LOG_FIRST_N(WARNING, 1) << "XLA parameter broadcast is not supported for " + "model-partitioned parameters. Falling back to " + "non-broadcast mode for all parameters."; return var_read; } + DataType dtype; + TF_RETURN_IF_ERROR(GetNodeAttr(var_read->def(), "dtype", &dtype)); + + DeviceNameUtils::ParsedName parsed_device; + TF_RET_CHECK(DeviceNameUtils::ParseFullName(host_cpu_device, &parsed_device)); + TF_RET_CHECK(parsed_device.has_task); + + // Task 0 behaves as the primary task, where variables are assigned. Use the + // variable reads as arguments to TPUExecute. + // For other tasks, create dummies if the graph meets preconditions. + int64 orig_arg_num = var_num + params_info.NumPerReplicaArgs() + + params_info.NumDistributedArgs() + + params_info.NumBroadcastArgs(); + if (parsed_device.task == 0 || + !EnableXlaParamBroadcast(/*enable_xla_param_broadcast=*/true, params_info, + orig_arg_num, dtype, num_cores_per_replica)) { + return var_read; + } + + auto raw_var_shape = arg_shapes[orig_arg_num]; TensorShape var_shape; if (!raw_var_shape.handle_shape.AsTensorShape(&var_shape) && !raw_var_shape.shape.AsTensorShape(&var_shape)) { @@ -2707,6 +2769,8 @@ xla::StatusOr CreatePerHostDummyArgs(const InferredShape& raw_var_shape, } // Const - shape_as_tensor + const std::string name_prefix = strings::StrCat( + var_read->name(), absl::StrFormat("/dummy_%d", parsed_device.task)); NodeDef shape_tensor_def; shape_tensor_def.set_op("Const"); shape_tensor_def.set_name(graph->NewName( @@ -2801,12 +2865,6 @@ xla::StatusOr CreateOrGetPerHostVariableCopy( return it->second[var_index]; } - // Variable replication relies on identification of a master. - DeviceNameUtils::ParsedName parsed_device; - TF_RET_CHECK(DeviceNameUtils::ParseFullName(host_cpu_device, &parsed_device)); - TF_RET_CHECK(parsed_device.has_task); - VLOG(1) << "Creating per-host IdentityN node for task " << parsed_device.task; - DataTypeVector dtypes; // Per-variable data source for TPUExecute. std::vector index_mapping; @@ -2814,8 +2872,9 @@ xla::StatusOr CreateOrGetPerHostVariableCopy( dtypes.reserve(variable_reads.size()); for (int64 i = 0; i < variable_reads.size(); ++i) { Node* read = variable_reads[i]; - int64 orig_arg_num = - i + params_info.NumPerReplicaArgs() + params_info.NumBroadcastArgs(); + int64 orig_arg_num = i + params_info.NumPerReplicaArgs() + + params_info.NumDistributedArgs() + + params_info.NumBroadcastArgs(); if (arg_shardings[orig_arg_num].type() != xla::OpSharding::OTHER) { // We haven't built the IdentityN node yet, so temporarily use nullptr. index_mapping.push_back( @@ -2843,34 +2902,18 @@ xla::StatusOr CreateOrGetPerHostVariableCopy( if (index_mapping[i].node == nullptr) { // Fill index_mapping with the actual IdentityN node. index_mapping[i].node = id_node; - if (parsed_device.task == 0 || !enable_xla_param_broadcast) { - // XLA broadcast mode is not enabled, so use the variable reads as args - // to TPUExecuteOp. For task 0, variable reads are always used - // regardless of XLA broadcast. - + if (!enable_xla_param_broadcast) { // Add the variable read edge to id_node. graph->AddEdge(variable_reads[i], 0, id_node, index_mapping[i].index); } else { - // XLA broadcast mode is enabled. Create zero-valued dummy tensors to - // use as variable args in the TPUExecuteOp. - int64 orig_arg_num = i + params_info.NumPerReplicaArgs() + - params_info.NumBroadcastArgs(); - if (num_cores_per_replica > 1) { - LOG(WARNING) << "XLA parameter broadcast is only supported for " - "replicated parameters. Falling back to " - "non-broadcast mode for the parameter associated " - "with the following variable read: " - << variable_reads[i]->name(); - graph->AddEdge(variable_reads[i], 0, id_node, index_mapping[i].index); - continue; - } - string dummy_name = - strings::StrCat(variable_reads[i]->name(), - absl::StrFormat("/dummy_%d", parsed_device.task)); + // XLA param broadcast mode is enabled. Create zero-valued dummy + // tensors to use as variable args in the TPUExecuteOp, instead of + // original variable reads. TF_ASSIGN_OR_RETURN( Node * var_read, - CreatePerHostDummyArgs(arg_shapes[orig_arg_num], host_cpu_device, - variable_reads[i], dummy_name, graph)); + MaybeCreatePerHostDummyArgs(arg_shapes, host_cpu_device, + params_info, variable_reads[i], i, + num_cores_per_replica, graph)); graph->AddEdge(var_read, 0, id_node, index_mapping[i].index); } } @@ -4323,7 +4366,7 @@ DistributedTPURewritePass::BuildCompilationStatusReturnNodes( arg_types, guaranteed_constant_nodes, session_handle, arg_sharding, arg_fast_mem, arg_names, retval_sharding, num_cores_per_replica, /*compile_device=*/tpu_compilation_device, xla_device_assignment.get(), - dynamic_shape_nodes, graph, &compile_node, autotuner_thresh)); + dynamic_shape_nodes, graph, &compile_node, autotuner_thresh, num_tasks)); // Compilation must be sequenced after the control node if the TPU computation // in a control-flow construct, such as a loop. diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h index fd755bcefbc..acbe4e00963 100644 --- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h @@ -362,7 +362,7 @@ class DistributedTPURewritePass : public GraphOptimizationPass { int num_cores_per_replica, const string& compile_device, const xla::DeviceAssignment* xla_device_assignment, const std::vector& dynamic_shape_nodes, Graph* graph, - Node** compile_node, int64 autotuner_thresh); + Node** compile_node, int64 autotuner_thresh, int num_tasks); // Builds a TPUCompileSucceededAssert node that verifies that compilation // succeeded and replaces the TPUCompilationStatus node in the graph. diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index 08afc821053..96a4c2c2b97 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -956,6 +956,7 @@ cc_library( name = "tpu_ordinal_selector_op", srcs = ["tpu_ordinal_selector_op.cc"], deps = [ + ":tpu_ordinal_selector", "//tensorflow/core:framework", ], alwayslink = 1, @@ -968,3 +969,14 @@ cc_library( "//tensorflow/core:framework", ], ) + +cc_library( + name = "tpu_ordinal_selector", + hdrs = ["tpu_ordinal_selector.h"], + deps = [ + ":tpu_ordinal_selector_interface", + "//tensorflow/core:framework", + "//tensorflow/core/tpu:tpu_api", + "//tensorflow/core/tpu:tpu_ops_c_api_hdrs", + ], +) diff --git a/tensorflow/core/tpu/kernels/tpu_ordinal_selector.h b/tensorflow/core/tpu/kernels/tpu_ordinal_selector.h new file mode 100644 index 00000000000..faf78f97dc4 --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_ordinal_selector.h @@ -0,0 +1,58 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_ORDINAL_SELECTOR_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_ORDINAL_SELECTOR_H_ + +#include "tensorflow/core/tpu/kernels/tpu_ordinal_selector_interface.h" +#include "tensorflow/core/tpu/tpu_api.h" +#include "tensorflow/core/tpu/tpu_ops_c_api.h" + +namespace tensorflow { +namespace tpu { + +// A reserved ID for deferred core selection. Intentionally set at a number +// that is more than the number of cores available in a future system. +constexpr int32 kDeferredCoreSelectionReserved = -8193; + +class TPUOrdinalSelector : TPUOrdinalSelectorInterface { + public: + explicit TPUOrdinalSelector(int num_cores_per_replica = 1) { + OpsApiFn()->TfTpuOrdinalSelector_CreateFn(&ordinal_selector_, + num_cores_per_replica); + } + ~TPUOrdinalSelector() override { + OpsApiFn()->TfTpuOrdinalSelector_DestroyFn(ordinal_selector_); + } + int64 GetOrdinal(absl::optional key, int64_t* req_id) override { + int64 ordinal; + OpsApiFn()->TfTpuOrdinalSelector_GetOrdinalFn(ordinal_selector_, key, + req_id, &ordinal); + return ordinal; + } + void DequeueFromCoreSelector(int32_t device_ordinal, + int64_t req_id) override { + OpsApiFn()->TfTpuOrdinalSelector_DequeueFromCoreSelectorFn( + ordinal_selector_, device_ordinal, req_id); + } + + private: + TfTpuOrdinalSelector* ordinal_selector_; +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_ORDINAL_SELECTOR_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_ordinal_selector_interface.h b/tensorflow/core/tpu/kernels/tpu_ordinal_selector_interface.h index 4a08efae7ef..658e2f48295 100644 --- a/tensorflow/core/tpu/kernels/tpu_ordinal_selector_interface.h +++ b/tensorflow/core/tpu/kernels/tpu_ordinal_selector_interface.h @@ -24,7 +24,7 @@ namespace tpu { class TPUOrdinalSelectorInterface { public: virtual ~TPUOrdinalSelectorInterface() = default; - virtual int64 GetOrdinal(int64_t* req_id) = 0; + virtual int64 GetOrdinal(absl::optional key, int64_t* req_id) = 0; virtual void DequeueFromCoreSelector(int32_t device_ordinal, int64_t req_id) = 0; }; diff --git a/tensorflow/core/tpu/kernels/tpu_ordinal_selector_op.cc b/tensorflow/core/tpu/kernels/tpu_ordinal_selector_op.cc index 13a624b92f7..c6da029d417 100644 --- a/tensorflow/core/tpu/kernels/tpu_ordinal_selector_op.cc +++ b/tensorflow/core/tpu/kernels/tpu_ordinal_selector_op.cc @@ -19,14 +19,11 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/tpu/kernels/tpu_ordinal_selector.h" namespace tensorflow { namespace { -// A reserved ID for deferred core selection. Intentionally set at a number -// that is more than the number of cores available in a future system. -constexpr int32 kDeferredCoreSelectionReserved = -8193; - // TPUOrdinalSelectorOp is a no-op for backward compatibility. The core // selection algorithm happens inside TPUPartitionedCall. class TPUOrdinalSelectorOp : public OpKernel { @@ -37,7 +34,7 @@ class TPUOrdinalSelectorOp : public OpKernel { void Compute(OpKernelContext* ctx) override { Tensor output(DT_INT32, TensorShape({})); - output.flat().setValues({kDeferredCoreSelectionReserved}); + output.flat().setValues({tpu::kDeferredCoreSelectionReserved}); ctx->set_output(0, output); ctx->SetStatus(Status::OK()); } diff --git a/tensorflow/core/tpu/tpu_initializer_helper.cc b/tensorflow/core/tpu/tpu_initializer_helper.cc index 5fb5a6b1940..f89751d4ba1 100644 --- a/tensorflow/core/tpu/tpu_initializer_helper.cc +++ b/tensorflow/core/tpu/tpu_initializer_helper.cc @@ -53,14 +53,14 @@ bool TryAcquireTpuLock() { // This lock is held until the process exits intentionally. The underlying // TPU device will be held on until it quits. if (lockf(fd, F_TLOCK, 0) != 0) { - LOG(ERROR) << "libtpu.so already in used by another process. Not " - "attempting to load libtpu.so in this process."; + LOG(INFO) << "libtpu.so already in used by another process. Not " + "attempting to load libtpu.so in this process."; should_load_library = false; } else { should_load_library = true; } } else { - LOG(INFO) << "TPU_HOST_BOUNDS is set, allowing multiple libtpu.so loads."; + VLOG(1) << "TPU_HOST_BOUNDS is set, allowing multiple libtpu.so loads."; should_load_library = true; } } diff --git a/tensorflow/core/tpu/tpu_library_init_fns.inc b/tensorflow/core/tpu/tpu_library_init_fns.inc index 249830b1e2f..410cd844f97 100644 --- a/tensorflow/core/tpu/tpu_library_init_fns.inc +++ b/tensorflow/core/tpu/tpu_library_init_fns.inc @@ -82,6 +82,12 @@ tensorflow::Status SetTpuOpsStructFns(void* library_handle) { TFTPU_SET_FN(ops_api_fn, TfTpu_InitializeTpuModelServer); + TFTPU_SET_FN(ops_api_fn, TfTpuOrdinalSelector_Create); + TFTPU_SET_FN(ops_api_fn, TfTpuOrdinalSelector_Destroy); + TFTPU_SET_FN(ops_api_fn, TfTpuOrdinalSelector_GetOrdinal); + TFTPU_SET_FN(ops_api_fn, TfTpuOrdinalSelector_DequeueFromCoreSelector); + TFTPU_SET_FN(ops_api_fn, TfTpu_GetTpuPartitionedCallParams); + return tensorflow::Status::OK(); } diff --git a/tensorflow/core/tpu/tpu_ops_c_api.h b/tensorflow/core/tpu/tpu_ops_c_api.h index 75b5d07b037..a4eac987083 100644 --- a/tensorflow/core/tpu/tpu_ops_c_api.h +++ b/tensorflow/core/tpu/tpu_ops_c_api.h @@ -19,6 +19,7 @@ limitations under the License. #include +#include "absl/types/optional.h" #include "tensorflow/core/tpu/libtftpu.h" #include "tensorflow/stream_executor/tpu/c_api_decl.h" #include "tensorflow/stream_executor/tpu/proto_helper.h" @@ -84,6 +85,15 @@ struct CompilationCacheKeyResult { typedef struct XLA_TpuNodeContext XLA_TpuNodeContext; +typedef struct TfTpu_OrdinalSelector TfTpuOrdinalSelector; + +struct TpuPartitionedCall_Params { + bool input_shape_opt; + bool group_tensors_for_packing; + int32_t minimum_input_tensors_packing; + int32_t minimum_output_tensors_packing; +}; + // Compiles Mlir or TF function computation by lowering into HLO IR and returns // `count` number of TPU programs ready for execution. // The API allocates the `XLA_TpuProgram*[]` array `tpu_programs` and creates @@ -151,6 +161,23 @@ TFTPU_CAPI_EXPORT void TpuMeshState_Free(XLA_TpuMeshState* mesh_state); TFTPU_CAPI_EXPORT void* TpuMeshState_MeshCommonState( XLA_TpuMeshState* mesh_state); +TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_Create( + TfTpuOrdinalSelector** ordinal_selector, int num_cores_per_replica); + +TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_Destroy( + TfTpuOrdinalSelector* ordinal_selector); + +TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_GetOrdinal( + TfTpuOrdinalSelector* ordinal_selector, absl::optional key, + int64_t* req_id, int64_t* ordinal); + +TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_DequeueFromCoreSelector( + TfTpuOrdinalSelector* ordinal_selector, int32_t device_ordinal, + int64_t req_id); + +TFTPU_CAPI_EXPORT void TfTpu_GetTpuPartitionedCallParams( + TpuPartitionedCall_Params* params); + typedef struct TpuExecutable_LoadProgramAndEnqueueToStream_Params { int32_t struct_size; void* priv; @@ -500,6 +527,12 @@ struct TfTpu_OpsApiFn { TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CompactionSupported); TFTPU_ADD_FN_IN_STRUCT(TfTpu_InitializeTpuModelServer); + + TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_Create); + TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_Destroy); + TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_GetOrdinal); + TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_DequeueFromCoreSelector); + TFTPU_ADD_FN_IN_STRUCT(TfTpu_GetTpuPartitionedCallParams); }; } // extern "C" diff --git a/tensorflow/core/util/mkl_threadpool.h b/tensorflow/core/util/mkl_threadpool.h index 493f7732b8f..713c8ea13fa 100644 --- a/tensorflow/core/util/mkl_threadpool.h +++ b/tensorflow/core/util/mkl_threadpool.h @@ -24,16 +24,18 @@ limitations under the License. #include #include #include + #include "mkldnn.hpp" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/threadpool.h" #define EIGEN_USE_THREADS + +namespace tensorflow { + #ifdef ENABLE_MKLDNN_THREADPOOL using dnnl::stream_attr; using dnnl::threadpool_iface; -namespace tensorflow { - // Divide 'n' units of work equally among 'teams' threads. If 'n' is not // divisible by 'teams' and has a remainder 'r', the first 'r' teams have one // unit of work more than the rest. Returns the range of work that belongs to @@ -106,39 +108,17 @@ struct MklDnnThreadPool : public dnnl::threadpool_iface { Eigen::ThreadPoolInterface* eigen_interface_ = nullptr; }; -class MklDnnThreadPoolWrapper { - public: - static MklDnnThreadPoolWrapper& GetInstance() { - static MklDnnThreadPoolWrapper instance_; - return instance_; - } - MklDnnThreadPool* CreateThreadPoolPtr(OpKernelContext* ctx) { - mutex_lock l(m_); - if (threadpool_map_.empty() || - threadpool_map_.find(ctx->device()) == threadpool_map_.end()) { - auto tp_iface = new MklDnnThreadPool(ctx); - threadpool_map_.emplace(std::make_pair(ctx->device(), tp_iface)); - return tp_iface; - } else { - auto entry = threadpool_map_.find(ctx->device()); - return entry->second; - } - } +#else - private: - mutex m_; - std::unordered_map threadpool_map_; - MklDnnThreadPoolWrapper() {} - MklDnnThreadPoolWrapper(const MklDnnThreadPoolWrapper&) = delete; - MklDnnThreadPoolWrapper& operator=(const MklDnnThreadPoolWrapper&) = delete; - ~MklDnnThreadPoolWrapper() { - for (auto& tp : threadpool_map_) { - delete tp.second; - } - } +// This struct was just added to enable successful OMP-based build. +struct MklDnnThreadPool { + MklDnnThreadPool() = default; + MklDnnThreadPool(OpKernelContext* ctx) {} }; -} // namespace tensorflow #endif // ENABLE_MKLDNN_THREADPOOL + +} // namespace tensorflow + #endif // INTEL_MKL #endif // TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_ diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 4cfc380dfe1..0a997599357 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -222,13 +222,11 @@ inline bool array_cmp(const T* a1, const T* a2, size_t size) { return true; } -inline mkldnn::stream* CreateStream(OpKernelContext* ctx, +inline mkldnn::stream* CreateStream(MklDnnThreadPool* eigen_tp, const engine& engine) { #ifdef ENABLE_MKLDNN_THREADPOOL stream_attr tp_stream_attr(engine::kind::cpu); - if (ctx != nullptr) { - auto eigen_tp = - MklDnnThreadPoolWrapper::GetInstance().CreateThreadPoolPtr(ctx); + if (eigen_tp != nullptr) { tp_stream_attr.set_threadpool(eigen_tp); stream* tp_stream = new stream(engine, stream::flags::default_flags, tp_stream_attr); @@ -611,12 +609,18 @@ inline void ExecutePrimitive(const std::vector& net, OpKernelContext* context = nullptr) { DCHECK(net_args); DCHECK_EQ(net.size(), net_args->size()); - stream* cpu_stream = CreateStream(context, cpu_engine); + std::unique_ptr cpu_stream; + MklDnnThreadPool eigen_tp; + if (context != nullptr) { + eigen_tp = MklDnnThreadPool(context); + cpu_stream.reset(CreateStream(&eigen_tp, cpu_engine)); + } else { + cpu_stream.reset(CreateStream(nullptr, cpu_engine)); + } for (size_t i = 0; i < net.size(); ++i) { net.at(i).execute(*cpu_stream, net_args->at(i)); } cpu_stream->wait(); - delete cpu_stream; } template inline Status ConvertMklToTF(OpKernelContext* context, @@ -1494,7 +1498,13 @@ class MklDnnData { reorder_memory_ = new memory(op_md, engine); auto* prim = FindOrCreateReorder(user_memory_, reorder_memory_); std::shared_ptr cpu_stream; - cpu_stream.reset(CreateStream(context, prim->GetEngine())); + MklDnnThreadPool eigen_tp; + if (context != nullptr) { + eigen_tp = MklDnnThreadPool(context); + cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine())); + } else { + cpu_stream.reset(CreateStream(nullptr, prim->GetEngine())); + } std::vector net; net.push_back(*(prim->GetPrimitive())); std::vector net_args; @@ -1555,7 +1565,13 @@ class MklDnnData { reorder_memory_ = new memory(op_md, engine, reorder_data_handle); auto* prim = FindOrCreateReorder(user_memory_, reorder_memory_); std::shared_ptr cpu_stream; - cpu_stream.reset(CreateStream(context, prim->GetEngine())); + MklDnnThreadPool eigen_tp; + if (context != nullptr) { + eigen_tp = MklDnnThreadPool(context); + cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine())); + } else { + cpu_stream.reset(CreateStream(nullptr, prim->GetEngine())); + } std::vector net; net.push_back(*(prim->GetPrimitive())); std::vector net_args; @@ -1661,7 +1677,13 @@ class MklDnnData { net_args.push_back( {{MKLDNN_ARG_FROM, *reorder_memory_}, {MKLDNN_ARG_TO, *user_memory_}}); std::shared_ptr cpu_stream; - cpu_stream.reset(CreateStream(ctx, prim->GetEngine())); + MklDnnThreadPool eigen_tp; + if (ctx != nullptr) { + eigen_tp = MklDnnThreadPool(ctx); + cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine())); + } else { + cpu_stream.reset(CreateStream(nullptr, prim->GetEngine())); + } execute_primitives(net, cpu_stream, net_args); } }; diff --git a/tensorflow/lite/core/shims/cc/kernels/register.h b/tensorflow/lite/core/shims/cc/kernels/register.h index 05c96eb47f7..0dedfc99068 100644 --- a/tensorflow/lite/core/shims/cc/kernels/register.h +++ b/tensorflow/lite/core/shims/cc/kernels/register.h @@ -21,6 +21,9 @@ namespace tflite_shims { namespace ops { namespace builtin { using BuiltinOpResolver = ::tflite::ops::builtin::BuiltinOpResolver; +using BuiltinOpResolverWithoutDefaultDelegates = + ::tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates; + } // namespace builtin } // namespace ops } // namespace tflite_shims diff --git a/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc b/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc index e53e9f7b5cd..67635a5781e 100644 --- a/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc +++ b/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc @@ -73,6 +73,7 @@ const std::set& GetFlexAllowlist() { "AvgPool3DGrad", "AvgPoolGrad", "BatchCholesky", + "BatchDatasetV2", "BatchMatMul", "BatchMatMulV2", "BatchMatrixDeterminant", @@ -266,6 +267,7 @@ const std::set& GetFlexAllowlist() { "LogicalNot", "LogicalOr", "LoopCond", + "MapDataset", "MatMul", "MatrixDeterminant", "MatrixDiag", @@ -301,6 +303,7 @@ const std::set& GetFlexAllowlist() { "Minimum", "MirrorPad", "MirrorPadGrad", + "ModelDataset", "Mul", "MulNoNan", "Multinomial", @@ -316,6 +319,7 @@ const std::set& GetFlexAllowlist() { "NotEqual", "OneHot", "OnesLike", + "OptimizeDatasetV2", "Pack", "Pad", "PadV2", @@ -392,6 +396,7 @@ const std::set& GetFlexAllowlist() { "Reciprocal", "ReciprocalGrad", "Recv", + "ReduceDataset", "ReduceJoin", "RefEnter", "RefExit", diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index 6351f8debc3..93277e3b2c6 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -2541,15 +2541,14 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context, bool allow_quant_ops, // the tflite model representation tensors are created lazily, so there is no // guarantee that the order will match the source model tensors order. absl::Status PrecreateIOTensors( - TfLiteContext* context, GraphFloat32* graph, TfLiteIntArray* io_tensors, + TfLiteContext* context, GraphFloat32* graph, const std::vector& io_ids, absl::flat_hash_map* quant_conversion_map, absl::flat_hash_map* tensor_to_value) { - for (int i = 0; i < io_tensors->size; ++i) { - const int tensor_index = io_tensors->data[i]; - const TfLiteTensor& tflite_tensor = context->tensors[tensor_index]; + for (const auto& id : io_ids) { + const TfLiteTensor& tflite_tensor = context->tensors[id]; if (tflite::IsConstantTensor(&tflite_tensor)) continue; RETURN_IF_ERROR(ObjectReader::ReadNonConstantTensor( - context, tensor_to_value, quant_conversion_map, graph, tensor_index)); + context, tensor_to_value, quant_conversion_map, graph, id)); } return absl::OkStatus(); } @@ -2596,6 +2595,22 @@ absl::Status BuildModel(TfLiteContext* context, const TfLiteDelegateParams* delegate_params, GraphFloat32* graph, absl::flat_hash_map* quant_conversion_map) { + std::vector inputs(delegate_params->input_tensors->size); + std::vector outputs(delegate_params->output_tensors->size); + for (int i = 0; i < delegate_params->input_tensors->size; i++) { + inputs[i] = delegate_params->input_tensors->data[i]; + } + for (int i = 0; i < delegate_params->output_tensors->size; i++) { + outputs[i] = delegate_params->output_tensors->data[i]; + } + return BuildModelEnforceIO(context, delegate_params, inputs, outputs, graph, + quant_conversion_map); +} + +absl::Status BuildModelEnforceIO( + TfLiteContext* context, const TfLiteDelegateParams* delegate_params, + const std::vector& input_ids, const std::vector& output_ids, + GraphFloat32* graph, absl::flat_hash_map* quant_conversion_map) { std::vector> operations; std::vector tflite_nodes; for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) { @@ -2623,11 +2638,10 @@ absl::Status BuildModel(TfLiteContext* context, } absl::flat_hash_map tensor_to_value; std::vector variable_inputs_to_value_id; - RETURN_IF_ERROR(PrecreateIOTensors(context, graph, - delegate_params->input_tensors, + + RETURN_IF_ERROR(PrecreateIOTensors(context, graph, input_ids, quant_conversion_map, &tensor_to_value)); - RETURN_IF_ERROR(PrecreateIOTensors(context, graph, - delegate_params->output_tensors, + RETURN_IF_ERROR(PrecreateIOTensors(context, graph, output_ids, quant_conversion_map, &tensor_to_value)); for (int i = 0; i < operations.size(); ++i) { TfLiteNode* tflite_node; diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.h b/tensorflow/lite/delegates/gpu/common/model_builder.h index ab18f056d58..4529666883e 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.h +++ b/tensorflow/lite/delegates/gpu/common/model_builder.h @@ -49,6 +49,15 @@ absl::Status BuildModel( GraphFloat32* graph, absl::flat_hash_map* quant_conversion_map = nullptr); +// Same as BuildModel, but enforces user-provided input/output indices instead +// of using delegate_params->inputs and delegate_params->outputs for +// inputs/outputs preallocating. +absl::Status BuildModelEnforceIO( + TfLiteContext* context, const TfLiteDelegateParams* delegate_params, + const std::vector& input_ids, const std::vector& output_ids, + GraphFloat32* graph, + absl::flat_hash_map* quant_conversion_map = nullptr); + // Same as above but also apply all transformations on the final graph. // Prefer using this method instead of BuildModel. // diff --git a/tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.cc b/tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.cc index 7ba3de641ef..e4b59d8fb53 100644 --- a/tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.cc +++ b/tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.cc @@ -34,13 +34,21 @@ namespace { class DelegateContext { public: + struct DelegateData { + std::vector input_ids; + std::vector output_ids; + GraphFloat32* graph; + }; bool Init(TfLiteContext* context, const TfLiteDelegateParams* delegate_params) { - auto denormalized_graph = - reinterpret_cast(delegate_params->delegate->data_); - return denormalized_graph - ? BuildModel(context, delegate_params, denormalized_graph).ok() - : false; + const auto* delegate_data = + reinterpret_cast(delegate_params->delegate->data_); + + return delegate_data->graph && + BuildModelEnforceIO(context, delegate_params, + delegate_data->input_ids, + delegate_data->output_ids, delegate_data->graph) + .ok(); } }; @@ -82,7 +90,11 @@ absl::Status BuildFromFlatBuffer(const tflite::FlatBufferModel& flatbuffer, return absl::InternalError("Unable to prepare TfLite interpreter."); } TfLiteDelegate delegate; - delegate.data_ = graph; + + DelegateContext::DelegateData delegate_data{interpreter->inputs(), + interpreter->outputs(), graph}; + + delegate.data_ = &delegate_data; delegate.flags = kTfLiteDelegateFlagsNone; delegate.Prepare = DelegatePrepare; delegate.CopyFromBufferHandle = nullptr; diff --git a/tensorflow/lite/experimental/acceleration/configuration/BUILD b/tensorflow/lite/experimental/acceleration/configuration/BUILD index 720766503cc..5a437808af7 100644 --- a/tensorflow/lite/experimental/acceleration/configuration/BUILD +++ b/tensorflow/lite/experimental/acceleration/configuration/BUILD @@ -180,3 +180,15 @@ cc_library( ], alwayslink = 1, # For registration to always run. ) + +cc_library( + name = "xnnpack_plugin", + srcs = ["xnnpack_plugin.cc"], + deps = [ + ":configuration_fbs", + ":delegate_registry", + "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", + "@com_google_absl//absl/memory", + ], + alwayslink = 1, # For registration to always run. +) diff --git a/tensorflow/lite/experimental/acceleration/configuration/xnnpack_plugin.cc b/tensorflow/lite/experimental/acceleration/configuration/xnnpack_plugin.cc new file mode 100644 index 00000000000..467cce1a604 --- /dev/null +++ b/tensorflow/lite/experimental/acceleration/configuration/xnnpack_plugin.cc @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h" +#include "tensorflow/lite/experimental/acceleration/configuration/delegate_registry.h" + +namespace tflite { +namespace delegates { +class XNNPackPlugin : public DelegatePluginInterface { + public: + TfLiteDelegatePtr Create() override { + return TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&options_), + TfLiteXNNPackDelegateDelete); + } + int GetDelegateErrno(TfLiteDelegate* from_delegate) override { return 0; } + static std::unique_ptr New( + const TFLiteSettings& acceleration) { + return absl::make_unique(acceleration); + } + explicit XNNPackPlugin(const TFLiteSettings& tflite_settings) + : options_(TfLiteXNNPackDelegateOptionsDefault()) { + const auto* xnnpack_settings = tflite_settings.xnnpack_settings(); + if (xnnpack_settings) { + options_.num_threads = xnnpack_settings->num_threads(); + } + } + + private: + TfLiteXNNPackDelegateOptions options_; +}; + +TFLITE_REGISTER_DELEGATE_FACTORY_FUNCTION(XNNPackPlugin, XNNPackPlugin::New); + +} // namespace delegates +} // namespace tflite diff --git a/tensorflow/lite/experimental/kernels/unidirectional_sequence_gru.cc b/tensorflow/lite/experimental/kernels/unidirectional_sequence_gru.cc index 1fd6dabb52f..2a957da3e91 100644 --- a/tensorflow/lite/experimental/kernels/unidirectional_sequence_gru.cc +++ b/tensorflow/lite/experimental/kernels/unidirectional_sequence_gru.cc @@ -62,6 +62,12 @@ void GruImpl(const TfLiteTensor* input, const TfLiteTensor* input_state, tflite::FullyConnectedParams fc_params; fc_params.float_activation_min = std::numeric_limits::lowest(); fc_params.float_activation_max = std::numeric_limits::max(); + + // The lhs is cacheable only when both gate weight & candidate weight are both + // constants. + fc_params.lhs_cacheable = + IsConstantTensor(gate_weight) && IsConstantTensor(candidate_weight); + fc_params.rhs_cacheable = false; for (int i = 0; i < n_time; ++i) { gru_cell::GruCell( input_shape, input_data, state_shape, input_state_data, diff --git a/tensorflow/lite/experimental/quantization_debugger/debugger.py b/tensorflow/lite/experimental/quantization_debugger/debugger.py index bfccd35a898..5dedcbd9f81 100644 --- a/tensorflow/lite/experimental/quantization_debugger/debugger.py +++ b/tensorflow/lite/experimental/quantization_debugger/debugger.py @@ -32,6 +32,8 @@ _DEFAULT_LAYER_DEBUG_METRICS = { 'mean_square_error': lambda diffs: np.average(diffs**2), } +_NUMERIC_VERIFY_OP_NAME = 'NumericVerify' + def _get_quant_params( tensor_detail: Mapping[str, Any]) -> Optional[Tuple[float, int]]: @@ -225,8 +227,7 @@ class QuantizationDebugger: for metric_name, metric in model_statistics.items() } - def _set_input_tensors(self, - interpreter: tf.lite.Interpreter, + def _set_input_tensors(self, interpreter: tf.lite.Interpreter, tensor_data: Sequence[np.ndarray], initialize: bool) -> None: """Sets input tensors into TFLite model Interpreter. @@ -286,17 +287,30 @@ class QuantizationDebugger: def _get_numeric_verify_tensor_details(self) -> List[str]: """Returns all names of all tensors from NumericVerify op.""" + # pylint: disable=protected-access if not self._numeric_verify_tensor_details: - self._numeric_verify_tensor_details = [ - detail for detail in self._quant_interpreter.get_tensor_details() - if detail['name'].startswith('NumericVerify') - ] + self._numeric_verify_tensor_details = [] + for op_info in self._quant_interpreter._get_ops_details(): + if op_info['op_name'] == _NUMERIC_VERIFY_OP_NAME: + self._numeric_verify_tensor_details.append( + self._quant_interpreter._get_tensor_details( + op_info['outputs'][0])) + # pylint: enable=protected-access return self._numeric_verify_tensor_details - def _get_operand_index(self, numeric_verify_name: str) -> int: - """Gets the index of NumericVerify Op's quantized input tensor.""" - tensor_idx = numeric_verify_name.rsplit(':', 1)[-1] - return int(tensor_idx) + def _get_operand_name_and_index(self, + numeric_verify_name: str) -> Tuple[str, int]: + """Gets the index and name of NumericVerify Op's quantized input tensor. + + Args: + numeric_verify_name: name of the NumericVerify op's output tensor. It has + format of `NumericVerify/{quantized_tensor_name}:{quantized_tensor_idx}` + + Returns: + Tuple of (tensor_name, tensor_idx) for quantized op's output tensor. + """ + tensor_name, tensor_idx = numeric_verify_name.rsplit(':', 1) + return (tensor_name[len(_NUMERIC_VERIFY_OP_NAME) + 1:], int(tensor_idx)) def layer_statistics_dump(self, file: IO[str]) -> None: """Dumps layer statistics into file, in csv format. @@ -304,15 +318,17 @@ class QuantizationDebugger: Args: file: file, or file-like object to write. """ - fields = ['op_name', 'op_idx'] + list( - self._layer_debug_metrics.keys()) + ['scales', 'zero_points'] + # order of `fields` is the order of fields in csv. + fields = ['op_name', 'tensor_idx'] + list(self._layer_debug_metrics.keys( + )) + ['scales', 'zero_points', 'tensor_name'] writer = csv.DictWriter(file, fields) writer.writeheader() for name, metrics in self.layer_statistics.items(): data = metrics.copy() - data['op_idx'] = self._get_operand_index(name) - data['op_name'] = self._defining_op[data['op_idx']] - details = self._quant_interpreter._get_tensor_details(data['op_idx']) # pylint: disable=protected-access + (data['tensor_name'], + data['tensor_idx']) = self._get_operand_name_and_index(name) + data['op_name'] = self._defining_op[data['tensor_idx']] + details = self._quant_interpreter._get_tensor_details(data['tensor_idx']) # pylint: disable=protected-access data['scales'], data['zero_points'] = ( details['quantization_parameters']['scales'], details['quantization_parameters']['zero_points']) diff --git a/tensorflow/lite/experimental/quantization_debugger/debugger_test.py b/tensorflow/lite/experimental/quantization_debugger/debugger_test.py index 4339f4848eb..ceef9da6b6e 100644 --- a/tensorflow/lite/experimental/quantization_debugger/debugger_test.py +++ b/tensorflow/lite/experimental/quantization_debugger/debugger_test.py @@ -131,9 +131,10 @@ class QuantizationDebuggerTest(test_util.TensorFlowTestCase, expected_values = expected_metrics.copy() expected_values.update({ 'op_name': 'CONV_2D', - 'op_idx': 7 if quantized_io else 8, + 'tensor_idx': 7 if quantized_io else 8, 'scales': [0.15686275], 'zero_points': [-128], + 'tensor_name': 'Identity' if quantized_io else 'Identity4' }) for key, value in expected_values.items(): if isinstance(value, str): diff --git a/tensorflow/lite/g3doc/guide/op_select_allowlist.md b/tensorflow/lite/g3doc/guide/op_select_allowlist.md index 6051b5b988d..56469e6ca4d 100644 --- a/tensorflow/lite/g3doc/guide/op_select_allowlist.md +++ b/tensorflow/lite/g3doc/guide/op_select_allowlist.md @@ -19,12 +19,12 @@ supported by TensorFlow Lite runtime with the Select TensorFlow Ops feature. * `raw_ops.All` * `raw_ops.Angle` * `raw_ops.Any` -* `raw_ops.ApplyAdaMax` * `raw_ops.ApplyAdadelta` * `raw_ops.ApplyAdagrad` * `raw_ops.ApplyAdagradDA` * `raw_ops.ApplyAdagradV2` * `raw_ops.ApplyAdam` +* `raw_ops.ApplyAdaMax` * `raw_ops.ApplyAddSign` * `raw_ops.ApplyCenteredRMSProp` * `raw_ops.ApplyFtrl` @@ -53,6 +53,7 @@ supported by TensorFlow Lite runtime with the Select TensorFlow Ops feature. * `raw_ops.AvgPool3DGrad` * `raw_ops.AvgPoolGrad` * `raw_ops.BatchCholesky` +* `raw_ops.BatchDatasetV2` * `raw_ops.BatchMatMul` * `raw_ops.BatchMatMulV2` * `raw_ops.BatchMatrixDiag` @@ -115,6 +116,8 @@ supported by TensorFlow Lite runtime with the Select TensorFlow Ops feature. * `raw_ops.CropAndResize` * `raw_ops.CropAndResizeGradBoxes` * `raw_ops.CropAndResizeGradImage` +* `raw_ops.CTCBeamSearchDecoder` +* `raw_ops.CTCGreedyDecoder` * `raw_ops.Cumprod` * `raw_ops.Cumsum` * `raw_ops.CumulativeLogsumexp` @@ -171,11 +174,6 @@ supported by TensorFlow Lite runtime with the Select TensorFlow Ops feature. * `raw_ops.Exp` * `raw_ops.ExpandDims` * `raw_ops.ExtractImagePatches` -* `raw_ops.FFT` -* `raw_ops.FFT2D` -* `raw_ops.FFT3D` -* `raw_ops.FIFOQueue` -* `raw_ops.FIFOQueueV2` * `raw_ops.FakeQuantWithMinMaxArgs` * `raw_ops.FakeQuantWithMinMaxArgsGradient` * `raw_ops.FakeQuantWithMinMaxVars` @@ -183,6 +181,11 @@ supported by TensorFlow Lite runtime with the Select TensorFlow Ops feature. * `raw_ops.FakeQuantWithMinMaxVarsPerChannel` * `raw_ops.FakeQuantWithMinMaxVarsPerChannelGradient` * `raw_ops.FakeQueue` +* `raw_ops.FFT` +* `raw_ops.FFT2D` +* `raw_ops.FFT3D` +* `raw_ops.FIFOQueue` +* `raw_ops.FIFOQueueV2` * `raw_ops.Fill` * `raw_ops.Fingerprint` * `raw_ops.Floor` @@ -205,27 +208,27 @@ supported by TensorFlow Lite runtime with the Select TensorFlow Ops feature. * `raw_ops.Greater` * `raw_ops.GreaterEqual` * `raw_ops.HistogramSummary` +* `raw_ops.Identity` +* `raw_ops.IdentityN` * `raw_ops.IFFT` * `raw_ops.IFFT2D` * `raw_ops.IFFT3D` -* `raw_ops.IRFFT` -* `raw_ops.IRFFT2D` -* `raw_ops.IRFFT3D` -* `raw_ops.Identity` -* `raw_ops.IdentityN` * `raw_ops.Imag` * `raw_ops.ImageProjectiveTransformV2` * `raw_ops.ImageProjectiveTransformV3` * `raw_ops.ImmutableConst` -* `raw_ops.InTopK` -* `raw_ops.InTopKV2` * `raw_ops.InplaceAdd` * `raw_ops.InplaceSub` * `raw_ops.InplaceUpdate` +* `raw_ops.InTopK` +* `raw_ops.InTopKV2` * `raw_ops.Inv` -* `raw_ops.InvGrad` * `raw_ops.Invert` * `raw_ops.InvertPermutation` +* `raw_ops.InvGrad` +* `raw_ops.IRFFT` +* `raw_ops.IRFFT2D` +* `raw_ops.IRFFT3D` * `raw_ops.IsBoostedTreesQuantileStreamResourceInitialized` * `raw_ops.IsFinite` * `raw_ops.IsNan` @@ -239,11 +242,13 @@ supported by TensorFlow Lite runtime with the Select TensorFlow Ops feature. * `raw_ops.LinSpace` * `raw_ops.ListDiff` * `raw_ops.Log` -* `raw_ops.LogSoftmax` * `raw_ops.LogicalAnd` * `raw_ops.LogicalNot` * `raw_ops.LogicalOr` +* `raw_ops.LogSoftmax` * `raw_ops.LoopCond` +* `raw_ops.LRN` +* `raw_ops.MapDataset` * `raw_ops.MatMul` * `raw_ops.MatrixDiag` * `raw_ops.MatrixDiagPart` @@ -257,6 +262,7 @@ supported by TensorFlow Lite runtime with the Select TensorFlow Ops feature. * `raw_ops.MatrixSetDiagV3` * `raw_ops.MatrixTriangularSolve` * `raw_ops.Max` +* `raw_ops.Maximum` * `raw_ops.MaxPool` * `raw_ops.MaxPool3D` * `raw_ops.MaxPool3DGrad` @@ -268,7 +274,6 @@ supported by TensorFlow Lite runtime with the Select TensorFlow Ops feature. * `raw_ops.MaxPoolGradWithArgmax` * `raw_ops.MaxPoolV2` * `raw_ops.MaxPoolWithArgmax` -* `raw_ops.Maximum` * `raw_ops.Mean` * `raw_ops.Merge` * `raw_ops.MergeSummary` @@ -278,26 +283,29 @@ supported by TensorFlow Lite runtime with the Select TensorFlow Ops feature. * `raw_ops.Minimum` * `raw_ops.MirrorPad` * `raw_ops.MirrorPadGrad` +* `raw_ops.ModelDataset` * `raw_ops.Mul` * `raw_ops.MulNoNan` * `raw_ops.Multinomial` * `raw_ops.Neg` * `raw_ops.NextIteration` -* `raw_ops.NoOp` * `raw_ops.NonMaxSuppression` * `raw_ops.NonMaxSuppressionV2` * `raw_ops.NonMaxSuppressionV3` * `raw_ops.NonMaxSuppressionV4` * `raw_ops.NonMaxSuppressionV5` * `raw_ops.NonMaxSuppressionWithOverlaps` +* `raw_ops.NoOp` * `raw_ops.NotEqual` * `raw_ops.OneHot` * `raw_ops.OnesLike` +* `raw_ops.OptimizeDatasetV2` * `raw_ops.Pack` * `raw_ops.Pad` * `raw_ops.PadV2` * `raw_ops.PaddingFIFOQueue` * `raw_ops.PaddingFIFOQueueV2` +* `raw_ops.PadV2` * `raw_ops.ParallelConcat` * `raw_ops.ParallelDynamicStitch` * `raw_ops.ParseExample` @@ -315,8 +323,6 @@ supported by TensorFlow Lite runtime with the Select TensorFlow Ops feature. * `raw_ops.Print` * `raw_ops.PrintV2` * `raw_ops.Prod` -* `raw_ops.QuantizeDownAndShrinkRange` -* `raw_ops.QuantizeV2` * `raw_ops.QuantizedAdd` * `raw_ops.QuantizedAvgPool` * `raw_ops.QuantizedBatchNormWithGlobalNormalization` @@ -327,10 +333,12 @@ supported by TensorFlow Lite runtime with the Select TensorFlow Ops feature. * `raw_ops.QuantizedMatMul` * `raw_ops.QuantizedMaxPool` * `raw_ops.QuantizedMul` +* `raw_ops.QuantizeDownAndShrinkRange` * `raw_ops.QuantizedRelu` * `raw_ops.QuantizedRelu6` * `raw_ops.QuantizedReshape` * `raw_ops.QuantizedResizeBilinear` +* `raw_ops.QuantizeV2` * `raw_ops.QueueClose` * `raw_ops.QueueCloseV2` * `raw_ops.QueueDequeue` @@ -347,9 +355,6 @@ supported by TensorFlow Lite runtime with the Select TensorFlow Ops feature. * `raw_ops.QueueIsClosedV2` * `raw_ops.QueueSize` * `raw_ops.QueueSizeV2` -* `raw_ops.RFFT` -* `raw_ops.RFFT2D` -* `raw_ops.RFFT3D` * `raw_ops.RaggedBincount` * `raw_ops.RaggedGather` * `raw_ops.RaggedRange` @@ -369,6 +374,7 @@ supported by TensorFlow Lite runtime with the Select TensorFlow Ops feature. * `raw_ops.Reciprocal` * `raw_ops.ReciprocalGrad` * `raw_ops.Recv` +* `raw_ops.ReduceDataset` * `raw_ops.ReduceJoin` * `raw_ops.RefEnter` * `raw_ops.RefExit` @@ -393,12 +399,12 @@ supported by TensorFlow Lite runtime with the Select TensorFlow Ops feature. * `raw_ops.ResizeBilinearGrad` * `raw_ops.ResizeNearestNeighbor` * `raw_ops.ResizeNearestNeighborGrad` -* `raw_ops.ResourceApplyAdaMax` * `raw_ops.ResourceApplyAdadelta` * `raw_ops.ResourceApplyAdagrad` * `raw_ops.ResourceApplyAdagradDA` * `raw_ops.ResourceApplyAdagradV2` * `raw_ops.ResourceApplyAdam` +* `raw_ops.ResourceApplyAdaMax` * `raw_ops.ResourceApplyAdamWithAmsgrad` * `raw_ops.ResourceApplyAddSign` * `raw_ops.ResourceApplyCenteredRMSProp` @@ -444,6 +450,9 @@ supported by TensorFlow Lite runtime with the Select TensorFlow Ops feature. * `raw_ops.Reverse` * `raw_ops.ReverseSequence` * `raw_ops.ReverseV2` +* `raw_ops.RFFT` +* `raw_ops.RFFT2D` +* `raw_ops.RFFT3D` * `raw_ops.RightShift` * `raw_ops.Roll` * `raw_ops.Round` @@ -663,10 +672,10 @@ supported by TensorFlow Lite runtime with the Select TensorFlow Ops feature. * `raw_ops.UnsortedSegmentSum` * `raw_ops.UnwrapDatasetVariant` * `raw_ops.VarHandleOp` -* `raw_ops.VarIsInitializedOp` * `raw_ops.Variable` * `raw_ops.VariableShape` * `raw_ops.VariableV2` +* `raw_ops.VarIsInitializedOp` * `raw_ops.Where` * `raw_ops.WrapDatasetVariant` * `raw_ops.Xdivy` diff --git a/tensorflow/lite/interpreter.h b/tensorflow/lite/interpreter.h index d19d4b87949..0507fe555ba 100644 --- a/tensorflow/lite/interpreter.h +++ b/tensorflow/lite/interpreter.h @@ -681,7 +681,7 @@ class Interpreter { TfLiteExternalContextType type, TfLiteExternalContext* ctx); - // Helper method that return the tensot index that corresponds to + // Helper method that return the tensor index that corresponds to // a name in a SignatureDef. Defined by 'signature_method_name', and // 'signature_tensor_name'. // If 'is_input' is true then the tensor is checked in input tensors, diff --git a/tensorflow/lite/java/BUILD b/tensorflow/lite/java/BUILD index 81f43b2b737..1ab48302dfa 100644 --- a/tensorflow/lite/java/BUILD +++ b/tensorflow/lite/java/BUILD @@ -452,7 +452,6 @@ alias( cc_library( name = "tensorflowlite_native", srcs = ["libtensorflowlite_jni.so"], - visibility = ["//visibility:private"], ) cc_library( diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index f9940f445b4..37a042c0d5b 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -349,9 +349,6 @@ cc_library( "//conditions:default": ["-DTFLITE_HAVE_CPUINFO"], }), deps = [ - # TODO(b/168923364): Remove deprecated_backends after it is added to all - # necessary targets. - ":deprecated_backends", ":tflite_with_ruy", ":op_macros", # For now this unconditionally depends on both ruy and gemmlowp. @@ -743,6 +740,7 @@ cc_library( "random_uniform.cc", ], hdrs = ["custom_ops_register.h"], + compatible_with = get_compatible_with_portable(), copts = tflite_copts(), deps = [ ":kernel_util", diff --git a/tensorflow/lite/kernels/random_uniform.cc b/tensorflow/lite/kernels/random_uniform.cc index 1b19a80cc43..1a248755907 100644 --- a/tensorflow/lite/kernels/random_uniform.cc +++ b/tensorflow/lite/kernels/random_uniform.cc @@ -140,7 +140,7 @@ TfLiteStatus EvalInt(TfLiteContext* context, TfLiteNode* node) { size_t output_size = tflite::NumElements(output); switch (output->type) { case kTfLiteInt8: - RandomUniformSample>( + RandomUniformSample>( params->rng, GetTensorData(output), output_size, min_value, max_value); break; diff --git a/tensorflow/lite/kernels/random_uniform_test.cc b/tensorflow/lite/kernels/random_uniform_test.cc index 28a795470af..d852f69e482 100644 --- a/tensorflow/lite/kernels/random_uniform_test.cc +++ b/tensorflow/lite/kernels/random_uniform_test.cc @@ -36,6 +36,11 @@ tflite::TensorType GetTTEnum() { return tflite::TensorType_FLOAT64; } +template <> +tflite::TensorType GetTTEnum() { + return tflite::TensorType_INT8; +} + template <> tflite::TensorType GetTTEnum() { return tflite::TensorType_INT32; @@ -150,7 +155,7 @@ class RandomUniformIntTest : public ::testing::Test { using Int = IntType; }; -using TestTypesInt = ::testing::Types; +using TestTypesInt = ::testing::Types; TYPED_TEST_SUITE(RandomUniformIntTest, TestTypesInt); diff --git a/tensorflow/lite/micro/README.md b/tensorflow/lite/micro/README.md index c8afdc02fc5..a5811f06635 100644 --- a/tensorflow/lite/micro/README.md +++ b/tensorflow/lite/micro/README.md @@ -36,6 +36,7 @@ Linux | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-b ## Community Supported Builds Build Type | Status | Artifacts ---------- | ----------- | --------- +Arduino | [![Status](https://github.com/antmicro/tensorflow-arduino-examples/actions/workflows/test_examples.yml/badge.svg)](https://github.com/antmicro/tensorflow-arduino-examples/actions/workflows/test_examples.yml) | Xtensa | [![Status](https://github.com/advaitjain/tensorflow/blob/local-continuous-builds/tensorflow/lite/micro/docs/local_continuous_builds/xtensa-build-status.svg)](https://github.com/advaitjain/tensorflow/tree/local-continuous-builds/tensorflow/lite/micro/docs/local_continuous_builds/xtensa.md) | diff --git a/tensorflow/lite/micro/docs/renode.md b/tensorflow/lite/micro/docs/renode.md index abbdbb7eb2d..7bbf90690d8 100644 --- a/tensorflow/lite/micro/docs/renode.md +++ b/tensorflow/lite/micro/docs/renode.md @@ -26,13 +26,15 @@ Here, we document how Renode is used as part of the TFLM project. For more general use of Renode, please refer to the [Renode documentation](https://renode.readthedocs.io/en/latest/). +You can also read more about Renode from a [publicly available slide deck](https://docs.google.com/presentation/d/1j0gjI4pVkgF9CWvxaxr5XuCKakEB25YX2n-iFxlYKnE/edit). + # Installation -Renode can be installed and used in a variety of ways, as documented -[here](https://renode.readthedocs.io/en/latest/). For the purpose of Tensorflow -Lite Micro, we make use of a portable version for Linux. +Renode can be installed and used in a variety of ways, as documented in the +[Renode README](https://github.com/renode/renode/blob/master/README.rst#installation/). For the purpose of Tensorflow +Lite Micro, we make use of the portable version for Linux. -Portable renode wil be automatically installed when using the TfLite Micro +Portable renode will be automatically installed when using the TfLite Micro Makefile to `tensorflow/lite/micro/tools/make/downloads/renode`. The Makefile internally calls the `renode_download.sh` script: diff --git a/tensorflow/lite/micro/examples/person_detection/README.md b/tensorflow/lite/micro/examples/person_detection/README.md index 9877343bca3..5fe66b3d954 100644 --- a/tensorflow/lite/micro/examples/person_detection/README.md +++ b/tensorflow/lite/micro/examples/person_detection/README.md @@ -227,18 +227,17 @@ build and upload the example. To test the camera, start by pointing the device's camera at something that is definitely not a person, or just covering it up. The next time the blue LED flashes, the device will capture a frame from the camera and begin to run -inference. Since the vision model we are using for person detection is -relatively large, it takes a long time to run inference—around 19 seconds at the -time of writing, though it's possible TensorFlow Lite has gotten faster since -then. +inference. The vision model we are using for person detection is relatively +large, but with cmsis-nn optimizations it only takes around 800ms to run the +model. -After 19 seconds or so, the inference result will be translated into another LED -being lit. Since you pointed the camera at something that isn't a person, the -red LED should light up. +After a moment, the inference result will be translated into another LED being +lit. Since you pointed the camera at something that isn't a person, the red LED +should light up. Now, try pointing the device's camera at yourself! The next time the blue LED flashes, the device will capture another image and begin to run inference. After -19 seconds, the green LED should light up! +a brief puase, the green LED should light up! Remember, image data is captured as a snapshot before each inference, whenever the blue LED flashes. Whatever the camera is pointed at during that moment is diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index f69b4b53c85..f3cc166de3a 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -258,8 +258,9 @@ cc_library( "dequantize.cc", "detection_postprocess.cc", "elementwise.cc", - "exp.cc", "elu.cc", + "exp.cc", + "fill.cc", "floor.cc", "l2norm.cc", "logical.cc", @@ -589,6 +590,20 @@ cc_test( ], ) +cc_test( + name = "fill_test", + srcs = [ + "fill_test.cc", + ], + deps = [ + ":kernel_runner", + "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", + "//tensorflow/lite/micro/testing:micro_test", + ], +) + cc_test( name = "floor_test", srcs = [ diff --git a/tensorflow/lite/micro/kernels/dequantize.cc b/tensorflow/lite/micro/kernels/dequantize.cc index a04143dea7f..b488c41a420 100644 --- a/tensorflow/lite/micro/kernels/dequantize.cc +++ b/tensorflow/lite/micro/kernels/dequantize.cc @@ -59,8 +59,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, input->type == kTfLiteUInt8 || input->type == kTfLiteInt8 || input->type == kTfLiteInt16); - TF_LITE_ENSURE( - context, output->type == kTfLiteFloat32 || output->type == kTfLiteInt32); + TF_LITE_ENSURE(context, output->type == kTfLiteFloat32); if (output->type == kTfLiteInt32) { const double effective_output_scale = @@ -112,24 +111,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTypeGetName(output->type)); return kTfLiteError; } - } else if (output->type == kTfLiteInt32) { - int flat_size = MatchingFlatSize(tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorShape(output)); - switch (input->type) { - case kTfLiteInt8: { - reference_ops::Requantize( - tflite::micro::GetTensorData(input), flat_size, - data->output_multiplier, data->output_shift, - data->quantization_params.zero_point, data->output_zero_point, - tflite::micro::GetTensorData(output)); - break; - } - default: - TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", - TfLiteTypeGetName(input->type), - TfLiteTypeGetName(output->type)); - return kTfLiteError; - } } else { TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", TfLiteTypeGetName(input->type), diff --git a/tensorflow/lite/micro/kernels/fill.cc b/tensorflow/lite/micro/kernels/fill.cc index a7839cd1e41..ca3d15e1b6c 100644 --- a/tensorflow/lite/micro/kernels/fill.cc +++ b/tensorflow/lite/micro/kernels/fill.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,75 +13,97 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/lite/kernels/internal/reference/fill.h" + #include #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/string_util.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" namespace tflite { namespace { +template +TfLiteStatus EnsureEqImpl(TfLiteContext* context, const TfLiteIntArray* array, + const TfLiteTensor* tensor) { + for (int i = 0; i < array->size; ++i) { + TF_LITE_ENSURE_EQ(context, array->data[i], GetTensorData(tensor)[i]); + } + return kTfLiteOk; +} + +// Ensure the equality of an int array and a tensor, which must be +// one-dimensional and of an integer type. +TfLiteStatus EnsureEq(TfLiteContext* context, const TfLiteIntArray* array, + const TfLiteTensor* tensor) { + TF_LITE_ENSURE_EQ(context, NumDimensions(tensor), 1); + const auto tensor_len = tensor->dims->data[0]; + TF_LITE_ENSURE_EQ(context, array->size, tensor_len); + + switch (tensor->type) { + case kTfLiteInt8: + return EnsureEqImpl(context, array, tensor); + case kTfLiteUInt8: + return EnsureEqImpl(context, array, tensor); + case kTfLiteInt16: + return EnsureEqImpl(context, array, tensor); + case kTfLiteInt32: + return EnsureEqImpl(context, array, tensor); + case kTfLiteInt64: + return EnsureEqImpl(context, array, tensor); + default: + TF_LITE_KERNEL_LOG(context, + "cannot compare int array to tensor of type %d.", + tensor->type); + return kTfLiteError; + } +} + constexpr int kDimsTensor = 0; constexpr int kValueTensor = 1; constexpr int kOutputTensor = 0; TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); - TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - + // Ensure inputs and outputs exist. const TfLiteTensor* dims; TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDimsTensor, &dims)); const TfLiteTensor* value; TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kValueTensor, &value)); - - // Make sure the 1st input tensor is 1-D. - TF_LITE_ENSURE_EQ(context, NumDimensions(dims), 1); - - // Make sure the 1st input tensor is int32 or int64. - const auto dtype = dims->type; - TF_LITE_ENSURE(context, dtype == kTfLiteInt32 || dtype == kTfLiteInt64); - - // Make sure the 2nd input tensor is a scalar. - TF_LITE_ENSURE_EQ(context, NumDimensions(value), 0); - TfLiteTensor* output; TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputTensor, &output)); - output->type = value->type; - if (IsConstantTensor(dims)) { - TF_LITE_ENSURE_OK(context, ResizeOutput(context, dims, output)); - } else { - SetTensorToDynamic(output); - } + // The value tensor must be a scalar. + TF_LITE_ENSURE_EQ(context, NumDimensions(value), 0); + + // The value type and output type must match. + TF_LITE_ENSURE_EQ(context, value->type, output->type); + + // The dims tensor must match the output tensor shape. As a byproduct, + // ensures the dims tensor is of an integer type. + TF_LITE_ENSURE_OK(context, EnsureEq(context, output->dims, dims)); + return kTfLiteOk; } +template +void FillImpl(const TfLiteEvalTensor* value, TfLiteEvalTensor* output) { + reference_ops::Fill( + micro::GetTensorShape(value), micro::GetTensorData(value), + micro::GetTensorShape(output), micro::GetTensorData(output)); +} + TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - const TfLiteTensor* value; - TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kValueTensor, &value)); + const TfLiteEvalTensor* value = + micro::GetEvalInput(context, node, kValueTensor); + TfLiteEvalTensor* output = micro::GetEvalOutput(context, node, kOutputTensor); - TfLiteTensor* output; - TF_LITE_ENSURE_OK(context, - GetOutputSafe(context, node, kOutputTensor, &output)); - - if (IsDynamicTensor(output)) { - const TfLiteTensor* dims; - TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDimsTensor, &dims)); - TF_LITE_ENSURE_OK(context, ResizeOutput(context, dims, output)); - } -#define TF_LITE_FILL(data_type) \ - reference_ops::Fill(GetTensorShape(value), GetTensorData(value), \ - GetTensorShape(output), \ - GetTensorData(output)) - switch (output->type) { + switch (value->type) { case kTfLiteFloat32: - TF_LITE_FILL(float); + FillImpl(value, output); break; default: TF_LITE_KERNEL_LOG( @@ -89,16 +111,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTypeGetName(value->type)); return kTfLiteError; } -#undef TF_LITE_FILL + return kTfLiteOk; } } // namespace -TfLiteRegistration* Register_FILL() { - static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, - fill::Prepare, fill::Eval}; - return &r; +TfLiteRegistration Register_FILL() { + return {/*init=*/nullptr, + /*free=*/nullptr, + /*prepare=*/Prepare, + /*invoke=*/Eval, + /*profiling_string=*/nullptr, + /*builtin_code=*/0, + /*custom_name=*/nullptr, + /*version=*/0}; } } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/fill_test.cc b/tensorflow/lite/micro/kernels/fill_test.cc index 77c9baf4028..8735ce580ee 100644 --- a/tensorflow/lite/micro/kernels/fill_test.cc +++ b/tensorflow/lite/micro/kernels/fill_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,31 +13,116 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/micro_utils.h" +#include "tensorflow/lite/micro/test_helpers.h" +#include "tensorflow/lite/micro/testing/micro_test.h" + namespace { -TEST_P(FillOpTest, FillFloat) { - FillOpModel m(TensorType_INT64, {3}, {2, 2, 2}, 4.0, - GetParam()); - m.Invoke(); - EXPECT_THAT(m.GetOutput(), - ElementsAreArray({4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0})); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); -} +template +void TestFill(int* dims_shape, DimsType* dims_data, int* value_shape, + ValueType* value_data, int* output_shape, + OutputType* output_data) { + using tflite::testing::CreateTensor; + using tflite::testing::IntArrayFromInts; -TEST_P(FillOpTest, FillFloatInt32Dims) { - FillOpModel m(TensorType_INT32, {3}, {2, 2, 2}, 4.0, - GetParam()); - m.Invoke(); - EXPECT_THAT(m.GetOutput(), - ElementsAreArray({4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0})); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); -} + TfLiteTensor tensors[] = { + CreateTensor(dims_data, IntArrayFromInts(dims_shape)), + CreateTensor(value_data, IntArrayFromInts(value_shape)), + CreateTensor(output_data, IntArrayFromInts(output_shape))}; + constexpr int dims_index = 0; + constexpr int value_index = 1; + constexpr int output_index = 2; + constexpr int inputs[] = {2, dims_index, value_index}; + constexpr int outputs[] = {1, output_index}; + const auto registration = tflite::Register_FILL(); + tflite::micro::KernelRunner runner{registration, + tensors, + sizeof(tensors) / sizeof(TfLiteTensor), + IntArrayFromInts(inputs), + IntArrayFromInts(outputs), + /*builtin_data=*/nullptr}; -TEST_P(FillOpTest, FillOutputScalar) { - FillOpModel m(TensorType_INT64, {0}, {}, 4.0, GetParam()); - m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({4.0})); - EXPECT_THAT(m.GetOutputShape(), IsEmpty()); + TF_LITE_MICRO_EXPECT_EQ(runner.InitAndPrepare(), kTfLiteOk); + TF_LITE_MICRO_EXPECT_EQ(runner.Invoke(), kTfLiteOk); + + // The output shape must match the shape requested via dims. + const auto output_rank = output_shape[0]; + const auto requested_rank = dims_shape[1]; // yes, 1 + if (output_rank == requested_rank) { + for (int i = 0; i < requested_rank; ++i) { + TF_LITE_MICRO_EXPECT_EQ(output_shape[i + 1], dims_data[i]); + } + } else { + TF_LITE_MICRO_FAIL("output shape does not match shape requested via dims"); + } + + // The output type matches the value type. + TF_LITE_MICRO_EXPECT_EQ(tensors[output_index].type, + tensors[value_index].type); + + // The output elements contain the fill value. + const auto elements = tflite::ElementCount(*IntArrayFromInts(output_shape)); + for (int i = 0; i < elements; ++i) { + TF_LITE_MICRO_EXPECT_EQ(output_data[i], value_data[0]); + } } } // namespace + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(FillFloatInt64Dims) { + constexpr int kDim1 = 2; + constexpr int kDim2 = 2; + constexpr int kDim3 = 2; + + int dims_shape[] = {1, 3}; + int64_t dims_data[] = {kDim1, kDim2, kDim3}; + + int value_shape[] = {0}; + float value_data[] = {4.0}; + + int output_shape[] = {3, kDim1, kDim2, kDim3}; + float output_data[kDim1 * kDim2 * kDim3]; + + TestFill(dims_shape, dims_data, value_shape, value_data, output_shape, + output_data); +} + +TF_LITE_MICRO_TEST(FillFloatInt32Dims) { + constexpr int kDim1 = 2; + constexpr int kDim2 = 2; + constexpr int kDim3 = 2; + + int dims_shape[] = {1, 3}; + int32_t dims_data[] = {kDim1, kDim2, kDim3}; + + int value_shape[] = {0}; + float value_data[] = {4.0}; + + int output_shape[] = {3, kDim1, kDim2, kDim3}; + float output_data[kDim1 * kDim2 * kDim3]; + + TestFill(dims_shape, dims_data, value_shape, value_data, output_shape, + output_data); +} + +TF_LITE_MICRO_TEST(FillScalar) { + int dims_shape[] = {1, 0}; + int64_t dims_data[] = {0}; + + int value_shape[] = {0}; + float value_data[] = {4.0}; + + int output_shape[] = {0}; + float output_data[] = {0}; + + TestFill(dims_shape, dims_data, value_shape, value_data, output_shape, + output_data); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/micro_ops.h b/tensorflow/lite/micro/kernels/micro_ops.h index 8cccaa91d8c..54e04c14471 100644 --- a/tensorflow/lite/micro/kernels/micro_ops.h +++ b/tensorflow/lite/micro/kernels/micro_ops.h @@ -37,6 +37,7 @@ TfLiteRegistration Register_CONV_2D(); TfLiteRegistration Register_DEPTHWISE_CONV_2D(); TfLiteRegistration Register_ELU(); TfLiteRegistration Register_EXP(); +TfLiteRegistration Register_FILL(); TfLiteRegistration Register_QUANTIZE(); TfLiteRegistration Register_SHAPE(); TfLiteRegistration Register_SOFTMAX(); diff --git a/tensorflow/lite/micro/kernels/quantize.cc b/tensorflow/lite/micro/kernels/quantize.cc index f62addbb776..1f4946bc532 100644 --- a/tensorflow/lite/micro/kernels/quantize.cc +++ b/tensorflow/lite/micro/kernels/quantize.cc @@ -63,6 +63,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { if ((input->type == kTfLiteInt16 && output->type == kTfLiteInt8) || (input->type == kTfLiteInt8 && output->type == kTfLiteInt8) || + (input->type == kTfLiteInt8 && output->type == kTfLiteInt32) || (input->type == kTfLiteInt16 && output->type == kTfLiteInt16) || (input->type == kTfLiteInt16 && output->type == kTfLiteInt32)) { double effective_scale = static_cast(input->params.scale) / diff --git a/tensorflow/lite/micro/kernels/quantize_common.cc b/tensorflow/lite/micro/kernels/quantize_common.cc index 2c4a8d2c604..ea9f3f89938 100644 --- a/tensorflow/lite/micro/kernels/quantize_common.cc +++ b/tensorflow/lite/micro/kernels/quantize_common.cc @@ -103,6 +103,13 @@ TfLiteStatus EvalQuantizeReference(TfLiteContext* context, TfLiteNode* node) { data->input_zero_point, data->quantization_params.zero_point, tflite::micro::GetTensorData(output)); break; + case kTfLiteInt32: + reference_ops::Requantize( + tflite::micro::GetTensorData(input), size, + data->requantize_output_multiplier, data->requantize_output_shift, + data->input_zero_point, data->quantization_params.zero_point, + tflite::micro::GetTensorData(output)); + break; default: TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", TfLiteTypeGetName(input->type), diff --git a/tensorflow/lite/micro/kernels/zeros_like.cc b/tensorflow/lite/micro/kernels/zeros_like.cc index 1b986cc1695..ce403927567 100644 --- a/tensorflow/lite/micro/kernels/zeros_like.cc +++ b/tensorflow/lite/micro/kernels/zeros_like.cc @@ -58,6 +58,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt32: resetZeros(tflite::micro::GetTensorData(output), flat_size); break; + case kTfLiteInt8: + resetZeros(tflite::micro::GetTensorData(output), flat_size); + break; case kTfLiteFloat32: resetZeros(tflite::micro::GetTensorData(output), flat_size); break; diff --git a/tensorflow/lite/micro/kernels/zeros_like_test.cc b/tensorflow/lite/micro/kernels/zeros_like_test.cc index 20385ea3bea..68b7807dcba 100644 --- a/tensorflow/lite/micro/kernels/zeros_like_test.cc +++ b/tensorflow/lite/micro/kernels/zeros_like_test.cc @@ -24,72 +24,9 @@ namespace tflite { namespace testing { namespace { -void TestZerosLikeFloat(const int* input_dims_data, const float* input_data, - const float* expected_output_data, float* output_data) { - TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); - TfLiteIntArray* output_dims = IntArrayFromInts(input_dims_data); - const int output_dims_count = ElementCount(*output_dims); - constexpr int inputs_size = 1; - constexpr int outputs_size = 1; - constexpr int tensors_size = inputs_size + outputs_size; - TfLiteTensor tensors[tensors_size] = { - CreateTensor(input_data, input_dims), - CreateTensor(output_data, output_dims), - }; - - int inputs_array_data[] = {1, 0}; - TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); - int outputs_array_data[] = {1, 1}; - TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); - - const TfLiteRegistration registration = Register_ZEROS_LIKE(); - micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, - outputs_array, - /*builtin_data=*/nullptr); - - TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); - TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); - - for (int i = 0; i < output_dims_count; ++i) { - TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]); - } -} - -void TestZerosLikeInt32(const int* input_dims_data, const int32_t* input_data, - const int32_t* expected_output_data, - int32_t* output_data) { - TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); - TfLiteIntArray* output_dims = IntArrayFromInts(input_dims_data); - const int output_dims_count = ElementCount(*output_dims); - constexpr int inputs_size = 1; - constexpr int outputs_size = 1; - constexpr int tensors_size = inputs_size + outputs_size; - TfLiteTensor tensors[tensors_size] = { - CreateTensor(input_data, input_dims), - CreateTensor(output_data, output_dims), - }; - - int inputs_array_data[] = {1, 0}; - TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); - int outputs_array_data[] = {1, 1}; - TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); - - const TfLiteRegistration registration = Register_ZEROS_LIKE(); - micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, - outputs_array, - /*builtin_data=*/nullptr); - - TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); - TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); - - for (int i = 0; i < output_dims_count; ++i) { - TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]); - } -} - -void TestZerosLikeInt64(const int* input_dims_data, const int64_t* input_data, - const int64_t* expected_output_data, - int64_t* output_data) { +template +void TestZerosLike(const int* input_dims_data, const T* input_data, + const T* expected_output_data, T* output_data) { TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); TfLiteIntArray* output_dims = IntArrayFromInts(input_dims_data); const int output_dims_count = ElementCount(*output_dims); @@ -130,8 +67,17 @@ TF_LITE_MICRO_TEST(TestZerosLikeFloat) { const int input_dims[] = {2, 2, 3}; const float input_values[] = {-2.0, -1.0, 0.0, 1.0, 2.0, 3.0}; const float golden[] = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; - tflite::testing::TestZerosLikeFloat(input_dims, input_values, golden, - output_data); + tflite::testing::TestZerosLike(input_dims, input_values, golden, + output_data); +} + +TF_LITE_MICRO_TEST(TestZerosLikeInt8) { + int8_t output_data[6]; + const int input_dims[] = {3, 1, 2, 3}; + const int8_t input_values[] = {-2, -1, 0, 1, 2, 3}; + const int8_t golden[] = {0, 0, 0, 0, 0, 0}; + tflite::testing::TestZerosLike(input_dims, input_values, golden, + output_data); } TF_LITE_MICRO_TEST(TestZerosLikeInt32) { @@ -139,8 +85,8 @@ TF_LITE_MICRO_TEST(TestZerosLikeInt32) { const int input_dims[] = {4, 1, 2, 2, 1}; const int32_t input_values[] = {-2, -1, 0, 3}; const int32_t golden[] = {0, 0, 0, 0}; - tflite::testing::TestZerosLikeInt32(input_dims, input_values, golden, - output_data); + tflite::testing::TestZerosLike(input_dims, input_values, golden, + output_data); } TF_LITE_MICRO_TEST(TestZerosLikeInt64) { @@ -148,8 +94,8 @@ TF_LITE_MICRO_TEST(TestZerosLikeInt64) { const int input_dims[] = {4, 1, 2, 2, 1}; const int64_t input_values[] = {-2, -1, 0, 3}; const int64_t golden[] = {0, 0, 0, 0}; - tflite::testing::TestZerosLikeInt64(input_dims, input_values, golden, - output_data); + tflite::testing::TestZerosLike(input_dims, input_values, golden, + output_data); } TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h index af0a0c7fe68..757dac8c19a 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver.h +++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h @@ -188,6 +188,10 @@ class MicroMutableOpResolver : public MicroOpResolver { tflite::Register_DETECTION_POSTPROCESS()); } + TfLiteStatus AddElu() { + return AddBuiltin(BuiltinOperator_ELU, tflite::Register_ELU(), ParseElu); + } + TfLiteStatus AddEqual() { return AddBuiltin(BuiltinOperator_EQUAL, tflite::ops::micro::Register_EQUAL(), ParseEqual); diff --git a/tensorflow/lite/micro/testing/BUILD b/tensorflow/lite/micro/testing/BUILD index 71e503abafe..335953d1412 100644 --- a/tensorflow/lite/micro/testing/BUILD +++ b/tensorflow/lite/micro/testing/BUILD @@ -23,7 +23,6 @@ cc_library( visibility = [ ":micro", ":microfrontend", - "//learning/brain/contrib/micro/tflite:__pkg__", ], deps = [ "//tensorflow/lite/c:common", diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index 2e360e3e492..b96e6bb1363 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -270,8 +270,9 @@ tensorflow/lite/micro/kernels/depthwise_conv_test.cc \ tensorflow/lite/micro/kernels/dequantize_test.cc \ tensorflow/lite/micro/kernels/detection_postprocess_test.cc \ tensorflow/lite/micro/kernels/elementwise_test.cc \ -tensorflow/lite/micro/kernels/exp_test.cc \ tensorflow/lite/micro/kernels/elu_test.cc \ +tensorflow/lite/micro/kernels/exp_test.cc \ +tensorflow/lite/micro/kernels/fill_test.cc \ tensorflow/lite/micro/kernels/floor_test.cc \ tensorflow/lite/micro/kernels/fully_connected_test.cc \ tensorflow/lite/micro/kernels/hard_swish_test.cc \ @@ -327,6 +328,7 @@ tensorflow/lite/micro/kernels/elementwise.cc \ tensorflow/lite/micro/kernels/elu.cc \ tensorflow/lite/micro/kernels/ethosu.cc \ tensorflow/lite/micro/kernels/exp.cc \ +tensorflow/lite/micro/kernels/fill.cc \ tensorflow/lite/micro/kernels/flexbuffers_generated_data.cc \ tensorflow/lite/micro/kernels/floor.cc \ tensorflow/lite/micro/kernels/fully_connected.cc \ @@ -411,8 +413,9 @@ tensorflow/lite/kernels/internal/reference/conv.h \ tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h \ tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h \ tensorflow/lite/kernels/internal/reference/dequantize.h \ -tensorflow/lite/kernels/internal/reference/exp.h \ tensorflow/lite/kernels/internal/reference/elu.h \ +tensorflow/lite/kernels/internal/reference/exp.h \ +tensorflow/lite/kernels/internal/reference/fill.h \ tensorflow/lite/kernels/internal/reference/floor.h \ tensorflow/lite/kernels/internal/reference/fully_connected.h \ tensorflow/lite/kernels/internal/reference/hard_swish.h \ diff --git a/tensorflow/lite/micro/tools/make/ext_libs/cmsis_download.sh b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_download.sh index 2a4f3eb3fd7..fdb02a3b84a 100755 --- a/tensorflow/lite/micro/tools/make/ext_libs/cmsis_download.sh +++ b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_download.sh @@ -49,9 +49,9 @@ if [ -d ${DOWNLOADED_CMSIS_PATH} ]; then echo >&2 "${DOWNLOADED_CMSIS_PATH} already exists, skipping the download." else - ZIP_PREFIX="71627bc91534ed9eec2361c0ef6442cd057653e0" + ZIP_PREFIX="0d7e4fa7131241a17e23dfae18140e0b2e77728f" CMSIS_URL="http://github.com/ARM-software/CMSIS_5/archive/${ZIP_PREFIX}.zip" - CMSIS_MD5="207c49970758c663e2ce1cc0245972a9" + CMSIS_MD5="630bb4a0acd3d2f3ccdd8bcccb9d6400" # wget is much faster than git clone of the entire repo. So we wget a specific # version and can then apply a patch, as needed. diff --git a/tensorflow/lite/micro/tools/make/targets/cortex_m_corstone_300_makefile.inc b/tensorflow/lite/micro/tools/make/targets/cortex_m_corstone_300_makefile.inc index 435694f9d2f..71bab927bf3 100644 --- a/tensorflow/lite/micro/tools/make/targets/cortex_m_corstone_300_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/cortex_m_corstone_300_makefile.inc @@ -154,8 +154,7 @@ EXCLUDED_TESTS := \ tensorflow/lite/micro/output_handler_test.cc \ tensorflow/lite/micro/memory_arena_threshold_test.cc \ tensorflow/lite/micro/recording_micro_allocator_test.cc \ - tensorflow/lite/micro/kernels/circular_buffer_test.cc \ - tensorflow/lite/micro/kernels/pooling_test.cc + tensorflow/lite/micro/kernels/circular_buffer_test.cc MICROLITE_TEST_SRCS := $(filter-out $(EXCLUDED_TESTS), $(MICROLITE_TEST_SRCS)) EXCLUDED_EXAMPLE_TESTS := \ tensorflow/lite/micro/examples/magic_wand/Makefile.inc \ diff --git a/tensorflow/lite/micro/tools/make/templates/README_KEIL.md.tpl b/tensorflow/lite/micro/tools/make/templates/README_KEIL.md.tpl index 945b9f9c1ae..5b4560e7a0d 100644 --- a/tensorflow/lite/micro/tools/make/templates/README_KEIL.md.tpl +++ b/tensorflow/lite/micro/tools/make/templates/README_KEIL.md.tpl @@ -1,4 +1,4 @@ -# TensorFlow Lite Micro Mbed Project +# TensorFlow Lite Micro Keil Project This folder has been autogenerated by TensorFlow, and contains source, header, and project files needed to build a single TensorFlow Lite Micro target using diff --git a/tensorflow/lite/profiling/profile_summarizer.cc b/tensorflow/lite/profiling/profile_summarizer.cc index 2fc04f99659..076062c760b 100644 --- a/tensorflow/lite/profiling/profile_summarizer.cc +++ b/tensorflow/lite/profiling/profile_summarizer.cc @@ -174,6 +174,10 @@ void ProfileSummarizer::ProcessProfiles( const memory::MemoryUsage node_mem_usage = event->end_mem_usage - event->begin_mem_usage; std::string node_name(event->tag); + if (node_name == "Invoke") { + // Don't count the overall Invoke for profiling. + continue; + } node_name += "/" + std::to_string(event->extra_event_metadata); stats_calculator->AddNodeStats(node_name, event->tag, node_num, start_us, node_exec_time, diff --git a/tensorflow/lite/profiling/profile_summarizer_test.cc b/tensorflow/lite/profiling/profile_summarizer_test.cc index 53b08eaf95b..62444c5a89a 100644 --- a/tensorflow/lite/profiling/profile_summarizer_test.cc +++ b/tensorflow/lite/profiling/profile_summarizer_test.cc @@ -124,6 +124,7 @@ TEST(ProfileSummarizerTest, Interpreter) { auto output = summarizer.GetOutputString(); // TODO(shashishekhar): Add a better test here. ASSERT_TRUE(output.find("SimpleOpEval") != std::string::npos) << output; + ASSERT_TRUE(output.find("Invoke") == std::string::npos) << output; // NOLINT } TEST(ProfileSummarizerTest, InterpreterPlusProfilingDetails) { diff --git a/tensorflow/lite/python/util.py b/tensorflow/lite/python/util.py index 50fa57e7307..5ee9c7b9b92 100644 --- a/tensorflow/lite/python/util.py +++ b/tensorflow/lite/python/util.py @@ -839,9 +839,8 @@ def _remove_redundant_quantize_ops(model): # This is a requantize op, so write down its input tensor index. if input_type != dtypes.float32 and output_type != dtypes.float32: redundant_quant_tensors[op.inputs[0]] = op - elif (op.opcodeIndex in dequant_opcode_idxs and - op.outputs[0] in subgraph.outputs): - # Mark quant-dequant op pairs right before outputs to be removed. + if op.opcodeIndex in dequant_opcode_idxs and \ + op.outputs[0] in subgraph.outputs: output_dequant_tensors[op.inputs[0]] = op # Remove all the quant ops which produce the redundant quant tensors. @@ -853,13 +852,12 @@ def _remove_redundant_quantize_ops(model): requantize_op.inputs[0] = op.inputs[0] operators.remove(op) - # Remove all the quant/dequant op pairs right before the outputs. + # Remove all the quant ops which connect to the output dequant op. for op in all_quant_ops: output_tensor_idx = op.outputs[0] if output_tensor_idx in output_dequant_tensors: dequant_op = output_dequant_tensors[output_tensor_idx] - output_idx = subgraph.outputs.index(dequant_op.outputs[0]) - subgraph.outputs[output_idx] = op.inputs[0] + subgraph.outputs[subgraph.outputs == dequant_op.outputs[0]] = op.inputs[0] operators.remove(op) operators.remove(dequant_op) @@ -867,7 +865,7 @@ def _remove_redundant_quantize_ops(model): def modify_model_io_type( model, inference_input_type=dtypes.float32, inference_output_type=dtypes.float32): - """Modifies the input/output type of a tflite model. + """Modify the input/output type of a tflite model. Args: model: A tflite model. @@ -879,7 +877,6 @@ def modify_model_io_type( (default tf.float32. If model output is int8 dequantized, it must be in {tf.float32, tf.int8,tf.uint8}, else if model output is int16 dequantized, it must be in {tf.float32, tf.int16}, else it must be tf.float32) - Returns: A tflite model with modified input/output type. diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index e473136dc46..c83fb709ce8 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -226,11 +226,6 @@ tensorflow/third_party/tflite_mobilenet_quant.BUILD tensorflow/third_party/tflite_ovic_testdata.BUILD tensorflow/third_party/tflite_smartreply.BUILD tensorflow/third_party/toolchains/BUILD -tensorflow/third_party/toolchains/clang6/BUILD -tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl -tensorflow/third_party/toolchains/clang6/README.md -tensorflow/third_party/toolchains/clang6/clang.BUILD -tensorflow/third_party/toolchains/clang6/repo.bzl tensorflow/third_party/toolchains/cpus/arm/BUILD tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl tensorflow/third_party/toolchains/cpus/arm/cc_config.bzl.tpl @@ -259,6 +254,8 @@ tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc- tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/cc_toolchain_config.bzl tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1/cc_toolchain_config.bzl +tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2/BUILD +tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2/cc_toolchain_config.bzl tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11/cc_toolchain_config.bzl tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/BUILD @@ -273,11 +270,6 @@ tensorflow/third_party/toolchains/preconfig/win_1803/BUILD tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD tensorflow/third_party/toolchains/preconfig/win_1803/py37/BUILD tensorflow/third_party/toolchains/preconfig/win_1803/py38/BUILD -tensorflow/third_party/toolchains/remote/BUILD -tensorflow/third_party/toolchains/remote/BUILD.tpl -tensorflow/third_party/toolchains/remote/configure.bzl -tensorflow/third_party/toolchains/remote/execution.bzl.tpl -tensorflow/third_party/toolchains/remote_config/configs.bzl tensorflow/third_party/typing_extensions.BUILD tensorflow/third_party/wrapt.BUILD tensorflow/third_party/zlib.BUILD diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index e9ced78ef1e..4cfc389eac6 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -767,6 +767,7 @@ py_library( "//tensorflow/python/util:_pywrap_transform_graph", "//tensorflow/python/util:_pywrap_util_port", ":_pywrap_utils", + ":_errors_test_helper", ":composite_tensor", ":config", ":convert_to_constants", @@ -1926,6 +1927,7 @@ tf_py_test( main = "framework/errors_test.py", python_version = "PY3", deps = [ + ":_errors_test_helper", ":client_testlib", ":errors", "//tensorflow/core:protos_all_py", @@ -5121,6 +5123,17 @@ tf_python_pybind_extension( ], ) +tf_python_pybind_extension( + name = "_errors_test_helper", + srcs = ["framework/errors_test_helper.cc"], + module_name = "_errors_test_helper", + deps = [ + "//tensorflow/core/platform:status", + "//tensorflow/python/lib/core:pybind11_status", + "@pybind11", + ], +) + cuda_py_tests( name = "device_lib_test", size = "small", diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index 07fa62bafb2..1115318c610 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -106,25 +106,10 @@ class _ErrorMetadata(error_utils.ErrorMetadataBase): message = self.get_message() init_args = tuple(init_argspec.args) # At the time of this writing, TF errors either take 3 or 4 arguments, - # with the fourth being error_code. - if init_args == ('self', 'node_def', 'op', 'message', 'error_code'): - return preferred_type( - node_def=source_error.node_def, - op=source_error.op, - message=message, - error_code=self.error_code) - elif init_args == ('self', 'node_def', 'op', 'message'): - if 'error_code' in init_argspec.kwonlyargs: - return preferred_type( - node_def=source_error.node_def, - op=source_error.op, - message=message, - errro_code=self.error_code) - else: - return preferred_type( - node_def=source_error.node_def, - op=source_error.op, - message=message) + # the argument '*args' may or may not be used. + if init_args == ('self', 'node_def', 'op', 'message'): + return preferred_type(source_error.node_def, source_error.op, message, + source_error.experimental_payloads) elif preferred_type in (errors.PyCTError, AutoGraphError, ConversionError, StagingError, errors_impl.InaccessibleTensorError, @@ -159,8 +144,8 @@ def _attach_error_metadata(e, f): cause_tb = traceback.extract_tb(sys.exc_info()[2])[1:] - e.ag_error_metadata = _ErrorMetadata( - cause_tb, metadata, message, source_map, __file__) + e.ag_error_metadata = _ErrorMetadata(cause_tb, metadata, message, source_map, + __file__) class StackTraceMapper(tf_stack.StackTraceMapper): @@ -312,11 +297,7 @@ def is_autograph_artifact(entity): return hasattr(entity, 'autograph_info__') -def converted_call(f, - args, - kwargs, - caller_fn_scope=None, - options=None): +def converted_call(f, args, kwargs, caller_fn_scope=None, options=None): """Converts a function call inline. For internal use only. @@ -437,8 +418,7 @@ def converted_call(f, return _fall_back_unconverted(f, args, kwargs, options, e) if not hasattr(target_entity, '__code__'): - logging.log(2, 'Permanently allowed: %s: native binding', - target_entity) + logging.log(2, 'Permanently allowed: %s: native binding', target_entity) return _call_unconverted(f, args, kwargs, options) elif (hasattr(target_entity.__code__, 'co_filename') and target_entity.__code__.co_filename == ''): @@ -611,6 +591,7 @@ def tf_convert(f, ctx, convert_by_default=True, user_requested=False): def call_with_unspecified_conversion_status(func): """Decorator that resets the conversion context to the unspecified status.""" + def wrapper(*args, **kwargs): with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED): return func(*args, **kwargs) diff --git a/tensorflow/python/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py index 59ae5f4d856..50bc452a2a9 100644 --- a/tensorflow/python/autograph/impl/api_test.py +++ b/tensorflow/python/autograph/impl/api_test.py @@ -33,6 +33,7 @@ import types import numpy as np import six +from tensorflow.python import _errors_test_helper from tensorflow.python.autograph.core import ag_ctx from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import converter_testing @@ -46,6 +47,7 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import def_function from tensorflow.python.eager import function from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors as tf_errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import math_ops @@ -1260,6 +1262,35 @@ class ApiTest(test.TestCase): self.assertEqual(5, tc.two_args(2)) + def test_raise_from_func_graph(self): + + @def_function.function + def raise_from_tf_function(n): + _errors_test_helper.TestRaiseFromStatus(n) + + for code, expected_exception in [ + (1, tf_errors.CancelledError), + (2, tf_errors.UnknownError), + (3, tf_errors.InvalidArgumentError), + (4, tf_errors.DeadlineExceededError), + (5, tf_errors.NotFoundError), + (6, tf_errors.AlreadyExistsError), + (7, tf_errors.PermissionDeniedError), + (16, tf_errors.UnauthenticatedError), + (8, tf_errors.ResourceExhaustedError), + (9, tf_errors.FailedPreconditionError), + (10, tf_errors.AbortedError), + (11, tf_errors.OutOfRangeError), + (12, tf_errors.UnimplementedError), + (13, tf_errors.InternalError), + (14, tf_errors.UnavailableError), + (15, tf_errors.DataLossError), + ]: + with self.assertRaises(expected_exception) as error: + raise_from_tf_function(code) + self.assertEqual(error.exception.experimental_payloads['key1'], 'value1') + self.assertEqual(error.exception.experimental_payloads['key2'], 'value2') + if __name__ == '__main__': os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1' diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index fae583e878b..46f986d43de 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -33,7 +33,7 @@ from tensorflow.python.util.tf_export import tf_export # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2021, 2, 26) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2021, 3, 3) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None diff --git a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py index ec9157d20f8..8af00e6860b 100644 --- a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py @@ -163,6 +163,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): self._trt_test_params = None self._disable_non_trt_optimizers = False self._use_implicit_batch = True + self._profile_strategy = "Unknown" def setUp(self): """Setup method.""" @@ -264,8 +265,9 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): def DisableNonTrtOptimizers(self): self._disable_non_trt_optimizers = True - def DisableImplicitBatchMode(self): + def SetDynamicShapeModeAndProfileStrategy(self, profile_strategy="Range"): self._use_implicit_batch = False + self._profile_strategy = profile_strategy def GetParams(self): """Returns a TfTrtIntegrationTestParams for the test.""" @@ -453,11 +455,11 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): if run_params.is_v2: converter_v2 = trt_convert.TrtGraphConverterV2( input_saved_model_dir=saved_model_dir, - conversion_params=conversion_params) + conversion_params=conversion_params, + use_dynamic_shape=not self._use_implicit_batch, + dynamic_shape_profile_strategy=self._profile_strategy) if self._disable_non_trt_optimizers: converter_v2._test_only_disable_non_trt_optimizers = True # pylint: disable=protected-access - if not self._use_implicit_batch: - converter_v2._test_only_use_implicit_batch = False # pylint: disable=protected-access return converter_v2 converter_v1 = trt_convert.TrtGraphConverter( @@ -873,6 +875,10 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): all_op_names.append(node.name) if node.op == "TRTEngineOp": trt_op_names.append(node.name) + if not self._use_implicit_batch: + self.assertEqual( + self._ToString(node.attr["profile_strategy"].s).lower(), + self._profile_strategy.lower()) all_op_names = self._Canonicalize(all_op_names) trt_op_names = self._RemoveGraphSequenceNumber( diff --git a/tensorflow/python/compiler/tensorrt/test/trt_mode_test.py b/tensorflow/python/compiler/tensorrt/test/trt_mode_test.py index f5eb3c75653..b96d9b3b586 100644 --- a/tensorflow/python/compiler/tensorrt/test/trt_mode_test.py +++ b/tensorflow/python/compiler/tensorrt/test/trt_mode_test.py @@ -125,8 +125,8 @@ class ExplicitBatchTest(TrtModeTestBase): def setUp(self): super().setUp() - # Diable implicit batch mode for testing explicit batch mode. - self.DisableImplicitBatchMode() + self.SetDynamicShapeModeAndProfileStrategy( + profile_strategy="ImplicitBatchModeCompatible") class DynamicShapesTest(TrtModeTestBase): @@ -162,7 +162,9 @@ class DynamicShapesTest(TrtModeTestBase): def setUp(self): super().setUp() - self.DisableImplicitBatchMode() + self.SetDynamicShapeModeAndProfileStrategy( + profile_strategy="ImplicitBatchModeCompatible") + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/compiler/tensorrt/trt_convert.py b/tensorflow/python/compiler/tensorrt/trt_convert.py index d1579f244ac..a16e9ad429c 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert.py @@ -114,6 +114,19 @@ class TrtPrecisionMode(object): # so it can produce reasonable performance results with the default. DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30 +PROFILE_STRATEGY_RANGE = "Range" +PROFILE_STRATEGY_OPTIMAL = "Optimal" +PROFILE_STRATEGY_RANGE_OPTIMAL = "Range+Optimal" +PROFILE_STRATEGY_IMPLICIT_BATCH_MODE_COMPATIBLE = "ImplicitBatchModeCompatible" + + +def supported_profile_strategies(): + return [ + PROFILE_STRATEGY_RANGE, PROFILE_STRATEGY_OPTIMAL, + PROFILE_STRATEGY_RANGE_OPTIMAL, + PROFILE_STRATEGY_IMPLICIT_BATCH_MODE_COMPATIBLE + ] + @tf_export("experimental.tensorrt.ConversionParams", v1=[]) class TrtConversionParams( @@ -234,7 +247,8 @@ def _get_tensorrt_rewriter_config(conversion_params, max_batch_size=None, is_v2=False, disable_non_trt_optimizers=False, - use_implicit_batch=True): + use_implicit_batch=True, + profile_strategy=PROFILE_STRATEGY_RANGE): """Returns a RewriterConfig proto for TRT transformation. Args: @@ -244,6 +258,7 @@ def _get_tensorrt_rewriter_config(conversion_params, is_v2: whether we're getting a RewriterConfig for TF 2.0. disable_non_trt_optimizers: Turn off all default Grappler optimizers. use_implicit_batch: Whether to use implicit batch or explicit batch. + profile_strategy: dynamic shape optimization profile strategy. Returns: A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler. @@ -301,6 +316,11 @@ def _get_tensorrt_rewriter_config(conversion_params, if max_batch_size is not None: optimizer.parameter_map["max_batch_size"].i = max_batch_size optimizer.parameter_map["use_implicit_batch"].b = use_implicit_batch + # While we accept case insensitive strings from the users, we only pass the + # strings in lower cases to TF-TRT converter. + if not use_implicit_batch: + optimizer.parameter_map["profile_strategy"].s = _to_bytes( + profile_strategy.lower()) # Disabling optimizers should happen after defining the TF-TRT grappler pass # otherwise the template can overwrite the disablement. @@ -921,12 +941,35 @@ class TrtGraphConverterV2(object): # Save the TRT engine and the engines. converter.save(output_saved_model_dir) ``` + 4. To use dynamic shape, we need to call the build method with an input + function to generate profiles. This step is similar to the INT8 calibration + step described above. The converter also needs to be created with + use_dynamic_shape=True and one of the following profile_strategies for + creating profiles based on the inputs produced by the input function: + * `Range`: create one profile that works for inputs with dimension values + in the range of [min_dims, max_dims] where min_dims and max_dims are + derived from the provided inputs. + * `Optimal`: create one profile for each input. The profile only works for + inputs with the same dimensions as the input it is created for. The GPU + engine will be run with optimal performance with such inputs. + * `Range+Optimal`: create the profiles for both `Range` and `Optimal`. + * `ImplicitBatchModeCompatible`: create the profiles that will produce the + same GPU engines as the implicit_batch_mode would produce. """ + def _verify_profile_strategy(self, strategy): + supported_strategies = [s.lower() for s in supported_profile_strategies()] + if strategy.lower() not in supported_strategies: + raise ValueError( + ("profile_strategy '{}' is not supported. It should be one of {}" + ).format(strategy, supported_profile_strategies())) + def __init__(self, input_saved_model_dir=None, input_saved_model_tags=None, input_saved_model_signature_key=None, + use_dynamic_shape=None, + dynamic_shape_profile_strategy=None, conversion_params=None): """Initialize the converter. @@ -936,6 +979,11 @@ class TrtGraphConverterV2(object): input_saved_model_tags: list of tags to load the SavedModel. input_saved_model_signature_key: the key of the signature to optimize the graph for. + use_dynamic_shape: whether to enable dynamic shape support. None is + equivalent to False in the current implementation. + dynamic_shape_profile_strategy: one of the strings in + supported_profile_strategies(). None is equivalent to Range in the + current implementation. conversion_params: a TrtConversionParams instance. Raises: @@ -963,12 +1011,24 @@ class TrtGraphConverterV2(object): self._converted = False self._build_called_once = False + if use_dynamic_shape is None: + self._use_dynamic_shape = False + else: + self._use_dynamic_shape = use_dynamic_shape + + self._profile_strategy = "Unknown" + if self._use_dynamic_shape: + if dynamic_shape_profile_strategy is None: + self._profile_strategy = PROFILE_STRATEGY_RANGE + else: + self._verify_profile_strategy(dynamic_shape_profile_strategy) + self._profile_strategy = dynamic_shape_profile_strategy + # Fields to support TF-TRT testing and shouldn't be used for other purpose. self._test_only_disable_non_trt_optimizers = False - self._test_only_use_implicit_batch = True def _need_trt_profiles(self): - return not self._test_only_use_implicit_batch + return self._use_dynamic_shape def _run_conversion(self, meta_graph_def): """Run Grappler's OptimizeGraph() tool to convert the graph. @@ -985,7 +1045,8 @@ class TrtGraphConverterV2(object): is_dynamic_op=True, max_batch_size=None, disable_non_trt_optimizers=self._test_only_disable_non_trt_optimizers, - use_implicit_batch=self._test_only_use_implicit_batch) + use_implicit_batch=not self._use_dynamic_shape, + profile_strategy=self._profile_strategy) grappler_session_config.graph_options.rewrite_options.CopyFrom( custom_rewriter_config) return tf_optimizer.OptimizeGraph( diff --git a/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py index 50672da090b..cfe4a713754 100644 --- a/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py @@ -38,6 +38,7 @@ from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import sparse_ops @@ -107,6 +108,18 @@ class DataServiceOpsTest(data_service_test_base.TestBase, self.assertAllEqual(results[1], [[0, 1, 2], [0, 1, 0]]) self.assertAllEqual(results[2], [[0, 1, 2, 3, 4, 5, 6, 7]]) + @combinations.generate(test_base.eager_only_combinations()) + def testDistributeLookupTable(self): + cluster = data_service_test_base.TestCluster(num_workers=1) + keys_tensor = constant_op.constant([1, 2]) + vals_tensor = constant_op.constant([11, 12]) + table = lookup_ops.StaticHashTable( + lookup_ops.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1) + ds = dataset_ops.Dataset.range(3, output_type=dtypes.int32) + ds = ds.map(table.lookup) + ds = self.make_distributed_dataset(ds, cluster) + self.assertDatasetProduces(ds, [-1, 11, 12]) + @combinations.generate(test_base.eager_only_combinations()) def testDifferentShuffleOrders(self): random_seed.set_random_seed(None) diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD index 22af87ff0b9..9cf0ae0d58a 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD @@ -7,30 +7,6 @@ package( exports_files(["LICENSE"]) -py_library( - name = "dataset_serialization_test_base", - srcs = [ - "dataset_serialization_test_base.py", - ], - srcs_version = "PY3", - deps = [ - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:lookup_ops", - "//tensorflow/python:platform", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:variables", - "//tensorflow/python/data/experimental/ops:iterator_ops", - "//tensorflow/python/data/kernel_tests:test_base", - "//tensorflow/python/data/ops:iterator_ops", - "//third_party/py/numpy", - ], -) - tf_py_test( name = "assert_cardinality_dataset_serialization_test", size = "small", @@ -41,9 +17,9 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", "@absl_py//absl/testing:parameterized", ], @@ -59,13 +35,13 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:sparse_tensor", "//tensorflow/python/data/experimental/ops:distribute", "//tensorflow/python/data/experimental/ops:interleave_ops", "//tensorflow/python/data/experimental/ops:readers", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -81,11 +57,11 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:sparse_tensor", "//tensorflow/python/data/experimental/ops:batching", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -101,9 +77,9 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", "@absl_py//absl/testing:parameterized", ], @@ -142,12 +118,12 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:math_ops", "//tensorflow/python/data/experimental/ops:batching", "//tensorflow/python/data/experimental/ops:optimization", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -162,9 +138,9 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python/data/experimental/ops:optimization", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -179,8 +155,8 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -197,8 +173,8 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "@absl_py//absl/testing:parameterized", @@ -215,10 +191,10 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_ops", "//tensorflow/python/data/experimental/ops:readers", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", ], ) @@ -232,9 +208,9 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -250,10 +226,10 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:math_ops", "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -270,9 +246,9 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:readers", ], ) @@ -287,7 +263,6 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", @@ -298,6 +273,7 @@ tf_py_test( "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:variable_scope", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -312,9 +288,9 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python/data/experimental/ops:grouping", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -330,9 +306,9 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python/data/experimental/ops:grouping", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -349,10 +325,10 @@ tf_py_test( "noasan", # TODO(b/337374867) fails with -fsanitize=null ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python/data/experimental/ops:error_ops", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -368,10 +344,10 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", @@ -388,10 +364,10 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:math_ops", "//tensorflow/python/data/experimental/ops:batching", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -406,7 +382,6 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", @@ -416,6 +391,7 @@ tf_py_test( "//tensorflow/python:random_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:variable_scope", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -429,9 +405,9 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python/data/experimental/ops:matching_files", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -447,9 +423,9 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python/data/experimental/ops:optimization", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -464,9 +440,9 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python/data/experimental/ops:distribute", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -481,10 +457,10 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:string_ops", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -500,11 +476,11 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python/data/experimental/ops:interleave_ops", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -520,7 +496,6 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", @@ -529,6 +504,7 @@ tf_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python:variable_scope", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -544,9 +520,9 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", ], ) @@ -560,8 +536,8 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -576,7 +552,6 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", @@ -585,6 +560,7 @@ tf_py_test( "//tensorflow/python:io_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:variables", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -599,9 +575,9 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python/data/experimental/ops:interleave_ops", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -616,9 +592,9 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python/data/experimental/ops:scan_ops", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -634,8 +610,8 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -671,8 +647,8 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", "@absl_py//absl/testing:parameterized", ], @@ -688,9 +664,9 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python/data/experimental/ops:shuffle_ops", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -705,11 +681,11 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_ops", "//tensorflow/python:training", "//tensorflow/python/data/experimental/ops:iterator_ops", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -724,12 +700,12 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python/data/experimental/kernel_tests:sql_dataset_test_base", "//tensorflow/python/data/experimental/ops:readers", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", ], ) @@ -743,12 +719,12 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_ops", "//tensorflow/python/data/experimental/ops:stats_aggregator", "//tensorflow/python/data/experimental/ops:stats_ops", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -763,9 +739,9 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python/data/experimental/ops:take_while_ops", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", "@absl_py//absl/testing:parameterized", ], @@ -782,9 +758,9 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:readers", ], ) @@ -800,9 +776,9 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:readers", ], ) @@ -817,9 +793,9 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python/data/experimental/ops:batching", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -835,9 +811,9 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python/data/experimental/ops:unique", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -852,8 +828,8 @@ tf_py_test( "no_windows", ], deps = [ - ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/assert_cardinality_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/assert_cardinality_dataset_serialization_test.py index 59332e31802..ef6892648b9 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/assert_cardinality_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/assert_cardinality_dataset_serialization_test.py @@ -12,24 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the AssertCardinalityDataset serialization.""" +"""Tests for checkpointing the AssertCardinalityDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl.testing import parameterized -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.experimental.ops import cardinality +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations from tensorflow.python.platform import test -class AssertCardinalityDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class AssertCardinalityDatasetCheckpointTest( + checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testCardinality(self): diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/auto_shard_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/auto_shard_dataset_serialization_test.py index 195181d14c6..479b45e7610 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/auto_shard_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/auto_shard_dataset_serialization_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the _AutoShard dataset serialization.""" +"""Tests for checkpointing the _AutoShardDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -21,9 +21,9 @@ import os from absl.testing import parameterized -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.experimental.ops import distribute from tensorflow.python.data.experimental.ops import interleave_ops +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers @@ -33,9 +33,8 @@ from tensorflow.python.platform import test from tensorflow.python.util import compat -class AutoShardDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class AutoShardDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def _record(self, f, r): return compat.as_bytes("Record %d of file %d" % (r, f)) diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/batch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/batch_dataset_serialization_test.py index f6603a4090b..0c03a01d2d4 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/batch_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/batch_dataset_serialization_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the BatchDataset serialization.""" +"""Tests for checkpointing the BatchDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -20,8 +20,8 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.experimental.ops import batching +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations @@ -30,9 +30,8 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class BatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class BatchDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2): components = ( diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/cache_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/cache_dataset_serialization_test.py index a82a81b9e06..f79f7059f0c 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/cache_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/cache_dataset_serialization_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the CacheDataset serialization.""" +"""Tests for checkpointing the CacheDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -21,7 +21,7 @@ import os from absl.testing import parameterized -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations @@ -29,9 +29,8 @@ from tensorflow.python.framework import errors from tensorflow.python.platform import test -class CacheDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class CacheDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def setUp(self): self.range_size = 10 diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/choose_fastest_branch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/choose_fastest_branch_dataset_serialization_test.py index 325dbc9187d..7da6546f46c 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/choose_fastest_branch_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/choose_fastest_branch_dataset_serialization_test.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the ChooseFastestBranchDataset serialization.""" +"""Tests for checkpointing the ChooseFastestBranchDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl.testing import parameterized -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.experimental.ops import optimization +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations @@ -30,9 +30,8 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class ChooseFastestBranchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class ChooseFastestBranchDatasetCheckpointTest( + checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testCore(self): diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/choose_fastest_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/choose_fastest_dataset_serialization_test.py index cdd2edfd617..0c4de603d02 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/choose_fastest_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/choose_fastest_dataset_serialization_test.py @@ -12,24 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the ChooseFastestDataset serialization.""" +"""Tests for checkpointing the ChooseFastestDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl.testing import parameterized -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.experimental.ops import optimization +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations from tensorflow.python.platform import test -class ChooseFastestDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class ChooseFastestDatasetCheckpointTest( + checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testCore(self): diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/concatenate_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/concatenate_dataset_serialization_test.py index 0e3bc637274..45696da74c9 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/concatenate_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/concatenate_dataset_serialization_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the ConcatenateDataset serialization.""" +"""Tests for checkpointing the ConcatenateDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -20,16 +20,15 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations from tensorflow.python.platform import test -class ConcatenateDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class ConcatenateDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def _build_concatenate_dataset(self, var_array): input_components = (np.tile(np.array([[1], [2], [3], [4]]), 20), diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/csv_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/csv_dataset_serialization_test.py index 1540e67119e..f80449877ca 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/csv_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/csv_dataset_serialization_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the CsvDataset serialization.""" +"""Tests for checkpointing the CsvDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -22,16 +22,15 @@ import os from absl.testing import parameterized -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.experimental.ops import readers from tensorflow.python.framework import combinations from tensorflow.python.platform import test -class CsvDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class CsvDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def setUp(self): self._num_cols = 7 diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_constructor_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_constructor_serialization_test.py index 88fa7d4e022..d5979b9f01f 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_constructor_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_constructor_serialization_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the dataset constructors serialization.""" +"""Tests for checkpointing the dataset constructors.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -20,7 +20,7 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations @@ -28,9 +28,8 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.platform import test -class FromTensorsSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class FromTensorsCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def _build_tensor_dataset(self, variable_array): components = (variable_array, np.array([1, 2, 3]), np.array(37.0)) @@ -46,9 +45,8 @@ class FromTensorsSerializationTest( num_outputs) -class FromTensorSlicesSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class FromTensorSlicesCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def _build_tensor_slices_dataset(self, components): return dataset_ops.Dataset.from_tensor_slices(components) @@ -68,9 +66,8 @@ class FromTensorSlicesSerializationTest( lambda: self._build_tensor_slices_dataset(dict_components), 3) -class FromSparseTensorSlicesSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class FromSparseTensorSlicesCheckpointTest( + checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_sparse_tensor_slice_dataset(self, slices): indices = np.array( diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py index 76675fcacbe..64e16912784 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the FilterDataset serialization.""" +"""Tests for checkpointing the FilterDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl.testing import parameterized -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations @@ -28,9 +28,8 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class FilterDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class FilterDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def _build_filter_range_graph(self, div): return dataset_ops.Dataset.range(100).filter( diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py index 40ebc8a05bf..9f0cec8f759 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the FixedLengthRecordDataset serialization.""" +"""Tests for checkpointing the FixedLengthRecordDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -20,17 +20,16 @@ from __future__ import print_function from absl.testing import parameterized from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.framework import combinations from tensorflow.python.platform import test -class FixedLengthRecordDatasetSerializationTest( +class FixedLengthRecordDatasetCheckpointTest( reader_dataset_ops_test_base.FixedLengthRecordDatasetTestBase, - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): + checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_iterator_graph(self, num_epochs, compression_type=None): filenames = self._createFiles() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/flat_map_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/flat_map_dataset_serialization_test.py index bfe9521f9c5..8325916ec52 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/flat_map_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/flat_map_dataset_serialization_test.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the FlatMapDataset serialization.""" +"""Tests for checkpointing the FlatMapDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl.testing import parameterized -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations @@ -35,9 +35,8 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test -class FlatMapDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class FlatMapDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testCore(self): diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_reducer_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_reducer_serialization_test.py index 3763c1decd2..95076c6ee00 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_reducer_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_reducer_serialization_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the GroupByReducer serialization.""" +"""Tests for checkpointing the GroupByReducer.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -20,17 +20,16 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.experimental.ops import grouping +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations from tensorflow.python.platform import test -class GroupByReducerSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class GroupByReducerCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def _build_dataset(self, components): reducer = grouping.Reducer( diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_window_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_window_serialization_test.py index eaa416dc2fe..46a4fbb42fd 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_window_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_window_serialization_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the GroupByWindow serialization.""" +"""Tests for checkpointing the GroupByWindow datasets.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -20,17 +20,16 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.experimental.ops import grouping +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations from tensorflow.python.platform import test -class GroupByWindowSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class GroupByWindowCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def _build_dataset(self, components): return dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply( diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/ignore_errors_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/ignore_errors_serialization_test.py index 3c2e9276ca0..d051abe2554 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/ignore_errors_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/ignore_errors_serialization_test.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the IgnoreErrors input pipeline ops.""" +"""Tests for checkpointing the IgnoreErrors input pipeline ops datasets.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl.testing import parameterized -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.experimental.ops import error_ops +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations @@ -28,9 +28,8 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class IgnoreErrorsSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class IgnoreErrorsCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def _build_ds(self): return dataset_ops.Dataset.range(5).map( diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/interleave_dataset_serialization_test.py index ff3f238f34b..44220eec976 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/interleave_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/interleave_dataset_serialization_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the InterleaveDataset serialization.""" +"""Tests for checkpointing the InterleaveDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -20,7 +20,7 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations @@ -29,9 +29,8 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test -class InterleaveDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class InterleaveDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def _build_iterator_graph(self, input_values, cycle_length, block_length, num_parallel_calls): diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py index 450cb24fb5b..3a9755f8a2c 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the MapAndBatchDataset serialization.""" +"""Tests for checkpointing the MapAndBatchDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -21,8 +21,8 @@ import math from absl.testing import parameterized -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.experimental.ops import batching +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations @@ -31,9 +31,8 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class MapAndBatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class MapAndBatchDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testNumParallelBatches(self): diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/map_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/map_dataset_serialization_test.py index a81bccf1c4e..04ba37b230a 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/map_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/map_dataset_serialization_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the MapDataset serialization.""" +"""Tests for checkpointing the MapDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -20,7 +20,7 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations @@ -35,9 +35,8 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test -class MapDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class MapDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def setUp(self): self._tensor_slice_len = 7 diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/matching_files_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/matching_files_dataset_serialization_test.py index 909bab89f66..9ce5b90bf76 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/matching_files_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/matching_files_dataset_serialization_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the MatchingFilesDataset serialization.""" +"""Tests for checkpointing the MatchingFilesDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -23,16 +23,15 @@ import tempfile from absl.testing import parameterized -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.experimental.ops import matching_files +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.framework import combinations from tensorflow.python.platform import test -class MatchingFilesDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class MatchingFilesDatasetCheckpointTest( + checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_iterator_graph(self, test_patterns): return matching_files.MatchingFilesDataset(test_patterns) diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/optimize_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/optimize_dataset_serialization_test.py index 30d53165f85..ac124fa86d5 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/optimize_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/optimize_dataset_serialization_test.py @@ -12,24 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the OptimizeDataset serialization.""" +"""Tests for checkpointing the OptimizeDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl.testing import parameterized -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.experimental.ops import optimization +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations from tensorflow.python.platform import test -class OptimizeDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class OptimizeDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testCore(self): diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/padded_batch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/padded_batch_dataset_serialization_test.py index 956279cb7a5..17d605ba6bd 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/padded_batch_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/padded_batch_dataset_serialization_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the PaddedBatchDataset serialization.""" +"""Tests for checkpointing the PaddedBatchDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -20,7 +20,7 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations @@ -29,9 +29,8 @@ from tensorflow.python.ops import string_ops from tensorflow.python.platform import test -class PaddedBatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class PaddedBatchDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testPaddedBatch(self): diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py index 79ee2937d8a..d89b715aa3a 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the ParallelInterleaveDataset serialization.""" +"""Tests for checkpointing the ParallelInterleaveDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -20,8 +20,8 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.experimental.ops import interleave_ops +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations @@ -30,9 +30,8 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test -class ParallelInterleaveDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class ParallelInterleaveDatasetCheckpointTest( + checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def setUp(self): self.input_values = np.array([4, 5, 6], dtype=np.int64) @@ -47,7 +46,7 @@ class ParallelInterleaveDatasetSerializationTest( cycle_length, block_length, sloppy))) @combinations.generate(test_base.default_test_combinations()) - def testSerializationCore(self): + def testCheckpointCore(self): # cycle_length > 1, block_length > 1 cycle_length = 2 block_length = 3 @@ -65,7 +64,7 @@ class ParallelInterleaveDatasetSerializationTest( self.num_outputs) @combinations.generate(test_base.default_test_combinations()) - def testSerializationWithSloppy(self): + def testCheckpointWithSloppy(self): break_points = self.gen_break_points(self.num_outputs, 10) expected_outputs = np.repeat( np.concatenate([np.arange(10 * x, 11 * x) for x in self.input_values]), diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_map_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_map_dataset_serialization_test.py index 48081d16ac4..0ab48e5b1a2 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_map_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_map_dataset_serialization_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the ParallelMapDataset serialization.""" +"""Tests for checkpointing the ParallelMapDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -20,7 +20,7 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations @@ -34,9 +34,8 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test -class ParallelMapDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class ParallelMapDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def setUp(self): self._tensor_slice_len = 7 diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py index 738fb1ecdbe..d459444fa0e 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the ParseExampleDataset serialization.""" +"""Tests for checkpointing the ParseExampleDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -20,16 +20,15 @@ from __future__ import print_function from absl.testing import parameterized from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.framework import combinations from tensorflow.python.platform import test -class ParseExampleDatasetSerializationTest( +class ParseExampleDatasetCheckpointTest( reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase, - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): + checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _parse_example_dataset(self, num_repeat, batch_size): return self.make_batch_feature( @@ -40,7 +39,7 @@ class ParseExampleDatasetSerializationTest( parser_num_threads=10) @combinations.generate(test_base.default_test_combinations()) - def testSerializationCore(self): + def testCheckpointCore(self): num_repeat = 5 batch_size = 2 num_outputs = self._num_records * self._num_files * num_repeat // batch_size diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/prefetch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/prefetch_dataset_serialization_test.py index 98b89fca6ff..da261e193e8 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/prefetch_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/prefetch_dataset_serialization_test.py @@ -12,23 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the PrefetchDataset serialization.""" +"""Tests for checkpointing the PrefetchDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl.testing import parameterized -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations from tensorflow.python.platform import test -class PrefetchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class PrefetchDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def build_dataset(self, seed): return dataset_ops.Dataset.range(100).prefetch(10).shuffle( diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py index 557bdc72a20..8378ded1f49 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the RangeDataset serialization.""" +"""Tests for checkpointing the RangeDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -21,7 +21,7 @@ import os from absl.testing import parameterized -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations @@ -35,9 +35,8 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test -class RangeDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class RangeDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def _iterator_checkpoint_prefix_local(self): return os.path.join(self.get_temp_dir(), "iterator") @@ -58,6 +57,7 @@ class RangeDatasetSerializationTest( iterator_state_variant) return restore_op + # TODO(vikoth18): implement eager mode checkpoint tests @combinations.generate( combinations.combine(tf_api_version=1, mode=["graph"])) def testSaveRestore(self): @@ -77,7 +77,7 @@ class RangeDatasetSerializationTest( break_point = 5 with ops.Graph().as_default() as g: init_op, get_next, save_op, _ = _build_graph(start, stop) - with self.session(graph=g) as sess: + with self.session(graph=g): self.evaluate(variables.global_variables_initializer()) self.evaluate(init_op) for i in range(start, break_point): @@ -86,7 +86,7 @@ class RangeDatasetSerializationTest( with ops.Graph().as_default() as g: init_op, get_next, _, restore_op = _build_graph(start, stop) - with self.session(graph=g) as sess: + with self.session(graph=g): self.evaluate(init_op) self.evaluate(restore_op) for i in range(break_point, stop): @@ -97,7 +97,7 @@ class RangeDatasetSerializationTest( # Saving and restoring in same session. with ops.Graph().as_default() as g: init_op, get_next, save_op, restore_op = _build_graph(start, stop) - with self.session(graph=g) as sess: + with self.session(graph=g): self.evaluate(variables.global_variables_initializer()) self.evaluate(init_op) for i in range(start, break_point): @@ -116,7 +116,6 @@ class RangeDatasetSerializationTest( def testRangeCore(self): start = 2 stop = 10 - stop_1 = 8 self.run_core_tests(lambda: self._build_range_dataset(start, stop), stop - start) diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/rebatch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/rebatch_dataset_serialization_test.py index fe4eac5b69d..1c50d80afa5 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/rebatch_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/rebatch_dataset_serialization_test.py @@ -12,24 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the _RebatchDataset serialization.""" +"""Tests for checkpointing the _RebatchDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl.testing import parameterized -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.experimental.ops import distribute +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations from tensorflow.python.platform import test -class LegacyRebatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class LegacyRebatchDatasetCheckpointTest( + checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testCore(self): @@ -43,9 +42,8 @@ class LegacyRebatchDatasetSerializationTest( self.run_core_tests(lambda: build_dataset(64, 8), 8) -class RebatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class RebatchDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testCore(self): diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/sample_from_datasets_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/sample_from_datasets_serialization_test.py index 2dcc272615b..24fdee94f56 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/sample_from_datasets_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/sample_from_datasets_serialization_test.py @@ -12,24 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the SampleFromDatasets serialization.""" +"""Tests for checkpointing the SampleFromDatasets.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl.testing import parameterized -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.experimental.ops import interleave_ops +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations from tensorflow.python.platform import test -class SampleFromDatasetsSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class SampleFromDatasetsCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def _build_dataset(self, probs, num_samples): dataset = interleave_ops.sample_from_datasets( @@ -42,7 +41,7 @@ class SampleFromDatasetsSerializationTest( return dataset.take(num_samples) @combinations.generate(test_base.default_test_combinations()) - def testSerializationCore(self): + def testCheckpointCore(self): self.run_core_tests(lambda: self._build_dataset([0.5, 0.5], 100), 100) diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/scan_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/scan_dataset_serialization_test.py index 31e3e578402..2e21119919c 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/scan_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/scan_dataset_serialization_test.py @@ -12,24 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the ScanDataset serialization.""" +"""Tests for checkpointing the ScanDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl.testing import parameterized -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.experimental.ops import scan_ops +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations from tensorflow.python.platform import test -class ScanDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class ScanDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def _build_dataset(self, num_elements): return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply( diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/sequence_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/sequence_dataset_serialization_test.py index bab6c594072..ba317f66241 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/sequence_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/sequence_dataset_serialization_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the sequence datasets serialization.""" +"""Tests for checkpointing the sequence datasets.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -20,16 +20,15 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations from tensorflow.python.platform import test -class SkipDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class SkipDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def _build_skip_dataset(self, count): components = (np.arange(10),) @@ -58,9 +57,8 @@ class SkipDatasetSerializationTest( self.run_core_tests(lambda: self._build_skip_dataset([1, 2]), 0) -class TakeDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class TakeDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def _build_take_dataset(self, count): components = (np.arange(10),) @@ -88,9 +86,8 @@ class TakeDatasetSerializationTest( self.run_core_tests(lambda: self._build_take_dataset([1, 2]), 0) -class RepeatDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class RepeatDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def _build_repeat_dataset(self, count, take_count=3): components = (np.arange(10),) diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/shard_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/shard_dataset_serialization_test.py index 3745dad7d24..623964b3b3b 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/shard_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/shard_dataset_serialization_test.py @@ -12,23 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the ShardDataset serialization.""" +"""Tests for checkpointing the ShardDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl.testing import parameterized -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations from tensorflow.python.platform import test -class ShardDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class ShardDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def _build_dataset(self, num_elements, num_shards, index): return dataset_ops.Dataset.range(num_elements).shard(num_shards, index) diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py index ae2715f51f1..69501f09bc8 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py @@ -12,24 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the ShuffleAndRepeatDataset serialization.""" +"""Tests for checkpointing the ShuffleAndRepeatDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl.testing import parameterized -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.experimental.ops import shuffle_ops +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations from tensorflow.python.platform import test -class ShuffleAndRepeatSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class ShuffleAndRepeatCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def _build_ds(self, seed): return dataset_ops.Dataset.range(20).apply( diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py index d11f0335549..a1c44c8c8e8 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the ShuffleDataset serialization.""" +"""Tests for checkpointing the ShuffleDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl.testing import parameterized -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations @@ -29,9 +29,8 @@ from tensorflow.python.platform import test from tensorflow.python.training import saver as saver_lib -class ShuffleDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class ShuffleDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def _build_shuffle_dataset( self, diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/snapshot_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/snapshot_dataset_serialization_test.py index d5daaacae9a..c8369c70eb7 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/snapshot_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/snapshot_dataset_serialization_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the MapDataset serialization.""" +"""Tests for checkpointing the SnapshotDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -21,8 +21,8 @@ import os from absl.testing import parameterized -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.experimental.ops import snapshot +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations @@ -30,9 +30,8 @@ from tensorflow.python.framework import ops from tensorflow.python.platform import test -class SnapshotDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class SnapshotDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def _build_snapshot_dataset(self, repeat=False): @@ -121,9 +120,8 @@ class SnapshotDatasetSerializationTest( list(range(100)) + list(range(10)) + list(range(10, 100))) -class LegacySnapshotDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class LegacySnapshotDatasetCheckpointTest( + checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_snapshot_dataset(self, num_threads=1, diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py index 9094955e175..184e7f23d85 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the SqlDataset serialization.""" +"""Tests for checkpointing the SqlDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -22,8 +22,8 @@ import os from absl.testing import parameterized from tensorflow.python.data.experimental.kernel_tests import sql_dataset_test_base -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.experimental.ops import readers +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.framework import combinations from tensorflow.python.framework import dtypes @@ -31,10 +31,9 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class SqlDatasetSerializationTest( - sql_dataset_test_base.SqlDatasetTestBase, - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class SqlDatasetCheckpointTest(sql_dataset_test_base.SqlDatasetTestBase, + checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def _build_dataset(self, num_repeats): data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py index 68bfd2aba35..6eb83d1f17e 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the StatsDataset serialization.""" +"""Tests for checkpointing the StatsDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl.testing import parameterized -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.experimental.ops import stats_aggregator from tensorflow.python.data.experimental.ops import stats_ops +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations @@ -33,9 +33,8 @@ from tensorflow.python.platform import test # TODO(b/116814321): Can not checkpoint input_pipeline with the # transformation `stats_ops.set_stats_aggregator`, since we don't support # saving/restoring resources (StatsAggregator in this case) yet. -class StatsDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class StatsDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def _build_dataset_bytes_stats(self, num_elements): return dataset_ops.Dataset.range(num_elements).map( diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/take_while_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/take_while_dataset_serialization_test.py index c189c13b458..6a8dc9b02e5 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/take_while_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/take_while_dataset_serialization_test.py @@ -12,24 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the TakeWhileDataset serialization.""" +"""Tests for checkpointing the TakeWhileDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl.testing import parameterized -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.experimental.ops import take_while_ops +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations from tensorflow.python.platform import test -class TakeWhileDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class TakeWhileDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def _build_dataset(self, num_elements, upper_bound): return dataset_ops.Dataset.range(num_elements).apply( diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/textline_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/textline_dataset_serialization_test.py index 5203d75f095..65c64d0d6e3 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/textline_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/textline_dataset_serialization_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the TextLineDataset serialization.""" +"""Tests for checkpointing the TextLineDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -20,17 +20,16 @@ from __future__ import print_function from absl.testing import parameterized from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.framework import combinations from tensorflow.python.platform import test -class TextLineDatasetSerializationTest( +class TextLineDatasetCheckpointTest( reader_dataset_ops_test_base.TextLineDatasetTestBase, - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): + checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_iterator_graph(self, test_filenames, compression_type=None): return core_readers.TextLineDataset( diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/tf_record_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/tf_record_dataset_serialization_test.py index 3fa88bc1267..e1546ad90a5 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/tf_record_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/tf_record_dataset_serialization_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the TFRecordDataset serialization.""" +"""Tests for checkpointing the TFRecordDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -24,17 +24,16 @@ import zlib from absl.testing import parameterized from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.framework import combinations from tensorflow.python.platform import test -class TFRecordDatasetSerializationTest( +class TFRecordDatasetCheckpointTest( reader_dataset_ops_test_base.TFRecordDatasetTestBase, - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): + checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_iterator_graph(self, num_epochs, diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/unbatch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/unbatch_dataset_serialization_test.py index 6fa115bd2fe..1bbb9750e56 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/unbatch_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/unbatch_dataset_serialization_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the UnbatchDataset serialization.""" +"""Tests for checkpointing the UnbatchDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -20,16 +20,15 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations from tensorflow.python.platform import test -class UnbatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class UnbatchDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2): components = ( diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/unique_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/unique_dataset_serialization_test.py index f9b77fe69e8..6db256f66c7 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/unique_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/unique_dataset_serialization_test.py @@ -12,24 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the UniqueDataset serialization.""" +"""Tests for checkpointing the UniqueDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl.testing import parameterized -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.experimental.ops import unique +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations from tensorflow.python.platform import test -class UniqueDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class UniqueDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testUnique(self): diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/zip_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/zip_dataset_serialization_test.py index a1b7cfef093..b1f8a7abcf5 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/zip_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/zip_dataset_serialization_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the ZipDataset serialization.""" +"""Tests for checkpointing the ZipDataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -20,16 +20,15 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations from tensorflow.python.platform import test -class ZipDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase, - parameterized.TestCase): +class ZipDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): def _build_dataset(self, arr): components = [ diff --git a/tensorflow/python/data/experimental/ops/threading_options.py b/tensorflow/python/data/experimental/ops/threading_options.py index 39da39353d6..9eb38da5439 100644 --- a/tensorflow/python/data/experimental/ops/threading_options.py +++ b/tensorflow/python/data/experimental/ops/threading_options.py @@ -48,7 +48,9 @@ class ThreadingOptions(options.OptionsBase): name="private_threadpool_size", ty=int, docstring= - "If set, the dataset will use a private threadpool of the given size.") + "If set, the dataset will use a private threadpool of the given size. " + "The value 0 can be used to indicate that the threadpool size should be " + "determined at runtime based on the number of available CPU cores.") def _to_proto(self): pb = dataset_options_pb2.ThreadingOptions() diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index e6c0f3b6b71..5b3b26aa0eb 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -8,6 +8,30 @@ package( licenses = ["notice"], # Apache 2.0 ) +py_library( + name = "checkpoint_test_base", + srcs = [ + "checkpoint_test_base.py", + ], + srcs_version = "PY3", + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:lookup_ops", + "//tensorflow/python:platform", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python:variables", + "//tensorflow/python/data/experimental/ops:iterator_ops", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:iterator_ops", + "//third_party/py/numpy", + ], +) + tf_py_test( name = "batch_test", size = "medium", diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/python/data/kernel_tests/checkpoint_test_base.py similarity index 94% rename from tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py rename to tensorflow/python/data/kernel_tests/checkpoint_test_base.py index 44fe30f6729..b810197afd4 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py +++ b/tensorflow/python/data/kernel_tests/checkpoint_test_base.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Base class for testing serializable datasets.""" +"""Base test class for checkpointing datasets.""" from __future__ import absolute_import from __future__ import division @@ -51,12 +51,12 @@ def remove_variants(get_next_op): return nest.map_structure(_remove_variant, get_next_op) -class DatasetSerializationTestBase(test.TestCase): - """Base class for testing serializable datasets.""" +class CheckpointTestBase(test.TestCase): + """Base test class for checkpointing datasets.""" def tearDown(self): self._delete_ckpt() - super(DatasetSerializationTestBase, self).tearDown() + super(CheckpointTestBase, self).tearDown() # TODO(b/72657739): Remove sparse_tensor argument, which is to test the # (deprecated) saveable `SparseTensorSliceDataset`, once the API @@ -72,7 +72,7 @@ class DatasetSerializationTestBase(test.TestCase): Raises: AssertionError if any test fails. """ - # NOTE: We disable all default optimizations in serialization tests in order + # NOTE: We disable all default optimizations in checkpoint tests in order # to test the actual dataset in question. options = dataset_ops.Options() options.experimental_optimization.apply_default_optimizations = False @@ -113,7 +113,9 @@ class DatasetSerializationTestBase(test.TestCase): sparse_tensors=sparse_tensors, verify_exhausted=verify_exhausted) - def verify_fully_used_iterator(self, ds_fn, num_outputs, + def verify_fully_used_iterator(self, + ds_fn, + num_outputs, sparse_tensors=False): """Verifies that saving and restoring a fully used iterator works. @@ -170,8 +172,8 @@ class DatasetSerializationTestBase(test.TestCase): Args: ds_fn: See `run_core_tests`. num_outputs: See `run_core_tests`. - num_breaks: The number of break points. These are uniformly spread in - [0, num_outputs] both inclusive. + num_breaks: The number of break points. These are uniformly spread in [0, + num_outputs] both inclusive. sparse_tensors: See `run_core_tests`. verify_exhausted: See `gen_outputs`. @@ -222,6 +224,7 @@ class DatasetSerializationTestBase(test.TestCase): verify_exhausted=False) actual = [] + # TODO(vikoth18): implement eager mode compatible checkpointing # Restore from checkpoint and then run init_op. with ops.Graph().as_default() as g: saver = self._import_meta_graph() @@ -259,6 +262,7 @@ class DatasetSerializationTestBase(test.TestCase): """ break_point = num_outputs // 2 if not break_point else break_point + # TODO(vikoth18): implement eager mode compatible checkpointing with ops.Graph().as_default() as g: init_op, get_next_op, saver = self._build_graph( ds_fn, sparse_tensors=sparse_tensors) @@ -327,10 +331,10 @@ class DatasetSerializationTestBase(test.TestCase): ds_fn: 0-argument function that returns the dataset. break_points: A list of integers. For each `break_point` in `break_points`, we produce outputs till `break_point` number of items - have been produced and then checkpoint the state. The current graph - and session are destroyed and a new graph and session are used to - produce outputs till next checkpoint or till `num_outputs` elements - have been produced. `break_point` must be <= `num_outputs`. + have been produced and then checkpoint the state. The current graph and + session are destroyed and a new graph and session are used to produce + outputs till next checkpoint or till `num_outputs` elements have been + produced. `break_point` must be <= `num_outputs`. num_outputs: The total number of outputs to produce from the iterator. ckpt_saved: Whether a checkpoint already exists. sparse_tensors: Whether dataset is built from SparseTensor(s). @@ -356,6 +360,7 @@ class DatasetSerializationTestBase(test.TestCase): ds_fn, sparse_tensors=sparse_tensors) return init_op, get_next_op, saver + # TODO(vikoth18): implement eager mode compatible checkpointing for i in range(len(break_points) + 1): with ops.Graph().as_default() as g: init_op, get_next_op, saver = get_ops() @@ -485,14 +490,17 @@ class DatasetSerializationTestBase(test.TestCase): return all_ops[0], nest.pack_sequence_as( self._get_output_types(ds_fn), get_next_list) + # TODO(vikoth18): replace with `element_spec` and add eager mode support def _get_output_types(self, ds_fn): with ops.Graph().as_default(): return dataset_ops.get_legacy_output_types(ds_fn()) + # TODO(vikoth18): replace with `element_spec` and add eager mode support def _get_output_shapes(self, ds_fn): with ops.Graph().as_default(): return dataset_ops.get_legacy_output_shapes(ds_fn()) + # TODO(vikoth18): replace with `element_spec` and add eager mode support def _get_output_classes(self, ds_fn): with ops.Graph().as_default(): return dataset_ops.get_legacy_output_classes(ds_fn()) diff --git a/tensorflow/python/data/kernel_tests/dataset_test.py b/tensorflow/python/data/kernel_tests/dataset_test.py index e5d616b5482..829b1d3610a 100644 --- a/tensorflow/python/data/kernel_tests/dataset_test.py +++ b/tensorflow/python/data/kernel_tests/dataset_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import collections +import os import warnings from absl.testing import parameterized @@ -43,6 +44,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -64,6 +66,33 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): dataset._as_serialized_graph(external_state_policy=distribute_options .ExternalStatePolicy.FAIL)) + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(init_from_file=[True, False]))) + def testLookupTableGraphSerialization(self, init_from_file): + if init_from_file: + file = os.path.join(self.get_temp_dir(), "lookup_table_graph_serialize") + with open(file, "w") as f: + f.write("10\n11\n") + initializer = lookup_ops.TextFileInitializer( + file, dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER, + dtypes.int64, lookup_ops.TextFileIndex.WHOLE_LINE) + else: + keys_tensor = constant_op.constant([0, 1], dtype=dtypes.int64) + vals_tensor = constant_op.constant([10, 11]) + initializer = lookup_ops.KeyValueTensorInitializer( + keys_tensor, vals_tensor) + + table = lookup_ops.StaticHashTable(initializer, -1) + dataset = dataset_ops.Dataset.range(3) + dataset = dataset.map(table.lookup) + self.evaluate(lookup_ops.tables_initializer()) + round_tripped = self.graphRoundTrip(dataset) + del table + del dataset + self.assertDatasetProduces( + round_tripped, [10, 11, -1], requires_initialization=True) + @combinations.generate(test_base.default_test_combinations()) def testAsFunctionWithMap(self): if not context.executing_eagerly(): diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 7c9316a6b2a..5a119418dc1 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -1541,16 +1541,26 @@ cuda_py_test( ":collective_all_reduce_strategy", ":combinations", ":distribute_lib", + ":distribute_utils", ":strategy_combinations", ":values", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", + "//tensorflow/python:config", "//tensorflow/python:constant_op", - "//tensorflow/python:framework_test_lib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:func_graph", + "//tensorflow/python:math_ops", + "//tensorflow/python:rnn_cell", "//tensorflow/python:state_ops", "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", "//tensorflow/python/eager:context", + "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:test", + "//tensorflow/python/saved_model:load", + "//tensorflow/python/saved_model:save", + "//tensorflow/python/training/tracking:util", ], ) @@ -1907,9 +1917,9 @@ tf_py_test( name = "parameter_server_strategy_v2_test", srcs = ["parameter_server_strategy_v2_test.py"], python_version = "PY3", + shard_count = 5, tags = [ "no_windows", # TODO(171349346) - "notsan", # TODO(b/168675975) ], deps = [ ":multi_worker_test_base", diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py index 46b69543b67..a45cbe4ef0a 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py @@ -521,7 +521,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): def _input_workers_with_options(self, options=None): host_device = device_util.get_host_for_device(self._worker_device) - if not options or options.experimental_prefetch_to_device: + if not options or options.experimental_fetch_to_device: return input_lib.InputWorkers([(host_device, self.worker_devices)]) else: return input_lib.InputWorkers([( diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy_test.py b/tensorflow/python/distribute/collective_all_reduce_strategy_test.py index f27164d32b5..bde919cb20c 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy_test.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy_test.py @@ -292,7 +292,7 @@ class DistributedCollectiveAllReduceStrategyTest( input_options = None else: input_options = distribute_lib.InputOptions( - experimental_prefetch_to_device=prefetch_to_device) + experimental_fetch_to_device=prefetch_to_device) dataset = dataset_ops.Dataset.range(100) dataset = dataset.batch(distribution.num_replicas_in_sync) dataset = distribution.experimental_distribute_dataset( @@ -313,7 +313,7 @@ class DistributedCollectiveAllReduceStrategyTest( task_id=0, num_gpus=2) input_options = distribute_lib.InputOptions( - experimental_prefetch_to_device=False) + experimental_fetch_to_device=False) dataset = dataset_ops.Dataset.range(100) dataset = dataset.batch(distribution.num_replicas_in_sync) dataset = distribution.experimental_distribute_dataset( diff --git a/tensorflow/python/distribute/coordinator/cluster_coordinator.py b/tensorflow/python/distribute/coordinator/cluster_coordinator.py index c57304ae3a6..4ffcb1ac85d 100644 --- a/tensorflow/python/distribute/coordinator/cluster_coordinator.py +++ b/tensorflow/python/distribute/coordinator/cluster_coordinator.py @@ -825,13 +825,20 @@ class Worker(object): time.sleep(delay_secs) def _process_queue(self): - """Function running in a thread to process closure queues.""" + """Function running in a worker thread to process closure queues.""" self._maybe_delay() while self._should_worker_thread_run: closure = self._cluster._closure_queue.get() # pylint: disable=protected-access if not self._should_worker_thread_run or closure is None: return self._process_closure(closure) + # To properly stop the worker and preemption threads, it is important that + # `ClusterCoordinator` object is not held onto so its `__del__` can be + # called. By removing the reference to the `closure` that has already been + # processed, we ensure that the `closure` object is released, while + # getting the next `closure` at above `self._cluster._closure_queue.get()` + # call. + del closure def _create_resource(self, function, args=None, kwargs=None): """Synchronously creates a per-worker resource represented by a `RemoteValue`. diff --git a/tensorflow/python/distribute/custom_training_loop_input_test.py b/tensorflow/python/distribute/custom_training_loop_input_test.py index a5f135808f1..4470f87c075 100644 --- a/tensorflow/python/distribute/custom_training_loop_input_test.py +++ b/tensorflow/python/distribute/custom_training_loop_input_test.py @@ -459,8 +459,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase, input_iterator = iter( distribution.experimental_distribute_dataset( get_dataset_from_tensor_slices(data).batch(2), - distribute_lib.InputOptions( - experimental_prefetch_to_device=False))) + distribute_lib.InputOptions(experimental_fetch_to_device=False))) local_results = distribution.experimental_local_results( input_iterator.get_next()) @@ -479,7 +478,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase, input_iterator = iter( distribution.distribute_datasets_from_function( lambda _: get_dataset_from_tensor_slices(data), - distribute_lib.InputOptions(experimental_prefetch_to_device=False))) + distribute_lib.InputOptions(experimental_fetch_to_device=False))) local_results = distribution.experimental_local_results( input_iterator.get_next()) diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index 28c5f094512..c43e070db0b 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -622,12 +622,11 @@ class RunOptions( @tf_export("distribute.InputOptions", v1=[]) class InputOptions( collections.namedtuple("InputOptions", [ - # TODO(b/180518705): Rename `experimental_prefetch_to_device` to - # `experimental_fetch_to_device` to better reflect its functionality. - "experimental_prefetch_to_device", + "experimental_fetch_to_device", "experimental_replication_mode", "experimental_place_dataset_on_device", "experimental_per_replica_buffer_size", + "experimental_prefetch_to_device", # Deprecated ])): """Run options for `experimental_distribute_dataset(s_from_function)`. @@ -652,11 +651,12 @@ class InputOptions( ``` Attributes: - experimental_prefetch_to_device: Boolean. Defaults to True. If True, dataset + experimental_fetch_to_device: Boolean. If True, dataset elements will be prefetched to accelerator device memory. When False, dataset elements are prefetched to host device memory. Must be False when - using TPUEmbedding API. experimental_prefetch_to_device can only be used - with experimental_replication_mode=PER_WORKER + using TPUEmbedding API. experimental_fetch_to_device can only be used + with experimental_replication_mode=PER_WORKER. Default behavior is same as + setting it to True. experimental_replication_mode: Replication mode for the input function. Currently, the InputReplicationMode.PER_REPLICA is only supported with tf.distribute.MirroredStrategy. @@ -670,19 +670,26 @@ class InputOptions( prefetch buffer size in the replica device memory. Users can set it to 0 to completely disable prefetching behavior, or a number greater than 1 to enable larger buffer size. Note that this option is still - valid with `experimental_prefetch_to_device=False`. + valid with `experimental_fetch_to_device=False`. """ def __new__(cls, - experimental_prefetch_to_device=True, + experimental_fetch_to_device=None, experimental_replication_mode=InputReplicationMode.PER_WORKER, experimental_place_dataset_on_device=False, - experimental_per_replica_buffer_size=1): + experimental_per_replica_buffer_size=1, + experimental_prefetch_to_device=True): + if experimental_fetch_to_device is None: + # TODO(b/180133992): Remove `experimental_prefetch_to_device` after + # replacing all its usages with `experimental_fetch_to_device`. + experimental_fetch_to_device = experimental_prefetch_to_device + return super(InputOptions, - cls).__new__(cls, experimental_prefetch_to_device, + cls).__new__(cls, experimental_fetch_to_device, experimental_replication_mode, experimental_place_dataset_on_device, - experimental_per_replica_buffer_size) + experimental_per_replica_buffer_size, + experimental_prefetch_to_device) # ------------------------------------------------------------------------------ # Base classes for all distribution strategies. @@ -2435,6 +2442,20 @@ class StrategyExtendedV2(object): reduced = replica_context.merge_call(merge_fn, args=(nest.flatten(value),)) return nest.pack_sequence_as(value, reduced) + def _replica_ctx_update(self, var, fn, args=(), kwargs=None): + """Run `fn` with `args` and `kwargs` to update `var`.""" + # This method is called by ReplicaContext.update. Strategies who'd like to + # remove merge_call in this path should override this method. + replica_context = distribution_strategy_context.get_replica_context() + if not replica_context: + raise ValueError("`StrategyExtended._replica_ctx_update` must be called " + "in a replica context.") + + def merge_fn(_, *merged_args, **merged_kwargs): + return self.update(var, fn, merged_args, merged_kwargs, group=True) + + return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs) + def _gather_to(self, value, destinations, axis, options=None): """Gather `value` across replicas along axis-th dimension to `destinations`. @@ -3337,6 +3358,91 @@ class ReplicaContext(ReplicaContextBase): return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value))) + def _update(self, var, fn, args=(), kwargs=None): + """Run `fn` to update `var` with `args` and `kwargs` in replica context. + + `tf.distribute.ReplicaContext.update` takes a (distributed) variable `var` + to be updated, an update function `fn`, and `args` and `kwargs` for `fn`. + `fn` applies to each component variable of `var` with corresponding input + values from `args` and `kwargs`. + + Example usage: + + >>> strategy = tf.distribute.MirroredStrategy(['GPU:0', 'CPU:0']) # 2 replicas + >>> with strategy.scope(): + ... distributed_variable = tf.Variable(5.0) + >>> distributed_variable + MirroredVariable:{ + 0: , + 1: + } + >>> def replica_fn(v): + ... value = tf.identity(1.0) + ... replica_context = tf.distribute.get_replica_context() + ... update_fn = lambda var, value: var.assign(value) + ... replica_context._update(v, update_fn, args=(value,)) + >>> strategy.run(replica_fn, args=(distributed_variable,)) + >>> distributed_variable + MirroredVariable:{ + 0: , + 1: + } + + This API must be called in a replica context. + + Note that if `var` is a MirroredVariable (i.e., the type of variable created + under the scope of a synchronous strategy, and is synchronized on-write, see + `tf.VariableSynchronization` for more information) and `args`/`kwargs` + contains different values for different replicas, `var` will be dangerously + out of synchronization. Thus we recommend using `variable.assign(value)` as + long as you can, which under the hood aggregates the updates and guarantees + the synchronization. The case where you actually want this API instead of + `variable.assign(value)` is that before assigning `value` to the `variable`, + you'd like to conduct some pre-`assign` computation colocated with the + variable devices (i.e. where variables reside, for MirroredStrategy they are + the same as the compute device, for ParameterServerStrategy they refer to + parameter servers). E.g., + + ```python + strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # 2 replicas + with strategy.scope(): + v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM) + def replica_fn(inputs): + value = computation(inputs) + replica_context = tf.distribute.get_replica_context() + reduced_value = replica_context.all_reduce(value) + + def update_fn(var, value): + # this computation will colocate with `var`'s device + updated_value = post_reduce_pre_update_computation(value) + var.assign(value) + + replica_context._update(v, update_fn, args=(reduced_value,)) + + strategy.run(replica_fn, args=(inputs,)) + ``` + + This code snippet is consistent across all strategies. If you directly + compute and use `assign` in the replica context instead of wrapping it with + `update`, for strategies with fewer variable devices than compute devices + (e.g., parameter server strategy, usually), the + `post_reduce_pre_update_computation` will happen + N==number_of_compute_devices times which is less performant. + + + Args: + var: Variable, possibly distributed to multiple devices, to operate on. + fn: Function to call. Should take the variable as the first argument. + args: Tuple or list. Additional positional arguments to pass to `fn()`. + kwargs: Dict with keyword arguments to pass to `fn()`. + + Returns: + The return value of `fn` for the local replica. + """ + if kwargs is None: + kwargs = {} + return self._strategy.extended._replica_ctx_update(var, fn, args=args, kwargs=kwargs) # pylint: disable=protected-access + @tf_export(v1=["distribute.ReplicaContext"]) class ReplicaContextV1(ReplicaContextBase): diff --git a/tensorflow/python/distribute/distribute_utils.py b/tensorflow/python/distribute/distribute_utils.py index 3a7840b0492..a8506d1a28c 100644 --- a/tensorflow/python/distribute/distribute_utils.py +++ b/tensorflow/python/distribute/distribute_utils.py @@ -304,6 +304,16 @@ def create_mirrored_variable(strategy, real_mirrored_creator, class_mapping, # here. with tape.stop_recording(): value_list = real_mirrored_creator(**kwargs) + # MirroredVariable is recreated during saved_model loading, and its + # component variables (value_list) will have None initializer. We + # set their initializers to no_op so that consumer like + # `global_variables_initializer` wouldn't complain, as it groups all + # variables' initializers thus all variables have to have initializers. + for v in value_list: + # pylint:disable=protected-access + if hasattr(v, "_initializer_op") and v._initializer_op is None: + v._initializer_op = control_flow_ops.no_op() + # pylint:enable=protected-access if use_var_policy: var_policy_cls = policy_mapping.get(synchronization) var_policy = var_policy_cls(aggregation=aggregation) diff --git a/tensorflow/python/distribute/distribution_strategy_context.py b/tensorflow/python/distribute/distribution_strategy_context.py index b08c2313b1c..d0af66cf3e4 100644 --- a/tensorflow/python/distribute/distribution_strategy_context.py +++ b/tensorflow/python/distribute/distribution_strategy_context.py @@ -273,12 +273,12 @@ def experimental_set_strategy(strategy): @contextlib.contextmanager def enter_or_assert_strategy(strategy): - if not has_strategy(): - with strategy.scope(): - yield - else: + if has_strategy(): _assert_strategy(strategy) yield + else: + with strategy.scope(): + yield # ------------------------------------------------------------------------------ diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index f1be2dd557b..4e9c877bfb3 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -97,9 +97,9 @@ def get_distributed_dataset(dataset, """ if tf2.enabled(): return DistributedDataset( - dataset, input_workers, strategy, + dataset, num_replicas_in_sync=num_replicas_in_sync, input_context=input_context, options=options) @@ -155,22 +155,22 @@ def get_distributed_datasets_from_function(dataset_fn, if (options is not None and options.experimental_replication_mode == InputReplicationMode.PER_REPLICA - and options.experimental_prefetch_to_device and + and options.experimental_fetch_to_device and options.experimental_place_dataset_on_device): raise ValueError( "`experimental_place_dataset_on_device` can not be set to True " - "when experimental_prefetch_to_device is True and " + "when experimental_fetch_to_device is True and " "replication mode is set to `PER_REPLICA`") if tf2.enabled(): - return DistributedDatasetsFromFunction(dataset_fn, input_workers, - input_contexts, strategy, options) + return DistributedDatasetsFromFunction(input_workers, strategy, + input_contexts, dataset_fn, options) else: return DistributedDatasetsFromFunctionV1( - dataset_fn, input_workers, - input_contexts, strategy, + input_contexts, + dataset_fn, options) @@ -757,16 +757,20 @@ class DistributedIteratorV1(DistributedIteratorBase): return self._element_spec -class DistributedIteratorSpec(type_spec.TypeSpec): - """Type specification for `DistributedIterator`.""" +class DistributedDatasetAndIteratorSpec(type_spec.TypeSpec): + """Common Type specification for `DistributedDataset and DistributedDatasetsFromFunction.""" __slots__ = [ "_input_workers", "_element_spec", "_strategy", "_enable_get_next_as_optional", "_options" ] - def __init__(self, input_workers, element_spec, strategy, - enable_get_next_as_optional, options): + def __init__(self, + input_workers, + element_spec, + strategy, + options, + enable_get_next_as_optional=None): # We don't want to allow deserialization of this class because we don't # serialize the strategy object. Currently the only places where # _deserialize is called is when we save/restore using SavedModels. @@ -780,10 +784,6 @@ class DistributedIteratorSpec(type_spec.TypeSpec): self._enable_get_next_as_optional = enable_get_next_as_optional self._options = options - @property - def value_type(self): - return DistributedIterator - def _serialize(self): # We cannot serialize the strategy object so we convert it to an id that we # can use for comparison. @@ -791,11 +791,10 @@ class DistributedIteratorSpec(type_spec.TypeSpec): id(self._strategy), id(self._options)) def _deserialize(self): - raise ValueError("Deserialization is currently unsupported for " - "DistributedIteratorSpec.") + raise ValueError( + f"Deserialization is currently unsupported for {type(self)}.") - # Overriding this method so that we can merge and reconstruct the spec object - def most_specific_compatible_type(self, other): + def sanity_check_type(self, other): """Returns the most specific TypeSpec compatible with `self` and `other`. Args: @@ -815,6 +814,34 @@ class DistributedIteratorSpec(type_spec.TypeSpec): if self._strategy is not other._strategy: raise ValueError("tf.distribute strategy is not compatible with both %s " "and %s" % (self, other)) + + +class DistributedIteratorSpec(DistributedDatasetAndIteratorSpec): + """Type specification for `DistributedIterator`.""" + + def __init__(self, input_workers, element_spec, strategy, + enable_get_next_as_optional, options): + super(DistributedIteratorSpec, + self).__init__(input_workers, element_spec, strategy, options, + enable_get_next_as_optional) + + @property + def value_type(self): + return DistributedIterator + + # Overriding this method so that we can merge and reconstruct the spec object + def most_specific_compatible_type(self, other): + """Returns the most specific TypeSpec compatible with `self` and `other`. + + Args: + other: A `TypeSpec`. + + Raises: + ValueError: If there is no TypeSpec that is compatible with both `self` + and `other`. + """ + # pylint: disable=protected-access + self.sanity_check_type(other) element_spec = nest.map_structure( lambda a, b: a.most_specific_compatible_type(b), self._element_spec, other._element_spec) @@ -961,15 +988,84 @@ class _IterableInput(DistributedDatasetInterface): return final_state -class DistributedDataset(_IterableInput): +class DistributedDatasetSpec(DistributedDatasetAndIteratorSpec): + """Type specification for `DistributedDataset.""" + + def __init__(self, input_workers, element_spec, strategy, + enable_get_next_as_optional, options): + super(DistributedDatasetSpec, + self).__init__(input_workers, element_spec, strategy, options, + enable_get_next_as_optional) + + @property + def value_type(self): + return DistributedDataset + + # Overriding this method so that we can merge and reconstruct the spec object + def most_specific_compatible_type(self, other): + """Returns the most specific TypeSpec compatible with `self` and `other`. + + Args: + other: A `TypeSpec`. + + Raises: + ValueError: If there is no TypeSpec that is compatible with both `self` + and `other`. + """ + # pylint: disable=protected-access + self.sanity_check_type(other) + element_spec = nest.map_structure( + lambda a, b: a.most_specific_compatible_type(b), self._element_spec, + other._element_spec) + return DistributedDatasetSpec(self._input_workers, element_spec, + self._strategy, + self._enable_get_next_as_optional, + self._options) + + @property + def _component_specs(self): + specs = [] + worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access + + for i, _ in enumerate(worker_device_pairs): + element_spec = nest.map_structure( + functools.partial(_replace_per_replica_spec, i=i), self._element_spec) + specs.append(dataset_ops.DatasetSpec(element_spec)) + return specs + + def _to_components(self, value): + return value._cloned_datasets # pylint: disable=protected-access + + def _from_components(self, components): + return DistributedDataset( + input_workers=self._input_workers, + strategy=self._strategy, + components=components, + element_spec=self._element_spec, + enable_get_next_as_optional=self._enable_get_next_as_optional, + options=self._options) + + @staticmethod + def from_value(value): + # pylint: disable=protected-access + return DistributedDatasetSpec(value._input_workers, value._element_spec, + value._strategy, + value._enable_get_next_as_optional, + value._options) + + +class DistributedDataset(_IterableInput, composite_tensor.CompositeTensor): """Distributed dataset that supports prefetching to multiple devices.""" def __init__(self, - dataset, input_workers, strategy, + dataset=None, num_replicas_in_sync=None, input_context=None, + components=None, + element_spec=None, + enable_get_next_as_optional=None, options=None): """Distribute the dataset on all workers. @@ -979,10 +1075,14 @@ class DistributedDataset(_IterableInput): workers and replicas) is as expected. Args: - dataset: `tf.data.Dataset` that will be used as the input source. input_workers: an `InputWorkers` object. strategy: a `tf.distribute.Strategy` object, used to run all-reduce to handle last partial batch. + dataset: `tf.data.Dataset` that will be used as the input source. Either + dataset or components field should be passed when constructing + DistributedDataset. Use this when contructing DistributedDataset from a + new `tf.data.Dataset`. Use components when constructing using + DistributedDatasetSpec. num_replicas_in_sync: Optional integer. If this is not None, the value is used to decide how to rebatch datasets into smaller batches so that the total batch size for each step (across all workers and replicas) @@ -991,10 +1091,53 @@ class DistributedDataset(_IterableInput): graph multi-worker cases where there is only one `input_worker`. In these cases, we will shard based on the `input_pipeline_id` and `num_input_pipelines` in the `InputContext`. + components: datasets when DistributedDataset is constructed from + DistributedDatasetSpec. Either field dataset or components should be + passed. + element_spec: element spec for DistributedDataset when constructing from + DistributedDatasetSpec. This will be used to set the element_spec for + DistributedDataset and verified against element_spec from components. + enable_get_next_as_optional: this is required when components is passed + instead of dataset. options: `tf.distribute.InputOptions` used to control options on how this dataset is distributed. """ super(DistributedDataset, self).__init__(input_workers=input_workers) + if input_workers is None or strategy is None: + raise ValueError("input_workers and strategy are required arguments") + if dataset is not None and components is not None: + raise ValueError("Only one of dataset or components should be present") + if dataset is None and components is None: + raise ValueError("At least one of dataset or components should be passed") + + if dataset is not None: + self._create_cloned_datasets_from_dataset(dataset, input_context, + input_workers, strategy, + num_replicas_in_sync) + else: + if enable_get_next_as_optional is None: + raise ValueError( + "When constructing DistributedDataset with components, " + + "enable_get_next_as_optional should also be passed") + self._cloned_datasets = components + self._enable_get_next_as_optional = enable_get_next_as_optional + + self._input_workers = input_workers + self._strategy = strategy + self._options = options + + if element_spec is not None: + if element_spec != _create_distributed_tensor_spec( + self._strategy, self._cloned_datasets[0].element_spec): + raise ValueError("Mismatched element_spec from the passed components") + self._element_spec = element_spec + else: + self._element_spec = _create_distributed_tensor_spec( + self._strategy, self._cloned_datasets[0].element_spec) + + def _create_cloned_datasets_from_dataset(self, dataset, input_context, + input_workers, strategy, + num_replicas_in_sync): # We clone and shard the dataset on each worker. The current setup tries to # shard the dataset by files if possible so that each worker sees a # different subset of files. If that is not possible, will attempt to shard @@ -1012,7 +1155,6 @@ class DistributedDataset(_IterableInput): num_replicas_in_sync) else: rebatch_fn = None - self._cloned_datasets = [] if input_context: # Between-graph where we rely on the input_context for sharding @@ -1038,13 +1180,8 @@ class DistributedDataset(_IterableInput): num_replicas_in_sync) self._cloned_datasets.append(cloned_dataset) - self._input_workers = input_workers - self._strategy = strategy - self._options = options self._enable_get_next_as_optional = _enable_get_next_as_optional( - self._strategy, dataset) - self._element_spec = _create_distributed_tensor_spec( - self._strategy, self._cloned_datasets[0].element_spec) + strategy, dataset) def _make_rebatch_fn(self, dataset, num_workers, num_replicas_in_sync): """Returns a callable that rebatches the input dataset. @@ -1151,6 +1288,13 @@ class DistributedDataset(_IterableInput): _rebatch_as_dynamic, self._element_spec, expand_composites=False) return self._element_spec + @property + def _type_spec(self): + return DistributedDatasetSpec(self._input_workers, self._element_spec, + self._strategy, + self._enable_get_next_as_optional, + self._options) + class DistributedDatasetV1(DistributedDataset): """Distributed dataset that supports prefetching to multiple devices.""" @@ -1164,9 +1308,9 @@ class DistributedDatasetV1(DistributedDataset): options=None): self._input_workers = input_workers super(DistributedDatasetV1, self).__init__( - dataset, input_workers, strategy, + dataset, num_replicas_in_sync=num_replicas_in_sync, input_context=input_context, options=options) @@ -1239,45 +1383,136 @@ class DistributedDatasetV1(DistributedDataset): "or when eager execution is enabled.") +class DistributedDatasetsFromFunctionSpec(DistributedDatasetAndIteratorSpec): + """Type specification for `DistributedDatasetsFromFunction.""" + + def __init__(self, input_workers, element_spec, strategy, options): + super(DistributedDatasetsFromFunctionSpec, + self).__init__(input_workers, element_spec, strategy, options) + + @property + def value_type(self): + return DistributedDatasetsFromFunction + + @property + def _component_specs(self): + specs = [] + worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access + + for i, _ in enumerate(worker_device_pairs): + element_spec = nest.map_structure( + functools.partial(_replace_per_replica_spec, i=i), self._element_spec) + specs.append(dataset_ops.DatasetSpec(element_spec)) + return specs + + # Overriding this method so that we can merge and reconstruct the spec object + def most_specific_compatible_type(self, other): + """Returns the most specific TypeSpec compatible with `self` and `other`. + + Args: + other: A `TypeSpec`. + + Raises: + ValueError: If there is no TypeSpec that is compatible with both `self` + and `other`. + """ + # pylint: disable=protected-access + self.sanity_check_type(other) + element_spec = nest.map_structure( + lambda a, b: a.most_specific_compatible_type(b), self._element_spec, + other._element_spec) # pylint: disable=protected-access + return DistributedDatasetsFromFunctionSpec(self._input_workers, + element_spec, self._strategy, + self._options) + + def _to_components(self, value): + return value._datasets # pylint: disable=protected-access + + def _from_components(self, components): + return DistributedDatasetsFromFunction( + input_workers=self._input_workers, + strategy=self._strategy, + components=components, + element_spec=self._element_spec, + options=self._options) + + @staticmethod + def from_value(value): + # pylint: disable=protected-access + return DistributedDatasetsFromFunctionSpec( + input_workers=value._input_workers, + element_spec=value._element_spec, + strategy=value._strategy, + options=value._options) + + # TODO(priyag): Add other replication modes. -class DistributedDatasetsFromFunction(_IterableInput): +class DistributedDatasetsFromFunction(_IterableInput, + composite_tensor.CompositeTensor): """Inputs created from dataset function.""" - def __init__(self, dataset_fn, input_workers, input_contexts, strategy, - options): + def __init__(self, + input_workers, + strategy, + input_contexts=None, + dataset_fn=None, + options=None, + components=None, + element_spec=None): """Makes an iterable from datasets created by the given function. Args: - dataset_fn: A function that returns a `Dataset` given an `InputContext`. input_workers: an `InputWorkers` object. + strategy: a `tf.distribute.Strategy` object, used to run all-reduce to + handle last partial batch. input_contexts: A list of `InputContext` instances to be passed to call(s) to `dataset_fn`. Length and order should match worker order in `worker_device_pairs`. - strategy: a `tf.distribute.Strategy` object, used to run all-reduce to - handle last partial batch. + dataset_fn: A function that returns a `Dataset` given an `InputContext`. + Either dataset_fn or components should be passed to construct + DistributedDatasetsFromFunction. Use this when contructing + DistributedDataset using a function. Use components when constructing + using DistributedDatasetsFromFunctionSpec. options: `tf.distribute.InputOptions` used to control options on how this dataset is distributed. + components: datasets when DistributedDatasetsFromFunction is constructed + from DistributedDatasetsFromFunctionSpec. Only one of dataset or + components should be passed. + element_spec: element spec for DistributedDataset when constructing from + DistributedDatasetSpec. This will be used to set the element_spec for + DistributedDatasetsFromFunctionSpec and verified against element_spec + from components. """ super(DistributedDatasetsFromFunction, self).__init__( input_workers=input_workers) - - if input_workers.num_workers != len(input_contexts): - raise ValueError( - "Number of input workers (%d) is not same as number of " - "input_contexts (%d)" % - (input_workers.num_workers, len(input_contexts))) - self._input_workers = input_workers - self._input_contexts = input_contexts self._strategy = strategy self._options = options - self._datasets, element_spec = ( - _create_datasets_from_function_with_input_context( - self._input_contexts, self._input_workers, dataset_fn)) + if dataset_fn is not None and components is not None: + raise ValueError("Only one of dataset_fn or components should be set") + if dataset_fn is None and components is None: + raise ValueError("At least one of dataset_fn or components should be set") + + if dataset_fn is not None: + if input_workers.num_workers != len(input_contexts): + raise ValueError( + "Number of input workers (%d) is not same as number of " + "input_contexts (%d)" % + (input_workers.num_workers, len(input_contexts))) + self._datasets, element_spec = ( + _create_datasets_from_function_with_input_context( + input_contexts, self._input_workers, dataset_fn)) + self._element_spec = _create_distributed_tensor_spec( + self._strategy, element_spec) + else: + if element_spec is None: + raise ValueError( + "element_spec should also be passed when passing components") + self._element_spec = element_spec + self._datasets = components + self._enable_get_next_as_optional = _enable_get_next_as_optional( self._strategy, self._datasets[0]) - self._element_spec = _create_distributed_tensor_spec( - self._strategy, element_spec) def __iter__(self): if (ops.executing_eagerly_outside_functions() or @@ -1331,6 +1566,12 @@ class DistributedDatasetsFromFunction(_IterableInput): _rebatch_as_dynamic, self._element_spec, expand_composites=False) return self._element_spec + @property + def _type_spec(self): + return DistributedDatasetsFromFunctionSpec(self._input_workers, + self._element_spec, + self._strategy, self._options) + class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction): """Inputs created from dataset function.""" @@ -1588,7 +1829,7 @@ class _SingleWorkerDatasetIteratorBase(object): """ if (self._options and self._options.experimental_replication_mode == InputReplicationMode.PER_REPLICA and - not self._options.experimental_prefetch_to_device): + not self._options.experimental_fetch_to_device): return [data_list] else: return data_list @@ -2028,7 +2269,7 @@ def _should_use_multi_device_iterator(options): options.experimental_replication_mode == InputReplicationMode.PER_WORKER or (options.experimental_replication_mode == InputReplicationMode.PER_REPLICA - and options.experimental_prefetch_to_device)): + and options.experimental_fetch_to_device)): return True return False diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py index 2af3db2b27a..9ed4144ad12 100644 --- a/tensorflow/python/distribute/input_lib_test.py +++ b/tensorflow/python/distribute/input_lib_test.py @@ -112,9 +112,9 @@ class DistributedIteratorTestBase(test.TestCase): if input_type == "dataset": if tf2.enabled(): return input_lib.DistributedDataset( - dataset, input_workers, strategy, + dataset, num_replicas_in_sync=num_replicas_in_sync, input_context=input_context) else: @@ -1467,12 +1467,12 @@ class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase, input_options=[ distribute_lib.InputOptions( experimental_place_dataset_on_device=False, - experimental_prefetch_to_device=True, + experimental_fetch_to_device=True, experimental_replication_mode=distribute_lib .InputReplicationMode.PER_WORKER), distribute_lib.InputOptions( experimental_place_dataset_on_device=False, - experimental_prefetch_to_device=True, + experimental_fetch_to_device=True, experimental_replication_mode=distribute_lib .InputReplicationMode.PER_REPLICA), ], @@ -1509,7 +1509,7 @@ class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase, input_options=[ distribute_lib.InputOptions( experimental_place_dataset_on_device=False, - experimental_prefetch_to_device=False, + experimental_fetch_to_device=False, experimental_replication_mode=distribute_lib .InputReplicationMode.PER_WORKER) ], @@ -1541,12 +1541,12 @@ class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase, input_options=[ distribute_lib.InputOptions( experimental_place_dataset_on_device=True, - experimental_prefetch_to_device=False, + experimental_fetch_to_device=False, experimental_replication_mode=distribute_lib .InputReplicationMode.PER_WORKER), distribute_lib.InputOptions( experimental_place_dataset_on_device=True, - experimental_prefetch_to_device=True, + experimental_fetch_to_device=True, experimental_replication_mode=distribute_lib .InputReplicationMode.PER_REPLICA) ], @@ -1572,11 +1572,11 @@ class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase, input_options=[ distribute_lib.InputOptions( experimental_place_dataset_on_device=False, - experimental_prefetch_to_device=False, + experimental_fetch_to_device=False, experimental_per_replica_buffer_size=2), distribute_lib.InputOptions( experimental_place_dataset_on_device=False, - experimental_prefetch_to_device=True, + experimental_fetch_to_device=True, experimental_per_replica_buffer_size=2), ], mode=["eager"], @@ -1605,12 +1605,12 @@ class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase, input_options=[ distribute_lib.InputOptions( experimental_place_dataset_on_device=False, - experimental_prefetch_to_device=False, + experimental_fetch_to_device=False, experimental_replication_mode=distribute_lib .InputReplicationMode.PER_WORKER), distribute_lib.InputOptions( experimental_place_dataset_on_device=False, - experimental_prefetch_to_device=True, + experimental_fetch_to_device=True, experimental_replication_mode=distribute_lib .InputReplicationMode.PER_WORKER), ], @@ -1641,17 +1641,17 @@ class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase, input_options=[ distribute_lib.InputOptions( experimental_place_dataset_on_device=True, - experimental_prefetch_to_device=False, + experimental_fetch_to_device=False, experimental_replication_mode=distribute_lib .InputReplicationMode.PER_REPLICA), distribute_lib.InputOptions( experimental_place_dataset_on_device=False, - experimental_prefetch_to_device=False, + experimental_fetch_to_device=False, experimental_replication_mode=distribute_lib .InputReplicationMode.PER_REPLICA), distribute_lib.InputOptions( experimental_place_dataset_on_device=False, - experimental_prefetch_to_device=True, + experimental_fetch_to_device=True, experimental_replication_mode=distribute_lib .InputReplicationMode.PER_REPLICA), ], diff --git a/tensorflow/python/distribute/input_lib_type_spec_test.py b/tensorflow/python/distribute/input_lib_type_spec_test.py index d3830c3df62..54be959778b 100644 --- a/tensorflow/python/distribute/input_lib_type_spec_test.py +++ b/tensorflow/python/distribute/input_lib_type_spec_test.py @@ -268,9 +268,260 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase): ], enable_get_next_as_optional=[True, False], experimental_place_dataset_on_device=[True, False], - experimental_prefetch_to_device=[True, False], + experimental_fetch_to_device=[True, False], )) def testFromFunctionInputSignatureForPerReplicaValuesWithOptions( + self, distribution, enable_get_next_as_optional, + experimental_place_dataset_on_device, experimental_fetch_to_device): + + if experimental_place_dataset_on_device and experimental_fetch_to_device: + self.skipTest("Setting experimental_place_dataset_on_device and " + "experimental_fetch_to_device to `True` is not " + "allowed when using " + "distribute_lib.InputReplicationMode.PER_REPLICA.") + + fname1 = os.path.join(self.get_temp_dir(), "1.txt") + _create_text_file(fname1, 5) + fname2 = os.path.join(self.get_temp_dir(), "2.txt") + _create_text_file(fname2, 9) + + def dataset_fn(input_context): + dataset = dataset_ops.DatasetV2.from_tensor_slices([fname1, fname2]) + dataset = dataset.shard(input_context.num_input_pipelines, + input_context.input_pipeline_id) + return readers.TextLineDatasetV2(dataset).map( + string_ops.string_to_number).batch( + input_context.get_per_replica_batch_size(4)) + + options = distribute_lib.InputOptions( + experimental_place_dataset_on_device=( + experimental_place_dataset_on_device), + experimental_fetch_to_device=experimental_fetch_to_device, + experimental_replication_mode=( + distribute_lib.InputReplicationMode.PER_REPLICA)) + + distribution.extended.experimental_enable_get_next_as_optional = ( + enable_get_next_as_optional) + ds = distribution.experimental_distribute_datasets_from_function( + dataset_fn, options) + + iterator = iter(ds) + _check_type_spec_structure(iterator) + spec = iterator._type_spec + tensor_list = spec._to_components(iterator) + re_iterator = spec._from_components(tensor_list) + + _check_type_spec_structure(iter(ds)) + element_spec = ds.element_spec + iter_element_spec = iter(ds).element_spec + nest.assert_same_structure(element_spec, iter_element_spec) + self.assertAllEqual( + nest.flatten(element_spec), nest.flatten(iter_element_spec)) + self.assertEqual(iterator._input_workers, re_iterator._input_workers) + self.assertAllEqual(iterator._iterators, re_iterator._iterators) + + @def_function.function(input_signature=[element_spec]) + def process_inputs(inputs): + distribution.run(lambda inputs: inputs, args=(inputs,)) + + for x in ds: + process_inputs(x) + + +class DistributedDatasetTypeSpecTest(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.combine( + mode=["eager"], + tf_api_version=2, + distribution=[ + strategy_combinations.mirrored_strategy_with_two_gpus, + strategy_combinations.mirrored_strategy_with_cpu_1_and_2, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + ], + enable_get_next_as_optional=[True, False])) + def testTypeSpecBase(self, distribution, enable_get_next_as_optional): + + def create_dataset(): + dataset = dataset_ops.DatasetV2.range(10).batch(2) + return dataset + + distribution.extended.experimental_enable_get_next_as_optional = ( + enable_get_next_as_optional) + + dist_dataset = distribution.experimental_distribute_dataset( + create_dataset()) + + spec = dist_dataset._type_spec + self.assertEqual(spec._input_workers, dist_dataset._input_workers) + self.assertEqual( + spec._element_spec._value_specs, + (tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.int64, name=None), + tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.int64, name=None))) + + @combinations.generate( + combinations.combine( + mode=["eager"], + tf_api_version=2, + distribution=[ + strategy_combinations.mirrored_strategy_with_cpu_1_and_2, + ], + enable_get_next_as_optional=[True, False])) + def testTypeSpecReturnedFromTFFunction(self, distribution, + enable_get_next_as_optional): + # TODO(ishark): This is observed when tensor is copied from one device to + # other and since DatasetVariantWrapper does not have a copy + # function. Some Context: b/146981184 + # Try to renable with non-canonicalized input workers, which + # helped in PS Strategy for similar error. + self.skipTest("Failures observed in Ubuntu presubmit: No unary variant " + "device copy function found for direction: 1 and Variant " + "type_index:tensorflow::data::(anonymous namespace)::" + "DatasetVariantWrapper") + + @def_function.function + def create_dist_dataset(): + dataset = dataset_ops.DatasetV2.range(10).batch(2) + return distribution.experimental_distribute_dataset(dataset) + + distribution.extended.experimental_enable_get_next_as_optional = ( + enable_get_next_as_optional) + + dist_dataset = create_dist_dataset() + + spec = dist_dataset._type_spec + self.assertEqual(spec._input_workers, dist_dataset._input_workers) + self.assertEqual( + spec._element_spec._value_specs, + (tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.int64, name=None), + tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.int64, name=None))) + + # Read distributed data to confirm values are correct. + iterator = iter(dist_dataset) + data = [] + for it in iterator: + data.append(distribution.experimental_local_results(it)) + self.assertAllEqual( + nest.flatten(data), + list(dataset_ops.DatasetV2.range(10).batch(1).as_numpy_iterator())) + + @combinations.generate( + combinations.combine( + mode=["eager"], + tf_api_version=2, + distribution=[ + strategy_combinations.mirrored_strategy_with_two_gpus, + strategy_combinations.mirrored_strategy_with_cpu_1_and_2, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + ], + enable_get_next_as_optional=[True, False])) + def testTypeSpecRaggedTensor(self, distribution, enable_get_next_as_optional): + ctx = distribute_lib.InputContext() + batch_size = ctx.get_per_replica_batch_size(8) + # Use 20 which isn't divisible by 8 to test partial batch behavior. + row_lengths = np.mod(np.arange(20), 4).astype(np.int64) + ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths( + np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths) + dataset = dataset_ops.DatasetV2.from_tensor_slices({ + "dense": ragged_tensor.to_tensor(), + "ragged": ragged_tensor, + "sparse": ragged_tensor.to_sparse(), + }) + dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) + dataset = dataset.batch(batch_size) + + distribution.extended.experimental_enable_get_next_as_optional = ( + enable_get_next_as_optional) + dist_dataset = distribution.experimental_distribute_dataset(dataset) + spec = dist_dataset._type_spec + self.assertEqual(spec._input_workers, dist_dataset._input_workers) + self.assertEqual( + spec._element_spec, { + "sparse": + values.PerReplicaSpec( + sparse_tensor.SparseTensorSpec( + tensor_shape.TensorShape([None, 3]), dtypes.float32), + sparse_tensor.SparseTensorSpec( + tensor_shape.TensorShape([None, 3]), dtypes.float32)), + "dense": + values.PerReplicaSpec( + tensor_spec.TensorSpec( + shape=(None, 3), dtype=dtypes.float32, name=None), + tensor_spec.TensorSpec( + shape=(None, 3), dtype=dtypes.float32, name=None)), + "ragged": + values.PerReplicaSpec( + ragged_tensor_lib.RaggedTensorSpec( + tensor_shape.TensorShape([None, None]), dtypes.float32, + 1, dtypes.int64), + ragged_tensor_lib.RaggedTensorSpec( + tensor_shape.TensorShape([None, None]), dtypes.float32, + 1, dtypes.int64)) + }) + + @combinations.generate( + combinations.combine( + mode=["eager"], + tf_api_version=2, + distribution=[ + strategy_combinations.mirrored_strategy_with_two_gpus, + strategy_combinations.mirrored_strategy_with_cpu_1_and_2, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + ], + enable_get_next_as_optional=[True, False], + experimental_place_dataset_on_device=[True, False], + experimental_prefetch_to_device=[True, False])) + def testTypeSpecComponents(self, distribution, enable_get_next_as_optional, + experimental_place_dataset_on_device, + experimental_prefetch_to_device): + dataset = dataset_ops.DatasetV2.range(10).batch(2) + distribution.extended.experimental_enable_get_next_as_optional = ( + enable_get_next_as_optional) + + options = distribute_lib.InputOptions( + experimental_place_dataset_on_device= + experimental_place_dataset_on_device, + experimental_prefetch_to_device=experimental_prefetch_to_device) + + dist_dataset = distribution.experimental_distribute_dataset( + dataset, options) + + spec = dist_dataset._type_spec + self.assertEqual(spec._input_workers, dist_dataset._input_workers) + self.assertEqual( + spec._element_spec._value_specs, + (tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.int64, name=None), + tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.int64, name=None))) + components = spec._to_components(dist_dataset) + re_dist_dataset = spec._from_components(components) + + self.assertEqual(dist_dataset._input_workers, + re_dist_dataset._input_workers) + self.assertAllEqual(dist_dataset._cloned_datasets, + re_dist_dataset._cloned_datasets) + self.assertEqual(dist_dataset._element_spec, re_dist_dataset._element_spec) + self.assertEqual(dist_dataset._enable_get_next_as_optional, + re_dist_dataset._enable_get_next_as_optional) + self.assertEqual(dist_dataset._options, re_dist_dataset._options) + + +class DistributedDatasetsFromFunctionSpecTest(test.TestCase, + parameterized.TestCase): + + @combinations.generate( + combinations.combine( + mode=["eager"], + tf_api_version=2, + distribution=[ + strategy_combinations.mirrored_strategy_with_two_gpus, + strategy_combinations.mirrored_strategy_with_cpu_1_and_2, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + ], + enable_get_next_as_optional=[True, False], + experimental_place_dataset_on_device=[True, False], + experimental_prefetch_to_device=[True, False], + )) + def testDistributedDatasetsFromFunctionSpec( self, distribution, enable_get_next_as_optional, experimental_place_dataset_on_device, experimental_prefetch_to_device): @@ -305,20 +556,17 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase): ds = distribution.experimental_distribute_datasets_from_function( dataset_fn, options) - iterator = iter(ds) - _check_type_spec_structure(iterator) - spec = iterator._type_spec - tensor_list = spec._to_components(iterator) - re_iterator = spec._from_components(tensor_list) + spec = ds._type_spec + components = spec._to_components(ds) + re_ds = spec._from_components(components) - _check_type_spec_structure(iter(ds)) - element_spec = ds.element_spec + element_spec = re_ds.element_spec iter_element_spec = iter(ds).element_spec nest.assert_same_structure(element_spec, iter_element_spec) self.assertAllEqual( nest.flatten(element_spec), nest.flatten(iter_element_spec)) - self.assertEqual(iterator._input_workers, re_iterator._input_workers) - self.assertAllEqual(iterator._iterators, re_iterator._iterators) + self.assertEqual(ds._input_workers, re_ds._input_workers) + self.assertEqual(ds._element_spec, re_ds._element_spec) @def_function.function(input_signature=[element_spec]) def process_inputs(inputs): @@ -424,8 +672,7 @@ class RaggedTensorDistributedIteratorTest(test.TestCase, (tpu_strategy.TPUStrategyV2, tpu_strategy.TPUStrategy)): # TPUStrategy does not support distributed datasets with device prefetch # when using sparse or ragged tensors. - options = distribute_lib.InputOptions( - experimental_prefetch_to_device=False) + options = distribute_lib.InputOptions(experimental_fetch_to_device=False) else: options = None @@ -486,8 +733,7 @@ class RaggedTensorDistributedIteratorTest(test.TestCase, (tpu_strategy.TPUStrategyV2, tpu_strategy.TPUStrategy)): # TPUStrategy does not support distributed datasets with device prefetch # when using sparse or ragged tensors. - options = distribute_lib.InputOptions( - experimental_prefetch_to_device=False) + options = distribute_lib.InputOptions(experimental_fetch_to_device=False) else: options = None diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index 2ca8e8c55f0..d150eff3cc3 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -26,12 +26,14 @@ from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribute_utils +from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import mirrored_run from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values +from tensorflow.python.distribute import values_util from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver from tensorflow.python.eager import context from tensorflow.python.eager import tape @@ -453,7 +455,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): for d in self._devices)) return input_lib.InputWorkers(self._input_workers_devices) else: - if not options.experimental_prefetch_to_device: + if not options.experimental_fetch_to_device: return input_lib.InputWorkers([ (host_device, (host_device,) * len(compute_devices)) for host_device, compute_devices in self._input_workers_devices @@ -773,6 +775,19 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): **distribute_utils.select_replica(i, kwargs))) return distribute_utils.update_regroup(self, updates, group) + def _replica_ctx_update(self, var, fn, args, kwargs): + replica_context = distribution_strategy_context.get_replica_context() + assert replica_context + replica_id = values_util.get_current_replica_id_as_int() + name = "update_%d" % replica_id + + if isinstance(var, values.DistributedVariable): + var = var._get_replica(replica_id) # pylint: disable=protected-access + + with ops.device(var.device), ops.name_scope(name): + result = fn(var, *args, **kwargs) + return result + def _update_non_slot(self, colocate_with, fn, args, kwargs, group): assert isinstance(colocate_with, tuple) # TODO(josh11b): In eager mode, use one thread per device. diff --git a/tensorflow/python/distribute/mirrored_strategy_test.py b/tensorflow/python/distribute/mirrored_strategy_test.py index 156c23750bd..0bb9ebff770 100644 --- a/tensorflow/python/distribute/mirrored_strategy_test.py +++ b/tensorflow/python/distribute/mirrored_strategy_test.py @@ -236,7 +236,7 @@ class MirroredTwoDeviceDistributionTest( def test_prefetch_to_device_dataset(self, distribution): input_options = distribute_lib.InputOptions( - experimental_prefetch_to_device=True) + experimental_fetch_to_device=True) dataset = dataset_ops.Dataset.range(100) dataset = dataset.batch(distribution.num_replicas_in_sync) dataset = distribution.experimental_distribute_dataset( @@ -258,7 +258,7 @@ class MirroredTwoDeviceDistributionTest( def test_prefetch_to_host_dataset(self, distribution): input_options = distribute_lib.InputOptions( - experimental_prefetch_to_device=False) + experimental_fetch_to_device=False) dataset = dataset_ops.Dataset.range(100) dataset = dataset.batch(distribution.num_replicas_in_sync) dataset = distribution.experimental_distribute_dataset( diff --git a/tensorflow/python/distribute/mirrored_variable_test.py b/tensorflow/python/distribute/mirrored_variable_test.py index 53a18fb271b..acf6861b9a2 100644 --- a/tensorflow/python/distribute/mirrored_variable_test.py +++ b/tensorflow/python/distribute/mirrored_variable_test.py @@ -39,6 +39,9 @@ from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.saved_model import load +from tensorflow.python.saved_model import save +from tensorflow.python.training.tracking import util as tracking_util def _replica_id(): @@ -590,6 +593,29 @@ class MirroredVariableCreationTest(test.TestCase): self.assertIs(distribution, mirrored.distribute_strategy) self.assertIs(distribution, sync_on_read.distribute_strategy) + def testInitializer(self, distribution, mode): + if mode == "graph": + self.skipTest("Skip graph mode") + + temp_dir = self.get_temp_dir() + + class Model(tracking_util.Checkpoint): + + def __init__(self): + self._v = variables.Variable(1.0) + + with distribution.scope(): + m = Model() + save.save(m, temp_dir) + + g = ops.Graph() + with g.as_default(): + with distribution.scope(): + load.load(temp_dir) + + for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES): + self.assertIsNotNone(v.initializer) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/distribute/one_device_strategy.py b/tensorflow/python/distribute/one_device_strategy.py index 4dc791062e1..08ec7c8e45d 100644 --- a/tensorflow/python/distribute/one_device_strategy.py +++ b/tensorflow/python/distribute/one_device_strategy.py @@ -261,7 +261,7 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1): self._input_device = device_util.get_host_for_device(self._device) def _input_workers_with_options(self, options=None): - if not options or options.experimental_prefetch_to_device: + if not options or options.experimental_fetch_to_device: return input_lib.InputWorkers([(self._input_device, (self._device,))]) else: return input_lib.InputWorkers([(self._input_device, diff --git a/tensorflow/python/distribute/one_device_strategy_test.py b/tensorflow/python/distribute/one_device_strategy_test.py index 238d0150100..4f96cd76bcc 100644 --- a/tensorflow/python/distribute/one_device_strategy_test.py +++ b/tensorflow/python/distribute/one_device_strategy_test.py @@ -121,7 +121,7 @@ class OneDeviceStrategyTest( def test_prefetch_to_device_dataset(self, distribution): input_options = distribute_lib.InputOptions( - experimental_prefetch_to_device=True) + experimental_fetch_to_device=True) dataset = dataset_ops.Dataset.range(100) dataset = dataset.batch(distribution.num_replicas_in_sync) dataset = distribution.experimental_distribute_dataset( @@ -142,7 +142,7 @@ class OneDeviceStrategyTest( def test_prefetch_to_host_dataset(self, distribution): input_options = distribute_lib.InputOptions( - experimental_prefetch_to_device=False) + experimental_fetch_to_device=False) dataset = dataset_ops.Dataset.range(100) dataset = dataset.batch(distribution.num_replicas_in_sync) dataset = distribution.experimental_distribute_dataset( diff --git a/tensorflow/python/distribute/parameter_server_strategy.py b/tensorflow/python/distribute/parameter_server_strategy.py index 702bfee413b..6c65f3f884d 100644 --- a/tensorflow/python/distribute/parameter_server_strategy.py +++ b/tensorflow/python/distribute/parameter_server_strategy.py @@ -323,7 +323,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): compute_devices, self._variable_device) def _input_workers_with_options(self, options=None): - if not options or options.experimental_prefetch_to_device: + if not options or options.experimental_fetch_to_device: return input_lib.InputWorkers( [(self._worker_device, self._compute_devices)]) else: diff --git a/tensorflow/python/distribute/parameter_server_strategy_test.py b/tensorflow/python/distribute/parameter_server_strategy_test.py index 2c0b73de14d..89a04c415aa 100644 --- a/tensorflow/python/distribute/parameter_server_strategy_test.py +++ b/tensorflow/python/distribute/parameter_server_strategy_test.py @@ -782,7 +782,7 @@ class ParameterServerStrategyTest( input_options = None else: input_options = distribute_lib.InputOptions( - experimental_prefetch_to_device=prefetch_to_device) + experimental_fetch_to_device=prefetch_to_device) dataset = dataset_ops.Dataset.range(100) dataset = dataset.batch(distribution.num_replicas_in_sync) dataset = distribution.experimental_distribute_dataset( @@ -804,7 +804,7 @@ class ParameterServerStrategyTest( task_id=0, num_gpus=2) input_options = distribute_lib.InputOptions( - experimental_prefetch_to_device=False) + experimental_fetch_to_device=False) dataset = dataset_ops.Dataset.range(100) dataset = dataset.batch(distribution.num_replicas_in_sync) dataset = distribution.experimental_distribute_dataset( diff --git a/tensorflow/python/distribute/ps_values.py b/tensorflow/python/distribute/ps_values.py index a257a022dfa..c5172c9ca99 100644 --- a/tensorflow/python/distribute/ps_values.py +++ b/tensorflow/python/distribute/ps_values.py @@ -155,6 +155,9 @@ class AggregatingVariable(variables_lib.Variable, core.Tensor): def op(self): return self._v.op + def value(self): + return self._v.value() + def read_value(self): return self._v.read_value() diff --git a/tensorflow/python/distribute/strategy_common_test.py b/tensorflow/python/distribute/strategy_common_test.py index f200638ef19..d578c2685eb 100644 --- a/tensorflow/python/distribute/strategy_common_test.py +++ b/tensorflow/python/distribute/strategy_common_test.py @@ -237,6 +237,51 @@ class ReduceTest(test.TestCase, parameterized.TestCase): self.assertEqual(3 * strategy.num_replicas_in_sync, x_s) +@combinations.generate( + combinations.combine( + strategy=[ + strategy_combinations.default_strategy, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.tpu_strategy, + strategy_combinations.tpu_strategy_packed_var, + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, + ], + update_fn=['assign', 'assign_add', 'assign_sub'], + tf_function=[True, False], + mode=['eager'])) +class ReplicaCtxUpdateTest(test.TestCase, parameterized.TestCase): + + def testDenseUpdate(self, strategy, tf_function, update_fn): + if isinstance(strategy, tpu_strategy.TPUStrategy) and (not tf_function): + self.skipTest('Skip TPUStrategy + eager combination.') + with strategy.scope(): + distributed_variable1 = variables.Variable(5.0) + + def replica_fn(): + value = array_ops.constant(2.) + python_literal = 1. + replica_context = ds_context.get_replica_context() + fn_sets = { + 'assign': lambda var, value: var.assign(value), + 'assign_add': lambda var, value: var.assign_add(value), + 'assign_sub': lambda var, value: var.assign_sub(value), + } + replica_context._update( + distributed_variable1, fn_sets[update_fn], args=(value,)) + replica_context._update( + distributed_variable1, fn_sets[update_fn], args=(python_literal,)) + + if tf_function: + replica_fn = def_function.function(replica_fn) + strategy.run(replica_fn) + + expected_result = {'assign': 1., 'assign_add': 8., 'assign_sub': 2.} + self.assertAllEqual( + strategy.experimental_local_results(distributed_variable1), + [expected_result[update_fn]] * _get_num_replicas_per_client(strategy)) + + @combinations.generate( combinations.combine( strategy=[ diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index a11e0f43355..497583ada5f 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -909,7 +909,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): session) def _get_input_workers(self, options): - if not options or options.experimental_prefetch_to_device: + if not options or options.experimental_fetch_to_device: return input_lib.InputWorkers( tuple(self._device_input_worker_devices.items())) else: @@ -928,7 +928,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): "distributed datasets with device prefetch when using sparse or " "ragged tensors. If you intend to use sparse or ragged tensors, " "please pass a tf.distribute.InputOptions object with " - "experimental_prefetch_to_device set to False to your dataset " + "experimental_fetch_to_device set to False to your dataset " "distribution function.".format(path, type(spec))) def _experimental_distribute_dataset(self, dataset, options): @@ -939,7 +939,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): "is only supported in " "`experimental_distribute_datasets_from_function`." ) - if options is None or options.experimental_prefetch_to_device: + if options is None or options.experimental_fetch_to_device: self._check_spec(dataset.element_spec) return input_lib.get_distributed_dataset( @@ -974,7 +974,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): options=options) # We can only check after the dataset_fn is called. - if options is None or options.experimental_prefetch_to_device: + if options is None or options.experimental_fetch_to_device: self._check_spec(distributed_dataset.element_spec) return distributed_dataset diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py index 5b5fec3d81c..94e5ad9eb86 100644 --- a/tensorflow/python/distribute/tpu_strategy_test.py +++ b/tensorflow/python/distribute/tpu_strategy_test.py @@ -760,7 +760,7 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase): dataset = iter( strategy.distribute_datasets_from_function( dataset_fn, - distribute_lib.InputOptions(experimental_prefetch_to_device=False))) + distribute_lib.InputOptions(experimental_fetch_to_device=False))) sparse, result = sparse_lookup(dataset) @@ -810,7 +810,7 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase): dataset = iter( strategy.distribute_datasets_from_function( dataset_fn, - distribute_lib.InputOptions(experimental_prefetch_to_device=False))) + distribute_lib.InputOptions(experimental_fetch_to_device=False))) output = sparse_lookup(dataset) @@ -866,7 +866,7 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase): strategy.distribute_datasets_from_function( dataset_fn, options=distribute_lib.InputOptions( - experimental_prefetch_to_device=False))) + experimental_fetch_to_device=False))) result = sparse_lookup(dataset) self.assertAllEqual(result, [[0.0, 2.0], [1.5, 5.0]]) @@ -916,7 +916,7 @@ class TPUStrategyDataPrefetchTest(test.TestCase): output_type=dtypes.float32).batch(strategy.num_replicas_in_sync) input_options = distribute_lib.InputOptions( - experimental_prefetch_to_device=True) + experimental_fetch_to_device=True) dataset_item = next(iter(strategy.experimental_distribute_dataset( dataset, options=input_options))) dataset_location = tf_device.DeviceSpec.from_string( @@ -931,7 +931,7 @@ class TPUStrategyDataPrefetchTest(test.TestCase): # Should be CPU when prefetch_to_device is False. input_options = distribute_lib.InputOptions( - experimental_prefetch_to_device=False) + experimental_fetch_to_device=False) dataset_item = next(iter(strategy.experimental_distribute_dataset( dataset, options=input_options))) dataset_location = tf_device.DeviceSpec.from_string( diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 6e8d4363505..cb113fba8b9 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -66,6 +66,7 @@ cc_library( "//tensorflow/python:ndarray_tensor_bridge", "//tensorflow/python:numpy_lib", "//tensorflow/python:py_exception_registry", + "//tensorflow/python:pybind11_status", "//tensorflow/python/lib/core:py_seq_tensor", "//tensorflow/python/lib/core:py_util", "//tensorflow/python/lib/core:safe_ptr", diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py index 9f3a400d4e6..47326af9459 100644 --- a/tensorflow/python/eager/def_function_xla_jit_test.py +++ b/tensorflow/python/eager/def_function_xla_jit_test.py @@ -580,7 +580,6 @@ class DefFunctionTest(xla_test.XLATestCase): self.assertEqual(inner_retracings, 1) - @test_util.disable_mlir_bridge('b/180951174') def testUpdateVariable(self): with ops.device('device:{}:0'.format(self.device)): diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index 549a86ba882..755c50ea0d6 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/python/lib/core/numpy.h" #include "tensorflow/python/lib/core/py_exception_registry.h" #include "tensorflow/python/lib/core/py_seq_tensor.h" +#include "tensorflow/python/lib/core/pybind11_status.h" #include "tensorflow/python/lib/core/safe_ptr.h" // forward declare @@ -186,13 +187,9 @@ int ConvertDeviceName(PyObject* obj, const char** dst) { return 1; } -void RaiseExceptionTypeFromTFStatus(TF_Status* status) { - TF_Code code = TF_GetCode(status); - PyObject* exception = tensorflow::PyExceptionRegistry::Lookup(code); - PyErr_SetObject(exception, - pybind11::make_tuple(pybind11::none(), pybind11::none(), - TF_Message(status)) - .ptr()); +void RaiseExceptionTypeFromTFStatus(TF_Status* tf_status) { + auto status = tensorflow::StatusFromTF_Status(tf_status); + SetRegisteredErrFromStatus(status); } } // namespace diff --git a/tensorflow/python/framework/errors_impl.py b/tensorflow/python/framework/errors_impl.py index 34daf43372a..c2716f43125 100644 --- a/tensorflow/python/framework/errors_impl.py +++ b/tensorflow/python/framework/errors_impl.py @@ -67,7 +67,7 @@ class OpError(Exception): of `OpError` from the `tf.errors` module. """ - def __init__(self, node_def, op, message, error_code): + def __init__(self, node_def, op, message, error_code, *args): """Creates a new `OpError` indicating that a particular op failed. Args: @@ -76,12 +76,19 @@ class OpError(Exception): op: The `ops.Operation` that failed, if known; otherwise None. message: The message string describing the failure. error_code: The `error_codes_pb2.Code` describing the error. + *args: If not empty, it should contain a dictionary describing details + about the error. This argument is inspired by Abseil payloads: + https://github.com/abseil/abseil-cpp/blob/master/absl/status/status.h """ super(OpError, self).__init__() self._node_def = node_def self._op = op self._message = message self._error_code = error_code + if args: + self._experimental_payloads = args[0] + else: + self._experimental_payloads = {} def __reduce__(self): # Allow the subclasses to accept less arguments in their __init__. @@ -120,10 +127,19 @@ class OpError(Exception): """The `NodeDef` proto representing the op that failed.""" return self._node_def + @property + def experimental_payloads(self): + """A dictionary describing the details of the error.""" + return self._experimental_payloads + def __str__(self): if self._op is not None: - output = ["%s\n\nOriginal stack trace for %r:\n" % (self.message, - self._op.name,)] + output = [ + "%s\n\nOriginal stack trace for %r:\n" % ( + self.message, + self._op.name, + ) + ] curr_traceback_list = traceback.format_list( _compact_stack_trace(self._op)) output.extend(curr_traceback_list) @@ -132,8 +148,8 @@ class OpError(Exception): # pylint: enable=protected-access while original_op is not None: output.append( - "\n...which was originally created as op %r, defined at:\n" - % (original_op.name,)) + "\n...which was originally created as op %r, defined at:\n" % + (original_op.name,)) prev_traceback_list = curr_traceback_list curr_traceback_list = traceback.format_list( _compact_stack_trace(original_op)) @@ -157,9 +173,10 @@ class OpError(Exception): else: if is_eliding: if elide_count > 0: - output.extend( - ["[elided %d identical lines from previous traceback]\n" - % (elide_count - 1,), last_elided_line]) + output.extend([ + "[elided %d identical lines from previous traceback]\n" % + (elide_count - 1,), last_elided_line + ]) is_eliding = False output.extend(line) @@ -228,9 +245,12 @@ class CancelledError(OpError): @@__init__ """ - def __init__(self, node_def, op, message): + def __init__(self, node_def, op, message, *args): """Creates a `CancelledError`.""" - super(CancelledError, self).__init__(node_def, op, message, CANCELLED) + super(CancelledError, self).__init__(node_def, op, message, CANCELLED, + *args) + + # pylint: enable=line-too-long @@ -247,9 +267,9 @@ class UnknownError(OpError): @@__init__ """ - def __init__(self, node_def, op, message, error_code=UNKNOWN): + def __init__(self, node_def, op, message, *args): """Creates an `UnknownError`.""" - super(UnknownError, self).__init__(node_def, op, message, error_code) + super(UnknownError, self).__init__(node_def, op, message, UNKNOWN, *args) @tf_export("errors.InvalidArgumentError") @@ -267,10 +287,10 @@ class InvalidArgumentError(OpError): @@__init__ """ - def __init__(self, node_def, op, message): + def __init__(self, node_def, op, message, *args): """Creates an `InvalidArgumentError`.""" super(InvalidArgumentError, self).__init__(node_def, op, message, - INVALID_ARGUMENT) + INVALID_ARGUMENT, *args) @tf_export("errors.DeadlineExceededError") @@ -282,10 +302,10 @@ class DeadlineExceededError(OpError): @@__init__ """ - def __init__(self, node_def, op, message): + def __init__(self, node_def, op, message, *args): """Creates a `DeadlineExceededError`.""" super(DeadlineExceededError, self).__init__(node_def, op, message, - DEADLINE_EXCEEDED) + DEADLINE_EXCEEDED, *args) @tf_export("errors.NotFoundError") @@ -300,9 +320,9 @@ class NotFoundError(OpError): @@__init__ """ - def __init__(self, node_def, op, message): + def __init__(self, node_def, op, message, *args): """Creates a `NotFoundError`.""" - super(NotFoundError, self).__init__(node_def, op, message, NOT_FOUND) + super(NotFoundError, self).__init__(node_def, op, message, NOT_FOUND, *args) @tf_export("errors.AlreadyExistsError") @@ -317,10 +337,10 @@ class AlreadyExistsError(OpError): @@__init__ """ - def __init__(self, node_def, op, message): + def __init__(self, node_def, op, message, *args): """Creates an `AlreadyExistsError`.""" super(AlreadyExistsError, self).__init__(node_def, op, message, - ALREADY_EXISTS) + ALREADY_EXISTS, *args) @tf_export("errors.PermissionDeniedError") @@ -335,10 +355,10 @@ class PermissionDeniedError(OpError): @@__init__ """ - def __init__(self, node_def, op, message): + def __init__(self, node_def, op, message, *args): """Creates a `PermissionDeniedError`.""" super(PermissionDeniedError, self).__init__(node_def, op, message, - PERMISSION_DENIED) + PERMISSION_DENIED, *args) @tf_export("errors.UnauthenticatedError") @@ -350,10 +370,10 @@ class UnauthenticatedError(OpError): @@__init__ """ - def __init__(self, node_def, op, message): + def __init__(self, node_def, op, message, *args): """Creates an `UnauthenticatedError`.""" super(UnauthenticatedError, self).__init__(node_def, op, message, - UNAUTHENTICATED) + UNAUTHENTICATED, *args) @tf_export("errors.ResourceExhaustedError") @@ -366,10 +386,10 @@ class ResourceExhaustedError(OpError): @@__init__ """ - def __init__(self, node_def, op, message): + def __init__(self, node_def, op, message, *args): """Creates a `ResourceExhaustedError`.""" super(ResourceExhaustedError, self).__init__(node_def, op, message, - RESOURCE_EXHAUSTED) + RESOURCE_EXHAUSTED, *args) @tf_export("errors.FailedPreconditionError") @@ -383,10 +403,10 @@ class FailedPreconditionError(OpError): @@__init__ """ - def __init__(self, node_def, op, message): + def __init__(self, node_def, op, message, *args): """Creates a `FailedPreconditionError`.""" super(FailedPreconditionError, self).__init__(node_def, op, message, - FAILED_PRECONDITION) + FAILED_PRECONDITION, *args) @tf_export("errors.AbortedError") @@ -402,9 +422,9 @@ class AbortedError(OpError): @@__init__ """ - def __init__(self, node_def, op, message): + def __init__(self, node_def, op, message, *args): """Creates an `AbortedError`.""" - super(AbortedError, self).__init__(node_def, op, message, ABORTED) + super(AbortedError, self).__init__(node_def, op, message, ABORTED, *args) @tf_export("errors.OutOfRangeError") @@ -420,10 +440,10 @@ class OutOfRangeError(OpError): @@__init__ """ - def __init__(self, node_def, op, message): + def __init__(self, node_def, op, message, *args): """Creates an `OutOfRangeError`.""" - super(OutOfRangeError, self).__init__(node_def, op, message, - OUT_OF_RANGE) + super(OutOfRangeError, self).__init__(node_def, op, message, OUT_OF_RANGE, + *args) @tf_export("errors.UnimplementedError") @@ -439,10 +459,10 @@ class UnimplementedError(OpError): @@__init__ """ - def __init__(self, node_def, op, message): + def __init__(self, node_def, op, message, *args): """Creates an `UnimplementedError`.""" super(UnimplementedError, self).__init__(node_def, op, message, - UNIMPLEMENTED) + UNIMPLEMENTED, *args) @tf_export("errors.InternalError") @@ -455,9 +475,9 @@ class InternalError(OpError): @@__init__ """ - def __init__(self, node_def, op, message): + def __init__(self, node_def, op, message, *args): """Creates an `InternalError`.""" - super(InternalError, self).__init__(node_def, op, message, INTERNAL) + super(InternalError, self).__init__(node_def, op, message, INTERNAL, *args) @tf_export("errors.UnavailableError") @@ -469,10 +489,10 @@ class UnavailableError(OpError): @@__init__ """ - def __init__(self, node_def, op, message): + def __init__(self, node_def, op, message, *args): """Creates an `UnavailableError`.""" - super(UnavailableError, self).__init__(node_def, op, message, - UNAVAILABLE) + super(UnavailableError, self).__init__(node_def, op, message, UNAVAILABLE, + *args) @tf_export("errors.DataLossError") @@ -486,9 +506,9 @@ class DataLossError(OpError): @@__init__ """ - def __init__(self, node_def, op, message): + def __init__(self, node_def, op, message, *args): """Creates a `DataLossError`.""" - super(DataLossError, self).__init__(node_def, op, message, DATA_LOSS) + super(DataLossError, self).__init__(node_def, op, message, DATA_LOSS, *args) _CODE_TO_EXCEPTION_CLASS = { @@ -513,7 +533,8 @@ _CODE_TO_EXCEPTION_CLASS = { _pywrap_py_exception_registry.PyExceptionRegistry_Init(_CODE_TO_EXCEPTION_CLASS) _EXCEPTION_CLASS_TO_CODE = { - class_: code for code, class_ in _CODE_TO_EXCEPTION_CLASS.items()} + class_: code for code, class_ in _CODE_TO_EXCEPTION_CLASS.items() +} @tf_export(v1=["errors.exception_type_from_error_code"]) @@ -555,8 +576,7 @@ class raise_exception_on_not_ok_status(object): try: if c_api.TF_GetCode(self.status.status) != 0: raise _make_specific_exception( - None, None, - compat.as_text(c_api.TF_Message(self.status.status)), + None, None, compat.as_text(c_api.TF_Message(self.status.status)), c_api.TF_GetCode(self.status.status)) # Delete the underlying status object from memory otherwise it stays alive # as there is a reference to status from this from the traceback due to diff --git a/tensorflow/python/framework/errors_test.py b/tensorflow/python/framework/errors_test.py index c0a5e8e6811..ee03459f936 100644 --- a/tensorflow/python/framework/errors_test.py +++ b/tensorflow/python/framework/errors_test.py @@ -23,6 +23,7 @@ import pickle import warnings from tensorflow.core.lib.core import error_codes_pb2 +from tensorflow.python import _errors_test_helper from tensorflow.python.framework import c_api_util from tensorflow.python.framework import errors from tensorflow.python.framework import errors_impl @@ -146,6 +147,53 @@ class ErrorsTest(test.TestCase): self.assertEqual(exc.message, unpickled.message) self.assertEqual(exc.error_code, unpickled.error_code) + def testErrorPayloadsFromStatus(self): + for code, expected_exception in [ + (1, errors.CancelledError), + (2, errors.UnknownError), + (3, errors.InvalidArgumentError), + (4, errors.DeadlineExceededError), + (5, errors.NotFoundError), + (6, errors.AlreadyExistsError), + (7, errors.PermissionDeniedError), + (16, errors.UnauthenticatedError), + (8, errors.ResourceExhaustedError), + (9, errors.FailedPreconditionError), + (10, errors.AbortedError), + (11, errors.OutOfRangeError), + (12, errors.UnimplementedError), + (13, errors.InternalError), + (14, errors.UnavailableError), + (15, errors.DataLossError), + ]: + with self.assertRaises(expected_exception) as error: + _errors_test_helper.TestRaiseFromStatus(code) + self.assertEqual(error.exception.experimental_payloads["key1"], "value1") + self.assertEqual(error.exception.experimental_payloads["key2"], "value2") + + def testErrorPayloadsDefaultValue(self): + for exception_type in [ + (errors.CancelledError), + (errors.UnknownError), + (errors.InvalidArgumentError), + (errors.DeadlineExceededError), + (errors.NotFoundError), + (errors.AlreadyExistsError), + (errors.PermissionDeniedError), + (errors.UnauthenticatedError), + (errors.ResourceExhaustedError), + (errors.FailedPreconditionError), + (errors.AbortedError), + (errors.OutOfRangeError), + (errors.UnimplementedError), + (errors.InternalError), + (errors.UnavailableError), + (errors.DataLossError), + ]: + e = exception_type(None, None, None) + self.assertEqual(type(e.experimental_payloads), dict) + self.assertEqual(len(e.experimental_payloads), 0) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/framework/errors_test_helper.cc b/tensorflow/python/framework/errors_test_helper.cc new file mode 100644 index 00000000000..5f13a37d38f --- /dev/null +++ b/tensorflow/python/framework/errors_test_helper.cc @@ -0,0 +1,31 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "pybind11/pybind11.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/python/lib/core/pybind11_status.h" + +namespace tensorflow { +PYBIND11_MODULE(_errors_test_helper, m) { + m.def("TestRaiseFromStatus", [](int code) { + tensorflow::Status status(static_cast(code), + "test message"); + status.SetPayload("key1", "value1"); + status.SetPayload("key2", "value2"); + MaybeRaiseRegisteredFromStatus(status); + return 0; + }); +} +} // namespace tensorflow diff --git a/tensorflow/python/framework/graph_util_impl.py b/tensorflow/python/framework/graph_util_impl.py index a2055241517..f0af4e39ca1 100644 --- a/tensorflow/python/framework/graph_util_impl.py +++ b/tensorflow/python/framework/graph_util_impl.py @@ -377,6 +377,7 @@ def remove_training_nodes(input_graph, protected_nodes=None): return output_graph +@tf_export("__internal__.graph_util.graph_defs_equal", v1=[]) def graph_defs_equal(graph_def_1: graph_pb2.GraphDef, graph_def_2: graph_pb2.GraphDef, treat_nan_as_equal: bool = False) -> bool: @@ -387,6 +388,25 @@ def graph_defs_equal(graph_def_1: graph_pb2.GraphDef, Additionally, it checks that the functions in the function library are equal as sets. + Example usage: + + ``` + with tf.Graph().as_default() as g1: + tf.constant(1) + + with tf.Graph().as_default() as g2: + tf.constant(2) + + with tf.Graph().as_default() as g3: + tf.constant(1) + + assert tf.__internal__.graph_util.graph_defs_equal(g1.as_graph_def(), + g3.as_graph_def()) + + assert not tf.__internal__.graph_util.graph_defs_equal(g1.as_graph_def(), + g2.as_graph_def()) + ``` + Args: graph_def_1: Instance of `graph_pb2.GraphDef` to compare. graph_def_2: Instance of `graph_pb2.GraphDef` to compare. diff --git a/tensorflow/python/framework/meta_graph.py b/tensorflow/python/framework/meta_graph.py index dbc2a894d65..695828315e9 100644 --- a/tensorflow/python/framework/meta_graph.py +++ b/tensorflow/python/framework/meta_graph.py @@ -119,7 +119,8 @@ def _read_file(filename): if not file_io.file_exists(filename): raise IOError("File %s does not exist." % filename) # First try to read it as a binary file. - file_content = file_io.FileIO(filename, "rb").read() + with file_io.FileIO(filename, "rb") as f: + file_content = f.read() try: graph_def.ParseFromString(file_content) return graph_def @@ -629,7 +630,8 @@ def read_meta_graph_file(filename): if not file_io.file_exists(filename): raise IOError("File %s does not exist." % filename) # First try to read it as a binary file. - file_content = file_io.FileIO(filename, "rb").read() + with file_io.FileIO(filename, "rb") as f: + file_content = f.read() try: meta_graph_def.ParseFromString(file_content) return meta_graph_def diff --git a/tensorflow/python/grappler/cost_analyzer_tool.py b/tensorflow/python/grappler/cost_analyzer_tool.py index 7dbaf449cad..92ff471de58 100644 --- a/tensorflow/python/grappler/cost_analyzer_tool.py +++ b/tensorflow/python/grappler/cost_analyzer_tool.py @@ -21,6 +21,8 @@ from __future__ import print_function import argparse import sys +from absl import app + from google.protobuf import message from google.protobuf import text_format from tensorflow.contrib.fused_conv.ops import gen_fused_conv2d_bias_activation_op # pylint: disable=unused-import @@ -32,7 +34,6 @@ from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.grappler import cost_analyzer from tensorflow.python.grappler import tf_optimizer -from tensorflow.python.platform import app from tensorflow.python.platform import gfile from tensorflow.python.training import saver diff --git a/tensorflow/python/grappler/graph_analyzer.py b/tensorflow/python/grappler/graph_analyzer.py index c46a74ea64c..2e6dc640388 100644 --- a/tensorflow/python/grappler/graph_analyzer.py +++ b/tensorflow/python/grappler/graph_analyzer.py @@ -25,8 +25,9 @@ from __future__ import print_function import argparse import sys +from absl import app + from tensorflow.python import _pywrap_graph_analyzer as tf_wrap -from tensorflow.python.platform import app def main(_): diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 69d045e6a77..515a729663a 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -244,7 +244,6 @@ py_library( srcs_version = "PY3", deps = [ ":backend", - "//tensorflow/python/keras/utils:control_flow_util", "//tensorflow/python/keras/utils:engine_utils", ], ) diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index c4855aa429c..0bed2ba9542 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -1999,6 +1999,11 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): * Activation histograms * Sampled profiling + When used in `Model.evaluate`, in addition to epoch summaries, there will be + a summary that records evaluation metrics vs `Model.optimizer.iterations` + written. The metric names will be prepended with `evaluation`, with + `Model.optimizer.iterations` being the step in the visualized TensorBoard. + If you have installed TensorFlow with pip, you should be able to launch TensorBoard from the command line: @@ -2372,6 +2377,13 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): self._push_writer(self._val_writer, self._val_step) def on_test_end(self, logs=None): + if self.model.optimizer and hasattr(self.model.optimizer, 'iterations'): + with summary_ops_v2.record_if(True), self._val_writer.as_default(): + for name, value in logs.items(): + summary_ops_v2.scalar( + 'evaluation_' + name + '_vs_iterations', + value, + step=self.model.optimizer.iterations.read_value()) self._pop_writer() def _implements_train_batch_hooks(self): diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index bc5910a434d..8b082131521 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -751,6 +751,7 @@ class KerasCallbacksTest(keras_parameterized.TestCase): test_samples=TEST_SAMPLES, input_shape=(INPUT_DIM,), num_classes=NUM_CLASSES) + y_train = np_utils.to_categorical(y_train, num_classes=NUM_CLASSES) model.fit( x_train, @@ -760,7 +761,6 @@ class KerasCallbacksTest(keras_parameterized.TestCase): verbose=0) # Check that the filepath is a SavedModel directory. self.assertIn('saved_model.pb', os.listdir(filepath)) - os.remove(filepath) def _get_dummy_resource_for_model_checkpoint_testing(self): @@ -1978,6 +1978,8 @@ class TestTensorBoardV2(keras_parameterized.TestCase): summary_file.scalars, { _ObservedSummary(logdir=train_dir, tag='epoch_loss'), _ObservedSummary(logdir=validation_dir, tag='epoch_loss'), + _ObservedSummary( + logdir=validation_dir, tag='evaluation_loss_vs_iterations'), }) def test_TensorBoard_basic(self): @@ -1998,6 +2000,9 @@ class TestTensorBoardV2(keras_parameterized.TestCase): summary_file.scalars, { _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'), _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'), + _ObservedSummary( + logdir=self.validation_dir, + tag='evaluation_loss_vs_iterations'), }) def test_TensorBoard_across_invocations(self): @@ -2023,6 +2028,9 @@ class TestTensorBoardV2(keras_parameterized.TestCase): summary_file.scalars, { _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'), _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'), + _ObservedSummary( + logdir=self.validation_dir, + tag='evaluation_loss_vs_iterations'), }) def test_TensorBoard_no_spurious_event_files(self): @@ -2063,6 +2071,9 @@ class TestTensorBoardV2(keras_parameterized.TestCase): _ObservedSummary(logdir=self.train_dir, tag='batch_loss'), _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'), _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'), + _ObservedSummary( + logdir=self.validation_dir, + tag='evaluation_loss_vs_iterations'), }, ) @@ -2144,6 +2155,9 @@ class TestTensorBoardV2(keras_parameterized.TestCase): { _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'), _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'), + _ObservedSummary( + logdir=self.validation_dir, + tag='evaluation_loss_vs_iterations'), }, ) self.assertEqual( @@ -2175,6 +2189,9 @@ class TestTensorBoardV2(keras_parameterized.TestCase): { _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'), _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'), + _ObservedSummary( + logdir=self.validation_dir, + tag='evaluation_loss_vs_iterations'), }, ) self.assertEqual( @@ -2274,6 +2291,9 @@ class TestTensorBoardV2(keras_parameterized.TestCase): { _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'), _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'), + _ObservedSummary( + logdir=self.validation_dir, + tag='evaluation_loss_vs_iterations'), _ObservedSummary(logdir=self.train_dir, tag='batch_loss'), _ObservedSummary( logdir=self.train_dir, diff --git a/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py b/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py index bb9d303ee3d..8b37c204a20 100644 --- a/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py +++ b/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py @@ -20,7 +20,6 @@ from __future__ import print_function import json import os -import sys from absl.testing import parameterized @@ -36,11 +35,6 @@ from tensorflow.python.lib.io import file_io from tensorflow.python.platform import test -def _is_oss(): - """Returns whether the test is run under OSS.""" - return len(sys.argv) >= 1 and 'bazel' in sys.argv[0] - - def checkpoint_exists(filepath): """Returns whether the checkpoint `filepath` refers to exists.""" if filepath.endswith('.h5'): @@ -189,8 +183,6 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase): def proc_model_checkpoint_works_with_same_file_path(test_obj, saving_filepath): - if _is_oss(): - test_obj.skipTest('TODO(b/170838633): Failing in OSS') model, _, train_ds, steps = _model_setup(test_obj, file_format='') num_epoch = 4 diff --git a/tensorflow/python/keras/distribute/parameter_server_training_test.py b/tensorflow/python/keras/distribute/parameter_server_training_test.py index 6b67aef4d2f..e8d88531730 100644 --- a/tensorflow/python/keras/distribute/parameter_server_training_test.py +++ b/tensorflow/python/keras/distribute/parameter_server_training_test.py @@ -31,6 +31,7 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import combinations from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import parameter_server_strategy_v2 +from tensorflow.python.distribute import sharded_variable from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib from tensorflow.python.eager import backprop @@ -71,10 +72,11 @@ def make_cluster(num_workers, num_ps): return SimpleClusterResolver(ClusterSpec(cluster_def), rpc_layer="grpc") -def make_coordinator(num_workers, num_ps): +def make_coordinator(num_workers, num_ps, variable_partitioner=None): return coordinator_lib.ClusterCoordinator( parameter_server_strategy_v2.ParameterServerStrategyV2( - make_cluster(num_workers, num_ps))) + make_cluster(num_workers, num_ps), + variable_partitioner=variable_partitioner)) # TODO(yuefengz): move this to keras/integration_tests. @@ -243,7 +245,10 @@ class KPLTest(test.TestCase, parameterized.TestCase): class ModelFitTest(test.TestCase, parameterized.TestCase): - def _model_compile(self, steps_per_execution=1, run_eagerly=False): + def _model_compile(self, + steps_per_execution=1, + run_eagerly=False, + with_normalization_layer=False): class ResultAssertingCallback(callbacks_lib.Callback): @@ -260,9 +265,15 @@ class ModelFitTest(test.TestCase, parameterized.TestCase): raise RuntimeError("loss is supposed to be in the logs and float.") strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( - make_cluster(3, 2)) + make_cluster(3, 2), + variable_partitioner=sharded_variable.FixedShardsPartitioner(2)) with strategy.scope(): model = sequential.Sequential([core_layers.Dense(10)]) + if with_normalization_layer: + norm = keras.layers.BatchNormalization( + axis=-1, input_shape=(4, 4, 3), momentum=0.8) + model.add(norm) + model.compile( gradient_descent.SGD(), loss="mse", @@ -275,8 +286,10 @@ class ModelFitTest(test.TestCase, parameterized.TestCase): validation_data=None, x=None, steps_per_epoch=10, - run_eagerly=False): - model, callbacks = self._model_compile(steps_per_execution, run_eagerly) + run_eagerly=False, + with_normalization_layer=False): + model, callbacks = self._model_compile(steps_per_execution, run_eagerly, + with_normalization_layer) def dataset_fn(input_context): del input_context @@ -300,6 +313,12 @@ class ModelFitTest(test.TestCase, parameterized.TestCase): def testModelFit(self): model = self._model_fit() self.assertEqual(model.optimizer.iterations, 100) + return model + + @combinations.generate(combinations.combine(mode=["eager"])) + def testModelFitWithNormalizationLayer(self): + model = self._model_fit(with_normalization_layer=True) + self.assertEqual(model.optimizer.iterations, 100) @combinations.generate(combinations.combine(mode=["eager"])) def testModelFitWithStepsPerExecution(self): diff --git a/tensorflow/python/keras/distribute/sidecar_evaluator.py b/tensorflow/python/keras/distribute/sidecar_evaluator.py index 83d2110fa47..99c2bafac8d 100644 --- a/tensorflow/python/keras/distribute/sidecar_evaluator.py +++ b/tensorflow/python/keras/distribute/sidecar_evaluator.py @@ -24,7 +24,6 @@ import re # pylint: disable=g-direct-tensorflow-import from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl -from tensorflow.python.ops import summary_ops_v2 from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_utils @@ -78,8 +77,8 @@ class SidecarEvaluator(object): data=data, checkpoint_dir='/tmp/checkpoint_dir', # dir for training-saved checkpoint steps=None, # Eval until dataset is exhausted - log_dir='/tmp/log_dir', - max_evaluations=None # The evaluation needs to be stopped manually + max_evaluations=None, # The evaluation needs to be stopped manually + callbacks=[tf.keras.callbacks.TensorBoard(log_dir='/tmp/log_dir')] ).start() ``` @@ -87,7 +86,7 @@ class SidecarEvaluator(object): files which can be visualized by tensorboard (which provides a webpage link): ```bash - $ tensorboard --logdir=/tmp/log_dir + $ tensorboard --logdir=/tmp/log_dir/validation ... TensorBoard 2.4.0a0 at http://host:port (Press CTRL+C to quit) ``` @@ -109,9 +108,9 @@ class SidecarEvaluator(object): model, data, checkpoint_dir, - log_dir=None, steps=None, - max_evaluations=None): + max_evaluations=None, + callbacks=None): """Initializes an `SidecarEvaluator` object. Args: @@ -123,7 +122,6 @@ class SidecarEvaluator(object): types that Keras `model.evaluate` supports as the input data `x`, such as a `tf.data.Dataset`. checkpoint_dir: Directory where checkpoint files are saved. - log_dir: Directory where summary files for TensorBoard are saved. steps: Number of steps to perform evaluation for, when evaluating a single checkpoint file. If `None`, evaluation continues until the dataset is exhausted. For repeated evaluation dataset, user must specify `steps` to @@ -142,21 +140,19 @@ class SidecarEvaluator(object): checkpoints may be skipped if evaluation is slower than checkpoint creation. If `None`, `SidecarEvaluator` will evaluate indefinitely, and the user must terminate evaluator program themselves. + callbacks: List of `keras.callbacks.Callback` instances to apply during + evaluation. See [callbacks](/api_docs/python/tf/keras/callbacks). """ self.model = model self.data = data self.checkpoint_dir = checkpoint_dir - if log_dir: - self._summary_writer = summary_ops_v2.create_file_writer_v2( - logdir=log_dir) - else: - self._summary_writer = None self._iterations = variables.Variable( name='iterations', initial_value=_ITERATIONS_UNINITIALIZED, dtype=dtypes.int64) self.max_evaluations = max_evaluations self.steps = steps + self.callbacks = callbacks or [] def start(self): """Starts the evaluation loop.""" @@ -185,6 +181,12 @@ class SidecarEvaluator(object): if re.match(r'^layer_with_weights-[\d+]', attribute) is not None: self.model.load_weights(latest_checkpoint) break + else: + # The model checkpoint might not include optimizer in cases, e.g. + # using a custom training loop. Directly assign the iterations + # property to be used in callbacks. + if self.model.optimizer: + self.model.optimizer.iterations.assign(self._iterations) except (errors_impl.OpError,) as e: # A couple errors can happen here with the coordinator racing to write # checkpoint: @@ -208,8 +210,7 @@ class SidecarEvaluator(object): 'Evaluation starts: Model weights loaded from latest ' 'checkpoint file: %s.', latest_checkpoint) - # TODO(rchao): Support arbitrary callback for extensibility. - self.model.evaluate(self.data, steps=self.steps) + self.model.evaluate(self.data, steps=self.steps, callbacks=self.callbacks) logging.info( 'End of evaluation. Metrics: %s', ' '.join([ @@ -218,14 +219,6 @@ class SidecarEvaluator(object): for metric in self.model.metrics ])) - if self._summary_writer: - with summary_ops_v2.record_if(True), self._summary_writer.as_default(): - for metric in self.model.metrics: - summary_ops_v2.scalar( - metric.name, - metric.result(), - step=self._iterations.read_value()) - # TODO(rchao): Make the max evaluation robust in case users save the # checkpoints with epoch format {epoch:03d}. if (self.max_evaluations and diff --git a/tensorflow/python/keras/distribute/sidecar_evaluator_test.py b/tensorflow/python/keras/distribute/sidecar_evaluator_test.py index da3faaebb3e..c4e16cc5b4f 100644 --- a/tensorflow/python/keras/distribute/sidecar_evaluator_test.py +++ b/tensorflow/python/keras/distribute/sidecar_evaluator_test.py @@ -59,13 +59,17 @@ class SidecarEvaluatorTest(test.TestCase): # Asserts the content of the summary file. event_pb_written = False event_tags = [] - for event_pb in summary_iterator.summary_iterator( - os.path.join(log_dir, summary_files[0])): - if event_pb.step > 0: - self.assertEqual(event_pb.step, 32) - event_tags.append(event_pb.summary.value[0].tag) - event_pb_written = True - self.assertCountEqual(event_tags, ['categorical_accuracy', 'loss']) + for summary_file in summary_files: + for event_pb in summary_iterator.summary_iterator( + os.path.join(log_dir, summary_file)): + if event_pb.step > 0: + self.assertEqual(event_pb.step, 32) + event_tags.append(event_pb.summary.value[0].tag) + event_pb_written = True + self.assertCountEqual(event_tags, [ + 'evaluation_categorical_accuracy_vs_iterations', + 'evaluation_loss_vs_iterations' + ]) # Verifying at least one non-zeroth step is written to summary. self.assertTrue(event_pb_written) @@ -88,7 +92,7 @@ class SidecarEvaluatorTest(test.TestCase): checkpoint_manager.save() sidecar_evaluator = sidecar_evaluator_lib.SidecarEvaluator( - model, data=None, checkpoint_dir=checkpoint_dir, log_dir=None) + model, data=None, checkpoint_dir=checkpoint_dir) with self.assertRaisesRegexp( RuntimeError, '`iterations` cannot be loaded ' 'from the checkpoint file.'): @@ -124,14 +128,14 @@ class SidecarEvaluatorTest(test.TestCase): eval_model, data=dataset, checkpoint_dir=checkpoint_dir, - log_dir=log_dir, - max_evaluations=1).start() + max_evaluations=1, + callbacks=[keras.callbacks.TensorBoard(log_dir=log_dir)]).start() # Eval model has been restored to the same state as the original model, so # their weights should match. If not, restoration of the model didn't # work. self.assertModelsSameVariables(model, eval_model) - self.assertSummaryEventsWritten(log_dir) + self.assertSummaryEventsWritten(os.path.join(log_dir, 'validation')) def testSidecarEvaluatorOutputsSummarySavedWithCallback(self): checkpoint_dir = os.path.join(self.get_temp_dir(), 'checkpoints') @@ -158,8 +162,8 @@ class SidecarEvaluatorTest(test.TestCase): eval_model, data=dataset, checkpoint_dir=checkpoint_dir, - log_dir=log_dir, - max_evaluations=1) + max_evaluations=1, + callbacks=[keras.callbacks.TensorBoard(log_dir=log_dir)]) sidecar_evaluator.start() # Eval model has been restored to the same state as the original model, so @@ -170,7 +174,7 @@ class SidecarEvaluatorTest(test.TestCase): # check the iterations is restored. self.assertEqual(sidecar_evaluator._iterations.numpy(), _BATCH_SIZE) - self.assertSummaryEventsWritten(log_dir) + self.assertSummaryEventsWritten(os.path.join(log_dir, 'validation')) if __name__ == '__main__': diff --git a/tensorflow/python/keras/engine/BUILD b/tensorflow/python/keras/engine/BUILD index 2a0579cb965..ade3cfa2b11 100644 --- a/tensorflow/python/keras/engine/BUILD +++ b/tensorflow/python/keras/engine/BUILD @@ -75,7 +75,6 @@ py_library( "//tensorflow/python/keras/mixed_precision:loss_scale_optimizer", "//tensorflow/python/keras/mixed_precision:policy", "//tensorflow/python/keras/saving", - "//tensorflow/python/keras/utils:control_flow_util", "//tensorflow/python/keras/utils:engine_utils", "//tensorflow/python/keras/utils:metrics_utils", "//tensorflow/python/keras/utils:mode_keys", @@ -186,7 +185,6 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/keras/utils:control_flow_util", "//tensorflow/python/keras/utils:dataset_creator", "//tensorflow/python/keras/utils:engine_utils", "//tensorflow/python/keras/utils:tf_utils", diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index f5474d165b7..f4ac96066c5 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -2978,12 +2978,15 @@ class Layer(module.Module, version_utils.LayerVersionSelector): self.__class__._call_accepts_kwargs.fget.cache.pop(self, None) call_fn_args = self._call_fn_args + call_fn_args += self._call_full_argspec.kwonlyargs or [] self._expects_training_arg = ('training' in call_fn_args or self._call_accepts_kwargs) # The default training arg will be any (non-None) default specified in the # method signature, or None if no value is specified. - self._default_training_arg = self._call_fn_arg_defaults.get( - 'training') + call_fn_arg_defaults = self._call_fn_arg_defaults.copy() + call_fn_arg_defaults.update(self._call_full_argspec.kwonlydefaults or {}) + self._default_training_arg = call_fn_arg_defaults.get('training') + self._expects_mask_arg = ('mask' in call_fn_args or self._call_accepts_kwargs) diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py index 4da642121d2..d5fec0754ec 100644 --- a/tensorflow/python/keras/engine/base_layer_test.py +++ b/tensorflow/python/keras/engine/base_layer_test.py @@ -757,6 +757,88 @@ class BaseLayerTest(keras_parameterized.TestCase): else: return self._nested_layer(inputs) * 0.5 + self._test_custom_layer_training_arg( + CustomLayerNoTrainingArg=CustomLayerNoTrainingArg, + CustomLayerDefaultTrainingMissing=CustomLayerDefaultTrainingMissing, + CustomLayerDefaultTrainingNone=CustomLayerDefaultTrainingNone, + CustomLayerDefaultTrainingFalse=CustomLayerDefaultTrainingFalse, + CustomLayerDefaultTrainingTrue=CustomLayerDefaultTrainingTrue) + + @combinations.generate(combinations.combine(mode=['eager'])) + def test_custom_layer_training_arg_kwargonly(self): + class CustomLayerNoTrainingArg(base_layer.Layer): + + def __init__(self, nested_layer=None): + super(CustomLayerNoTrainingArg, self).__init__() + self._nested_layer = nested_layer or array_ops.identity + + def call(self, inputs): + return self._nested_layer(inputs) + + class CustomLayerDefaultTrainingMissing(base_layer.Layer): + + def __init__(self, nested_layer=None): + super(CustomLayerDefaultTrainingMissing, self).__init__() + self._nested_layer = nested_layer or array_ops.identity + + def call(self, inputs, *, training): + if training: + return self._nested_layer(inputs) + else: + return self._nested_layer(inputs) * 0.5 + + class CustomLayerDefaultTrainingNone(base_layer.Layer): + + def __init__(self, nested_layer=None): + super(CustomLayerDefaultTrainingNone, self).__init__() + self._nested_layer = nested_layer or array_ops.identity + + def call(self, inputs, *, training=None): + if training: + return self._nested_layer(inputs) + else: + return self._nested_layer(inputs) * 0.5 + + class CustomLayerDefaultTrainingFalse(base_layer.Layer): + + def __init__(self, nested_layer=None): + super(CustomLayerDefaultTrainingFalse, self).__init__() + self._nested_layer = nested_layer or array_ops.identity + + def call(self, inputs, *, training=False): + if training: + return self._nested_layer(inputs) + else: + return self._nested_layer(inputs) * 0.5 + + class CustomLayerDefaultTrainingTrue(base_layer.Layer): + + def __init__(self, nested_layer=None): + super(CustomLayerDefaultTrainingTrue, self).__init__() + self._nested_layer = nested_layer or array_ops.identity + + def call(self, inputs, *, training=True): + if training: + return self._nested_layer(inputs) + else: + return self._nested_layer(inputs) * 0.5 + + self._test_custom_layer_training_arg( + CustomLayerNoTrainingArg=CustomLayerNoTrainingArg, + CustomLayerDefaultTrainingMissing=CustomLayerDefaultTrainingMissing, + CustomLayerDefaultTrainingNone=CustomLayerDefaultTrainingNone, + CustomLayerDefaultTrainingFalse=CustomLayerDefaultTrainingFalse, + CustomLayerDefaultTrainingTrue=CustomLayerDefaultTrainingTrue) + + def _test_custom_layer_training_arg(self, + # pylint: disable=invalid-name + CustomLayerNoTrainingArg, + CustomLayerDefaultTrainingMissing, + CustomLayerDefaultTrainingNone, + CustomLayerDefaultTrainingFalse, + CustomLayerDefaultTrainingTrue, + # pylint: enable=invalid-name + ): x = array_ops.ones(shape=(1, 1)) # If the layer signature doesn't specify a default training arg, diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py index f1625c475a8..0c175e939e3 100644 --- a/tensorflow/python/keras/engine/data_adapter.py +++ b/tensorflow/python/keras/engine/data_adapter.py @@ -39,11 +39,11 @@ from tensorflow.python.eager import monitoring from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import backend from tensorflow.python.keras.engine import training_utils -from tensorflow.python.keras.utils import control_flow_util from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import dataset_creator from tensorflow.python.keras.utils import tf_utils @@ -559,11 +559,12 @@ class CompositeTensorDataAdapter(DataAdapter): flat_inputs += nest.flatten(y) def _is_composite(v): - # Dataset/iterator inherits from CompositeTensor but should be handled - # by DatasetAdapter and GeneratorAdapter. + # Dataset/iterator/DistributedDataset inherits from CompositeTensor but + # should be handled by DatasetAdapter and GeneratorAdapter. if (tf_utils.is_extension_type(v) and - not isinstance(v, (dataset_ops.DatasetV2, - iterator_ops.IteratorBase))): + not isinstance(v, + (dataset_ops.DatasetV2, iterator_ops.IteratorBase)) and + not _is_distributed_dataset(v)): return True # Support Scipy sparse tensors if scipy is installed if scipy_sparse is not None and scipy_sparse.issparse(v): @@ -1414,7 +1415,7 @@ def _make_class_weight_map_fn(class_weight): raise ValueError("`class_weight` not supported for " "3+ dimensional targets.") - y_classes = control_flow_util.smart_cond( + y_classes = smart_cond.smart_cond( y.shape.rank == 2 and backend.shape(y)[1] > 1, lambda: backend.argmax(y, axis=1), lambda: math_ops.cast(backend.reshape(y, (-1,)), dtypes.int64)) diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index bc3de6f9352..2d1d2e16ba5 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -1034,8 +1034,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): workers: Integer. Used for generator or `keras.utils.Sequence` input only. Maximum number of processes to spin up when using process-based threading. If unspecified, `workers` - will default to 1. If 0, will execute the generator on the main - thread. + will default to 1. use_multiprocessing: Boolean. Used for generator or `keras.utils.Sequence` input only. If `True`, use process-based threading. If unspecified, `use_multiprocessing` will default to @@ -1371,8 +1370,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): `max_queue_size` will default to 10. workers: Integer. Used for generator or `keras.utils.Sequence` input only. Maximum number of processes to spin up when using process-based - threading. If unspecified, `workers` will default to 1. If 0, will - execute the generator on the main thread. + threading. If unspecified, `workers` will default to 1. use_multiprocessing: Boolean. Used for generator or `keras.utils.Sequence` input only. If `True`, use process-based threading. If unspecified, `use_multiprocessing` will default to @@ -1613,7 +1611,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): workers: Integer. Used for generator or `keras.utils.Sequence` input only. Maximum number of processes to spin up when using process-based threading. If unspecified, `workers` will default - to 1. If 0, will execute the generator on the main thread. + to 1. use_multiprocessing: Boolean. Used for generator or `keras.utils.Sequence` input only. If `True`, use process-based threading. If unspecified, `use_multiprocessing` will default to diff --git a/tensorflow/python/keras/engine/training_utils_v1.py b/tensorflow/python/keras/engine/training_utils_v1.py index 743c0c7dbd3..e61d1deebc0 100644 --- a/tensorflow/python/keras/engine/training_utils_v1.py +++ b/tensorflow/python/keras/engine/training_utils_v1.py @@ -40,6 +40,7 @@ from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util @@ -47,7 +48,6 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras import callbacks as cbks from tensorflow.python.keras import losses from tensorflow.python.keras import metrics as metrics_module -from tensorflow.python.keras.utils import control_flow_util from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import losses_utils @@ -1037,7 +1037,7 @@ def standardize_weights(y, weight_vector[:] = np.nan weight_vector[keys] = values - y_classes = control_flow_util.smart_cond( + y_classes = smart_cond.smart_cond( len(y.shape.as_list()) == 2 and K.shape(y)[1] > 1, lambda: K.argmax(y, axis=1), lambda: math_ops.cast(K.reshape(y, (-1,)), dtypes.int64)) diff --git a/tensorflow/python/keras/feature_column/BUILD b/tensorflow/python/keras/feature_column/BUILD index a64f88b639a..dc382e5a9f9 100644 --- a/tensorflow/python/keras/feature_column/BUILD +++ b/tensorflow/python/keras/feature_column/BUILD @@ -66,6 +66,7 @@ py_library( ":dense_features", "//tensorflow/python:framework_ops", "//tensorflow/python/feature_column:feature_column_v2", + "//tensorflow/python/keras/utils:tf_contextlib", "//tensorflow/python/util:tf_export", ], ) diff --git a/tensorflow/python/keras/feature_column/base_feature_layer.py b/tensorflow/python/keras/feature_column/base_feature_layer.py index 1ca74846b56..3127ddedf65 100644 --- a/tensorflow/python/keras/feature_column/base_feature_layer.py +++ b/tensorflow/python/keras/feature_column/base_feature_layer.py @@ -21,8 +21,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import re from tensorflow.python.feature_column import feature_column_v2 +from tensorflow.python.framework import tensor_shape from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.utils import generic_utils from tensorflow.python.ops import array_ops @@ -75,8 +77,7 @@ class _BaseFeaturesLayer(Layer): with variable_scope.variable_scope( self.name, partitioner=self._partitioner): with variable_scope.variable_scope( - feature_column_v2._sanitize_column_name_for_variable_scope( # pylint: disable=protected-access - column.name)): + _sanitize_column_name_for_variable_scope(column.name)): column.create_state(self._state_manager) super(_BaseFeaturesLayer, self).build(None) @@ -114,8 +115,7 @@ class _BaseFeaturesLayer(Layer): def _verify_and_concat_tensors(self, output_tensors): """Verifies and concatenates the dense output of several columns.""" - feature_column_v2._verify_static_batch_size_equality( # pylint: disable=protected-access - output_tensors, self._feature_columns) + _verify_static_batch_size_equality(output_tensors, self._feature_columns) return array_ops.concat(output_tensors, -1) def get_config(self): @@ -142,3 +142,36 @@ class _BaseFeaturesLayer(Layer): config['partitioner'], custom_objects) return cls(**config_cp) + + +def _sanitize_column_name_for_variable_scope(name): + """Sanitizes user-provided feature names for use as variable scopes.""" + invalid_char = re.compile('[^A-Za-z0-9_.\\-]') + return invalid_char.sub('_', name) + + +def _verify_static_batch_size_equality(tensors, columns): + """Verify equality between static batch sizes. + + Args: + tensors: iterable of input tensors. + columns: Corresponding feature columns. + + Raises: + ValueError: in case of mismatched batch sizes. + """ + expected_batch_size = None + for i in range(0, len(tensors)): + # bath_size is a Dimension object. + batch_size = tensor_shape.Dimension(tensor_shape.dimension_value( + tensors[i].shape[0])) + if batch_size.value is not None: + if expected_batch_size is None: + bath_size_column_index = i + expected_batch_size = batch_size + elif not expected_batch_size.is_compatible_with(batch_size): + raise ValueError( + 'Batch size (first dimension) of each feature must be same. ' + 'Batch size of columns ({}, {}): ({}, {})'.format( + columns[bath_size_column_index].name, columns[i].name, + expected_batch_size, batch_size)) diff --git a/tensorflow/python/keras/feature_column/dense_features_v2.py b/tensorflow/python/keras/feature_column/dense_features_v2.py index ae1294c6fca..5f541206f14 100644 --- a/tensorflow/python/keras/feature_column/dense_features_v2.py +++ b/tensorflow/python/keras/feature_column/dense_features_v2.py @@ -22,6 +22,8 @@ from tensorflow.python.feature_column import feature_column_v2 as fc from tensorflow.python.framework import ops from tensorflow.python.keras.feature_column import base_feature_layer as kfc from tensorflow.python.keras.feature_column import dense_features +from tensorflow.python.keras.utils import tf_contextlib +from tensorflow.python.training.tracking import base as trackable from tensorflow.python.util.tf_export import keras_export @@ -85,7 +87,7 @@ class DenseFeatures(dense_features.DenseFeatures): trainable=trainable, name=name, **kwargs) - self._state_manager = fc._StateManagerImplV2(self, self.trainable) # pylint: disable=protected-access + self._state_manager = _StateManagerImplV2(self, self.trainable) def build(self, _): for column in self._feature_columns: @@ -94,3 +96,65 @@ class DenseFeatures(dense_features.DenseFeatures): # We would like to call Layer.build and not _DenseFeaturesHelper.build. # pylint: disable=protected-access super(kfc._BaseFeaturesLayer, self).build(None) # pylint: disable=bad-super-call + + +class _StateManagerImplV2(fc._StateManagerImpl): # pylint: disable=protected-access + """Manages the state of DenseFeatures.""" + + def create_variable(self, + feature_column, + name, + shape, + dtype=None, + trainable=True, + use_resource=True, + initializer=None): + if name in self._cols_to_vars_map[feature_column]: + raise ValueError('Variable already exists.') + + # We explicitly track these variables since `name` is not guaranteed to be + # unique and disable manual tracking that the add_weight call does. + with no_manual_dependency_tracking_scope(self._layer): + var = self._layer.add_weight( + name=name, + shape=shape, + dtype=dtype, + initializer=initializer, + trainable=self._trainable and trainable, + use_resource=use_resource) + if isinstance(var, trackable.Trackable): + self._layer._track_trackable(var, feature_column.name + '/' + name) # pylint: disable=protected-access + self._cols_to_vars_map[feature_column][name] = var + return var + + +@tf_contextlib.contextmanager +def no_manual_dependency_tracking_scope(obj): + """A context that disables manual dependency tracking for the given `obj`. + + Sometimes library methods might track objects on their own and we might want + to disable that and do the tracking on our own. One can then use this context + manager to disable the tracking the library method does and do your own + tracking. + + For example: + + class TestLayer(tf.keras.Layer): + def build(): + with no_manual_dependency_tracking_scope(self): + var = self.add_variable("name1") # Creates a var and doesn't track it + self._track_trackable("name2", var) # We track variable with name `name2` + + Args: + obj: A trackable object. + + Yields: + a scope in which the object doesn't track dependencies manually. + """ + # pylint: disable=protected-access + previous_value = getattr(obj, '_manual_tracking', True) + obj._manual_tracking = False + try: + yield + finally: + obj._manual_tracking = previous_value diff --git a/tensorflow/python/keras/feature_column/sequence_feature_column.py b/tensorflow/python/keras/feature_column/sequence_feature_column.py index ff68d4ccf7a..38882cfe42b 100644 --- a/tensorflow/python/keras/feature_column/sequence_feature_column.py +++ b/tensorflow/python/keras/feature_column/sequence_feature_column.py @@ -163,8 +163,8 @@ class SequenceFeatures(kfc._BaseFeaturesLayer): sequence_lengths.append(sequence_length) # Check and process sequence lengths. - fc._verify_static_batch_size_equality(sequence_lengths, - self._feature_columns) + kfc._verify_static_batch_size_equality( # pylint: disable=protected-access + sequence_lengths, self._feature_columns) sequence_length = _assert_all_equal_and_return(sequence_lengths) return self._verify_and_concat_tensors(output_tensors), sequence_length diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py index bcfb8b02a6c..4d43390bcfe 100644 --- a/tensorflow/python/keras/layers/convolutional.py +++ b/tensorflow/python/keras/layers/convolutional.py @@ -534,7 +534,8 @@ class Conv2D(Conv): provide the keyword argument `input_shape` (tuple of integers or `None`, does not include the sample axis), e.g. `input_shape=(128, 128, 3)` for 128x128 RGB pictures - in `data_format="channels_last"`. + in `data_format="channels_last"`. You can use `None` when + a dimension has variable size. Examples: diff --git a/tensorflow/python/keras/layers/multi_head_attention.py b/tensorflow/python/keras/layers/multi_head_attention.py index d57ce570eb1..82182bdf12a 100644 --- a/tensorflow/python/keras/layers/multi_head_attention.py +++ b/tensorflow/python/keras/layers/multi_head_attention.py @@ -377,21 +377,35 @@ class MultiHeadAttention(Layer): # These computations could be wrapped into the keras attention layer once # it support mult-head einsum computations. self._build_attention(output_rank) - if self._output_shape: - if not isinstance(self._output_shape, collections.abc.Sized): - output_shape = [self._output_shape] - else: - output_shape = self._output_shape + self._output_dense = self._make_output_dense( + free_dims, common_kwargs, "attention_output") + + def _make_output_dense(self, free_dims, common_kwargs, name=None): + """Builds the output projection matrix. + + Args: + free_dims: Number of free dimensions for einsum equation building. + common_kwargs: Common keyword arguments for einsum layer. + name: the name for the projection layer. + + Returns: + Projection layer. + """ + if self._output_shape: + if not isinstance(self._output_shape, collections.abc.Sized): + output_shape = [self._output_shape] else: - output_shape = [self._query_shape[-1]] - einsum_equation, bias_axes, output_rank = _build_proj_equation( - free_dims, bound_dims=2, output_dims=len(output_shape)) - self._output_dense = einsum_dense.EinsumDense( - einsum_equation, - output_shape=_get_output_shape(output_rank - 1, output_shape), - bias_axes=bias_axes if self._use_bias else None, - name="attention_output", - **common_kwargs) + output_shape = self._output_shape + else: + output_shape = [self._query_shape[-1]] + einsum_equation, bias_axes, output_rank = _build_proj_equation( + free_dims, bound_dims=2, output_dims=len(output_shape)) + return einsum_dense.EinsumDense( + einsum_equation, + output_shape=_get_output_shape(output_rank - 1, output_shape), + bias_axes=bias_axes if self._use_bias else None, + name=name, + **common_kwargs) def _build_attention(self, rank): """Builds multi-head dot-product attention computations. diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index 6ba0e50725e..9d9bab21eaa 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -518,22 +518,33 @@ class BatchNormalizationBase(Layer): self.built = True def _assign_moving_average(self, variable, value, momentum, inputs_size): + + def calculate_update_delta(): + decay = ops.convert_to_tensor_v2_with_dispatch( + 1.0 - momentum, name='decay') + if decay.dtype != variable.dtype.base_dtype: + decay = math_ops.cast(decay, variable.dtype.base_dtype) + update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay + if inputs_size is not None: + update_delta = array_ops.where(inputs_size > 0, update_delta, + K.zeros_like(update_delta)) + return update_delta + with K.name_scope('AssignMovingAvg') as scope: - with ops.colocate_with(variable): - decay = ops.convert_to_tensor_v2_with_dispatch( - 1.0 - momentum, name='decay') - if decay.dtype != variable.dtype.base_dtype: - decay = math_ops.cast(decay, variable.dtype.base_dtype) - update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay - if inputs_size is not None: - update_delta = array_ops.where(inputs_size > 0, update_delta, - K.zeros_like(update_delta)) - return state_ops.assign_sub(variable, update_delta, name=scope) + if ops.executing_eagerly_outside_functions(): + return variable.assign_sub(calculate_update_delta(), name=scope) + else: + with ops._colocate_with(variable): # pylint: disable=protected-access + return state_ops.assign_sub( + variable, calculate_update_delta(), name=scope) def _assign_new_value(self, variable, value): with K.name_scope('AssignNewValue') as scope: - with ops.colocate_with(variable): - return state_ops.assign(variable, value, name=scope) + if ops.executing_eagerly_outside_functions(): + return variable.assign(value, name=scope) + else: + with ops._colocate_with(variable): # pylint: disable=protected-access + return state_ops.assign(variable, value, name=scope) def _fused_batch_norm(self, inputs, training): """Returns the output of fused batch norm.""" diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py index 645ed5bc8d9..c2751b57a5e 100644 --- a/tensorflow/python/keras/losses.py +++ b/tensorflow/python/keras/losses.py @@ -28,10 +28,10 @@ from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend as K -from tensorflow.python.keras.utils import control_flow_util from tensorflow.python.keras.utils import losses_utils from tensorflow.python.keras.utils import tf_utils from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object @@ -1424,8 +1424,8 @@ def _maybe_convert_labels(y_true): # Convert the binary labels to -1 or 1. return 2. * y_true - 1. - updated_y_true = control_flow_util.smart_cond( - is_binary, _convert_binary_labels, lambda: y_true) + updated_y_true = smart_cond.smart_cond(is_binary, _convert_binary_labels, + lambda: y_true) return updated_y_true @@ -1642,8 +1642,8 @@ def categorical_crossentropy(y_true, num_classes = math_ops.cast(array_ops.shape(y_true)[-1], y_pred.dtype) return y_true * (1.0 - label_smoothing) + (label_smoothing / num_classes) - y_true = control_flow_util.smart_cond( - label_smoothing, _smooth_labels, lambda: y_true) + y_true = smart_cond.smart_cond(label_smoothing, _smooth_labels, + lambda: y_true) return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits) @@ -1739,8 +1739,8 @@ def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0): def _smooth_labels(): return y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing - y_true = control_flow_util.smart_cond( - label_smoothing, _smooth_labels, lambda: y_true) + y_true = smart_cond.smart_cond(label_smoothing, _smooth_labels, + lambda: y_true) return K.mean( K.binary_crossentropy(y_true, y_pred, from_logits=from_logits), axis=-1) diff --git a/tensorflow/python/keras/mixed_precision/BUILD b/tensorflow/python/keras/mixed_precision/BUILD index 585c3e19a6f..baf126e54a4 100644 --- a/tensorflow/python/keras/mixed_precision/BUILD +++ b/tensorflow/python/keras/mixed_precision/BUILD @@ -186,7 +186,6 @@ py_library( "//tensorflow/python/distribute:one_device_strategy", "//tensorflow/python/distribute:tpu_strategy", "//tensorflow/python/keras/optimizer_v2", - "//tensorflow/python/keras/utils:control_flow_util", "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/python/keras/mixed_precision/loss_scale_optimizer.py b/tensorflow/python/keras/mixed_precision/loss_scale_optimizer.py index 7cb1473f2c8..2c6871bbd49 100644 --- a/tensorflow/python/keras/mixed_precision/loss_scale_optimizer.py +++ b/tensorflow/python/keras/mixed_precision/loss_scale_optimizer.py @@ -27,11 +27,11 @@ from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond from tensorflow.python.keras import backend from tensorflow.python.keras import optimizers from tensorflow.python.keras.mixed_precision import loss_scale as keras_loss_scale_module from tensorflow.python.keras.optimizer_v2 import optimizer_v2 -from tensorflow.python.keras.utils import control_flow_util from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope @@ -742,8 +742,8 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): # DistributionStrategy does not support having a cond in a replica context # with a branch that calls `merge_call`, and self._optimizer.apply_gradients # calls `merge_call`. - maybe_apply_op = control_flow_util.smart_cond( - should_apply_grads, apply_fn, do_not_apply_fn) + maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn, + do_not_apply_fn) return control_flow_ops.group(maybe_apply_op, loss_scale_update_op) def _apply_gradients(self, grads, wrapped_vars, name, diff --git a/tensorflow/python/keras/saving/BUILD b/tensorflow/python/keras/saving/BUILD index d0dca28df41..9b7606e7db0 100644 --- a/tensorflow/python/keras/saving/BUILD +++ b/tensorflow/python/keras/saving/BUILD @@ -172,7 +172,6 @@ tf_py_test( tags = [ "no_rocm", "no_windows", - "notap", # TODO(b/161198218): flaky timeout ], deps = [ "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/keras/saving/save_test.py b/tensorflow/python/keras/saving/save_test.py index 69489e99061..7d8750c3e4b 100644 --- a/tensorflow/python/keras/saving/save_test.py +++ b/tensorflow/python/keras/saving/save_test.py @@ -1012,27 +1012,28 @@ class TestWholeModelSaving(keras_parameterized.TestCase): e.g. "head_0_accuracy" should not become "head_0_head_0_accuracy" after saving and loading a model. """ - input_ = keras.Input((4,)) - model = keras.Model( - input_, - [keras.layers.Softmax(name='head_0')(keras.layers.Dense(3)(input_)), - keras.layers.Softmax(name='head_1')(keras.layers.Dense(5)(input_))]) - metric = keras.metrics.BinaryAccuracy() - model.compile(optimizer='rmsprop', - loss='mse', - metrics={'head_0': [metric, 'accuracy']}) + with self.cached_session(): + input_ = keras.Input((4,)) + model = keras.Model( + input_, + [keras.layers.Softmax(name='head_0')(keras.layers.Dense(3)(input_)), + keras.layers.Softmax(name='head_1')(keras.layers.Dense(5)(input_))]) + metric = keras.metrics.BinaryAccuracy() + model.compile(optimizer='rmsprop', + loss='mse', + metrics={'head_0': [metric, 'accuracy']}) - # Run one iteration. - x = np.random.rand(2, 4) - y = {'head_0': np.random.randint(2, size=(2, 3)), - 'head_1': np.random.randint(2, size=(2, 5))} - model.fit(x, y, verbose=0) + # Run one iteration. + x = np.random.rand(2, 4) + y = {'head_0': np.random.randint(2, size=(2, 3)), + 'head_1': np.random.randint(2, size=(2, 5))} + model.fit(x, y, verbose=0) - # Save and reload. - save_format = testing_utils.get_save_format() - saved_model_dir = self._save_model_dir() - keras.models.save_model(model, saved_model_dir, save_format=save_format) - loaded = keras.models.load_model(saved_model_dir) + # Save and reload. + save_format = testing_utils.get_save_format() + saved_model_dir = self._save_model_dir() + keras.models.save_model(model, saved_model_dir, save_format=save_format) + loaded = keras.models.load_model(saved_model_dir) # Make sure the metrics names from the model before saving match the loaded # model. diff --git a/tensorflow/python/keras/saving/saved_model/layer_serialization.py b/tensorflow/python/keras/saving/saved_model/layer_serialization.py index 3f59a8ee726..bc82ca5acea 100644 --- a/tensorflow/python/keras/saving/saved_model/layer_serialization.py +++ b/tensorflow/python/keras/saving/saved_model/layer_serialization.py @@ -155,11 +155,13 @@ class RNNSavedModelSaver(LayerSavedModelSaver): super(RNNSavedModelSaver, self)._get_serialized_attributes_internal( serialization_cache)) states = data_structures.wrap_or_unwrap(self.obj.states) - # Force the tuple into TupleWrapper which is a trackable object. The - # save/load code requires all the objects to be trackable. - # Tuple is not converted to TupleWrapper by data_structures.wrap_or_unwrap() - # if it doesn't contains any trackable objects. + # SaveModel require all the objects to be Trackable when saving. + # If the states is still a tuple after wrap_or_unwrap, it means it doesn't + # contain any trackable item within it, eg empty tuple or (None, None) for + # stateless ConvLSTM2D. We convert them to list so that wrap_or_unwrap can + # make it a Trackable again for saving. When loaded, ConvLSTM2D is + # able to handle the tuple/list conversion. if isinstance(states, tuple): - states = data_structures._TupleWrapper(states) # pylint: disable=protected-access + states = data_structures.wrap_or_unwrap(list(states)) objects['states'] = states return objects, functions diff --git a/tensorflow/python/keras/saving/saved_model/load.py b/tensorflow/python/keras/saving/saved_model/load.py index 4f9f9a812e1..59818f9fb23 100644 --- a/tensorflow/python/keras/saving/saved_model/load.py +++ b/tensorflow/python/keras/saving/saved_model/load.py @@ -696,7 +696,7 @@ class KerasObjectLoader(object): model.__init__(inputs, outputs, name=config['name']) functional_lib.connect_ancillary_layers(model, created_layers) - # Set model dtype and trainable status. + # Set model dtype. _set_network_attributes_from_metadata(model) # Unblock models that are dependent on this model. @@ -1161,7 +1161,7 @@ def _set_network_attributes_from_metadata(revived_obj): metadata = revived_obj._serialized_attributes['metadata'] if metadata.get('dtype') is not None: revived_obj._set_dtype_policy(metadata['dtype']) - revived_obj.trainable = metadata['trainable'] + revived_obj._trainable = metadata['trainable'] # pylint:enable=protected-access diff --git a/tensorflow/python/keras/saving/saved_model/saved_model_test.py b/tensorflow/python/keras/saving/saved_model/saved_model_test.py index 3229dfbd1d5..9d2b26c2916 100644 --- a/tensorflow/python/keras/saving/saved_model/saved_model_test.py +++ b/tensorflow/python/keras/saving/saved_model/saved_model_test.py @@ -124,7 +124,7 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase): self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) return os.path.join(temp_dir, dirname) - def _test_save_and_load(self, use_dataset=False): + def _get_model(self): model = testing_utils.get_small_mlp(1, 4, input_dim=3) model.layers[-1].activity_regularizer = regularizers.get('l2') model.activity_regularizer = regularizers.get('l2') @@ -134,7 +134,9 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase): def callable_loss(): return math_ops.reduce_sum(model.weights[0]) model.add_loss(callable_loss) + return model + def _train_model(self, model, use_dataset=False): x = np.random.random((1, 3)) y = np.random.random((1, 4)) @@ -150,9 +152,14 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase): else: model.train_on_batch(x, y) + def _save_and_load(self, model): saved_model_dir = self._save_model_dir() tf_save.save(model, saved_model_dir) loaded = keras_load.load(saved_model_dir) + return loaded + + def _test_evaluation(self, model, loaded): + # Assert that original and loaded models have the same results when called. self.evaluate(variables.variables_initializer(loaded.variables)) self.assertAllClose(self.evaluate(model.weights), self.evaluate(loaded.weights)) @@ -175,13 +182,20 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase): @keras_parameterized.run_with_all_model_types def test_model_save_and_load(self): - self._test_save_and_load(use_dataset=True) + model = self._get_model() + self._train_model(model, use_dataset=False) + loaded = self._save_and_load(model) + self._test_evaluation(model, loaded) @keras_parameterized.run_with_all_model_types def test_model_save_and_load_dataset(self): - self._test_save_and_load(use_dataset=True) + model = self._get_model() + self._train_model(model, use_dataset=True) + loaded = self._save_and_load(model) + self._test_evaluation(model, loaded) def test_trainable_weights(self): + """Tests that trainable status of individual weights is preserved.""" layer = keras.layers.Dense(4, name='custom_layer') layer.build([3,]) layer.add_weight( @@ -208,6 +222,31 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase): self.assertAllClose(self.evaluate(getattr(layer, attr)), self.evaluate(getattr(loaded, attr))) + @keras_parameterized.run_with_all_model_types + def test_trainable_layers(self): + """Tests that trainable status of individual layers is preserved.""" + model = model = self._get_model() + # Set the last layer to *not* be trainable. + model.layers[-1].trainable = False + self._train_model(model, use_dataset=True) + loaded = self._save_and_load(model) + + self._test_evaluation(model, loaded) + self.assertFalse(model.layers[-1].trainable) + self.assertFalse(loaded.layers[-1].trainable) + + def test_trainable_custom_model_false(self): + """Tests that overall False trainable status of Model is preserved.""" + # Set all layers to *not* be trainable. + model = testing_utils.SmallSubclassMLP(1, 4, trainable=False) + model.compile(loss='mse', optimizer='rmsprop') + self._train_model(model, use_dataset=False) + loaded = self._save_and_load(model) + + self._test_evaluation(model, loaded) + self.assertEmpty(model.trainable_variables) + self.assertEmpty(loaded.trainable_variables) + def test_maintains_losses(self): """Tests that the layer losses do not change before and after export.""" model = keras.models.Sequential([LayerWithLoss()]) @@ -781,13 +820,15 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase): self.assertAllClose(layer.states, loaded_layer.states) self.assertAllClose(model(input_arr), loaded(input_arr)) - def testSaveStatelessConvLSTM2D(self): + @parameterized.named_parameters([('stateful', True), ('stateless', False)]) + def testSaveConvLSTM2D(self, stateful): data_format = 'channels_first' batch, timesteps, channels, rows, cols = 12, 10, 8, 4, 4 input_arr = np.ones( (batch, timesteps, channels, rows, cols)).astype('float32') layer = keras.layers.ConvLSTM2D( - filters=16, kernel_size=(1, 1), data_format=data_format) + filters=16, kernel_size=(1, 1), data_format=data_format, + stateful=stateful) x = keras.Input(batch_shape=(batch, timesteps, channels, rows, cols)) y = layer(x) model = keras.Model(x, y) @@ -798,6 +839,8 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase): del model loaded = keras_load.load(saved_model_dir) + if stateful: + loaded.reset_states() predict_2 = loaded(input_arr) self.assertAllClose(predict_1, predict_2) @@ -1200,11 +1243,15 @@ class MetricTest(test.TestCase, parameterized.TestCase): class CustomMetric(keras.metrics.MeanSquaredError): pass - model = testing_utils.get_small_mlp(1, 4, input_dim=3) - model.compile(loss='mse', optimizer='rmsprop', metrics=[CustomMetric()]) + with self.cached_session(): + metric = CustomMetric() + model = testing_utils.get_small_mlp(1, 4, input_dim=3) + model.compile(loss='mse', optimizer='rmsprop', metrics=[metric]) + self.evaluate(variables.global_variables_initializer()) + self.evaluate([v.initializer for v in metric.variables]) - saved_model_dir = self._save_model_dir() - tf_save.save(model, saved_model_dir) + saved_model_dir = self._save_model_dir() + tf_save.save(model, saved_model_dir) with self.assertRaisesRegex(ValueError, 'custom_objects'): keras_load.load(saved_model_dir) diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py index bf3de9a7f37..785beeabb49 100644 --- a/tensorflow/python/keras/testing_utils.py +++ b/tensorflow/python/keras/testing_utils.py @@ -469,8 +469,13 @@ def get_small_functional_mlp(num_hidden, num_classes, input_dim): class SmallSubclassMLP(models.Model): """A subclass model based small MLP.""" - def __init__(self, num_hidden, num_classes, use_bn=False, use_dp=False): - super(SmallSubclassMLP, self).__init__(name='test_model') + def __init__(self, + num_hidden, + num_classes, + use_bn=False, + use_dp=False, + **kwargs): + super(SmallSubclassMLP, self).__init__(name='test_model', **kwargs) self.use_bn = use_bn self.use_dp = use_dp diff --git a/tensorflow/python/keras/utils/BUILD b/tensorflow/python/keras/utils/BUILD index fc28e0e758e..73a25b37d4b 100644 --- a/tensorflow/python/keras/utils/BUILD +++ b/tensorflow/python/keras/utils/BUILD @@ -52,12 +52,7 @@ py_library( name = "control_flow_util", srcs = ["control_flow_util.py"], srcs_version = "PY3", - deps = [ - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:tensor_util", - "//tensorflow/python:variables", - ], + deps = [], ) py_library( @@ -111,7 +106,9 @@ py_library( deps = [ ":object_identity", "//tensorflow/python:composite_tensor", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", + "//tensorflow/python:smart_cond", "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_util", "//tensorflow/python:util", diff --git a/tensorflow/python/keras/utils/control_flow_util.py b/tensorflow/python/keras/utils/control_flow_util.py index 53d876b21f9..9a52149ca02 100644 --- a/tensorflow/python/keras/utils/control_flow_util.py +++ b/tensorflow/python/keras/utils/control_flow_util.py @@ -23,6 +23,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond as smart_module from tensorflow.python.framework import tensor_util from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variables @@ -110,21 +111,8 @@ def smart_cond(pred, true_fn=None, false_fn=None, name=None): # pylint: disable if isinstance(pred, variables.Variable): return control_flow_ops.cond( pred, true_fn=true_fn, false_fn=false_fn, name=name) - - if not callable(true_fn): - raise TypeError("`true_fn` must be callable.") - if not callable(false_fn): - raise TypeError("`false_fn` must be callable.") - - pred_value = constant_value(pred) - if pred_value is not None: - if pred_value: - return true_fn() - else: - return false_fn() - else: - return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn, - name=name) + return smart_module.smart_cond( + pred, true_fn=true_fn, false_fn=false_fn, name=name) def constant_value(pred): # pylint: disable=invalid-name diff --git a/tensorflow/python/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py index 22b15860d63..f258b6cdeac 100644 --- a/tensorflow/python/keras/utils/data_utils.py +++ b/tensorflow/python/keras/utils/data_utils.py @@ -695,8 +695,6 @@ class SequenceEnqueuer(object): class OrderedEnqueuer(SequenceEnqueuer): """Builds a Enqueuer from a Sequence. - Used in `fit_generator`, `evaluate_generator`, `predict_generator`. - Args: sequence: A `tf.keras.utils.data_utils.Sequence` object. use_multiprocessing: use multiprocessing if True, otherwise threading @@ -833,20 +831,17 @@ class GeneratorEnqueuer(SequenceEnqueuer): The provided generator can be finite in which case the class will throw a `StopIteration` exception. - Used in `fit_generator`, `evaluate_generator`, `predict_generator`. - Args: generator: a generator function which yields data use_multiprocessing: use multiprocessing if True, otherwise threading - wait_time: time to sleep in-between calls to `put()` random_seed: Initial seed for workers, will be incremented by one for each worker. """ - def __init__(self, sequence, + def __init__(self, generator, use_multiprocessing=False, random_seed=None): - super(GeneratorEnqueuer, self).__init__(sequence, use_multiprocessing) + super(GeneratorEnqueuer, self).__init__(generator, use_multiprocessing) self.random_seed = random_seed def _get_executor_init(self, workers): diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 0634f485a12..40abeb9d2a9 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -3021,22 +3021,6 @@ cuda_py_test( ], ) -tf_py_test( - name = "neon_depthwise_conv_op_test", - size = "medium", - srcs = ["neon_depthwise_conv_op_test.py"], - tags = ["no_windows"], - deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:nn", - "//tensorflow/python:nn_grad", - "//tensorflow/python:nn_ops", - "//third_party/py/numpy", - ], -) - cuda_py_test( name = "division_future_test", size = "medium", diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py index 917d7ae3f0e..f277575d5a1 100644 --- a/tensorflow/python/kernel_tests/embedding_ops_test.py +++ b/tensorflow/python/kernel_tests/embedding_ops_test.py @@ -288,7 +288,7 @@ class EmbeddingLookupTest(test.TestCase): norms = math_ops.sqrt( math_ops.reduce_sum(embeddings * embeddings, axis=1)) normalized = embeddings / array_ops.stack([norms, norms], axis=1) - self.assertAllEqual(embedding, 2 * self.evaluate(normalized)) + self.assertAllClose(embedding, 2 * self.evaluate(normalized)) @test_util.run_deprecated_v1 def testSimpleShardedPartitionedVariable(self): diff --git a/tensorflow/python/kernel_tests/neon_depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/neon_depthwise_conv_op_test.py deleted file mode 100644 index e5ae9574e38..00000000000 --- a/tensorflow/python/kernel_tests/neon_depthwise_conv_op_test.py +++ /dev/null @@ -1,291 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Functional tests for neon kernel for depthwise convolutional operations.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import test_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import nn_impl -from tensorflow.python.ops import nn_ops -import tensorflow.python.ops.nn_grad # pylint: disable=unused-import -from tensorflow.python.platform import test - - -def ConfigsToTest(): - """Iterator for different convolution shapes, strides and paddings. - - Yields: - Tuple (input_size, filter_size, out_size, stride, padding), the depthwise - convolution parameters. - """ - input_sizes = [[4, 5, 5, 48], [4, 8, 8, 84], [4, 17, 17, 48], [4, 35, 35, 2], - [4, 147, 147, 2], [3, 299, 299, 3], [5, 183, 183, 1]] - filter_sizes = [[1, 1, 48, 2], [1, 3, 84, 1], [3, 1, 48, 4], [5, 5, 2, 1], - [3, 3, 2, 8], [2, 2, 3, 8], [5, 5, 1, 2]] - out_sizes = [[4, 5, 5, 96], [4, 8, 8, 84], [4, 17, 17, 192], [4, 35, 35, 2], - [4, 49, 49, 16], [3, 150, 150, 24], [5, 92, 92, 2]] - strides = [1, 1, 1, 1, 3, 2, 2] - # pylint: disable=invalid-name - VALID = "VALID" - SAME = "SAME" - # pylint: enable=invalid-name - paddings = [SAME, SAME, SAME, SAME, VALID, SAME, SAME, SAME] - for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides, - paddings): - yield i, f, o, s, p - - -def CheckGradConfigsToTest(): - """Iterator for different convolution shapes, strides and paddings. - - compute_gradient_error() is very expensive. So the configs should be - relatively small. - - Yields: - Tuple (input_size, filter_size, out_size, stride, padding), the depthwise - convolution parameters. - """ - input_sizes = [[2, 5, 8, 1], [4, 5, 5, 1], [2, 4, 4, 2], [1, 15, 15, 2], - [2, 15, 16, 1]] - filter_sizes = [[4, 4, 1, 2], [2, 2, 1, 2], [3, 1, 2, 2], [1, 3, 2, 1], - [3, 3, 1, 2]] - out_sizes = [[2, 5, 8, 2], [4, 2, 2, 2], [2, 4, 4, 4], [1, 15, 15, 2], - [2, 5, 5, 2]] - strides = [1, 2, 1, 1, 3] - # pylint: disable=invalid-name - VALID = "VALID" - SAME = "SAME" - # pylint: enable=invalid-name - paddings = [SAME, VALID, SAME, SAME, VALID] - for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides, - paddings): - yield i, f, o, s, p - - -class DepthwiseConv2DTest(test.TestCase): - - # This is testing that depthwise_conv2d and depthwise_conv2d_native - # produce the same results. It also tests that NCHW and NHWC - # formats agree, by comparing the depthwise_conv2d_native with - # 'NCHW' format (with transposition) matches the 'NHWC' format using - # the higher level interface. - def _VerifyValues(self, - tensor_in_sizes, - filter_in_sizes, - stride, - padding, - use_gpu, - data_format="NHWC"): - """Verifies the output values of the convolution function. - - Args: - tensor_in_sizes: Input tensor dimensions in - [batch, input_rows, input_cols, input_depth]. - filter_in_sizes: Filter tensor dimensions in - [filter_rows, filter_cols, input_depth, depth_multiplier]. - stride: Stride. - padding: Padding type. - use_gpu: Whether to use GPU. - data_format: The data_format of the input. "NHWC" or "NCHW". - """ - total_size_1 = 1 - total_size_2 = 1 - for s in tensor_in_sizes: - total_size_1 *= s - for s in filter_in_sizes: - total_size_2 *= s - # Initializes the input and filter tensor with numbers incrementing from 1. - x1 = [f * 1.0 for f in range(1, total_size_1 + 1)] - x2 = [f * 1.0 for f in range(1, total_size_2 + 1)] - with self.cached_session(use_gpu=use_gpu) as sess: - with sess.graph._kernel_label_map({"DepthwiseConv2dNative": "neon"}): - t1 = constant_op.constant(x1, shape=tensor_in_sizes) - t1.set_shape(tensor_in_sizes) - t2 = constant_op.constant(x2, shape=filter_in_sizes) - - native_t1 = t1 - strides = [1, stride, stride, 1] - if data_format == "NCHW": - # Transpose from NHWC input to NCHW - # Ex. [4, 5, 5, 48] to [4, 48, 5, 5] - native_t1 = array_ops.transpose(t1, [0, 3, 1, 2]) - strides = [1, 1, stride, stride] - - conv_native = nn_ops.depthwise_conv2d_native( - native_t1, - t2, - strides=strides, - data_format=data_format, - padding=padding) - - if data_format == "NCHW": - # Transpose back from NCHW to NHWC - conv_native = array_ops.transpose(conv_native, [0, 2, 3, 1]) - - conv_interface = nn_impl.depthwise_conv2d( - t1, t2, strides=[1, stride, stride, 1], padding=padding) - - native_result = self.evaluate(conv_native) - interface_result = self.evaluate(conv_interface) - - print("depthwise conv_2d: ", tensor_in_sizes, "*", filter_in_sizes, - ", stride:", stride, ", padding: ", padding, ", max diff: ", - np.amax(np.absolute(native_result - interface_result))) - self.assertAllClose( - np.ravel(native_result), np.ravel(interface_result), 1e-5) - self.assertShapeEqual(native_result, conv_native) - self.assertShapeEqual(native_result, conv_interface) - - @test_util.run_deprecated_v1 - def testDepthwiseConv2D(self): - for index, (input_size, filter_size, _, stride, - padding) in enumerate(ConfigsToTest()): - print("Processing ", index, "th config.") - if index == 2: - self._VerifyValues( - input_size, filter_size, stride, padding, use_gpu=True) - self._VerifyValues( - input_size, filter_size, stride, padding, use_gpu=False) - - @test_util.run_deprecated_v1 - def testDepthwiseConv2DFormat(self): - if not test.is_gpu_available(): - return - - for index, (input_size, filter_size, _, stride, - padding) in enumerate(ConfigsToTest()): - print("Processing ", index, "th config.") - self._VerifyValues( - input_size, - filter_size, - stride, - padding, - use_gpu=True, - data_format="NCHW") - -# This is testing against hand calculated results. - - def _VerifyHandValues(self, tensor_in_sizes, filter_in_sizes, stride, padding, - expected, use_gpu): - """Verifies the output values of the depthwise convolution function. - - Args: - tensor_in_sizes: Input tensor dimensions in - [batch, input_rows, input_cols, input_depth]. - filter_in_sizes: Filter tensor dimensions in - [filter_rows, filter_cols, input_depth, depth_multiplier]. - stride: Stride. - padding: Padding type. - expected: An array containing the expected operation outputs. - use_gpu: Whether to use GPU. - """ - total_size_1 = 1 - total_size_2 = 1 - for s in tensor_in_sizes: - total_size_1 *= s - for s in filter_in_sizes: - total_size_2 *= s - # Initializes the input tensor with array containing incrementing - # numbers from 1. - x1 = [f * 1.0 for f in range(1, total_size_1 + 1)] - x2 = [f * 1.0 for f in range(1, total_size_2 + 1)] - with self.cached_session(use_gpu=use_gpu) as sess: - with sess.graph._kernel_label_map({"DepthwiseConv2dNative": "neon"}): - t1 = constant_op.constant(x1, shape=tensor_in_sizes) - t1.set_shape(tensor_in_sizes) - t2 = constant_op.constant(x2, shape=filter_in_sizes) - conv = nn_ops.depthwise_conv2d_native( - t1, t2, strides=[1, stride, stride, 1], padding=padding) - value = self.evaluate(conv) - print("value = ", value) - self.assertAllClose(expected, np.ravel(value), 1e-5) - self.assertShapeEqual(value, conv) - - @test_util.run_deprecated_v1 - def testConv2D2x2Filter(self): - # The inputs look like this (it's a 3 x 2 matrix, each of depth 2): - # - # [ (1.0, 2.0), (3.0, 4.0), ( 5.0, 6.0) ] - # [ (7.0, 8.0), (9.0, 10.0), (11.0, 12.0) ] - # We can view this as two inputs - # - # input depth 0: - # - # [ 1.0, 3.0, 5.0 ] - # [ 7.0, 9.0, 11.0 ] - # - # input depth 1: - # - # [ 2.0, 4.0, 6.0 ] - # [ 8.0, 10.0, 12.0 ] - # - # The filter looks like this (it has two 2 x 2 patches, each generating 2 - # depths): - # - # filter #0: - # - # [ (1.0, 3.0), ( 5.0, 7.0)] - # [ (9.0, 11.0), (13.0, 15.0)] - # - # filter #1: - # - # [ ( 2.0, 4.0), ( 6.0, 8.0)] - # [ (10.0, 12.0), (14.0, 16.0)] - # - # So the outputs are: - # - # (position 0, 0: in_depth 0, output_depth 0 -- using filter #0) - # 1.0 * 1.0 + 7.0 * 9.0 + 3.0 * 5.0 + 9.0 * 13.0 = 196 - # (position 0, 0: in_depth 0, output_depth 1 -- using filter #1) - # 1.0 * 2.0 + 7.0 * 10.0 + 3.0 * 6.0 + 9.0 * 14.0 = 216 - # (position 0, 0: in_depth 1, output_depth 2 -- using filter #0) - # 2.0 * 3.0 + 8.0 * 11.0 + 4.0 * 7.0 + 10.0 * 15.0 = 272 - # (position 0, 0: in_depth 1, output_depth 3 -- using filter #1) - # 2.0 * 4.0 + 8.0 * 12.0 + 4.0 * 8.0 + 10.0 * 16.0 = 296 - # - # (position 1, 0: in_depth 0, output_depth 0 -- using filter #0) - # 3.0 * 1.0 + 9.0 * 9.0 + 5.0 * 5.0 + 11.0 * 13.0 = 252 - # (position 1, 0: in_depth 0, output_depth 1 -- using filter #1) - # 3.0 * 2.0 + 9.0 * 10.0 + 5.0 * 6.0 + 11.0 * 14.0 = 280 - # (position 1, 0: in_depth 1, output_depth 2 -- using filter #0) - # 4.0 * 3.0 + 10.0 * 11.0 + 6.0 * 7.0 + 12.0 * 15.0 = 344 - # (position 1, 0: in_depth 1, output_depth 3 -- using filter #1) - # 4.0 * 4.0 + 10.0 * 12.0 + 6.0 * 8.0 + 12.0 * 16.0 = 376 - expected_output = [196, 216, 272, 296, 252, 280, 344, 376] - self._VerifyHandValues( - tensor_in_sizes=[1, 2, 3, 2], - filter_in_sizes=[2, 2, 2, 2], - stride=1, - padding="VALID", - expected=expected_output, - use_gpu=False) - - self._VerifyHandValues( - tensor_in_sizes=[1, 2, 3, 2], - filter_in_sizes=[2, 2, 2, 2], - stride=1, - padding="VALID", - expected=expected_output, - use_gpu=True) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/python/kernel_tests/sparse_reshape_op_test.py b/tensorflow/python/kernel_tests/sparse_reshape_op_test.py index 946774e7275..ab98c9a3deb 100644 --- a/tensorflow/python/kernel_tests/sparse_reshape_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_reshape_op_test.py @@ -94,7 +94,7 @@ class SparseReshapeTest(test.TestCase): self.assertAllEqual((2, 3 * 4), sp_output.shape) def testSameShape(self): - with self.session(use_gpu=False) as sess: + with self.session() as sess: input_val = self._SparseTensorValue_5x6() sp_output = sparse_ops.sparse_reshape(input_val, [5, 6]) @@ -105,7 +105,7 @@ class SparseReshapeTest(test.TestCase): @test_util.run_deprecated_v1 def testFeedSameShape(self): - with self.session(use_gpu=False) as sess: + with self.session() as sess: sp_input = self._SparseTensorPlaceholder() input_val = self._SparseTensorValue_5x6() sp_output = sparse_ops.sparse_reshape(sp_input, [5, 6]) @@ -117,7 +117,7 @@ class SparseReshapeTest(test.TestCase): @test_util.run_deprecated_v1 def testWorksWellWithTfShape(self): - with self.session(use_gpu=False) as sess: + with self.session() as sess: sp_input = self._SparseTensorPlaceholder() input_val = self._SparseTensorValue_5x6() shape = array_ops.shape(sp_input) # tf.shape generates int32 output @@ -130,7 +130,7 @@ class SparseReshapeTest(test.TestCase): @test_util.run_deprecated_v1 def testFeedSameShapeWithInferredDim(self): - with self.session(use_gpu=False) as sess: + with self.session() as sess: sp_input = self._SparseTensorPlaceholder() input_val = self._SparseTensorValue_5x6() sp_output = sparse_ops.sparse_reshape(sp_input, [-1, 6]) @@ -142,7 +142,7 @@ class SparseReshapeTest(test.TestCase): @test_util.run_deprecated_v1 def testFeedNewShapeSameRank(self): - with self.session(use_gpu=False) as sess: + with self.session() as sess: sp_input = self._SparseTensorPlaceholder() input_val = self._SparseTensorValue_5x6() sp_output = sparse_ops.sparse_reshape(sp_input, [3, 10]) @@ -156,7 +156,7 @@ class SparseReshapeTest(test.TestCase): @test_util.run_deprecated_v1 def testFeedNewShapeSameRankWithInferredDim(self): - with self.session(use_gpu=False) as sess: + with self.session() as sess: sp_input = self._SparseTensorPlaceholder() input_val = self._SparseTensorValue_5x6() sp_output = sparse_ops.sparse_reshape(sp_input, [3, -1]) @@ -169,7 +169,7 @@ class SparseReshapeTest(test.TestCase): self.assertAllEqual(output_val.dense_shape, [3, 10]) def testUpRank(self): - with self.session(use_gpu=False) as sess: + with self.session() as sess: input_val = self._SparseTensorValue_5x6() sp_output = sparse_ops.sparse_reshape(input_val, [2, 3, 5]) @@ -182,7 +182,7 @@ class SparseReshapeTest(test.TestCase): @test_util.run_deprecated_v1 def testFeedUpRank(self): - with self.session(use_gpu=False) as sess: + with self.session() as sess: sp_input = self._SparseTensorPlaceholder() input_val = self._SparseTensorValue_5x6() sp_output = sparse_ops.sparse_reshape(sp_input, [2, 3, 5]) @@ -196,7 +196,7 @@ class SparseReshapeTest(test.TestCase): @test_util.run_deprecated_v1 def testFeedUpRankWithInferredDim(self): - with self.session(use_gpu=False) as sess: + with self.session() as sess: sp_input = self._SparseTensorPlaceholder() input_val = self._SparseTensorValue_5x6() sp_output = sparse_ops.sparse_reshape(sp_input, [2, -1, 5]) @@ -210,7 +210,7 @@ class SparseReshapeTest(test.TestCase): @test_util.run_deprecated_v1 def testFeedDownRank(self): - with self.session(use_gpu=False) as sess: + with self.session() as sess: sp_input = self._SparseTensorPlaceholder() input_val = self._SparseTensorValue_2x3x4() sp_output = sparse_ops.sparse_reshape(sp_input, [6, 4]) @@ -224,7 +224,7 @@ class SparseReshapeTest(test.TestCase): @test_util.run_deprecated_v1 def testFeedDownRankWithInferredDim(self): - with self.session(use_gpu=False) as sess: + with self.session() as sess: sp_input = self._SparseTensorPlaceholder() input_val = self._SparseTensorValue_2x3x4() sp_output = sparse_ops.sparse_reshape(sp_input, [6, -1]) @@ -238,7 +238,7 @@ class SparseReshapeTest(test.TestCase): @test_util.run_deprecated_v1 def testFeedMultipleInferredDims(self): - with self.session(use_gpu=False) as sess: + with self.session() as sess: sp_input = self._SparseTensorPlaceholder() input_val = self._SparseTensorValue_5x6() sp_output = sparse_ops.sparse_reshape(sp_input, [4, -1, -1]) @@ -254,7 +254,7 @@ class SparseReshapeTest(test.TestCase): @test_util.run_deprecated_v1 def testFeedMismatchedSizes(self): - with self.session(use_gpu=False) as sess: + with self.session() as sess: sp_input = self._SparseTensorPlaceholder() input_val = self._SparseTensorValue_5x6() sp_output = sparse_ops.sparse_reshape(sp_input, [4, 7]) @@ -264,7 +264,7 @@ class SparseReshapeTest(test.TestCase): @test_util.run_deprecated_v1 def testFeedMismatchedSizesWithInferredDim(self): - with self.session(use_gpu=False) as sess: + with self.session() as sess: sp_input = self._SparseTensorPlaceholder() input_val = self._SparseTensorValue_5x6() sp_output = sparse_ops.sparse_reshape(sp_input, [4, -1]) @@ -273,7 +273,7 @@ class SparseReshapeTest(test.TestCase): @test_util.run_deprecated_v1 def testFeedPartialShapes(self): - with self.session(use_gpu=False): + with self.session(): # Incorporate new rank into shape information if known sp_input = self._SparseTensorPlaceholder() sp_output = sparse_ops.sparse_reshape(sp_input, [2, 3, 5]) @@ -299,7 +299,7 @@ class SparseReshapeTest(test.TestCase): @test_util.run_deprecated_v1 def testFeedDenseReshapeSemantics(self): - with self.session(use_gpu=False) as sess: + with self.session() as sess: # Compute a random rank-5 initial shape and new shape, randomly sparsify # it, and check that the output of SparseReshape has the same semantics # as a dense reshape. diff --git a/tensorflow/python/lib/core/pybind11_status.h b/tensorflow/python/lib/core/pybind11_status.h index 3f9991c6577..a2035ab6f57 100644 --- a/tensorflow/python/lib/core/pybind11_status.h +++ b/tensorflow/python/lib/core/pybind11_status.h @@ -49,6 +49,15 @@ inline PyObject* TFStatusToPyExc(const TF_Status* status) { return CodeToPyExc(TF_GetCode(status)); } +inline pybind11::dict StatusPayloadToDict(const Status& status) { + pybind11::dict dict; + const auto& payloads = status.GetAllPayloads(); + for (auto& pair : payloads) { + dict[pair.first.c_str()] = pair.second.c_str(); + } + return dict; +} + } // namespace internal inline void MaybeRaiseFromStatus(const Status& status) { @@ -59,12 +68,17 @@ inline void MaybeRaiseFromStatus(const Status& status) { } } +inline void SetRegisteredErrFromStatus(const tensorflow::Status& status) { + PyErr_SetObject(PyExceptionRegistry::Lookup(status.code()), + pybind11::make_tuple(pybind11::none(), pybind11::none(), + status.error_message(), + internal::StatusPayloadToDict(status)) + .ptr()); +} + inline void MaybeRaiseRegisteredFromStatus(const tensorflow::Status& status) { if (!status.ok()) { - PyErr_SetObject(PyExceptionRegistry::Lookup(status.code()), - pybind11::make_tuple(pybind11::none(), pybind11::none(), - status.error_message()) - .ptr()); + SetRegisteredErrFromStatus(status); throw pybind11::error_already_set(); } } @@ -74,11 +88,7 @@ inline void MaybeRaiseRegisteredFromStatusWithGIL( if (!status.ok()) { // Acquire GIL for throwing exception. pybind11::gil_scoped_acquire acquire; - - PyErr_SetObject(PyExceptionRegistry::Lookup(status.code()), - pybind11::make_tuple(pybind11::none(), pybind11::none(), - status.error_message()) - .ptr()); + SetRegisteredErrFromStatus(status); throw pybind11::error_already_set(); } } diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 3c6bb4fb829..c5821ce8fcb 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -1178,7 +1178,7 @@ def _SigmoidGradGrad(op, grad): def _SignGrad(op, _): """Returns 0.""" x = op.inputs[0] - return array_ops.zeros(array_ops.shape(x), dtype=x.dtype) + return array_ops.zeros_like(x) @ops.RegisterGradient("Sin") @@ -1560,11 +1560,9 @@ def _MaximumMinimumGrad(op, grad, selector_op): # No gradient skipping, so do the full gradient computation pass x = op.inputs[0] - gdtype = grad.dtype sx = array_ops.shape(x) sy = array_ops.shape(y) - gradshape = array_ops.shape(grad) - zeros = array_ops.zeros(gradshape, gdtype) + zeros = array_ops.zeros_like(grad) xmask = selector_op(x, y) rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) if skip_input_indices is not None and 0 in skip_input_indices: diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 70b137a57a8..089297171bf 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -251,10 +251,8 @@ def _MaxPool3DGrad(op, grad): @ops.RegisterGradient("MaxPool3DGrad") def _MaxPool3DGradGrad(op, grad): - return (array_ops.zeros( - shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), - array_ops.zeros( - shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), + return (array_ops.zeros_like(op.inputs[0]), + array_ops.zeros_like(op.inputs[1]), gen_nn_ops.max_pool3d_grad_grad( op.inputs[0], op.inputs[1], @@ -267,10 +265,8 @@ def _MaxPool3DGradGrad(op, grad): @ops.RegisterGradient("MaxPool3DGradGrad") def _MaxPool3DGradGradGrad(op, grad): - return (array_ops.zeros( - shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), - array_ops.zeros( - shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), + return (array_ops.zeros_like(op.inputs[0]), + array_ops.zeros_like(op.inputs[1]), gen_nn_ops.max_pool3d_grad( op.inputs[0], op.inputs[1], @@ -441,8 +437,7 @@ def _Relu6Grad(op, grad): @ops.RegisterGradient("Relu6Grad") def _Relu6GradGrad(op, grad): x = op.inputs[1] - return (gen_nn_ops.relu6_grad(grad, x), - array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)) + return (gen_nn_ops.relu6_grad(grad, x), array_ops.zeros_like(x)) @ops.RegisterGradient("LeakyRelu") @@ -456,8 +451,8 @@ def _LeakyReluGrad(op, grad): def _LeakyReluGradGrad(op, grad): x = op.inputs[1] alpha = op.get_attr("alpha") - return (gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha), - array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)) + return (gen_nn_ops.leaky_relu_grad(grad, x, + alpha=alpha), array_ops.zeros_like(x)) @ops.RegisterGradient("Elu") @@ -496,8 +491,7 @@ def _SoftsignGrad(op, grad): @ops.RegisterGradient("ReluGrad") def _ReluGradGrad(op, grad): x = op.inputs[1] - return (gen_nn_ops.relu_grad(grad, x), - array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)) + return (gen_nn_ops.relu_grad(grad, x), array_ops.zeros_like(x)) def _BroadcastMul(vec, mat): @@ -721,10 +715,8 @@ def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad): @ops.RegisterGradient("MaxPoolGrad") def _MaxPoolGradGrad(op, grad): - return (array_ops.zeros( - shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), - array_ops.zeros( - shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), + return (array_ops.zeros_like(op.inputs[0]), + array_ops.zeros_like(op.inputs[1]), gen_nn_ops.max_pool_grad_grad( op.inputs[0], op.inputs[1], @@ -739,10 +731,8 @@ def _MaxPoolGradGrad(op, grad): def _MaxPoolGradGradV2(op, grad): ksize = op.inputs[3] strides = op.inputs[4] - return (array_ops.zeros( - shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), - array_ops.zeros( - shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), + return (array_ops.zeros_like(op.inputs[0]), + array_ops.zeros_like(op.inputs[1]), gen_nn_ops.max_pool_grad_grad_v2( op.inputs[0], op.inputs[1], @@ -755,10 +745,8 @@ def _MaxPoolGradGradV2(op, grad): @ops.RegisterGradient("MaxPoolGradGrad") def _MaxPoolGradGradGrad(op, grad): - return (array_ops.zeros( - shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), - array_ops.zeros( - shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), + return (array_ops.zeros_like(op.inputs[0]), + array_ops.zeros_like(op.inputs[1]), gen_nn_ops.max_pool_grad( op.inputs[0], op.inputs[1], diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 0c3cf93dfbe..d9925b06c9d 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -3476,6 +3476,7 @@ def leaky_relu(features, alpha=0.2, name=None): Source: [Rectifier Nonlinearities Improve Neural Network Acoustic Models. AL Maas, AY Hannun, AY Ng - Proc. ICML, 2013] (https://ai.stanford.edu/~amaas/papers/relu_hybrid_icml2013_final.pdf). + Args: features: A `Tensor` representing preactivation values. Must be one of the following types: `float16`, `float32`, `float64`, `int32`, `int64`. diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py index 6daa631084f..0a0132ff570 100644 --- a/tensorflow/python/saved_model/loader_impl.py +++ b/tensorflow/python/saved_model/loader_impl.py @@ -94,15 +94,17 @@ def parse_saved_model(export_dir): # Parse the SavedModel protocol buffer. saved_model = saved_model_pb2.SavedModel() if file_io.file_exists(path_to_pb): + with file_io.FileIO(path_to_pb, "rb") as f: + file_content = f.read() try: - file_content = file_io.FileIO(path_to_pb, "rb").read() saved_model.ParseFromString(file_content) return saved_model except message.DecodeError as e: raise IOError("Cannot parse file %s: %s." % (path_to_pb, str(e))) elif file_io.file_exists(path_to_pbtxt): + with file_io.FileIO(path_to_pbtxt, "rb") as f: + file_content = f.read() try: - file_content = file_io.FileIO(path_to_pbtxt, "rb").read() text_format.Merge(file_content.decode("utf-8"), saved_model) return saved_model except text_format.ParseError as e: diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl index e736d1a112d..4369ad385c2 100644 --- a/tensorflow/python/tools/api/generator/api_init_files.bzl +++ b/tensorflow/python/tools/api/generator/api_init_files.bzl @@ -13,6 +13,7 @@ TENSORFLOW_API_INIT_FILES = [ "__internal__/distribute/multi_process_runner/__init__.py", "__internal__/eager_context/__init__.py", "__internal__/function/__init__.py", + "__internal__/graph_util/__init__.py", "__internal__/monitoring/__init__.py", "__internal__/nest/__init__.py", "__internal__/ops/__init__.py", diff --git a/tensorflow/python/tools/saved_model_utils.py b/tensorflow/python/tools/saved_model_utils.py index b88ecb5ead0..64d070f6807 100644 --- a/tensorflow/python/tools/saved_model_utils.py +++ b/tensorflow/python/tools/saved_model_utils.py @@ -57,15 +57,17 @@ def read_saved_model(saved_model_dir): # Parse the SavedModel protocol buffer. saved_model = saved_model_pb2.SavedModel() if file_io.file_exists(path_to_pb): + with file_io.FileIO(path_to_pb, "rb") as f: + file_content = f.read() try: - file_content = file_io.FileIO(path_to_pb, "rb").read() saved_model.ParseFromString(file_content) return saved_model except message.DecodeError as e: raise IOError("Cannot parse file %s: %s." % (path_to_pb, str(e))) elif file_io.file_exists(path_to_pbtxt): + with file_io.FileIO(path_to_pbtxt, "rb") as f: + file_content = f.read() try: - file_content = file_io.FileIO(path_to_pbtxt, "rb").read() text_format.Merge(file_content.decode("utf-8"), saved_model) return saved_model except text_format.ParseError as e: diff --git a/tensorflow/python/tpu/tpu_embedding_v2.py b/tensorflow/python/tpu/tpu_embedding_v2.py index 83efe88ba07..8a792b3fdef 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2.py +++ b/tensorflow/python/tpu/tpu_embedding_v2.py @@ -156,7 +156,7 @@ class TPUEmbedding(tracking.AutoTrackable): strategy.distribute_datasets_from_function( dataset_fn=..., options=tf.distribute.InputOptions( - experimental_prefetch_to_device=False)) + experimental_fetch_to_device=False)) dataset_iterator = iter(distributed_dataset) ``` @@ -592,7 +592,7 @@ class TPUEmbedding(tracking.AutoTrackable): strategy.distribute_datasets_from_function( dataset_fn=..., options=tf.distribute.InputOptions( - experimental_prefetch_to_device=False)) + experimental_fetch_to_device=False)) dataset_iterator = iter(distributed_dataset) @tf.function @@ -689,7 +689,7 @@ class TPUEmbedding(tracking.AutoTrackable): strategy.distribute_datasets_from_function( dataset_fn=..., options=tf.distribute.InputOptions( - experimental_prefetch_to_device=False)) + experimental_fetch_to_device=False)) dataset_iterator = iter(distributed_dataset) @tf.function @@ -1102,10 +1102,9 @@ class TPUEmbedding(tracking.AutoTrackable): "Received input tensor {} which is on a TPU input device {}. Input " "tensors for TPU embeddings must be placed on the CPU. Please " "ensure that your dataset is prefetching tensors to the host by " - "setting the 'experimental_prefetch_to_device' option of the " + "setting the 'experimental_fetch_to_device' option of the " "dataset distribution function. See the documentation of the " - "enqueue method for an example.".format( - path, device_string)) + "enqueue method for an example.".format(path, device_string)) # expand_composites here is important, we need to check the device of each # underlying tensor. @@ -1145,7 +1144,7 @@ class TPUEmbedding(tracking.AutoTrackable): strategy.distribute_datasets_from_function( dataset_fn=..., options=tf.distribute.InputOptions( - experimental_prefetch_to_device=False)) + experimental_fetch_to_device=False)) dataset_iterator = iter(distributed_dataset) @tf.function diff --git a/tensorflow/python/tpu/tpu_embedding_v2_correctness_test.py b/tensorflow/python/tpu/tpu_embedding_v2_correctness_test.py index c90ae959db6..0ad800532d0 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2_correctness_test.py +++ b/tensorflow/python/tpu/tpu_embedding_v2_correctness_test.py @@ -163,8 +163,7 @@ class TPUEmbeddingCorrectness(parameterized.TestCase, test.TestCase): dist = strategy.experimental_distribute_dataset( dataset, - options=distribute_lib.InputOptions( - experimental_prefetch_to_device=False)) + options=distribute_lib.InputOptions(experimental_fetch_to_device=False)) dist_iter = iter(dist) @def_function.function @@ -446,8 +445,7 @@ class TPUEmbeddingCorrectness(parameterized.TestCase, test.TestCase): input_fn = self._create_dense_input_fn(strategy) dist = strategy.distribute_datasets_from_function( input_fn, - options=distribute_lib.InputOptions( - experimental_prefetch_to_device=False)) + options=distribute_lib.InputOptions(experimental_fetch_to_device=False)) dist_iter = iter(dist) @def_function.function @@ -497,10 +495,12 @@ class TPUEmbeddingCorrectness(parameterized.TestCase, test.TestCase): dataset = self._create_sparse_dataset(strategy) else: dataset = self._create_ragged_dataset(strategy) - data = next(iter(strategy.experimental_distribute_dataset( - dataset, - options=distribute_lib.InputOptions( - experimental_prefetch_to_device=False)))) + data = next( + iter( + strategy.experimental_distribute_dataset( + dataset, + options=distribute_lib.InputOptions( + experimental_fetch_to_device=False)))) @def_function.function def embedding_and_set_gradients(data): diff --git a/tensorflow/python/tpu/tpu_embedding_v2_test.py b/tensorflow/python/tpu/tpu_embedding_v2_test.py index 6c649ee02c4..2e353983aaa 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2_test.py +++ b/tensorflow/python/tpu/tpu_embedding_v2_test.py @@ -458,10 +458,12 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase): strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd') mid_level_api.build(self.batch_size) dataset = self._create_sparse_dataset(strategy) - data = next(iter(strategy.experimental_distribute_dataset( - dataset, - options=distribute_lib.InputOptions( - experimental_prefetch_to_device=False)))) + data = next( + iter( + strategy.experimental_distribute_dataset( + dataset, + options=distribute_lib.InputOptions( + experimental_fetch_to_device=False)))) @def_function.function def embedding_and_set_gradients(data): @@ -549,8 +551,7 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase): input_fn = self._create_dense_input_fn(strategy, include_weights=True) dist = strategy.distribute_datasets_from_function( input_fn, - options=distribute_lib.InputOptions( - experimental_prefetch_to_device=False)) + options=distribute_lib.InputOptions(experimental_fetch_to_device=False)) dist_iter = iter(dist) @def_function.function @@ -570,14 +571,16 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase): sparse = self._create_sparse_dataset(strategy) ragged = self._create_ragged_dataset(strategy, include_weights=True) - sparse_iter = iter(strategy.experimental_distribute_dataset( - sparse, - options=distribute_lib.InputOptions( - experimental_prefetch_to_device=False))) - ragged_iter = iter(strategy.experimental_distribute_dataset( - ragged, - options=distribute_lib.InputOptions( - experimental_prefetch_to_device=False))) + sparse_iter = iter( + strategy.experimental_distribute_dataset( + sparse, + options=distribute_lib.InputOptions( + experimental_fetch_to_device=False))) + ragged_iter = iter( + strategy.experimental_distribute_dataset( + ragged, + options=distribute_lib.InputOptions( + experimental_fetch_to_device=False))) @def_function.function def test_fn(): @@ -598,14 +601,16 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase): sparse = self._create_sparse_dataset(strategy, include_weights=True) ragged = self._create_ragged_dataset(strategy) - sparse_iter = iter(strategy.experimental_distribute_dataset( - sparse, - options=distribute_lib.InputOptions( - experimental_prefetch_to_device=False))) - ragged_iter = iter(strategy.experimental_distribute_dataset( - ragged, - options=distribute_lib.InputOptions( - experimental_prefetch_to_device=False))) + sparse_iter = iter( + strategy.experimental_distribute_dataset( + sparse, + options=distribute_lib.InputOptions( + experimental_fetch_to_device=False))) + ragged_iter = iter( + strategy.experimental_distribute_dataset( + ragged, + options=distribute_lib.InputOptions( + experimental_fetch_to_device=False))) @def_function.function def test_fn(): @@ -626,14 +631,16 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase): sparse = self._create_sparse_dataset(strategy) ragged = self._create_ragged_dataset(strategy) - sparse_iter = iter(strategy.experimental_distribute_dataset( - sparse, - options=distribute_lib.InputOptions( - experimental_prefetch_to_device=False))) - ragged_iter = iter(strategy.experimental_distribute_dataset( - ragged, - options=distribute_lib.InputOptions( - experimental_prefetch_to_device=False))) + sparse_iter = iter( + strategy.experimental_distribute_dataset( + sparse, + options=distribute_lib.InputOptions( + experimental_fetch_to_device=False))) + ragged_iter = iter( + strategy.experimental_distribute_dataset( + ragged, + options=distribute_lib.InputOptions( + experimental_fetch_to_device=False))) @def_function.function def test_fn(): @@ -654,10 +661,11 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase): strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd') sparse = self._create_sparse_dataset(strategy) - sparse_iter = iter(strategy.experimental_distribute_dataset( - sparse, - options=distribute_lib.InputOptions( - experimental_prefetch_to_device=False))) + sparse_iter = iter( + strategy.experimental_distribute_dataset( + sparse, + options=distribute_lib.InputOptions( + experimental_fetch_to_device=False))) @def_function.function def test_fn(): @@ -677,10 +685,11 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase): strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd') sparse = self._create_sparse_dataset(strategy, include_weights=True) - sparse_iter = iter(strategy.experimental_distribute_dataset( - sparse, - options=distribute_lib.InputOptions( - experimental_prefetch_to_device=False))) + sparse_iter = iter( + strategy.experimental_distribute_dataset( + sparse, + options=distribute_lib.InputOptions( + experimental_fetch_to_device=False))) @def_function.function def test_fn(): @@ -701,14 +710,16 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase): sparse = self._create_sparse_dataset(strategy) ragged = self._create_ragged_dataset(strategy) - sparse_iter = iter(strategy.experimental_distribute_dataset( - sparse, - options=distribute_lib.InputOptions( - experimental_prefetch_to_device=False))) - ragged_iter = iter(strategy.experimental_distribute_dataset( - ragged, - options=distribute_lib.InputOptions( - experimental_prefetch_to_device=False))) + sparse_iter = iter( + strategy.experimental_distribute_dataset( + sparse, + options=distribute_lib.InputOptions( + experimental_fetch_to_device=False))) + ragged_iter = iter( + strategy.experimental_distribute_dataset( + ragged, + options=distribute_lib.InputOptions( + experimental_fetch_to_device=False))) @def_function.function def test_fn(): @@ -735,10 +746,11 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase): strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd') sparse = self._create_sparse_dataset(strategy) - sparse_iter = iter(strategy.experimental_distribute_dataset( - sparse, - options=distribute_lib.InputOptions( - experimental_prefetch_to_device=False))) + sparse_iter = iter( + strategy.experimental_distribute_dataset( + sparse, + options=distribute_lib.InputOptions( + experimental_fetch_to_device=False))) @def_function.function def test_fn(): @@ -827,10 +839,11 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase): weight=weight) mid_level_api.build(self.batch_size) - dataset_iter = iter(strategy.experimental_distribute_dataset( - dataset, - options=distribute_lib.InputOptions( - experimental_prefetch_to_device=False))) + dataset_iter = iter( + strategy.experimental_distribute_dataset( + dataset, + options=distribute_lib.InputOptions( + experimental_fetch_to_device=False))) @def_function.function def enqueue_and_get(features, weights): @@ -869,10 +882,11 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase): strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd') mid_level_api.build(self.batch_size) dataset = self._create_sparse_dataset(strategy) - dataset_iter = iter(strategy.experimental_distribute_dataset( - dataset, - options=distribute_lib.InputOptions( - experimental_prefetch_to_device=False))) + dataset_iter = iter( + strategy.experimental_distribute_dataset( + dataset, + options=distribute_lib.InputOptions( + experimental_fetch_to_device=False))) @def_function.function def enqueue_with_outside_compilation(data): @@ -906,10 +920,11 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase): strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd') dataset = self._create_sparse_dataset(strategy) - dataset_iter = iter(strategy.experimental_distribute_dataset( - dataset, - options=distribute_lib.InputOptions( - experimental_prefetch_to_device=False))) + dataset_iter = iter( + strategy.experimental_distribute_dataset( + dataset, + options=distribute_lib.InputOptions( + experimental_fetch_to_device=False))) # This is one way to force the enqueue in some control flow. @tf.functions # aren't inlined in the calling tf.function. An alternative would be to @@ -934,10 +949,11 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase): strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd') mid_level_api.build(self.batch_size) dataset = self._create_sparse_dataset(strategy) - dataset_iter = iter(strategy.experimental_distribute_dataset( - dataset, - options=distribute_lib.InputOptions( - experimental_prefetch_to_device=False))) + dataset_iter = iter( + strategy.experimental_distribute_dataset( + dataset, + options=distribute_lib.InputOptions( + experimental_fetch_to_device=False))) @def_function.function def enqueue_with_outside_compilation(): @@ -957,10 +973,11 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase): strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd') mid_level_api.build(self.batch_size) dataset = self._create_sparse_dataset(strategy) - dataset_iter = iter(strategy.experimental_distribute_dataset( - dataset, - options=distribute_lib.InputOptions( - experimental_prefetch_to_device=False))) + dataset_iter = iter( + strategy.experimental_distribute_dataset( + dataset, + options=distribute_lib.InputOptions( + experimental_fetch_to_device=False))) @def_function.function def enqueue_with_no_gradient_apply(data): @@ -1162,8 +1179,7 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase): dist = strategy.distribute_datasets_from_function( input_fn, - options=distribute_lib.InputOptions( - experimental_prefetch_to_device=False)) + options=distribute_lib.InputOptions(experimental_fetch_to_device=False)) dist_iter = iter(dist) @def_function.function diff --git a/tensorflow/python/training/tracking/data_structures.py b/tensorflow/python/training/tracking/data_structures.py index 38cc50a43f8..8e1e1d24b1a 100644 --- a/tensorflow/python/training/tracking/data_structures.py +++ b/tensorflow/python/training/tracking/data_structures.py @@ -37,6 +37,8 @@ from tensorflow.python.training.tracking import base from tensorflow.python.training.tracking import layer_utils from tensorflow.python.util import lazy_loader from tensorflow.python.util.compat import collections_abc +from tensorflow.python.util.tf_export import tf_export + module = lazy_loader.LazyLoader( "module", globals(), "tensorflow.python.module.module") @@ -83,8 +85,23 @@ def _should_wrap_tuple(t): return False +@tf_export("__internal__.tracking.wrap", v1=[]) def wrap_or_unwrap(value): - """Wraps basic data structures, unwraps NoDependency objects.""" + """Wraps input value into trackable data structures. + + This is mostly useful for containers like list, dict, etc, which could contain + trackable objects in it. Wrapped data structure will be tracked when + associated with a `tf.Module`, so that save model/checkpoint can properly + track the dependency. + + It will also unwrap NoDependency objects. + + Args: + value: the input object to be wrapped. + + Returns: + Wrapped trackable data structure. + """ # pylint: disable=unidiomatic-typecheck # Exact type checking to avoid mucking up custom logic in list/dict # subclasses, e.g. collections.Counter. diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 731fc0b71be..4332787d77b 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -3338,9 +3338,7 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan { private: // In some cases cublasLt does not support large batch sizes, so we need to // split up such cases into multiple calls. - // TODO(reedwm): Making this static or constexpr causes a link error with gcc - // in debug mode for unknown reasons. Investigate why. - const int kMaxBatchCount = 65535; + static constexpr int kMaxBatchCount = 65535; blas::BlasLtMatmulPlanParams params_; blas::DataType scale_type_; UniqueOpDesc op_desc_; @@ -3358,6 +3356,8 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan { UniqueLayoutDesc d_remainder_desc_; }; +/*static*/ constexpr int CUDABlasLtMatmulPlan::kMaxBatchCount; + bool CUDABlasLtMatmulPlan::SetBiasPointer(const void *bias) const { return SetCublasLtAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_BIAS_POINTER, bias) diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 01113f89f5e..593619ff084 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -1186,6 +1186,9 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { } // Create the params handle. + // TODO(kaixih@nvidia.com): Should be removed when cudnnRNNForward*** and + // cudnnRNNForward***Ex are removed from the codebase, since the new API + // doesn't need param descriptors any more. SE_ASSIGN_OR_RETURN(auto params_desc, CudnnRnnParamsDescriptor::Create( cudnn, input_size, data_type, rnn_desc.get(), @@ -1659,10 +1662,16 @@ port::Status CheckRNNParameterSize( const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc) { size_t params_size_in_bytes = 0; +#if CUDNN_VERSION >= 8000 + RETURN_IF_CUDNN_ERROR(cudnnGetRNNWeightSpaceSize( + /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), + /*sizeInBytes=*/¶ms_size_in_bytes)); +#else RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*xDesc=*/input_desc.handles()[0], /*sizeInBytes=*/¶ms_size_in_bytes, /*dataType=*/rnn_desc.data_type())); +#endif if (static_cast(params_size_in_bytes) != rnn_desc.ParamsSizeInBytes()) { return port::Status(port::error::INVALID_ARGUMENT, @@ -1747,6 +1756,7 @@ port::Status CudnnSupport::DoRnnForwardImpl( Stream* stream, const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const CudnnRnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const CudnnRnnStateTensorDescriptor& input_c_desc, @@ -1770,6 +1780,78 @@ port::Status CudnnSupport::DoRnnForwardImpl( auto cudnn = cudnn_->GetHandle(parent_, stream); SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc)); + + // In CUDNN v8.0, the cudnnRNNForward*** and cudnnRNNForward***Ex have been + // deprecated. Instead, we use the cudnnRNNForward which requires the + // sequence_lengths parameter. +#if CUDNN_VERSION >= 8000 + if (input_desc.is_var_seq_lengths()) { + DeviceMemory workspace; + DeviceMemory reserve_space; + cudnnForwardMode_t rnn_fwd_mode; + if (is_training) { + rnn_fwd_mode = CUDNN_FWD_MODE_TRAINING; + } else { + rnn_fwd_mode = CUDNN_FWD_MODE_INFERENCE; + } + size_t reserve_space_size_in_bytes = 0; + size_t workspace_size_in_bytes = 0; + RETURN_IF_CUDNN_ERROR(cudnnGetRNNTempSpaceSizes( + /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), + /*fMode=*/rnn_fwd_mode, /*xDesc=*/input_desc.data_handle(), + /*workSpaceSize=*/&workspace_size_in_bytes, + /*reserveSpaceSize=*/&reserve_space_size_in_bytes)); + + if (workspace_size_in_bytes > 0) { + SE_ASSIGN_OR_RETURN(workspace, workspace_allocator->AllocateBytes( + workspace_size_in_bytes)); + } + if (reserve_space_size_in_bytes > 0) { + SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes( + reserve_space_size_in_bytes)); + } + + std::unique_ptr timer; + const bool is_profiling = output_profile_result != nullptr; + if (is_profiling) { + timer.reset(new GpuTimer(parent_)); + // The start and stop of the timer should be as close to the Cudnn call as + // possible. It is still possible for other threads to issue workload on + // to this stream. So it could take multiple profiling measurements. + if (!timer->Init() || !timer->Start(AsGpuStream(stream))) { + return port::Status(port::error::INTERNAL, "Failed to start timer"); + } + } + + RETURN_IF_CUDNN_ERROR(cudnnRNNForward( + /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), + /*fwdMode=*/rnn_fwd_mode, + /*devSeqLengths=*/ + reinterpret_cast(seq_lengths_data.opaque()), + /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(), + /*yDesc=*/output_desc.data_handle(), /*y=*/output_data->opaque(), + /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(), + /*hy=*/output_h_data->opaque(), + /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(), + /*cy=*/output_c_data->opaque(), + /*weightSpaceSize=*/rnn_desc.ParamsSizeInBytes(), + /*weightSpace=*/params.opaque(), + /*workSpaceSize=*/workspace.size(), /*workspace=*/workspace.opaque(), + /*reserveSpaceSizeInBytes=*/reserve_space.size(), + /*reserveSpace=*/reserve_space.opaque())); + + if (is_profiling) { + if (!timer->Stop(AsGpuStream(stream))) { + return port::Status(port::error::INTERNAL, "Failed to stop timer"); + } + auto algo_desc = *rnn_desc.algorithm_config().algorithm(); + output_profile_result->set_algorithm(algo_desc); + output_profile_result->set_elapsed_time_in_ms( + timer->GetElapsedMilliseconds()); + } + return port::Status::OK(); + } +#endif SE_ASSIGN_OR_RETURN(DeviceMemory workspace, CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc, workspace_allocator)) @@ -1834,7 +1916,6 @@ port::Status CudnnSupport::DoRnnForwardImpl( } } else { if (input_desc.is_var_seq_lengths()) { - // cudnnSetRNNPaddingMode(rnn_desc.handle(), CUDNN_RNN_PADDED_IO_ENABLED); RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTrainingEx( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(), @@ -1887,6 +1968,7 @@ port::Status CudnnSupport::DoRnnBackwardImpl( Stream* stream, const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const CudnnRnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const CudnnRnnStateTensorDescriptor& input_c_desc, @@ -1917,6 +1999,91 @@ port::Status CudnnSupport::DoRnnBackwardImpl( auto cudnn = cudnn_->GetHandle(parent_, stream); SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc)); + + // In CUDNN v8.0, the cudnnRNNForward*** and cudnnRNNForward***Ex have been + // deprecated. Instead, we use the cudnnRNNForward which requires the + // sequence_lengths parameter. +#if CUDNN_VERSION >= 8000 + if (input_desc.is_var_seq_lengths()) { + DeviceMemory workspace; + size_t workspace_size_in_bytes = 0; + RETURN_IF_CUDNN_ERROR(cudnnGetRNNTempSpaceSizes( + /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), + /*fMode=*/CUDNN_FWD_MODE_TRAINING, /*xDesc=*/input_desc.data_handle(), + /*workSpaceSize=*/&workspace_size_in_bytes, + /*reserveSpaceSize=*/NULL)); + if (workspace_size_in_bytes > 0) { + SE_ASSIGN_OR_RETURN(workspace, workspace_allocator->AllocateBytes( + workspace_size_in_bytes)); + } + + std::unique_ptr timer; + const bool is_profiling = output_profile_result != nullptr; + if (is_profiling) { + timer.reset(new GpuTimer(parent_)); + // The start and stop of the timer should be as close to the Cudnn call as + // possible. It is still possible for other threads to issue workload on + // to this stream. So it could take multiple profiling measurements. + if (!timer->Init() || !timer->Start(AsGpuStream(stream))) { + return port::Status(port::error::INTERNAL, "Failed to start timer"); + } + } + + RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardData_v8( + /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), + /*devSeqLengths=*/ + reinterpret_cast(seq_lengths_data.opaque()), + /*yDesc=*/output_desc.data_handle(), /*y=*/output_data.opaque(), + /*dy=*/output_backprop_data.opaque(), + /*xDesc=*/input_desc.data_handle(), + /*dx=*/input_backprop_data->opaque(), + /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(), + /*dhy=*/output_h_backprop_data.opaque(), + /*dhx=*/input_h_backprop_data->opaque(), + /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(), + /*dcy=*/output_c_backprop_data.opaque(), + /*dcx=*/input_c_backprop_data->opaque(), + /*weightSpaceSize=*/rnn_desc.ParamsSizeInBytes(), + /*weightSpace=*/params.opaque(), + /*workSpaceSize=*/workspace.size(), /*workSpace=*/workspace.opaque(), + /*reserveSpaceSize=*/reserve_space_data->size(), + /*reserveSpace=*/reserve_space_data->opaque())); + + if (params_backprop_data != nullptr) { + // Clear the dw to zeros. + stream->ThenMemZero(params_backprop_data, params_backprop_data->size()); + RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights_v8( + /*handle=*/cudnn.handle(), + /*rnnDesc=*/rnn_desc.handle(), + /*addGrad=*/CUDNN_WGRAD_MODE_ADD, + /*devSeqLengths=*/ + reinterpret_cast(seq_lengths_data.opaque()), + /*xDesc=*/input_desc.data_handle(), + /*x=*/input_data.opaque(), + /*hDesc=*/input_h_desc.handle(), + /*hx=*/input_h_data.opaque(), + /*yDesc=*/output_desc.data_handle(), + /*y=*/output_data.opaque(), + /*weightSpaceSize=*/rnn_desc.ParamsSizeInBytes(), + /*dweightSpace=*/params_backprop_data->opaque(), + /*workSpaceSize=*/workspace.size(), + /*workSpace=*/workspace.opaque(), + /*reserveSpaceSize=*/reserve_space_data->size(), + /*reserveSpace=*/reserve_space_data->opaque())); + } + + if (is_profiling) { + if (!timer->Stop(AsGpuStream(stream))) { + return port::Status(port::error::INTERNAL, "Failed to stop timer"); + } + auto algo_desc = *rnn_desc.algorithm_config().algorithm(); + output_profile_result->set_algorithm(algo_desc); + output_profile_result->set_elapsed_time_in_ms( + timer->GetElapsedMilliseconds()); + } + return port::Status::OK(); + } +#endif SE_ASSIGN_OR_RETURN(DeviceMemory workspace, CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc, workspace_allocator)); @@ -2127,6 +2294,7 @@ bool CudnnSupport::DoRnnForward( Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2158,10 +2326,11 @@ bool CudnnSupport::DoRnnForward( return IsStatusOk( DoRnnForwardImpl( stream, cudnn_rnn_desc, cudnn_input_desc, input_data, - cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data, - params, cudnn_output_desc, output_data, cudnn_output_h_desc, - output_h_data, cudnn_output_c_desc, output_c_data, is_training, - reserve_space_allocator, workspace_allocator, output_profile_result), + seq_lengths_data, cudnn_input_h_desc, input_h_data, + cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, + output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, + output_c_data, is_training, reserve_space_allocator, + workspace_allocator, output_profile_result), /*report_error=*/!output_profile_result); } @@ -2169,6 +2338,7 @@ bool CudnnSupport::DoRnnForward( Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2199,10 +2369,11 @@ bool CudnnSupport::DoRnnForward( return IsStatusOk( DoRnnForwardImpl( stream, cudnn_rnn_desc, cudnn_input_desc, input_data, - cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data, - params, cudnn_output_desc, output_data, cudnn_output_h_desc, - output_h_data, cudnn_output_c_desc, output_c_data, is_training, - reserve_space_allocator, workspace_allocator, output_profile_result), + seq_lengths_data, cudnn_input_h_desc, input_h_data, + cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, + output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, + output_c_data, is_training, reserve_space_allocator, + workspace_allocator, output_profile_result), /*report_error=*/!output_profile_result); } @@ -2210,6 +2381,7 @@ bool CudnnSupport::DoRnnForward( Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2241,10 +2413,11 @@ bool CudnnSupport::DoRnnForward( return IsStatusOk( DoRnnForwardImpl( stream, cudnn_rnn_desc, cudnn_input_desc, input_data, - cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data, - params, cudnn_output_desc, output_data, cudnn_output_h_desc, - output_h_data, cudnn_output_c_desc, output_c_data, is_training, - reserve_space_allocator, workspace_allocator, output_profile_result), + seq_lengths_data, cudnn_input_h_desc, input_h_data, + cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, + output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, + output_c_data, is_training, reserve_space_allocator, + workspace_allocator, output_profile_result), /*report_error=*/!output_profile_result); } @@ -2252,6 +2425,7 @@ bool CudnnSupport::DoRnnBackward( Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2290,13 +2464,13 @@ bool CudnnSupport::DoRnnBackward( return IsStatusOk( DoRnnBackwardImpl( stream, cudnn_rnn_desc, cudnn_input_desc, input_data, - cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data, - params, cudnn_output_desc, output_data, cudnn_output_h_desc, - output_h_data, cudnn_output_c_desc, output_c_data, - output_backprop_data, output_h_backprop_data, output_c_backprop_data, - input_backprop_data, input_h_backprop_data, input_c_backprop_data, - params_backprop_data, reserve_space_data, workspace_allocator, - output_profile_result), + seq_lengths_data, cudnn_input_h_desc, input_h_data, + cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, + output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, + output_c_data, output_backprop_data, output_h_backprop_data, + output_c_backprop_data, input_backprop_data, input_h_backprop_data, + input_c_backprop_data, params_backprop_data, reserve_space_data, + workspace_allocator, output_profile_result), /*report_error=*/!output_profile_result); } @@ -2304,6 +2478,7 @@ bool CudnnSupport::DoRnnBackward( Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2341,13 +2516,13 @@ bool CudnnSupport::DoRnnBackward( return IsStatusOk( DoRnnBackwardImpl( stream, cudnn_rnn_desc, cudnn_input_desc, input_data, - cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data, - params, cudnn_output_desc, output_data, cudnn_output_h_desc, - output_h_data, cudnn_output_c_desc, output_c_data, - output_backprop_data, output_h_backprop_data, output_c_backprop_data, - input_backprop_data, input_h_backprop_data, input_c_backprop_data, - params_backprop_data, reserve_space_data, workspace_allocator, - output_profile_result), + seq_lengths_data, cudnn_input_h_desc, input_h_data, + cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, + output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, + output_c_data, output_backprop_data, output_h_backprop_data, + output_c_backprop_data, input_backprop_data, input_h_backprop_data, + input_c_backprop_data, params_backprop_data, reserve_space_data, + workspace_allocator, output_profile_result), /*report_error=*/!output_profile_result); } @@ -2355,6 +2530,7 @@ bool CudnnSupport::DoRnnBackward( Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2393,13 +2569,13 @@ bool CudnnSupport::DoRnnBackward( return IsStatusOk( DoRnnBackwardImpl( stream, cudnn_rnn_desc, cudnn_input_desc, input_data, - cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data, - params, cudnn_output_desc, output_data, cudnn_output_h_desc, - output_h_data, cudnn_output_c_desc, output_c_data, - output_backprop_data, output_h_backprop_data, output_c_backprop_data, - input_backprop_data, input_h_backprop_data, input_c_backprop_data, - params_backprop_data, reserve_space_data, workspace_allocator, - output_profile_result), + seq_lengths_data, cudnn_input_h_desc, input_h_data, + cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, + output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, + output_c_data, output_backprop_data, output_h_backprop_data, + output_c_backprop_data, input_backprop_data, input_h_backprop_data, + input_c_backprop_data, params_backprop_data, reserve_space_data, + workspace_allocator, output_profile_result), /*report_error=*/!output_profile_result); } diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h index 9cab982c9a1..941260e460c 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/stream_executor/cuda/cuda_dnn.h @@ -74,6 +74,7 @@ class CudnnSupport : public dnn::DnnSupport { bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -92,6 +93,7 @@ class CudnnSupport : public dnn::DnnSupport { bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -110,6 +112,7 @@ class CudnnSupport : public dnn::DnnSupport { bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -128,6 +131,7 @@ class CudnnSupport : public dnn::DnnSupport { bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -153,6 +157,7 @@ class CudnnSupport : public dnn::DnnSupport { bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -178,6 +183,7 @@ class CudnnSupport : public dnn::DnnSupport { bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -641,6 +647,7 @@ class CudnnSupport : public dnn::DnnSupport { Stream* stream, const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const CudnnRnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const CudnnRnnStateTensorDescriptor& input_c_desc, @@ -660,6 +667,7 @@ class CudnnSupport : public dnn::DnnSupport { Stream* stream, const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const CudnnRnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const CudnnRnnStateTensorDescriptor& input_c_desc, diff --git a/tensorflow/stream_executor/cuda/cudnn_8_0.inc b/tensorflow/stream_executor/cuda/cudnn_8_0.inc index 9161dbc8cf9..52a9d7cd2bd 100644 --- a/tensorflow/stream_executor/cuda/cudnn_8_0.inc +++ b/tensorflow/stream_executor/cuda/cudnn_8_0.inc @@ -1786,6 +1786,16 @@ cudnnStatus_t CUDNNWINAPI cudnnSetPersistentRNNPlan( return func_ptr(rnnDesc, plan); } +cudnnStatus_t CUDNNWINAPI +cudnnGetRNNWeightSpaceSize(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, + size_t *weightSpaceSize) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, + cudnnRNNDescriptor_t, size_t *); + static auto func_ptr = LoadSymbol("cudnnGetRNNWeightSpaceSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, rnnDesc, weightSpaceSize); +} + cudnnStatus_t CUDNNWINAPI cudnnGetRNNWorkspaceSize( cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, const int seqLength, const cudnnTensorDescriptor_t *xDesc, @@ -1798,6 +1808,19 @@ cudnnStatus_t CUDNNWINAPI cudnnGetRNNWorkspaceSize( return func_ptr(handle, rnnDesc, seqLength, xDesc, sizeInBytes); } +cudnnStatus_t CUDNNWINAPI cudnnGetRNNTempSpaceSizes( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, + cudnnForwardMode_t fMode, cudnnRNNDataDescriptor_t xDesc, + size_t *workSpaceSize, size_t *reserveSpaceSize) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, cudnnForwardMode_t, + cudnnRNNDataDescriptor_t, size_t *, size_t *); + static auto func_ptr = LoadSymbol("cudnnGetRNNTempSpaceSizes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, rnnDesc, fMode, xDesc, workSpaceSize, + reserveSpaceSize); +} + cudnnStatus_t CUDNNWINAPI cudnnGetRNNParamsSize(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, const cudnnTensorDescriptor_t xDesc, size_t *sizeInBytes, @@ -2748,6 +2771,28 @@ cudnnStatus_t CUDNNWINAPI cudnnRNNForwardTrainingEx( reserveSpace, reserveSpaceSizeInBytes); } +cudnnStatus_t CUDNNWINAPI cudnnRNNForward( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, + cudnnForwardMode_t fwdMode, const int32_t devSeqLengths[], + cudnnRNNDataDescriptor_t xDesc, const void *x, + cudnnRNNDataDescriptor_t yDesc, void *y, cudnnTensorDescriptor_t hDesc, + const void *hx, void *hy, cudnnTensorDescriptor_t cDesc, const void *cx, + void *cy, size_t weightSpaceSize, const void *weightSpace, + size_t workSpaceSize, void *workSpace, size_t reserveSpaceSize, + void *reserveSpace) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, cudnnForwardMode_t, const int32_t[], + cudnnRNNDataDescriptor_t, const void *, cudnnRNNDataDescriptor_t, void *, + cudnnTensorDescriptor_t, const void *, void *, cudnnTensorDescriptor_t, + const void *, void *, size_t, const void *, size_t, void *, size_t, + void *); + static auto func_ptr = LoadSymbol("cudnnRNNForward"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, rnnDesc, fwdMode, devSeqLengths, xDesc, x, yDesc, y, + hDesc, hx, hy, cDesc, cx, cy, weightSpaceSize, weightSpace, + workSpaceSize, workSpace, reserveSpaceSize, reserveSpace); +} + cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardDataEx( cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, const cudnnRNNDataDescriptor_t yDesc, const void *y, @@ -2787,6 +2832,28 @@ cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardDataEx( reserveSpaceSizeInBytes); } +cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardData_v8( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, + const int32_t devSeqLengths[], cudnnRNNDataDescriptor_t yDesc, + const void *y, const void *dy, cudnnRNNDataDescriptor_t xDesc, void *dx, + cudnnTensorDescriptor_t hDesc, const void *hx, const void *dhy, void *dhx, + cudnnTensorDescriptor_t cDesc, const void *cx, const void *dcy, void *dcx, + size_t weightSpaceSize, const void *weightSpace, size_t workSpaceSize, + void *workSpace, size_t reserveSpaceSize, void *reserveSpace) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, const int32_t[], + cudnnRNNDataDescriptor_t, const void *, const void *, + cudnnRNNDataDescriptor_t, void *, cudnnTensorDescriptor_t, const void *, + const void *, void *, cudnnTensorDescriptor_t, const void *, const void *, + void *, size_t, const void *, size_t, void *, size_t, void *); + static auto func_ptr = LoadSymbol("cudnnRNNBackwardData_v8"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, rnnDesc, devSeqLengths, yDesc, y, dy, xDesc, dx, + hDesc, hx, dhy, dhx, cDesc, cx, dcy, dcx, weightSpaceSize, + weightSpace, workSpaceSize, workSpace, reserveSpaceSize, + reserveSpace); +} + cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardWeightsEx( cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, const cudnnRNNDataDescriptor_t xDesc, const void *x, @@ -2806,6 +2873,26 @@ cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardWeightsEx( reserveSpaceSizeInBytes); } +cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardWeights_v8( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, + cudnnWgradMode_t addGrad, const int32_t devSeqLengths[], + cudnnRNNDataDescriptor_t xDesc, const void *x, + cudnnTensorDescriptor_t hDesc, const void *hx, + cudnnRNNDataDescriptor_t yDesc, const void *y, size_t weightSpaceSize, + void *dweightSpace, size_t workSpaceSize, void *workSpace, + size_t reserveSpaceSize, void *reserveSpace) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, cudnnWgradMode_t, const int32_t[], + cudnnRNNDataDescriptor_t, const void *, cudnnTensorDescriptor_t, + const void *, cudnnRNNDataDescriptor_t, const void *, size_t, void *, + size_t, void *, size_t, void *); + static auto func_ptr = LoadSymbol("cudnnRNNBackwardWeights_v8"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, rnnDesc, addGrad, devSeqLengths, xDesc, x, hDesc, hx, + yDesc, y, weightSpaceSize, dweightSpace, workSpaceSize, + workSpace, reserveSpaceSize, reserveSpace); +} + cudnnStatus_t CUDNNWINAPI cudnnMultiHeadAttnBackwardData( cudnnHandle_t handle, const cudnnAttnDescriptor_t attnDesc, const int loWinIdx[], const int hiWinIdx[], const int devSeqLengthsDQDO[], diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index 920f5fe246c..6ca42340d5b 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -2185,6 +2185,7 @@ class DnnSupport { virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2206,6 +2207,7 @@ class DnnSupport { virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2227,6 +2229,7 @@ class DnnSupport { virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2289,6 +2292,7 @@ class DnnSupport { Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2317,6 +2321,7 @@ class DnnSupport { Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2345,6 +2350,7 @@ class DnnSupport { Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, diff --git a/tensorflow/stream_executor/rocm/rocm_dnn.cc b/tensorflow/stream_executor/rocm/rocm_dnn.cc index 8c1596331f3..2e0a865e41e 100644 --- a/tensorflow/stream_executor/rocm/rocm_dnn.cc +++ b/tensorflow/stream_executor/rocm/rocm_dnn.cc @@ -2578,6 +2578,7 @@ bool MIOpenSupport::DoRnnForward( Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2621,6 +2622,7 @@ bool MIOpenSupport::DoRnnForward( Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2663,6 +2665,7 @@ bool MIOpenSupport::DoRnnForward( Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2685,6 +2688,7 @@ bool MIOpenSupport::DoRnnBackward( Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2737,6 +2741,7 @@ bool MIOpenSupport::DoRnnBackward( Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2788,6 +2793,7 @@ bool MIOpenSupport::DoRnnBackward( Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, diff --git a/tensorflow/stream_executor/rocm/rocm_dnn.h b/tensorflow/stream_executor/rocm/rocm_dnn.h index 654a1bf8f3a..11f1a1dd86d 100644 --- a/tensorflow/stream_executor/rocm/rocm_dnn.h +++ b/tensorflow/stream_executor/rocm/rocm_dnn.h @@ -101,6 +101,7 @@ class MIOpenSupport : public dnn::DnnSupport { bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -119,6 +120,7 @@ class MIOpenSupport : public dnn::DnnSupport { bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -137,6 +139,7 @@ class MIOpenSupport : public dnn::DnnSupport { bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -155,6 +158,7 @@ class MIOpenSupport : public dnn::DnnSupport { bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -180,6 +184,7 @@ class MIOpenSupport : public dnn::DnnSupport { bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -205,6 +210,7 @@ class MIOpenSupport : public dnn::DnnSupport { bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 4ad9fc128cc..ccdb467a03d 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -4539,6 +4539,7 @@ Stream &Stream::ThenRnnForward( const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory &input_data, + const DeviceMemory &seq_lengths_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory &input_h_data, const dnn::RnnStateTensorDescriptor &input_c_desc, @@ -4556,10 +4557,11 @@ Stream &Stream::ThenRnnForward( // TODO(zhengxq): add VLOG PARAM calls. if (dnn::DnnSupport *dnn = parent_->AsDnn()) { auto status = dnn->DoRnnForward( - this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, - input_c_desc, input_c_data, params, output_desc, output_data, - output_h_desc, output_h_data, output_c_desc, output_c_data, is_training, - reserve_space_allocator, workspace_allocator, output_profile_result); + this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc, + input_h_data, input_c_desc, input_c_data, params, output_desc, + output_data, output_h_desc, output_h_data, output_c_desc, output_c_data, + is_training, reserve_space_allocator, workspace_allocator, + output_profile_result); if (!status && !output_profile_result) { SetError(); } @@ -4573,6 +4575,7 @@ Stream &Stream::ThenRnnForward( const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory &input_data, + const DeviceMemory &seq_lengths_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory &input_h_data, const dnn::RnnStateTensorDescriptor &input_c_desc, @@ -4589,10 +4592,11 @@ Stream &Stream::ThenRnnForward( // TODO(zhengxq): add VLOG PARAM calls. if (dnn::DnnSupport *dnn = parent_->AsDnn()) { auto status = dnn->DoRnnForward( - this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, - input_c_desc, input_c_data, params, output_desc, output_data, - output_h_desc, output_h_data, output_c_desc, output_c_data, is_training, - reserve_space_allocator, workspace_allocator, output_profile_result); + this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc, + input_h_data, input_c_desc, input_c_data, params, output_desc, + output_data, output_h_desc, output_h_data, output_c_desc, output_c_data, + is_training, reserve_space_allocator, workspace_allocator, + output_profile_result); if (!status && !output_profile_result) { SetError(); } @@ -4606,6 +4610,7 @@ Stream &Stream::ThenRnnForward( const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory &input_data, + const DeviceMemory &seq_lengths_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory &input_h_data, const dnn::RnnStateTensorDescriptor &input_c_desc, @@ -4623,10 +4628,11 @@ Stream &Stream::ThenRnnForward( // TODO(zhengxq): add VLOG PARAM calls. if (dnn::DnnSupport *dnn = parent_->AsDnn()) { auto status = dnn->DoRnnForward( - this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, - input_c_desc, input_c_data, params, output_desc, output_data, - output_h_desc, output_h_data, output_c_desc, output_c_data, is_training, - reserve_space_allocator, workspace_allocator, output_profile_result); + this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc, + input_h_data, input_c_desc, input_c_data, params, output_desc, + output_data, output_h_desc, output_h_data, output_c_desc, output_c_data, + is_training, reserve_space_allocator, workspace_allocator, + output_profile_result); if (!status && !output_profile_result) { SetError(); } @@ -4640,6 +4646,7 @@ Stream &Stream::ThenRnnBackward( const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory &input_data, + const DeviceMemory &seq_lengths_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory &input_h_data, const dnn::RnnStateTensorDescriptor &input_c_desc, @@ -4664,9 +4671,9 @@ Stream &Stream::ThenRnnBackward( // TODO(zhengxq): add VLOG PARAM calls. if (dnn::DnnSupport *dnn = parent_->AsDnn()) { auto status = dnn->DoRnnBackward( - this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, - input_c_desc, input_c_data, params, output_desc, output_data, - output_h_desc, output_h_data, output_c_desc, output_c_data, + this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc, + input_h_data, input_c_desc, input_c_data, params, output_desc, + output_data, output_h_desc, output_h_data, output_c_desc, output_c_data, output_backprop_data, output_h_backprop_data, output_c_backprop_data, input_backprop_data, input_h_backprop_data, input_c_backprop_data, params_backprop_data, reserve_space_data, workspace_allocator, @@ -4685,6 +4692,7 @@ Stream &Stream::ThenRnnBackward( const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory &input_data, + const DeviceMemory &seq_lengths_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory &input_h_data, const dnn::RnnStateTensorDescriptor &input_c_desc, @@ -4708,9 +4716,9 @@ Stream &Stream::ThenRnnBackward( // TODO(zhengxq): add VLOG PARAM calls. if (dnn::DnnSupport *dnn = parent_->AsDnn()) { auto status = dnn->DoRnnBackward( - this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, - input_c_desc, input_c_data, params, output_desc, output_data, - output_h_desc, output_h_data, output_c_desc, output_c_data, + this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc, + input_h_data, input_c_desc, input_c_data, params, output_desc, + output_data, output_h_desc, output_h_data, output_c_desc, output_c_data, output_backprop_data, output_h_backprop_data, output_c_backprop_data, input_backprop_data, input_h_backprop_data, input_c_backprop_data, params_backprop_data, reserve_space_data, workspace_allocator, @@ -4729,6 +4737,7 @@ Stream &Stream::ThenRnnBackward( const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory &input_data, + const DeviceMemory &seq_lengths_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory &input_h_data, const dnn::RnnStateTensorDescriptor &input_c_desc, @@ -4753,9 +4762,9 @@ Stream &Stream::ThenRnnBackward( // TODO(zhengxq): add VLOG PARAM calls. if (dnn::DnnSupport *dnn = parent_->AsDnn()) { auto status = dnn->DoRnnBackward( - this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, - input_c_desc, input_c_data, params, output_desc, output_data, - output_h_desc, output_h_data, output_c_desc, output_c_data, + this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc, + input_h_data, input_c_desc, input_c_data, params, output_desc, + output_data, output_h_desc, output_h_data, output_c_desc, output_c_data, output_backprop_data, output_h_backprop_data, output_c_backprop_data, input_backprop_data, input_h_backprop_data, input_c_backprop_data, params_backprop_data, reserve_space_data, workspace_allocator, diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index cb038c9ee67..e214ee47513 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -1779,6 +1779,7 @@ class Stream { Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory &input_data, + const DeviceMemory &seq_lengths_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory &input_h_data, const dnn::RnnStateTensorDescriptor &input_c_desc, @@ -1798,6 +1799,7 @@ class Stream { Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory &input_data, + const DeviceMemory &seq_lengths_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory &input_h_data, const dnn::RnnStateTensorDescriptor &input_c_desc, @@ -1816,6 +1818,7 @@ class Stream { Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory &input_data, + const DeviceMemory &seq_lengths_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory &input_h_data, const dnn::RnnStateTensorDescriptor &input_c_desc, @@ -1837,6 +1840,7 @@ class Stream { const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory &input_data, + const DeviceMemory &seq_lengths_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory &input_h_data, const dnn::RnnStateTensorDescriptor &input_c_desc, @@ -1862,6 +1866,7 @@ class Stream { Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory &input_data, + const DeviceMemory &seq_lengths_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory &input_h_data, const dnn::RnnStateTensorDescriptor &input_c_desc, @@ -1887,6 +1892,7 @@ class Stream { Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory &input_data, + const DeviceMemory &seq_lengths_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory &input_h_data, const dnn::RnnStateTensorDescriptor &input_c_desc, diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 46ea618b8e9..3e1da5984c6 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -19,10 +19,6 @@ load( "@local_config_tensorrt//:build_defs.bzl", "if_tensorrt", ) -load( - "//tensorflow/core/platform/default:cuda_build_defs.bzl", - "if_cuda_is_configured", -) load( "@local_config_cuda//cuda:build_defs.bzl", "cuda_library", @@ -101,21 +97,6 @@ def if_nvcc(a): "//conditions:default": [], }) -# In Google builds, this corresponds to whether `--config=cuda` has been -# specified. In OSS, this corresponds to whether the environment contains -# TF_NEED_CUDA=1, which is in turn triggered by --config=using_cuda through -# .bazelrc, which is again triggered by --config=cuda. -# -# In other words, --config=cuda is sufficient for this function to return -# x both for Google and OSS builds. But for OSS builds it is not necessary. -# We are working on a plan to clean up this complicated setup. -def if_cuda_is_configured_compat(x): - # copybara:uncomment_begin(--config=cuda is necessary and sufficient) - # return if_cuda(x) - # copybara:uncomment_end_and_comment_begin - return if_cuda_is_configured(x) - # copybara:comment_end - def if_xla_available(if_true, if_false = []): return select({ clean_dep("//tensorflow:with_xla_support"): if_true, @@ -379,7 +360,7 @@ def tf_copts( if_libtpu(["-DLIBTPU_ON_GCE"], []) + if_xla_available(["-DTENSORFLOW_USE_XLA=1"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"]) + - if_mkl(["-DINTEL_MKL=1", "-DENABLE_MKLDNN_V1", "-DENABLE_INTEL_MKL_BFLOAT16", "-DINTEL_MKL_DNN_ONLY"]) + + if_mkl(["-DINTEL_MKL=1"]) + if_mkldnn_threadpool(["-DENABLE_MKLDNN_THREADPOOL"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_android_arm(["-mfpu=neon"]) + @@ -1183,9 +1164,7 @@ def tf_gpu_cc_test( }), suffix = "_gpu", tags = tags + tf_gpu_tests_tags(), - deps = deps + if_cuda_is_configured([ - clean_dep("//tensorflow/core:gpu_runtime"), - ]) + if_rocm_is_configured([ + deps = deps + if_cuda_or_rocm([ clean_dep("//tensorflow/core:gpu_runtime"), ]), ) @@ -1211,9 +1190,7 @@ def tf_gpu_cc_test( }), suffix = "_2gpu", tags = cleaned_tags, - deps = deps + if_cuda_is_configured([ - clean_dep("//tensorflow/core:gpu_runtime"), - ]) + if_rocm_is_configured([ + deps = deps + if_cuda_or_rocm([ clean_dep("//tensorflow/core:gpu_runtime"), ]), ) @@ -1396,14 +1373,14 @@ def _cuda_copts(opts = []): """ return select({ "//conditions:default": [], - "@local_config_cuda//cuda:using_nvcc": ([ + "@local_config_cuda//cuda:using_nvcc": [ "-nvcc_options=relaxed-constexpr", "-nvcc_options=ftz=true", - ]), - "@local_config_cuda//cuda:using_clang": ([ + ] + opts, + "@local_config_cuda//cuda:using_clang": [ "-fcuda-flush-denormals-to-zero", - ]), - }) + if_cuda_is_configured_compat(opts) + ] + opts, + }) # Build defs for TensorFlow kernels @@ -1428,10 +1405,9 @@ def tf_gpu_kernel_library( srcs = srcs, hdrs = hdrs, copts = copts, - deps = deps + if_cuda_is_configured_compat([ + deps = deps + if_cuda([ clean_dep("//tensorflow/stream_executor/cuda:cudart_stub"), - clean_dep("//tensorflow/core:gpu_lib"), - ]) + if_rocm_is_configured([ + ]) + if_cuda_or_rocm([ clean_dep("//tensorflow/core:gpu_lib"), ]), alwayslink = 1, @@ -1441,32 +1417,33 @@ def tf_gpu_kernel_library( def tf_gpu_library(deps = None, cuda_deps = None, copts = tf_copts(), **kwargs): """Generate a cc_library with a conditional set of CUDA dependencies. - When the library is built with --config=cuda: + When the library is built with --config=cuda: - - Both deps and cuda_deps are used as dependencies. - - The cuda runtime is added as a dependency (if necessary). - - The library additionally passes -DGOOGLE_CUDA=1 to the list of copts. - - In addition, when the library is also built with TensorRT enabled, it - additionally passes -DGOOGLE_TENSORRT=1 to the list of copts. + - Both deps and cuda_deps are used as dependencies. + - The cuda runtime is added as a dependency (if necessary). + - The library additionally passes -DGOOGLE_CUDA=1 to the list of copts. + - In addition, when the library is also built with TensorRT enabled, it + additionally passes -DGOOGLE_TENSORRT=1 to the list of copts. - Args: - - cuda_deps: BUILD dependencies which will be linked if and only if: - '--config=cuda' is passed to the bazel command line. - - deps: dependencies which will always be linked. - - copts: copts always passed to the cc_library. - - kwargs: Any other argument to cc_library. - """ + Args: + cuda_deps: BUILD dependencies which will be linked if and only if: + '--config=cuda' is passed to the bazel command line. + deps: dependencies which will always be linked. + copts: copts always passed to the cc_library. + **kwargs: Any other argument to cc_library. + """ if not deps: deps = [] if not cuda_deps: cuda_deps = [] kwargs["features"] = kwargs.get("features", []) + ["-use_header_modules"] + deps = deps + if_cuda_or_rocm(cuda_deps) cc_library( - deps = deps + if_cuda_is_configured_compat(cuda_deps + [ + deps = deps + if_cuda([ clean_dep("//tensorflow/stream_executor/cuda:cudart_stub"), "@local_config_cuda//cuda:cuda_headers", - ]) + if_rocm_is_configured(cuda_deps + [ + ]) + if_rocm_is_configured([ "@local_config_rocm//rocm:rocm_headers", ]), copts = (copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]) + if_xla_available(["-DTENSORFLOW_USE_XLA=1"]) + if_mkl(["-DINTEL_MKL=1"]) + if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"])), @@ -1862,15 +1839,12 @@ check_deps = rule( def tf_custom_op_library(name, srcs = [], gpu_srcs = [], deps = [], linkopts = [], copts = [], **kwargs): """Helper to build a dynamic library (.so) from the sources containing implementations of custom ops and kernels. """ - cuda_deps = [ + deps = deps + if_cuda_or_rocm([ clean_dep("//tensorflow/core:stream_executor_headers_lib"), + ]) + if_cuda([ "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cudart_static", - ] - rocm_deps = [ - clean_dep("//tensorflow/core:stream_executor_headers_lib"), - ] - deps = deps + tf_custom_op_library_additional_deps() + ]) + tf_custom_op_library_additional_deps() # Override EIGEN_STRONG_INLINE to inline when # --define=override_eigen_strong_inline=true to avoid long compiling time. @@ -1884,11 +1858,10 @@ def tf_custom_op_library(name, srcs = [], gpu_srcs = [], deps = [], linkopts = [ srcs = gpu_srcs, copts = copts + tf_copts() + _cuda_copts() + rocm_copts() + if_tensorrt(["-DGOOGLE_TENSORRT=1"]), - deps = deps + if_cuda_is_configured_compat(cuda_deps) + if_rocm_is_configured(rocm_deps), + deps = deps, **kwargs ) - cuda_deps.extend([":" + basename + "_gpu"]) - rocm_deps.extend([":" + basename + "_gpu"]) + deps = deps + [":" + basename + "_gpu"] check_deps( name = name + "_check_deps", @@ -1896,12 +1869,12 @@ def tf_custom_op_library(name, srcs = [], gpu_srcs = [], deps = [], linkopts = [ clean_dep("//tensorflow/core:framework"), clean_dep("//tensorflow/core:lib"), ], - deps = deps + if_cuda_is_configured_compat(cuda_deps) + if_rocm_is_configured(rocm_deps), + deps = deps, ) tf_cc_shared_object( name = name, srcs = srcs, - deps = deps + if_cuda_is_configured_compat(cuda_deps) + if_rocm_is_configured(rocm_deps), + deps = deps, data = if_static([name + "_check_deps"]), copts = copts + tf_copts(is_external = True), features = ["windows_export_all_symbols"], diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-op-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-op-error.pbtxt index 7e59615534f..c0add5ddd32 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-op-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-op-error.pbtxt @@ -10,6 +10,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -24,6 +28,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\', \'error_code\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\', \'error_code\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.errors.-aborted-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-aborted-error.pbtxt index ea9186b0b9d..aa5841341bc 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.errors.-aborted-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.errors.-aborted-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.errors.-already-exists-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-already-exists-error.pbtxt index 4e155081dd2..4c26ce8f539 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.errors.-already-exists-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.errors.-already-exists-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.errors.-cancelled-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-cancelled-error.pbtxt index b02a0e023aa..3ccf29652bc 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.errors.-cancelled-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.errors.-cancelled-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.errors.-data-loss-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-data-loss-error.pbtxt index c1fa66342a7..8129457e243 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.errors.-data-loss-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.errors.-data-loss-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.errors.-deadline-exceeded-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-deadline-exceeded-error.pbtxt index 8e037936191..1317add50a0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.errors.-deadline-exceeded-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.errors.-deadline-exceeded-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.errors.-failed-precondition-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-failed-precondition-error.pbtxt index 384d4b534c6..197b911f840 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.errors.-failed-precondition-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.errors.-failed-precondition-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.errors.-internal-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-internal-error.pbtxt index ac5c4d7879b..1486e29c995 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.errors.-internal-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.errors.-internal-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.errors.-invalid-argument-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-invalid-argument-error.pbtxt index 161edd4a7c5..299471b5872 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.errors.-invalid-argument-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.errors.-invalid-argument-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.errors.-not-found-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-not-found-error.pbtxt index 1e64730ac6d..2b33e4eaab7 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.errors.-not-found-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.errors.-not-found-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.errors.-op-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-op-error.pbtxt index b1f14c0457d..1a677bd0f60 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.errors.-op-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.errors.-op-error.pbtxt @@ -10,6 +10,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -24,6 +28,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\', \'error_code\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\', \'error_code\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.errors.-out-of-range-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-out-of-range-error.pbtxt index 6365e472868..32e532bb6bb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.errors.-out-of-range-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.errors.-out-of-range-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.errors.-permission-denied-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-permission-denied-error.pbtxt index dc8a66f9ead..7cec82752d0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.errors.-permission-denied-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.errors.-permission-denied-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.errors.-resource-exhausted-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-resource-exhausted-error.pbtxt index 85bb384b469..e26ab44bcb2 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.errors.-resource-exhausted-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.errors.-resource-exhausted-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.errors.-unauthenticated-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-unauthenticated-error.pbtxt index d57d7ac2f20..b0f09a2572b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.errors.-unauthenticated-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.errors.-unauthenticated-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.errors.-unavailable-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-unavailable-error.pbtxt index cc33e6ed8d1..8a2e01b8ba8 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.errors.-unavailable-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.errors.-unavailable-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.errors.-unimplemented-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-unimplemented-error.pbtxt index b8c2e22dbd7..01bb5af6967 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.errors.-unimplemented-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.errors.-unimplemented-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.errors.-unknown-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-unknown-error.pbtxt index 8ffcfae95b8..4c6d97d3b76 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.errors.-unknown-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.errors.-unknown-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\', \'error_code\'], varargs=None, keywords=None, defaults=[\'2\'], " + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-generator-enqueuer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-generator-enqueuer.pbtxt index 6f5ad2dc963..c253afe559a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-generator-enqueuer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-generator-enqueuer.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'sequence\', \'use_multiprocessing\', \'random_seed\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + argspec: "args=[\'self\', \'generator\', \'use_multiprocessing\', \'random_seed\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " } member_method { name: "get" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.graph_util.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.graph_util.pbtxt new file mode 100644 index 00000000000..acd92245c11 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.graph_util.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.__internal__.graph_util" +tf_module { + member_method { + name: "graph_defs_equal" + argspec: "args=[\'graph_def_1\', \'graph_def_2\', \'treat_nan_as_equal\'], varargs=None, keywords=None, defaults=[\'False\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt index 523a80614b4..bca7676cd6e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt @@ -36,6 +36,10 @@ tf_module { name: "function" mtype: "" } + member { + name: "graph_util" + mtype: "" + } member { name: "monitoring" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.tracking.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.tracking.pbtxt index 6014c72d730..223a2472e93 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.tracking.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.tracking.pbtxt @@ -8,4 +8,8 @@ tf_module { name: "Trackable" mtype: "" } + member_method { + name: "wrap" + argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None" + } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-options.pbtxt index bce836e7c78..70f800c6e4c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-options.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-options.pbtxt @@ -3,6 +3,10 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + member { + name: "experimental_fetch_to_device" + mtype: "" + } member { name: "experimental_per_replica_buffer_size" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-aborted-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-aborted-error.pbtxt index ea9186b0b9d..aa5841341bc 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.errors.-aborted-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-aborted-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-already-exists-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-already-exists-error.pbtxt index 4e155081dd2..4c26ce8f539 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.errors.-already-exists-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-already-exists-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-cancelled-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-cancelled-error.pbtxt index b02a0e023aa..3ccf29652bc 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.errors.-cancelled-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-cancelled-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-data-loss-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-data-loss-error.pbtxt index c1fa66342a7..8129457e243 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.errors.-data-loss-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-data-loss-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-deadline-exceeded-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-deadline-exceeded-error.pbtxt index 8e037936191..1317add50a0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.errors.-deadline-exceeded-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-deadline-exceeded-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-failed-precondition-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-failed-precondition-error.pbtxt index 384d4b534c6..197b911f840 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.errors.-failed-precondition-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-failed-precondition-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-internal-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-internal-error.pbtxt index ac5c4d7879b..1486e29c995 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.errors.-internal-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-internal-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-invalid-argument-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-invalid-argument-error.pbtxt index 161edd4a7c5..299471b5872 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.errors.-invalid-argument-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-invalid-argument-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-not-found-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-not-found-error.pbtxt index 1e64730ac6d..2b33e4eaab7 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.errors.-not-found-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-not-found-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-op-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-op-error.pbtxt index b1f14c0457d..1a677bd0f60 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.errors.-op-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-op-error.pbtxt @@ -10,6 +10,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -24,6 +28,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\', \'error_code\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\', \'error_code\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-out-of-range-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-out-of-range-error.pbtxt index 6365e472868..32e532bb6bb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.errors.-out-of-range-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-out-of-range-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-permission-denied-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-permission-denied-error.pbtxt index dc8a66f9ead..7cec82752d0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.errors.-permission-denied-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-permission-denied-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-resource-exhausted-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-resource-exhausted-error.pbtxt index 85bb384b469..e26ab44bcb2 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.errors.-resource-exhausted-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-resource-exhausted-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-unauthenticated-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-unauthenticated-error.pbtxt index d57d7ac2f20..b0f09a2572b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.errors.-unauthenticated-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-unauthenticated-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-unavailable-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-unavailable-error.pbtxt index cc33e6ed8d1..8a2e01b8ba8 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.errors.-unavailable-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-unavailable-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-unimplemented-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-unimplemented-error.pbtxt index b8c2e22dbd7..01bb5af6967 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.errors.-unimplemented-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-unimplemented-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-unknown-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-unknown-error.pbtxt index 8ffcfae95b8..4c6d97d3b76 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.errors.-unknown-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-unknown-error.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "error_code" mtype: "" } + member { + name: "experimental_payloads" + mtype: "" + } member { name: "message" mtype: "" @@ -25,6 +29,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'node_def\', \'op\', \'message\', \'error_code\'], varargs=None, keywords=None, defaults=[\'2\'], " + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=args, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.experimental.tensorrt.-converter.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.experimental.tensorrt.-converter.pbtxt index ec2b641cd53..74c027854f7 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.experimental.tensorrt.-converter.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.experimental.tensorrt.-converter.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'input_saved_model_dir\', \'input_saved_model_tags\', \'input_saved_model_signature_key\', \'conversion_params\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'input_saved_model_dir\', \'input_saved_model_tags\', \'input_saved_model_signature_key\', \'use_dynamic_shape\', \'dynamic_shape_profile_strategy\', \'conversion_params\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "build" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-generator-enqueuer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-generator-enqueuer.pbtxt index 6f5ad2dc963..c253afe559a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-generator-enqueuer.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-generator-enqueuer.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'sequence\', \'use_multiprocessing\', \'random_seed\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + argspec: "args=[\'self\', \'generator\', \'use_multiprocessing\', \'random_seed\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " } member_method { name: "get" diff --git a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py index 86994248cc5..bebe4daaedc 100644 --- a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py +++ b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py @@ -102,6 +102,8 @@ _NORMALIZE_TYPE[( _NORMALIZE_TYPE['typing.Generic'] = "" # TODO(mdan): Remove once the golden files are generated in Python 3.7. _NORMALIZE_TYPE[""] = 'typing.Union' +# TODO(mdan): Remove once the golden files are generated in Python 3.9. +_NORMALIZE_TYPE[""] = 'typing.Union' if sys.version_info.major == 3 and sys.version_info.minor >= 8: diff --git a/tensorflow/tools/ci_build/install/install_centos_pip_packages.sh b/tensorflow/tools/ci_build/install/install_centos_pip_packages.sh index 0f0f182a1bc..20d06642aed 100755 --- a/tensorflow/tools/ci_build/install/install_centos_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_centos_pip_packages.sh @@ -74,7 +74,7 @@ pip3 install py-cpuinfo # pylint tests require the following: pip2 install pylint==1.6.4 -pip3 install pylint==1.6.4 +pip3 install pylint==2.6.2 astroid==2.5 # pycodestyle tests require the following: pip2 install pycodestyle diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh index 9afc778a0e8..a6ef0716ae3 100755 --- a/tensorflow/tools/ci_build/install/install_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh @@ -86,7 +86,7 @@ pip3 install lazy-object-proxy==1.4.3 # pylint tests require the following version. pylint==1.6.4 hangs erratically, # thus using the updated version of 2.5.3 only for python3 as python2 is EOL # and this version is not available. -pip3 install pylint==2.5.3 +pip3 install pylint==2.6.2 astroid==2.5 # pycodestyle tests require the following: pip3 install pycodestyle diff --git a/tensorflow/tools/ci_build/nightly_release/macos/cpu_py39.sh b/tensorflow/tools/ci_build/nightly_release/macos/cpu_py39.sh index e30d3363b99..4addd598bd6 100644 --- a/tensorflow/tools/ci_build/nightly_release/macos/cpu_py39.sh +++ b/tensorflow/tools/ci_build/nightly_release/macos/cpu_py39.sh @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - -# Warning: as of Jan 20, 2020, MacOS(_EXTERNAL) images do not support Python3.9. set -e set -x @@ -31,7 +29,7 @@ PYENV_ROOT="$(pwd)/pyenv" export PYENV_ROOT export PATH="$PYENV_ROOT/bin:$PATH" eval "$(pyenv init -)" -PY_VERSION=3.9.1 +PY_VERSION=3.9.0 pyenv install -s "${PY_VERSION}" pyenv local "${PY_VERSION}" python --version diff --git a/tensorflow/tools/ci_build/presubmit/ubuntu_16/gpu_py36_full/build.sh b/tensorflow/tools/ci_build/presubmit/ubuntu_16/gpu_py36_full/build.sh index dd461b7578f..0bed0399c7b 100644 --- a/tensorflow/tools/ci_build/presubmit/ubuntu_16/gpu_py36_full/build.sh +++ b/tensorflow/tools/ci_build/presubmit/ubuntu_16/gpu_py36_full/build.sh @@ -89,7 +89,8 @@ function run_build () { --define=with_default_optimizations=true \ --define=framework_shared_object=true \ --define=with_xla_support=true \ - --define=using_cuda_nvcc=true \ + --@local_config_cuda//:enable_cuda \ + --@local_config_cuda//:cuda_compiler=nvcc \ --define=use_fast_cpp_protos=true \ --define=allow_oversize_protos=true \ --define=grpc_no_ares=true \ diff --git a/tensorflow/tools/ci_build/presubmit/ubuntu_16/sanity/build.sh b/tensorflow/tools/ci_build/presubmit/ubuntu_16/sanity/build.sh index cdc6bc644c4..4dd10dc80dc 100644 --- a/tensorflow/tools/ci_build/presubmit/ubuntu_16/sanity/build.sh +++ b/tensorflow/tools/ci_build/presubmit/ubuntu_16/sanity/build.sh @@ -28,7 +28,7 @@ function install_pylint () { # TODO(mihaimaruseac): this is used in the release build in the same way, # maybe extract out to a common? sudo python3.8 -m pip install setuptools --upgrade - sudo python3.8 -m pip install pylint==2.4.4 + sudo python3.8 -m pip install pylint==2.6.2 astroid==2.5 } function run_sanity_checks () { diff --git a/tensorflow/tools/ci_build/pylint_allowlist b/tensorflow/tools/ci_build/pylint_allowlist index e14b4388ab7..b952475cb6a 100644 --- a/tensorflow/tools/ci_build/pylint_allowlist +++ b/tensorflow/tools/ci_build/pylint_allowlist @@ -12,6 +12,7 @@ ^tensorflow/python/autograph/converters/directives_test.py.*\[E1123.*unexpected-keyword-arg ^tensorflow/python/autograph/impl/api_test.py.*\[E0202.*method-hidden ^tensorflow/python/autograph/pyct/static_analysis/activity_test.py.*\[W0611.*unused-import +^tensorflow/python/autograph/pyct/static_analysis/annos.py.*\[E0306.*invalid-repr-returned ^tensorflow/python/client/session.py.*\[E1120.*no-value-for-parameter ^tensorflow/python/compiler/tensorrt/test/base_test.py.*\[E1003.*bad-super-call ^tensorflow/python/compiler/tensorrt/test/biasadd_matmul_test.py.*\[E1003.*bad-super-call @@ -24,25 +25,35 @@ ^tensorflow/python/data/experimental/kernel_tests/data_service_ops_ft_test.py.*\[W0622.*redefined-builtin ^tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py.*\[E1121.*too-many-function-args ^tensorflow/python/data/experimental/ops/optimization.py.*\[E1120.*no-value-for-parameter +^tensorflow/python/data/kernel_tests/iterator_test.py.*\[E1123.*unexpected-keyword-arg ^tensorflow/python/data/kernel_tests/multi_device_iterator_test.py.*\[E1120.*no-value-for-parameter ^tensorflow/python/debug/lib/grpc_debug_server.py.*\[W0622.*redefined-builtin ^tensorflow/python/distribute/combinations_test.py.*\[E1124.*redundant-keyword-arg ^tensorflow/python/distribute/coordinator/cluster_coordinator.py.*\[E0702.*raising-bad-type ^tensorflow/python/distribute/parallel_device/saving.py.*\[E1120.*no-value-for-parameter ^tensorflow/python/distribute/parallel_device/saving.py.*\[E1123.*unexpected-keyword-arg +^tensorflow/python/distribute/parameter_server_strategy_test.py.*\[E1111.*assignment-from-no-return ^tensorflow/python/distribute/strategy_test_lib.py.*\[E1111.*assignment-from-no-return ^tensorflow/python/eager/function_test.py.*\[C0326.*bad-whitespace ^tensorflow/python/eager/pywrap_tfe_test.py.*\[E1121.*too-many-function-args ^tensorflow/python/eager/tensor_test.py.*\[E1120.*no-value-for-parameter ^tensorflow/python/feature_column/feature_column_test\.py.*\[E0110.*abstract-class-instantiated +^tensorflow/python/framework/dtypes.py.*\[E1120.*no-value-for-parameter ^tensorflow/python/framework/dtypes_test.py.*\[W0611.*unused-import ^tensorflow/python/framework/function_test\.py.*\[E1123.*noinline +^tensorflow/python/framework/op_callbacks_test.py.*\[E1102.*not-callable ^tensorflow/python/framework/ops.py.*\[E1102.*not-callable ^tensorflow/python/framework/ops_test.py.*\[E1130.*invalid-unary-operand-type ^tensorflow/python/framework/test_util.py.*\[C0326.*bad-whitespace +^tensorflow/python/framework/test_util_test.py.*\[E1111.*assignment-from-no-return ^tensorflow/python/framework/type_spec_test.py.*\[E1003.*bad-super-call ^tensorflow/python/grappler/tf_optimizer_test.py.*\[E1130.*invalid-unary-operand-type ^tensorflow/python/keras/callbacks\.py.*\[E1133.*not-an-iterable +^tensorflow/python/keras/datasets/boston_housing.py.*\[E1123.*unexpected-keyword-arg +^tensorflow/python/keras/datasets/imdb.py.*\[E1123.*unexpected-keyword-arg +^tensorflow/python/keras/datasets/mnist.py.*\[E1123.*unexpected-keyword-arg +^tensorflow/python/keras/datasets/reuters.py.*\[E1123.*unexpected-keyword-arg +^tensorflow/python/keras/distribute/checkpointing_test.py.*\[E1123.*unexpected-keyword-arg ^tensorflow/python/keras/engine/base_layer.py.*\[E0202.*method-hidden ^tensorflow/python/keras/engine/base_layer.py.*\[E0203.*access-member-before-definition ^tensorflow/python/keras/engine/base_layer.py.*\[E1003.*bad-super-call @@ -65,6 +76,11 @@ ^tensorflow/python/keras/optimizer_v2/nadam.py.*\[E1130.*invalid-unary-operand-type ^tensorflow/python/keras/preprocessing/image\.py.*\[E0240.*Inconsistent method resolution ^tensorflow/python/keras/saving/saved_model/json_utils.py.*\[E0202.*method-hidden +^tensorflow/python/keras/tests/tracking_util_test.py.*\[E1120.*no-value-for-parameter +^tensorflow/python/keras/tests/tracking_util_test.py.*\[E1123.*unexpected-keyword-arg +^tensorflow/python/keras/tests/tracking_util_with_v1_optimizers_test.py.*\[E1120.*no-value-for-parameter +^tensorflow/python/keras/tests/tracking_util_with_v1_optimizers_test.py.*\[E1123.*unexpected-keyword-arg +^tensorflow/python/keras/utils/composite_tensor_support_test.py.*\[E1123.*unexpected-keyword-arg ^tensorflow/python/keras/utils/conv_utils_test.py.*\[E1133.*not-an-iterable ^tensorflow/python/keras/utils/data_utils.py.*\[E1102.*not-callable ^tensorflow/python/keras/utils/generic_utils_test.py.*\[E0102.*function-redefined @@ -83,9 +99,11 @@ ^tensorflow/python/kernel_tests/parse_single_example_op_test.py.*\[E1121.*too-many-function-args ^tensorflow/python/kernel_tests/parsing_ops_test.py.*\[E1121.*too-many-function-args ^tensorflow/python/kernel_tests/partitioned_variables_test.py.*\[E1121.*too-many-function-args +^tensorflow/python/kernel_tests/random/multinomial_op_big_test.py.*\[E1123.*unexpected-keyword-arg ^tensorflow/python/kernel_tests/unicode_encode_op_test.py.*\[E1120.*no-value-for-parameter ^tensorflow/python/kernel_tests/unicode_encode_op_test.py.*\[E1123.*unexpected-keyword-arg ^tensorflow/python/lib/core/bfloat16_test.py.*\[E1121.*too-many-function-args +^tensorflow/python/lib/io/file_io_test.py.*\[E1123.*unexpected-keyword-arg ^tensorflow/python/modules_with_exports.py.*\[W0622.*redefined-builtin ^tensorflow/python/ops/control_flow_grad.py.*\[W0622.*redefined-builtin ^tensorflow/python/ops/distributions/bernoulli.py.*\[E1130.*invalid-unary-operand-type @@ -104,6 +122,7 @@ ^tensorflow/python/ops/linalg/linear_operator_circulant.py.*\[E1102.*not-callable ^tensorflow/python/ops/linalg/linear_operator_householder.py.*\[E1130.*invalid-unary-operand-type ^tensorflow/python/ops/linalg/linear_operator_identity.py.*\[E1102.*not-callable +^tensorflow/python/ops/linalg/linear_operator_test_util.py.*\[E1121.*too-many-function-args ^tensorflow/python/ops/linalg_grad.py.*\[E1130.*invalid-unary-operand-type ^tensorflow/python/ops/math_grad.py.*\[E1130.*invalid-unary-operand-type ^tensorflow/python/ops/math_ops.py.*\[E1130.*invalid-unary-operand-type @@ -124,11 +143,14 @@ ^tensorflow/python/ops/standard_ops.py.*\[W0622.*redefined-builtin ^tensorflow/python/platform/default/_gfile\.py.*\[E0301.*non-iterator ^tensorflow/python/platform/default/_googletest\.py.*\[E0102.*function\salready\sdefined +^tensorflow/python/platform/flags_test.py.*\[E1120.*no-value-for-parameter ^tensorflow/python/platform/gfile\.py.*\[E0301.*non-iterator ^tensorflow/python/platform/googletest.py.*\[E0001.*syntax-error ^tensorflow/python/profiler/profile_context.py.*\[E1111.*assignment-from-no-return ^tensorflow/python/pywrap_tensorflow.py.*\[W0622.*redefined-builtin ^tensorflow/python/saved_model/model_utils/mode_keys_test.py.*\[E1137.*unsupported-assignment-operation +^tensorflow/python/summary/writer/event_file_writer_v2.py.*\[E1111.*assignment-from-no-return +^tensorflow/python/tools/saved_model_cli.py.*\[E1123.*unexpected-keyword-arg ^tensorflow/python/tpu/async_checkpoint.py.*\[C0326.*bad-whitespace ^tensorflow/python/tpu/bfloat16.py.*\[C0326.*bad-whitespace ^tensorflow/python/tpu/datasets.py.*\[C0326.*bad-whitespace @@ -144,7 +166,12 @@ ^tensorflow/python/training/tracking/data_structures_test.py.*\[E1102.*not-callable ^tensorflow/python/training/tracking/data_structures_test.py.*\[E1120.*no-value-for-parameter ^tensorflow/python/training/tracking/data_structures_test.py.*\[E1138.*unsupported-delete-operation +^tensorflow/python/training/tracking/python_state_test.py.*\[E1123.*unexpected-keyword-arg ^tensorflow/python/training/tracking/tracking\.py.*\[E0202.*method-hidden +^tensorflow/python/training/tracking/util_test.py.*\[E1120.*no-value-for-parameter +^tensorflow/python/training/tracking/util_test.py.*\[E1123.*unexpected-keyword-arg +^tensorflow/python/training/tracking/util_with_v1_optimizers_test.py.*\[E1120.*no-value-for-parameter +^tensorflow/python/training/tracking/util_with_v1_optimizers_test.py.*\[E1123.*unexpected-keyword-arg ^tensorflow/python/util/function_utils_test.py.*\[E1120.*no-value-for-parameter ^tensorflow/python/util/tf_stack_test.py.*\[E1121.*too-many-function-args ^tensorflow/security/fuzzing/raggedCountSparseOutput_fuzz.py.*\[C0330.*bad-continuation diff --git a/tensorflow/tools/ci_build/rel/macos/cpu_py39_nonpip.sh b/tensorflow/tools/ci_build/rel/macos/cpu_py39_nonpip.sh index 0ddeec34a66..a0cbe7eced0 100644 --- a/tensorflow/tools/ci_build/rel/macos/cpu_py39_nonpip.sh +++ b/tensorflow/tools/ci_build/rel/macos/cpu_py39_nonpip.sh @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - -# Warning: as of Jan 20, 2020, MacOS(_EXTERNAL) images do not support Python3.9. set -e set -x @@ -32,7 +30,7 @@ PYENV_ROOT="$(pwd)/pyenv" export PYENV_ROOT export PATH="$PYENV_ROOT/bin:$PATH" eval "$(pyenv init -)" -PY_VERSION=3.9.1 +PY_VERSION=3.9.0 pyenv install -s "${PY_VERSION}" pyenv local "${PY_VERSION}" python --version diff --git a/tensorflow/tools/ci_build/rel/macos/cpu_py39_pip.sh b/tensorflow/tools/ci_build/rel/macos/cpu_py39_pip.sh index e8dc4b285d4..76671a0921f 100644 --- a/tensorflow/tools/ci_build/rel/macos/cpu_py39_pip.sh +++ b/tensorflow/tools/ci_build/rel/macos/cpu_py39_pip.sh @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - -# Warning: as of Jan 20, 2020, MacOS(_EXTERNAL) images do not support Python3.9. set -e set -x @@ -31,7 +29,7 @@ PYENV_ROOT="$(pwd)/pyenv" export PYENV_ROOT export PATH="$PYENV_ROOT/bin:$PATH" eval "$(pyenv init -)" -PY_VERSION=3.9.1 +PY_VERSION=3.9.0 pyenv install -s "${PY_VERSION}" pyenv local "${PY_VERSION}" python --version diff --git a/tensorflow/tools/ci_build/rel/ubuntu/sanity.sh b/tensorflow/tools/ci_build/rel/ubuntu/sanity.sh index 0dcd90ec827..fa87d0a3a22 100644 --- a/tensorflow/tools/ci_build/rel/ubuntu/sanity.sh +++ b/tensorflow/tools/ci_build/rel/ubuntu/sanity.sh @@ -25,7 +25,7 @@ sudo python3.8 -m pip install pep8 # Install pylint sudo python3.8 -m pip install setuptools --upgrade -sudo python3.8 -m pip install pylint==2.4.4 +sudo python3.8 -m pip install pylint==2.6.2 astroid==2.5 python3.8 -m pylint --version # Run tensorflow sanity checks. diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_libtensorflow.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_libtensorflow.sh new file mode 100644 index 00000000000..66adddf96ff --- /dev/null +++ b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_libtensorflow.sh @@ -0,0 +1,52 @@ +#!/bin/bash +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e + +# Source the external common scripts. +source tensorflow/tools/ci_build/release/common.sh + + +# Install latest bazel +install_bazelisk +which bazel + +# Install realpath +sudo apt-get install realpath + +# Update the version string to nightly +if [ -n "${IS_NIGHTLY}" ]; then + ./tensorflow/tools/ci_build/update_version.py --nightly +fi + +./tensorflow/tools/ci_build/linux/libtensorflow.sh + +# Copy the nightly version update script +if [ -n "${IS_NIGHTLY}" ]; then + cp tensorflow/tools/ci_build/builds/libtensorflow_nightly_symlink.sh lib_package + + echo "This package was built on $(date)" >> lib_package/build_time.txt + + tar -zcvf ubuntu_cpu_libtensorflow_binaries.tar.gz lib_package + + gsutil cp ubuntu_cpu_libtensorflow_binaries.tar.gz gs://libtensorflow-nightly/prod/tensorflow/release/ubuntu_16/latest/cpu +fi + +# Upload to go/tf-sizetracker +python3 ./tensorflow/tools/ci_build/sizetrack_helper.py \ + --team tensorflow_libtensorflow \ + --artifact_id ubuntu_cpu_nightly \ + --upload \ + --artifact "$(find lib_package -iname "libtensorflow*.tar.gz" -not -iname "*jni*" | head -n 1)" diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_py36_nonpip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_py36_nonpip.sh new file mode 100644 index 00000000000..abf79c17246 --- /dev/null +++ b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_py36_nonpip.sh @@ -0,0 +1,49 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +install_ubuntu_16_python_pip_deps python3.6 +# Update bazel +install_bazelisk + +# Run configure. +export TF_NEED_GCP=1 +export TF_NEED_HDFS=1 +export TF_NEED_S3=1 +export TF_NEED_CUDA=0 +export CC_OPT_FLAGS='-mavx -march=native' +export PYTHON_BIN_PATH=$(which python3.6) +export TF2_BEHAVIOR=1 +yes "" | "$PYTHON_BIN_PATH" configure.py +tag_filters="-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-no_oss_py36,-v1only" + +# Get the default test targets for bazel. +source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh + +# Run tests +set +e +bazel test --test_output=errors --config=opt --test_lang_filters=py \ + --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain \ + --linkopt=-lrt \ + --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ + --local_test_jobs=8 \ + --build_tag_filters="${tag_filters}" \ + --test_tag_filters="${tag_filters}" -- \ + ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... +test_xml_summary_exit diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_py36_pip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_py36_pip.sh new file mode 100644 index 00000000000..fdade918558 --- /dev/null +++ b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_py36_pip.sh @@ -0,0 +1,47 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +install_ubuntu_16_python_pip_deps python3.6 +# Update bazel +install_bazelisk + +# Export required variables for running pip.sh +export OS_TYPE="UBUNTU" +export CONTAINER_TYPE="CPU" +export TF_PYTHON_VERSION='python3.6' + +# Run configure. +export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) +yes "" | "$PYTHON_BIN_PATH" configure.py + +# Get the default test targets for bazel. +source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh + +# Export optional variables for running pip.sh +export TF_BUILD_FLAGS="--config=release_cpu_linux" +export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py --test_output=errors --verbose_failures=true --keep_going --test_env=TF2_BEHAVIOR=1" +export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " +export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" +export TF_TEST_FILTER_TAGS='-no_oss,-oss_serial,-no_oss_py36,-v1only' +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. +export TF_PROJECT_NAME="tensorflow_cpu" +export TF_PIP_TEST_ROOT="pip_test" + +./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_py37_nonpip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_py37_nonpip.sh new file mode 100644 index 00000000000..5ddf0f17cbe --- /dev/null +++ b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_py37_nonpip.sh @@ -0,0 +1,49 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +install_ubuntu_16_python_pip_deps python3.7 +# Update bazel +install_bazelisk + +# Run configure. +export TF_NEED_GCP=1 +export TF_NEED_HDFS=1 +export TF_NEED_S3=1 +export TF_NEED_CUDA=0 +export CC_OPT_FLAGS='-mavx -march=native' +export PYTHON_BIN_PATH=$(which python3.7) +export TF2_BEHAVIOR=1 +yes "" | "$PYTHON_BIN_PATH" configure.py +tag_filters="-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-no_oss_py37,-v1only" + +# Get the default test targets for bazel. +source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh + +# Run tests +set +e +bazel test --test_output=errors --config=opt --test_lang_filters=py \ + --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain \ + --linkopt=-lrt \ + --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ + --local_test_jobs=8 \ + --build_tag_filters="${tag_filters}" \ + --test_tag_filters="${tag_filters}" -- \ + ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... +test_xml_summary_exit diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_py37_pip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_py37_pip.sh new file mode 100644 index 00000000000..a728bced348 --- /dev/null +++ b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_py37_pip.sh @@ -0,0 +1,47 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +install_ubuntu_16_python_pip_deps python3.7 +# Update bazel +install_bazelisk + +# Export required variables for running pip.sh +export OS_TYPE="UBUNTU" +export CONTAINER_TYPE="CPU" +export TF_PYTHON_VERSION='python3.7' + +# Run configure. +export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) +yes "" | "$PYTHON_BIN_PATH" configure.py + +# Get the default test targets for bazel. +source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh + +# Export optional variables for running pip.sh +export TF_BUILD_FLAGS="--config=release_cpu_linux" +export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py --test_output=errors --verbose_failures=true --keep_going --test_env=TF2_BEHAVIOR=1" +export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " +export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" +export TF_TEST_FILTER_TAGS='-no_oss,-oss_serial,-no_oss_py37,-v1only' +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. +export TF_PROJECT_NAME="tensorflow_cpu" +export TF_PIP_TEST_ROOT="pip_test" + +./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_py38_nonpip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_py38_nonpip.sh new file mode 100644 index 00000000000..107cf3d4b45 --- /dev/null +++ b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_py38_nonpip.sh @@ -0,0 +1,49 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +install_ubuntu_16_python_pip_deps python3.8 +# Update bazel +install_bazelisk + +# Run configure. +export TF_NEED_GCP=1 +export TF_NEED_HDFS=1 +export TF_NEED_S3=1 +export TF_NEED_CUDA=0 +export CC_OPT_FLAGS='-mavx -march=native' +export PYTHON_BIN_PATH=$(which python3.8) +export TF2_BEHAVIOR=1 +yes "" | "$PYTHON_BIN_PATH" configure.py +tag_filters="-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-no_oss_py38,-v1only" + +# Get the default test targets for bazel. +source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh + +# Run tests +set +e +bazel test --test_output=errors --config=opt --test_lang_filters=py \ + --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain \ + --linkopt=-lrt \ + --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ + --local_test_jobs=8 \ + --build_tag_filters="${tag_filters}" \ + --test_tag_filters="${tag_filters}" -- \ + ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... +test_xml_summary_exit diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_py38_pip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_py38_pip.sh new file mode 100644 index 00000000000..c68f0832507 --- /dev/null +++ b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_py38_pip.sh @@ -0,0 +1,47 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +install_ubuntu_16_python_pip_deps python3.8 +# Update bazel +install_bazelisk + +# Export required variables for running pip.sh +export OS_TYPE="UBUNTU" +export CONTAINER_TYPE="CPU" +export TF_PYTHON_VERSION='python3.8' + +# Run configure. +export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) +yes "" | "$PYTHON_BIN_PATH" configure.py + +# Get the default test targets for bazel. +source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh + +# Export optional variables for running pip.sh +export TF_BUILD_FLAGS="--config=release_cpu_linux" +export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py --test_output=errors --verbose_failures=true --keep_going --test_env=TF2_BEHAVIOR=1" +export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " +export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" +export TF_TEST_FILTER_TAGS='-no_oss,-oss_serial,-no_oss_py38,-v1only' +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. +export TF_PROJECT_NAME="tensorflow_cpu" +export TF_PIP_TEST_ROOT="pip_test" + +./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_py39_nonpip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_py39_nonpip.sh new file mode 100644 index 00000000000..86e50ccacfd --- /dev/null +++ b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_py39_nonpip.sh @@ -0,0 +1,49 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +install_ubuntu_16_python_pip_deps python3.9 +# Update bazel +install_bazelisk + +# Run configure. +export TF_NEED_GCP=1 +export TF_NEED_HDFS=1 +export TF_NEED_S3=1 +export TF_NEED_CUDA=0 +export CC_OPT_FLAGS='-mavx -march=native' +export PYTHON_BIN_PATH=$(which python3.9) +export TF2_BEHAVIOR=1 +yes "" | "$PYTHON_BIN_PATH" configure.py +tag_filters="-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-no_oss_py39,-v1only" + +# Get the default test targets for bazel. +source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh + +# Run tests +set +e +bazel test --test_output=errors --config=opt --test_lang_filters=py \ + --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain \ + --linkopt=-lrt \ + --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ + --local_test_jobs=8 \ + --build_tag_filters="${tag_filters}" \ + --test_tag_filters="${tag_filters}" -- \ + ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... +test_xml_summary_exit diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_py39_pip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_py39_pip.sh new file mode 100644 index 00000000000..7a637791f53 --- /dev/null +++ b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/cpu_py39_pip.sh @@ -0,0 +1,47 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +install_ubuntu_16_python_pip_deps python3.9 +# Update bazel +install_bazelisk + +# Export required variables for running pip.sh +export OS_TYPE="UBUNTU" +export CONTAINER_TYPE="CPU" +export TF_PYTHON_VERSION='python3.9' + +# Run configure. +export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) +yes "" | "$PYTHON_BIN_PATH" configure.py + +# Get the default test targets for bazel. +source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh + +# Export optional variables for running pip.sh +export TF_BUILD_FLAGS="--config=release_cpu_linux" +export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py --test_output=errors --verbose_failures=true --keep_going --test_env=TF2_BEHAVIOR=1" +export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " +export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" +export TF_TEST_FILTER_TAGS='-no_oss,-oss_serial,-no_oss_py39,-v1only' +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. +export TF_PROJECT_NAME="tensorflow_cpu" +export TF_PIP_TEST_ROOT="pip_test" + +./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_libtensorflow.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_libtensorflow.sh new file mode 100644 index 00000000000..edbcbeafa53 --- /dev/null +++ b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_libtensorflow.sh @@ -0,0 +1,46 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e + +# Source the external common scripts. +source tensorflow/tools/ci_build/release/common.sh + + +# Install latest bazel +install_bazelisk +which bazel + +# Install realpath +sudo apt-get install realpath + +export TF_NEED_CUDA=1 + +# Update the version string to nightly +if [ -n "${IS_NIGHTLY}" ]; then + ./tensorflow/tools/ci_build/update_version.py --nightly +fi + +./tensorflow/tools/ci_build/linux/libtensorflow.sh + +# Copy the nightly version update script +if [ -n "${IS_NIGHTLY}" ]; then + cp tensorflow/tools/ci_build/builds/libtensorflow_nightly_symlink.sh lib_package + + echo "This package was built on $(date)" >> lib_package/build_time.txt + + tar -zcvf ubuntu_gpu_libtensorflow_binaries.tar.gz lib_package + + gsutil cp ubuntu_gpu_libtensorflow_binaries.tar.gz gs://libtensorflow-nightly/prod/tensorflow/release/ubuntu_16/latest/gpu +fi diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_pip_on_cpu.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_pip_on_cpu.sh new file mode 100755 index 00000000000..5962d1c46ec --- /dev/null +++ b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_pip_on_cpu.sh @@ -0,0 +1,61 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +install_ubuntu_16_python_pip_deps python3.6 +# Update Bazel to the desired version +install_bazelisk + +# Run configure. +export TF_NEED_GCP=1 +export TF_NEED_HDFS=1 +export TF_NEED_S3=1 +export TF_NEED_CUDA=1 +export TF_CUDA_VERSION=11 +export TF_CUDNN_VERSION=8 +export TF_NEED_TENSORRT=1 +export TENSORRT_INSTALL_PATH=/usr/local/tensorrt +export CC_OPT_FLAGS='-mavx -march=native' +export PYTHON_BIN_PATH=$(which python3.6) +export LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$TENSORRT_INSTALL_PATH/lib" +export TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_50,sm_60,sm_70,sm_75,compute_80 + +yes "" | "$PYTHON_BIN_PATH" configure.py + +######################## +## Build GPU pip package +######################## +bazel build --config=opt \ + --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11:toolchain \ + tensorflow/tools/pip_package:build_pip_package + +# Set TF nightly flag so we get the proper version of estimator +if [[ "$IS_NIGHTLY" == 1 ]]; then + NIGHTLY_FLAG="--nightly_flag" +fi + +PIP_WHL_DIR=whl +mkdir -p ${PIP_WHL_DIR} +PIP_WHL_DIR=$(readlink -f ${PIP_WHL_DIR}) # Get absolute path +bazel-bin/tensorflow/tools/pip_package/build_pip_package "${PIP_WHL_DIR}" "${NIGHTLY_FLAG}" +WHL_PATH=$(ls "${PIP_WHL_DIR}"/*.whl) + +cp "${WHL_PATH}" "$(pwd)"/. +chmod +x tensorflow/tools/ci_build/builds/docker_cpu_pip.sh +docker run -e "BAZEL_VERSION=${BAZEL_VERSION}" -e "CI_BUILD_USER=$(id -u -n)" -e "CI_BUILD_UID=$(id -u)" -e "CI_BUILD_GROUP=$(id -g -n)" -e "CI_BUILD_GID=$(id -g)" -e "CI_BUILD_HOME=/bazel_pip" -v "$(pwd)":/bazel_pip tensorflow/tensorflow:devel "./bazel_pip/tensorflow/tools/ci_build/builds/with_the_same_user" "./bazel_pip/tensorflow/tools/ci_build/builds/docker_cpu_pip.sh" diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_py36_nonpip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_py36_nonpip.sh new file mode 100644 index 00000000000..cc1c5fbe1ef --- /dev/null +++ b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_py36_nonpip.sh @@ -0,0 +1,60 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +install_ubuntu_16_python_pip_deps python3.6 +# Update bazel +install_bazelisk + +# Run configure. +export TF_NEED_GCP=1 +export TF_NEED_HDFS=1 +export TF_NEED_S3=1 +export TF_NEED_CUDA=1 +export TF_CUDA_VERSION=11 +export TF_CUDNN_VERSION=8 +export TF_NEED_TENSORRT=1 +export TENSORRT_INSTALL_PATH=/usr/local/tensorrt +export CC_OPT_FLAGS='-mavx -march=native' +export PYTHON_BIN_PATH=$(which python3.6) +export TF2_BEHAVIOR=1 +export PROJECT_NAME="tensorflow_gpu" +export LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$TENSORRT_INSTALL_PATH/lib" +export TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_50,sm_60,sm_70,sm_75,compute_80 + +yes "" | "$PYTHON_BIN_PATH" configure.py + +# Get the default test targets for bazel. +source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh + +tag_filters="gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py36,-no_cuda11" + +set +e +bazel test --config=cuda --config=opt \ + --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11:toolchain \ + --linkopt=-lrt \ + --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ + --test_lang_filters=py \ + --test_tag_filters=${tag_filters} \ + --build_tag_filters=${tag_filters} \ + --test_timeout="300,450,1200,3600" --local_test_jobs=4 \ + --test_output=errors --verbose_failures=true --keep_going \ + --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \ + -- ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... +test_xml_summary_exit diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_py36_pip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_py36_pip.sh new file mode 100644 index 00000000000..f67801f33c8 --- /dev/null +++ b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_py36_pip.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +install_ubuntu_16_python_pip_deps python3.6 +# Update bazel +install_bazelisk + +# Export required variables for running pip.sh +export OS_TYPE="UBUNTU" +export CONTAINER_TYPE="GPU" +export TF_PYTHON_VERSION='python3.6' + +# Run configure. +export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) +yes "" | "$PYTHON_BIN_PATH" configure.py + +# Get the default test targets for bazel. +source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh + +# Export optional variables for running pip.sh +export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py36,-no_cuda11' +export TF_BUILD_FLAGS="--config=release_gpu_linux_cuda_11_2 " +export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filters=${TF_TEST_FILTER_TAGS} \ +--distinct_host_configuration=false \ +--action_env=TF_CUDA_VERSION=11.2 --action_env=TF_CUDNN_VERSION=8.1 --test_env=TF2_BEHAVIOR=1 \ +--config=cuda --test_output=errors --local_test_jobs=4 --test_lang_filters=py \ +--verbose_failures=true --keep_going --define=no_tensorflow_py_deps=true \ +--run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute " +export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " +export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. +export TF_PROJECT_NAME="tensorflow" # single pip package! +export TF_PIP_TEST_ROOT="pip_test" + +# To build both tensorflow and tensorflow-gpu pip packages +export TF_BUILD_BOTH_GPU_PACKAGES=1 + +./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_py37_nonpip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_py37_nonpip.sh new file mode 100644 index 00000000000..80d6a6a19dc --- /dev/null +++ b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_py37_nonpip.sh @@ -0,0 +1,60 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +install_ubuntu_16_python_pip_deps python3.7 +# Update bazel +install_bazelisk + +# Run configure. +export TF_NEED_GCP=1 +export TF_NEED_HDFS=1 +export TF_NEED_S3=1 +export TF_NEED_CUDA=1 +export TF_CUDA_VERSION=11 +export TF_CUDNN_VERSION=8 +export TF_NEED_TENSORRT=1 +export TENSORRT_INSTALL_PATH=/usr/local/tensorrt +export CC_OPT_FLAGS='-mavx -march=native' +export PYTHON_BIN_PATH=$(which python3.7) +export TF2_BEHAVIOR=1 +export PROJECT_NAME="tensorflow_gpu" +export LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$TENSORRT_INSTALL_PATH/lib" +export TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_50,sm_60,sm_70,sm_75,compute_80 + +yes "" | "$PYTHON_BIN_PATH" configure.py + +# Get the default test targets for bazel. +source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh + +tag_filters="gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py37,-no_cuda11" + +set +e +bazel test --config=cuda --config=opt \ + --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11:toolchain \ + --linkopt=-lrt \ + --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ + --test_lang_filters=py \ + --build_tag_filters=${tag_filters} \ + --test_tag_filters=${tag_filters} \ + --test_timeout="300,450,1200,3600" --local_test_jobs=4 \ + --test_output=errors --verbose_failures=true --keep_going \ + --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \ + -- ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... +test_xml_summary_exit diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_py37_pip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_py37_pip.sh new file mode 100644 index 00000000000..8bddc1dc7c8 --- /dev/null +++ b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_py37_pip.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +install_ubuntu_16_python_pip_deps python3.7 +# Update bazel +install_bazelisk + +# Export required variables for running pip.sh +export OS_TYPE="UBUNTU" +export CONTAINER_TYPE="GPU" +export TF_PYTHON_VERSION='python3.7' + +# Run configure. +export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) +yes "" | "$PYTHON_BIN_PATH" configure.py + +# Get the default test targets for bazel. +source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh + +# Export optional variables for running pip.sh +export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py37,-no_cuda11' +export TF_BUILD_FLAGS="--config=release_gpu_linux " +export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filters=${TF_TEST_FILTER_TAGS} \ +--distinct_host_configuration=false \ +--action_env=TF_CUDA_VERSION=11 --action_env=TF_CUDNN_VERSION=8 --test_env=TF2_BEHAVIOR=1 \ +--config=cuda --test_output=errors --local_test_jobs=4 --test_lang_filters=py \ +--verbose_failures=true --keep_going --define=no_tensorflow_py_deps=true \ +--run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute " +export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " +export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. +export TF_PROJECT_NAME="tensorflow" # single pip package! +export TF_PIP_TEST_ROOT="pip_test" + +# To build both tensorflow and tensorflow-gpu pip packages +export TF_BUILD_BOTH_GPU_PACKAGES=1 + +./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_py38_nonpip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_py38_nonpip.sh new file mode 100644 index 00000000000..6b2116a5721 --- /dev/null +++ b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_py38_nonpip.sh @@ -0,0 +1,60 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +install_ubuntu_16_python_pip_deps python3.8 +# Update bazel +update_bazel_linux + +# Run configure. +export TF_NEED_GCP=1 +export TF_NEED_HDFS=1 +export TF_NEED_S3=1 +export TF_NEED_CUDA=1 +export TF_CUDA_VERSION=11 +export TF_CUDNN_VERSION=8 +export TF_NEED_TENSORRT=1 +export TENSORRT_INSTALL_PATH=/usr/local/tensorrt +export CC_OPT_FLAGS='-mavx -march=native' +export PYTHON_BIN_PATH=$(which python3.8) +export TF2_BEHAVIOR=1 +export PROJECT_NAME="tensorflow_gpu" +export LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$TENSORRT_INSTALL_PATH/lib" +export TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_50,sm_60,sm_70,sm_75,compute_80 + +yes "" | "$PYTHON_BIN_PATH" configure.py + +# Get the default test targets for bazel. +source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh + +tag_filters="gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py38,-no_cuda11" + +test +e +bazel test --config=cuda --config=opt \ + --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11:toolchain \ + --linkopt=-lrt \ + --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ + --test_lang_filters=py \ + --build_tag_filters=${tag_filters} \ + --test_tag_filters=${tag_filters} \ + --test_timeout="300,450,1200,3600" --local_test_jobs=4 \ + --test_output=errors --verbose_failures=true --keep_going \ + --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \ + -- ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... +test_xml_summary_exit diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_py38_pip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_py38_pip.sh new file mode 100644 index 00000000000..5b7f026a1b6 --- /dev/null +++ b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_py38_pip.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +install_ubuntu_16_python_pip_deps python3.8 +# Update bazel +update_bazel_linux + +# Export required variables for running pip.sh +export OS_TYPE="UBUNTU" +export CONTAINER_TYPE="GPU" +export TF_PYTHON_VERSION='python3.8' + +# Run configure. +export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) +yes "" | "$PYTHON_BIN_PATH" configure.py + +# Get the default test targets for bazel. +source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh + +# Export optional variables for running pip.sh +export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py38,-no_cuda11' +export TF_BUILD_FLAGS="--config=release_gpu_linux " +export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filters=${TF_TEST_FILTER_TAGS} \ +--distinct_host_configuration=false \ +--action_env=TF_CUDA_VERSION=11 --action_env=TF_CUDNN_VERSION=8 --test_env=TF2_BEHAVIOR=1 \ +--config=cuda --test_output=errors --local_test_jobs=4 --test_lang_filters=py \ +--verbose_failures=true --keep_going --define=no_tensorflow_py_deps=true \ +--run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute " +export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " +export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. +export TF_PROJECT_NAME="tensorflow" # single pip package! +export TF_PIP_TEST_ROOT="pip_test" + +# To build both tensorflow and tensorflow-gpu pip packages +export TF_BUILD_BOTH_GPU_PACKAGES=1 + +./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_py39_nonpip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_py39_nonpip.sh new file mode 100644 index 00000000000..020dd06c206 --- /dev/null +++ b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_py39_nonpip.sh @@ -0,0 +1,60 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +install_ubuntu_16_python_pip_deps python3.9 +# Update bazel +update_bazel_linux + +# Run configure. +export TF_NEED_GCP=1 +export TF_NEED_HDFS=1 +export TF_NEED_S3=1 +export TF_NEED_CUDA=1 +export TF_CUDA_VERSION=11 +export TF_CUDNN_VERSION=8 +export TF_NEED_TENSORRT=1 +export TENSORRT_INSTALL_PATH=/usr/local/tensorrt +export CC_OPT_FLAGS='-mavx -march=native' +export PYTHON_BIN_PATH=$(which python3.9) +export TF2_BEHAVIOR=1 +export PROJECT_NAME="tensorflow_gpu" +export LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$TENSORRT_INSTALL_PATH/lib" +export TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_50,sm_60,sm_70,sm_75,compute_80 + +yes "" | "$PYTHON_BIN_PATH" configure.py + +# Get the default test targets for bazel. +source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh + +tag_filters="gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py39,-no_cuda11" + +test +e +bazel test --config=cuda --config=opt \ + --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11:toolchain \ + --linkopt=-lrt \ + --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ + --test_lang_filters=py \ + --build_tag_filters=${tag_filters} \ + --test_tag_filters=${tag_filters} \ + --test_timeout="300,450,1200,3600" --local_test_jobs=4 \ + --test_output=errors --verbose_failures=true --keep_going \ + --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \ + -- ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... +test_xml_summary_exit diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_py39_pip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_py39_pip.sh new file mode 100644 index 00000000000..e3ed31c4305 --- /dev/null +++ b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/gpu_py39_pip.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +install_ubuntu_16_python_pip_deps python3.9 +# Update bazel +update_bazel_linux + +# Export required variables for running pip.sh +export OS_TYPE="UBUNTU" +export CONTAINER_TYPE="GPU" +export TF_PYTHON_VERSION='python3.9' + +# Run configure. +export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) +yes "" | "$PYTHON_BIN_PATH" configure.py + +# Get the default test targets for bazel. +source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh + +# Export optional variables for running pip.sh +export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py39,-no_cuda11' +export TF_BUILD_FLAGS="--config=release_gpu_linux " +export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filters=${TF_TEST_FILTER_TAGS} \ +--distinct_host_configuration=false \ +--action_env=TF_CUDA_VERSION=11 --action_env=TF_CUDNN_VERSION=8 --test_env=TF2_BEHAVIOR=1 \ +--config=cuda --test_output=errors --local_test_jobs=4 --test_lang_filters=py \ +--verbose_failures=true --keep_going --define=no_tensorflow_py_deps=true \ +--run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute " +export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " +export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. +export TF_PROJECT_NAME="tensorflow" # single pip package! +export TF_PIP_TEST_ROOT="pip_test" + +# To build both tensorflow and tensorflow-gpu pip packages +export TF_BUILD_BOTH_GPU_PACKAGES=1 + +./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/sanity.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/sanity.sh new file mode 100644 index 00000000000..0dcd90ec827 --- /dev/null +++ b/tensorflow/tools/ci_build/rel/ubuntu_cuda11_2/sanity.sh @@ -0,0 +1,32 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e + +# Install latest bazel +source tensorflow/tools/ci_build/release/common.sh +install_bazelisk +which bazel + +# We need py3 lint +sudo python3.8 -m pip install pep8 + +# Install pylint +sudo python3.8 -m pip install setuptools --upgrade +sudo python3.8 -m pip install pylint==2.4.4 +python3.8 -m pylint --version + +# Run tensorflow sanity checks. +tensorflow/tools/ci_build/ci_sanity.sh diff --git a/tensorflow/tools/ci_build/rel/windows_cuda11_2/cpu_libtensorflow.bat b/tensorflow/tools/ci_build/rel/windows_cuda11_2/cpu_libtensorflow.bat new file mode 100644 index 00000000000..07c5456600a --- /dev/null +++ b/tensorflow/tools/ci_build/rel/windows_cuda11_2/cpu_libtensorflow.bat @@ -0,0 +1,22 @@ +:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: ============================================================================= + +CALL tensorflow\tools\ci_build\release\common_win.bat + +call tensorflow\tools\ci_build\windows\cpu\bazel\run_libtensorflow.bat || exit /b 1 + +copy lib_package %TF_ARTIFACTS_DIR%\lib_package + +CALL gsutil cp windows_cpu_libtensorflow_binaries.tar.gz gs://libtensorflow-nightly/prod/tensorflow/release/windows/latest/cpu diff --git a/tensorflow/tools/ci_build/rel/windows_cuda11_2/cpu_py36.bat b/tensorflow/tools/ci_build/rel/windows_cuda11_2/cpu_py36.bat new file mode 100644 index 00000000000..fde52ca24a5 --- /dev/null +++ b/tensorflow/tools/ci_build/rel/windows_cuda11_2/cpu_py36.bat @@ -0,0 +1,24 @@ +:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: ============================================================================= + +SET PYTHON_DIRECTORY=Python36 + +CALL tensorflow\tools\ci_build\release\common_win.bat + +if "%IS_NIGHTLY%" == "1" ( + call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" +) else ( + call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow_cpu" +) diff --git a/tensorflow/tools/ci_build/rel/windows_cuda11_2/cpu_py37.bat b/tensorflow/tools/ci_build/rel/windows_cuda11_2/cpu_py37.bat new file mode 100644 index 00000000000..4b696bb744e --- /dev/null +++ b/tensorflow/tools/ci_build/rel/windows_cuda11_2/cpu_py37.bat @@ -0,0 +1,24 @@ +:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: ============================================================================= + +SET PYTHON_DIRECTORY=Python37 + +CALL tensorflow\tools\ci_build\release\common_win.bat + +if "%IS_NIGHTLY%" == "1" ( + call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" +) else ( + call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow_cpu" +) diff --git a/tensorflow/tools/ci_build/rel/windows_cuda11_2/cpu_py38.bat b/tensorflow/tools/ci_build/rel/windows_cuda11_2/cpu_py38.bat new file mode 100644 index 00000000000..a1657b077cb --- /dev/null +++ b/tensorflow/tools/ci_build/rel/windows_cuda11_2/cpu_py38.bat @@ -0,0 +1,24 @@ +:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: ============================================================================= + +SET PYTHON_DIRECTORY=Python38 + +CALL tensorflow\tools\ci_build\release\common_win.bat + +if "%IS_NIGHTLY%" == "1" ( + call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" +) else ( + call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow_cpu" +) diff --git a/tensorflow/tools/ci_build/rel/windows_cuda11_2/cpu_py39.bat b/tensorflow/tools/ci_build/rel/windows_cuda11_2/cpu_py39.bat new file mode 100644 index 00000000000..2bb9a74bf49 --- /dev/null +++ b/tensorflow/tools/ci_build/rel/windows_cuda11_2/cpu_py39.bat @@ -0,0 +1,24 @@ +:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: ============================================================================= + +SET PYTHON_DIRECTORY=Python39 + +CALL tensorflow\tools\ci_build\release\common_win.bat + +if "%IS_NIGHTLY%" == "1" ( + call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" +) else ( + call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow_cpu" +) diff --git a/tensorflow/tools/ci_build/rel/windows_cuda11_2/gpu_libtensorflow.bat b/tensorflow/tools/ci_build/rel/windows_cuda11_2/gpu_libtensorflow.bat new file mode 100644 index 00000000000..4a766fc0088 --- /dev/null +++ b/tensorflow/tools/ci_build/rel/windows_cuda11_2/gpu_libtensorflow.bat @@ -0,0 +1,22 @@ +:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: ============================================================================= + +CALL tensorflow\tools\ci_build\release\common_win.bat + +call tensorflow\tools\ci_build\windows\gpu\bazel\run_libtensorflow.bat || exit /b + +copy lib_package %TF_ARTIFACTS_DIR%\lib_package + +CALL gsutil cp windows_gpu_libtensorflow_binaries.tar.gz gs://libtensorflow-nightly/prod/tensorflow/release/windows/latest/gpu diff --git a/tensorflow/tools/ci_build/rel/windows_cuda11_2/gpu_pip_on_cpu.bat b/tensorflow/tools/ci_build/rel/windows_cuda11_2/gpu_pip_on_cpu.bat new file mode 100644 index 00000000000..213de532069 --- /dev/null +++ b/tensorflow/tools/ci_build/rel/windows_cuda11_2/gpu_pip_on_cpu.bat @@ -0,0 +1,21 @@ +:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: ============================================================================= + +SET PYTHON_DIRECTORY=Python36 + +CALL tensorflow\tools\ci_build\release\common_win.bat + +call tensorflow\tools\ci_build\windows\integration\gpu_pip_on_cpu\run.bat + diff --git a/tensorflow/tools/ci_build/rel/windows_cuda11_2/gpu_py36.bat b/tensorflow/tools/ci_build/rel/windows_cuda11_2/gpu_py36.bat new file mode 100644 index 00000000000..3d16ff1e5a6 --- /dev/null +++ b/tensorflow/tools/ci_build/rel/windows_cuda11_2/gpu_py36.bat @@ -0,0 +1,25 @@ +:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: ============================================================================= + +SET PYTHON_DIRECTORY=Python36 + +CALL tensorflow\tools\ci_build\release\common_win.bat + +if "%IS_NIGHTLY%" == "1" ( + call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" +) else ( + call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow" + bash -l tensorflow\tools\ci_build\release\windows\gpu_py36_full\release_pip_rename.sh +) diff --git a/tensorflow/tools/ci_build/rel/windows_cuda11_2/gpu_py37.bat b/tensorflow/tools/ci_build/rel/windows_cuda11_2/gpu_py37.bat new file mode 100644 index 00000000000..2b7a3e72750 --- /dev/null +++ b/tensorflow/tools/ci_build/rel/windows_cuda11_2/gpu_py37.bat @@ -0,0 +1,25 @@ +:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: ============================================================================= + +SET PYTHON_DIRECTORY=Python37 + +CALL tensorflow\tools\ci_build\release\common_win.bat + +if "%IS_NIGHTLY%" == "1" ( + call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" +) else ( + call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow" + bash -l tensorflow\tools\ci_build\release\windows\gpu_py37_full\release_pip_rename.sh +) diff --git a/tensorflow/tools/ci_build/rel/windows_cuda11_2/gpu_py38.bat b/tensorflow/tools/ci_build/rel/windows_cuda11_2/gpu_py38.bat new file mode 100644 index 00000000000..15f7495b9c1 --- /dev/null +++ b/tensorflow/tools/ci_build/rel/windows_cuda11_2/gpu_py38.bat @@ -0,0 +1,25 @@ +:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: ============================================================================= + +SET PYTHON_DIRECTORY=Python38 + +CALL tensorflow\tools\ci_build\release\common_win.bat + +if "%IS_NIGHTLY%" == "1" ( + call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" +) else ( + call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow" + bash -l tensorflow\tools\ci_build\release\windows\gpu_py38_full\release_pip_rename.sh +) diff --git a/tensorflow/tools/ci_build/rel/windows_cuda11_2/gpu_py39.bat b/tensorflow/tools/ci_build/rel/windows_cuda11_2/gpu_py39.bat new file mode 100644 index 00000000000..70370c84460 --- /dev/null +++ b/tensorflow/tools/ci_build/rel/windows_cuda11_2/gpu_py39.bat @@ -0,0 +1,25 @@ +:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: ============================================================================= + +SET PYTHON_DIRECTORY=Python39 + +CALL tensorflow\tools\ci_build\release\common_win.bat + +if "%IS_NIGHTLY%" == "1" ( + call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" +) else ( + call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow" + bash -l tensorflow\tools\ci_build\release\windows\gpu_py39_full\release_pip_rename.sh +) diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/sanity/build.sh b/tensorflow/tools/ci_build/release/ubuntu_16/sanity/build.sh index 0dcd90ec827..fa87d0a3a22 100644 --- a/tensorflow/tools/ci_build/release/ubuntu_16/sanity/build.sh +++ b/tensorflow/tools/ci_build/release/ubuntu_16/sanity/build.sh @@ -25,7 +25,7 @@ sudo python3.8 -m pip install pep8 # Install pylint sudo python3.8 -m pip install setuptools --upgrade -sudo python3.8 -m pip install pylint==2.4.4 +sudo python3.8 -m pip install pylint==2.6.2 astroid==2.5 python3.8 -m pylint --version # Run tensorflow sanity checks. diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index c2fae97daea..baa95512b09 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -116,13 +116,13 @@ COMMON_PIP_DEPS = [ "//tensorflow/python:memory_checker", "//tensorflow/python:meta_graph_testdata", "//tensorflow/python/data/benchmarks:benchmark_base", - "//tensorflow/python/data/experimental/kernel_tests/serialization:dataset_serialization_test_base", "//tensorflow/python/data/experimental/kernel_tests:data_service_test_base", "//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base", "//tensorflow/python/data/experimental/kernel_tests:stats_dataset_test_base", "//tensorflow/python/data/experimental/ops:testing", "//tensorflow/python/data/experimental/service:server_lib", "//tensorflow/python/ops/ragged:ragged_tensor_test_ops", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/debug:debug_pip", "//tensorflow/python/distribute:combinations", diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index 8907896b0ec..8ab807713c0 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -1,5 +1,6 @@ """TensorFlow workspace initialization. Consult the WORKSPACE on how to use it.""" +# Import third party config rules. load("//tensorflow:version_check.bzl", "check_bazel_version_at_least") load("//third_party/gpus:cuda_configure.bzl", "cuda_configure") load("//third_party/gpus:rocm_configure.bzl", "rocm_configure") @@ -8,13 +9,13 @@ load("//third_party/nccl:nccl_configure.bzl", "nccl_configure") load("//third_party/git:git_configure.bzl", "git_configure") load("//third_party/py:python_configure.bzl", "python_configure") load("//third_party/systemlibs:syslibs_configure.bzl", "syslibs_configure") -load("//third_party/toolchains/remote:configure.bzl", "remote_execution_configure") -load("//third_party/toolchains/clang6:repo.bzl", "clang6_configure") load("//third_party/toolchains/cpus/arm:arm_compiler_configure.bzl", "arm_compiler_configure") load("//third_party/toolchains/embedded/arm-linux:arm_linux_toolchain_configure.bzl", "arm_linux_toolchain_configure") load("//third_party:repo.bzl", "tf_http_archive") load("//third_party/clang_toolchain:cc_configure_clang.bzl", "cc_download_clang_toolchain") load("//tensorflow/tools/def_file_filter:def_file_filter_configure.bzl", "def_file_filter_configure") + +# Import third party repository rules. See go/tfbr-thirdparty. load("//third_party/FP16:workspace.bzl", FP16 = "repo") load("//third_party/absl:workspace.bzl", absl = "repo") load("//third_party/aws:workspace.bzl", aws = "repo") @@ -39,50 +40,46 @@ load("//third_party/psimd:workspace.bzl", psimd = "repo") load("//third_party/ruy:workspace.bzl", ruy = "repo") load("//third_party/sobol_data:workspace.bzl", sobol_data = "repo") load("//third_party/vulkan_headers:workspace.bzl", vulkan_headers = "repo") + +# Import external repository rules. load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_file") load("@bazel_tools//tools/build_defs/repo:java.bzl", "java_import_external") load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external") load("@tf_toolchains//toolchains/remote_config:configs.bzl", "initialize_rbe_configs") +load("@tf_toolchains//toolchains/remote:configure.bzl", "remote_execution_configure") +load("@tf_toolchains//toolchains/clang6:repo.bzl", "clang6_configure") def _initialize_third_party(): """ Load third party repositories. See above load() statements. """ FP16() + absl() aws() clog() cpuinfo() dlpack() + eigen3() + farmhash() flatbuffers() + gemmlowp() hexagon_nn() highwayhash() hwloc() icu() - kissfft() jpeg() + kissfft() nasm() opencl_headers() pasta() psimd() + ruy() sobol_data() vulkan_headers() - ruy() - -# Sanitize a dependency so that it works correctly from code that includes -# TensorFlow as a submodule. -def _clean_dep(dep): - return str(Label(dep)) # Toolchains & platforms required by Tensorflow to build. def _tf_toolchains(): native.register_execution_platforms("@local_execution_config_platform//:platform") native.register_toolchains("@local_execution_config_python//:py_toolchain") -# Define all external repositories required by TensorFlow -def _tf_repositories(): - """All external dependencies for TF builds.""" - - # Initialize toolchains and platforms. - _tf_toolchains() - # Loads all external repos to configure RBE builds. initialize_rbe_configs() @@ -98,8 +95,6 @@ def _tf_repositories(): rocm_configure(name = "local_config_rocm") remote_execution_configure(name = "local_config_remote_execution") - _initialize_third_party() - # For windows bazel build # TODO: Remove def file filter when TensorFlow can export symbols properly on Windows. def_file_filter_configure(name = "local_config_def_file_filter") @@ -120,6 +115,10 @@ def _tf_repositories(): armhf_repo = "../armhf_linux_toolchain", ) +# Define all external repositories required by TensorFlow +def _tf_repositories(): + """All external dependencies for TF builds.""" + # To update any of the dependencies bellow: # a) update URL and strip_prefix to the new git commit hash # b) get the sha256 hash of the commit by running: @@ -177,10 +176,6 @@ def _tf_repositories(): ], ) - absl("com_google_absl") - - eigen3(name = "eigen_archive") - tf_http_archive( name = "arm_compiler", build_file = "//:arm_compiler.BUILD", @@ -300,10 +295,6 @@ def _tf_repositories(): ], ) - gemmlowp("gemmlowp") - - farmhash("farmhash_archive") - tf_http_archive( name = "png", build_file = "//third_party:png.BUILD", @@ -538,15 +529,15 @@ def _tf_repositories(): tf_http_archive( name = "com_google_protobuf", patch_file = "//third_party/protobuf:protobuf.patch", - sha256 = "cfcba2df10feec52a84208693937c17a4b5df7775e1635c1e3baffc487b24c9b", - strip_prefix = "protobuf-3.9.2", + sha256 = "9748c0d90e54ea09e5e75fb7fac16edce15d2028d4356f32211cfa3c0e956564", + strip_prefix = "protobuf-3.11.4", system_build_file = "//third_party/systemlibs:protobuf.BUILD", system_link_files = { "//third_party/systemlibs:protobuf.bzl": "protobuf.bzl", }, urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/protocolbuffers/protobuf/archive/v3.9.2.zip", - "https://github.com/protocolbuffers/protobuf/archive/v3.9.2.zip", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/protocolbuffers/protobuf/archive/v3.11.4.zip", + "https://github.com/protocolbuffers/protobuf/archive/v3.11.4.zip", ], ) @@ -1078,8 +1069,17 @@ def workspace(): # those rules rely on the version we require here. check_bazel_version_at_least("1.0.0") - # Load tf_repositories() before loading dependencies for other repository so - # that dependencies like com_google_protobuf won't be overridden. + # Initialize toolchains and platforms. + _tf_toolchains() + + # Import third party repositories according to go/tfbr-thirdparty. + _initialize_third_party() + + # Import all other repositories. This should happen before initializing + # any external repositories, because those come with their own + # dependencies. Those recursive dependencies will only be imported if they + # don't already exist (at least if the external repository macros were + # written according to common practice to query native.existing_rule()). _tf_repositories() # Alias so it can be loaded without assigning to a different symbol to prevent diff --git a/tensorflow/workspace3.bzl b/tensorflow/workspace3.bzl index de0144ba11f..218cc7cd75f 100644 --- a/tensorflow/workspace3.bzl +++ b/tensorflow/workspace3.bzl @@ -15,11 +15,11 @@ def workspace(): http_archive( name = "tf_toolchains", - sha256 = "d60f9637c64829e92dac3f4477a2c45cdddb9946c5da0dd46db97765eb9de08e", - strip_prefix = "toolchains-1.1.5", + sha256 = "abc6d1b705e3a36e029fef4a85009deafbd7dfc725de2a224c2d22f0778ef092", + strip_prefix = "toolchains-1.1.7", urls = [ - "http://mirror.tensorflow.org/github.com/tensorflow/toolchains/archive/v1.1.5.tar.gz", - "https://github.com/tensorflow/toolchains/archive/v1.1.5.tar.gz", + "http://mirror.tensorflow.org/github.com/tensorflow/toolchains/archive/v1.1.7.tar.gz", + "https://github.com/tensorflow/toolchains/archive/v1.1.7.tar.gz", ], ) diff --git a/third_party/absl/workspace.bzl b/third_party/absl/workspace.bzl index e1987f475c3..a6ab27b33bf 100644 --- a/third_party/absl/workspace.bzl +++ b/third_party/absl/workspace.bzl @@ -2,7 +2,7 @@ load("//third_party:repo.bzl", "tf_http_archive") -def repo(name): +def repo(): """Imports absl.""" # Attention: tools parse and update these lines. @@ -10,7 +10,7 @@ def repo(name): ABSL_SHA256 = "f368a8476f4e2e0eccf8a7318b98dafbe30b2600f4e3cf52636e5eb145aba06a" tf_http_archive( - name = name, + name = "com_google_absl", sha256 = ABSL_SHA256, build_file = "//third_party/absl:com_google_absl.BUILD", # TODO: Remove the patch when https://github.com/abseil/abseil-cpp/issues/326 is resolved diff --git a/third_party/eigen3/workspace.bzl b/third_party/eigen3/workspace.bzl index 96dc3fc4f3a..b4d4b9b7686 100644 --- a/third_party/eigen3/workspace.bzl +++ b/third_party/eigen3/workspace.bzl @@ -2,7 +2,7 @@ load("//third_party:repo.bzl", "tf_http_archive") -def repo(name): +def repo(): """Imports Eigen.""" # Attention: tools parse and update these lines. @@ -10,7 +10,7 @@ def repo(name): EIGEN_SHA256 = "6ae281a5a32d0f4185856e790c06f58858ffc16594483281621746ffb74d88a2" tf_http_archive( - name = name, + name = "eigen_archive", build_file = "//third_party/eigen3:eigen_archive.BUILD", sha256 = EIGEN_SHA256, strip_prefix = "eigen-{commit}".format(commit = EIGEN_COMMIT), diff --git a/third_party/farmhash/workspace.bzl b/third_party/farmhash/workspace.bzl index a15ae663cd2..f72fb746949 100644 --- a/third_party/farmhash/workspace.bzl +++ b/third_party/farmhash/workspace.bzl @@ -2,7 +2,7 @@ load("//third_party:repo.bzl", "tf_http_archive") -def repo(name): +def repo(): """Imports farmhash.""" # Attention: tools parse and update these lines. @@ -10,7 +10,7 @@ def repo(name): FARMHASH_SHA256 = "6560547c63e4af82b0f202cb710ceabb3f21347a4b996db565a411da5b17aba0" tf_http_archive( - name = name, + name = "farmhash_archive", build_file = "//third_party/farmhash:farmhash.BUILD", sha256 = FARMHASH_SHA256, strip_prefix = "farmhash-{commit}".format(commit = FARMHASH_COMMIT), diff --git a/third_party/gemmlowp/workspace.bzl b/third_party/gemmlowp/workspace.bzl index 7b71e43e6dd..da53dfcb745 100644 --- a/third_party/gemmlowp/workspace.bzl +++ b/third_party/gemmlowp/workspace.bzl @@ -2,7 +2,7 @@ load("//third_party:repo.bzl", "tf_http_archive") -def repo(name): +def repo(): """Imports gemmlowp.""" # Attention: tools parse and update these lines. @@ -10,7 +10,7 @@ def repo(name): GEMMLOWP_SHA256 = "43146e6f56cb5218a8caaab6b5d1601a083f1f31c06ff474a4378a7d35be9cfb" tf_http_archive( - name = name, + name = "gemmlowp", sha256 = GEMMLOWP_SHA256, strip_prefix = "gemmlowp-{commit}".format(commit = GEMMLOWP_COMMIT), urls = [ diff --git a/third_party/gpus/cuda/BUILD.tpl b/third_party/gpus/cuda/BUILD.tpl index bf4ee62e688..8f783e58d66 100644 --- a/third_party/gpus/cuda/BUILD.tpl +++ b/third_party/gpus/cuda/BUILD.tpl @@ -1,48 +1,57 @@ load(":build_defs.bzl", "cuda_header_library") load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load("@bazel_skylib//lib:selects.bzl", "selects") licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like package(default_visibility = ["//visibility:public"]) -config_setting( - name = "using_nvcc", - values = { - "define": "using_cuda_nvcc=true", - }, +# Config setting whether TensorFlow is built with CUDA support using clang. +# +# TODO(b/174244321), DEPRECATED: this target will be removed when all users +# have been converted to :is_cuda_enabled (most) or :is_cuda_compiler_clang. +selects.config_setting_group( + name = "using_clang", + match_all = [ + "@local_config_cuda//:is_cuda_enabled", + "@local_config_cuda//:is_cuda_compiler_clang", + ], ) -config_setting( - name = "using_clang", - values = { - "define": "using_cuda_clang=true", - }, +# Config setting whether TensorFlow is built with CUDA support using nvcc. +# +# TODO(b/174244321), DEPRECATED: this target will be removed when all users +# have been converted to :is_cuda_enabled (most) or :is_cuda_compiler_nvcc. +selects.config_setting_group( + name = "using_nvcc", + match_all = [ + "@local_config_cuda//:is_cuda_enabled", + "@local_config_cuda//:is_cuda_compiler_nvcc", + ], ) # Equivalent to using_clang && -c opt. -config_setting( +selects.config_setting_group( name = "using_clang_opt", - values = { - "define": "using_cuda_clang=true", - "compilation_mode": "opt", - }, + match_all = [ + ":using_clang", + ":_opt", + ], ) config_setting( - name = "darwin", - values = {"cpu": "darwin"}, -) - -config_setting( - name = "freebsd", - values = {"cpu": "freebsd"}, + name = "_opt", + values = {"compilation_mode": "opt"}, + visibility = ["//visibility:private"], ) +# Provides CUDA headers for '#include "third_party/gpus/cuda/include/cuda.h"' +# All clients including TensorFlow should use these directives. cuda_header_library( name = "cuda_headers", hdrs = [ "cuda/cuda_config.h", - ":cuda-include" + ":cuda-include", ], include_prefix = "third_party/gpus", includes = [ @@ -54,10 +63,8 @@ cuda_header_library( cc_library( name = "cudart_static", srcs = ["cuda/lib/%{cudart_static_lib}"], - linkopts = select({ - ":freebsd": [], - "//conditions:default": ["-ldl"], - }) + [ + linkopts = [ + "-ldl", "-lpthread", %{cudart_static_linkopt} ], @@ -79,45 +86,45 @@ cuda_header_library( name = "cublas_headers", hdrs = [":cublas-include"], include_prefix = "third_party/gpus/cuda/include", + includes = ["cublas/include"], strip_include_prefix = "cublas/include", deps = [":cuda_headers"], - includes = ["cublas/include"], ) cuda_header_library( name = "cusolver_headers", hdrs = [":cusolver-include"], include_prefix = "third_party/gpus/cuda/include", + includes = ["cusolver/include"], strip_include_prefix = "cusolver/include", deps = [":cuda_headers"], - includes = ["cusolver/include"], ) cuda_header_library( name = "cufft_headers", hdrs = [":cufft-include"], include_prefix = "third_party/gpus/cuda/include", + includes = ["cufft/include"], strip_include_prefix = "cufft/include", deps = [":cuda_headers"], - includes = ["cufft/include"], ) cuda_header_library( name = "cusparse_headers", hdrs = [":cusparse-include"], include_prefix = "third_party/gpus/cuda/include", + includes = ["cusparse/include"], strip_include_prefix = "cusparse/include", deps = [":cuda_headers"], - includes = ["cusparse/include"], ) cuda_header_library( name = "curand_headers", hdrs = [":curand-include"], include_prefix = "third_party/gpus/cuda/include", + includes = ["curand/include"], strip_include_prefix = "curand/include", deps = [":cuda_headers"], - includes = ["curand/include"], ) cc_library( @@ -186,13 +193,13 @@ cc_library( alias( name = "cub_headers", - actual = "%{cub_actual}" + actual = "%{cub_actual}", ) cuda_header_library( name = "cupti_headers", hdrs = [":cuda-extras"], - include_prefix="third_party/gpus", + include_prefix = "third_party/gpus", includes = ["cuda/extras/CUPTI/include/"], deps = [":cuda_headers"], ) @@ -225,8 +232,7 @@ bzl_library( py_library( name = "cuda_config_py", - srcs = ["cuda/cuda_config.py"] + srcs = ["cuda/cuda_config.py"], ) %{copy_rules} - diff --git a/third_party/gpus/cuda/BUILD.windows.tpl b/third_party/gpus/cuda/BUILD.windows.tpl index 65c152e92ca..e34a29e2a65 100644 --- a/third_party/gpus/cuda/BUILD.windows.tpl +++ b/third_party/gpus/cuda/BUILD.windows.tpl @@ -1,41 +1,48 @@ load(":build_defs.bzl", "cuda_header_library") load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load("@bazel_skylib//lib:selects.bzl", "selects") licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like package(default_visibility = ["//visibility:public"]) -config_setting( - name = "using_nvcc", - values = { - "define": "using_cuda_nvcc=true", - }, +# Config setting whether TensorFlow is built with CUDA support using clang. +# +# TODO(b/174244321), DEPRECATED: this target will be removed when all users +# have been converted to :is_cuda_enabled (most) or :is_cuda_compiler_clang. +selects.config_setting_group( + name = "using_clang", + match_all = [ + "@local_config_cuda//:is_cuda_enabled", + "@local_config_cuda//:is_cuda_compiler_clang", + ], ) -config_setting( - name = "using_clang", - values = { - "define": "using_cuda_clang=true", - }, +# Config setting whether TensorFlow is built with CUDA support using nvcc. +# +# TODO(b/174244321), DEPRECATED: this target will be removed when all users +# have been converted to :is_cuda_enabled (most) or :is_cuda_compiler_nvcc. +selects.config_setting_group( + name = "using_nvcc", + match_all = [ + "@local_config_cuda//:is_cuda_enabled", + "@local_config_cuda//:is_cuda_compiler_nvcc", + ], ) # Equivalent to using_clang && -c opt. -config_setting( +selects.config_setting_group( name = "using_clang_opt", - values = { - "define": "using_cuda_clang=true", - "compilation_mode": "opt", - }, + match_all = [ + ":using_clang", + ":_opt", + ], ) config_setting( - name = "darwin", - values = {"cpu": "darwin"}, -) - -config_setting( - name = "freebsd", - values = {"cpu": "freebsd"}, + name = "_opt", + values = {"compilation_mode": "opt"}, + visibility = ["//visibility:private"], ) # Provides CUDA headers for '#include "third_party/gpus/cuda/include/cuda.h"' @@ -44,7 +51,7 @@ cuda_header_library( name = "cuda_headers", hdrs = [ "cuda/cuda_config.h", - ":cuda-include" + ":cuda-include", ], include_prefix = "third_party/gpus", includes = [ @@ -127,6 +134,12 @@ cc_import( system_provided = 1, ) +cc_import( + name = "cublasLt", + interface_library = "cuda/lib/%{cublasLt_lib}", + system_provided = 1, +) + cc_import( name = "cusolver", interface_library = "cuda/lib/%{cusolver_lib}", @@ -163,6 +176,7 @@ cc_library( name = "cuda", deps = [ ":cublas", + ":cublasLt", ":cuda_headers", ":cudart", ":cudnn", @@ -173,13 +187,13 @@ cc_library( alias( name = "cub_headers", - actual = "%{cub_actual}" + actual = "%{cub_actual}", ) cuda_header_library( name = "cupti_headers", hdrs = [":cuda-extras"], - include_prefix="third_party/gpus", + include_prefix = "third_party/gpus", includes = ["cuda/extras/CUPTI/include/"], deps = [":cuda_headers"], ) @@ -211,7 +225,7 @@ bzl_library( py_library( name = "cuda_config_py", - srcs = ["cuda/cuda_config.py"] + srcs = ["cuda/cuda_config.py"], ) %{copy_rules} diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index 8b2c4ce773c..bc370456fe8 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -1378,6 +1378,8 @@ def _create_remote_cuda_repository(repository_ctx, remote_config_repo): def _cuda_autoconf_impl(repository_ctx): """Implementation of the cuda_autoconf repository rule.""" + build_file = Label("//third_party/gpus:local_config_cuda.BUILD") + if not enable_cuda(repository_ctx): _create_dummy_repository(repository_ctx) elif get_host_environ(repository_ctx, _TF_CUDA_CONFIG_REPO) != None: @@ -1393,6 +1395,8 @@ def _cuda_autoconf_impl(repository_ctx): else: _create_local_cuda_repository(repository_ctx) + repository_ctx.symlink(build_file, "BUILD") + # For @bazel_tools//tools/cpp:windows_cc_configure.bzl _MSVC_ENVVARS = [ "BAZEL_VC", diff --git a/third_party/gpus/local_config_cuda.BUILD b/third_party/gpus/local_config_cuda.BUILD new file mode 100644 index 00000000000..52cfd31a135 --- /dev/null +++ b/third_party/gpus/local_config_cuda.BUILD @@ -0,0 +1,50 @@ +load( + "@bazel_skylib//rules:common_settings.bzl", + "bool_flag", + "string_flag", +) + +package(default_visibility = ["//visibility:public"]) + +# Build flag to enable CUDA support. +# +# Enable with '--@local_config_cuda//:enable_cuda', or indirectly with +# ./configure or '--config=cuda'. +bool_flag( + name = "enable_cuda", + build_setting_default = False, +) + +# Config setting whether CUDA support has been requested. +# +# Enable path: ./configure > --config=cuda (.tf_configure.bazelrc) +# > --//tensorflow:enable_cuda (.bazelrc) > :is_cuda_enabled +config_setting( + name = "is_cuda_enabled", + flag_values = {":enable_cuda": "True"}, +) + +# Build flag to select CUDA compiler. +# +# Set with '--@local_config_cuda//:cuda_compiler=...', or indirectly with +# ./configure, '--config=cuda' or '--config=cuda_clang'. +string_flag( + name = "cuda_compiler", + build_setting_default = "nvcc", + values = [ + "clang", + "nvcc", + ], +) + +# Config setting whether CUDA device code should be compiled with clang. +config_setting( + name = "is_cuda_compiler_clang", + flag_values = {":cuda_compiler": "clang"}, +) + +# Config setting whether CUDA device code should be compiled with nvcc. +config_setting( + name = "is_cuda_compiler_nvcc", + flag_values = {":cuda_compiler": "nvcc"}, +) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index f4dddd974a8..ab9800343e6 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "bf6380c0966b26a2aec7f2072efd0a1a9b6328f2" - LLVM_SHA256 = "b225983215a6ba47cbc344df25c158c60baf8cf2cfe5e990e13e028f7ff6dcdc" + LLVM_COMMIT = "99a6d003edbe97fcb94854547276ffad3382ec1d" + LLVM_SHA256 = "763f933a1db16a857b8fd5a4a70bb8e9a4ffb83be5e9c2cd86a62b4ed2cb652a" tf_http_archive( name = name, diff --git a/third_party/mkl_dnn/mkldnn.BUILD b/third_party/mkl_dnn/mkldnn.BUILD index f15bb5bf995..8123cc6c12a 100644 --- a/third_party/mkl_dnn/mkldnn.BUILD +++ b/third_party/mkl_dnn/mkldnn.BUILD @@ -5,14 +5,6 @@ load( "template_rule", ) -config_setting( - name = "clang_linux_x86_64", - values = { - "cpu": "k8", - "define": "using_clang=true", - }, -) - template_rule( name = "mkldnn_config_h", src = "include/mkldnn_config.h.in", diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD index f6ec1688b9c..982a3a80217 100644 --- a/third_party/mlir/BUILD +++ b/third_party/mlir/BUILD @@ -1167,6 +1167,8 @@ cc_library( ":StandardToLLVM", ":StandardToSPIRV", ":TosaToLinalg", + ":TosaToSCF", + ":TosaToStandard", ":VectorToLLVM", ":VectorToROCDL", ":VectorToSCF", @@ -5112,6 +5114,54 @@ cc_library( ], ) +cc_library( + name = "TosaToSCF", + srcs = glob([ + "lib/Conversion/TosaToSCF/*.cpp", + "lib/Conversion/TosaToSCF/*.h", + ]) + ["lib/Conversion/PassDetail.h"], + hdrs = glob([ + "include/mlir/Conversion/TosaToSCF/*.h", + ]), + includes = [ + "include", + "lib/Conversion/TosaToSCF", + ], + deps = [ + ":ConversionPassIncGen", + ":IR", + ":Pass", + ":SCFDialect", + ":TensorDialect", + ":TosaDialect", + ":Transforms", + ], +) + +cc_library( + name = "TosaToStandard", + srcs = glob([ + "lib/Conversion/TosaToStandard/*.cpp", + "lib/Conversion/TosaToStandard/*.h", + ]) + ["lib/Conversion/PassDetail.h"], + hdrs = glob([ + "include/mlir/Conversion/TosaToStandard/*.h", + ]), + includes = [ + "include", + "lib/Conversion/TosaToStandard", + ], + deps = [ + ":ConversionPassIncGen", + ":IR", + ":Pass", + ":StandardOps", + ":TensorDialect", + ":TosaDialect", + ":Transforms", + ], +) + filegroup( name = "ComplexOpsTdFiles", srcs = [ diff --git a/third_party/mlir/tblgen.bzl b/third_party/mlir/tblgen.bzl index 45d53281952..f303af5228e 100644 --- a/third_party/mlir/tblgen.bzl +++ b/third_party/mlir/tblgen.bzl @@ -1,87 +1,375 @@ """BUILD extensions for MLIR table generation.""" -def gentbl(name, tblgen, td_file, tbl_outs, td_srcs = [], td_includes = [], td_relative_includes = [], strip_include_prefix = None, test = False, **kwargs): - """gentbl() generates tabular code from a table definition file. +TdInfo = provider( + "Holds tablegen files and the dependencies and include paths necessary to" + + " build them.", + fields = { + "transitive_sources": "td files transitively used by this rule.", + "transitive_includes": ( + "include arguments to add to the final tablegen invocation. These" + + " are the absolute directory paths that will be added with '-I'." + ), + }, +) + +# For now we allow anything that provides DefaultInfo to just forward its files. +# In particular, this allows filegroups to be used. This is mostly to ease +# transition. In the future, the TdInfo provider will be required. +# TODO(gcmn): Switch to enforcing TdInfo provider. +def _get_dep_transitive_srcs(dep): + """Extract TdInfo.transitive_sources, falling back to DefaultInfo.files.""" + if TdInfo in dep: + return dep[TdInfo].transitive_sources + return dep[DefaultInfo].files + +def _get_dep_transitive_includes(dep): + """Extract TdInfo.transitive_includes, falling back to an empty depset().""" + if TdInfo in dep: + return dep[TdInfo].transitive_includes + return depset() + +def _get_transitive_srcs(srcs, deps): + """Obtain the source files for a target and its transitive dependencies. Args: - name: The name of the build rule for use in dependencies. + srcs: a list of source files + deps: a list of targets that are direct dependencies + Returns: + a collection of the transitive sources + """ + return depset( + direct = srcs, + transitive = [_get_dep_transitive_srcs(dep) for dep in deps], + ) + +def _get_transitive_includes(includes, deps): + """Obtain the includes paths for a target and its transitive dependencies. + + Args: + includes: a list of include paths + deps: a list of targets that are direct dependencies + Returns: + a collection of the transitive include paths + """ + return depset( + direct = includes, + transitive = [_get_dep_transitive_includes(dep) for dep in deps], + ) + +def _prefix_roots(ctx, includes): + """Map the given includes to be relative to all root directories. + + This will expand them to be relative to all the root directories available + in the execution environment for ctx.run (bin and genfiles in addition to + the normal source root) + """ + prefixed_includes = [] + for include in includes: + prefixed_includes.append(include) + prefixed_includes.append(ctx.genfiles_dir.path + "/" + include) + prefixed_includes.append(ctx.bin_dir.path + "/" + include) + return prefixed_includes + +def _resolve_includes(ctx, includes): + """Resolves include paths to paths relative to the execution root. + + Relative paths are interpreted as relative to the current label's package. + Absolute paths are interpreted as relative to the current label's workspace + root.""" + package = ctx.label.package + workspace_root = ctx.label.workspace_root + workspace_root = workspace_root if workspace_root else "." + resolved_includes = [] + for include in includes: + if not include.startswith("/"): + include = "/" + package + "/" + include + include = workspace_root + include + resolved_includes.extend(_prefix_roots(ctx, [include])) + return resolved_includes + +def _td_library_impl(ctx): + trans_srcs = _get_transitive_srcs(ctx.files.srcs, ctx.attr.deps) + trans_includes = _get_transitive_includes( + _resolve_includes(ctx, ctx.attr.includes), + ctx.attr.deps, + ) + return [ + DefaultInfo(files = trans_srcs), + TdInfo( + transitive_sources = trans_srcs, + transitive_includes = trans_includes, + ), + ] + +td_library = rule( + _td_library_impl, + attrs = { + "srcs": attr.label_list(allow_files = True), + "includes": attr.string_list( + doc = "Include paths to be added to the final tablegen tool" + + " invocation. Relative paths are interpreted as relative to" + + " the current label's package. Absolute paths are" + + " interpreted as relative to the current label's workspace", + ), + # TODO(gcmn): limit to TdInfo providers. + "deps": attr.label_list( + doc = "Dependencies providing tablegen source files and include" + + " paths.", + ), + }, +) + +def _gentbl_rule_impl(ctx): + td_file = ctx.file.td_file + + trans_srcs = _get_transitive_srcs( + ctx.files.td_srcs + [td_file], + ctx.attr.deps, + ) + + # Note that we have two types of includes here. The deprecated ones expanded + # only by "_prefix_roots" are already relative to the execution root, i.e. + # may contain an `external/` prefix if the current workspace + # is not the main workspace (where workspace_name is something configured + # per-project and therefore generally not known). Note that dirname also + # already includes this prefix. The new style of includes have it prepended + # automatically by `_resolve_includes` to avoid BUILD files having to depend + # on project specific configurations and Bazel implementation details. + trans_includes = _get_transitive_includes( + _resolve_includes(ctx, ctx.attr.includes + ["/"]) + + _prefix_roots(ctx, ctx.attr.td_includes + [td_file.dirname]), + ctx.attr.deps, + ) + + args = ctx.actions.args() + args.add_all(ctx.attr.opts) + args.add(td_file) + args.add_all(trans_includes, before_each = "-I") + + args.add("-o", ctx.outputs.out.path) + + ctx.actions.run( + outputs = [ctx.outputs.out], + inputs = trans_srcs, + executable = ctx.executable.tblgen, + arguments = [args], + ) + + return [DefaultInfo()] + +gentbl_rule = rule( + _gentbl_rule_impl, + doc = "Generates tabular code from a table definition file.", + # Match genrule behavior + output_to_genfiles = True, + attrs = { + "tblgen": attr.label( + doc = "The tablegen executable with which to generate `out`.", + executable = True, + cfg = "exec", + ), + "td_file": attr.label( + doc = "The tablegen file to run through `tblgen`.", + allow_single_file = True, + mandatory = True, + ), + "td_srcs": attr.label_list( + doc = "Additional tablegen files included by `td_file`. It is not" + + " necessary to list td_file here (though not an error).", + allow_files = True, + ), + # TODO(gcmn): limit to TdInfo providers. + "deps": attr.label_list( + doc = "Dependencies providing tablegen source files and include" + + " paths.", + ), + "out": attr.output( + doc = "The output file for the tablegen invocation.", + mandatory = True, + ), + "opts": attr.string_list( + doc = "Additional command line options to add to the tablegen" + + " invocation. For include arguments, prefer to use" + + " `includes`.", + ), + "includes": attr.string_list( + doc = "Include paths to be added to the final tablegen tool" + + " invocation. Relative paths are interpreted as relative to" + + " the current label's package. Absolute paths are" + + " interpreted as relative to the current label's workspace." + + " Includes are applied from all roots available in the" + + " execution environment (source, genfiles, and bin" + + " directories). The execution roots themselves and the " + + " directory of td_file are always added.", + ), + "td_includes": attr.string_list( + doc = "Include paths to add to the tablegen invocation. Paths are" + + " interpreted as relative to the current label's workspace" + + " root and applied from all roots available in the" + + " execution environment (source, genfiles, and bin" + + " directories). Deprecated. Use `includes` instead.", + ), + }, +) + +# TODO(gcmn): Figure out how to reduce duplication with _gentbl_rule_impl +def _gentbl_test_impl(ctx): + td_file = ctx.file.td_file + + trans_srcs = _get_transitive_srcs( + ctx.files.td_srcs + [td_file], + ctx.attr.deps, + ) + + # Note that we have two types of includes here. The deprecated ones expanded + # only by "_prefix_roots" are already relative to the execution root, i.e. + # may contain an `external/` prefix if the current workspace + # is not the main workspace (where workspace_name is something configured + # per-project and therefore generally not known). Note that dirname also + # already includes this prefix. The new style of includes have it prepended + # automatically by `_resolve_includes` to avoid BUILD files having to depend + # on project specific configurations and Bazel implementation details. + trans_includes = _get_transitive_includes( + _resolve_includes(ctx, ctx.attr.includes + ["/"]) + + _prefix_roots(ctx, ctx.attr.td_includes + [td_file.dirname]), + ctx.attr.deps, + ) + + test_args = [ctx.executable.tblgen.short_path] + test_args.extend(ctx.attr.opts) + test_args.append(td_file.path) + test_args.extend(["-I " + include for include in trans_includes.to_list()]) + + test_args.extend(["-o", "/dev/null"]) + + ctx.actions.write( + ctx.outputs.executable, + content = " ".join(test_args), + is_executable = True, + ) + + return [DefaultInfo( + runfiles = ctx.runfiles( + [ctx.executable.tblgen], + transitive_files = trans_srcs, + ), + )] + +gentbl_test = rule( + _gentbl_test_impl, + test = True, + doc = "A shell test that tests the given tablegen invocation. Note" + + " that unlike gentbl_rule, this builds and invokes `tblgen` in the" + + " target configuration. Takes all the same arguments as gentbl_rule" + + " except for `out` (as it does not generate any output)", + # Match genrule behavior + output_to_genfiles = True, + attrs = { + "tblgen": attr.label( + doc = "The tablegen executable run in the shell command. Note" + + " that this is built in the target configuration.", + executable = True, + cfg = "target", + ), + "td_file": attr.label( + doc = "See gentbl_rule.td_file", + allow_single_file = True, + mandatory = True, + ), + "td_srcs": attr.label_list( + doc = "See gentbl_rule.td_srcs", + allow_files = True, + ), + "deps": attr.label_list(doc = "See gentbl_rule.deps"), + "opts": attr.string_list(doc = "See gentbl_rule.opts"), + "includes": attr.string_list(doc = "See gentbl_rule.includes"), + "td_includes": attr.string_list(doc = "See gentbl_rule.td_includes"), + }, +) + +def gentbl( + name, + tblgen, + td_file, + tbl_outs, + td_srcs = [], + td_includes = [], + includes = [], + td_relative_includes = [], + deps = [], + strip_include_prefix = None, + test = False, + **kwargs): + """Create multiple tablegen generated files using the same tool and input. + + All generated outputs are bundled in a cc_library rule. + + Args: + name: The name of the generated cc_library rule for use in dependencies. tblgen: The binary used to produce the output. td_file: The primary table definitions file. tbl_outs: A list of tuples (opts, out), where each opts is a string of options passed to tblgen, and the out is the corresponding output file produced. - td_srcs: A list of table definition files included transitively. - td_includes: A list of include paths for relative includes, provided as build targets. - td_relative_includes: A list of include paths for relative includes, provided as relative path. - strip_include_prefix: Attribute to pass through to cc_library. - test: Whether to create a test to invoke the tool too. - **kwargs: Extra keyword arguments to pass to native rules such as cc_library below. + td_srcs: See gentbl_rule.td_srcs + includes: See gentbl_rule.includes + td_includes: See gentbl_rule.td_includes + td_relative_includes: An alias for "includes". Deprecated. Use includes + instead. + deps: See gentbl_rule.deps + strip_include_prefix: attribute to pass through to cc_library. + test: whether to create a shell test that invokes the tool too. + **kwargs: Extra keyword arguments to pass to all generated rules. """ - srcs = [] - srcs += td_srcs - if td_file not in td_srcs: - srcs += [td_file] + for (opts_string, out) in tbl_outs: + # TODO(gcmn): The API of opts as single string is preserved for backward + # compatibility. Change to taking a sequence. + opts = opts_string.split(" ") if opts_string else [] - td_includes_cmd = [ - "-I external/llvm-project/mlir/include -I external/org_tensorflow", - "-I $(GENDIR)/external/llvm-project/mlir/include -I $(GENDIR)/external/org_tensorflow", - ] - for td_include in td_includes: - td_includes_cmd += [ - "-I%s" % td_include, - "-I$(GENDIR)/%s" % td_include, - ] - for td_include in td_relative_includes: - td_includes_cmd += [ - "-I%s/%s -Iexternal/org_tensorflow/%s/%s" % (native.package_name(), td_include, native.package_name(), td_include), - "-I$(GENDIR)/%s/%s" % (native.package_name(), td_include), - ] + # Filter out empty options + opts = [opt for opt in opts if opt] - local_inc = "-I $$(dirname $(location %s))" % td_file - - if test: - # Rule to generate shell script to invoke tblgen. This generates a very - # bare shell file which the sh_test uses. - native.genrule( - name = "%s_genrule_sh" % name, - srcs = srcs, - outs = ["%s.gen.sh" % name], - cmd = ("echo \"\\$$1\" %s \\$${@:2} -o /dev/null > $@" % local_inc), - executable = 1, + first_opt = opts[0] if opts else "" + rule_suffix = "_{}_{}".format( + first_opt.replace("-", "_").replace("=", "_"), + str(hash(opts_string)), + ) + gentbl_name = "%s_%s_genrule" % (name, rule_suffix) + gentbl_rule( + name = gentbl_name, + td_file = td_file, + tblgen = tblgen, + opts = opts, + td_srcs = td_srcs, + deps = deps, + includes = includes + td_relative_includes, + # TODO(gcmn): Update callers to td_library and explicit includes and + # drop this hardcoded include. + td_includes = td_includes + [ + "external/llvm-project/mlir/include", + ], + out = out, **kwargs ) - - for (opts, out) in tbl_outs: - # All arguments to generate the output except output destination. - base_args = [ - "$(location %s)" % tblgen, - "%s" % opts, - "$(location %s)" % td_file, - "-I$(GENDIR)", - ] + td_includes_cmd - first_opt = opts.split(" ", 1)[0] - rule_suffix = "_{}_{}".format(first_opt.replace("-", "_").replace("=", "_"), str(hash(opts))) - - # Rule to generate code using generated shell script. - native.genrule( - name = "%s_%s_genrule" % (name, rule_suffix), - srcs = srcs, - outs = [out], - tools = [tblgen], - message = "Generating code from table: %s" % td_file, - cmd = (" ".join(base_args) + " %s -o $@" % local_inc), - **kwargs - ) - - # Optionally generate rule to test tblgen invocation. - # Disable these on windows, because $(location ...) does not seem to - # work as expected on windows. if test: - native.sh_test( - name = "%s_%s_genrule_test" % (name, rule_suffix), - srcs = ["%s.gen.sh" % name], - args = base_args, - data = srcs + [tblgen], + # Also run the generator in the target configuration as a test. This + # means it gets run with asserts and sanitizers and such when they + # are enabled and is counted in coverage. + gentbl_test( + name = "%s_test" % (gentbl_name,), + td_file = td_file, + tblgen = tblgen, + opts = opts, + td_srcs = td_srcs, + deps = deps, + includes = includes + td_relative_includes, + # TODO(gcmn): Update callers to td_library and explicit includes + # and drop this hardcoded include. + td_includes = td_includes + [ + "external/llvm-project/mlir/include", + ], + # Shell files not executable on Windows. + # TODO(gcmn): Support windows. tags = ["no_windows"], **kwargs ) @@ -91,7 +379,8 @@ def gentbl(name, tblgen, td_file, tbl_outs, td_srcs = [], td_includes = [], td_r hdrs = [f for (opts, f) in tbl_outs if opts not in skip_opts] native.cc_library( name = name, - # include_prefix does not apply to textual_hdrs. + # strip_include_prefix does not apply to textual_hdrs. + # https://github.com/bazelbuild/bazel/issues/12424 hdrs = hdrs if strip_include_prefix else [], strip_include_prefix = strip_include_prefix, textual_hdrs = hdrs, diff --git a/third_party/ortools/BUILD b/third_party/ortools/BUILD deleted file mode 100644 index 2f5d02becb9..00000000000 --- a/third_party/ortools/BUILD +++ /dev/null @@ -1 +0,0 @@ -# Dummy BUILD file to make this directory a package. diff --git a/third_party/ortools/BUILD.bazel b/third_party/ortools/BUILD.bazel deleted file mode 100644 index 61191e3d271..00000000000 --- a/third_party/ortools/BUILD.bazel +++ /dev/null @@ -1,13 +0,0 @@ -# Google's software suite for combinatorial optimization - -licenses(["notice"]) # Apache2 license - -exports_files(["LICENSE-2.0.txt"]) - -native.cc_library( - name = "linear_solver_glop", - deps = [ - "@ortools_archive//linear_solver:linear_solver_glop", - ], - visibility = ["//visibility:public"], -) diff --git a/third_party/ortools/workspace.bzl b/third_party/ortools/workspace.bzl deleted file mode 100644 index 42eb122f7a6..00000000000 --- a/third_party/ortools/workspace.bzl +++ /dev/null @@ -1,15 +0,0 @@ -"""loads the aws library, used by TF.""" - -load("//third_party:repo.bzl", "tf_http_archive") - -def repo(): - tf_http_archive( - name = "ortools_archive", - urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/or-tools/archive/v6.7.2.tar.gz", - "https://github.com/google/or-tools/archive/v6.7.2.tar.gz", - ], - sha256 = "d025a95f78b5fc5eaa4da5f395f23d11c23cf7dbd5069f1f627f002de87b86b9", - strip_prefix = "or-tools-6.7.2/src", - build_file = "//third_party/ortools:BUILD.bazel", - ) diff --git a/third_party/protobuf/protobuf.patch b/third_party/protobuf/protobuf.patch index 8ce4a843759..98f31908b8b 100644 --- a/third_party/protobuf/protobuf.patch +++ b/third_party/protobuf/protobuf.patch @@ -1,25 +1,25 @@ diff --git a/BUILD b/BUILD -index dbae719ff..87dc38470 100644 +index b32a1f161..9fa411783 100644 --- a/BUILD +++ b/BUILD -@@ -23,7 +23,7 @@ config_setting( +@@ -51,7 +51,7 @@ GTEST_MAIN = select({ # ZLIB configuration ################################################################################ - + -ZLIB_DEPS = ["@zlib//:zlib"] +ZLIB_DEPS = ["@zlib"] - + ################################################################################ # Protobuf Runtime Library -@@ -143,6 +143,7 @@ cc_library( +@@ -198,6 +198,7 @@ cc_library( copts = COPTS, includes = ["src/"], linkopts = LINK_OPTS, + alwayslink = 1, visibility = ["//visibility:public"], ) - -@@ -213,6 +214,7 @@ cc_library( + +@@ -270,6 +271,7 @@ cc_library( copts = COPTS, includes = ["src/"], linkopts = LINK_OPTS, @@ -27,17 +27,25 @@ index dbae719ff..87dc38470 100644 visibility = ["//visibility:public"], deps = [":protobuf_lite"] + PROTOBUF_DEPS, ) +@@ -849,7 +851,7 @@ py_proto_library( + py_extra_srcs = glob(["python/**/__init__.py"]), + py_libs = [ + ":python_srcs", +- "@six//:six", ++ "//external:six", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], diff --git a/protobuf.bzl b/protobuf.bzl -index e0653321f..253d9cbb5 100644 +index e2821f5b5..9f6a13ef5 100644 --- a/protobuf.bzl +++ b/protobuf.bzl -@@ -84,7 +84,9 @@ def _proto_gen_impl(ctx): - +@@ -88,6 +88,8 @@ def _proto_gen_impl(ctx): for dep in ctx.attr.deps: import_flags += dep.proto.import_flags deps += dep.proto.deps + import_flags = depset(import_flags).to_list() + deps = depset(deps).to_list() - + if not ctx.attr.gen_cc and not ctx.attr.gen_py and not ctx.executable.plugin: - return struct( \ No newline at end of file + return struct( diff --git a/third_party/toolchains/clang6/BUILD b/third_party/toolchains/clang6/BUILD deleted file mode 100644 index ffd0fb0cdc5..00000000000 --- a/third_party/toolchains/clang6/BUILD +++ /dev/null @@ -1 +0,0 @@ -package(default_visibility = ["//visibility:public"]) diff --git a/third_party/toolchains/clang6/CROSSTOOL.tpl b/third_party/toolchains/clang6/CROSSTOOL.tpl deleted file mode 100644 index 16e36a1cd1f..00000000000 --- a/third_party/toolchains/clang6/CROSSTOOL.tpl +++ /dev/null @@ -1,583 +0,0 @@ -major_version: "v1" -minor_version: "llvm:6.0.0" -default_target_cpu: "k8" - -default_toolchain { - cpu: "k8" - toolchain_identifier: "k8-clang-6.0-cxx-4.8-linux-gnu" -} - -toolchain { - compiler: "clang6" # bazel build --compiler=clang6 - target_cpu: "k8" # bazel build --cpu=k8 - target_libc: "GLIBC_2.19" # bazel build --glibc=GLIBC_2.19 - - abi_libc_version: "2.19" - abi_version: "gcc-4.8-cxx11" - builtin_sysroot: "" - cc_target_os: "linux-gnu" - default_python_version: "python2.7" - dynamic_runtimes_filegroup: "dynamic-runtime-libs-k8" - host_system_name: "x86_64-unknown-linux-gnu" - needsPic: true - static_runtimes_filegroup: "static-runtime-libs-k8" - supports_embedded_runtimes: true - supports_fission: true - supports_gold_linker: true - supports_incremental_linker: true - supports_interface_shared_objects: true - supports_normalizing_ar: true - supports_start_end_lib: true - supports_thin_archives: true - target_system_name: "x86_64-unknown-linux-gnu" - toolchain_identifier: "k8-clang-6.0-cxx-4.8-linux-gnu" - - tool_path { name: "ar" path: "%package(@local_config_clang6//clang6)%/llvm/bin/llvm-ar" } - tool_path { name: "as" path: "%package(@local_config_clang6//clang6)%/llvm/bin/llvm-as" } - tool_path { name: "compat-ld" path: "%package(@local_config_clang6//clang6)%/llvm/bin/ld.lld" } - tool_path { name: "cpp" path: "%package(@local_config_clang6//clang6)%/llvm/bin/llvm-cpp" } - tool_path { name: "dwp" path: "%package(@local_config_clang6//clang6)%/llvm/bin/llvm-dwp" } - tool_path { name: "gcc" path: "%package(@local_config_clang6//clang6)%/llvm/bin/clang" } - tool_path { name: "gcov" path: "%package(@local_config_clang6//clang6)%/llvm/bin/llvm-cov" } - tool_path { name: "ld" path: "%package(@local_config_clang6//clang6)%/llvm/bin/ld.lld" } - tool_path { name: "llvm-profdata" path: "%package(@local_config_clang6//clang6)%/llvm/bin/llvm-profdata" } - tool_path { name: "nm" path: "%package(@local_config_clang6//clang6)%/llvm/bin/llvm-nm" } - tool_path { name: "objcopy" path: "%package(@local_config_clang6//clang6)%/llvm/bin/llvm-objcopy" } - tool_path { name: "objdump" path: "%package(@local_config_clang6//clang6)%/sbin/objdump" } - tool_path { name: "strip" path: "%package(@local_config_clang6//clang6)%/sbin/strip" } - - unfiltered_cxx_flag: "-no-canonical-prefixes" - - # Make C++ compilation deterministic. Use linkstamping instead of these - # compiler symbols. - unfiltered_cxx_flag: "-Wno-builtin-macro-redefined" - unfiltered_cxx_flag: "-D__DATE__=\"redacted\"" - unfiltered_cxx_flag: "-D__TIMESTAMP__=\"redacted\"" - unfiltered_cxx_flag: "-D__TIME__=\"redacted\"" - - objcopy_embed_flag: "-I" - objcopy_embed_flag: "binary" - - # This action_config makes features flags propagate - # to CC_FLAGS for genrules, and eventually skylark. - action_config { - action_name: "cc-flags-make-variable" - config_name: "cc-flags-make-variable" - } - - # Security hardening on by default. - # Conservative choice; -D_FORTIFY_SOURCE=2 may be unsafe in some cases. - # We need to undef it before redefining it as some distributions now have - # it enabled by default. - compiler_flag: "-U_FORTIFY_SOURCE" - compiler_flag: "-D_FORTIFY_SOURCE=1" - compiler_flag: "-fstack-protector" - linker_flag: "-Wl,-z,relro,-z,now" - - # TODO(b/151234342): Clean up the following options. - # This adds a little bit more durability to our Clang build. - # - # Folks who do maintenance work on TF Bazel Clang should consider - # commenting out these lines, while doing that work, to gain a better - # understanding of what the intersection of support looks like between GCC - # and Clang. Please note that Bazel does not support -Xclang-only. - compiler_flag: "-Wno-unknown-warning-option" - compiler_flag: "-Wno-unused-command-line-argument" - compiler_flag: "-Wno-ignored-optimization-argument" - - #### Common compiler options. #### - compiler_flag: "-D_REENTRANT" - compiler_flag: "-D__STDC_FORMAT_MACROS" - compiler_flag: "-DSUPPRESS_USE_FILE_OFFSET64" - compiler_flag: "-Wall" - compiler_flag: "-Wformat-security" - compiler_flag: "-Wframe-larger-than=16384" - compiler_flag: "-Wno-char-subscripts" - compiler_flag: "-Wno-error=deprecated-declarations" - compiler_flag: "-Wno-uninitialized" - compiler_flag: "-Wno-sign-compare" - compiler_flag: "-Wno-strict-overflow" - compiler_flag: "-Wno-unused-function" - compiler_flag: "-fdiagnostics-show-option" - compiler_flag: "-fmessage-length=0" - compiler_flag: "-fno-exceptions" - compiler_flag: "-fno-omit-frame-pointer" - compiler_flag: "-fno-strict-aliasing" - compiler_flag: "-fno-use-init-array" - compiler_flag: "-funsigned-char" - compiler_flag: "-gmlt" - cxx_flag: "-Wno-deprecated" - cxx_flag: "-Wno-invalid-offsetof" # Needed for protobuf code (2017-11-07) - cxx_flag: "-fshow-overloads=best" - compiler_flag: "-Wthread-safety-analysis" - - # Python extensions unfortunately make this go wild. - compiler_flag: "-Wno-writable-strings" - - # GCC's warning produces too many false positives: - cxx_flag: "-Woverloaded-virtual" - cxx_flag: "-Wnon-virtual-dtor" - - # Enable coloring even if there's no attached terminal. Bazel removes the - # escape sequences if --nocolor is specified. This isn't supported by gcc - # on Ubuntu 14.04. - compiler_flag: "-fcolor-diagnostics" - - # Disable some broken warnings from Clang. - compiler_flag: "-Wno-ambiguous-member-template" - compiler_flag: "-Wno-pointer-sign" - - # These warnings have a low signal to noise ratio. - compiler_flag: "-Wno-reserved-user-defined-literal" - compiler_flag: "-Wno-return-type-c-linkage" - compiler_flag: "-Wno-invalid-source-encoding" - - # Per default we switch off any layering related warnings. - compiler_flag: "-Wno-private-header" - - # Clang-specific warnings that we explicitly enable for TensorFlow. Some of - # these aren't on by default, or under -Wall, or are subsets of warnings - # turned off above. - compiler_flag: "-Wfloat-overflow-conversion" - compiler_flag: "-Wfloat-zero-conversion" - compiler_flag: "-Wfor-loop-analysis" - compiler_flag: "-Wgnu-redeclared-enum" - compiler_flag: "-Winfinite-recursion" - compiler_flag: "-Wliteral-conversion" - compiler_flag: "-Wself-assign" - compiler_flag: "-Wstring-conversion" - compiler_flag: "-Wtautological-overlap-compare" - compiler_flag: "-Wunused-comparison" - compiler_flag: "-Wvla" - cxx_flag: "-Wdeprecated-increment-bool" - - # Clang code-generation flags for performance optimization. - compiler_flag: "-faligned-allocation" - compiler_flag: "-fnew-alignment=8" - - # Clang defaults to C99 while GCC defaults to C89. GCC plugins are written in - # C89 and don't have a BUILD rule we could add a copts flag to. - gcc_plugin_compiler_flag: "-std=gnu89" - - compilation_mode_flags { - mode: FASTBUILD - } - - compilation_mode_flags { - mode: DBG - compiler_flag: "-g" - } - - compilation_mode_flags { - mode: OPT - compiler_flag: "-g0" - compiler_flag: "-fdebug-types-section" - compiler_flag: "-DNDEBUG" - compiler_flag: "-fno-split-dwarf-inlining" - compiler_flag: "-Os" - compiler_flag: "-fexperimental-new-pass-manager" - compiler_flag: "-fdebug-info-for-profiling" - compiler_flag: "-ffunction-sections" - compiler_flag: "-fdata-sections" - linker_flag: "-Wl,--gc-sections" - linker_flag: "-Wl,-z,relro,-z,now" - } - - # Features indicating whether this is a host compile or not. Exactly one of - # these will be implicitly provided by bazel. - feature { name: "host" } - feature { name: "nonhost" } - - # Features indicating which compiler will be used for code generation. - feature { - name: "llvm_codegen" - provides: "codegen" - enabled: true - } - - # Features for compilation modes. Exactly one of these will be implicitly - # provided by bazel. - feature { name: "fastbuild" } - feature { name: "dbg" } - feature { name: "opt" } - - # Features controlling the C++ language mode. - feature { - name: "c++11" - provides: "c++std" - flag_set { - action: "c++-compile" - action: "c++-header-parsing" - action: "c++-header-preprocessing" - action: "c++-module-compile" - action: "linkstamp-compile" - flag_group { - flag: "-nostdinc++" - flag: "-std=c++11" - flag: "-Wc++14-extensions" - flag: "-Wc++2a-extensions" - flag: "-Wno-binary-literal" - } - } - } - feature { - name: "c++14" - provides: "c++std" - flag_set { - action: "c++-compile" - action: "c++-header-parsing" - action: "c++-header-preprocessing" - action: "c++-module-compile" - action: "linkstamp-compile" - flag_group { - flag: "-nostdinc++" - flag: "-std=c++14" - flag: "-Wc++11-compat" - flag: "-Wno-c++11-compat-binary-literal" - flag: "-Wc++2a-extensions" - } - } - } - feature { - name: "c++17" - provides: "c++std" - flag_set { - action: "c++-compile" - action: "c++-header-parsing" - action: "c++-header-preprocessing" - action: "c++-module-compile" - action: "linkstamp-compile" - flag_group { - flag: "-nostdinc++" - flag: "-std=c++17" - flag: "-Wc++11-compat" - flag: "-Wno-c++11-compat-binary-literal" - flag: "-Wc++2a-extensions" - } - } - } - feature { - name: "c++2a" - provides: "c++std" - flag_set { - action: "c++-compile" - action: "c++-header-parsing" - action: "c++-header-preprocessing" - action: "c++-module-compile" - action: "linkstamp-compile" - flag_group { - flag: "-nostdinc++" - flag: "-std=c++2a" - flag: "-Wc++11-compat" - flag: "-Wno-c++11-compat-binary-literal" - } - } - } - feature { - name: "c++default" - enabled: true - flag_set { - # Provide the c++11 flags if no standard is selected - with_feature { - not_feature: "c++11" - not_feature: "c++14" - not_feature: "c++17" - not_feature: "c++2a" - } - action: "c++-compile" - action: "c++-header-parsing" - action: "c++-header-preprocessing" - action: "c++-module-compile" - action: "linkstamp-compile" - flag_group { - flag: "-nostdinc++" - flag: "-std=c++11" - flag: "-Wc++14-extensions" - flag: "-Wc++2a-extensions" - flag: "-Wno-binary-literal" - } - } - } - - feature { - name: "use_compiler_rt" - requires { feature: "llvm_codegen" } - # TODO(saugustine): At the moment, "use_compiler_rt" also - # requires "linking_mode_flags { mode: FULLY_STATIC" ... }, - # but that isn't a feature. We should probably convert it. - flag_set { - action: "c++-link" - action: "c++-link-interface-dynamic-library" - action: "c++-link-dynamic-library" - action: "c++-link-executable" - # "link" is a misnomer for these actions. They are really just - # invocations of ar. - #action: "c++-link-pic-static-library" - #action: "c++-link-static-library" - #action: "c++-link-alwayslink-static-library" - #action: "c++-link-pic-static-library" - #action: "c++-link-alwayslink-pic-static-library" - flag_group { - flag: "-rtlib=compiler-rt" - flag: "-lunwind" - } - } - } - - feature { - name: "pie" - flag_set { - action: "assemble" - action: "preprocess-assemble" - action: "c-compile" - action: "c++-compile" - action: "c++-header-parsing" - action: "c++-header-preprocessing" - action: "c++-module-compile" - action: "c++-module-codegen" - action: "cc-flags-make-variable" - action: "lto-backend" - action: "linkstamp-compile" - flag_group { - flag: "-mpie-copy-relocations" - flag: "-fPIE" - } - } - flag_set { - action: "cc-flags-make-variable" - action: "c++-link-executable" - flag_group { - flag: "-pie" - } - } - } - - # Pic must appear after pie, because pic may need to override pie, and bazel - # turns it on selectively. These don't interact with other options. - # - # TODO: In practice, normal vs pic vs pie is a ternary mode. We should - # implement it that way. This will require changes to bazel, which only - # calculates whether or not pic is needed, not pie. - # - # NOTE: Bazel might make this all a moot point. - feature { - name: "pic" - flag_set { - action: "assemble" - action: "preprocess-assemble" - action: "c-compile" - action: "c++-compile" - action: "c++-module-codegen" - action: "c++-module-compile" - action: "linkstamp-compile" - expand_if_all_available: "pic" - flag_group { - flag: "-fPIC" - } - } - } - - feature { - name: "gold" - enabled: true - flag_set { - action: "c++-link-executable" - action: "c++-link-dynamic-library" - action: "c++-link-interface-dynamic-library" - flag_group { - expand_if_none_available: "lto" - flag: "-fuse-ld=gold" - } - } - } - - # This is great if you want linking TensorFlow to take ten minutes. - feature { - name: "lto" - requires { feature: "nonhost" } - flag_set { - action: "c-compile" - action: "c++-compile" - flag_group { - flag: "-flto=thin" - } - } - flag_set { - action: "c++-link-executable" - action: "c++-link-dynamic-library" - action: "c++-link-interface-dynamic-library" - flag_group { - flag: "-flto=thin" - } - } - } - - feature { - name: "parse_headers" - flag_set { - action: "c++-header-parsing" - flag_group { - flag: "-xc++-header" - flag: "-fsyntax-only" - } - } - } - - feature { - name: "preprocess_headers" - flag_set { - action: "c++-header-preprocessing" - flag_group { - flag: "-xc++" - flag: "-E" - } - } - } - - feature { - name: "per_object_debug_info" - flag_set { - action: "c-compile" - action: "c++-compile" - action: "c++-module-codegen" - action: "assemble" - action: "preprocess-assemble" - action: "lto-backend" - flag_group { - flag: "-gsplit-dwarf" - flag: "-ggnu-pubnames" - } - } - flag_set { - action: "c++-link-executable" - action: "c++-link-dynamic-library" - action: "c++-link-interface-dynamic-library" - flag_group { - expand_if_all_available: "is_using_fission" - flag: "-Wl,--gdb-index" - } - } - } - - feature { - name: "xray" - requires { - feature: "llvm_codegen" - feature: "nonhost" - } - flag_set { - action: "c-compile" - action: "c++-compile" - action: "c++-header-parsing" - action: "c++-header-preprocessing" - action: "c++-module-compile" - action: "c++-link-interface-dynamic-library" - action: "c++-link-dynamic-library" - action: "c++-link-executable" - flag_group { - flag: "-fxray-instrument" - } - } - } - - feature { - name: "minimal_ubsan" - requires { feature: "llvm_codegen" } - flag_set { - action: "c-compile" - action: "c++-compile" - action: "c++-header-parsing" - action: "c++-header-preprocessing" - action: "c++-module-compile" - action: "c++-module-codegen" - flag_group { - flag: "-fsanitize=return,returns-nonnull-attribute,vla-bound,unreachable,float-cast-overflow" - flag: "-fsanitize-trap=all" - flag: "-DUNDEFINED_BEHAVIOR_SANITIZER" - } - } - } - - feature { - name: "minimal_ubsan_enabled_by_default" - requires { - feature: "llvm_codegen" - feature: "fastbuild" - } - enabled: true - implies: "minimal_ubsan" - } - - cxx_builtin_include_directory: "%package(@local_config_clang6//clang6)%/llvm/lib/clang/6.0.0/include" - cxx_builtin_include_directory: "/usr/include" - - unfiltered_cxx_flag: "-cxx-isystem" - unfiltered_cxx_flag: "/usr/include/c++/4.8" - unfiltered_cxx_flag: "-cxx-isystem" - unfiltered_cxx_flag: "/usr/include/x86_64-linux-gnu/c++/4.8" - unfiltered_cxx_flag: "-isystem" - unfiltered_cxx_flag: "%package(@local_config_clang6//clang6)%/llvm/lib/clang/6.0.0/include" - unfiltered_cxx_flag: "-isystem" - unfiltered_cxx_flag: "/usr/include/x86_64-linux-gnu" - unfiltered_cxx_flag: "-isystem" - unfiltered_cxx_flag: "/usr/include" - - linker_flag: "-Wl,--build-id=md5" - linker_flag: "-Wl,--fatal-warnings" - linker_flag: "-Wl,--hash-style=gnu" - linker_flag: "-no-canonical-prefixes" - linker_flag: "--target=x86_64-unknown-linux-gnu" - - linker_flag: "-L/usr/lib/gcc/x86_64-linux-gnu/4.8" - - # This is the minimum x86 architecture TensorFlow supports. - compiler_flag: "-m64" - - # These are for Linux. - ld_embed_flag: "-melf_x86_64" - linker_flag: "-Wl,--eh-frame-hdr" - linker_flag: "-Wl,-z,max-page-size=0x1000" - - # Google never uses the stack like a heap, e.g. alloca(), because tcmalloc - # and jemalloc are so fast. However copts=["$(STACK_FRAME_UNLIMITED)"] can be - # specified when that can't be the case. - make_variable { - name: "STACK_FRAME_UNLIMITED" - value: "-Wframe-larger-than=100000000 -Wno-vla" - } - - # These flags are for folks who build C/C++ code inside genrules. - make_variable { - name: "CC_FLAGS" - value: "-no-canonical-prefixes --target=x86_64-unknown-linux-gnu -fno-omit-frame-pointer -fno-tree-vrp -msse3" - } - - feature { - name: "copts" - flag_set { - expand_if_all_available: "copts" - action: "assemble" - action: "preprocess-assemble" - action: "c-compile" - action: "c++-compile" - action: "c++-header-parsing" - action: "c++-header-preprocessing" - action: "c++-module-compile" - action: "c++-module-codegen" - action: "lto-backend" - flag_group { - iterate_over: "copts" - flag: "%{copts}" - } - } - } - - # Please do not statically link libstdc++. This would probably lead to a lot - # of bloat since OpKernels need to use linkstatic=1 because b/27630669 and - # it could cause memory leaks since Python uses dlopen() on our libraries: - # https://stackoverflow.com/a/35015415 - linker_flag: "-lstdc++" - linker_flag: "-lm" - linker_flag: "-lpthread" - linker_flag: "-l:/lib/x86_64-linux-gnu/libc-2.19.so" -} diff --git a/third_party/toolchains/clang6/README.md b/third_party/toolchains/clang6/README.md deleted file mode 100644 index 0c6be25a0ed..00000000000 --- a/third_party/toolchains/clang6/README.md +++ /dev/null @@ -1,101 +0,0 @@ -# TensorFlow Bazel Clang - -This is a specialized toolchain that uses an old Debian with a new Clang that -can cross compile to any x86_64 microarchitecture. It's intended to build Linux -binaries that only require the following ABIs: - -- GLIBC_2.18 -- CXXABI_1.3.7 (GCC 4.8.3) -- GCC_4.2.0 - -Which are available on at least the following Linux platforms: - -- Ubuntu 14+ -- CentOS 7+ -- Debian 8+ -- SuSE 13.2+ -- Mint 17.3+ -- Manjaro 0.8.11 - -# System Install - -On Debian 8 (Jessie) Clang 6.0 can be installed as follows: - -```sh -cat >>/etc/apt/sources.list <<'EOF' -deb http://apt.llvm.org/jessie/ llvm-toolchain-jessie main -deb-src http://apt.llvm.org/jessie/ llvm-toolchain-jessie main -EOF -wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add - -apt-key fingerprint |& grep '6084 F3CF 814B 57C1 CF12 EFD5 15CF 4D18 AF4F 7421' -apt-get update -apt-get install clang lld -``` - -# Bazel Configuration - -This toolchain can compile TensorFlow in 2m30s on a 96-core Skylake GCE VM if -the following `.bazelrc` settings are added: - -``` -startup --host_jvm_args=-Xmx30G -startup --host_jvm_args=-Xms30G -startup --host_jvm_args=-XX:MaxNewSize=3g -startup --host_jvm_args=-XX:-UseAdaptiveSizePolicy -startup --host_jvm_args=-XX:+UseConcMarkSweepGC -startup --host_jvm_args=-XX:TargetSurvivorRatio=70 -startup --host_jvm_args=-XX:SurvivorRatio=6 -startup --host_jvm_args=-XX:+UseCMSInitiatingOccupancyOnly -startup --host_jvm_args=-XX:CMSFullGCsBeforeCompaction=1 -startup --host_jvm_args=-XX:CMSInitiatingOccupancyFraction=75 - -build --jobs=100 -build --local_resources=200000,100,100 -build --crosstool_top=@local_config_clang6//clang6 -build --noexperimental_check_output_files -build --nostamp -build --config=opt -build --noexperimental_check_output_files -build --copt=-march=native -build --host_copt=-march=native -``` - -# x86_64 Microarchitectures - -## Intel CPU Line - -- 2003 P6 M SSE SSE2 -- 2004 prescott SSE3 SSSE3 (-march=prescott) -- 2006 core X64 SSE4.1 (only on 45nm variety) (-march=core2) -- 2008 nehalem SSE4.2 VT-x VT-d (-march=nehalem) -- 2010 westmere CLMUL AES (-march=westmere) -- 2012 sandybridge AVX TXT (-march=sandybridge) -- 2012 ivybridge F16C MOVBE (-march=ivybridge) -- 2013 haswell AVX2 TSX BMI2 FMA (-march=haswell) -- 2014 broadwell RDSEED ADCX PREFETCHW (-march=broadwell - works on trusty gcc4.9) -- 2015 skylake SGX ADX MPX AVX-512[xeon-only] (-march=skylake / -march=skylake-avx512 - needs gcc7) -- 2018 cannonlake AVX-512 SHA (-march=cannonlake - needs clang5) - -## Intel Low Power CPU Line - -- 2013 silvermont SSE4.1 SSE4.2 VT-x (-march=silvermont) -- 2016 goldmont SHA (-march=goldmont - needs clang5) - -## AMD CPU Line - -- 2003 k8 SSE SSE2 (-march=k8) -- 2005 k8 (Venus) SSE3 (-march=k8-sse3) -- 2008 barcelona SSE4a?! (-march=barcelona) -- 2011 bulldozer SSE4.1 SSE4.2 CLMUL AVX AES FMA4?! (-march=bdver1) -- 2011 piledriver FMA (-march=bdver2) -- 2015 excavator AVX2 BMI2 MOVBE (-march=bdver4) - -## Google Compute Engine Supported CPUs - -- 2012 sandybridge 2.6gHz -march=sandybridge -- 2012 ivybridge 2.5gHz -march=ivybridge -- 2013 haswell 2.3gHz -march=haswell -- 2014 broadwell 2.2gHz -march=broadwell -- 2015 skylake 2.0gHz -march=skylake-avx512 - -See: diff --git a/third_party/toolchains/clang6/clang.BUILD b/third_party/toolchains/clang6/clang.BUILD deleted file mode 100644 index 094d69271a9..00000000000 --- a/third_party/toolchains/clang6/clang.BUILD +++ /dev/null @@ -1,160 +0,0 @@ -package(default_visibility = ["//visibility:public"]) - -# Please note that the output of these tools is unencumbered. -licenses(["restricted"]) # NCSA, GPLv3 (e.g. gold) - -filegroup( - name = "ar", - srcs = ["llvm/bin/llvm-ar"], - output_licenses = ["unencumbered"], -) - -filegroup( - name = "as", - srcs = ["llvm/bin/llvm-as"], - output_licenses = ["unencumbered"], -) - -filegroup( - name = "cpp", - srcs = ["llvm/bin/llvm-cpp"], - output_licenses = ["unencumbered"], -) - -filegroup( - name = "dwp", - srcs = ["llvm/bin/llvm-dwp"], - output_licenses = ["unencumbered"], -) - -filegroup( - name = "gcc", - srcs = ["llvm/bin/clang"], - output_licenses = ["unencumbered"], -) - -filegroup( - name = "gcov", - srcs = ["llvm/bin/llvm-cov"], - output_licenses = ["unencumbered"], -) - -filegroup( - name = "ld", - srcs = ["llvm/bin/ld.lld"], - output_licenses = ["unencumbered"], -) - -filegroup( - name = "nm", - srcs = ["llvm/bin/llvm-nm"], - output_licenses = ["unencumbered"], -) - -filegroup( - name = "objcopy", - srcs = ["llvm/bin/llvm-objcopy"], - output_licenses = ["unencumbered"], -) - -filegroup( - name = "objdump", - srcs = ["llvm/bin/llvm-objdump"], - output_licenses = ["unencumbered"], -) - -filegroup( - name = "profdata", - srcs = ["llvm/bin/llvm-profdata"], - output_licenses = ["unencumbered"], -) - -filegroup( - name = "strip", - srcs = ["sbin/strip"], - output_licenses = ["unencumbered"], -) - -filegroup( - name = "xray", - srcs = ["llvm/bin/llvm-xray"], - output_licenses = ["unencumbered"], -) - -filegroup( - name = "includes", - srcs = glob(["llvm/lib/clang/6.0.0/include/**"]), - output_licenses = ["unencumbered"], -) - -filegroup( - name = "libraries", - srcs = glob([ - "lib/*.*", - "lib/clang/6.0.0/lib/linux/*.*", - ]), - output_licenses = ["unencumbered"], -) - -filegroup( - name = "compiler_files", - srcs = [ - ":as", - ":gcc", - ":includes", - ], - output_licenses = ["unencumbered"], -) - -filegroup( - name = "linker_files", - srcs = [ - ":ar", - ":ld", - ":libraries", - ], - output_licenses = ["unencumbered"], -) - -filegroup( - name = "all_files", - srcs = [ - ":compiler_files", - ":dwp", - ":gcov", - ":linker_files", - ":nm", - ":objcopy", - ":objdump", - ":profdata", - ":strip", - ":xray", - ], - output_licenses = ["unencumbered"], -) - -filegroup( - name = "empty", - srcs = [], # bazel crashes without this - output_licenses = ["unencumbered"], -) - -cc_toolchain_suite( - name = "clang6", - toolchains = { - "k8|clang6": ":clang6-k8", - }, -) - -cc_toolchain( - name = "clang6-k8", - all_files = ":all_files", - compiler_files = ":compiler_files", - cpu = "k8", - dwp_files = ":dwp", - linker_files = ":linker_files", - objcopy_files = ":objcopy", - output_licenses = ["unencumbered"], - strip_files = ":strip", - supports_param_files = 1, -) diff --git a/third_party/toolchains/clang6/repo.bzl b/third_party/toolchains/clang6/repo.bzl deleted file mode 100644 index e4b6422c96d..00000000000 --- a/third_party/toolchains/clang6/repo.bzl +++ /dev/null @@ -1,37 +0,0 @@ -"""Repository rule for Debian 8 Jessie Clang-6.0 portable Linux builds.""" - -def _clang6_configure(ctx): - # TODO(jart): It'd probably be better to use Bazel's struct.to_proto() - # method to generate a gigantic CROSSTOOL file that allows - # Clang to support everything. - ctx.symlink( - ctx.os.environ.get( - "TF_LLVM_PATH", - "/usr/lib/llvm-6.0", - ), - "clang6/llvm", - ) - ctx.symlink( - ctx.os.environ.get("STRIP", "/usr/bin/strip"), - "clang6/sbin/strip", - ) - ctx.symlink( - ctx.os.environ.get("OBJDUMP", "/usr/bin/objdump"), - "clang6/sbin/objdump", - ) - ctx.symlink(ctx.attr._build, "clang6/BUILD") - ctx.template("clang6/CROSSTOOL", ctx.attr._crosstool, { - "%package(@local_config_clang6//clang6)%": str(ctx.path("clang6")), - }) - -clang6_configure = repository_rule( - implementation = _clang6_configure, - attrs = { - "_build": attr.label( - default = str(Label("//third_party/toolchains/clang6:clang.BUILD")), - ), - "_crosstool": attr.label( - default = str(Label("//third_party/toolchains/clang6:CROSSTOOL.tpl")), - ), - }, -) diff --git a/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2/BUILD b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2/BUILD new file mode 100755 index 00000000000..fc2064c816d --- /dev/null +++ b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2/BUILD @@ -0,0 +1,175 @@ +# This file is expanded from a template by cuda_configure.bzl +# Update cuda_configure.bzl#verify_build_defines when adding new variables. + +load(":cc_toolchain_config.bzl", "cc_toolchain_config") + +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +toolchain( + name = "toolchain-linux-x86_64", + exec_compatible_with = [ + "@bazel_tools//platforms:linux", + "@bazel_tools//platforms:x86_64", + ], + target_compatible_with = [ + "@bazel_tools//platforms:linux", + "@bazel_tools//platforms:x86_64", + ], + toolchain = ":cc-compiler-local", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain_suite( + name = "toolchain", + toolchains = { + "local|compiler": ":cc-compiler-local", + "darwin|compiler": ":cc-compiler-darwin", + "x64_windows|msvc-cl": ":cc-compiler-windows", + "x64_windows": ":cc-compiler-windows", + "arm": ":cc-compiler-local", + "aarch64": ":cc-compiler-local", + "k8": ":cc-compiler-local", + "piii": ":cc-compiler-local", + "ppc": ":cc-compiler-local", + "darwin": ":cc-compiler-darwin", + }, +) + +cc_toolchain( + name = "cc-compiler-local", + all_files = ":crosstool_wrapper_driver_is_not_gcc", + ar_files = ":crosstool_wrapper_driver_is_not_gcc", + as_files = ":crosstool_wrapper_driver_is_not_gcc", + compiler_files = ":crosstool_wrapper_driver_is_not_gcc", + dwp_files = ":empty", + linker_files = ":crosstool_wrapper_driver_is_not_gcc", + objcopy_files = ":empty", + strip_files = ":empty", + # To support linker flags that need to go to the start of command line + # we need the toolchain to support parameter files. Parameter files are + # last on the command line and contain all shared libraries to link, so all + # regular options will be left of them. + supports_param_files = 1, + toolchain_config = ":cc-compiler-local-config", + toolchain_identifier = "local_linux", +) + +cc_toolchain_config( + name = "cc-compiler-local-config", + builtin_include_directories = [ + "/dt7/usr/include/c++/7", + "/dt7/usr/include/c++/7/x86_64-pc-linux-gnu", + "/dt7/usr/include/c++/7/backward", + "/dt7/usr/lib/gcc/x86_64-pc-linux-gnu/7/include", + "/dt7/usr/lib/gcc/x86_64-pc-linux-gnu/7/include-fixed", + "/dt7/usr/include", + "/usr/local/cuda-11.2/targets/x86_64-linux/include", + "/usr/local/cuda-11.2/include", + "/usr/local/cuda-11.2/extras/CUPTI/include", + "/usr/include", + ], + builtin_sysroot = "", + cpu = "local", + cuda_path = "", + extra_no_canonical_prefixes_flags = ["-fno-canonical-system-headers"], + host_compiler_path = "clang/bin/crosstool_wrapper_driver_is_not_gcc", + host_compiler_prefix = "/usr/bin", + host_compiler_warnings = [], + host_unfiltered_compile_flags = [], + linker_bin_path = "/usr/bin", +) + +cc_toolchain( + name = "cc-compiler-darwin", + all_files = ":crosstool_wrapper_driver_is_not_gcc", + ar_files = ":crosstool_wrapper_driver_is_not_gcc", + as_files = ":crosstool_wrapper_driver_is_not_gcc", + compiler_files = ":crosstool_wrapper_driver_is_not_gcc", + dwp_files = ":empty", + linker_files = ":crosstool_wrapper_driver_is_not_gcc", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 0, + toolchain_config = ":cc-compiler-local-darwin", + toolchain_identifier = "local_darwin", +) + +cc_toolchain_config( + name = "cc-compiler-local-darwin", + builtin_include_directories = [ + "/dt7/usr/include/c++/7", + "/dt7/usr/include/c++/7/x86_64-pc-linux-gnu", + "/dt7/usr/include/c++/7/backward", + "/dt7/usr/lib/gcc/x86_64-pc-linux-gnu/7/include", + "/dt7/usr/lib/gcc/x86_64-pc-linux-gnu/7/include-fixed", + "/dt7/usr/include", + "/usr/local/cuda-11.2/targets/x86_64-linux/include", + "/usr/local/cuda-11.2/include", + "/usr/local/cuda-11.2/extras/CUPTI/include", + "/usr/include", + ], + cpu = "darwin", + extra_no_canonical_prefixes_flags = ["-fno-canonical-system-headers"], + host_compiler_path = "clang/bin/crosstool_wrapper_driver_is_not_gcc", + host_compiler_prefix = "/usr/bin", + host_compiler_warnings = [], + host_unfiltered_compile_flags = [], + linker_bin_path = "/usr/bin", +) + +cc_toolchain( + name = "cc-compiler-windows", + all_files = ":windows_msvc_wrapper_files", + ar_files = ":windows_msvc_wrapper_files", + as_files = ":windows_msvc_wrapper_files", + compiler_files = ":windows_msvc_wrapper_files", + dwp_files = ":empty", + linker_files = ":windows_msvc_wrapper_files", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":cc-compiler-windows-config", + toolchain_identifier = "local_windows", +) + +cc_toolchain_config( + name = "cc-compiler-windows-config", + builtin_include_directories = [ + "/dt7/usr/include/c++/7", + "/dt7/usr/include/c++/7/x86_64-pc-linux-gnu", + "/dt7/usr/include/c++/7/backward", + "/dt7/usr/lib/gcc/x86_64-pc-linux-gnu/7/include", + "/dt7/usr/lib/gcc/x86_64-pc-linux-gnu/7/include-fixed", + "/dt7/usr/include", + "/usr/local/cuda-11.2/targets/x86_64-linux/include", + "/usr/local/cuda-11.2/include", + "/usr/local/cuda-11.2/extras/CUPTI/include", + "/usr/include", + ], + cpu = "x64_windows", + msvc_cl_path = "msvc_not_used", + msvc_env_include = "msvc_not_used", + msvc_env_lib = "msvc_not_used", + msvc_env_path = "msvc_not_used", + msvc_env_tmp = "msvc_not_used", + msvc_lib_path = "msvc_not_used", + msvc_link_path = "msvc_not_used", + msvc_ml_path = "msvc_not_used", +) + +filegroup( + name = "empty", + srcs = [], +) + +filegroup( + name = "crosstool_wrapper_driver_is_not_gcc", + srcs = ["clang/bin/crosstool_wrapper_driver_is_not_gcc"], +) + +filegroup( + name = "windows_msvc_wrapper_files", + srcs = glob(["windows/msvc_*"]), +) diff --git a/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2/cc_toolchain_config.bzl b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2/cc_toolchain_config.bzl new file mode 100755 index 00000000000..70197628811 --- /dev/null +++ b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2/cc_toolchain_config.bzl @@ -0,0 +1,1516 @@ +"""cc_toolchain_config rule for configuring CUDA toolchains on Linux, Mac, and Windows.""" + +load( + "@bazel_tools//tools/cpp:cc_toolchain_config_lib.bzl", + "action_config", + "env_entry", + "env_set", + "feature", + "feature_set", + "flag_group", + "flag_set", + "tool", + "tool_path", + "variable_with_value", +) +load( + "@bazel_tools//tools/build_defs/cc:action_names.bzl", + "ASSEMBLE_ACTION_NAME", + "CC_FLAGS_MAKE_VARIABLE_ACTION_NAME", + "CLIF_MATCH_ACTION_NAME", + "CPP_COMPILE_ACTION_NAME", + "CPP_HEADER_PARSING_ACTION_NAME", + "CPP_LINK_DYNAMIC_LIBRARY_ACTION_NAME", + "CPP_LINK_EXECUTABLE_ACTION_NAME", + "CPP_LINK_NODEPS_DYNAMIC_LIBRARY_ACTION_NAME", + "CPP_LINK_STATIC_LIBRARY_ACTION_NAME", + "CPP_MODULE_CODEGEN_ACTION_NAME", + "CPP_MODULE_COMPILE_ACTION_NAME", + "C_COMPILE_ACTION_NAME", + "LINKSTAMP_COMPILE_ACTION_NAME", + "LTO_BACKEND_ACTION_NAME", + "LTO_INDEXING_ACTION_NAME", + "OBJCPP_COMPILE_ACTION_NAME", + "OBJCPP_EXECUTABLE_ACTION_NAME", + "OBJC_ARCHIVE_ACTION_NAME", + "OBJC_COMPILE_ACTION_NAME", + "OBJC_EXECUTABLE_ACTION_NAME", + "OBJC_FULLY_LINK_ACTION_NAME", + "PREPROCESS_ASSEMBLE_ACTION_NAME", + "STRIP_ACTION_NAME", +) + +ACTION_NAMES = struct( + c_compile = C_COMPILE_ACTION_NAME, + cpp_compile = CPP_COMPILE_ACTION_NAME, + linkstamp_compile = LINKSTAMP_COMPILE_ACTION_NAME, + cc_flags_make_variable = CC_FLAGS_MAKE_VARIABLE_ACTION_NAME, + cpp_module_codegen = CPP_MODULE_CODEGEN_ACTION_NAME, + cpp_header_parsing = CPP_HEADER_PARSING_ACTION_NAME, + cpp_module_compile = CPP_MODULE_COMPILE_ACTION_NAME, + assemble = ASSEMBLE_ACTION_NAME, + preprocess_assemble = PREPROCESS_ASSEMBLE_ACTION_NAME, + lto_indexing = LTO_INDEXING_ACTION_NAME, + lto_backend = LTO_BACKEND_ACTION_NAME, + cpp_link_executable = CPP_LINK_EXECUTABLE_ACTION_NAME, + cpp_link_dynamic_library = CPP_LINK_DYNAMIC_LIBRARY_ACTION_NAME, + cpp_link_nodeps_dynamic_library = CPP_LINK_NODEPS_DYNAMIC_LIBRARY_ACTION_NAME, + cpp_link_static_library = CPP_LINK_STATIC_LIBRARY_ACTION_NAME, + strip = STRIP_ACTION_NAME, + objc_archive = OBJC_ARCHIVE_ACTION_NAME, + objc_compile = OBJC_COMPILE_ACTION_NAME, + objc_executable = OBJC_EXECUTABLE_ACTION_NAME, + objc_fully_link = OBJC_FULLY_LINK_ACTION_NAME, + objcpp_compile = OBJCPP_COMPILE_ACTION_NAME, + objcpp_executable = OBJCPP_EXECUTABLE_ACTION_NAME, + clif_match = CLIF_MATCH_ACTION_NAME, + objcopy_embed_data = "objcopy_embed_data", + ld_embed_data = "ld_embed_data", +) + +def _impl(ctx): + if (ctx.attr.cpu == "darwin"): + toolchain_identifier = "local_darwin" + elif (ctx.attr.cpu == "local"): + toolchain_identifier = "local_linux" + elif (ctx.attr.cpu == "x64_windows"): + toolchain_identifier = "local_windows" + else: + fail("Unreachable") + + host_system_name = "local" + + target_system_name = "local" + + if (ctx.attr.cpu == "darwin"): + target_cpu = "darwin" + elif (ctx.attr.cpu == "local"): + target_cpu = "local" + elif (ctx.attr.cpu == "x64_windows"): + target_cpu = "x64_windows" + else: + fail("Unreachable") + + if (ctx.attr.cpu == "local"): + target_libc = "local" + elif (ctx.attr.cpu == "darwin"): + target_libc = "macosx" + elif (ctx.attr.cpu == "x64_windows"): + target_libc = "msvcrt" + else: + fail("Unreachable") + + if (ctx.attr.cpu == "darwin" or + ctx.attr.cpu == "local"): + compiler = "compiler" + elif (ctx.attr.cpu == "x64_windows"): + compiler = "msvc-cl" + else: + fail("Unreachable") + + abi_version = "local" + + abi_libc_version = "local" + + cc_target_os = None + + builtin_sysroot = ctx.attr.builtin_sysroot + + all_link_actions = [ + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ] + + cpp_link_dynamic_library_action = action_config( + action_name = ACTION_NAMES.cpp_link_dynamic_library, + implies = [ + "nologo", + "shared_flag", + "linkstamps", + "output_execpath_flags", + "input_param_flags", + "user_link_flags", + "linker_subsystem_flag", + "linker_param_file", + "msvc_env", + "no_stripping", + "has_configured_linker_path", + "def_file", + ], + tools = [tool(path = ctx.attr.msvc_link_path)], + ) + + cpp_link_nodeps_dynamic_library_action = action_config( + action_name = ACTION_NAMES.cpp_link_nodeps_dynamic_library, + implies = [ + "nologo", + "shared_flag", + "linkstamps", + "output_execpath_flags", + "input_param_flags", + "user_link_flags", + "linker_subsystem_flag", + "linker_param_file", + "msvc_env", + "no_stripping", + "has_configured_linker_path", + "def_file", + ], + tools = [tool(path = ctx.attr.msvc_link_path)], + ) + + cpp_link_static_library_action = action_config( + action_name = ACTION_NAMES.cpp_link_static_library, + implies = [ + "nologo", + "archiver_flags", + "input_param_flags", + "linker_param_file", + "msvc_env", + ], + tools = [tool(path = ctx.attr.msvc_lib_path)], + ) + + assemble_action = action_config( + action_name = ACTION_NAMES.assemble, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "nologo", + "msvc_env", + "sysroot", + ], + tools = [tool(path = ctx.attr.msvc_ml_path)], + ) + + preprocess_assemble_action = action_config( + action_name = ACTION_NAMES.preprocess_assemble, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "nologo", + "msvc_env", + "sysroot", + ], + tools = [tool(path = ctx.attr.msvc_ml_path)], + ) + + c_compile_action = action_config( + action_name = ACTION_NAMES.c_compile, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "nologo", + "msvc_env", + "parse_showincludes", + "user_compile_flags", + "sysroot", + "unfiltered_compile_flags", + ], + tools = [tool(path = ctx.attr.msvc_cl_path)], + ) + + cpp_compile_action = action_config( + action_name = ACTION_NAMES.cpp_compile, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "nologo", + "msvc_env", + "parse_showincludes", + "user_compile_flags", + "sysroot", + "unfiltered_compile_flags", + ], + tools = [tool(path = ctx.attr.msvc_cl_path)], + ) + + cpp_link_executable_action = action_config( + action_name = ACTION_NAMES.cpp_link_executable, + implies = [ + "nologo", + "linkstamps", + "output_execpath_flags", + "input_param_flags", + "user_link_flags", + "linker_subsystem_flag", + "linker_param_file", + "msvc_env", + "no_stripping", + ], + tools = [tool(path = ctx.attr.msvc_link_path)], + ) + + if (ctx.attr.cpu == "darwin" or + ctx.attr.cpu == "local"): + action_configs = [] + elif (ctx.attr.cpu == "x64_windows"): + action_configs = [ + assemble_action, + preprocess_assemble_action, + c_compile_action, + cpp_compile_action, + cpp_link_executable_action, + cpp_link_dynamic_library_action, + cpp_link_nodeps_dynamic_library_action, + cpp_link_static_library_action, + ] + else: + fail("Unreachable") + + no_windows_export_all_symbols_feature = feature(name = "no_windows_export_all_symbols") + + pic_feature = feature( + name = "pic", + enabled = True, + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group(flags = ["-fPIC"], expand_if_available = "pic"), + flag_group( + flags = ["-fPIE"], + expand_if_not_available = "pic", + ), + ], + ), + ], + ) + + preprocessor_defines_feature = feature( + name = "preprocessor_defines", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ], + flag_groups = [ + flag_group( + flags = ["/D%{preprocessor_defines}"], + iterate_over = "preprocessor_defines", + ), + ], + ), + ], + ) + + generate_pdb_file_feature = feature( + name = "generate_pdb_file", + requires = [ + feature_set(features = ["dbg"]), + feature_set(features = ["fastbuild"]), + ], + ) + + linkstamps_feature = feature( + name = "linkstamps", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["%{linkstamp_paths}"], + iterate_over = "linkstamp_paths", + expand_if_available = "linkstamp_paths", + ), + ], + ), + ], + ) + + unfiltered_compile_flags_feature = feature( + name = "unfiltered_compile_flags", + flag_sets = ([ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flags = ctx.attr.host_unfiltered_compile_flags, + ), + ], + ), + ] if ctx.attr.host_unfiltered_compile_flags else []), + ) + + determinism_feature = feature( + name = "determinism", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group( + flags = [ + "-Wno-builtin-macro-redefined", + "-D__DATE__=\"redacted\"", + "-D__TIMESTAMP__=\"redacted\"", + "-D__TIME__=\"redacted\"", + ], + ), + ], + ), + ], + ) + + nologo_feature = feature( + name = "nologo", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.cpp_link_static_library, + ], + flag_groups = [flag_group(flags = ["/nologo"])], + ), + ], + ) + + supports_pic_feature = feature(name = "supports_pic", enabled = True) + + output_execpath_flags_feature = feature( + name = "output_execpath_flags", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["/OUT:%{output_execpath}"], + expand_if_available = "output_execpath", + ), + ], + ), + ], + ) + + default_link_flags_feature = feature( + name = "default_link_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/MACHINE:X64"])], + ), + ], + ) + + if (ctx.attr.cpu == "local"): + hardening_feature = feature( + name = "hardening", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group( + flags = [ + "-U_FORTIFY_SOURCE", + "-D_FORTIFY_SOURCE=1", + "-fstack-protector", + ], + ), + ], + ), + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [flag_group(flags = ["-Wl,-z,relro,-z,now"])], + ), + flag_set( + actions = [ACTION_NAMES.cpp_link_executable], + flag_groups = [flag_group(flags = ["-pie", "-Wl,-z,relro,-z,now"])], + ), + ], + ) + elif (ctx.attr.cpu == "darwin"): + hardening_feature = feature( + name = "hardening", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group( + flags = [ + "-U_FORTIFY_SOURCE", + "-D_FORTIFY_SOURCE=1", + "-fstack-protector", + ], + ), + ], + ), + flag_set( + actions = [ACTION_NAMES.cpp_link_executable], + flag_groups = [flag_group(flags = ["-pie"])], + ), + ], + ) + else: + hardening_feature = None + + supports_dynamic_linker_feature = feature(name = "supports_dynamic_linker", enabled = True) + + targets_windows_feature = feature( + name = "targets_windows", + enabled = True, + implies = ["copy_dynamic_libraries_to_binary"], + ) + + msvc_env_feature = feature( + name = "msvc_env", + env_sets = [ + env_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.cpp_link_static_library, + ], + env_entries = [ + env_entry(key = "PATH", value = ctx.attr.msvc_env_path), + env_entry( + key = "INCLUDE", + value = ctx.attr.msvc_env_include, + ), + env_entry(key = "LIB", value = ctx.attr.msvc_env_lib), + env_entry(key = "TMP", value = ctx.attr.msvc_env_tmp), + env_entry(key = "TEMP", value = ctx.attr.msvc_env_tmp), + ], + ), + ], + ) + + linker_subsystem_flag_feature = feature( + name = "linker_subsystem_flag", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/SUBSYSTEM:CONSOLE"])], + ), + ], + ) + + dynamic_link_msvcrt_no_debug_feature = feature( + name = "dynamic_link_msvcrt_no_debug", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/MD"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEFAULTLIB:msvcrt.lib"])], + ), + ], + requires = [ + feature_set(features = ["fastbuild"]), + feature_set(features = ["opt"]), + ], + ) + + warnings_feature = feature( + name = "warnings", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group( + flags = ["-Wall"] + ctx.attr.host_compiler_warnings, + ), + ], + ), + ], + ) + + dynamic_link_msvcrt_debug_feature = feature( + name = "dynamic_link_msvcrt_debug", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/MDd"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEFAULTLIB:msvcrtd.lib"])], + ), + ], + requires = [feature_set(features = ["dbg"])], + ) + + compiler_output_flags_feature = feature( + name = "compiler_output_flags", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.assemble], + flag_groups = [ + flag_group( + flag_groups = [ + flag_group( + flags = ["/Fo%{output_file}", "/Zi"], + expand_if_not_available = "output_preprocess_file", + ), + ], + expand_if_available = "output_file", + expand_if_not_available = "output_assembly_file", + ), + ], + ), + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flag_groups = [ + flag_group( + flags = ["/Fo%{output_file}"], + expand_if_not_available = "output_preprocess_file", + ), + ], + expand_if_available = "output_file", + expand_if_not_available = "output_assembly_file", + ), + flag_group( + flag_groups = [ + flag_group( + flags = ["/Fa%{output_file}"], + expand_if_available = "output_assembly_file", + ), + ], + expand_if_available = "output_file", + ), + flag_group( + flag_groups = [ + flag_group( + flags = ["/P", "/Fi%{output_file}"], + expand_if_available = "output_preprocess_file", + ), + ], + expand_if_available = "output_file", + ), + ], + ), + ], + ) + + default_compile_flags_feature = feature( + name = "default_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ], + flag_groups = [ + flag_group( + flags = [ + "/DCOMPILER_MSVC", + "/DNOMINMAX", + "/D_WIN32_WINNT=0x0600", + "/D_CRT_SECURE_NO_DEPRECATE", + "/D_CRT_SECURE_NO_WARNINGS", + "/D_SILENCE_STDEXT_HASH_DEPRECATION_WARNINGS", + "/bigobj", + "/Zm500", + "/J", + "/Gy", + "/GF", + "/EHsc", + "/wd4351", + "/wd4291", + "/wd4250", + "/wd4996", + ], + ), + ], + ), + ], + ) + + static_link_msvcrt_debug_feature = feature( + name = "static_link_msvcrt_debug", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/MTd"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEFAULTLIB:libcmtd.lib"])], + ), + ], + requires = [feature_set(features = ["dbg"])], + ) + + static_link_msvcrt_feature = feature(name = "static_link_msvcrt") + + if (ctx.attr.cpu == "darwin" or + ctx.attr.cpu == "local"): + dbg_feature = feature( + name = "dbg", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["-g"])], + ), + ], + implies = ["common"], + ) + elif (ctx.attr.cpu == "x64_windows"): + dbg_feature = feature( + name = "dbg", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/Od", "/Z7", "/DDEBUG"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEBUG:FULL", "/INCREMENTAL:NO"])], + ), + ], + implies = ["generate_pdb_file"], + ) + else: + dbg_feature = None + + undefined_dynamic_feature = feature( + name = "undefined-dynamic", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.cpp_link_executable, + ], + flag_groups = [flag_group(flags = ["-undefined", "dynamic_lookup"])], + ), + ], + ) + + parse_showincludes_feature = feature( + name = "parse_showincludes", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_header_parsing, + ], + flag_groups = [flag_group(flags = ["/showIncludes"])], + ), + ], + ) + + linker_param_file_feature = feature( + name = "linker_param_file", + flag_sets = [ + flag_set( + actions = all_link_actions + + [ACTION_NAMES.cpp_link_static_library], + flag_groups = [ + flag_group( + flags = ["@%{linker_param_file}"], + expand_if_available = "linker_param_file", + ), + ], + ), + ], + ) + + static_link_msvcrt_no_debug_feature = feature( + name = "static_link_msvcrt_no_debug", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/MT"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEFAULTLIB:libcmt.lib"])], + ), + ], + requires = [ + feature_set(features = ["fastbuild"]), + feature_set(features = ["opt"]), + ], + ) + + supports_interface_shared_libraries_feature = feature( + name = "supports_interface_shared_libraries", + enabled = True, + ) + + disable_assertions_feature = feature( + name = "disable-assertions", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["-DNDEBUG"])], + ), + ], + ) + + if (ctx.attr.cpu == "x64_windows"): + fastbuild_feature = feature( + name = "fastbuild", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/Od", "/Z7", "/DDEBUG"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group(flags = ["/DEBUG:FASTLINK", "/INCREMENTAL:NO"]), + ], + ), + ], + implies = ["generate_pdb_file"], + ) + elif (ctx.attr.cpu == "darwin" or + ctx.attr.cpu == "local"): + fastbuild_feature = feature(name = "fastbuild", implies = ["common"]) + else: + fastbuild_feature = None + + user_compile_flags_feature = feature( + name = "user_compile_flags", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flags = ["%{user_compile_flags}"], + iterate_over = "user_compile_flags", + expand_if_available = "user_compile_flags", + ), + ], + ), + ], + ) + + compiler_input_flags_feature = feature( + name = "compiler_input_flags", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flags = ["/c", "%{source_file}"], + expand_if_available = "source_file", + ), + ], + ), + ], + ) + + no_legacy_features_feature = feature(name = "no_legacy_features") + + archiver_flags_feature = feature( + name = "archiver_flags", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.cpp_link_static_library], + flag_groups = [ + flag_group( + flags = ["/OUT:%{output_execpath}"], + expand_if_available = "output_execpath", + ), + ], + ), + ], + ) + + redirector_feature = feature( + name = "redirector", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ], + flag_groups = [ + flag_group( + flags = [ + "-B", + "external/local_config_cuda/crosstool/windows/msvc_wrapper_for_nvcc.py", + ], + ), + ], + ), + ], + ) + + linker_bin_path_feature = feature( + name = "linker-bin-path", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["-B" + ctx.attr.linker_bin_path])], + ), + ], + ) + + if (ctx.attr.cpu == "local"): + opt_feature = feature( + name = "opt", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group( + flags = ["-g0", "-O2", "-ffunction-sections", "-fdata-sections"], + ), + ], + ), + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.cpp_link_executable, + ], + flag_groups = [flag_group(flags = ["-Wl,--gc-sections"])], + ), + ], + implies = ["common", "disable-assertions"], + ) + elif (ctx.attr.cpu == "darwin"): + opt_feature = feature( + name = "opt", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group( + flags = ["-g0", "-O2", "-ffunction-sections", "-fdata-sections"], + ), + ], + ), + ], + implies = ["common", "disable-assertions"], + ) + elif (ctx.attr.cpu == "x64_windows"): + opt_feature = feature( + name = "opt", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/O2", "/DNDEBUG"])], + ), + ], + ) + else: + opt_feature = None + + include_paths_feature = feature( + name = "include_paths", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ], + flag_groups = [ + flag_group( + flags = ["/I%{quote_include_paths}"], + iterate_over = "quote_include_paths", + ), + flag_group( + flags = ["/I%{include_paths}"], + iterate_over = "include_paths", + ), + flag_group( + flags = ["/I%{system_include_paths}"], + iterate_over = "system_include_paths", + ), + ], + ), + ], + ) + + shared_flag_feature = feature( + name = "shared_flag", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [flag_group(flags = ["/DLL"])], + ), + ], + ) + + windows_export_all_symbols_feature = feature(name = "windows_export_all_symbols") + + frame_pointer_feature = feature( + name = "frame-pointer", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["-fno-omit-frame-pointer"])], + ), + ], + ) + + build_id_feature = feature( + name = "build-id", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["-Wl,--build-id=md5", "-Wl,--hash-style=gnu"], + ), + ], + ), + ], + ) + + sysroot_feature = feature( + name = "sysroot", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [ + flag_group( + flags = ["--sysroot=%{sysroot}"], + iterate_over = "sysroot", + expand_if_available = "sysroot", + ), + ], + ), + ], + ) + + cuda_path_feature = feature( + name = "cuda_path", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [ + flag_group( + flags = ["--cuda-path=" + ctx.attr.cuda_path], + ), + ], + ), + ], + ) + + def_file_feature = feature( + name = "def_file", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["/DEF:%{def_file_path}", "/ignore:4070"], + expand_if_available = "def_file_path", + ), + ], + ), + ], + ) + + if (ctx.attr.cpu == "darwin"): + stdlib_feature = feature( + name = "stdlib", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["-lc++"])], + ), + ], + ) + elif (ctx.attr.cpu == "local"): + stdlib_feature = feature( + name = "stdlib", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["-lstdc++"])], + ), + ], + ) + else: + stdlib_feature = None + + no_stripping_feature = feature(name = "no_stripping") + + alwayslink_feature = feature( + name = "alwayslink", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.cpp_link_executable, + ], + flag_groups = [flag_group(flags = ["-Wl,-no-as-needed"])], + ), + ], + ) + + input_param_flags_feature = feature( + name = "input_param_flags", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [ + flag_group( + flags = ["/IMPLIB:%{interface_library_output_path}"], + expand_if_available = "interface_library_output_path", + ), + ], + ), + flag_set( + actions = all_link_actions + + [ACTION_NAMES.cpp_link_static_library], + flag_groups = [ + flag_group( + iterate_over = "libraries_to_link", + flag_groups = [ + flag_group( + iterate_over = "libraries_to_link.object_files", + flag_groups = [flag_group(flags = ["%{libraries_to_link.object_files}"])], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file_group", + ), + ), + flag_group( + flag_groups = [flag_group(flags = ["%{libraries_to_link.name}"])], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file", + ), + ), + flag_group( + flag_groups = [flag_group(flags = ["%{libraries_to_link.name}"])], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "interface_library", + ), + ), + flag_group( + flag_groups = [ + flag_group( + flags = ["%{libraries_to_link.name}"], + expand_if_false = "libraries_to_link.is_whole_archive", + ), + flag_group( + flags = ["/WHOLEARCHIVE:%{libraries_to_link.name}"], + expand_if_true = "libraries_to_link.is_whole_archive", + ), + ], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "static_library", + ), + ), + ], + expand_if_available = "libraries_to_link", + ), + ], + ), + ], + ) + + if (ctx.attr.cpu == "local"): + no_canonical_prefixes_feature = feature( + name = "no-canonical-prefixes", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [ + flag_group( + flags = [ + "-no-canonical-prefixes", + ] + ctx.attr.extra_no_canonical_prefixes_flags, + ), + ], + ), + ], + ) + elif (ctx.attr.cpu == "darwin"): + no_canonical_prefixes_feature = feature( + name = "no-canonical-prefixes", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [flag_group(flags = ["-no-canonical-prefixes"])], + ), + ], + ) + else: + no_canonical_prefixes_feature = None + + has_configured_linker_path_feature = feature(name = "has_configured_linker_path") + + copy_dynamic_libraries_to_binary_feature = feature(name = "copy_dynamic_libraries_to_binary") + + user_link_flags_feature = feature( + name = "user_link_flags", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["%{user_link_flags}"], + iterate_over = "user_link_flags", + expand_if_available = "user_link_flags", + ), + ], + ), + ], + ) + + cpp11_feature = feature( + name = "c++11", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["-std=c++11"])], + ), + ], + ) + + if (ctx.attr.cpu == "local"): + common_feature = feature( + name = "common", + implies = [ + "stdlib", + "c++11", + "determinism", + "alwayslink", + "hardening", + "warnings", + "frame-pointer", + "build-id", + "no-canonical-prefixes", + "linker-bin-path", + ], + ) + elif (ctx.attr.cpu == "darwin"): + common_feature = feature( + name = "common", + implies = [ + "stdlib", + "c++11", + "determinism", + "hardening", + "warnings", + "frame-pointer", + "no-canonical-prefixes", + "linker-bin-path", + "undefined-dynamic", + ], + ) + else: + common_feature = None + + if (ctx.attr.cpu == "local"): + features = [ + cpp11_feature, + stdlib_feature, + determinism_feature, + alwayslink_feature, + pic_feature, + hardening_feature, + warnings_feature, + frame_pointer_feature, + build_id_feature, + no_canonical_prefixes_feature, + disable_assertions_feature, + linker_bin_path_feature, + common_feature, + opt_feature, + fastbuild_feature, + dbg_feature, + supports_dynamic_linker_feature, + supports_pic_feature, + ] + if ctx.attr.cuda_path: + features.append(cuda_path_feature) + elif (ctx.attr.cpu == "darwin"): + features = [ + cpp11_feature, + stdlib_feature, + determinism_feature, + pic_feature, + hardening_feature, + warnings_feature, + frame_pointer_feature, + no_canonical_prefixes_feature, + disable_assertions_feature, + linker_bin_path_feature, + undefined_dynamic_feature, + common_feature, + opt_feature, + fastbuild_feature, + dbg_feature, + supports_dynamic_linker_feature, + supports_pic_feature, + ] + elif (ctx.attr.cpu == "x64_windows"): + features = [ + no_legacy_features_feature, + redirector_feature, + nologo_feature, + has_configured_linker_path_feature, + no_stripping_feature, + targets_windows_feature, + copy_dynamic_libraries_to_binary_feature, + default_compile_flags_feature, + msvc_env_feature, + include_paths_feature, + preprocessor_defines_feature, + parse_showincludes_feature, + generate_pdb_file_feature, + shared_flag_feature, + linkstamps_feature, + output_execpath_flags_feature, + archiver_flags_feature, + input_param_flags_feature, + linker_subsystem_flag_feature, + user_link_flags_feature, + default_link_flags_feature, + linker_param_file_feature, + static_link_msvcrt_feature, + static_link_msvcrt_no_debug_feature, + dynamic_link_msvcrt_no_debug_feature, + static_link_msvcrt_debug_feature, + dynamic_link_msvcrt_debug_feature, + dbg_feature, + fastbuild_feature, + opt_feature, + user_compile_flags_feature, + sysroot_feature, + unfiltered_compile_flags_feature, + compiler_output_flags_feature, + compiler_input_flags_feature, + def_file_feature, + windows_export_all_symbols_feature, + no_windows_export_all_symbols_feature, + supports_dynamic_linker_feature, + supports_interface_shared_libraries_feature, + ] + else: + fail("Unreachable") + + cxx_builtin_include_directories = ctx.attr.builtin_include_directories + + if (ctx.attr.cpu == "x64_windows"): + tool_paths = [ + tool_path(name = "ar", path = ctx.attr.msvc_lib_path), + tool_path(name = "ml", path = ctx.attr.msvc_ml_path), + tool_path(name = "cpp", path = ctx.attr.msvc_cl_path), + tool_path(name = "gcc", path = ctx.attr.msvc_cl_path), + tool_path(name = "gcov", path = "wrapper/bin/msvc_nop.bat"), + tool_path(name = "ld", path = ctx.attr.msvc_link_path), + tool_path(name = "nm", path = "wrapper/bin/msvc_nop.bat"), + tool_path( + name = "objcopy", + path = "wrapper/bin/msvc_nop.bat", + ), + tool_path( + name = "objdump", + path = "wrapper/bin/msvc_nop.bat", + ), + tool_path( + name = "strip", + path = "wrapper/bin/msvc_nop.bat", + ), + ] + elif (ctx.attr.cpu == "local"): + tool_paths = [ + tool_path(name = "gcc", path = ctx.attr.host_compiler_path), + tool_path(name = "ar", path = ctx.attr.host_compiler_prefix + "/ar"), + tool_path(name = "compat-ld", path = ctx.attr.host_compiler_prefix + "/ld"), + tool_path(name = "cpp", path = ctx.attr.host_compiler_prefix + "/cpp"), + tool_path(name = "dwp", path = ctx.attr.host_compiler_prefix + "/dwp"), + tool_path(name = "gcov", path = ctx.attr.host_compiler_prefix + "/gcov"), + tool_path(name = "ld", path = ctx.attr.host_compiler_prefix + "/ld"), + tool_path(name = "nm", path = ctx.attr.host_compiler_prefix + "/nm"), + tool_path(name = "objcopy", path = ctx.attr.host_compiler_prefix + "/objcopy"), + tool_path(name = "objdump", path = ctx.attr.host_compiler_prefix + "/objdump"), + tool_path(name = "strip", path = ctx.attr.host_compiler_prefix + "/strip"), + ] + elif (ctx.attr.cpu == "darwin"): + tool_paths = [ + tool_path(name = "gcc", path = ctx.attr.host_compiler_path), + tool_path(name = "ar", path = ctx.attr.host_compiler_prefix + "/libtool"), + tool_path(name = "compat-ld", path = ctx.attr.host_compiler_prefix + "/ld"), + tool_path(name = "cpp", path = ctx.attr.host_compiler_prefix + "/cpp"), + tool_path(name = "dwp", path = ctx.attr.host_compiler_prefix + "/dwp"), + tool_path(name = "gcov", path = ctx.attr.host_compiler_prefix + "/gcov"), + tool_path(name = "ld", path = ctx.attr.host_compiler_prefix + "/ld"), + tool_path(name = "nm", path = ctx.attr.host_compiler_prefix + "/nm"), + tool_path(name = "objcopy", path = ctx.attr.host_compiler_prefix + "/objcopy"), + tool_path(name = "objdump", path = ctx.attr.host_compiler_prefix + "/objdump"), + tool_path(name = "strip", path = ctx.attr.host_compiler_prefix + "/strip"), + ] + else: + fail("Unreachable") + + out = ctx.actions.declare_file(ctx.label.name) + ctx.actions.write(out, "Fake executable") + return [ + cc_common.create_cc_toolchain_config_info( + ctx = ctx, + features = features, + action_configs = action_configs, + artifact_name_patterns = [], + cxx_builtin_include_directories = cxx_builtin_include_directories, + toolchain_identifier = toolchain_identifier, + host_system_name = host_system_name, + target_system_name = target_system_name, + target_cpu = target_cpu, + target_libc = target_libc, + compiler = compiler, + abi_version = abi_version, + abi_libc_version = abi_libc_version, + tool_paths = tool_paths, + make_variables = [], + builtin_sysroot = builtin_sysroot, + cc_target_os = cc_target_os, + ), + DefaultInfo( + executable = out, + ), + ] + +cc_toolchain_config = rule( + implementation = _impl, + attrs = { + "cpu": attr.string(mandatory = True, values = ["darwin", "local", "x64_windows"]), + "builtin_include_directories": attr.string_list(), + "extra_no_canonical_prefixes_flags": attr.string_list(), + "host_compiler_path": attr.string(), + "host_compiler_prefix": attr.string(), + "host_compiler_warnings": attr.string_list(), + "host_unfiltered_compile_flags": attr.string_list(), + "linker_bin_path": attr.string(), + "builtin_sysroot": attr.string(), + "cuda_path": attr.string(), + "msvc_cl_path": attr.string(default = "msvc_not_used"), + "msvc_env_include": attr.string(default = "msvc_not_used"), + "msvc_env_lib": attr.string(default = "msvc_not_used"), + "msvc_env_path": attr.string(default = "msvc_not_used"), + "msvc_env_tmp": attr.string(default = "msvc_not_used"), + "msvc_lib_path": attr.string(default = "msvc_not_used"), + "msvc_link_path": attr.string(default = "msvc_not_used"), + "msvc_ml_path": attr.string(default = "msvc_not_used"), + }, + provides = [CcToolchainConfigInfo], + executable = True, +) diff --git a/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2/clang/bin/crosstool_wrapper_driver_is_not_gcc b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2/clang/bin/crosstool_wrapper_driver_is_not_gcc new file mode 100755 index 00000000000..575f4b21415 --- /dev/null +++ b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2/clang/bin/crosstool_wrapper_driver_is_not_gcc @@ -0,0 +1,289 @@ +#!/usr/bin/env python +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Crosstool wrapper for compiling CUDA programs. + +SYNOPSIS: + crosstool_wrapper_is_not_gcc [options passed in by cc_library() + or cc_binary() rule] + +DESCRIPTION: + This script is expected to be called by the cc_library() or cc_binary() bazel + rules. When the option "-x cuda" is present in the list of arguments passed + to this script, it invokes the nvcc CUDA compiler. Most arguments are passed + as is as a string to --compiler-options of nvcc. When "-x cuda" is not + present, this wrapper invokes hybrid_driver_is_not_gcc with the input + arguments as is. + +NOTES: + Changes to the contents of this file must be propagated from + //third_party/gpus/crosstool/crosstool_wrapper_is_not_gcc to + //third_party/gpus/crosstool/v*/*/clang/bin/crosstool_wrapper_is_not_gcc +""" + +from __future__ import print_function + +__author__ = 'keveman@google.com (Manjunath Kudlur)' + +from argparse import ArgumentParser +import os +import subprocess +import re +import sys +import pipes + +# Template values set by cuda_autoconf. +CPU_COMPILER = ('/dt7/usr/bin/gcc') +GCC_HOST_COMPILER_PATH = ('/dt7/usr/bin/gcc') + +NVCC_PATH = '/usr/local/cuda-11.2/bin/nvcc' +PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH) +NVCC_VERSION = '10.1' + +def Log(s): + print('gpus/crosstool: {0}'.format(s)) + + +def GetOptionValue(argv, option): + """Extract the list of values for option from the argv list. + + Args: + argv: A list of strings, possibly the argv passed to main(). + option: The option whose value to extract, with the leading '-'. + + Returns: + A list of values, either directly following the option, + (eg., -opt val1 val2) or values collected from multiple occurrences of + the option (eg., -opt val1 -opt val2). + """ + + parser = ArgumentParser() + parser.add_argument(option, nargs='*', action='append') + option = option.lstrip('-').replace('-', '_') + args, _ = parser.parse_known_args(argv) + if not args or not vars(args)[option]: + return [] + else: + return sum(vars(args)[option], []) + + +def GetHostCompilerOptions(argv): + """Collect the -isystem, -iquote, and --sysroot option values from argv. + + Args: + argv: A list of strings, possibly the argv passed to main(). + + Returns: + The string that can be used as the --compiler-options to nvcc. + """ + + parser = ArgumentParser() + parser.add_argument('-isystem', nargs='*', action='append') + parser.add_argument('-iquote', nargs='*', action='append') + parser.add_argument('--sysroot', nargs=1) + parser.add_argument('-g', nargs='*', action='append') + parser.add_argument('-fno-canonical-system-headers', action='store_true') + parser.add_argument('-no-canonical-prefixes', action='store_true') + + args, _ = parser.parse_known_args(argv) + + opts = '' + + if args.isystem: + opts += ' -isystem ' + ' -isystem '.join(sum(args.isystem, [])) + if args.iquote: + opts += ' -iquote ' + ' -iquote '.join(sum(args.iquote, [])) + if args.g: + opts += ' -g' + ' -g'.join(sum(args.g, [])) + if args.fno_canonical_system_headers: + opts += ' -fno-canonical-system-headers' + if args.no_canonical_prefixes: + opts += ' -no-canonical-prefixes' + if args.sysroot: + opts += ' --sysroot ' + args.sysroot[0] + + return opts + +def _update_options(nvcc_options): + if NVCC_VERSION in ("7.0",): + return nvcc_options + + update_options = { "relaxed-constexpr" : "expt-relaxed-constexpr" } + return [ update_options[opt] if opt in update_options else opt + for opt in nvcc_options ] + +def GetNvccOptions(argv): + """Collect the -nvcc_options values from argv. + + Args: + argv: A list of strings, possibly the argv passed to main(). + + Returns: + The string that can be passed directly to nvcc. + """ + + parser = ArgumentParser() + parser.add_argument('-nvcc_options', nargs='*', action='append') + + args, _ = parser.parse_known_args(argv) + + if args.nvcc_options: + options = _update_options(sum(args.nvcc_options, [])) + return ' '.join(['--'+a for a in options]) + return '' + +def system(cmd): + """Invokes cmd with os.system(). + + Args: + cmd: The command. + + Returns: + The exit code if the process exited with exit() or -signal + if the process was terminated by a signal. + """ + retv = os.system(cmd) + if os.WIFEXITED(retv): + return os.WEXITSTATUS(retv) + else: + return -os.WTERMSIG(retv) + +def InvokeNvcc(argv, log=False): + """Call nvcc with arguments assembled from argv. + + Args: + argv: A list of strings, possibly the argv passed to main(). + log: True if logging is requested. + + Returns: + The return value of calling system('nvcc ' + args) + """ + + host_compiler_options = GetHostCompilerOptions(argv) + nvcc_compiler_options = GetNvccOptions(argv) + opt_option = GetOptionValue(argv, '-O') + m_options = GetOptionValue(argv, '-m') + m_options = ''.join([' -m' + m for m in m_options if m in ['32', '64']]) + include_options = GetOptionValue(argv, '-I') + out_file = GetOptionValue(argv, '-o') + depfiles = GetOptionValue(argv, '-MF') + defines = GetOptionValue(argv, '-D') + defines = ''.join([' -D' + define for define in defines]) + undefines = GetOptionValue(argv, '-U') + undefines = ''.join([' -U' + define for define in undefines]) + std_options = GetOptionValue(argv, '-std') + # Supported -std flags as of CUDA 9.0. Only keep last to mimic gcc/clang. + nvcc_allowed_std_options = ["c++03", "c++11", "c++14"] + std_options = ''.join([' -std=' + define + for define in std_options if define in nvcc_allowed_std_options][-1:]) + fatbin_options = ''.join([' --fatbin-options=' + option + for option in GetOptionValue(argv, '-Xcuda-fatbinary')]) + + # The list of source files get passed after the -c option. I don't know of + # any other reliable way to just get the list of source files to be compiled. + src_files = GetOptionValue(argv, '-c') + + # Pass -w through from host to nvcc, but don't do anything fancier with + # warnings-related flags, since they're not necessarily the same across + # compilers. + warning_options = ' -w' if '-w' in argv else '' + + if len(src_files) == 0: + return 1 + if len(out_file) != 1: + return 1 + + opt = (' -O2' if (len(opt_option) > 0 and int(opt_option[0]) > 0) + else ' -g') + + includes = (' -I ' + ' -I '.join(include_options) + if len(include_options) > 0 + else '') + + # Unfortunately, there are other options that have -c prefix too. + # So allowing only those look like C/C++ files. + src_files = [f for f in src_files if + re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)] + srcs = ' '.join(src_files) + out = ' -o ' + out_file[0] + + nvccopts = '-D_FORCE_INLINES ' + for capability in GetOptionValue(argv, "--cuda-gpu-arch"): + capability = capability[len('sm_'):] + nvccopts += r'-gencode=arch=compute_%s,\"code=sm_%s\" ' % (capability, + capability) + for capability in GetOptionValue(argv, '--cuda-include-ptx'): + capability = capability[len('sm_'):] + nvccopts += r'-gencode=arch=compute_%s,\"code=compute_%s\" ' % (capability, + capability) + nvccopts += nvcc_compiler_options + nvccopts += undefines + nvccopts += defines + nvccopts += std_options + nvccopts += m_options + nvccopts += warning_options + nvccopts += fatbin_options + + if depfiles: + # Generate the dependency file + depfile = depfiles[0] + cmd = (NVCC_PATH + ' ' + nvccopts + + ' --compiler-options "' + host_compiler_options + '"' + + ' --compiler-bindir=' + GCC_HOST_COMPILER_PATH + + ' -I .' + + ' -x cu ' + opt + includes + ' ' + srcs + ' -M -o ' + depfile) + if log: Log(cmd) + exit_status = system(cmd) + if exit_status != 0: + return exit_status + + cmd = (NVCC_PATH + ' ' + nvccopts + + ' --compiler-options "' + host_compiler_options + ' -fPIC"' + + ' --compiler-bindir=' + GCC_HOST_COMPILER_PATH + + ' -I .' + + ' -x cu ' + opt + includes + ' -c ' + srcs + out) + + # TODO(zhengxq): for some reason, 'gcc' needs this help to find 'as'. + # Need to investigate and fix. + cmd = 'PATH=' + PREFIX_DIR + ':$PATH ' + cmd + if log: Log(cmd) + return system(cmd) + + +def main(): + parser = ArgumentParser() + parser.add_argument('-x', nargs=1) + parser.add_argument('--cuda_log', action='store_true') + args, leftover = parser.parse_known_args(sys.argv[1:]) + + if args.x and args.x[0] == 'cuda': + if args.cuda_log: Log('-x cuda') + leftover = [pipes.quote(s) for s in leftover] + if args.cuda_log: Log('using nvcc') + return InvokeNvcc(leftover, log=args.cuda_log) + + # Strip our flags before passing through to the CPU compiler for files which + # are not -x cuda. We can't just pass 'leftover' because it also strips -x. + # We not only want to pass -x to the CPU compiler, but also keep it in its + # relative location in the argv list (the compiler is actually sensitive to + # this). + cpu_compiler_flags = [flag for flag in sys.argv[1:] + if not flag.startswith(('--cuda_log'))] + + return subprocess.call([CPU_COMPILER] + cpu_compiler_flags) + +if __name__ == '__main__': + sys.exit(main()) diff --git a/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2/crosstool_wrapper_driver_is_not_gcc b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2/crosstool_wrapper_driver_is_not_gcc new file mode 100755 index 00000000000..575f4b21415 --- /dev/null +++ b/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2/crosstool_wrapper_driver_is_not_gcc @@ -0,0 +1,289 @@ +#!/usr/bin/env python +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Crosstool wrapper for compiling CUDA programs. + +SYNOPSIS: + crosstool_wrapper_is_not_gcc [options passed in by cc_library() + or cc_binary() rule] + +DESCRIPTION: + This script is expected to be called by the cc_library() or cc_binary() bazel + rules. When the option "-x cuda" is present in the list of arguments passed + to this script, it invokes the nvcc CUDA compiler. Most arguments are passed + as is as a string to --compiler-options of nvcc. When "-x cuda" is not + present, this wrapper invokes hybrid_driver_is_not_gcc with the input + arguments as is. + +NOTES: + Changes to the contents of this file must be propagated from + //third_party/gpus/crosstool/crosstool_wrapper_is_not_gcc to + //third_party/gpus/crosstool/v*/*/clang/bin/crosstool_wrapper_is_not_gcc +""" + +from __future__ import print_function + +__author__ = 'keveman@google.com (Manjunath Kudlur)' + +from argparse import ArgumentParser +import os +import subprocess +import re +import sys +import pipes + +# Template values set by cuda_autoconf. +CPU_COMPILER = ('/dt7/usr/bin/gcc') +GCC_HOST_COMPILER_PATH = ('/dt7/usr/bin/gcc') + +NVCC_PATH = '/usr/local/cuda-11.2/bin/nvcc' +PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH) +NVCC_VERSION = '10.1' + +def Log(s): + print('gpus/crosstool: {0}'.format(s)) + + +def GetOptionValue(argv, option): + """Extract the list of values for option from the argv list. + + Args: + argv: A list of strings, possibly the argv passed to main(). + option: The option whose value to extract, with the leading '-'. + + Returns: + A list of values, either directly following the option, + (eg., -opt val1 val2) or values collected from multiple occurrences of + the option (eg., -opt val1 -opt val2). + """ + + parser = ArgumentParser() + parser.add_argument(option, nargs='*', action='append') + option = option.lstrip('-').replace('-', '_') + args, _ = parser.parse_known_args(argv) + if not args or not vars(args)[option]: + return [] + else: + return sum(vars(args)[option], []) + + +def GetHostCompilerOptions(argv): + """Collect the -isystem, -iquote, and --sysroot option values from argv. + + Args: + argv: A list of strings, possibly the argv passed to main(). + + Returns: + The string that can be used as the --compiler-options to nvcc. + """ + + parser = ArgumentParser() + parser.add_argument('-isystem', nargs='*', action='append') + parser.add_argument('-iquote', nargs='*', action='append') + parser.add_argument('--sysroot', nargs=1) + parser.add_argument('-g', nargs='*', action='append') + parser.add_argument('-fno-canonical-system-headers', action='store_true') + parser.add_argument('-no-canonical-prefixes', action='store_true') + + args, _ = parser.parse_known_args(argv) + + opts = '' + + if args.isystem: + opts += ' -isystem ' + ' -isystem '.join(sum(args.isystem, [])) + if args.iquote: + opts += ' -iquote ' + ' -iquote '.join(sum(args.iquote, [])) + if args.g: + opts += ' -g' + ' -g'.join(sum(args.g, [])) + if args.fno_canonical_system_headers: + opts += ' -fno-canonical-system-headers' + if args.no_canonical_prefixes: + opts += ' -no-canonical-prefixes' + if args.sysroot: + opts += ' --sysroot ' + args.sysroot[0] + + return opts + +def _update_options(nvcc_options): + if NVCC_VERSION in ("7.0",): + return nvcc_options + + update_options = { "relaxed-constexpr" : "expt-relaxed-constexpr" } + return [ update_options[opt] if opt in update_options else opt + for opt in nvcc_options ] + +def GetNvccOptions(argv): + """Collect the -nvcc_options values from argv. + + Args: + argv: A list of strings, possibly the argv passed to main(). + + Returns: + The string that can be passed directly to nvcc. + """ + + parser = ArgumentParser() + parser.add_argument('-nvcc_options', nargs='*', action='append') + + args, _ = parser.parse_known_args(argv) + + if args.nvcc_options: + options = _update_options(sum(args.nvcc_options, [])) + return ' '.join(['--'+a for a in options]) + return '' + +def system(cmd): + """Invokes cmd with os.system(). + + Args: + cmd: The command. + + Returns: + The exit code if the process exited with exit() or -signal + if the process was terminated by a signal. + """ + retv = os.system(cmd) + if os.WIFEXITED(retv): + return os.WEXITSTATUS(retv) + else: + return -os.WTERMSIG(retv) + +def InvokeNvcc(argv, log=False): + """Call nvcc with arguments assembled from argv. + + Args: + argv: A list of strings, possibly the argv passed to main(). + log: True if logging is requested. + + Returns: + The return value of calling system('nvcc ' + args) + """ + + host_compiler_options = GetHostCompilerOptions(argv) + nvcc_compiler_options = GetNvccOptions(argv) + opt_option = GetOptionValue(argv, '-O') + m_options = GetOptionValue(argv, '-m') + m_options = ''.join([' -m' + m for m in m_options if m in ['32', '64']]) + include_options = GetOptionValue(argv, '-I') + out_file = GetOptionValue(argv, '-o') + depfiles = GetOptionValue(argv, '-MF') + defines = GetOptionValue(argv, '-D') + defines = ''.join([' -D' + define for define in defines]) + undefines = GetOptionValue(argv, '-U') + undefines = ''.join([' -U' + define for define in undefines]) + std_options = GetOptionValue(argv, '-std') + # Supported -std flags as of CUDA 9.0. Only keep last to mimic gcc/clang. + nvcc_allowed_std_options = ["c++03", "c++11", "c++14"] + std_options = ''.join([' -std=' + define + for define in std_options if define in nvcc_allowed_std_options][-1:]) + fatbin_options = ''.join([' --fatbin-options=' + option + for option in GetOptionValue(argv, '-Xcuda-fatbinary')]) + + # The list of source files get passed after the -c option. I don't know of + # any other reliable way to just get the list of source files to be compiled. + src_files = GetOptionValue(argv, '-c') + + # Pass -w through from host to nvcc, but don't do anything fancier with + # warnings-related flags, since they're not necessarily the same across + # compilers. + warning_options = ' -w' if '-w' in argv else '' + + if len(src_files) == 0: + return 1 + if len(out_file) != 1: + return 1 + + opt = (' -O2' if (len(opt_option) > 0 and int(opt_option[0]) > 0) + else ' -g') + + includes = (' -I ' + ' -I '.join(include_options) + if len(include_options) > 0 + else '') + + # Unfortunately, there are other options that have -c prefix too. + # So allowing only those look like C/C++ files. + src_files = [f for f in src_files if + re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)] + srcs = ' '.join(src_files) + out = ' -o ' + out_file[0] + + nvccopts = '-D_FORCE_INLINES ' + for capability in GetOptionValue(argv, "--cuda-gpu-arch"): + capability = capability[len('sm_'):] + nvccopts += r'-gencode=arch=compute_%s,\"code=sm_%s\" ' % (capability, + capability) + for capability in GetOptionValue(argv, '--cuda-include-ptx'): + capability = capability[len('sm_'):] + nvccopts += r'-gencode=arch=compute_%s,\"code=compute_%s\" ' % (capability, + capability) + nvccopts += nvcc_compiler_options + nvccopts += undefines + nvccopts += defines + nvccopts += std_options + nvccopts += m_options + nvccopts += warning_options + nvccopts += fatbin_options + + if depfiles: + # Generate the dependency file + depfile = depfiles[0] + cmd = (NVCC_PATH + ' ' + nvccopts + + ' --compiler-options "' + host_compiler_options + '"' + + ' --compiler-bindir=' + GCC_HOST_COMPILER_PATH + + ' -I .' + + ' -x cu ' + opt + includes + ' ' + srcs + ' -M -o ' + depfile) + if log: Log(cmd) + exit_status = system(cmd) + if exit_status != 0: + return exit_status + + cmd = (NVCC_PATH + ' ' + nvccopts + + ' --compiler-options "' + host_compiler_options + ' -fPIC"' + + ' --compiler-bindir=' + GCC_HOST_COMPILER_PATH + + ' -I .' + + ' -x cu ' + opt + includes + ' -c ' + srcs + out) + + # TODO(zhengxq): for some reason, 'gcc' needs this help to find 'as'. + # Need to investigate and fix. + cmd = 'PATH=' + PREFIX_DIR + ':$PATH ' + cmd + if log: Log(cmd) + return system(cmd) + + +def main(): + parser = ArgumentParser() + parser.add_argument('-x', nargs=1) + parser.add_argument('--cuda_log', action='store_true') + args, leftover = parser.parse_known_args(sys.argv[1:]) + + if args.x and args.x[0] == 'cuda': + if args.cuda_log: Log('-x cuda') + leftover = [pipes.quote(s) for s in leftover] + if args.cuda_log: Log('using nvcc') + return InvokeNvcc(leftover, log=args.cuda_log) + + # Strip our flags before passing through to the CPU compiler for files which + # are not -x cuda. We can't just pass 'leftover' because it also strips -x. + # We not only want to pass -x to the CPU compiler, but also keep it in its + # relative location in the argv list (the compiler is actually sensitive to + # this). + cpu_compiler_flags = [flag for flag in sys.argv[1:] + if not flag.startswith(('--cuda_log'))] + + return subprocess.call([CPU_COMPILER] + cpu_compiler_flags) + +if __name__ == '__main__': + sys.exit(main()) diff --git a/third_party/toolchains/remote/BUILD b/third_party/toolchains/remote/BUILD deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/third_party/toolchains/remote/BUILD.tpl b/third_party/toolchains/remote/BUILD.tpl deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/third_party/toolchains/remote/configure.bzl b/third_party/toolchains/remote/configure.bzl deleted file mode 100644 index cc5b9842648..00000000000 --- a/third_party/toolchains/remote/configure.bzl +++ /dev/null @@ -1,43 +0,0 @@ -"""Repository rule for remote GPU autoconfiguration. - -This rule creates the starlark file -//third_party/toolchains/remote:execution.bzl -providing the function `gpu_test_tags`. - -`gpu_test_tags` will return: - - * `local`: if `REMOTE_GPU_TESTING` is false, allowing CPU tests to run - remotely and GPU tests to run locally in the same bazel invocation. - * `remote-gpu`: if `REMOTE_GPU_TESTING` is true; this allows rules to - set an execution requirement that enables a GPU-enabled remote platform. -""" - -_REMOTE_GPU_TESTING = "REMOTE_GPU_TESTING" - -def _flag_enabled(repository_ctx, flag_name): - if flag_name not in repository_ctx.os.environ: - return False - return repository_ctx.os.environ[flag_name].strip() == "1" - -def _remote_execution_configure(repository_ctx): - # If we do not support remote gpu test execution, mark them as local, so we - # can combine remote builds with local gpu tests. - gpu_test_tags = "\"local\"" - if _flag_enabled(repository_ctx, _REMOTE_GPU_TESTING): - gpu_test_tags = "\"remote-gpu\"" - repository_ctx.template( - "remote_execution.bzl", - Label("//third_party/toolchains/remote:execution.bzl.tpl"), - { - "%{gpu_test_tags}": gpu_test_tags, - }, - ) - repository_ctx.template( - "BUILD", - Label("//third_party/toolchains/remote:BUILD.tpl"), - ) - -remote_execution_configure = repository_rule( - implementation = _remote_execution_configure, - environ = [_REMOTE_GPU_TESTING], -) diff --git a/third_party/toolchains/remote/execution.bzl.tpl b/third_party/toolchains/remote/execution.bzl.tpl deleted file mode 100644 index 18858cc0dc0..00000000000 --- a/third_party/toolchains/remote/execution.bzl.tpl +++ /dev/null @@ -1,2 +0,0 @@ -def gpu_test_tags(): - return [%{gpu_test_tags}] diff --git a/third_party/toolchains/remote_config/configs.bzl b/third_party/toolchains/remote_config/configs.bzl deleted file mode 100644 index 2c72e4c2efe..00000000000 --- a/third_party/toolchains/remote_config/configs.bzl +++ /dev/null @@ -1,109 +0,0 @@ -"""Configurations of RBE builds used with remote config.""" - -load("//third_party/toolchains/remote_config:rbe_config.bzl", "tensorflow_local_config", "tensorflow_rbe_config", "tensorflow_rbe_win_config") - -def initialize_rbe_configs(): - tensorflow_local_config( - name = "local_execution", - ) - - tensorflow_rbe_config( - name = "ubuntu16.04-manylinux2010-py3", - os = "ubuntu16.04-manylinux2010", - python_versions = ["3"], - compiler = "", - ) - - tensorflow_rbe_config( - name = "ubuntu16.04-py3-gcc7_manylinux2010-cuda10.0-cudnn7-tensorrt5.1", - compiler = "/dt7/usr/bin/gcc", - compiler_prefix = "/usr/bin", - cuda_version = "10.0", - cudnn_version = "7", - os = "ubuntu16.04-manylinux2010", - python_versions = ["3"], - tensorrt_install_path = "/usr", - tensorrt_version = "5.1", - ) - - tensorflow_rbe_config( - name = "ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0", - compiler = "/dt7/usr/bin/gcc", - compiler_prefix = "/usr/bin", - cuda_version = "10.1", - cudnn_version = "7", - os = "ubuntu16.04-manylinux2010-multipython", - python_versions = ["2.7", "3.5", "3.6", "3.7", "3.8"], - tensorrt_install_path = "/usr", - tensorrt_version = "6.0", - python_install_path = "/usr/local", - ) - - tensorflow_rbe_config( - name = "ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0", - compiler = "/dt7/usr/bin/gcc", - compiler_prefix = "/usr/bin", - cuda_version = "10.1", - cudnn_version = "7", - os = "ubuntu18.04-manylinux2010-multipython", - python_versions = ["2.7", "3.5", "3.6", "3.7", "3.8"], - tensorrt_install_path = "/usr", - tensorrt_version = "6.0", - python_install_path = "/usr/local", - ) - - tensorflow_rbe_config( - name = "ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1", - compiler = "/dt7/usr/bin/gcc", - compiler_prefix = "/usr/bin", - cuda_version = "11.0", - cudnn_version = "8", - os = "ubuntu18.04-manylinux2010-multipython", - python_versions = ["2.7", "3.5", "3.6", "3.7", "3.8"], - tensorrt_install_path = "/usr", - tensorrt_version = "7.1", - python_install_path = "/usr/local", - ) - - # TODO(klimek): Delete this once all users are migrated to a python-version - # independent configuration. In the future, use - # "ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0" instead. - tensorflow_rbe_config( - name = "ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0", - compiler = "/dt7/usr/bin/gcc", - compiler_prefix = "/usr/bin", - cuda_version = "10.1", - cudnn_version = "7", - os = "ubuntu16.04-manylinux2010", - python_versions = ["3"], - tensorrt_install_path = "/usr", - tensorrt_version = "6.0", - ) - - tensorflow_rbe_config( - name = "ubuntu18.04-clang_manylinux2010-cuda11.0-cudnn8-tensorrt7.1", - compiler = "/clang_r7f6f9f4cf966c78a315d15d6e913c43cfa45c47c/bin/clang", - cuda_version = "11.0", - cudnn_version = "8", - os = "ubuntu18.04-manylinux2010-multipython", - python_versions = ["2.7", "3.5", "3.6", "3.7", "3.8"], - tensorrt_install_path = "/usr", - tensorrt_version = "7.1", - sysroot = "/dt7", - python_install_path = "/usr/local", - ) - - tensorflow_rbe_config( - name = "ubuntu18.04-gcc7_manylinux2010-rocm", - compiler = "/dt7/usr/bin/gcc", - compiler_prefix = "/usr/bin", - rocm_version = "3.5", # Any version will do. - os = "ubuntu18.04-manylinux2010-multipython", - python_versions = ["2.7", "3.5", "3.6", "3.7", "3.8"], - python_install_path = "/usr/local", - ) - - tensorflow_rbe_win_config( - name = "windows_py37", - python_bin_path = "C:/Python37/python.exe", - )