diff --git a/.bazelrc b/.bazelrc index 3ad93cdf49f..451cc60fdd1 100644 --- a/.bazelrc +++ b/.bazelrc @@ -100,9 +100,9 @@ build --apple_platform_type=macos # iOS configs for each architecture and the fat binary builds. build:ios --apple_platform_type=ios build:ios --apple_bitcode=embedded --copt=-fembed-bitcode +build:ios --copt=-Wno-c++11-narrowing build:ios_armv7 --config=ios build:ios_armv7 --cpu=ios_armv7 -build:ios_armv7 --copt -Wno-c++11-narrowing build:ios_arm64 --config=ios build:ios_arm64 --cpu=ios_arm64 build:ios_i386 --config=ios @@ -111,7 +111,6 @@ build:ios_x86_64 --config=ios build:ios_x86_64 --cpu=ios_x86_64 build:ios_fat --config=ios build:ios_fat --ios_multi_cpus=armv7,arm64,i386,x86_64 -build:ios_fat --copt -Wno-c++11-narrowing # Config to use a mostly-static build and disable modular op registration # support (this will revert to loading TensorFlow with RTLD_GLOBAL in Python). @@ -202,18 +201,25 @@ build --define=allow_oversize_protos=true build --spawn_strategy=standalone build -c opt -# By default, build TF in C++ 14 mode. -build --cxxopt=-std=c++14 -build --host_cxxopt=-std=c++14 - # Make Bazel print out all options from rc files. build --announce_rc # Other build flags. build --define=grpc_no_ares=true -# Prevent regression of https://github.com/bazelbuild/bazel/issues/7362 -build --incompatible_remove_legacy_whole_archive +# See https://github.com/bazelbuild/bazel/issues/7362 for information on what +# --incompatible_remove_legacy_whole_archive flag does. +# This flag is set to true in Bazel 1.0 and newer versions. We tried to migrate +# Tensorflow to the default, however test coverage wasn't enough to catch the +# errors. +# There is ongoing work on Bazel team's side to provide support for transitive +# shared libraries. As part of migrating to transitive shared libraries, we +# hope to provide a better mechanism for control over symbol exporting, and +# then tackle this issue again. +# +# TODO: Remove this line once TF doesn't depend on Bazel wrapping all library +# archives in -whole_archive -no_whole_archive. +build --noincompatible_remove_legacy_whole_archive # Modular TF build options build:dynamic_kernels --define=dynamic_loaded_kernels=true @@ -224,13 +230,55 @@ build:c++17 --cxxopt=-std=c++1z build:c++17 --cxxopt=-stdlib=libc++ build:c++1z --config=c++17 -# Default paths for TF_SYSTEM_LIBS -build --define=PREFIX=/usr -build --define=LIBDIR=$(PREFIX)/lib -build --define=INCLUDEDIR=$(PREFIX)/include +# Enable using platform specific build settings +build --enable_platform_specific_config # Suppress C++ compiler warnings, otherwise build logs become 10s of MBs. -build --copt=-w +build:linux --copt=-w +build:macos --copt=-w +build:windows --copt=/w + +# Default paths for TF_SYSTEM_LIBS +build:linux --define=PREFIX=/usr +build:linux --define=LIBDIR=$(PREFIX)/lib +build:linux --define=INCLUDEDIR=$(PREFIX)/include +build:macos --define=PREFIX=/usr +build:macos --define=LIBDIR=$(PREFIX)/lib +build:macos --define=INCLUDEDIR=$(PREFIX)/include +# TF_SYSTEM_LIBS do not work on windows. + +# By default, build TF in C++ 14 mode. +build:linux --cxxopt=-std=c++14 +build:linux --host_cxxopt=-std=c++14 +build:macos --cxxopt=-std=c++14 +build:macos --host_cxxopt=-std=c++14 +build:windows --cxxopt=/std:c++14 +build:windows --host_cxxopt=/std:c++14 + +# On windows, we still link everything into a single DLL. +build:windows --config=monolithic + +# Make sure to include as little of windows.h as possible +build:windows --copt=-DWIN32_LEAN_AND_MEAN +build:windows --host_copt=-DWIN32_LEAN_AND_MEAN +build:windows --copt=-DNOGDI +build:windows --host_copt=-DNOGDI + +# Misc build options we need for windows. +build:windows --linkopt=/DEBUG +build:windows --host_linkopt=/DEBUG +build:windows --linkopt=/OPT:REF +build:windows --host_linkopt=/OPT:REF +build:windows --linkopt=/OPT:ICF +build:windows --host_linkopt=/OPT:ICF +build:windows --experimental_strict_action_env=true +build:windows --incompatible_windows_native_test_wrapper + +# Verbose failure logs when something goes wrong +build:windows --verbose_failures + +# On windows, we never cross compile +build:windows --distinct_host_configuration=false # Suppress all warning messages. build:short_logs --output_filter=DONT_MATCH_ANYTHING @@ -335,20 +383,6 @@ build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win_ build:rbe_win --platforms="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803" build:rbe_win --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe -# Misc build options we need for windows -build:rbe_win --copt=-DWIN32_LEAN_AND_MEAN -build:rbe_win --host_copt=-DWIN32_LEAN_AND_MEAN -build:rbe_win --copt=-DNOGDI -build:rbe_win --host_copt=-DNOGDI -build:rbe_win --linkopt=/DEBUG -build:rbe_win --host_linkopt=/DEBUG -build:rbe_win --linkopt=/OPT:REF -build:rbe_win --host_linkopt=/OPT:REF -build:rbe_win --linkopt=/OPT:ICF -build:rbe_win --host_linkopt=/OPT:ICF -build:rbe_win --config=monolithic -build:rbe_win --experimental_strict_action_env=true -build:rbe_win --incompatible_windows_native_test_wrapper # TODO(gunan): Remove once we use MSVC 2019 with latest patches. build:rbe_win --define=override_eigen_strong_inline=true diff --git a/WORKSPACE b/WORKSPACE index babb14b509e..48536a5d1d0 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -89,7 +89,7 @@ swift_rules_dependencies() # files, in case the parsing of those build files depends on the bazel # version we require here. load("//tensorflow:version_check.bzl", "check_bazel_version_at_least") -check_bazel_version_at_least("0.19.0") +check_bazel_version_at_least("1.0.0") load("//third_party/android:android_configure.bzl", "android_configure") android_configure(name="local_config_android") diff --git a/configure.py b/configure.py index e02428a25a2..fedbd470f2d 100644 --- a/configure.py +++ b/configure.py @@ -49,7 +49,7 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc' _TF_WORKSPACE_ROOT = '' _TF_BAZELRC = '' _TF_CURRENT_BAZEL_VERSION = None -_TF_MIN_BAZEL_VERSION = '0.27.1' +_TF_MIN_BAZEL_VERSION = '1.0.0' _TF_MAX_BAZEL_VERSION = '1.1.0' NCCL_LIB_PATHS = [ @@ -1232,20 +1232,6 @@ def is_reduced_optimize_huge_functions_available(environ_cp): def set_windows_build_flags(environ_cp): """Set Windows specific build options.""" - # The non-monolithic build is not supported yet - write_to_bazelrc('build --config monolithic') - # Suppress warning messages - write_to_bazelrc('build --copt=-w --host_copt=-w') - # Fix winsock2.h conflicts - write_to_bazelrc( - 'build --copt=-DWIN32_LEAN_AND_MEAN --host_copt=-DWIN32_LEAN_AND_MEAN ' - '--copt=-DNOGDI --host_copt=-DNOGDI') - # Output more verbose information when something goes wrong - write_to_bazelrc('build --verbose_failures') - # The host and target platforms are the same in Windows build. So we don't - # have to distinct them. This avoids building the same targets twice. - write_to_bazelrc('build --distinct_host_configuration=false') - if is_reduced_optimize_huge_functions_available(environ_cp): write_to_bazelrc( 'build --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions' diff --git a/tensorflow/BUILD b/tensorflow/BUILD index ec879fe2c45..0f299ec13f8 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -2,10 +2,7 @@ # TensorFlow is a computational framework, primarily for use in machine # learning applications. -load("//tensorflow:tensorflow.bzl", "VERSION") -load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object") -load("//tensorflow:tensorflow.bzl", "tf_custom_op_library_additional_deps_impl") -load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary") +load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary") load( "//tensorflow/core/platform:build_config.bzl", "tf_additional_binary_deps", @@ -450,6 +447,7 @@ config_setting( package_group( name = "internal", packages = [ + "//learning/brain/swift/x10/...", "//perftools/accelerators/xprof/api/...", "//tensorflow/...", "//tensorflow_estimator/python/estimator/...", diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 56d65d45faf..c515cc76b9a 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -119,11 +119,11 @@ def _running_from_pip_package(): _current_file_location.startswith(dir_) for dir_ in _site_packages_dirs) if _running_from_pip_package(): - for s in _site_packages_dirs: + for _s in _site_packages_dirs: # TODO(gunan): Add sanity checks to loaded modules here. - plugin_dir = _os.path.join(s, 'tensorflow-plugins') - if _fi.file_exists(plugin_dir): - _ll.load_library(plugin_dir) + _plugin_dir = _os.path.join(_s, 'tensorflow-plugins') + if _fi.file_exists(_plugin_dir): + _ll.load_library(_plugin_dir) # Add module aliases if hasattr(_current_module, 'keras'): @@ -136,3 +136,5 @@ if hasattr(_current_module, 'keras'): setattr(_current_module, "optimizers", optimizers) setattr(_current_module, "initializers", initializers) # pylint: enable=undefined-variable + +# __all__ PLACEHOLDER diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index 97478a18b8a..2b2899c3fe0 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -132,9 +132,10 @@ def _running_from_pip_package(): _current_file_location.startswith(dir_) for dir_ in _site_packages_dirs) if _running_from_pip_package(): - for s in _site_packages_dirs: + for _s in _site_packages_dirs: # TODO(gunan): Add sanity checks to loaded modules here. - plugin_dir = _os.path.join(s, 'tensorflow-plugins') - if _fi.file_exists(plugin_dir): - _ll.load_library(plugin_dir) + _plugin_dir = _os.path.join(_s, 'tensorflow-plugins') + if _fi.file_exists(_plugin_dir): + _ll.load_library(_plugin_dir) +# __all__ PLACEHOLDER diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index fd58c1b173e..cabc3b21e45 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -196,6 +196,12 @@ cc_library( }), ) +cc_library( + name = "tf_status_headers", + hdrs = ["tf_status.h"], + visibility = ["//visibility:public"], +) + cc_library( name = "tf_file_statistics", hdrs = ["tf_file_statistics.h"], diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 89e6a04abd7..bcdfa019459 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -231,7 +231,14 @@ cc_library( srcs = ["shape_inference_helpers.cc"], hdrs = ["shape_inference_helpers.h"], visibility = [":friends"], - deps = ["//tensorflow/core:graph"], + deps = select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:graph", + ], + }), ) # Internal targets below this point. diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 1a459477ac1..8f24ad441a6 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -880,11 +880,8 @@ static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr, mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context); // Parses output_arrays_order from command line option. - absl::flat_hash_set output_set; - std::vector output_arrays_order; - if (!tensorflow::ParseOutputArrayInfo(output_arrays_string, &output_set, - &output_arrays_order) - .ok()) { + std::vector outputs; + if (!tensorflow::ParseOutputArrayInfo(output_arrays_string, &outputs).ok()) { return emitError(loc, "parsing output array info failed ") << output_arrays_string, nullptr; @@ -892,7 +889,7 @@ static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr, return tflite::FlatBufferToMlir( absl::string_view(input->getBufferStart(), input->getBufferSize()), - context, loc, output_arrays_order); + context, loc, outputs); } static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg( diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc index ba945086470..9205ebb101c 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc @@ -384,7 +384,7 @@ class Translator { const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); // Returns opcode index for op identified by the op_name, if already - // available. Otherwise, creates a new OperactorCode using the given `builtin` + // available. Otherwise, creates a new OperatorCode using the given `builtin` // operator and associates it with `op_name`. uint32_t GetOpcodeIndex(const std::string& op_name, tflite::BuiltinOperator builtin); diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 0549eadc88a..221c9aa2adc 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -720,7 +720,8 @@ static LogicalResult Verify(PackOp op) { for (Value *operand : op.getOperands()) { auto other_type = operand->getType().cast(); if (input_type != other_type) - return op.emitOpError("operands should be of the same type"); + return op.emitOpError("operands should be of the same type. got ") + << input_type << ", " << other_type; } return success(); @@ -857,8 +858,8 @@ void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, // // => Value [5, 8, 9] // TODO(b/133341698): Move to tablegen when variadic is supported. -struct RemoveRedunantUnpackPack : public RewritePattern { - explicit RemoveRedunantUnpackPack(MLIRContext *context) +struct RemoveRedundantUnpackPack : public RewritePattern { + explicit RemoveRedundantUnpackPack(MLIRContext *context) : RewritePattern(PackOp::getOperationName(), 2, context) {} PatternMatchResult matchAndRewrite(Operation *op, @@ -896,7 +897,7 @@ struct RemoveRedunantUnpackPack : public RewritePattern { void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -1041,7 +1042,7 @@ struct DropFakeQuant : public RewritePattern { } void rewrite(Operation *op, PatternRewriter &rewriter) const override { - // Replace the matched FakeQuantOp by its primiary operand. + // Replace the matched FakeQuantOp by its primary operand. rewriter.replaceOp(op, op->getOperand(0)); } }; diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index e91f5fa1e8e..c339d96baed 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -32,7 +32,7 @@ def TFL_Dialect : Dialect { Invariants: * All values are of Tensor type (in particular, scalars are - represented using zero-dimentional tensors); + represented using zero-dimensional tensors); }]; let cppNamespace = "TFL"; @@ -603,7 +603,7 @@ def TFL_FCWO_Default : StrEnumAttrCase<"DEFAULT">; def TFL_FCWO_Shuffled4x16i8 : StrEnumAttrCase<"SHUFFLED4x16INT8">; def TFL_FullyConnectedOptionsWeightFormatAttr : - StrEnumAttr<"FullyConectedOptionsWeightsFormat", + StrEnumAttr<"FullyConnectedOptionsWeightsFormat", "fully connected options weights format", [ TFL_FCWO_Default, TFL_FCWO_Shuffled4x16i8 ]>; @@ -1873,9 +1873,9 @@ def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect, x -> max(0, x) }]; - let arguments = (ins AnyTensor:$x); + let arguments = (ins TensorOf<[F32, QUI8, I8]>:$x); - let results = (outs AnyTensor:$y); + let results = (outs TensorOf<[F32, QUI8, I8]>:$y); } def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect, @@ -1888,9 +1888,24 @@ def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect, x -> max(0, min(6, x)) }]; - let arguments = (ins AnyTensor:$x); + let arguments = (ins TensorOf<[F32, QUI8, I8]>:$x); - let results = (outs AnyTensor:$y); + let results = (outs TensorOf<[F32, QUI8, I8]>:$y); +} + +def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect, + SameOperandsAndResultShape, + SameOperandsAndResultsScale]> { + let summary = "Relu1 operator"; + + let description = [{ + Element-wise Relu1 operator + x -> max(-1, min(1, x)) + }]; + + let arguments = (ins TensorOf<[F32, QUI8, I8]>:$x); + + let results = (outs TensorOf<[F32, QUI8, I8]>:$y); } def TFL_ReshapeOp: TFL_Op<"reshape", [ diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index 895d12f61ef..51bd1e4540c 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -237,8 +237,8 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, // Parse output arrays. std::vector output_arrays(model_flags.output_arrays().begin(), model_flags.output_arrays().end()); - TF_RETURN_IF_ERROR(tensorflow::ParseOutputArrayInfo( - output_arrays, &specs.output_arrays, &specs.output_arrays_order)); + TF_RETURN_IF_ERROR( + tensorflow::ParseOutputArrayInfo(output_arrays, &specs.outputs)); // Other flags. bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops(); diff --git a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc index 4beb4ef9ecf..0326d122c07 100644 --- a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc +++ b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc @@ -82,7 +82,7 @@ class ImportQuantStatsPass : public FunctionPass { res->getType().cast().getElementType().isa(); } - // A method to retrive the name for the given op. + // A method to retrieve the name for the given op. OperationToName op_to_name_; // We split the normal names and regex names, since the former can use hash diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization.td b/tensorflow/compiler/mlir/lite/quantization/quantization.td index 33347c85458..f9fcf0e83a0 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization.td +++ b/tensorflow/compiler/mlir/lite/quantization/quantization.td @@ -102,7 +102,7 @@ class AffineOpCoefficient : NativeOpTrait< StrJoinInt<[dim, index]>.result, ">::Impl")>; -// Specify this trait if the op doesn't have quantizable ouput. We shouldn't +// Specify this trait if the op doesn't have quantizable output. We shouldn't // apply quantization on this op. def NoQuantizableResult : NativeOpTrait<"quant::NoQuantizableResult">; diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h b/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h index 9f027d27bc2..3830d11afe4 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h @@ -54,7 +54,7 @@ class SameOperandsAndResultsScale // OpTrait::quant::FixedResultUniformScale< // 8, -128, 390625, -8, 0, 255, false>::Impl> { // -// TODO(fengliuai): create a better way to epxress floating point scale in the +// TODO(fengliuai): create a better way to express floating point scale in the // template argument list. template diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index 5a08cfa93e9..c9f9d6619a3 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -133,10 +133,10 @@ struct ConvertStatsToQDQs : public OpRewritePattern { // quantization parameters are annotated by the Q/DQ op pairs. Each // matched pattern are rewritten by its quantized alternatives. // -// The concret pattern, extends from this base pattern, can specify whether it +// The concrete pattern, extends from this base pattern, can specify whether it // allows "hybrid" operands or results. These "hybrid" operands and results // don't have quantization parameters propagated to, so will be in float in the -// quantized results. The concret pattern should define the following two +// quantized results. The concrete pattern should define the following two // functions: // // bool AllowHybridOperand() const diff --git a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir index 3130c5c2042..ef77288ad27 100644 --- a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir @@ -114,8 +114,8 @@ func @fakequant_notdropfakequant(tensor, f32, f32) -> tensor { // ----- -// CHECK-LABEL: @RemoveRedunantUnpackPack -func @RemoveRedunantUnpackPack(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> { +// CHECK-LABEL: @RemoveRedundantUnpackPack +func @RemoveRedundantUnpackPack(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> { %0:2 = "tfl.unpack"(%arg0) {axis = 0 : i32, num = 2 : i32} : (tensor<2x5xf32>) -> (tensor<5xf32>, tensor<5xf32>) %1 = "tfl.pack"(%0#0, %0#1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<5xf32>, tensor<5xf32>) -> (tensor<2x5xf32>) return %1: tensor<2x5xf32> @@ -125,8 +125,8 @@ func @RemoveRedunantUnpackPack(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> { // ----- -// CHECK-LABEL: @RemoveRedunantPack -func @RemoveRedunantPack(%arg0: tensor<2x5xf32>) -> (tensor<2x5xf32>, tensor<5xf32>) { +// CHECK-LABEL: @RemoveRedundantPack +func @RemoveRedundantPack(%arg0: tensor<2x5xf32>) -> (tensor<2x5xf32>, tensor<5xf32>) { %0:2 = "tfl.unpack"(%arg0) {axis = 0 : i32, num = 2 : i32} : (tensor<2x5xf32>) -> (tensor<5xf32>, tensor<5xf32>) %1 = "tfl.pack"(%0#0, %0#1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<5xf32>, tensor<5xf32>) -> (tensor<2x5xf32>) return %1, %0#0: tensor<2x5xf32>, tensor<5xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/extract-ophint.mlir b/tensorflow/compiler/mlir/lite/tests/extract-ophint.mlir index 215c2d6d94e..bde800897c5 100644 --- a/tensorflow/compiler/mlir/lite/tests/extract-ophint.mlir +++ b/tensorflow/compiler/mlir/lite/tests/extract-ophint.mlir @@ -106,8 +106,8 @@ func @extractStackInputOutputOphint() { // CHECK: %[[PACK:[0-9]*]] = "tfl.pack"(%0, %1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<2x1x16x1xf32> // CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @b92ed354b9f011e99426dc4a3e957995(%[[PACK]]) : (tensor<2x1x16x1xf32>) -> tensor<2x1x16x1xf32> // CHECK: %[[UNPACK:[0-9]*]]:2 = "tfl.unpack"(%[[OP_HINT_CALL]]) {axis = 0 : i32, num = 2 : i32} : (tensor<2x1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -// CHECK-DAG: %[[OUPUT:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#1) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> -// CHECK-DAG: %[[OUPUT_1:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> +// CHECK-DAG: %[[OUTPUT:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#1) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> +// CHECK-DAG: %[[OUTPUT_1:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> %0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> %1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index c2653f3d6f1..27eff39c397 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1,22 +1,22 @@ // RUN: tf-opt %s -tfl-legalize-tf | FileCheck %s --dump-input-on-failure -func @addRelu(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { - %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> - %1 = "tf.Add"(%arg0, %0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> - %2 = "tf.Relu"(%1) : (tensor<1xi32>) -> tensor<1xi32> - %3 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> - %4 = "tf.Add"(%3, %2) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> - %5 = "tf.Relu6"(%4) : (tensor<1xi32>) -> tensor<1xi32> - %6 = "tfl.add"(%5, %3) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> - %7 = "tf.Relu6"(%6) : (tensor<1xi32>) -> tensor<1xi32> - return %7: tensor<1xi32> +func @addRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %1 = "tf.Add"(%arg0, %0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %2 = "tf.Relu"(%1) : (tensor<1xf32>) -> tensor<1xf32> + %3 = "tf.Relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + %4 = "tf.Add"(%3, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %5 = "tf.Relu6"(%4) : (tensor<1xf32>) -> tensor<1xf32> + %6 = "tfl.add"(%5, %3) {fused_activation_function = "NONE"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %7 = "tf.Relu6"(%6) : (tensor<1xf32>) -> tensor<1xf32> + return %7: tensor<1xf32> // CHECK-LABEL: addRelu -// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xi32> -// CHECK: %1 = tfl.add %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xi32> -// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> -// CHECK: %3 = tfl.add %2, %1 {fused_activation_function = "RELU6"} : tensor<1xi32> -// CHECK: %4 = tfl.add %3, %2 {fused_activation_function = "RELU6"} : tensor<1xi32> +// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xf32> +// CHECK: %1 = tfl.add %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xf32> +// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> +// CHECK: %3 = tfl.add %2, %1 {fused_activation_function = "RELU6"} : tensor<1xf32> +// CHECK: %4 = tfl.add %3, %2 {fused_activation_function = "RELU6"} : tensor<1xf32> // CHECK: return } @@ -244,32 +244,32 @@ func @zeros_like(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { // CHECK: "tfl.zeros_like"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> } -func @divRelu(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { - %0 = "tf.Div"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> - %1 = "tf.Div"(%arg0, %0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> - %2 = "tf.Relu"(%1) : (tensor<1xi32>) -> tensor<1xi32> - %3 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> - %4 = "tf.Div"(%3, %2) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> - %5 = "tf.Relu6"(%4) : (tensor<1xi32>) -> tensor<1xi32> - return %5: tensor<1xi32> +func @divRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + %0 = "tf.Div"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %1 = "tf.Div"(%arg0, %0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %2 = "tf.Relu"(%1) : (tensor<1xf32>) -> tensor<1xf32> + %3 = "tf.Relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + %4 = "tf.Div"(%3, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %5 = "tf.Relu6"(%4) : (tensor<1xf32>) -> tensor<1xf32> + return %5: tensor<1xf32> // CHECK-LABEL: divRelu -// CHECK: tfl.div %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xi32> -// CHECK: %1 = tfl.div %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xi32> -// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> -// CHECK: %3 = tfl.div %2, %1 {fused_activation_function = "RELU6"} : tensor<1xi32> +// CHECK: tfl.div %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xf32> +// CHECK: %1 = tfl.div %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xf32> +// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> +// CHECK: %3 = tfl.div %2, %1 {fused_activation_function = "RELU6"} : tensor<1xf32> // CHECK: return } -func @squaredDifferenceRelu(tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> { -^bb0(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>): - %0 = "tf.SquaredDifference"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> - %1 = "tf.Relu6"(%0) : (tensor<1xi32>) -> tensor<1xi32> - return %1: tensor<1xi32> +func @squaredDifferenceRelu(tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> { +^bb0(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>): + %0 = "tf.SquaredDifference"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %1 = "tf.Relu6"(%0) : (tensor<1xf32>) -> tensor<1xf32> + return %1: tensor<1xf32> // CHECK-LABEL: squaredDifferenceRelu -// CHECK: tfl.squared_difference %arg0, %arg1 : tensor<1xi32> -// CHECK: %1 = "tfl.relu6"(%0) : (tensor<1xi32>) -> tensor<1xi32> +// CHECK: tfl.squared_difference %arg0, %arg1 : tensor<1xf32> +// CHECK: %1 = "tfl.relu6"(%0) : (tensor<1xf32>) -> tensor<1xf32> // CHECK: return } diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index 9cf26d35810..ad3b5540dd7 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -434,7 +434,7 @@ func @testEluI32(%arg0: tensor) -> tensor { // ----- -func @testFusedActiviationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) { +func @testFusedActivationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) { // CHECK: "NONE" %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<4xi32> // CHECK: "RELU" @@ -452,7 +452,7 @@ func @testFusedActiviationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) - // ----- -func @testFusedActiviationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { +func @testFusedActivationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { // expected-error @+1 {{attribute 'fused_activation_function' failed to satisfy constraint: fused activation enum}} %0 = tfl.add %arg0, %arg1 {fused_activation_function = "Relu6"} : tensor<4xi32> return %0: tensor<4xi32> @@ -1047,7 +1047,7 @@ func @testConcatInvalidOperandDimSize(%arg0: tensor<1x2xi32>, %arg1: tensor<1x3x // ----- -func @testConcatInvalidOperandDimSizeComaredToPrevInput(%arg0: tensor<1x2xi32>, %arg1: tensor<1x3xi32>) -> tensor { +func @testConcatInvalidOperandDimSizeComparedToPrevInput(%arg0: tensor<1x2xi32>, %arg1: tensor<1x3xi32>) -> tensor { // expected-error @+1 {{'tfl.concatenation' op dimension size of dimension #1 of operand #1 must be equal to dimension size of dimension #1 of operand #0, expected 2, got 3}} %0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<1x3xi32>) -> tensor return %0 : tensor diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 7d63d8df11b..f7913f11f72 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -278,7 +278,7 @@ func @notFuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x %cst2 = constant dense<3.0> : tensor<112x2xf32> %0 = "tfl.depthwise_conv_2d"(%arg0, %cst0, %cst1) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32> - // We cannot fuse this tfl.mul into the preceding conv op becuase %cst2 is not broadcast-compatible to %cst0. + // We cannot fuse this tfl.mul into the preceding conv op because %cst2 is not broadcast-compatible to %cst0. %1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x112x112x2xf32>, tensor<112x2xf32>) -> tensor<1x112x112x2xf32> return %1 : tensor<1x112x112x2xf32> @@ -600,3 +600,25 @@ func @squeezeToReshape(%arg0: tensor<1x1x2xf32>) -> tensor<2xf32> { // CHECK: %[[RESULT:.*]] = "tfl.reshape"(%arg0, %[[CONST:.*]]) : (tensor<1x1x2xf32>, tensor<1xi32>) -> tensor<2xf32> // CHECK: return %[[RESULT]] } + +// CHECK-LABEL: Relu1 +func @Relu1(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + %cst = constant dense<-1.0> : tensor + %cst1 = constant dense<1.0> : tensor + %0 = "tfl.maximum"(%arg0, %cst) : (tensor<2x3xf32>, tensor) -> tensor<2x3xf32> + %1 = "tfl.minimum"(%0, %cst1) : (tensor<2x3xf32>, tensor) -> tensor<2x3xf32> + return %1 : tensor<2x3xf32> + + // CHECK: %[[relu_n1_to_1:[0-9].*]] = "tfl.relu_n1_to_1" +} + +// CHECK-LABEL: Relu1_2 +func @Relu1_2(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + %cst = constant dense<-1.0> : tensor + %cst1 = constant dense<1.0> : tensor + %0 = "tfl.minimum"(%arg0, %cst1) : (tensor<2x3xf32>, tensor) -> tensor<2x3xf32> + %1 = "tfl.maximum"(%0, %cst) : (tensor<2x3xf32>, tensor) -> tensor<2x3xf32> + return %1 : tensor<2x3xf32> + + // CHECK: %[[relu_n1_to_1:[0-9].*]] = "tfl.relu_n1_to_1" +} diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc index b7fa6245cbb..57ce43ec28a 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc @@ -30,7 +30,7 @@ opt output_file_name("o", llvm::cl::desc(""), opt use_splatted_constant( "use-splatted-constant", llvm::cl::desc( - "Replace constants with randonmly generated splatted tensors"), + "Replace constants with randomly generated splatted tensors"), llvm::cl::init(false), llvm::cl::Hidden); // NOLINTNEXTLINE opt input_mlir( diff --git a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc index 6c51b5fb1c6..63cf4240224 100644 --- a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc +++ b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc @@ -282,7 +282,7 @@ struct OphintCompositeOp { // Since we have different aggregation strategies, e.g., "first", "last", // "stack". We don't somehow aggregated to get the outputs for the funcOp. // This function is simply compute the RankedTensorType (shape & element type) - std::map GetAggregatedOuputTypes(OpBuilder* builder) { + std::map GetAggregatedOutputTypes(OpBuilder* builder) { std::map aggregated_output_types; for (const auto& kv : outputs) { const AggregatedOperand& operand = kv.second; @@ -387,11 +387,12 @@ struct OphintCompositeOp { // inputs/outputs indicate edges) Assume the graph is acyclic. The preprocess // does the following: // Compute each operations's in-degress (how many input nodes they're taken) -// Get all consumer operations for every operations. (operation_to_ouputs) +// Get all consumer operations for every operations. (operation_to_outputs) // Get the init_queue (those operations will be processed first). void PreprocessTopoSortGraph( Block* block, std::queue* init_queue, - llvm::DenseMap>* operation_to_ouputs, + llvm::DenseMap>* + operation_to_outputs, llvm::DenseMap* operation_to_in_degrees) { for (auto& op : *block) { if (&op == block->getTerminator()) continue; @@ -412,9 +413,9 @@ void PreprocessTopoSortGraph( } operation_to_in_degrees->try_emplace(&op, input_ops.size()); for (auto* input_op : input_ops) { - auto preceeding_op_it = operation_to_ouputs->find(input_op); - if (preceeding_op_it == operation_to_ouputs->end()) { - auto result = operation_to_ouputs->try_emplace( + auto preceeding_op_it = operation_to_outputs->find(input_op); + if (preceeding_op_it == operation_to_outputs->end()) { + auto result = operation_to_outputs->try_emplace( input_op, llvm::DenseSet()); preceeding_op_it = result.first; } @@ -442,19 +443,19 @@ bool IsSideEffectOp(Operation* op) { // Also assume the block has no arguments. LogicalResult TopoSortOperations(OpBuilder* builder) { std::queue init_queue; - llvm::DenseMap> operation_to_ouputs; + llvm::DenseMap> operation_to_outputs; llvm::DenseMap operation_to_in_degrees; std::vector sorted_ops; PreprocessTopoSortGraph(builder->getBlock(), &init_queue, - &operation_to_ouputs, &operation_to_in_degrees); + &operation_to_outputs, &operation_to_in_degrees); while (!init_queue.empty()) { Operation* current_op = init_queue.front(); init_queue.pop(); sorted_ops.push_back(current_op); - auto current_op_to_output_it = operation_to_ouputs.find(current_op); - if (current_op_to_output_it == operation_to_ouputs.end()) { + auto current_op_to_output_it = operation_to_outputs.find(current_op); + if (current_op_to_output_it == operation_to_outputs.end()) { continue; } for (Operation* output_op : current_op_to_output_it->second) { @@ -467,7 +468,7 @@ LogicalResult TopoSortOperations(OpBuilder* builder) { operation_to_in_degrees.erase(output_op_it); } } - operation_to_ouputs.erase(current_op_to_output_it); + operation_to_outputs.erase(current_op_to_output_it); } // Before we performs the sort. We need to make sure we didn't mess the @@ -629,11 +630,11 @@ LogicalResult ConvertOphintToStub(StringRef stub_name, // Step 4, get aggregated output types. const std::map& aggregated_output_types = - ophint_composite_op.GetAggregatedOuputTypes(builder); + ophint_composite_op.GetAggregatedOutputTypes(builder); // Step 5, create & place the fused op and rewire the inputs. // Here we use a funcOp to represent the fused op. This "funcOp" will be - // coonverted to other ops (like UnidirectionalSequenceRNNOp) in the + // converted to other ops (like UnidirectionalSequenceRNNOp) in the // legalization phase. Operation* inserted_before_op = ophint_composite_op.GetFirstOutputOp(); Operation* fused_op = BuildFusedFuncOp( diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc index 8d3a78f49fe..ed3a9ea5000 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc @@ -191,10 +191,10 @@ LogicalResult BuildUnidirectionalSequenceLSTMOp(FuncOp composite_func_op, return success(); } -LogicalResult ConvertTfLiteFusedOpIfAvaiable(StringRef func_name, - FuncOp composite_func_op, - CallOp call_op, - OpBuilder* builder) { +LogicalResult ConvertTfLiteFusedOpIfAvailable(StringRef func_name, + FuncOp composite_func_op, + CallOp call_op, + OpBuilder* builder) { Operation* fused_op = nullptr; if (func_name == kUnidirectionalSequenceRnn) { // TODO(renjieliu): Validate the func op inputs. @@ -243,8 +243,8 @@ LogicalResult ConvertCallOps(llvm::StringMap* composite_func_ops, StringRef func_name = composite_func_op.getAttr(kTfLiteFunctionName) .cast() .getValue(); - if (failed(ConvertTfLiteFusedOpIfAvaiable(func_name, composite_func_op, - call_op, &builder))) + if (failed(ConvertTfLiteFusedOpIfAvailable(func_name, composite_func_op, + call_op, &builder))) return failure(); composite_func_ops->erase(it); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 43a84b4406a..d8697a8c4e0 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -140,6 +140,8 @@ ElementsAttr ExpandTo4DForDepthwiseConv(Attribute a) { return ExpandTo4DForConvImpl(a, true); } +// Returns shape of a ranked tensor. +// Precondition: output_val's is ranked tensor. DenseElementsAttr GetShape(Value *output_val) { auto output_type = output_val->getType().cast(); auto shape_vector = output_type.getShape(); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 78a14f3b409..905f01d8413 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -267,9 +267,27 @@ multiclass FuseTileBroadcastIntoFollowingBinary { foreach BroadcastingOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in defm : FuseTileBroadcastIntoFollowingBinary; +// Returns shape of a ranked tensor. +// if called without a ranked tensor it will fail. def GetShape: NativeCodeCall<"GetShape($0)">; def : Pat<(TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims), (TFL_ReshapeOp $input, (ConstantOp (GetShape $squeeze_op))), [(AnyStaticShapeTensor $squeeze_op)]>; + +class ValueEquals : Constraint().getNumElements() == 1 &&" + "*$0.cast().getValues().begin() == " # val>>; + +def : Pat<(TFL_MinimumOp (TFL_MaximumOp $input, + (ConstantOp $NegOne)), + (ConstantOp $One)), + (TFL_Relu1Op $input), + [(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>; + +def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input, + (ConstantOp $One)), + (ConstantOp $NegOne)), + (TFL_Relu1Op $input), + [(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>; diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc index 5701b8c1154..c299064a136 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc @@ -68,11 +68,11 @@ class ConvertEmbeddedLookupFunc { if (func_.getNumArguments() != 2) { return func_.emitError() << "Invalid number of arguments in the embedding " - "matmal composite function"; + "matmul composite function"; } if (func_.getType().getNumResults() != 1) { return func_.emitError() << "Invalid number of results in the embedding " - "matmal composite function"; + "matmul composite function"; } return success(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index 44d1cecfd3b..5d139f83933 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -34,7 +34,7 @@ limitations under the License. // NOLINTNEXTLINE static llvm::cl::list quantize_whitelist( "tfl-test-quantize-whitelist", llvm::cl::value_desc("list"), - llvm::cl::desc("comma seprarated list of whitelisted functions to be " + llvm::cl::desc("comma separated list of whitelisted functions to be " "quantized. Only used in tests"), llvm::cl::CommaSeparated); diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index 3da45e930d3..823efdc3ef5 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -400,7 +400,7 @@ class ConvertTFDepthwiseConv2dNative } }; -// StridedSlice can have complicated atributes like begin_axis_mask, +// StridedSlice can have complicated attributes like begin_axis_mask, // end_axis_mask, ellipsis_axis_mask, new_axis_mask, shrink_axis_mask. These // masks will complicate the strided_slice computation logic, we can simplify // the logic by inserting a reshape op to pad the inputs so strided_slice can diff --git a/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc index 8b354cc9875..61d33a5233e 100644 --- a/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc +++ b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc @@ -247,7 +247,7 @@ PatternMatchResult ConvertTFBatchMatMulOp::matchAndRewrite( } if (lhs_shape[dims_a - 1] != rhs_shape[dims_b - 2]) { - // Input dimensions must be compatible for multipication. + // Input dimensions must be compatible for multiplication. return this->matchFailure(); } diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.h b/tensorflow/compiler/mlir/lite/utils/lstm_utils.h index 30449328cec..235d4387faf 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.h @@ -126,7 +126,7 @@ class ConvertLSTMCellSimpleToFusedLSTM { Value* input2cell_; Value* input2output_; - // reccurrent -> cifg + // recurrent -> cifg Value* rec2input_; Value* rec2forget_; Value* rec2cell_; diff --git a/tensorflow/compiler/mlir/runlit.site.cfg.py b/tensorflow/compiler/mlir/runlit.site.cfg.py index 1ae06e36c25..e14199ed43b 100644 --- a/tensorflow/compiler/mlir/runlit.site.cfg.py +++ b/tensorflow/compiler/mlir/runlit.site.cfg.py @@ -28,7 +28,7 @@ config.llvm_tools_dir = os.path.join(os.environ['TEST_SRCDIR'], 'llvm') config.mlir_obj_root = os.path.join(os.environ['TEST_SRCDIR']) config.mlir_tools_dir = os.path.join(os.environ['TEST_SRCDIR'], 'local_config_mlir') -# TODO(jpienaar): Replace with sufffices in build rule. +# TODO(jpienaar): Replace with suffices in build rule. config.suffixes = ['.td', '.mlir', '.pbtxt'] mlir_tf_tools_dirs = [ diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 683c66eced3..5484988d0f5 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -217,6 +217,7 @@ cc_library( "transforms/shape_inference.cc", "transforms/shape_inference_pass.cc", "transforms/sink_constant.cc", + "transforms/test_side_effect_analysis.cc", "transforms/tpu_cluster_formation.cc", "transforms/tpu_merge_variables_with_execute.cc", "transforms/tpu_rewrite_pass.cc", @@ -239,6 +240,7 @@ cc_library( ":error_util", ":export_tf_dialect_op", ":mangling_util", + ":side_effect_analysis", ":tensorflow", ":tensorflow_optimize_inc_gen", ":tpu_rewrite_device_util", @@ -467,6 +469,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:types", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -669,6 +672,7 @@ cc_library( ":import_utils", ":mangling_util", ":mlir_roundtrip_flags", + "//tensorflow/core:graph", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", @@ -981,3 +985,19 @@ cc_library( "@local_config_mlir//:Pass", ], ) + +cc_library( + name = "side_effect_analysis", + srcs = ["analysis/side_effect_analysis.cc"], + hdrs = ["analysis/side_effect_analysis.h"], + deps = [ + ":tensorflow", + "//tensorflow/compiler/tf2xla:resource_operation_table", + "//tensorflow/core:framework", + "@com_google_absl//absl/strings", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:Pass", + "@local_config_mlir//:Support", + ], +) diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc new file mode 100644 index 00000000000..8d43c9330d0 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc @@ -0,0 +1,374 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Block.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/Support/LLVM.h" // TF:local_config_mlir +#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" +#include "tensorflow/core/framework/resource_mgr.h" + +namespace mlir { +namespace TF { + +namespace { + +constexpr int64_t kUnknownResourceId = -1; + +// Returns if a VarHandleOp is anonymous, which means it always creates a new +// variable. +bool IsResourceHandleAnonymous(TF::VarHandleOp handle) { + return handle.shared_name() == tensorflow::ResourceHandle::ANONYMOUS_NAME; +} + +// Returns a string unique identifier for a non-anonymous VarHandleOp. +std::string GetVarHandleStringId(TF::VarHandleOp handle) { + auto device = handle.getAttrOfType("device"); + return absl::StrCat(handle.container().str(), "/", handle.shared_name().str(), + "/", device ? device.getValue().str() : std::string("")); +} + +// Finds a unique ID for a VarHandleOp's output. If it is anonymous, always +// creates a new ID; otherwise, tries to reuse the existing ID for the +// referenced variable if it exists, or creates a new one if not. +int64_t GetOrCreateIdForVarHandle(TF::VarHandleOp handle, int64_t* next_id, + llvm::StringMap* name_id_map) { + // Always create a new ID for anonymous handle. + if (IsResourceHandleAnonymous(handle)) return (*next_id)++; + + auto name = GetVarHandleStringId(handle); + auto emplace_res = name_id_map->try_emplace(name, *next_id); + // New ID created, increment next_id. + if (emplace_res.second) ++(*next_id); + return emplace_res.first->second; +} + +} // namespace + +ResourceAliasAnalysis::ResourceAliasAnalysis(Operation* op) { + auto func_op = llvm::dyn_cast(op); + if (!func_op) return; + AnalyzeFunction(func_op); +} + +void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) { + // This function populates resource_value_to_ids_. + // + // TODO(yuanzx): Pass variable aliasing information to functions so we can + // properly resolve aliasing arguments. + // + // Before having that, we assume function arguments do not alias each other. + int64_t next_unique_id = 0; + for (auto arg : func_op.getArguments()) { + if (!mlir::getElementTypeOrSelf(arg->getType()).isa()) + continue; + resource_value_to_ids_[arg].insert(next_unique_id++); + } + llvm::StringMap var_handle_name_id_map; + auto forward_input_to_output = [&](Value* operand, Value* result) { + if (!mlir::getElementTypeOrSelf(result->getType()).isa()) + return; + auto operand_it = resource_value_to_ids_.find(operand); + assert(operand_it != resource_value_to_ids_.end() && + "A resource-type output does not have the corresponding " + "resource-type input."); + resource_value_to_ids_[result].insert(operand_it->getSecond().begin(), + operand_it->getSecond().end()); + }; + // TODO(yuanzx): Consider control-flow ops. + func_op.walk([&](Operation* op) { + if (auto var_handle = llvm::dyn_cast(op)) { + resource_value_to_ids_[var_handle.resource()].insert( + GetOrCreateIdForVarHandle(var_handle, &next_unique_id, + &var_handle_name_id_map)); + } else if (llvm::isa(op) || + llvm::isa(op)) { + for (auto operand_and_result : + llvm::zip(op->getOperands(), op->getResults())) { + forward_input_to_output(std::get<0>(operand_and_result), + std::get<1>(operand_and_result)); + } + } else { + for (auto result : op->getResults()) { + if (!mlir::getElementTypeOrSelf(result->getType()) + .isa()) + continue; + resource_value_to_ids_[result].insert(kUnknownResourceId); + } + } + }); +} + +bool ResourceAliasAnalysis::IsUnknownResource(const Value* resource) const { + auto it = resource_value_to_ids_.find(resource); + assert(it != resource_value_to_ids_.end() && !it->getSecond().empty()); + // The set is sorted so we only need to check the first element since + // kUnknownResourceId < 0. + static_assert(kUnknownResourceId < 0, + "kUnknownResourceId should be negative"); + return *it->getSecond().begin() == kUnknownResourceId; +} + +const llvm::SmallSet& ResourceAliasAnalysis::GetResourceUniqueIds( + const Value* resource) const { + auto it = resource_value_to_ids_.find(resource); + assert(it != resource_value_to_ids_.end() && "Unseen resource was queried"); + return it->getSecond(); +} + +namespace { + +// Returns a set that contains only kUnknownResourceId. +llvm::SmallDenseSet UnknownResourceSet() { + llvm::SmallDenseSet unknown_set; + unknown_set.insert(kUnknownResourceId); + return unknown_set; +} + +// Returns all resources that could be accessed by op, or UnknownResourceSet() +// if we cannot find all of them. +llvm::SmallDenseSet FindAccessedResources( + Operation* op, const ResourceAliasAnalysis& alias_analysis) { + llvm::SmallDenseSet resources; + + for (auto operand : op->getOperands()) { + if (!mlir::getElementTypeOrSelf(operand->getType()).isa()) + continue; + if (alias_analysis.IsUnknownResource(operand)) return UnknownResourceSet(); + const auto& ids = alias_analysis.GetResourceUniqueIds(operand); + resources.insert(ids.begin(), ids.end()); + } + for (auto result : op->getResults()) { + if (!mlir::getElementTypeOrSelf(result->getType()).isa()) + continue; + if (alias_analysis.IsUnknownResource(result)) return UnknownResourceSet(); + const auto& ids = alias_analysis.GetResourceUniqueIds(result); + resources.insert(ids.begin(), ids.end()); + } + return resources; +} + +// Returns an XlaResourceOpInfo (or nullptr if it does not exist) that specifies +// the resource access type of the op. It tells whether the op is read only, +// etc. +// +// TODO(yuanzx): Define this information in a different place. Currently we use +// tensorflow/compiler/tf2xla/resource_operation_table.h. +const tensorflow::XlaResourceOpInfo* GetResourceInfoForOp(Operation* op) { + auto op_name = op->getName().getStringRef().str(); + if (op->getName().getDialect() != + TF::TensorFlowDialect::getDialectNamespace()) { + return nullptr; + } + return tensorflow::GetResourceOpInfoForOp( + op->getName().getStringRef().split('.').second.str()); +} + +// Returns whether `op` accesses resources and it is known to be read-only. +bool OpIsReadOnly(Operation* op) { + auto resource_op_info = GetResourceInfoForOp(op); + return resource_op_info && + resource_op_info->kind() == tensorflow::XlaResourceOpKind::kRead; +} + +// Returns if `op` is a resource declaration. +bool OpIsDeclaration(Operation* op, + const ResourceAliasAnalysis& alias_analysis) { + // TODO(yuanzx): Add other types of resources. + return llvm::isa(op) || + ((llvm::isa(op) || llvm::isa(op)) && + !FindAccessedResources(op, alias_analysis).empty()); +} + +} // namespace + +void SideEffectAnalysis::TrackAccess(int64_t resource_id, Operation* op, + bool read_only) { + if (resource_id == kUnknownResourceId) { + if (read_only) { + // New unknown read is not tracked by any known resource access. + for (auto& entry : per_resource_access_info_) { + entry.getSecond().tracked_last_unknown_read = false; + } + } else { + // Unknown write can clear all other tracked information, since it acts + // like a barrier. + per_resource_access_info_.clear(); + } + } + auto& info = per_resource_access_info_[resource_id]; + if (read_only) { + info.reads_since_last_write.push_back(op); + // Resource read must have carried control dependencies of unknown write. + info.tracked_last_unknown_write = true; + } else { + // Resource write must have carried control dependencies of unknown access. + info.tracked_last_unknown_write = true; + info.tracked_last_unknown_read = true; + info.last_write = op; + info.reads_since_last_write.clear(); + } +} + +void SideEffectAnalysis::AddPredecessorsForAccess(int64_t resource_id, + Operation* op, + bool read_only) { + auto it = per_resource_access_info_.find(resource_id); + if (it == per_resource_access_info_.end()) return; + const auto& access_info = it->getSecond(); + auto& control_predecessors = control_predecessors_[op]; + bool read_tracked = false; + if (!read_only) { + control_predecessors.insert(access_info.reads_since_last_write.begin(), + access_info.reads_since_last_write.end()); + read_tracked = !access_info.reads_since_last_write.empty(); + } + if (access_info.last_write && !read_tracked) { + control_predecessors.insert(access_info.last_write); + } +} + +void SideEffectAnalysis::AnalyzeFunction( + FuncOp func_op, const ResourceAliasAnalysis& alias_analysis) { + // This function populates control_predecessors_ and control_successors_ by + // walking through func_op's body, and tracking resource accesses in + // per_resource_access_info_. + + // Returns whether an access to `resource` can skip control edges from + // prevoius accesses to unknown resources, due to that earlier accesses to + // `resource` already indirectly tracked previous accesses to uknown + // resources. `read_only` specifies the type of access of the current op being + // considered. + auto unknown_access_indirectly_tracked_by_resource = [&](int64_t resource, + bool read_only) { + auto it = per_resource_access_info_.find(resource); + if (it == per_resource_access_info_.end()) return false; + auto unknown_it = per_resource_access_info_.find(kUnknownResourceId); + const bool no_unknown_read = + unknown_it == per_resource_access_info_.end() || + unknown_it->getSecond().reads_since_last_write.empty(); + return read_only + ? it->second.tracked_last_unknown_write + : it->second.tracked_last_unknown_write && + (it->second.tracked_last_unknown_read || no_unknown_read); + }; + + func_op.walk([&](Operation* op) { + // We do not need explicit control edges for declaration ops. + if (OpIsDeclaration(op, alias_analysis)) return; + + auto resource_op_info = GetResourceInfoForOp(op); + if (!resource_op_info && op->hasNoSideEffect()) return; + + llvm::SmallDenseSet resources = + resource_op_info ? FindAccessedResources(op, alias_analysis) + : UnknownResourceSet(); + assert(!resources.empty()); + const bool is_unknown = resources.count(kUnknownResourceId) > 0; + const bool read_only = OpIsReadOnly(op); + bool indirectly_tracked_unknown_access = false; + // First add edges from known resources. + if (is_unknown) { + for (auto& entry : per_resource_access_info_) { + if (entry.getFirst() == kUnknownResourceId) continue; + AddPredecessorsForAccess(entry.getFirst(), op, read_only); + indirectly_tracked_unknown_access |= + unknown_access_indirectly_tracked_by_resource(entry.getFirst(), + read_only); + } + } else { + for (int64_t resource : resources) { + AddPredecessorsForAccess(resource, op, read_only); + indirectly_tracked_unknown_access |= + unknown_access_indirectly_tracked_by_resource(resource, read_only); + // Update access info for known resources. + TrackAccess(resource, op, read_only); + } + } + // If not indirectly tracked, add edges from the unknown resource. + if (!indirectly_tracked_unknown_access) { + AddPredecessorsForAccess(kUnknownResourceId, op, read_only); + } + if (is_unknown) { + // Update access info for unknown resource. + TrackAccess(kUnknownResourceId, op, read_only); + } + }); + + // Populate control_successors_ based on control_predecessors_. + for (auto& entry : control_predecessors_) { + auto op = entry.getFirst(); + for (auto predecessor : entry.getSecond()) { + control_successors_[predecessor].insert(op); + } + } +} + +llvm::SmallVector SideEffectAnalysis::DirectControlPredecessors( + Operation* op, llvm::function_ref filter) const { + llvm::SmallVector result; + auto it = control_predecessors_.find(op); + if (it == control_predecessors_.end()) return result; + result.reserve(it->getSecond().size()); + for (auto predecessor : it->getSecond()) { + if (!filter || filter(predecessor)) result.push_back(predecessor); + } + llvm::sort(result, + [](Operation* a, Operation* b) { return a->isBeforeInBlock(b); }); + return result; +} + +llvm::SmallVector SideEffectAnalysis::DirectControlSuccessors( + Operation* op, llvm::function_ref filter) const { + llvm::SmallVector result; + auto it = control_successors_.find(op); + if (it == control_successors_.end()) return result; + result.reserve(it->getSecond().size()); + for (auto successor : it->getSecond()) { + if (!filter || filter(successor)) result.push_back(successor); + } + llvm::sort(result, + [](Operation* a, Operation* b) { return a->isBeforeInBlock(b); }); + return result; +} + +SideEffectAnalysis::SideEffectAnalysis(Operation* op) { + auto func_op = llvm::dyn_cast(op); + if (!func_op) return; + ResourceAliasAnalysis alias_analysis(op); + AnalyzeFunction(func_op, alias_analysis); +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h new file mode 100644 index 00000000000..5eee28a6ae0 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h @@ -0,0 +1,125 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_ + +#include +#include + +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" +#include "mlir/IR/Function.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/IR/Region.h" // TF:local_config_mlir +#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir + +namespace mlir { +namespace TF { + +// An analysis that runs on a function and maps each resource-type value to a +// set of unique int64_t IDs representing the possible resources it could alias. +class ResourceAliasAnalysis { + public: + explicit ResourceAliasAnalysis(Operation* op); + ~ResourceAliasAnalysis() = default; + ResourceAliasAnalysis(ResourceAliasAnalysis&&) = default; + + // Returns if the analysis fails to resolve a resource-type value. + bool IsUnknownResource(const Value* resource) const; + + // Returns the set unique IDs which `resource` could alias. Requires that + // IsUnknownResource(resource) == true. + const llvm::SmallSet& GetResourceUniqueIds( + const Value* resource) const; + + private: + ResourceAliasAnalysis() = default; + + // Runs the analysis on `func_op` and populates resource_value_to_ids_. + void AnalyzeFunction(FuncOp func_op); + + // Maps each resource-type value to a set of unique IDs that it could alias. + llvm::SmallDenseMap, 8> + resource_value_to_ids_; +}; + +// An analysis that runs on a function and infers the control predecessors and +// successors for each op, based on side-effects on known and unknown resources. +// Side-effecting ops on uknown resources are conservatively treated as +// interfering with all known resource op accesses. It distinguishes accesses +// based on whether they are read-only, and read-only ops do not interfer with +// each other. +class SideEffectAnalysis { + public: + explicit SideEffectAnalysis(Operation* op); + SideEffectAnalysis(SideEffectAnalysis&& other) = default; + ~SideEffectAnalysis() = default; + + // Returns a vector of ops that are direct control predecessors of `op`, + // sorted in program order. If `filter` is provided, only predecessors that + // pass the filter (returning true) will be included. + llvm::SmallVector DirectControlPredecessors( + Operation* op, + llvm::function_ref filter = nullptr) const; + + // Returns a vector of ops that are direct control successors of `op`, sorted + // in program order. If `filter` is provided, only successors that pass the + // filter (returning true) will be included. + llvm::SmallVector DirectControlSuccessors( + Operation* op, + llvm::function_ref filter = nullptr) const; + + private: + // Runs the analysis on `func_op` and populates control_predecessors_ and + // control_successors_. + void AnalyzeFunction(FuncOp func_op, + const ResourceAliasAnalysis& alias_analysis); + + // Updates control_predecessors_ for `op` that is being visted, on the given + // `resource_id`. + void AddPredecessorsForAccess(int64_t resource_id, Operation* op, + bool read_only); + + // Adds op's access to per_resource_access_info_. + void TrackAccess(int64_t resource_id, Operation* op, bool read_only); + + // Maps from an op to its control predecessors. + llvm::SmallDenseMap, 8> + control_predecessors_; + // Maps from an op to its control successors. + llvm::SmallDenseMap, 8> + control_successors_; + + // Internal per-resource data structure when we build the dependencies. + struct PerResourceAcessInfo { + // Last op that writes the resource before the current op being analyzed. + Operation* last_write = nullptr; + // Read ops since last_write before the current op being analyzed. + llvm::SmallVector reads_since_last_write; + // Whether previous accesses of this resource already tracked last unknown + // read/write. + bool tracked_last_unknown_read = false; + bool tracked_last_unknown_write = false; + }; + llvm::SmallDenseMap + per_resource_access_info_; +}; + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc b/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc index 4bac3be1d1e..ac468d9810c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc @@ -26,7 +26,7 @@ static DialectRegistration tf_control_flow_ops; static DialectRegistration tf_ops; static DialectRegistration - tf_excutor_dialect; + tf_executor_dialect; static DialectRegistration tf_device_dialect; static DialectRegistration diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc index da3d26c1b72..20483691a92 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc @@ -26,11 +26,13 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/SMLoc.h" #include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/OpDefinition.h" // TF:local_config_mlir #include "mlir/IR/OpImplementation.h" // TF:local_config_mlir #include "mlir/IR/OperationSupport.h" // TF:local_config_mlir #include "mlir/IR/PatternMatch.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir #include "mlir/IR/Types.h" // TF:local_config_mlir #include "mlir/IR/Value.h" // TF:local_config_mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index d15cb91edca..cdc545d5681 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -576,6 +576,44 @@ endian orderings will give different results. let hasCanonicalizer = 1; } +def TF_BitwiseOrOp : TF_Op<"BitwiseOr", [Broadcastable, Commutative, NoSideEffect]>, + WithBroadcastableBinOpBuilder { + let summary = "Elementwise computes the bitwise OR of `x` and `y`."; + + let description = [{ +The result will have those bits set, that are set in `x`, `y` or both. The +computation is performed on the underlying representations of `x` and `y`. + +For example: + +```python +import tensorflow as tf +from tensorflow.python.ops import bitwise_ops +dtype_list = [tf.int8, tf.int16, tf.int32, tf.int64, + tf.uint8, tf.uint16, tf.uint32, tf.uint64] + +for dtype in dtype_list: + lhs = tf.constant([0, 5, 3, 14], dtype=dtype) + rhs = tf.constant([5, 0, 7, 11], dtype=dtype) + exp = tf.constant([5, 5, 7, 15], dtype=tf.float32) + + res = bitwise_ops.bitwise_or(lhs, rhs) + tf.assert_equal(tf.cast(res, tf.float32), exp) # TRUE +``` + }]; + + let arguments = (ins + TF_IntTensor:$x, + TF_IntTensor:$y + ); + + let results = (outs + TF_IntTensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_BroadcastGradientArgsOp : TF_Op<"BroadcastGradientArgs", [NoSideEffect]> { let summary = [{ Return the reduction indices for computing gradients of s0 op s1 with broadcast. @@ -1320,6 +1358,10 @@ Comparison with `numpy.einsum`: TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; + + let verifier = [{ + return Verify(*this); + }]; } def TF_EluOp : TF_Op<"Elu", [NoSideEffect, SameOperandsAndResultType]> { @@ -5726,6 +5768,8 @@ If two elements are equal, the lower-index element appears first. ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let verifier = [{ return Verify(*this); }]; } def TF_TransposeOp : TF_Op<"Transpose", [NoSideEffect]> { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index b2d99acfb4a..c3a51613357 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -39,7 +39,7 @@ This dialect maps to TensorFlow operations. Invariants: * All values are of Tensor type (in particular, scalars are - represented using zero-dimentional tensors); + represented using zero-dimensional tensors); TODO: Make invariants more structured so that we can reference them in ops. }]; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 5920ee0d9d1..1bd9accbb78 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -513,7 +513,7 @@ void ConstOp::build(Builder *builder, OperationState &result, Attribute value) { } else if (value.isa() || value.isa() || value.isa()) { // All TensorFlow types must be tensor types. In the build() method, - // we want to provide more flexiblity by allowing attributes of scalar + // we want to provide more flexibility by allowing attributes of scalar // types. But we need to wrap it up with ElementsAttr to construct // valid TensorFlow constants. type = RankedTensorType::get(/*shape=*/{}, value.getType()); @@ -674,6 +674,21 @@ void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +//===----------------------------------------------------------------------===// +// EinsumOp +//===----------------------------------------------------------------------===// + +// Verifies that, +// * Arity of the op is at most two. +// +// TODO(hinsu): Verify einsum equation attribute. +static LogicalResult Verify(EinsumOp op) { + if (op.N() > 2) { + return op.emitOpError("supports at most two operands"); + } + return success(); +} + //===----------------------------------------------------------------------===// // EmptyTensorListOp //===----------------------------------------------------------------------===// @@ -1683,6 +1698,21 @@ static LogicalResult Verify(TensorListStackOp op) { return success(); } +//===----------------------------------------------------------------------===// +// TopKV2Op +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(TopKV2Op op) { + if (!HasRankAtLeast(op.input(), 1)) + return op.emitOpError( + "requires input operand to have at least 1 dimension"); + + if (!IsOfRankOrUnranked(op.k(), 0)) + return op.emitOpError("requires k operand to be 0D tensor"); + + return success(); +} + //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h index 47004e438eb..e9aaed56afc 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h @@ -41,7 +41,7 @@ class TensorFlowDialect : public Dialect { static StringRef getDialectNamespace() { return "tf"; } - // Gradient attribute ("tf.gradient") in the list of NamedAttibutes in a + // Gradient attribute ("tf.gradient") in the list of NamedAttributes in a // function references to its gradient function. This attribute in TensorFlow // Dialect is used to model TF GradientDef. GetGradientAttrName() returns the // string description of gradient attribute. diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc index ad49c8970cf..c672d624944 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" #include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Function.h" // TF:local_config_mlir #include "mlir/IR/Identifier.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 2db64262094..a2cc33a8201 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -101,6 +101,17 @@ func @testDifferentCastType(%arg0: tensor<8x16x32x64xf32>) -> (tensor<8x16x32x64 // CHECK: return %0, %1 } +// CHECK-LABEL: testCompatibleCastType +func @testCompatibleCastType(%arg0: tensor) -> (tensor<10xf32>, tensor<10xf32>) { + %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor) -> tensor<10xf32> + %1 = "tf.Cast"(%arg0) {Truncate = true} : (tensor) -> tensor<10xf32> + return %0, %1: tensor<10xf32>, tensor<10xf32> + +// CHECK: %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor) -> tensor<10xf32> +// CHECK: %1 = "tf.Cast"(%arg0) {Truncate = true} : (tensor) -> tensor<10xf32> +// CHECK: return %0, %1 +} + // CHECK-LABEL: testSameCastTypeAcrossBasicBlocks func @testSameCastTypeAcrossBasicBlocks(tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xf32> { ^bb0(%arg0: tensor<8x16x32x64xf32>): diff --git a/tensorflow/compiler/mlir/tensorflow/tests/cluster_formation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/cluster_formation.mlir index 8a5375d6716..4bf6b0d2f28 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/cluster_formation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/cluster_formation.mlir @@ -232,12 +232,12 @@ module { // ----- -// Single device with non-continous instructions in original block. +// Single device with non-continuous instructions in original block. module { - // CHECK-LABEL: func @noncontinoussinglecluster + // CHECK-LABEL: func @noncontinuoussinglecluster // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) - func @noncontinoussinglecluster(%arg0: tensor) -> tensor { + func @noncontinuoussinglecluster(%arg0: tensor) -> tensor { %0 = tf_executor.graph { %1:2 = tf_executor.island { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir new file mode 100644 index 00000000000..c6eb4663e57 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir @@ -0,0 +1,237 @@ +// RUN: tf-opt -split-input-file -tf-test-side-effect-analysis -verify-diagnostics %s | FileCheck %s --dump-input=fail + +// Tests that the pass tracks control dependencies for reads/writes on the same +// resource. + +// CHECK-LABEL: func @non_aliasing_reads_writes +func @non_aliasing_reads_writes( +// expected-remark@above {{ID: 13}} +// expected-remark@above {{Predecessors: {12}}} + %arg0: tensor<*x!tf.resource>>, + %arg1: tensor<*x!tf.resource>>, + %arg2: tensor<32xf32>) -> (tensor<32xf32>) { + %graph = tf_executor.graph { + // expected-remark@above {{ID: 11}} + // expected-remark@above {{Predecessors: {10}}} + // expected-remark@above {{Successors: {12}}} + // CHECK: tf_executor.island + %island:2 = tf_executor.island { + // expected-remark@above {{ID: 9}} + // expected-remark@above {{Predecessors: {8}}} + // expected-remark@above {{Successors: {10}}} + %read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // expected-remark@above {{ID: 0}} + // expected-remark@above {{Successors: {1}}} + "tf.AssignVariableOp"(%arg0, %arg2) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // expected-remark@above {{ID: 1}} + // expected-remark@above {{Predecessors: {0}}} + // expected-remark@above {{Successors: {6}}} + %read1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // expected-remark@above {{ID: 2}} + // expected-remark@above {{Successors: {5}}} + %var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource>> + // expected-remark@above {{ID: 3}} + %read2 = "tf.ReadVariableOp"(%var_handle) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Successors: {8}}} + "tf.AssignVariableOp"(%arg1, %read0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // expected-remark@above {{ID: 5}} + // expected-remark@above {{Predecessors: {2}}} + // expected-remark@above {{Successors: {8}}} + "tf.AssignVariableOp"(%arg0, %read2) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // expected-remark@above {{ID: 6}} + // expected-remark@above {{Predecessors: {1}}} + // expected-remark@above {{Successors: {7}}} + %read3 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // expected-remark@above {{ID: 7}} + // expected-remark@above {{Predecessors: {6}}} + // expected-remark@above {{Successors: {8}}} + tf_executor.yield %read3 : tensor<32xf32> + // expected-remark@above {{ID: 8}} + // expected-remark@above {{Predecessors: {4,5,7}}} + // expected-remark@above {{Successors: {9}}} + } + tf_executor.fetch %island#0 : tensor<32xf32> + // expected-remark@above {{ID: 10}} + // expected-remark@above {{Predecessors: {9}}} + // expected-remark@above {{Successors: {11}}} + } + return %graph : tensor<32xf32> + // expected-remark@above {{ID: 12}} + // expected-remark@above {{Predecessors: {11}}} + // expected-remark@above {{Successors: {13}}} +} + +// ----- + +// Tests that the pass tracks control dependencies for reads/writes on the two +// resource handles that refer to the same variable. + +// CHECK-LABEL: func @aliasing_reads_writes +func @aliasing_reads_writes(%arg0: tensor<32xf32>) -> () { +// expected-remark@above {{ID: 14}} +// expected-remark@above {{Predecessors: {13}}} + tf_executor.graph { + // expected-remark@above {{ID: 12}} + // expected-remark@above {{Predecessors: {11}}} + // expected-remark@above {{Successors: {13}}} + // CHECK: tf_executor.island + %island = tf_executor.island { + // expected-remark@above {{ID: 10}} + // expected-remark@above {{Predecessors: {9}}} + // expected-remark@above {{Successors: {11}}} + %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource>> + // expected-remark@above {{ID: 0}} + %vh1 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource>> + // expected-remark@above {{ID: 1}} + %vh1_id:2 = "tf.IdentityN"(%vh1, %arg0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> (tensor<*x!tf.resource>>, tensor<32xf32>) + // expected-remark@above {{ID: 2}} + %read0 = "tf.ReadVariableOp"(%vh0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // expected-remark@above {{ID: 3}} + // expected-remark@above {{Successors: {4}}} + "tf.AssignVariableOp"(%vh1_id#0, %arg0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Predecessors: {3}}} + // expected-remark@above {{Successors: {5,6}}} + %read1 = "tf.ReadVariableOp"(%vh0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // expected-remark@above {{ID: 5}} + // expected-remark@above {{Predecessors: {4}}} + // expected-remark@above {{Successors: {7}}} + %read2 = "tf.ReadVariableOp"(%vh1) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // expected-remark@above {{ID: 6}} + // expected-remark@above {{Predecessors: {4}}} + // expected-remark@above {{Successors: {7}}} + "tf.AssignVariableOp"(%vh0, %read2) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // expected-remark@above {{ID: 7}} + // expected-remark@above {{Predecessors: {5,6}}} + // expected-remark@above {{Successors: {8}}} + "tf.AssignVariableOp"(%vh1_id#0, %read1) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // expected-remark@above {{ID: 8}} + // expected-remark@above {{Predecessors: {7}}} + // expected-remark@above {{Successors: {9}}} + tf_executor.yield + // expected-remark@above {{ID: 9}} + // expected-remark@above {{Predecessors: {8}}} + // expected-remark@above {{Successors: {10}}} + } + tf_executor.fetch %island : !tf_executor.control + // expected-remark@above {{ID: 11}} + // expected-remark@above {{Predecessors: {10}}} + // expected-remark@above {{Successors: {12}}} + } + return + // expected-remark@above {{ID: 13}} + // expected-remark@above {{Predecessors: {12}}} + // expected-remark@above {{Successors: {14}}} +} + +// ----- + +// Tests that the pass tracks control dependencies for side-effecting on unknown +// resources. + +// CHECK-LABEL: func @unknown_side_effecting_op +func @unknown_side_effecting_op(%arg0: tensor<32xf32>) -> () { +// expected-remark@above {{ID: 13}} +// expected-remark@above {{Predecessors: {12}}} + tf_executor.graph { + // expected-remark@above {{ID: 11}} + // expected-remark@above {{Predecessors: {10}}} + // expected-remark@above {{Successors: {12}}} + // CHECK: tf_executor.island + %island = tf_executor.island { + // expected-remark@above {{ID: 9}} + // expected-remark@above {{Predecessors: {8}}} + // expected-remark@above {{Successors: {10}}} + %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource>> + // expected-remark@above {{ID: 0}} + %vh1 = "tf.VarHandleOp"() {container = "c", shared_name = "v1"} : () -> tensor<*x!tf.resource>> + // expected-remark@above {{ID: 1}} + %read0 = "tf.ReadVariableOp"(%vh0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // expected-remark@above {{ID: 2}} + // expected-remark@above {{Successors: {4}}} + "tf.AssignVariableOp"(%vh1, %arg0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // expected-remark@above {{ID: 3}} + // expected-remark@above {{Successors: {4}}} + "tf._UnknownSideEffectingOp_"() : () -> () + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Predecessors: {2,3}}} + // expected-remark@above {{Successors: {5,6}}} + %read1 = "tf.ReadVariableOp"(%vh1) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // expected-remark@above {{ID: 5}} + // expected-remark@above {{Predecessors: {4}}} + // expected-remark@above {{Successors: {7}}} + "tf.AssignVariableOp"(%vh0, %read1) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // expected-remark@above {{ID: 6}} + // expected-remark@above {{Predecessors: {4}}} + // expected-remark@above {{Successors: {8}}} + "tf.AssignVariableOp"(%vh1, %read0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // expected-remark@above {{ID: 7}} + // expected-remark@above {{Predecessors: {5}}} + // expected-remark@above {{Successors: {8}}} + tf_executor.yield + // expected-remark@above {{ID: 8}} + // expected-remark@above {{Predecessors: {6,7}}} + // expected-remark@above {{Successors: {9}}} + } + tf_executor.fetch %island : !tf_executor.control + // expected-remark@above {{ID: 10}} + // expected-remark@above {{Predecessors: {9}}} + // expected-remark@above {{Successors: {11}}} + } + return + // expected-remark@above {{ID: 12}} + // expected-remark@above {{Predecessors: {11}}} + // expected-remark@above {{Successors: {13}}} +} + +// ----- + +// Tests that the pass tracks control dependencies for read-only ops on unknown +// resources. + +// CHECK-LABEL: func @read_only_unknown_resource +func @read_only_unknown_resource(%arg0: tensor<32xf32>) -> () { +// expected-remark@above {{ID: 10}} +// expected-remark@above {{Predecessors: {9}}} + tf_executor.graph { + // expected-remark@above {{ID: 8}} + // expected-remark@above {{Predecessors: {7}}} + // expected-remark@above {{Successors: {9}}} + // CHECK: tf_executor.island + %island = tf_executor.island { + // expected-remark@above {{ID: 6}} + // expected-remark@above {{Predecessors: {5}}} + // expected-remark@above {{Successors: {7}}} + %vh0 = "tf._UnknownSideEffectingOp_"() : () -> tensor<*x!tf.resource>> + // expected-remark@above {{ID: 0}} + // expected-remark@above {{Successors: {2,3}}} + %vh1 = "tf.VarHandleOp"() {container = "c", shared_name = "v1"} : () -> tensor<*x!tf.resource>> + // expected-remark@above {{ID: 1}} + %read0 = "tf.ReadVariableOp"(%vh0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // expected-remark@above {{ID: 2}} + // expected-remark@above {{Predecessors: {0}}} + // expected-remark@above {{Successors: {4}}} + %read1 = "tf.ReadVariableOp"(%vh1) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // expected-remark@above {{ID: 3}} + // expected-remark@above {{Predecessors: {0}}} + // expected-remark@above {{Successors: {4}}} + "tf.AssignVariableOp"(%vh1, %read0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Predecessors: {2,3}}} + // expected-remark@above {{Successors: {5}}} + tf_executor.yield + // expected-remark@above {{ID: 5}} + // expected-remark@above {{Predecessors: {4}}} + // expected-remark@above {{Successors: {6}}} + } + tf_executor.fetch %island : !tf_executor.control + // expected-remark@above {{ID: 7}} + // expected-remark@above {{Predecessors: {6}}} + // expected-remark@above {{Successors: {8}}} + } + return + // expected-remark@above {{ID: 9}} + // expected-remark@above {{Predecessors: {8}}} + // expected-remark@above {{Successors: {10}}} +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index b8e7ba71198..e064c1a53ef 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -1650,3 +1650,27 @@ func @testSplitSmallSplitDim(%input: tensor<4x8xf32>) { %0:3 = "tf.Split"(%cst, %input) : (tensor, tensor<4x8xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) return } + +// ----- + +func @testTernaryEinsum(%arg0: tensor<2x3xf32>){ + // expected-error @+1 {{supports at most two operands}} + %0 = "tf.Einsum"(%arg0, %arg0, %arg0) {equation = "ab,cd,ef->"} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> (tensor<*xf32>) + return +} + +// ----- + +func @testTopKV2WrongInputRank(%input: tensor, %k: tensor) { + // expected-error @+1 {{op requires input operand to have at least 1 dimension}} + %0:2 = "tf.TopKV2"(%input, %k) : (tensor, tensor) -> (tensor<*xf32>, tensor<*xi32>) + return +} + +// ----- + +func @testTopKV2WrongKRank(%input: tensor<8xf32>, %k: tensor<5xi32>) { + // expected-error @+1 {{op requires k operand to be 0D tensor}} + %0:2 = "tf.TopKV2"(%input, %k) : (tensor<8xf32>, tensor<5xi32>) -> (tensor<*xf32>, tensor<*xi32>) + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir index 022d181eb24..d0aa1414723 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir @@ -523,7 +523,7 @@ func @invalid_merge(%arg0: tensor<*x!tf.resource>, %arg1: tensor<4x!tf.resource> // ----- // Check that if result is a ref type, all operands need to be ref too. -func @inavlid_merge(%arg0: tensor<4x!tf.f32ref>, %arg1: tensor<4xf32>) -> tensor<4x!tf.f32ref> { +func @invalid_merge(%arg0: tensor<4x!tf.f32ref>, %arg1: tensor<4xf32>) -> tensor<4x!tf.f32ref> { %result = tf_executor.graph { %value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<4x!tf.f32ref>, tensor<4xf32>) -> (tensor<4x!tf.f32ref>, tensor, !tf_executor.control) // expected-error@-1 {{'tf_executor.Merge' op expects same operand and output element type but got 'tensor<4xf32>' vs 'tensor<4x!tf.f32ref>'}} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index fccedf4057a..a7f45c41f15 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -68,14 +68,17 @@ tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging) { namespace TF { tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module, - bool enable_logging) { + bool enable_logging, + bool enable_inliner) { PassManager bridge(module.getContext()); // Add logger to bridge passmanager. if (enable_logging) bridge.addInstrumentation(std::make_unique()); - CreateTFStandardPipeline(bridge); + StandardPipelineOptions pipeline_options; + pipeline_options.enable_inliner.setValue(enable_inliner); + CreateTFStandardPipeline(bridge, pipeline_options); mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); LogicalResult result = bridge.run(module); (void)result; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h index 844b9095dba..ff446af24f5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h @@ -31,11 +31,13 @@ tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging); namespace TF { -// Run all passes involved in transforming or optimizing an MLIR graph without +// Runs all passes involved in transforming or optimizing an MLIR graph without // any target specialization. When enable_logging is true, enables -// tensorflow::BridgeLogger. +// tensorflow::BridgeLogger. When enable_inliner is true, enables the inliner +// pass. tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module, - bool enable_logging); + bool enable_logging, + bool enable_inliner); } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc index b19bb0f8cd5..0208dc2f579 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc @@ -29,10 +29,4 @@ mlir::PassPipelineRegistration<> tpu_pipeline( "that it is suitable for targeting TPUs.", mlir::TFTPU::CreateTPUBridge); -mlir::PassPipelineRegistration<> standard_pipeline( - "tf-standard-bridge", - "Run all passes involved in transforming or optimizing an MLIR graph" - "without any target specialization.", - mlir::TF::CreateTFStandardPipeline); - } // anonymous namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td index beb7583fc57..7c38b78f239 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td @@ -22,6 +22,9 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" def SingleResultAndOperandHaveSameElementType : Constraint< CPred<"getElementTypeOrSelf($0) == getElementTypeOrSelf($1)">>; +def SingleResultAndOperandHaveSameType : Constraint< + CPred<"$0->getType() == $1->getType()">>; + def IsRank2Tensor : Type, "Rank 2 tensor">; //===----------------------------------------------------------------------===// @@ -75,8 +78,7 @@ def BitcastNested : Pat<(TF_BitcastOp (TF_BitcastOp $arg)), def CastSameType : Pat<(TF_CastOp:$res $arg, $truncate), (replaceWithValue $arg), - [(SingleResultAndOperandHaveSameElementType $res, - $arg)]>; + [(SingleResultAndOperandHaveSameType $res, $arg)]>; //===----------------------------------------------------------------------===// // Conj op patterns. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc index c6e3a0ab895..b0420663bde 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc @@ -45,7 +45,8 @@ struct TFOptimizePass : public FunctionPass { } // namespace // NOLINTNEXTLINE - MLIR contract is pass by mutable reference. -void CreateTFStandardPipeline(OpPassManager &pm) { +void CreateTFStandardPipeline(OpPassManager &pm, + const StandardPipelineOptions &options) { OpPassManager &func_pm = pm.nest(); // First operates on the executor dialect: @@ -59,8 +60,12 @@ void CreateTFStandardPipeline(OpPassManager &pm) { // Hopefully there is a single island left, or there wasn't any to begin with. // We now run the optimizer which operates mostly inside islands. func_pm.addPass(createCanonicalizerPass()); - func_pm.addPass(CreateTFOptimizePass()); - func_pm.addPass(createCSEPass()); + if (options.enable_inliner) { + pm.addPass(createInlinerPass()); + } + pm.addNestedPass(CreateTFShapeInferencePass()); + pm.addNestedPass(CreateTFOptimizePass()); + pm.addNestedPass(createCSEPass()); } std::unique_ptr> CreateTFOptimizePass() { @@ -70,7 +75,7 @@ std::unique_ptr> CreateTFOptimizePass() { static PassRegistration pass("tf-optimize", "Optimizes TF."); // Registers a pipeline builder function for the default canonicalize/optimizer. -static mlir::PassPipelineRegistration<> pipeline( +static mlir::PassPipelineRegistration pipeline( "tf-standard-pipeline", "Run all the passes involved in transforming/optimizing the graph after " "importing into MLIR, without any target specialization.", diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 7a5c060f5dc..30ee91f4aea 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -46,10 +46,17 @@ std::unique_ptr> CreateTFShapeInferencePass(); // Optimizes Tensorflow graph. std::unique_ptr> CreateTFOptimizePass(); +struct StandardPipelineOptions : public PassOptions { + Option enable_inliner{*this, "enable-inliner", + llvm::cl::desc("Enable inliner."), + llvm::cl::init(false)}; +}; + // Propagates the pass manager with the passes involved in transforming or // optimizing an MLIR graph without any target specialization. // NOLINTNEXTLINE - MLIR contract is pass by mutable reference. -void CreateTFStandardPipeline(OpPassManager& pm); +void CreateTFStandardPipeline(OpPassManager& pm, + const StandardPipelineOptions& options); } // namespace TF namespace TFControlFlow { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc new file mode 100644 index 00000000000..f0b7964389d --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc @@ -0,0 +1,77 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Pass/PassManager.h" // TF:local_config_mlir +#include "mlir/Support/LLVM.h" // TF:local_config_mlir +#include "mlir/Transforms/Passes.h" // TF:local_config_mlir +#include "mlir/Transforms/RegionUtils.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" + +namespace mlir { +namespace tf_executor { + +namespace { + +// A pass that adds "Predecessors" and "Successors" remarks for each op based on +// SideEffectAnalysis result. For testing purpose only. +struct TestSideEffectAnalysis + : public mlir::FunctionPass { + void runOnFunction() override { + int64_t next_id = 0; + llvm::SmallDenseMap ids; + getFunction().walk([&](Operation* op) { + ids[op] = next_id++; + op->emitRemark("ID: ") << ids[op]; + }); + auto join_ids = [&](const llvm::ArrayRef ops) { + llvm::SmallVector id_vec; + id_vec.reserve(ops.size()); + for (auto op : ops) id_vec.push_back(std::to_string(ids[op])); + return llvm::join(id_vec, ","); + }; + auto& analysis = getAnalysis(); + getFunction().walk([&](Operation* op) { + if (!analysis.DirectControlPredecessors(op).empty()) { + op->emitRemark("Predecessors: ") + << "{" << join_ids(analysis.DirectControlPredecessors(op)) << "}"; + } + if (!analysis.DirectControlSuccessors(op).empty()) { + op->emitRemark("Successors: ") + << "{" << join_ids(analysis.DirectControlSuccessors(op)) << "}"; + } + }); + } +}; + +static mlir::PassRegistration pass( + "tf-test-side-effect-analysis", + "Add remarks based on side-effect analysis result, for testing purpose."); + +} // anonymous namespace + +} // namespace tf_executor +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc index 4b74f3e6ca3..2eb12c80efe 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc @@ -141,7 +141,7 @@ static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); // NOLINTNEXTLINE static llvm::cl::list cl_pass_list( "graph-passes", llvm::cl::value_desc("list"), - llvm::cl::desc("comma seprarated list of GraphOptimizationPass to run."), + llvm::cl::desc("comma separated list of GraphOptimizationPass to run."), llvm::cl::CommaSeparated, llvm::cl::cat(clOptionsCategory)); class GraphOptByNamePass : public GraphOptPass { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index e735dfa1a4c..da2e6a67445 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/strings/escaping.h" #include "absl/strings/numbers.h" @@ -75,6 +76,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/protobuf.h" @@ -500,7 +502,8 @@ Status ImporterBase::GetInputOutputNodes( TF_RETURN_IF_ERROR(add_node(input.first)); } - for (const auto& output_node_name : specs_.output_arrays) { + for (const auto& output : specs_.outputs) { + auto output_node_name = std::string(ParseTensorName(output).first); TF_RETURN_IF_ERROR(add_node(output_node_name)); } @@ -535,7 +538,7 @@ Status ImporterBase::AddNodesToShapeRefiner() { auto node_name = node->op_def().name(); if (node_name != "Placeholder" && node_name != "LegacyFedInput" && node_name != FunctionLibraryDefinition::kArgOp) { - // We do not handle the case where the input node has multple outputs + // We do not handle the case where the input node has multiple outputs if (node->num_outputs() > 1) { return errors::FailedPrecondition(absl::StrCat( "Input arrays can only have op with single output. Node op:", @@ -1588,7 +1591,7 @@ StatusOr GraphDefImporter::Convert( llvm::SmallVector attrs; if (specs.graph_as_function) { if (specs.prune_unused_nodes || !specs.inputs.empty() || - !specs.output_arrays.empty() || !specs.output_arrays_order.empty()) + !specs.outputs.empty()) return errors::InvalidArgument( "Pruning of graph is currently unsupported when the main graph is " "converted to a function."); @@ -1622,7 +1625,7 @@ StatusOr GraphDefImporter::Convert( // TODO(prakalps): Refactor to keep attribute strings (tf.entry_function, // tf.versions) shared by importer and exporter in a centralized place. // Record the input and output mapping. - if (!specs.inputs.empty() || !specs.output_arrays.empty()) { + if (!specs.inputs.empty() || !specs.outputs.empty()) { mlir::Builder b(context); std::string s; llvm::raw_string_ostream ss(s); @@ -1632,7 +1635,7 @@ StatusOr GraphDefImporter::Convert( ","); auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str())); s.clear(); - mlir::interleave(specs.output_arrays_order, ss, ","); + mlir::interleave(specs.outputs, ss, ","); auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str())); attrs.push_back(b.getNamedAttr("tf.entry_function", @@ -1665,9 +1668,13 @@ StatusOr GraphDefImporter::InferMainFunctionType( absl::InlinedVector* arg_nodes, absl::InlinedVector* ret_nodes) { // Finds out all the input nodes and output nodes. - if (!specs.inputs.empty() || !specs.output_arrays.empty()) { + absl::flat_hash_set output_node_names; + for (const auto& output_tensor : specs.outputs) { + output_node_names.insert(ParseTensorName(output_tensor).node()); + } + if (!specs.inputs.empty() || !specs.outputs.empty()) { arg_nodes->resize(specs.inputs.size()); - ret_nodes->resize(specs.output_arrays_order.size()); + ret_nodes->resize(specs.outputs.size()); for (Node* n : GetOrderedNodes()) { // Handle inputs/arguments. @@ -1677,17 +1684,17 @@ StatusOr GraphDefImporter::InferMainFunctionType( } // Handle outputs/returns. - if (specs.output_arrays.find(n->name()) != specs.output_arrays.end()) { - for (int i = 0, e = specs.output_arrays_order.size(); i != e; ++i) { + if (output_node_names.contains(n->name())) { + for (int i = 0, e = specs.outputs.size(); i != e; ++i) { std::pair name_and_port = - absl::StrSplit(specs.output_arrays_order[i], ':'); + absl::StrSplit(specs.outputs[i], ':'); auto name = name_and_port.first; if (name != n->name()) continue; int port = 0; if (!name_and_port.second.empty() && !absl::SimpleAtoi(name_and_port.second, &port)) { return errors::InvalidArgument("Invalid port specification: ", - specs.output_arrays_order[i]); + specs.outputs[i]); } (*ret_nodes)[i] = {n, port}; } @@ -1726,10 +1733,10 @@ StatusOr GraphDefImporter::InferMainFunctionType( } llvm::SmallVector ret_types; - ret_types.reserve(specs.output_arrays.size()); - for (int i = 0, e = specs.output_arrays_order.size(); i != e; ++i) { + ret_types.reserve(specs.outputs.size()); + for (int i = 0, e = specs.outputs.size(); i != e; ++i) { if (ret_nodes->at(i).node == nullptr) { - return errors::InvalidArgument("Output ", specs.output_arrays_order[i], + return errors::InvalidArgument("Output ", specs.outputs[i], " was not found in graph"); } } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc index 133e4831356..b2cf906be0d 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc @@ -33,19 +33,16 @@ limitations under the License. namespace tensorflow { Status ParseOutputArrayInfo(absl::string_view array_names, - absl::flat_hash_set* array, - std::vector* order) { + std::vector* outputs) { std::vector output_names = absl::StrSplit(array_names, ','); - return ParseOutputArrayInfo(output_names, array, order); + return ParseOutputArrayInfo(output_names, outputs); } Status ParseOutputArrayInfo(const std::vector& output_names, - absl::flat_hash_set* array, - std::vector* order) { + std::vector* outputs) { for (auto& output_name : output_names) { if (output_name.empty()) continue; - array->insert(string(*absl::StrSplit(output_name, ':').begin())); - order->push_back(output_name); + outputs->push_back(output_name); } return Status::OK(); } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h index ebc862999e9..9b260883638 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h @@ -40,11 +40,9 @@ struct GraphImportConfig { llvm::MapVector>; // Maps input node names to node data types and shapes. InputArrays inputs; - // Output node names. - absl::flat_hash_set output_arrays; - // nodes:index strings for the output as specified on the command line. - std::vector output_arrays_order; - // setting prune_unused_nodes to true, would prune unreachable nodes if + // name:index strings for the output as specified on the command line. + std::vector outputs; + // Setting prune_unused_nodes to true, would prune unreachable nodes if // output_arrays is specified. bool prune_unused_nodes = false; // If true, inputs of type LegacyFedInput are replaced with Placeholder ops. @@ -73,12 +71,10 @@ struct GraphExportConfig { // Parses the command line flag strings to the specification of nodes in // the Graph. Status ParseOutputArrayInfo(absl::string_view array_names, - absl::flat_hash_set* array, - std::vector* order); + std::vector* outputs); Status ParseOutputArrayInfo(const std::vector& output_names, - absl::flat_hash_set* array, - std::vector* order); + std::vector* outputs); // Parses the command line flag strings to the specification of nodes in // the Graph. `data_types` input string can be empty since the flag is optional. diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc index cd422a66bc5..5c59eace5cc 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/grappler/utils/transitive_fanin.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h" @@ -63,16 +64,18 @@ static StatusOr GraphdefToMlirImport( specs.upgrade_legacy = upgrade_legacy; TF_RETURN_IF_ERROR(ParseInputArrayInfo(input_arrays, input_dtypes, input_shapes, &specs.inputs)); - TF_RETURN_IF_ERROR(ParseOutputArrayInfo(output_arrays, &specs.output_arrays, - &specs.output_arrays_order)); + TF_RETURN_IF_ERROR(ParseOutputArrayInfo(output_arrays, &specs.outputs)); // TODO(b/142828368): Pruning should not be needed when TF import // supports importing graphs w/ unregistered ops natively. GraphDef pruned_graph_def; if (specs.prune_unused_nodes) { - std::vector terminal_nodes(specs.output_arrays.begin(), - specs.output_arrays.end()); - for (const auto entry : specs.inputs) { - terminal_nodes.push_back(entry.first); + std::vector terminal_nodes; + terminal_nodes.reserve(specs.outputs.size() + specs.inputs.size()); + for (const auto& output : specs.outputs) { + terminal_nodes.push_back(std::string(ParseTensorName(output).node())); + } + for (const auto& input : specs.inputs) { + terminal_nodes.push_back(input.first); } TF_RETURN_IF_ERROR(tensorflow::grappler::SetTransitiveFaninGraph( graphdef, &pruned_graph_def, terminal_nodes)); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc index 3ee644bef4d..3574b336f9a 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc @@ -36,7 +36,7 @@ xla::StatusOr TestShapeRepresentation(const TensorShape& shape, return xla_shape; } -TEST(CompileSerializedMlirToXlaHloTest, InvalidSerliazedMlirModule) { +TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) { string invalid_mlir_module = "totally @invalid MLIR module {here} <-"; std::vector arg_shapes; XlaCompiler::CompilationResult compilation_result; @@ -101,7 +101,7 @@ ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) { xla::ShapeUtil::MakeTupleShape({output_shape}); EXPECT_EQ(compilation_result.xla_output_shape, tuple_output_shape); - // Expect exactly 1 OutputDescrpition. + // Expect exactly 1 OutputDescription. EXPECT_EQ(compilation_result.outputs.size(), 1); const XlaCompiler::OutputDescription& output_desc = compilation_result.outputs.front(); diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc index f12dde278ae..9ab31265a33 100644 --- a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/Support/FileUtilities.h" // TF:local_config_mlir #include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/Support/ToolUtilities.h" // TF:local_config_mlir #include "mlir/Support/TranslateClParser.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" @@ -40,6 +41,13 @@ static llvm::cl::opt output_filename( "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"), llvm::cl::init("-")); +// NOLINTNEXTLINE +static llvm::cl::opt splitInputFile( + "split-input-file", + llvm::cl::desc("Split the input file into pieces and process each chunk " + "independently"), + llvm::cl::init(false)); + // NOLINTNEXTLINE static llvm::cl::opt import_saved_model( "savedmodel-to-mlir", @@ -85,13 +93,12 @@ int main(int argc, char** argv) { return 1; } - mlir::MLIRContext context; - if (import_saved_model) { std::unordered_set tags = absl::StrSplit(saved_model_tags, ','); std::vector exported_names = absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); + mlir::MLIRContext context; auto module = tensorflow::SavedModelToMlirImport( input_filename, tags, absl::Span(exported_names), @@ -107,12 +114,23 @@ int main(int argc, char** argv) { return 1; } - llvm::SourceMgr source_mgr; - source_mgr.AddNewSourceBuffer(std::move(input), llvm::SMLoc()); - mlir::SourceMgrDiagnosticHandler diagnostic_handler(source_mgr, &context); + // Processes the memory buffer with a new MLIRContext. + auto processBuffer = [&](std::unique_ptr ownedBuffer, + llvm::raw_ostream& os) { + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc()); + mlir::MLIRContext context; + mlir::SourceMgrDiagnosticHandler diagnostic_handler(sourceMgr, &context); + return (*requested_translation)(sourceMgr, os, &context); + }; - if (failed((*requested_translation)(source_mgr, output->os(), &context))) - return 1; + if (splitInputFile) { + if (failed(mlir::splitAndProcessBuffer(std::move(input), processBuffer, + output->os()))) + return 1; + } else { + if (failed(processBuffer(std::move(input), output->os()))) return 1; + } } output->keep(); diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 3ed3fb6fc40..ac3475cebc4 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -404,6 +404,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/service:hlo", "@llvm//:support", "@local_config_mlir//:Analysis", diff --git a/tensorflow/compiler/mlir/xla/convert_op_folder.cc b/tensorflow/compiler/mlir/xla/convert_op_folder.cc index d26bec292cc..8245b4a0585 100644 --- a/tensorflow/compiler/mlir/xla/convert_op_folder.cc +++ b/tensorflow/compiler/mlir/xla/convert_op_folder.cc @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir +namespace mlir { namespace xla { mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements, @@ -82,3 +83,4 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements, } } // namespace xla +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/convert_op_folder.h b/tensorflow/compiler/mlir/xla/convert_op_folder.h index 1c3f75489f8..63ac0e61df5 100644 --- a/tensorflow/compiler/mlir/xla/convert_op_folder.h +++ b/tensorflow/compiler/mlir/xla/convert_op_folder.h @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +namespace mlir { namespace xla { // Converts the given elements attr to the specified elements type. @@ -27,5 +28,6 @@ namespace xla { mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements, mlir::Type new_type); } // namespace xla +} // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_ diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index b996b12119d..b2f02bdf76f 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -47,13 +47,11 @@ limitations under the License. #include "mlir/Transforms/InliningUtils.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/xla/convert_op_folder.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h.inc" +#include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h" namespace mlir { #include "tensorflow/compiler/mlir/xla/ir/hlo_structs.cc.inc" -} // namespace mlir - -using namespace mlir; -using namespace mlir::xla_hlo; +namespace xla_hlo { Operation* XlaHloDialect::materializeConstant(OpBuilder& builder, Attribute value, Type type, @@ -160,7 +158,7 @@ void ConstOp::build(Builder* builder, OperationState& result, Attribute value) { } else if (value.isa() || value.isa() || value.isa()) { // All XLA types must be tensor types. In the build() method, we want to - // provide more flexiblity by allowing attributes of scalar types. But we + // provide more flexibility by allowing attributes of scalar types. But we // need to wrap it up with ElementsAttr to construct valid XLA constants. type = RankedTensorType::get(/*shape=*/{}, value.getType()); value = DenseElementsAttr::get(type.cast(), value); @@ -212,9 +210,9 @@ void AbsOp::build(Builder* builder, OperationState& result, Value* operand) { new_type = operand->getType(); } else if (shaped_type.hasRank()) { new_type = - mlir::RankedTensorType::get(shaped_type.getShape(), operand->getType()); + RankedTensorType::get(shaped_type.getShape(), operand->getType()); } else { - new_type = mlir::UnrankedTensorType::get(operand->getType()); + new_type = UnrankedTensorType::get(operand->getType()); } return AbsOp::build(builder, result, new_type, operand); @@ -241,8 +239,8 @@ OpFoldResult ConvertOp::fold(ArrayRef operands) { // If the operand is constant, we can do the conversion now. if (auto elementsAttr = operands.front().dyn_cast_or_null()) { - return ::xla::ConvertElementsAttr(elementsAttr, - getElementTypeOrSelf(getResult())); + return xla::ConvertElementsAttr(elementsAttr, + getElementTypeOrSelf(getResult())); } return {}; @@ -436,7 +434,7 @@ static LogicalResult Verify(ClampOp op) { void ComplexOp::build(Builder* builder, OperationState& state, Value* lhs, Value* rhs) { auto type = lhs->getType(); - auto element_ty = mlir::ComplexType::get(getElementTypeOrSelf(type)); + auto element_ty = ComplexType::get(getElementTypeOrSelf(type)); Type result_ty; if (auto ranked_type = type.dyn_cast()) { result_ty = RankedTensorType::get(ranked_type.getShape(), element_ty); @@ -843,6 +841,70 @@ Type SliceOp::InferOutputTypes(Builder* builder, Value* operand, return RankedTensorType::get(shape, ranked_ty.getElementType()); } +//===----------------------------------------------------------------------===// +// SortOp +//===----------------------------------------------------------------------===// + +void SortOp::build(Builder* builder, OperationState& state, + ArrayRef operands, int64_t dimension, + bool is_stable) { + state.addOperands(operands); + state.addAttribute("dimension", builder->getI64IntegerAttr(dimension)); + state.addAttribute("is_stable", builder->getBoolAttr(dimension)); + + SmallVector element_types; + element_types.reserve(operands.size()); + for (Value* operand : operands) element_types.push_back(operand->getType()); + state.addTypes(builder->getTupleType(element_types)); + + state.addRegion(); +} + +static LogicalResult Verify(SortOp op) { + Operation::operand_range operands = op.operands(); + if (operands.empty()) return op.emitOpError("requires at least one input"); + + // TODO(antiagainst): verify partionally dynamic shapes + if (llvm::all_of(operands, [](Value* operand) { + return operand->getType().cast().hasRank(); + })) { + ArrayRef input_shape = + (*operands.begin())->getType().cast().getShape(); + + if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value* operand) { + return operand->getType().cast().getShape() != + input_shape; + })) + return op.emitOpError("requires all inputs to have the same dimensions"); + + if (op.dimension().getSExtValue() >= input_shape.size()) + return op.emitOpError( + "dimension attribute value must be less than input rank"); + } + + Block& block = op.comparator().front(); + size_t num_operands = op.getOperation()->getNumOperands(); + if (block.getNumArguments() != 2 * num_operands) + return op.emitOpError("comparator block should have ") + << 2 * num_operands << " arguments"; + + for (auto indexed_operand : llvm::enumerate(operands)) { + int index = indexed_operand.index(); + Type element_type = + indexed_operand.value()->getType().cast().getElementType(); + Type tensor_type = RankedTensorType::get({}, element_type); + for (int i : {2 * index, 2 * index + 1}) { + Type arg_type = block.getArgument(i)->getType(); + if (arg_type != tensor_type) + return op.emitOpError("comparator block argument #") + << i << " should be of type " << tensor_type << " but got " + << arg_type; + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// @@ -938,6 +1000,15 @@ void TupleOp::build(Builder* builder, OperationState& result, build(builder, result, builder->getTupleType(types), values); } +//===----------------------------------------------------------------------===// +// UnaryEinsumOp +//===----------------------------------------------------------------------===// + +void UnaryEinsumOp::getCanonicalizationPatterns( + OwningRewritePatternList& results, MLIRContext* context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // CompareOp //===----------------------------------------------------------------------===// @@ -990,3 +1061,6 @@ XlaHloDialect::XlaHloDialect(MLIRContext* context) // Support unknown operations because not all XLA operations are registered. // allowUnknownOperations(); } + +} // namespace xla_hlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index f036dec92b9..c9b3e7985fc 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -437,9 +437,6 @@ def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO let builders = [OpBuilder< "Builder *builder, OperationState &results, " "Value* value, int32_t index">]; - - // GetTupleElementOp has special conversion logic to HLO. - let hasCustomHLOConverter = 1; } def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp { @@ -730,6 +727,43 @@ def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneral let results = (outs HLO_Tensor); } +def BASE_EinsumOp { + string summary = "Einsum operator"; + + string description = [{ + Returns a tensor whose elements are defined by equation, which is written + in a shorthand form inspired by the Einstein summation convention. + }]; +} + +def HLO_EinsumOp: HLO_Op<"einsum", [NoSideEffect]> { + let arguments = (ins + HLO_Tensor:$lhs, + HLO_Tensor:$rhs, + StrAttr:$einsum_config + ); + + let results = (outs HLO_Tensor); + + // TODO(hinsu): Canonicalize to lower this client side HLO op to server + // side HLO ops. +} + +def HLO_UnaryEinsumOp: HLO_Op<"unary_einsum", [NoSideEffect]> { + let arguments = (ins + HLO_Tensor:$operand, + StrAttr:$einsum_config + ); + + let results = (outs HLO_Tensor); + + let hasCanonicalizer = 1; + + // UnarayEinsumOp is unconditionally canonicalized to the binary EinsumOp so + // the HLO converter shouldn't be invoked. + let hasCustomHLOConverter = 1; +} + def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]>, BASE_HLO_FftOp { let arguments = (ins HLO_Tensor:$operand, @@ -834,6 +868,26 @@ def HLO_SelectAndScatterOp: HLO_Op<"select_and_scatter", let hasCustomHLOConverter = 1; } +def HLO_SortOp : HLO_Op<"sort", [NoSideEffect]>, BASE_HLO_SortOp { + let arguments = (ins + Variadic:$operands, + DefaultValuedAttr:$dimension, + DefaultValuedAttr:$is_stable + ); + + let results = (outs HLO_TensorOrTuple); + + let regions = (region SizedRegion<1>:$comparator); + + let builders = [OpBuilder< + "Builder *builder, OperationState &state, ArrayRef operands, " + "int64_t dimension, bool is_stable" + >]; + + // TODO(b/129422361): SortOp has special conversion logic to HLO. + let hasCustomHLOConverter = 1; +} + def HLO_ReverseOp: HLO_Op<"reverse", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ReverseOp { let arguments = (ins diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td index e6c5aeb9dff..a6d4210b60c 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td @@ -708,7 +708,7 @@ class BASE_HLO_ClampOp { } class BASE_HLO_ConcatenateOp { - string summary = "XLA's concantenate op"; + string summary = "XLA's concatenate op"; string description = [{ Concatenates a set of tensors along the specified dimension. @@ -832,6 +832,17 @@ class BASE_HLO_SelectAndScatterOp { }]; } +class BASE_HLO_SortOp { + string summary = "Sort operator"; + + string description = [{ + Sorts the given `operands` at the given `dimension` with the given + `comparator`. + + See https://www.tensorflow.org/xla/operation_semantics#sort. + }]; +} + class BASE_HLO_ReverseOp { string summary = "Reverse operator"; diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc b/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc index 82b7032d542..7d3e2ca2384 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "mlir/IR/Attributes.h" // TF:local_config_mlir + namespace mlir { namespace xla { @@ -51,5 +53,18 @@ DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value *x, return DenseIntElementsAttr::get(type, broadcastDimensions); } +DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) { + RankedTensorType scalar_ty = RankedTensorType::get({}, ty); + + DenseElementsAttr attr; + if (auto float_ty = ty.dyn_cast()) { + APFloat value(float_ty.getFloatSemantics(), raw_value); + return DenseElementsAttr::get(scalar_ty, value); + } + auto int_ty = ty.cast(); + APInt value(int_ty.getWidth(), static_cast(raw_value), true); + return DenseElementsAttr::get(scalar_ty, value); +} + } // namespace xla } // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.h b/tensorflow/compiler/mlir/xla/ir/hlo_utils.h index d81abf6a0be..86c90b49f16 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_utils.h +++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.h @@ -18,6 +18,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/xla/convert_op_folder.h" @@ -48,6 +49,12 @@ static ElementsAttr getSplat(Builder* b, Value* val, T constant) { return DenseElementsAttr::get(valType, elementAttr); } + +// Returns DenseElementsAttr of rank zero with the given element type and the +// value. +// Requires `ty` to be either FloatType of IntegerType. +DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value); + } // namespace xla } // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.td b/tensorflow/compiler/mlir/xla/ir/hlo_utils.td index bd1a448b80f..97b29bf0851 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_utils.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.td @@ -18,20 +18,23 @@ limitations under the License. #ifndef HLO_UTILS #define HLO_UTILS -#ifndef OP_BASE include "mlir/IR/OpBase.td" -#endif // OP_BASE def NullArrayAttr : NativeCodeCall<"ArrayAttr()">; def CastIntElementsAttr : NativeCodeCall<"$0.cast()">; class ConstantSplat : NativeCodeCall< - "getSplat(&$_builder, $0, " # value # ")">; + "xla::getSplat(&$_builder, $0, " # value # ")">; def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">; def BinBroadcastDimensions : NativeCodeCall< - "getBroadcastDimensionsAttr(&$_builder, $0, $1)">; + "xla::getBroadcastDimensionsAttr(&$_builder, $0, $1)">; + +// Here, the element type can be any integer or float type. But, note that only +// 32 bit integers are supported for the value. +class GetScalarOfType : NativeCodeCall< + "xla::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">; #endif // HLO_UTILS diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index ec93bb9020b..e9bf3bac44b 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -77,6 +78,10 @@ static double ConvertAPFloat(llvm::APFloat value) { return value.convertToDouble(); } +static absl::string_view ConvertStringRef(mlir::StringRef value) { + return {value.data(), value.size()}; +} + static std::vector ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr) { auto values = attr.getValues(); return {values.begin(), values.end()}; @@ -494,13 +499,6 @@ LogicalResult ExportXlaOp(GatherOp op, OpLoweringContext ctx) { return failure(); } -LogicalResult ExportXlaOp(GetTupleElementOp op, OpLoweringContext ctx) { - auto& value_map = *ctx.values; - value_map[op] = xla::GetTupleElement(value_map[op.getOperand()], - op.index().getSExtValue()); - return success(); -} - LogicalResult ExportXlaOp(IotaOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; value_map[op] = xla::Iota(ctx.builder, xla::TypeToShape(op.getType()), @@ -626,12 +624,30 @@ LogicalResult ExportXlaOp(SliceOp op, OpLoweringContext ctx) { return failure(); } +LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) { + xla::XlaComputation comparator; + if (failed(ctx.converter->LowerRegionAsComputation(&op.comparator(), + &comparator))) + return failure(); + + auto& value_map = *ctx.values; + value_map[op] = xla::Sort(GetTuple(op.operands(), ctx), comparator, + op.dimension().getSExtValue(), op.is_stable()); + return success(); +} + LogicalResult ExportXlaOp(TupleOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; value_map[op] = xla::Tuple(ctx.builder, GetTuple(op.val(), ctx)); return success(); } +LogicalResult ExportXlaOp(UnaryEinsumOp op, OpLoweringContext ctx) { + // Intentional as UnaryEinsumOp is always lowered to the EinsumOp with two + // operands. + return failure(); +} + LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) { xla::XlaComputation condition; xla::XlaComputation body; @@ -773,7 +789,7 @@ LogicalResult ConvertToHloModule::LowerFunctionCall( LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) { if (lowered_computation_.count(f)) return success(); if (f.getBlocks().size() != 1) { - return f.emitError("only single block Function suppored"); + return f.emitError("only single block Function supported"); } // Create a sub-builder if this is not the main function. diff --git a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc index 4a9555a256a..acc3c17baf5 100644 --- a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc +++ b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc @@ -32,17 +32,20 @@ using llvm::raw_ostream; using llvm::RecordKeeper; using llvm::StringRef; using mlir::interleaveComma; +using mlir::tblgen::Attribute; using mlir::tblgen::NamedAttribute; using mlir::tblgen::NamedTypeConstraint; using mlir::tblgen::Operator; static std::string GetDefaultAttrExport( const mlir::tblgen::NamedAttribute& named_attr) { - auto storage_type = named_attr.attr.getStorageType(); + Attribute attr = named_attr.attr; + StringRef storage_type = attr.getStorageType(); // For some attribute types we have a general conversion, so use that. - if (storage_type.endswith("IntegerAttr") || - storage_type.endswith("FloatAttr")) { - return "Convert" + named_attr.attr.getReturnType().str(); + if (!attr.isEnumAttr() && (storage_type.endswith("IntegerAttr") || + storage_type.endswith("FloatAttr") || + storage_type.endswith("StringAttr"))) { + return "Convert" + attr.getReturnType().str(); } return "Convert_" + named_attr.name.str(); } diff --git a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir index e6d99b9e7d8..fa39b77918a 100644 --- a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir @@ -48,3 +48,11 @@ func @complex_collapse_fold(%arg0: tensor<4xcomplex>) -> tensor<4xcomplex> } + +// CHECK-LABEL: @unary_einsum +func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { + // CHECK: %[[ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor + // CHECK: "xla_hlo.einsum"(%[[ONE]], %arg0) {einsum_config = ",ab->aa"} + %0 = "xla_hlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 6c8737737ec..8aa9b5ef101 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -180,6 +180,27 @@ func @or_dynamic(%arg0: tensor, %arg1: tensor<1xi1>) -> tensor { return %0: tensor } +// CHECK-LABEL: func @bitwise_or +func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK-NEXT: xla_hlo.or + %0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %0: tensor<4xi32> +} + +// CHECK-LABEL: func @bitwise_or_broadcast +func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { + // CHECK-NEXT: xla_hlo.or + %0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> + return %0: tensor<1x4xi8> +} + +// CHECK-LABEL: func @bitwise_or_dynamic +func @bitwise_or_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { + // CHECK-NEXT: xla_hlo.or + %0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor + return %0: tensor +} + // CHECK-LABEL: func @pow func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-NEXT: xla_hlo.pow @@ -194,6 +215,20 @@ func @pow_dynamic(%arg0: tensor) -> tensor { return %0: tensor } +// CHECK-LABEL: func @einsum +func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> { + // CHECK: xla_hlo.einsum + %0 = "tf.Einsum"(%arg0, %arg1) {equation = "ab,bc->ac"} : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32> + return %0: tensor<2x4xf32> +} + +// CHECK-LABEL: func @unary_einsum +func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { + // CHECK: xla_hlo.unary_einsum + %0 = "tf.Einsum"(%arg0) {equation = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32> + return %0: tensor<2x2xf32> +} + // CHECK-LABEL: func @floordiv_broadcast_i32 func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { // CHECK-DAG: [[ZEROS1:%.+]] = xla_hlo.constant dense<0> @@ -589,6 +624,17 @@ func @padv2_2D(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<6x9xf32> { return %1 : tensor<6x9xf32> } +// CHECK-LABEL: func @padv2_i32_paddings +func @padv2_i32_paddings(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<6x9xf32> { + %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi32> } : () -> tensor<2x2xi32> + // CHECK: "xla_hlo.pad"(%arg0, %arg1) { + // CHECK-SAME: edge_padding_high = dense<[2, 4]> : tensor<2xi64>, + // CHECK-SAME: edge_padding_low = dense<[1, 3]> : tensor<2xi64>, + // CHECK-SAME: interior_padding = dense<0> : tensor<2xi64> + %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi32>, tensor) -> tensor<6x9xf32> + return %1 : tensor<6x9xf32> +} + //===----------------------------------------------------------------------===// // Identity op legalizations. //===----------------------------------------------------------------------===// @@ -1888,3 +1934,42 @@ func @split_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x2xf // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]] return %0#0, %0#1, %0#2 : tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32> } + +//===----------------------------------------------------------------------===// +// tf.TopKV2 legalization +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: topk_v2_non_const_k +func @topk_v2_non_const_k(%input: tensor<16xf32>, %k: tensor) -> (tensor, tensor) { + // CHECK: tf.TopKV2 + %0:2 = "tf.TopKV2"(%input, %k): (tensor<16xf32>, tensor) -> (tensor, tensor) + return %0#0, %0#1: tensor, tensor +} + +// CHECK-LABEL: topk_v2_unknown_input_last_dim +func @topk_v2_unknown_input_last_dim(%input: tensor<16x?xf32>) -> (tensor<16x?xf32>, tensor<16x?xi32>) { + %k = "tf.Const"() {value = dense<8> : tensor} : () -> tensor + // CHECK: tf.TopKV2 + %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x?xf32>, tensor) -> (tensor<16x?xf32>, tensor<16x?xi32>) + return %0#0, %0#1: tensor<16x?xf32>, tensor<16x?xi32> +} + +// CHECK-LABEL: topk_v2 +// CHECK-SAME: %[[INPUT:.*]]: tensor<16x16xf32> +func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) { + %k = "tf.Const"() {value = dense<8> : tensor} : () -> tensor + + // CHECK: %[[IOTA:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} + // CHECK-NEXT: %[[SORT:.*]] = "xla_hlo.sort"(%[[INPUT]], %[[IOTA]]) ( { + // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor, %{{.*}}: tensor, %{{.*}}: tensor): + // CHECK-NEXT: %[[CMP:.*]] = "xla_hlo.compare"(%[[LHS]], %[[RHS]]) {comparison_direction = "GT"} + // CHECK-NEXT: "xla_hlo.return"(%[[CMP]]) + // CHECK-NEXT: }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + // CHECK-NEXT: %[[TUPL0:.*]] = "xla_hlo.get_tuple_element"(%[[SORT]]) {index = 0 : i32} + // CHECK-NEXT: %[[TUPL1:.*]] = "xla_hlo.get_tuple_element"(%[[SORT]]) {index = 1 : i32} + // CHECK-NEXT: %[[VAL:.*]] = "xla_hlo.slice"(%[[TUPL0]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + // CHECK-NEXT: %[[IDX:.*]] = "xla_hlo.slice"(%[[TUPL1]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + // CHECK-NEXT: return %[[VAL]], %[[IDX]] + %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor) -> (tensor<16x8xf32>, tensor<16x8xi32>) + return %0#0, %0#1: tensor<16x8xf32>, tensor<16x8xi32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir index 00c37da8c5e..2d23a5fb1f9 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir @@ -1,7 +1,7 @@ // RUN: tf-opt -lhlo-fuse-linalg %s -o - | FileCheck %s #map0 = (d0, d1) -> (d0, d1) -#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], n_loop_types = [2, 0, 0], n_views = [2, 1]} +#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"], n_views = [2, 1]} func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>, %summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) { %temp_result = alloc() {temp = true} : memref<2x2xf32> diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir index 225fc97bb22..4f142f294e4 100644 --- a/tensorflow/compiler/mlir/xla/tests/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir @@ -416,3 +416,98 @@ func @constants() -> () { %3 = "xla_hlo.constant"() {extra_attr = 3 : i32, value = dense<0> : tensor} : () -> (tensor<*xi32>) return } + +// ----- + +func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { + // CHECK: xla_hlo.sort + %0 = "xla_hlo.sort"(%input0, %input1) ( { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + return +} + +// ----- + +func @sort_no_operands() { + // expected-error @+1 {{op requires at least one input}} + %0 = "xla_hlo.sort"() ( { + ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): + %7 = "xla_hlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : () -> tuple<> + return +} + +// ----- + +func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { + %0 = "xla_hlo.sort"(%input0, %input1) ( { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + return +} + +// ----- + +func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { + // expected-error @+1 {{comparator block argument #0 should be of type 'tensor' but got 'tensor'}} + %0 = "xla_hlo.sort"(%input0, %input1) ( { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + return +} + +// ----- + +func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>) { + // expected-error @+1 {{op requires all inputs to have the same dimensions}} + %0 = "xla_hlo.sort"(%input0, %input1) ( { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x8xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + return +} + +// ----- + +func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { + // expected-error @+1 {{op dimension attribute value must be less than input rank}} + %0 = "xla_hlo.sort"(%input0, %input1) ( { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 10 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + return +} + +// ----- + +func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { + // expected-error @+1 {{op comparator block should have 4 arguments}} + %0 = "xla_hlo.sort"(%input0, %input1) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + return +} + +// ----- + +func @sort_wrong_block_arg_type(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { + // expected-error @+1 {{op comparator block argument #3 should be of type 'tensor' but got 'tensor'}} + %0 = "xla_hlo.sort"(%input0, %input1) ( { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + return +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/all_reduce.mlir b/tensorflow/compiler/mlir/xla/tests/translate/all_reduce.mlir deleted file mode 100644 index 6c418799da8..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/all_reduce.mlir +++ /dev/null @@ -1,26 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { - %0 = "xla_hlo.all_reduce"(%arg0) ({ - // Perform max reduction inside the region - ^bb0(%lhs: tensor, %rhs: tensor): - %max = xla_hlo.max %lhs, %rhs : tensor - "xla_hlo.return"(%max) : (tensor) -> () - }) - { - replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, - channel_id = { - handle = 5 : i64, - type = 2 : i64 - } - } : (tensor<10xf32>) -> tensor<10xf32> - return %0 : tensor<10xf32> -} - -// CHECK: %[[COMPUTATION:.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[] -// CHECK-LABEL: ENTRY -// CHECK: %[[ARG0:.*]] = f32[10] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]]) -// CHECK-SAME: channel_id=5 -// CHECK-SAME: replica_groups={{[{][{]}}0,2,4,6},{1,3,5,7{{[}][}]}} -// CHECK-SAME: to_apply=%[[COMPUTATION]] diff --git a/tensorflow/compiler/mlir/xla/tests/translate/batch_norm_grad.mlir b/tensorflow/compiler/mlir/xla/tests/translate/batch_norm_grad.mlir deleted file mode 100644 index fff194c627b..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/batch_norm_grad.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>> { - %0 = "xla_hlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>> - return %0 : tuple, tensor<2xf32>, tensor<2xf32>> -} - -// CHECK-LABEL: ENTRY -// CHECK: [[VAL_1:%.*]] = f32[2,2,2,2] parameter(0) -// CHECK: [[VAL_2:%.*]] = f32[2] parameter(1) -// CHECK: [[VAL_3:%.*]] = f32[2] parameter(2) -// CHECK: [[VAL_4:%.*]] = f32[2] parameter(3) -// CHECK: [[VAL_5:%.*]] = f32[2,2,2,2] parameter(4) -// CHECK-LABEL: ROOT -// CHECK-SAME: (f32[2,2,2,2], f32[2], f32[2]) batch-norm-grad(f32[2,2,2,2] [[VAL_1]], f32[2] [[VAL_2]], f32[2] [[VAL_3]], f32[2] [[VAL_4]], f32[2,2,2,2] [[VAL_5]]), epsilon=0.001, feature_index=0 diff --git a/tensorflow/compiler/mlir/xla/tests/translate/batch_norm_training.mlir b/tensorflow/compiler/mlir/xla/tests/translate/batch_norm_training.mlir deleted file mode 100644 index d51e801b438..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/batch_norm_training.mlir +++ /dev/null @@ -1,13 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %offset: tensor<2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>> { - %0 = "xla_hlo.batch_norm_training" (%input, %scale, %offset) {epsilon = 0.001 : f32, feature_index = 3 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>> - return %0 : tuple, tensor<2xf32>, tensor<2xf32>> -} - -// CHECK-LABEL: ENTRY -// CHECK: [[VAL_1:%.*]] = f32[2,2,2,2] parameter(0) -// CHECK: [[VAL_2:%.*]] = f32[2] parameter(1) -// CHECK: [[VAL_3:%.*]] = f32[2] parameter(2) -// CHECK-LABEL: ROOT -// CHECK-SAME: (f32[2,2,2,2], f32[2], f32[2]) batch-norm-training(f32[2,2,2,2] [[VAL_1]], f32[2] [[VAL_2]], f32[2] [[VAL_3]]), epsilon=0.001, feature_index=3 diff --git a/tensorflow/compiler/mlir/xla/tests/translate/binary_arithmetic.mlir b/tensorflow/compiler/mlir/xla/tests/translate/binary_arithmetic.mlir deleted file mode 100644 index 50f10739816..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/binary_arithmetic.mlir +++ /dev/null @@ -1,23 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -module { - func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) { - // CHECK: [[VAL_1:%.*]] = s32[4] parameter(0) - // CHECK: [[VAL_2:%.*]] = s32[4] parameter(1) - // CHECK: [[ATAN2:%.*]] = s32[4] atan2(s32[4] [[VAL_1]], s32[4] [[VAL_2]]) - %0 = xla_hlo.atan2 %arg0, %arg1 : tensor<4xi32> - - // CHECK: [[SHL:%.*]] = s32[4] shift-left(s32[4] [[VAL_1]], s32[4] [[VAL_2]]) - %1 = xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32> - - // CHECK: [[SHRA:%.*]] = s32[4] shift-right-arithmetic(s32[4] [[VAL_1]], s32[4] [[VAL_2]]) - %2 = xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> - - // CHECK: [[SHRL:%.*]] = s32[4] shift-right-logical(s32[4] [[VAL_1]], s32[4] [[VAL_2]]) - %3 = xla_hlo.shift_right_logical %arg0, %arg1 : tensor<4xi32> - - // CHECK-LABEL: ROOT - // CHECK-SAME: [[VAL_7:%.*]] = (s32[4], s32[4], s32[4], s32[4]) tuple(s32[4] [[ATAN2]], s32[4] [[SHL]], s32[4] [[SHRA]], s32[4] [[SHRL]]) - return %0, %1, %2, %3 : tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32> - } -} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/binary_op_broadcast.mlir b/tensorflow/compiler/mlir/xla/tests/translate/binary_op_broadcast.mlir deleted file mode 100644 index 38aa4f04bad..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/binary_op_broadcast.mlir +++ /dev/null @@ -1,26 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -// CHECK-LABEL: ENTRY %main.13 (Arg_0.1: s32[1,4], Arg_1.2: s32[2,4], Arg_2.3: s32[2,3,4]) -> s32[2,3,4] { -func @main(%arg0: tensor<1x4xi32>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3x4xi32>) -> tensor<2x3x4xi32> { - // Same rank degenerate broadcast - // CHECK-NEXT: %Arg_0.1 = s32[1,4] parameter(0) - // CHECK-NEXT: %reshape.4 = s32[4] reshape(s32[1,4] %Arg_0.1) - // CHECK-NEXT: %broadcast.5 = s32[2,4] broadcast(s32[4] %reshape.4) - // CHECK-NEXT: %Arg_1.2 = s32[2,4] parameter(1) - // CHECK-NEXT: %add.6 = s32[2,4] add(s32[2,4] %broadcast.5, s32[2,4] %Arg_1.2) - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<1x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> - - // Broadcast up rank - // CHECK-NEXT: %broadcast.7 = s32[2,3,4] broadcast(s32[2,4] %Arg_1.2), dimensions={0,2} - // CHECK-NEXT: %Arg_2.3 = s32[2,3,4] parameter(2) - // CHECK-NEXT: %add.8 = s32[2,3,4] add(s32[2,3,4] %broadcast.7, s32[2,3,4] %Arg_2.3) - %1 = "xla_hlo.add"(%arg1, %arg2) {broadcast_dimensions = dense<[0,2]> : tensor<2xi64>} : (tensor<2x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32> - - // Broadcast up rank + degenerate broadcast - // CHECK-NEXT: %broadcast.9 = s32[2,1,4] broadcast(s32[1,4] %Arg_0.1), dimensions={1,2} - // CHECK-NEXT: %reshape.10 = s32[2,4] reshape(s32[2,1,4] %broadcast.9) - // CHECK-NEXT: %broadcast.11 = s32[2,3,4] broadcast(s32[2,4] %reshape.10), dimensions={0,2} - // CHECK-NEXT: ROOT %add.12 = s32[2,3,4] add(s32[2,3,4] %broadcast.11, s32[2,3,4] %Arg_2.3) - %2 = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<[1,2]> : tensor<2xi64>} : (tensor<1x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32> - return %2 : tensor<2x3x4xi32> -} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/broadcast.mlir b/tensorflow/compiler/mlir/xla/tests/translate/broadcast.mlir deleted file mode 100644 index 0b64ab23d54..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/broadcast.mlir +++ /dev/null @@ -1,9 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -// CHECK-LABEL: ENTRY %main.3 (Arg_0.1: s32[4]) -> s32[1,2,3,4] { -func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> { - // CHECK-NEXT: %Arg_0.1 = s32[4] parameter(0) - // CHECK-NEXT: ROOT %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] %Arg_0.1), dimensions={3} - %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32> - return %0 : tensor<1x2x3x4xi32> -} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/broadcast_in_dim.mlir b/tensorflow/compiler/mlir/xla/tests/translate/broadcast_in_dim.mlir deleted file mode 100644 index ac53ba9dbbe..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/broadcast_in_dim.mlir +++ /dev/null @@ -1,12 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -func @main(%arg0: tensor<1xf32>) -> tensor<1x10xf32> { - %result = "xla_hlo.broadcast_in_dim"(%arg0) { - broadcast_dimensions = dense<0> : tensor<1xi64> - } : (tensor<1xf32>) -> tensor<1x10xf32> - return %result : tensor<1x10xf32> -} - -// CHECK: ENTRY %main.3 ([[ARG0:.*]]: f32[1]) -> f32[1,10] { -// CHECK: %[[ARG0]] = f32[1] parameter(0) -// CHECK: ROOT %broadcast.2 = f32[1,10] broadcast(f32[1] %[[ARG0]]), dimensions={0} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/call.mlir b/tensorflow/compiler/mlir/xla/tests/translate/call.mlir deleted file mode 100644 index e9cfefc308d..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/call.mlir +++ /dev/null @@ -1,30 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -func @main(%arg0: tensor<4xi32>) -> tensor<4xi32> { - %0 = call @callee(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - %1 = call @callee(%0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - return %1 : tensor<4xi32> -} - -func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - return %0 : tensor<4xi32> -} - -// CHECK: [[CALLEE_1:%.*]] ([[ARG_1:.*]]: s32[4], [[ARG_2:.*]]: s32[4]) -> s32[4] { -// CHECK: %[[ARG_1]] = s32[4] parameter(0) -// CHECK: %[[ARG_2]] = s32[4] parameter(1) -// CHECK-LABEL: ROOT -// CHECK-SAME: s32[4] add(s32[4] %[[ARG_1]], s32[4] %[[ARG_2]]) - -// CHECK: [[CALLEE_2:%.*]] ([[ARG_3:.*]]: s32[4], [[ARG_4:.*]]: s32[4]) -> s32[4] { -// CHECK: %[[ARG_3]] = s32[4] parameter(0) -// CHECK: %[[ARG_4]] = s32[4] parameter(1) -// CHECK-LABEL: ROOT -// CHECK-SAME: s32[4] add(s32[4] %[[ARG_3]], s32[4] %[[ARG_4]]) - -// CHECK: ENTRY [[MAIN:%.*]] ([[ARG:.*]]: s32[4]) -> s32[4] { -// CHECK: %[[ARG]] = s32[4] parameter(0) -// CHECK: [[CALL_OUT:%.*]] = s32[4] call(s32[4] %[[ARG]], s32[4] %[[ARG]]), to_apply=[[CALLEE_1]] -// CHECK-LABEL: ROOT -// CHECK-SAME: s32[4] call(s32[4] [[CALL_OUT]], s32[4] [[CALL_OUT]]), to_apply=[[CALLEE_2]] diff --git a/tensorflow/compiler/mlir/xla/tests/translate/call_multiple_results.mlir b/tensorflow/compiler/mlir/xla/tests/translate/call_multiple_results.mlir deleted file mode 100644 index 3276cb71090..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/call_multiple_results.mlir +++ /dev/null @@ -1,24 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -func @main(%arg0: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) { - %0:2 = call @callee(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) - return %0#0, %0#1 : tensor<4xi32>, tensor<4xi32> -} - -func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) { - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - %1 = "xla_hlo.mul"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - return %0, %1 : tensor<4xi32>, tensor<4xi32> -} - -// Get name of callee computation -// CHECK: [[CALLEE:%.*]] ({{.*}}) -> ({{.*}}) { - -// CHECK-LABEL: ENTRY -// CHECK-SAME: [[MAIN:%.*]] ([[ARG:.*]]: s32[4]) -> (s32[4], s32[4]) { -// CHECK: %[[ARG]] = s32[4] parameter(0) -// CHECK: [[CALL_OUT:%.*]] = (s32[4], s32[4]) call(s32[4] %[[ARG]], s32[4] %[[ARG]]), to_apply=[[CALLEE]] -// CHECK: [[OUT_0:%.*]] = s32[4] get-tuple-element((s32[4], s32[4]) [[CALL_OUT]]), index=0 -// CHECK: [[OUT_1:%.*]] = s32[4] get-tuple-element((s32[4], s32[4]) [[CALL_OUT]]), index=1 -// CHECK-LABEL: ROOT -// CHECK-SAME: (s32[4], s32[4]) tuple(s32[4] [[OUT_0]], s32[4] [[OUT_1]]) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/concatenate.mlir b/tensorflow/compiler/mlir/xla/tests/translate/concatenate.mlir deleted file mode 100644 index 593c2e2f4e6..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/concatenate.mlir +++ /dev/null @@ -1,17 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -func @main(%arg0 : tensor<5x2xf32>, - %arg1 : tensor<5x5xf32>, - %arg2 : tensor<5x7xf32>) -> tensor<5x14xf32> { - %result = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) { - dimension = 1 : i64 - } : (tensor<5x2xf32>, tensor<5x5xf32>, tensor<5x7xf32>) -> tensor<5x14xf32> - return %result : tensor<5x14xf32> -} - - -// CHECK-LABEL: main -// CHECK: %[[ARG0:.*]] = f32[5,2] parameter(0) -// CHECK: %[[ARG1:.*]] = f32[5,5] parameter(1) -// CHECK: %[[ARG2:.*]] = f32[5,7] parameter(2) -// CHECK: ROOT %[[RESULT:.*]] = f32[5,14] concatenate(f32[5,2] %[[ARG0]], f32[5,5] %[[ARG1]], f32[5,7] %[[ARG2]]), dimensions={1} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/conditional.mlir b/tensorflow/compiler/mlir/xla/tests/translate/conditional.mlir index 1eb1e5ca7a5..e69d677a8cc 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/conditional.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/conditional.mlir @@ -42,13 +42,13 @@ func @main(%arg0: tensor) -> tuple> { // CHECK: %[[VAL3:.+]] = (f32[]) conditional(pred[] %[[VAL1]], (f32[]) %[[VAL2]], (f32[]) %[[VAL2]]), true_computation=[[R0]], false_computation=[[R1]] %2 = "xla_hlo.conditional"(%0, %1, %1) ( { - ^bb0(%arg1: tuple>): // no predecessors + ^bb0(%arg1: tuple>): %6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor %7 = "xla_hlo.log"(%6) : (tensor) -> tensor %8 = "xla_hlo.tuple"(%7) : (tensor) -> tuple> "xla_hlo.return"(%8) : (tuple>) -> () }, { - ^bb0(%arg1: tuple>): // no predecessors + ^bb0(%arg1: tuple>): %6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor %7 = "xla_hlo.exp"(%6) : (tensor) -> tensor %8 = "xla_hlo.tuple"(%7) : (tensor) -> tuple> diff --git a/tensorflow/compiler/mlir/xla/tests/translate/const.mlir b/tensorflow/compiler/mlir/xla/tests/translate/const.mlir deleted file mode 100644 index 42d9c5dc963..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/const.mlir +++ /dev/null @@ -1,30 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s --dump-input-on-failure - -// CHECK-LABEL: ENTRY %main -func @main() -> tensor<2x2x1x1xf32> { - // CHECK: constant.{{.*}} = s64[] constant(1) - %cst = constant dense<1> : tensor - // CHECK: constant.{{.*}} = f32[2,2,1,1] - // CHECK-SAME: { { /*i0=0*/ { /*i1=0*/ {1} }, { /*i1=1*/ {2} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {4} } } } - %cst_0 = constant dense< - [[[[1.000000e+00]], [[2.000000e+00]]], [[[3.000000e+00]], [[4.000000e+00]]]] - > : tensor<2x2x1x1xf32> - - // CHECK: s32[1] constant({1}) - %cst_1 = constant dense<1> : tensor<1xi32> - - // CHECK: %[[C:.*]] = s32[] constant(1) - // CHECK: s32[10] broadcast(s32[] %[[C]]) - %cst_2 = constant dense<1> : tensor<10xi32> - - // CHECK: s32[4] constant({1, 2, 3, 4}) - %cst_3 = constant dense<[1, 2, 3, 4]> : tensor<4xi32> - - // CHECK: s32[2,2] constant({ { 1, 2 }, { 3, 4 } }) - %cst_4 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> - - // CHECK: s32[2,2] constant({ { 3, 2 }, { 1, 4 } }) - %cst_5 = constant dense<[[3, 2], [1, 4]]> : tensor<2x2xi32> - - return %cst_0 : tensor<2x2x1x1xf32> -} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/conv.mlir b/tensorflow/compiler/mlir/xla/tests/translate/conv.mlir deleted file mode 100644 index 5cdc65b49af..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/conv.mlir +++ /dev/null @@ -1,31 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -func @main(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> { - %result = "xla_hlo.conv"(%arg0, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = { - input_batch_dimension = 0 : i64, - input_feature_dimension = 3 : i64, - input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, - kernel_input_feature_dimension = 3 : i64, - kernel_output_feature_dimension = 2 : i64, - kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, - output_batch_dimension = 0 : i64, - output_feature_dimension = 3 : i64, - output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64> - }, - feature_group_count = 1 : i64, - lhs_dilation = dense<1> : tensor<2xi64>, - padding = dense<2> : tensor<2x2xi64>, - rhs_dilation = dense<1> : tensor<2xi64>, - window_strides = dense<1> : tensor<2xi64> - } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> - return %result : tensor<100x28x28x1xf32> -} - -// CHECK-LABEL: main -// CHECK: %[[ARG0:.*]] = f32[100,26,26,32] parameter(0) -// CHECK: %[[ARG1:.*]] = f32[3,3,1,32] parameter(1) -// CHECK: ROOT %[[RESULT:.*]] = f32[100,28,28,1] convolution(f32[100,26,26,32] %[[ARG0]], f32[3,3,1,32] %[[ARG1]]), -// CHECK-SAME: window={size=3x3 pad=2_2x2_2}, -// CHECK-SAME: dim_labels=b01f_01oi->b01f diff --git a/tensorflow/compiler/mlir/xla/tests/translate/convert.mlir b/tensorflow/compiler/mlir/xla/tests/translate/convert.mlir deleted file mode 100644 index dd839df38b2..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/convert.mlir +++ /dev/null @@ -1,10 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> { - %0 = "xla_hlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> - return %0 : tensor<2xf32> -} - -// CHECK: ENTRY %main -// CHECK: %[[ARG:.*]] = s32[2] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = f32[2] convert(s32[2] %[[ARG]]) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/copy.mlir b/tensorflow/compiler/mlir/xla/tests/translate/copy.mlir deleted file mode 100644 index f6e5ef8fd98..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/copy.mlir +++ /dev/null @@ -1,10 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = "xla_hlo.copy"(%arg0) : (tensor<2xi32>) -> tensor<2xi32> - return %0 : tensor<2xi32> -} - -// CHECK: ENTRY %main -// CHECK: [[ARG:%.*]] = s32[2] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = s32[2] copy(s32[2] [[ARG]]) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/cross_replica_sum.mlir b/tensorflow/compiler/mlir/xla/tests/translate/cross_replica_sum.mlir deleted file mode 100644 index 2e094c76516..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/cross_replica_sum.mlir +++ /dev/null @@ -1,16 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { - %0 = xla_hlo.constant dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32> - %1 = "xla_hlo.cross-replica-sum"(%arg0) {replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>} : (tensor<10xf32>) -> tensor<10xf32> - return %1 : tensor<10xf32> -} - -// CHECK: %[[SUM_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> f32[] -// CHECK: ROOT %[[RESULT:.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]]) - -// CHECK: ENTRY %main -// CHECK: %[[ARG0:.*]] = f32[10] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]]) -// CHECK-SAME: replica_groups={{[{][{]}}0,2,4,6},{1,3,5,7{{[}][}]}} -// CHECK-SAME: to_apply=%[[SUM_COMPUTATION]] diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir new file mode 100644 index 00000000000..ffcc1cc9df3 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -0,0 +1,640 @@ +// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s + +// CHECK-LABEL: HloModule +func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { + %0 = "xla_hlo.all_reduce"(%arg0) ({ + // Perform max reduction inside the region + ^bb0(%lhs: tensor, %rhs: tensor): + %max = xla_hlo.max %lhs, %rhs : tensor + "xla_hlo.return"(%max) : (tensor) -> () + }) + { + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, + channel_id = { + handle = 5 : i64, + type = 2 : i64 + } + } : (tensor<10xf32>) -> tensor<10xf32> + return %0 : tensor<10xf32> +} + +// CHECK: %[[COMPUTATION:.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[] +// CHECK-LABEL: ENTRY +// CHECK: %[[ARG0:.*]] = f32[10] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]]) +// CHECK-SAME: channel_id=5 +// CHECK-SAME: replica_groups={{[{][{]}}0,2,4,6},{1,3,5,7{{[}][}]}} +// CHECK-SAME: to_apply=%[[COMPUTATION]] + +// ----- + +// CHECK-LABEL: HloModule +func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>> { + %0 = "xla_hlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>> + return %0 : tuple, tensor<2xf32>, tensor<2xf32>> +} + +// CHECK-LABEL: ENTRY +// CHECK: [[VAL_1:%.*]] = f32[2,2,2,2] parameter(0) +// CHECK: [[VAL_2:%.*]] = f32[2] parameter(1) +// CHECK: [[VAL_3:%.*]] = f32[2] parameter(2) +// CHECK: [[VAL_4:%.*]] = f32[2] parameter(3) +// CHECK: [[VAL_5:%.*]] = f32[2,2,2,2] parameter(4) +// CHECK-LABEL: ROOT +// CHECK-SAME: (f32[2,2,2,2], f32[2], f32[2]) batch-norm-grad(f32[2,2,2,2] [[VAL_1]], f32[2] [[VAL_2]], f32[2] [[VAL_3]], f32[2] [[VAL_4]], f32[2,2,2,2] [[VAL_5]]), epsilon=0.001, feature_index=0 + +// ----- + +// CHECK-LABEL: HloModule +func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %offset: tensor<2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>> { + %0 = "xla_hlo.batch_norm_training" (%input, %scale, %offset) {epsilon = 0.001 : f32, feature_index = 3 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tuple, tensor<2xf32>, tensor<2xf32>> + return %0 : tuple, tensor<2xf32>, tensor<2xf32>> +} + +// CHECK-LABEL: ENTRY +// CHECK: [[VAL_1:%.*]] = f32[2,2,2,2] parameter(0) +// CHECK: [[VAL_2:%.*]] = f32[2] parameter(1) +// CHECK: [[VAL_3:%.*]] = f32[2] parameter(2) +// CHECK-LABEL: ROOT +// CHECK-SAME: (f32[2,2,2,2], f32[2], f32[2]) batch-norm-training(f32[2,2,2,2] [[VAL_1]], f32[2] [[VAL_2]], f32[2] [[VAL_3]]), epsilon=0.001, feature_index=3 + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) { + // CHECK: [[VAL_1:%.*]] = s32[4] parameter(0) + // CHECK: [[VAL_2:%.*]] = s32[4] parameter(1) + // CHECK: [[ATAN2:%.*]] = s32[4] atan2(s32[4] [[VAL_1]], s32[4] [[VAL_2]]) + %0 = xla_hlo.atan2 %arg0, %arg1 : tensor<4xi32> + + // CHECK: [[SHL:%.*]] = s32[4] shift-left(s32[4] [[VAL_1]], s32[4] [[VAL_2]]) + %1 = xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32> + + // CHECK: [[SHRA:%.*]] = s32[4] shift-right-arithmetic(s32[4] [[VAL_1]], s32[4] [[VAL_2]]) + %2 = xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> + + // CHECK: [[SHRL:%.*]] = s32[4] shift-right-logical(s32[4] [[VAL_1]], s32[4] [[VAL_2]]) + %3 = xla_hlo.shift_right_logical %arg0, %arg1 : tensor<4xi32> + + // CHECK-LABEL: ROOT + // CHECK-SAME: [[VAL_7:%.*]] = (s32[4], s32[4], s32[4], s32[4]) tuple(s32[4] [[ATAN2]], s32[4] [[SHL]], s32[4] [[SHRA]], s32[4] [[SHRL]]) + return %0, %1, %2, %3 : tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32> +} + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg0: tensor<1x4xi32>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3x4xi32>) -> tensor<2x3x4xi32> { + // Same rank degenerate broadcast + // CHECK: [[ARG_0:%.*]] = s32[1,4] parameter(0) + // CHECK-NEXT: [[RESHAPE_1:%.*]] = s32[4] reshape(s32[1,4] [[ARG_0]]) + // CHECK-NEXT: [[BROADCAST_1:%.*]] = s32[2,4] broadcast(s32[4] [[RESHAPE_1]]) + // CHECK-NEXT: [[ARG_1:%.*]] = s32[2,4] parameter(1) + // CHECK-NEXT: s32[2,4] add(s32[2,4] [[BROADCAST_1]], s32[2,4] [[ARG_1]]) + %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<1x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> + + // Broadcast up rank + // CHECK-NEXT: [[BROADCAST_2:%.*]] = s32[2,3,4] broadcast(s32[2,4] [[ARG_1]]), dimensions={0,2} + // CHECK-NEXT: [[ARG_2:%.*]] = s32[2,3,4] parameter(2) + // CHECK-NEXT: s32[2,3,4] add(s32[2,3,4] [[BROADCAST_2]], s32[2,3,4] [[ARG_2]]) + %1 = "xla_hlo.add"(%arg1, %arg2) {broadcast_dimensions = dense<[0,2]> : tensor<2xi64>} : (tensor<2x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32> + + // Broadcast up rank + degenerate broadcast + // CHECK-NEXT: [[BROADCAST_3:%.*]] = s32[2,1,4] broadcast(s32[1,4] [[ARG_0]]), dimensions={1,2} + // CHECK-NEXT: [[RESHAPE_2:%.*]] = s32[2,4] reshape(s32[2,1,4] [[BROADCAST_3]]) + // CHECK-NEXT: [[BROADCAST_4:%.*]] = s32[2,3,4] broadcast(s32[2,4] [[RESHAPE_2]]), dimensions={0,2} + // CHECK-LABEL: ROOT + // CHECK-SAME: s32[2,3,4] add(s32[2,3,4] [[BROADCAST_4]], s32[2,3,4] [[ARG_2]]) + %2 = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<[1,2]> : tensor<2xi64>} : (tensor<1x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32> + return %2 : tensor<2x3x4xi32> +} + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> { + // CHECK: [[ARG:%.*]] = s32[4] parameter(0) + // CHECK-NEXT: ROOT %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] [[ARG]]), dimensions={3} + %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32> + return %0 : tensor<1x2x3x4xi32> +} + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg0: tensor<1xf32>) -> tensor<1x10xf32> { + %result = "xla_hlo.broadcast_in_dim"(%arg0) { + broadcast_dimensions = dense<0> : tensor<1xi64> + } : (tensor<1xf32>) -> tensor<1x10xf32> + return %result : tensor<1x10xf32> +} + +// CHECK-LABEL: ENTRY +// CHECK: [[ARG:%.*]] = f32[1] parameter(0) +// CHECK: ROOT %broadcast.2 = f32[1,10] broadcast(f32[1] [[ARG]]), dimensions={0} + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg0: tensor<4xi32>) -> tensor<4xi32> { + %0 = call @callee(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %1 = call @callee(%0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %1 : tensor<4xi32> +} + +func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// CHECK: [[CALLEE_1:%.*]] ([[ARG_1:.*]]: s32[4], [[ARG_2:.*]]: s32[4]) -> s32[4] { +// CHECK: %[[ARG_1]] = s32[4] parameter(0) +// CHECK: %[[ARG_2]] = s32[4] parameter(1) +// CHECK-LABEL: ROOT +// CHECK-SAME: s32[4] add(s32[4] %[[ARG_1]], s32[4] %[[ARG_2]]) + +// CHECK: [[CALLEE_2:%.*]] ([[ARG_3:.*]]: s32[4], [[ARG_4:.*]]: s32[4]) -> s32[4] { +// CHECK: %[[ARG_3]] = s32[4] parameter(0) +// CHECK: %[[ARG_4]] = s32[4] parameter(1) +// CHECK-LABEL: ROOT +// CHECK-SAME: s32[4] add(s32[4] %[[ARG_3]], s32[4] %[[ARG_4]]) + +// CHECK: ENTRY [[MAIN:%.*]] ([[ARG:.*]]: s32[4]) -> s32[4] { +// CHECK: %[[ARG]] = s32[4] parameter(0) +// CHECK: [[CALL_OUT:%.*]] = s32[4] call(s32[4] %[[ARG]], s32[4] %[[ARG]]), to_apply=[[CALLEE_1]] +// CHECK-LABEL: ROOT +// CHECK-SAME: s32[4] call(s32[4] [[CALL_OUT]], s32[4] [[CALL_OUT]]), to_apply=[[CALLEE_2]] + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg0: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) { + %0:2 = call @callee(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) + return %0#0, %0#1 : tensor<4xi32>, tensor<4xi32> +} + +func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) { + %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %1 = "xla_hlo.mul"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %0, %1 : tensor<4xi32>, tensor<4xi32> +} + +// Get name of callee computation +// CHECK: [[CALLEE:%.*]] ({{.*}}) -> ({{.*}}) { + +// CHECK-LABEL: ENTRY +// CHECK-SAME: [[MAIN:%.*]] ([[ARG:.*]]: s32[4]) -> (s32[4], s32[4]) { +// CHECK: %[[ARG]] = s32[4] parameter(0) +// CHECK: [[CALL_OUT:%.*]] = (s32[4], s32[4]) call(s32[4] %[[ARG]], s32[4] %[[ARG]]), to_apply=[[CALLEE]] +// CHECK: [[OUT_0:%.*]] = s32[4] get-tuple-element((s32[4], s32[4]) [[CALL_OUT]]), index=0 +// CHECK: [[OUT_1:%.*]] = s32[4] get-tuple-element((s32[4], s32[4]) [[CALL_OUT]]), index=1 +// CHECK-LABEL: ROOT +// CHECK-SAME: (s32[4], s32[4]) tuple(s32[4] [[OUT_0]], s32[4] [[OUT_1]]) + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg0 : tensor<5x2xf32>, + %arg1 : tensor<5x5xf32>, + %arg2 : tensor<5x7xf32>) -> tensor<5x14xf32> { + %result = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) { + dimension = 1 : i64 + } : (tensor<5x2xf32>, tensor<5x5xf32>, tensor<5x7xf32>) -> tensor<5x14xf32> + return %result : tensor<5x14xf32> +} + +// CHECK-LABEL: ENTRY +// CHECK: %[[ARG0:.*]] = f32[5,2] parameter(0) +// CHECK: %[[ARG1:.*]] = f32[5,5] parameter(1) +// CHECK: %[[ARG2:.*]] = f32[5,7] parameter(2) +// CHECK: ROOT %[[RESULT:.*]] = f32[5,14] concatenate(f32[5,2] %[[ARG0]], f32[5,5] %[[ARG1]], f32[5,7] %[[ARG2]]), dimensions={1} + +// ----- + +// CHECK-LABEL: HloModule +func @main() -> tensor<2x2x1x1xf32> { + // CHECK: constant.{{.*}} = s64[] constant(1) + %cst = constant dense<1> : tensor + // CHECK: constant.{{.*}} = f32[2,2,1,1] + // CHECK-SAME: { { /*i0=0*/ { /*i1=0*/ {1} }, { /*i1=1*/ {2} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {4} } } } + %cst_0 = constant dense< + [[[[1.000000e+00]], [[2.000000e+00]]], [[[3.000000e+00]], [[4.000000e+00]]]] + > : tensor<2x2x1x1xf32> + + // CHECK: s32[1] constant({1}) + %cst_1 = constant dense<1> : tensor<1xi32> + + // CHECK: %[[C:.*]] = s32[] constant(1) + // CHECK: s32[10] broadcast(s32[] %[[C]]) + %cst_2 = constant dense<1> : tensor<10xi32> + + // CHECK: s32[4] constant({1, 2, 3, 4}) + %cst_3 = constant dense<[1, 2, 3, 4]> : tensor<4xi32> + + // CHECK: s32[2,2] constant({ { 1, 2 }, { 3, 4 } }) + %cst_4 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> + + // CHECK: s32[2,2] constant({ { 3, 2 }, { 1, 4 } }) + %cst_5 = constant dense<[[3, 2], [1, 4]]> : tensor<2x2xi32> + + return %cst_0 : tensor<2x2x1x1xf32> +} + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> { + %result = "xla_hlo.conv"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = { + input_batch_dimension = 0 : i64, + input_feature_dimension = 3 : i64, + input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, + kernel_input_feature_dimension = 3 : i64, + kernel_output_feature_dimension = 2 : i64, + kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, + output_batch_dimension = 0 : i64, + output_feature_dimension = 3 : i64, + output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64> + }, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<2> : tensor<2x2xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> + return %result : tensor<100x28x28x1xf32> +} + +// CHECK-LABEL: ENTRY +// CHECK: %[[ARG0:.*]] = f32[100,26,26,32] parameter(0) +// CHECK: %[[ARG1:.*]] = f32[3,3,1,32] parameter(1) +// CHECK: ROOT %[[RESULT:.*]] = f32[100,28,28,1] convolution(f32[100,26,26,32] %[[ARG0]], f32[3,3,1,32] %[[ARG1]]), +// CHECK-SAME: window={size=3x3 pad=2_2x2_2}, +// CHECK-SAME: dim_labels=b01f_01oi->b01f + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> { + %0 = "xla_hlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// CHECK-LABEL: ENTRY +// CHECK: %[[ARG:.*]] = s32[2] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = f32[2] convert(s32[2] %[[ARG]]) + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { + %0 = "xla_hlo.copy"(%arg0) : (tensor<2xi32>) -> tensor<2xi32> + return %0 : tensor<2xi32> +} + +// CHECK-LABEL: ENTRY +// CHECK: [[ARG:%.*]] = s32[2] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = s32[2] copy(s32[2] [[ARG]]) + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { + %0 = xla_hlo.constant dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32> + %1 = "xla_hlo.cross-replica-sum"(%arg0) {replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>} : (tensor<10xf32>) -> tensor<10xf32> + return %1 : tensor<10xf32> +} + +// CHECK: %[[SUM_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> f32[] +// CHECK: ROOT %[[RESULT:.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]]) + +// CHECK-LABEL: ENTRY +// CHECK: %[[ARG0:.*]] = f32[10] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]]) +// CHECK-SAME: replica_groups={{[{][{]}}0,2,4,6},{1,3,5,7{{[}][}]}} +// CHECK-SAME: to_apply=%[[SUM_COMPUTATION]] + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg0: tensor<3x4xi32>, %arg1: tensor<4x5xi32>) -> tensor<3x5xi32> { + // Simple einsum is lowered to HLO dot op. + // CHECK: dot(s32[3,4] %{{.*}}, s32[4,5] %{{.*}}), lhs_contracting_dims={1}, rhs_contracting_dims={0} + %0 = "xla_hlo.einsum"(%arg0, %arg1) {einsum_config = "ab,bc->ac"} : (tensor<3x4xi32>, tensor<4x5xi32>) -> tensor<3x5xi32> + return %0 : tensor<3x5xi32> +} + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg: tensor<4x2xf32>) -> tensor { + %0 = "xla_hlo.get_dimension_size"(%arg) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: ENTRY +// CHECK: [[ARG:%.*]] = f32[4,2] parameter(0) +// CHECK: s32[] get-dimension-size(f32[4,2] [[ARG]]), dimensions={1} + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg0: tuple, tensor>) -> tensor { + %0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, tensor>) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: ENTRY +// CHECK: %[[ARG0:.*]] = (f32[], s32[]) parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = f32[] get-tuple-element((f32[], s32[]) %[[ARG0]]), index=0 + +// ----- + +// CHECK-LABEL: HloModule +func @main() -> tensor<1x10xf32> { + %result = "xla_hlo.iota"() { + iota_dimension = 1 : i64 + } : () -> tensor<1x10xf32> + return %result : tensor<1x10xf32> +} + +// CHECK-LABEL: ENTRY +// CHECK: ROOT %[[RESULT:.*]] = f32[1,10] iota(), iota_dimension=1 + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg: tensor<4x6xf32>, %pad: tensor) -> tensor<13x19xf32> { + %0 = "xla_hlo.pad"(%arg, %pad) {edge_padding_high = dense<[4,5]> : tensor<2xi64>, edge_padding_low = dense<[2,3]> : tensor<2xi64>, interior_padding = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>, tensor) -> tensor<13x19xf32> + return %0 : tensor<13x19xf32> +} + +// CHECK-LABEL: ENTRY +// CHECK: [[ARG:%.*]] = f32[4,6] parameter(0) +// CHECK: [[PADDING_VAL:%.*]] = f32[] parameter(1) +// CHECK-LABEL: ROOT +// CHECK-SAME: f32[13,19] pad(f32[4,6] [[ARG]], f32[] [[PADDING_VAL]]), padding=2_4_1x3_5_1 + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor, %arg3 : tensor) -> (tensor<1xf32>, tensor<1xi32>) { + %result0, %result1 = "xla_hlo.reduce"(%arg0, %arg1, %arg2, %arg3) ( { + ^bb0(%fa: tensor, %ia : tensor, %fb: tensor, %ib: tensor): // no predecessors + %fmax = "xla_hlo.max"(%fa, %fb) {} : (tensor, tensor) -> tensor + %imax = "xla_hlo.max"(%ia, %ib) {} : (tensor, tensor) -> tensor + "xla_hlo.return"(%fmax, %imax) : (tensor, tensor) -> () + }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<1x10xi32>, tensor, tensor) -> (tensor<1xf32>, tensor<1xi32>) + return %result0, %result1 : tensor<1xf32>, tensor<1xi32> +} + +// CHECK: %[[REGION:region_[0-9]+]] +// CHECK-SAME: ([[ARG_FA:.*]]: f32[], [[ARG_IA:.*]]: s32[], [[ARG_FB:.*]]: f32[], [[ARG_IB:.*]]: s32[]) -> (f32[], s32[]) +// CHECK: %[[FMAX:.*]] = f32[] maximum(f32[] %[[ARG_FA]], f32[] %[[ARG_FB]]) +// CHECK: %[[IMAX:.*]] = s32[] maximum(s32[] %[[ARG_IA]], s32[] %[[ARG_IB]]) +// CHECK: ROOT %[[RESULT_REGION:.*]] = (f32[], s32[]) tuple(f32[] %[[FMAX]], s32[] %[[IMAX]]) + +// CHECK-LABEL: ENTRY +// CHECK-SAME: ([[ARG0:.*]]: f32[1,10], [[ARG1:.*]]: s32[1,10], [[ARG2:.*]]: f32[], [[ARG3:.*]]: s32[]) -> (f32[1], s32[1]) +// CHECK: %[[RESULT:.*]] = (f32[1], s32[1]) reduce(f32[1,10] %[[ARG0]], s32[1,10] %[[ARG1]], f32[] %[[ARG2]], s32[] %[[ARG3]]), dimensions={1}, to_apply=%[[REGION]] +// CHECK: %[[RESULT0:.*]] = f32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=0 +// CHECK: %[[RESULT1:.*]] = s32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=1 +// CHECK: ROOT %[[RESULT:.*]] = (f32[1], s32[1]) tuple(f32[1] %[[RESULT0]], s32[1] %[[RESULT1]]) + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg0: tensor<2x17x31x7xi32>) -> tensor<2x3x5x7xi32> { + %0 = xla_hlo.constant dense<-2147483648> : tensor + %1 = "xla_hlo.reduce_window"(%arg0, %0) ( { + ^bb0(%arg1: tensor, %arg2: tensor): // no predecessors + %2 = xla_hlo.max %arg1, %arg2 : tensor + "xla_hlo.return"(%2) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>, + padding = dense<[[0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64>, + base_dilations = dense<[1, 1, 1, 1]> : tensor<4xi64>, + window_dilations = dense<[1, 2, 2, 1]> : tensor<4xi64> + } : (tensor<2x17x31x7xi32>, tensor) -> tensor<2x3x5x7xi32> + return %1 : tensor<2x3x5x7xi32> +} + +// CHECK: %[[MAX_COMPUTATION:.*]] ([[ARG0:.*]]: s32[], [[ARG1:.*]]: s32[]) -> s32[] +// CHECK: ROOT %[[RESULT:.*]] = s32[] maximum(s32[] %[[ARG0]], s32[] %[[ARG1]]) + +// CHECK-LABEL: ENTRY +// CHECK-DAG: %[[ARG0:.*]] = s32[2,17,31,7] parameter(0) +// CHECK-DAG: %[[INIT:.*]] = s32[] constant(-2147483648) +// CHECK: ROOT %[[RESULT:.*]] = s32[2,5,8,7] reduce-window(s32[2,17,31,7] %[[ARG0]], s32[] %constant.2), +// CHECK-SAME: window={size=1x2x2x1 stride=1x4x4x1 pad=0_0x2_0x0_2x0_0 rhs_dilate=1x2x2x1}, +// CHECK-SAME: to_apply=%[[MAX_COMPUTATION]] + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg0: tensor<2xf32>) -> tensor<1x2xf32> { + %0 = "xla_hlo.reshape"(%arg0) : (tensor<2xf32>) -> tensor<1x2xf32> + return %0 : tensor<1x2xf32> +} + +// CHECK-LABEL: ENTRY +// CHECK: %[[ARG0:.*]] = f32[2] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = f32[1,2] reshape(f32[2] %[[ARG0]]) + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg0 : tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> { + %result = "xla_hlo.reverse"(%arg0) { + dimensions = dense<[1,2]> : tensor<2xi64> + } : (tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> + return %result : tensor<10x11x12x13xf32> +} + +// CHECK-LABEL: ENTRY +// CHECK: %[[ARG0:.*]] = f32[10,11,12,13] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = f32[10,11,12,13] reverse(f32[10,11,12,13] %[[ARG0]]), dimensions={1,2} + +// ----- + +// CHECK-LABEL: HloModule +func @main() -> tensor<2x3x5xf32> { + %0 = xla_hlo.constant dense<0.000000e+00> : tensor + %1 = xla_hlo.constant dense<1.000000e+00> : tensor + %2 = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64> + %3 = "xla_hlo.rng_uniform"(%0, %1, %2) : (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> + return %3 : tensor<2x3x5xf32> +} + +// CHECK-LABEL: ENTRY +// CHECK-DAG: %[[A:.*]] = f32[] constant(0) +// CHECK-DAG: %[[B:.*]] = f32[] constant(1) +// CHECK: ROOT %[[RESULT:.*]] = f32[2,3,5] rng(f32[] %[[A]], f32[] %[[B]]), distribution=rng_uniform + +// ----- + +// CHECK-LABEL: HloModule +func @main(%input_tensor: tensor<200x100x300xf32>, %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> tensor<200x100x300xf32> { + %0 = "xla_hlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor, %rhs: tensor): // no predecessors + %add = xla_hlo.add %lhs, %rhs : tensor + "xla_hlo.return"(%add) : (tensor) -> () + }) { + scatter_dimension_numbers = { + update_window_dims = dense<[1]> : tensor<1xi64>, + inserted_window_dims = dense<[0, 1]> : tensor<2xi64>, + scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>, + index_vector_dim = 1 : i64 + }, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32> + return %0 : tensor<200x100x300xf32> +} + +// CHECK: [[COMPUTATION:%.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[] +// CHECK-LABEL: ENTRY +// CHECK: [[VAL_1:%.*]] = f32[200,100,300] parameter(0) +// CHECK: [[VAL_2:%.*]] = s32[10,2] parameter(1) +// CHECK: [[VAL_3:%.*]] = f32[10,300] parameter(2) +// CHECK-LABEL: ROOT +// CHECK-SAME: f32[200,100,300] scatter(f32[200,100,300] [[VAL_1]], s32[10,2] [[VAL_2]], f32[10,300] [[VAL_3]]), update_window_dims={1}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, indices_are_sorted=true, unique_indices=true, to_apply=[[COMPUTATION]] + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { + // CHECK: %[[ARG0:.*]] = pred[] parameter(0) + // CHECK: %[[COND:.*]] = pred[2,3] broadcast(pred[] %[[ARG0]]), dimensions={} + // CHECK: %[[ARG1:.*]] = s32[2,3] parameter(1) + // CHECK: %[[ARG2:.*]] = s32[2,3] parameter(2) + + // CHECK: ROOT %[[RES:.*]] = s32[2,3] select(pred[2,3] %[[COND]], s32[2,3] %[[ARG1]], s32[2,3] %[[ARG2]]) + %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + return %0 : tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> { + %0 = xla_hlo.constant dense<0.000000e+00> : tensor + %1 = "xla_hlo.select_and_scatter"(%arg0, %arg1, %0) ( { + ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors + %2 = "xla_hlo.compare"(%arg3, %arg4) {comparison_direction = "GE"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors + %2 = xla_hlo.add %arg3, %arg4 : tensor + "xla_hlo.return"(%2) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64> + } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32> + return %1 : tensor<10x24x24x64xf32> +} + +// CHECK: %[[SELECT_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> pred[] { +// CHECK: ROOT %[[RESULT:.*]] = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GE + +// CHECK: %[[SCATTER_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> f32[] { +// CHECK: ROOT %[[RESULT:.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]]) + +// CHECK-LABEL: ENTRY +// CHECK: %[[ARG0:.*]] = f32[10,24,24,64] parameter(0) +// CHECK: %[[ARG1:.*]] = f32[10,12,12,64] parameter(1) +// CHECK: %[[INIT:.*]] = f32[] constant(0) + +// CHECK: ROOT %[[RESULT:.*]] = f32[10,24,24,64] +// CHECK-SAME: select-and-scatter(f32[10,24,24,64] %[[ARG0]], f32[10,12,12,64] %[[ARG1]], f32[] %[[INIT]]), +// CHECK-SAME: window={size=1x2x2x1 stride=1x2x2x1}, +// CHECK-SAME: select=%[[SELECT_COMPUTATION]], scatter=%[[SCATTER_COMPUTATION]] + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg: tensor<3x4xi32>) -> tensor<1x2xi32> { + %0 = "xla_hlo.slice"(%arg) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32> + return %0 : tensor<1x2xi32> +} + +// CHECK-LABEL: ENTRY +// CHECK: [[ARG:%.*]] = s32[3,4] parameter(0) +// CHECK-LABEL: ROOT +// CHECK-SAME: s32[1,2] slice(s32[3,4] [[ARG]]), slice={[1:2:1], [0:4:2]} + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { + // CHECK: [[ARG:%.*]] = s32[1,2,3,4] parameter(0) + + // CHECK-NEXT: ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] [[ARG]]), dimensions={1,0,3,2} + %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + return %0 : tensor<2x1x4x3xi32> +} + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg0: tensor, %arg1 : tensor) -> tuple, tensor> { + %result = "xla_hlo.tuple"(%arg0, %arg1) {} : (tensor, tensor) -> tuple, tensor> + return %result : tuple, tensor> +} + +// CHECK-LABEL: ENTRY +// CHECK: %[[ARG0:.*]] = f32[] parameter(0) +// CHECK: %[[ARG1:.*]] = s32[] parameter(1) +// CHECK: ROOT %[[RESULT:.*]] = (f32[], s32[]) tuple(f32[] %[[ARG0]], s32[] %[[ARG1]]) + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg_f32: tensor<4xf32>, %arg_i32: tensor<4xi32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>) { + // CHECK: [[ARG_F32:%.*]] = f32[4] parameter(0) + // CHECK: [[EXPM1:%.*]] = f32[4] exponential-minus-one(f32[4] [[ARG_F32]]) + %expm1 = "xla_hlo.exponential_minus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32> + + // CHECK: [[LOG1P:%.*]] = f32[4] log-plus-one(f32[4] [[ARG_F32]]) + %log1p = "xla_hlo.log_plus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32> + + // CHECK: [[ARG_I32:%.*]] = s32[4] parameter(1) + // CHECK: [[NOT:%.*]] = s32[4] not(s32[4] [[ARG_I32]]) + %not = "xla_hlo.not"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32> + + // CHECK: [[POPCNT:%.*]] = s32[4] popcnt(s32[4] [[ARG_I32]]) + %popcnt = "xla_hlo.popcnt"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32> + + return %expm1, %log1p, %not, %popcnt : tensor<4xf32>, tensor<4xf32>, tensor<4xi32>, tensor<4xi32> +} + +// ----- + +// CHECK-LABEL: HloModule +func @main(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { + // CHECK: [[VAL_1:%.*]] = pred[4] parameter(0) + // CHECK: [[VAL_2:%.*]] = pred[4] parameter(1) + %0 = xla_hlo.xor %arg0, %arg1 : tensor<4xi1> + // CHECK: ROOT [[VAL_3:%.*]] = pred[4] xor(pred[4] [[VAL_1]], pred[4] [[VAL_2]]) + return %0 : tensor<4xi1> +} + +// ----- + +// CHECK-LABEL: HloModule +func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { + %0 = "xla_hlo.sort"(%input0, %input1) ( { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + return +} + +// CHECK: %[[SORT_CMP:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[], {{.*}}: s32[], {{.*}}: s32[]) -> pred[] { +// CHECK: ROOT %compare.8 = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GT + +// CHECK: ENTRY %{{.*}} ([[MAIN_ARG0:.*]]: f32[16,16], [[MAIN_ARG1:.*]]: s32[16,16]) -> (f32[16,16], s32[16,16]) { +// CHECK: ROOT %{{.*}} = (f32[16,16], s32[16,16]) sort(f32[16,16] %[[MAIN_ARG0]], s32[16,16] %[[MAIN_ARG1]]), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]] diff --git a/tensorflow/compiler/mlir/xla/tests/translate/get_dimension_size.mlir b/tensorflow/compiler/mlir/xla/tests/translate/get_dimension_size.mlir deleted file mode 100644 index 44ff3f144f6..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/get_dimension_size.mlir +++ /dev/null @@ -1,10 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -func @main(%arg: tensor<4x2xf32>) -> tensor { - %0 = "xla_hlo.get_dimension_size"(%arg) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor - return %0 : tensor -} - -// CHECK-LABEL: ENTRY -// CHECK: [[ARG:%.*]] = f32[4,2] parameter(0) -// CHECK: s32[] get-dimension-size(f32[4,2] [[ARG]]), dimensions={1} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/get_element_tuple.mlir b/tensorflow/compiler/mlir/xla/tests/translate/get_element_tuple.mlir deleted file mode 100644 index 8897a6fab33..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/get_element_tuple.mlir +++ /dev/null @@ -1,10 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -func @main(%arg0: tuple, tensor>) -> tensor { - %0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, tensor>) -> tensor - return %0 : tensor -} - -// CHECK-LABEL: main -// CHECK: %[[ARG0:.*]] = (f32[], s32[]) parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = f32[] get-tuple-element((f32[], s32[]) %[[ARG0]]), index=0 diff --git a/tensorflow/compiler/mlir/xla/tests/translate/ops.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt similarity index 100% rename from tensorflow/compiler/mlir/xla/tests/translate/ops.hlotxt rename to tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt diff --git a/tensorflow/compiler/mlir/xla/tests/translate/iota.mlir b/tensorflow/compiler/mlir/xla/tests/translate/iota.mlir deleted file mode 100644 index e7df347a734..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/iota.mlir +++ /dev/null @@ -1,11 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -func @main() -> tensor<1x10xf32> { - %result = "xla_hlo.iota"() { - iota_dimension = 1 : i64 - } : () -> tensor<1x10xf32> - return %result : tensor<1x10xf32> -} - -// CHECK-LABEL:main -// CHECK: ROOT %[[RESULT:.*]] = f32[1,10] iota(), iota_dimension=1 diff --git a/tensorflow/compiler/mlir/xla/tests/translate/pad.mlir b/tensorflow/compiler/mlir/xla/tests/translate/pad.mlir deleted file mode 100644 index d4fba830403..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/pad.mlir +++ /dev/null @@ -1,12 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -func @main(%arg: tensor<4x6xf32>, %pad: tensor) -> tensor<13x19xf32> { - %0 = "xla_hlo.pad"(%arg, %pad) {edge_padding_high = dense<[4,5]> : tensor<2xi64>, edge_padding_low = dense<[2,3]> : tensor<2xi64>, interior_padding = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>, tensor) -> tensor<13x19xf32> - return %0 : tensor<13x19xf32> -} - -// CHECK-LABEL: ENTRY -// CHECK: [[ARG:%.*]] = f32[4,6] parameter(0) -// CHECK: [[PADDING_VAL:%.*]] = f32[] parameter(1) -// CHECK-LABEL: ROOT -// CHECK-SAME: f32[13,19] pad(f32[4,6] [[ARG]], f32[] [[PADDING_VAL]]), padding=2_4_1x3_5_1 diff --git a/tensorflow/compiler/mlir/xla/tests/translate/reduce.mlir b/tensorflow/compiler/mlir/xla/tests/translate/reduce.mlir deleted file mode 100644 index db16a2219cc..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/reduce.mlir +++ /dev/null @@ -1,24 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor, %arg3 : tensor) -> (tensor<1xf32>, tensor<1xi32>) { - %result0, %result1 = "xla_hlo.reduce"(%arg0, %arg1, %arg2, %arg3) ( { - ^bb0(%fa: tensor, %ia : tensor, %fb: tensor, %ib: tensor): // no predecessors - %fmax = "xla_hlo.max"(%fa, %fb) {} : (tensor, tensor) -> tensor - %imax = "xla_hlo.max"(%ia, %ib) {} : (tensor, tensor) -> tensor - "xla_hlo.return"(%fmax, %imax) : (tensor, tensor) -> () - }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<1x10xi32>, tensor, tensor) -> (tensor<1xf32>, tensor<1xi32>) - return %result0, %result1 : tensor<1xf32>, tensor<1xi32> -} - -// CHECK: %[[REGION:region_[0-9]+]] -// CHECK-SAME: ([[ARG_FA:.*]]: f32[], [[ARG_IA:.*]]: s32[], [[ARG_FB:.*]]: f32[], [[ARG_IB:.*]]: s32[]) -> (f32[], s32[]) -// CHECK: %[[FMAX:.*]] = f32[] maximum(f32[] %[[ARG_FA]], f32[] %[[ARG_FB]]) -// CHECK: %[[IMAX:.*]] = s32[] maximum(s32[] %[[ARG_IA]], s32[] %[[ARG_IB]]) -// CHECK: ROOT %[[RESULT_REGION:.*]] = (f32[], s32[]) tuple(f32[] %[[FMAX]], s32[] %[[IMAX]]) - -// CHECK: ENTRY %main -// CHECK-SAME: ([[ARG0:.*]]: f32[1,10], [[ARG0:.*]]: s32[1,10], [[ARG0:.*]]: f32[], [[ARG0:.*]]: s32[]) -> (f32[1], s32[1]) -// CHECK: %[[RESULT:.*]] = (f32[1], s32[1]) reduce(f32[1,10] %Arg_0.1, s32[1,10] %Arg_1.2, f32[] %Arg_2.3, s32[] %Arg_3.4), dimensions={1}, to_apply=%[[REGION]] -// CHECK: %[[RESULT0:.*]] = f32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=0 -// CHECK: %[[RESULT1:.*]] = s32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=1 -// CHECK: ROOT %[[RESULT:.*]] = (f32[1], s32[1]) tuple(f32[1] %[[RESULT0]], s32[1] %[[RESULT1]]) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/reduce_window.mlir b/tensorflow/compiler/mlir/xla/tests/translate/reduce_window.mlir deleted file mode 100644 index 4ef1d1a6057..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/reduce_window.mlir +++ /dev/null @@ -1,27 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -func @main(%arg0: tensor<2x17x31x7xi32>) -> tensor<2x3x5x7xi32> { - %0 = xla_hlo.constant dense<-2147483648> : tensor - %1 = "xla_hlo.reduce_window"(%arg0, %0) ( { - ^bb0(%arg1: tensor, %arg2: tensor): // no predecessors - %2 = xla_hlo.max %arg1, %arg2 : tensor - "xla_hlo.return"(%2) : (tensor) -> () - }) { - window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, - window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>, - padding = dense<[[0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64>, - base_dilations = dense<[1, 1, 1, 1]> : tensor<4xi64>, - window_dilations = dense<[1, 2, 2, 1]> : tensor<4xi64> - } : (tensor<2x17x31x7xi32>, tensor) -> tensor<2x3x5x7xi32> - return %1 : tensor<2x3x5x7xi32> -} - -// CHECK: %[[MAX_COMPUTATION:.*]] ([[ARG0:.*]]: s32[], [[ARG1:.*]]: s32[]) -> s32[] -// ROOT %[[RESULT:.*]] = s32[] maximum(s32[] %[[ARG0]], s32[] %[[ARG1]]) - -// CHECK: ENTRY %main -// CHECK-DAG: %[[ARG0:.*]] = s32[2,17,31,7] parameter(0) -// CHECK-DAG: %[[INIT:.*]] = s32[] constant(-2147483648) -// CHECK: ROOT %[[RESULT:.*]] = s32[2,5,8,7] reduce-window(s32[2,17,31,7] %[[ARG0]], s32[] %constant.2), -// CHECK-SAME: window={size=1x2x2x1 stride=1x4x4x1 pad=0_0x2_0x0_2x0_0 rhs_dilate=1x2x2x1}, -// CHECK-SAME: to_apply=%[[MAX_COMPUTATION]] diff --git a/tensorflow/compiler/mlir/xla/tests/translate/reshape.mlir b/tensorflow/compiler/mlir/xla/tests/translate/reshape.mlir deleted file mode 100644 index b0bb8fedb74..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/reshape.mlir +++ /dev/null @@ -1,10 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -func @main(%arg0: tensor<2xf32>) -> tensor<1x2xf32> { - %0 = "xla_hlo.reshape"(%arg0) : (tensor<2xf32>) -> tensor<1x2xf32> - return %0 : tensor<1x2xf32> -} - -// CHECK: ENTRY %main -// CHECK: %[[ARG0:.*]] = f32[2] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = f32[1,2] reshape(f32[2] %[[ARG0]]) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/reverse.mlir b/tensorflow/compiler/mlir/xla/tests/translate/reverse.mlir deleted file mode 100644 index b3393952ed6..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/reverse.mlir +++ /dev/null @@ -1,12 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -func @main(%arg0 : tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> { - %result = "xla_hlo.reverse"(%arg0) { - dimensions = dense<[1,2]> : tensor<2xi64> - } : (tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> - return %result : tensor<10x11x12x13xf32> -} - -// CHECK-LABEL: main -// CHECK: %[[ARG0:.*]] = f32[10,11,12,13] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = f32[10,11,12,13] reverse(f32[10,11,12,13] %[[ARG0]]), dimensions={1,2} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/rng_uniform.mlir b/tensorflow/compiler/mlir/xla/tests/translate/rng_uniform.mlir deleted file mode 100644 index 505d6b43b06..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/rng_uniform.mlir +++ /dev/null @@ -1,14 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -func @main() -> tensor<2x3x5xf32> { - %0 = xla_hlo.constant dense<0.000000e+00> : tensor - %1 = xla_hlo.constant dense<1.000000e+00> : tensor - %2 = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64> - %3 = "xla_hlo.rng_uniform"(%0, %1, %2) : (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> - return %3 : tensor<2x3x5xf32> -} - -// CHECK: ENTRY %main -// CHECK-DAG: %[[A:.*]] = f32[] constant(0) -// CHECK-DAG: %[[B:.*]] = f32[] constant(1) -// CHECK: ROOT %[[RESULT:.*]] = f32[2,3,5] rng(f32[] %[[A]], f32[] %[[B]]), distribution=rng_uniform diff --git a/tensorflow/compiler/mlir/xla/tests/translate/scatter.mlir b/tensorflow/compiler/mlir/xla/tests/translate/scatter.mlir deleted file mode 100644 index 227a45bab18..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/scatter.mlir +++ /dev/null @@ -1,27 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -func @main(%input_tensor: tensor<200x100x300xf32>, %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> tensor<200x100x300xf32> { - %0 = "xla_hlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ - ^bb0(%lhs: tensor, %rhs: tensor): // no predecessors - %add = xla_hlo.add %lhs, %rhs : tensor - "xla_hlo.return"(%add) : (tensor) -> () - }) { - scatter_dimension_numbers = { - update_window_dims = dense<[1]> : tensor<1xi64>, - inserted_window_dims = dense<[0, 1]> : tensor<2xi64>, - scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>, - index_vector_dim = 1 : i64 - }, - indices_are_sorted = true, - unique_indices = true - } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32> - return %0 : tensor<200x100x300xf32> -} - -// CHECK: [[COMPUTATION:%.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[] -// CHECK-LABEL: ENTRY -// CHECK: [[VAL_1:%.*]] = f32[200,100,300] parameter(0) -// CHECK: [[VAL_2:%.*]] = s32[10,2] parameter(1) -// CHECK: [[VAL_3:%.*]] = f32[10,300] parameter(2) -// CHECK-LABEL: ROOT -// CHECK-SAME: f32[200,100,300] scatter(f32[200,100,300] [[VAL_1]], s32[10,2] [[VAL_2]], f32[10,300] [[VAL_3]]), update_window_dims={1}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, indices_are_sorted=true, unique_indices=true, to_apply=[[COMPUTATION]] diff --git a/tensorflow/compiler/mlir/xla/tests/translate/select.mlir b/tensorflow/compiler/mlir/xla/tests/translate/select.mlir deleted file mode 100644 index e4cc1b3babd..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/select.mlir +++ /dev/null @@ -1,14 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s --dump-input-on-failure - -// CHECK-LABEL: ENTRY %main -func @main(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { - // CHECK: %[[ARG0:.*]] = pred[] parameter(0) - // CHECK: %[[COND:.*]] = pred[2,3] broadcast(pred[] %Arg_0.1), dimensions={} - // CHECK: %[[ARG1:.*]] = s32[2,3] parameter(1) - // CHECK: %[[ARG2:.*]] = s32[2,3] parameter(2) - - // CHECK: ROOT %[[RES:.*]] = s32[2,3] select(pred[2,3] %[[COND]], s32[2,3] %[[ARG1]], s32[2,3] %[[ARG2]]) - %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> - return %0 : tensor<2x3xi32> -} - diff --git a/tensorflow/compiler/mlir/xla/tests/translate/select_and_scatter.mlir b/tensorflow/compiler/mlir/xla/tests/translate/select_and_scatter.mlir deleted file mode 100644 index 4a8d3bbfcf3..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/select_and_scatter.mlir +++ /dev/null @@ -1,34 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> { - %0 = xla_hlo.constant dense<0.000000e+00> : tensor - %1 = "xla_hlo.select_and_scatter"(%arg0, %arg1, %0) ( { - ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors - %2 = "xla_hlo.compare"(%arg3, %arg4) {comparison_direction = "GE"} : (tensor, tensor) -> tensor - "xla_hlo.return"(%2) : (tensor) -> () - }, { - ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors - %2 = xla_hlo.add %arg3, %arg4 : tensor - "xla_hlo.return"(%2) : (tensor) -> () - }) { - window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, - window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64> - } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32> - return %1 : tensor<10x24x24x64xf32> -} - -// CHECK: %[[SELECT_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> pred[] { -// CHECK: ROOT %[[RESULT:.*]] = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GE - -// CHECK: %[[SCATTER_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> f32[] { -// CHECK: ROOT %[[RESULT:.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]]) - -// CHECK: ENTRY %main -// CHECK: %[[ARG0:.*]] = f32[10,24,24,64] parameter(0) -// CHECK: %[[ARG1:.*]] = f32[10,12,12,64] parameter(1) -// CHECK: %[[INIT:.*]] = f32[] constant(0) - -// CHECK: ROOT %[[RESULT:.*]] = f32[10,24,24,64] -// CHECK-SAME: select-and-scatter(f32[10,24,24,64] %[[ARG0]], f32[10,12,12,64] %[[ARG1]], f32[] %[[INIT]]), -// CHECK-SAME: window={size=1x2x2x1 stride=1x2x2x1}, -// CHECK-SAME: select=%[[SELECT_COMPUTATION]], scatter=%[[SCATTER_COMPUTATION]] diff --git a/tensorflow/compiler/mlir/xla/tests/translate/slice.mlir b/tensorflow/compiler/mlir/xla/tests/translate/slice.mlir deleted file mode 100644 index 3f31a008c1c..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/slice.mlir +++ /dev/null @@ -1,11 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -func @main(%arg: tensor<3x4xi32>) -> tensor<1x2xi32> { - %0 = "xla_hlo.slice"(%arg) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32> - return %0 : tensor<1x2xi32> -} - -// CHECK-LABEL: ENTRY -// CHECK: [[ARG:%.*]] = s32[3,4] parameter(0) -// CHECK-LABEL: ROOT -// CHECK-SAME: s32[1,2] slice(s32[3,4] [[ARG]]), slice={[1:2:1], [0:4:2]} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/transpose.mlir b/tensorflow/compiler/mlir/xla/tests/translate/transpose.mlir deleted file mode 100644 index 77048e6c902..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/transpose.mlir +++ /dev/null @@ -1,11 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -// CHECK-LABEL: ENTRY %main -func @main(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { - // CHECK-NEXT: %Arg_0.1 = s32[1,2,3,4] parameter(0) - - // CHECK-NEXT: ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] %Arg_0.1), dimensions={1,0,3,2} - %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> - return %0 : tensor<2x1x4x3xi32> -} - diff --git a/tensorflow/compiler/mlir/xla/tests/translate/tuple.mlir b/tensorflow/compiler/mlir/xla/tests/translate/tuple.mlir deleted file mode 100644 index 5024a66dfe6..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/tuple.mlir +++ /dev/null @@ -1,11 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -func @main(%arg0: tensor, %arg1 : tensor) -> tuple, tensor> { - %result = "xla_hlo.tuple"(%arg0, %arg1) {} : (tensor, tensor) -> tuple, tensor> - return %result : tuple, tensor> -} - -// CHECK-LABEL: main -// CHECK: %[[ARG0:.*]] = f32[] parameter(0) -// CHECK: %[[ARG1:.*]] = s32[] parameter(1) -// CHECK: ROOT %[[RESULT:.*]] = (f32[], s32[]) tuple(f32[] %[[ARG0]], s32[] %[[ARG1]]) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/unary_ops.mlir b/tensorflow/compiler/mlir/xla/tests/translate/unary_ops.mlir deleted file mode 100644 index c4138010543..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/unary_ops.mlir +++ /dev/null @@ -1,21 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -module { - func @main(%arg_f32: tensor<4xf32>, %arg_i32: tensor<4xi32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>) { - // CHECK: [[ARG_F32:%.*]] = f32[4] parameter(0) - // CHECK: [[EXPM1:%.*]] = f32[4] exponential-minus-one(f32[4] [[ARG_F32]]) - %expm1 = "xla_hlo.exponential_minus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32> - - // CHECK: [[LOG1P:%.*]] = f32[4] log-plus-one(f32[4] [[ARG_F32]]) - %log1p = "xla_hlo.log_plus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32> - - // CHECK: [[ARG_I32:%.*]] = s32[4] parameter(1) - // CHECK: [[NOT:%.*]] = s32[4] not(s32[4] [[ARG_I32]]) - %not = "xla_hlo.not"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32> - - // CHECK: [[POPCNT:%.*]] = s32[4] popcnt(s32[4] [[ARG_I32]]) - %popcnt = "xla_hlo.popcnt"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32> - - return %expm1, %log1p, %not, %popcnt : tensor<4xf32>, tensor<4xf32>, tensor<4xi32>, tensor<4xi32> - } -} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/xor.mlir b/tensorflow/compiler/mlir/xla/tests/translate/xor.mlir deleted file mode 100644 index 3ad79d633c7..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/translate/xor.mlir +++ /dev/null @@ -1,11 +0,0 @@ -// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s - -module { - func @main(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { - // CHECK: [[VAL_1:%.*]] = pred[4] parameter(0) - // CHECK: [[VAL_2:%.*]] = pred[4] parameter(1) - %0 = xla_hlo.xor %arg0, %arg1 : tensor<4xi1> - // CHECK: ROOT [[VAL_3:%.*]] = pred[4] xor(pred[4] [[VAL_1]], pred[4] [[VAL_2]]) - return %0 : tensor<4xi1> - } -} diff --git a/tensorflow/compiler/mlir/xla/transforms/canonicalize.td b/tensorflow/compiler/mlir/xla/transforms/canonicalize.td index bc44117910b..37f6d7deaa3 100644 --- a/tensorflow/compiler/mlir/xla/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/xla/transforms/canonicalize.td @@ -17,7 +17,7 @@ limitations under the License. include "mlir/IR/OpBase.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" - +include "tensorflow/compiler/mlir/xla/ir/hlo_utils.td" //===----------------------------------------------------------------------===// // DynamicSlice op patterns. @@ -37,3 +37,13 @@ def DynamicSliceToSlice: Pat<(HLO_DynamicSliceOp HLO_Tensor:$input, (HLO_SliceOp $input, (CastIntElementsAttr $starting_indices), (BuildSliceLimits $starting_indices, $slice_sizes), (BuildSliceStrides $input))>; + +def UnaryToBianryEinsumEq : NativeCodeCall< + "$_builder.getStringAttr(\",\" + $0.getValue().str())">; + +// Convert UnaryEinsumOp to EinsumOp with two operands with redundant first +// operand. +def UnaryEinsumToEinsum : Pat< + (HLO_UnaryEinsumOp $operand, $equation), + (HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)), + $operand, (UnaryToBianryEinsumEq $equation))>; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index b1e94927ecb..f0ba67e2fd5 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Diagnostics.h" // TF:local_config_mlir @@ -47,9 +48,10 @@ limitations under the License. #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" -using namespace mlir; - +namespace mlir { +namespace xla_hlo { namespace { + class LegalizeTF : public FunctionPass { public: struct Options : public PassOptions { @@ -72,12 +74,6 @@ class LegalizeTF : public FunctionPass { private: bool allow_partial_conversion_; }; -} // end anonymous namespace - -std::unique_ptr> -mlir::xla_hlo::createLegalizeTFPass(bool allow_partial_conversion) { - return std::make_unique(allow_partial_conversion); -} /// Returns if the given TF data format string is the default format. static bool isDefaultDataFormat(StringRef format) { return format == "NHWC"; } @@ -131,10 +127,9 @@ static llvm::Optional GetIntegerHLOAxisFromTFAxis(Value *value, /// Returns a `ConvertOp` that casts the elements to a i64 type while retaining /// the shape of the input value. -static xla_hlo::ConvertOp CastElementsToI64(Location loc, Value *value, - PatternRewriter *rewriter) { - return rewriter->create(loc, value, - rewriter->getIntegerType(64)); +static ConvertOp CastElementsToI64(Location loc, Value *value, + PatternRewriter *rewriter) { + return rewriter->create(loc, value, rewriter->getIntegerType(64)); } // Returns size of dimension at the specified index, if ranked tensor. @@ -155,8 +150,8 @@ tensorflow::TensorShape ToTensorShape(llvm::ArrayRef sizes) { } // Returns minimum value for the given int or float element type. -static xla_hlo::ConstOp GetMinValueForType(Type ty, Location loc, - PatternRewriter *rewriter) { +static ConstOp GetMinValueForType(Type ty, Location loc, + PatternRewriter *rewriter) { RankedTensorType scalar_ty = RankedTensorType::get({}, ty); DenseElementsAttr attr; @@ -169,26 +164,14 @@ static xla_hlo::ConstOp GetMinValueForType(Type ty, Location loc, APInt min_val = APInt::getSignedMinValue(int_ty.getWidth()); attr = DenseElementsAttr::get(scalar_ty, min_val); } - return rewriter->create(loc, attr); + return rewriter->create(loc, attr); } // Returns int or float scalar DenseElementsAttr attribute with the given // element type and the value. -static xla_hlo::ConstOp GetScalarOfType(Type ty, Location loc, - int64_t raw_value, - PatternRewriter *rewriter) { - RankedTensorType scalar_ty = RankedTensorType::get({}, ty); - - DenseElementsAttr attr; - if (auto float_ty = ty.dyn_cast_or_null()) { - APFloat value(float_ty.getFloatSemantics(), raw_value); - attr = DenseElementsAttr::get(scalar_ty, value); - } else { - auto int_ty = ty.cast(); - APInt value(int_ty.getWidth(), static_cast(raw_value), true); - attr = DenseElementsAttr::get(scalar_ty, value); - } - return rewriter->create(loc, attr); +static ConstOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value, + PatternRewriter *rewriter) { + return rewriter->create(loc, xla::GetScalarOfType(ty, raw_value)); } // Builds body for reduce op by using the using the template binary op as the @@ -207,7 +190,7 @@ static void BuildReduceBody(Type element_type, Region *body, auto reducer = builder->create(loc, block->getArgument(0), block->getArgument(1), /*broadcast_dimensions=*/nullptr); - builder->create(loc, reducer.getResult()); + builder->create(loc, reducer.getResult()); } //===----------------------------------------------------------------------===// @@ -249,11 +232,14 @@ static DenseIntElementsAttr Get2DTransposePerm(BoolAttr transpose, Builder *b) { // Pad op utilities. //===----------------------------------------------------------------------===// +// Slices input attribute of rank two and returns the specified column. +// +// Always returns 64 bit integer attribute regardless of bitwidth of the input +// attribute. static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D( - Builder *b, ElementsAttr input, int column) { + ElementsAttr input, int column) { auto int_attr = input.cast(); auto shaped_type = int_attr.getType(); - auto element_type = shaped_type.getElementType(); auto shape = shaped_type.getShape(); if (shape.size() != 2) return DenseIntElementsAttr(); @@ -267,10 +253,20 @@ static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D( } } + auto element_type = IntegerType::get(64, input.getContext()); return DenseIntElementsAttr::get( RankedTensorType::get({shape[0]}, element_type), values); } +// Returns interior padding to use in HLO Pad op based on the TensorFlow padding +// in TensorFlow PadV2 op. +static DenseIntElementsAttr GetInteriorPadding(ElementsAttr tf_padding) { + auto length = tf_padding.getType().getShape()[0]; + auto element_type = IntegerType::get(64, tf_padding.getContext()); + return DenseIntElementsAttr::get( + RankedTensorType::get({length}, element_type), 0); +} + //===----------------------------------------------------------------------===// // Binary op utilities. //===----------------------------------------------------------------------===// @@ -360,17 +356,17 @@ static void BuildArgMinMaxReductionBody(Type input_element_type, Location loc = body->getLoc(); StringAttr compare_direction = StringAttr::get(direction, builder->getContext()); - Value *compare = builder->create( + Value *compare = builder->create( loc, block->getArgument(0), block->getArgument(2), /*broadcast_dimensions=*/nullptr, compare_direction); - Value *selected_input = builder->create( + Value *selected_input = builder->create( loc, input_type, compare, block->getArgument(0), block->getArgument(2)); - Value *selected_index = builder->create( + Value *selected_index = builder->create( loc, index_type, compare, block->getArgument(1), block->getArgument(3)); Value *return_values[] = {selected_input, selected_index}; - builder->create(loc, return_values); + builder->create(loc, return_values); } //===----------------------------------------------------------------------===// @@ -443,12 +439,40 @@ static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes( } //===----------------------------------------------------------------------===// -// Op converters. +// Sort op utilities. //===----------------------------------------------------------------------===// -namespace mlir { -namespace xla { -namespace { +// Builds the region `body` for xla_hlo.sort's comparator: for each type in +// `element_types`, create two block arguments, one for lhs and one for rhs, and +// generates xla_hlo.compare op to compare them with the given `direction`. +// +// Note that this right now only does comparsion on the first pair of block +// arguments. +static void BuildSortComparisonBody(llvm::ArrayRef element_types, + StringRef direction, Region *body, + OpBuilder *builder) { + OpBuilder::InsertionGuard insertion_point_gurad(*builder); + + Block *block = builder->createBlock(body); + // Add two arguments for each element type. + for (Type element_type : element_types) { + TensorType tensor_type = RankedTensorType::get({}, element_type); + block->addArguments({tensor_type, tensor_type}); + } + + Location loc = body->getLoc(); + StringAttr compare_direction = + StringAttr::get(direction, builder->getContext()); + Value *compare = builder->create( + loc, block->getArgument(0), block->getArgument(1), + /*broadcast_dimensions=*/nullptr, compare_direction); + + builder->create(loc, compare); +} + +//===----------------------------------------------------------------------===// +// Op converters. +//===----------------------------------------------------------------------===// NamedAttribute GetConvDimensionNumbersAttr( ArrayRef spatial_dim_indices, tensorflow::TensorFormat format, @@ -474,7 +498,7 @@ NamedAttribute GetConvDimensionNumbersAttr( return builder->getNamedAttr( "dimension_numbers", - mlir::xla_hlo::ConvDimensionNumbers::get( + ConvDimensionNumbers::get( batch_dim, feature_dim, spatial_dims, kernel_input_feature_dim, kernel_output_feature_dim, kernel_spatial_dimensions, batch_dim, feature_dim, spatial_dims, builder->getContext())); @@ -602,8 +626,8 @@ class ConvertConv : public OpRewritePattern { NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr, dimension_numbers_attr, feature_group_count_attr, batch_group_count_attr, paddings_attr}; - rewriter.replaceOpWithNewOp(op, op.getType(), operands, - llvm::makeArrayRef(attrs)); + rewriter.replaceOpWithNewOp(op, op.getType(), operands, + llvm::makeArrayRef(attrs)); return Pattern::matchSuccess(); } }; @@ -635,23 +659,46 @@ class ConvertBF16FloorDivOp : public OpRewritePattern { auto out_type = op.z()->getType().cast(); - l = rewriter.create(op.getLoc(), l, - rewriter.getF32Type()); - r = rewriter.create(op.getLoc(), r, - rewriter.getF32Type()); + l = rewriter.create(op.getLoc(), l, rewriter.getF32Type()); + r = rewriter.create(op.getLoc(), r, rewriter.getF32Type()); auto intermediate = rewriter.create( op.getLoc(), ChangeTensorElementType(&rewriter, out_type, rewriter.getF32Type()), l, r); - auto floor_op = rewriter.create(op.getLoc(), out_type, - intermediate); + auto floor_op = + rewriter.create(op.getLoc(), out_type, intermediate); rewriter.replaceOp(op, floor_op.getResult()); return Pattern::matchSuccess(); } }; +// Converts TensorFlow EinsumOp to either HLO EinsumOp or UnaryEinsumOp +// depending on arity of the op. +class ConvertEinsumOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(TF::EinsumOp op, + PatternRewriter &rewriter) const override { + StringAttr equation = op.getAttrOfType("equation"); + if (op.N() == 1) { + rewriter.replaceOpWithNewOp( + op, op.getType(), *op.inputs().begin(), equation); + } else if (op.N() == 2) { + auto inputs = llvm::to_vector<2>(op.inputs()); + rewriter.replaceOpWithNewOp(op, op.getType(), inputs[0], + inputs[1], equation); + } else { + // TensorFlow EinsumOp verifies that the number of operands are at most + // two. + return Pattern::matchFailure(); + } + return Pattern::matchSuccess(); + } +}; + // Converts MaxPool op to HLO ReduceWindow op by setting appropriate window // dimensions with max as the reduction function. // @@ -674,15 +721,15 @@ class ConvertMaxPoolOp : public OpRewritePattern { op.input()->getType().cast().getElementType(); if (!element_type.isIntOrFloat()) return matchFailure(); Location loc = op.getLoc(); - xla_hlo::ConstOp init = GetMinValueForType(element_type, loc, &rewriter); + ConstOp init = GetMinValueForType(element_type, loc, &rewriter); - auto reduce = rewriter.create( + auto reduce = rewriter.create( loc, op.getType(), op.input(), init.getResult(), GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()), /*base_dilations=*/DenseIntElementsAttr(), /*window_dilations=*/DenseIntElementsAttr(), /*paddings=*/DenseIntElementsAttr()); - BuildReduceBody(element_type, &reduce.body(), &rewriter); + BuildReduceBody(element_type, &reduce.body(), &rewriter); rewriter.replaceOp(op, reduce.getResult()); return matchSuccess(); @@ -717,28 +764,28 @@ class ConvertSigmoidOp : public OpRewritePattern { PatternRewriter &rewriter) const override { auto operand = op.getOperand(); - auto scalar_one = rewriter.create( + auto scalar_one = rewriter.create( op.getLoc(), rewriter.getFloatAttr(getElementTypeOrSelf(operand->getType()), 0.5)); auto shaped_type = operand->getType().cast(); - auto constant_ones = rewriter.create( + auto constant_ones = rewriter.create( op.getLoc(), shaped_type, scalar_one, DenseIntElementsAttr::get( RankedTensorType::get({shaped_type.getRank()}, rewriter.getIntegerType(64)), shaped_type.getShape())); - auto scaled_input = rewriter.create( + auto scaled_input = rewriter.create( op.getLoc(), operand, constant_ones, DenseIntElementsAttr()); - auto tanh_op = rewriter.create( - op.getLoc(), operand->getType(), scaled_input); - auto mul_op = rewriter.create( - op.getLoc(), tanh_op, constant_ones, - /*DenseIntElementsAttr=*/DenseIntElementsAttr()); - auto add_op = rewriter.create( - op.getLoc(), mul_op, constant_ones, - /*DenseIntElementsAttr=*/DenseIntElementsAttr()); + auto tanh_op = + rewriter.create(op.getLoc(), operand->getType(), scaled_input); + auto mul_op = + rewriter.create(op.getLoc(), tanh_op, constant_ones, + /*DenseIntElementsAttr=*/DenseIntElementsAttr()); + auto add_op = + rewriter.create(op.getLoc(), mul_op, constant_ones, + /*DenseIntElementsAttr=*/DenseIntElementsAttr()); rewriter.replaceOp(op, add_op.getResult()); return matchSuccess(); @@ -803,11 +850,11 @@ class ConvertSoftmaxOp : public OpRewritePattern { auto max_logits = rewriter.create(loc, logits, reduce_dim, /*keep_dims=*/rewriter.getBoolAttr(false)); - auto shifted_logits = rewriter.create( - loc, type, logits, max_logits, batch_dims); + auto shifted_logits = + rewriter.create(loc, type, logits, max_logits, batch_dims); // Exponentiate the inputs. - Value *exp = rewriter.create(loc, type, shifted_logits); + Value *exp = rewriter.create(loc, type, shifted_logits); // Compute summation of the exponentials. auto exp_sum = @@ -816,11 +863,10 @@ class ConvertSoftmaxOp : public OpRewritePattern { Value *sum = exp_sum.getResult(); if (use_log) { - Value *log = rewriter.create(loc, sum); - rewriter.replaceOpWithNewOp(op, shifted_logits, log, - batch_dims); + Value *log = rewriter.create(loc, sum); + rewriter.replaceOpWithNewOp(op, shifted_logits, log, batch_dims); } else { - rewriter.replaceOpWithNewOp(op, exp, sum, batch_dims); + rewriter.replaceOpWithNewOp(op, exp, sum, batch_dims); } return Pattern::matchSuccess(); } @@ -861,13 +907,13 @@ class ConvertSizeOp : public OpRewritePattern { const int64_t rank = input_ty.getRank(); auto result_type = op.getResult()->getType(); Operation *size = - GetScalarOfType(result_type.cast().getElementType(), - op.getLoc(), 1, &rewriter); + GetScalarConstOfType(result_type.cast().getElementType(), + op.getLoc(), 1, &rewriter); for (int64_t i = 0; i < rank; ++i) { - auto dim = rewriter.create( + auto dim = rewriter.create( op.getLoc(), result_type, input, rewriter.getIntegerAttr(rewriter.getIntegerType(32), i)); - size = rewriter.create( + size = rewriter.create( op.getLoc(), size->getResult(0), dim.getResult(), /*DenseIntElementsAttr=*/DenseIntElementsAttr()); } @@ -953,11 +999,11 @@ class ConvertSplitOp : public OpRewritePattern { for (int i = 0; i < num_splits; ++i) { begin_indices[dim_index] = i * slice_size; end_indices[dim_index] = (i + 1) * slice_size; - slices.push_back(rewriter.create( - op.getLoc(), slice_type, op.value(), - GetI64ElementsAttr(begin_indices, &rewriter), - GetI64ElementsAttr(end_indices, &rewriter), - GetI64ElementsAttr(strides, &rewriter))); + slices.push_back( + rewriter.create(op.getLoc(), slice_type, op.value(), + GetI64ElementsAttr(begin_indices, &rewriter), + GetI64ElementsAttr(end_indices, &rewriter), + GetI64ElementsAttr(strides, &rewriter))); } rewriter.replaceOp(op, slices); @@ -1059,10 +1105,10 @@ class ConvertStridedSliceOp : public OpRewritePattern { } Location loc = op.getLoc(); - auto reversed = rewriter.create( + auto reversed = rewriter.create( loc, input_ty, op.input(), GetI64ElementsAttr(dims_to_reverse, &rewriter)); - auto sliced = rewriter.create( + auto sliced = rewriter.create( loc, reversed.getResult(), GetI64ElementsAttr(hlo_begin_indices, &rewriter), GetI64ElementsAttr(hlo_end_indices, &rewriter), @@ -1070,7 +1116,7 @@ class ConvertStridedSliceOp : public OpRewritePattern { // Reshape slice result so that the shape is updated depending on // 'new_axis_mask' or 'shrink_axis_mask' attributes. - rewriter.replaceOpWithNewOp(op, op.getType(), sliced); + rewriter.replaceOpWithNewOp(op, op.getType(), sliced); return matchSuccess(); } }; @@ -1104,14 +1150,14 @@ class ConvertRangeOp : public OpRewritePattern { return matchFailure(); } - auto iota = rewriter.create(op.getLoc(), result_type, - rewriter.getI64IntegerAttr(0)); - auto scaled = rewriter.create( + auto iota = rewriter.create(op.getLoc(), result_type, + rewriter.getI64IntegerAttr(0)); + auto scaled = rewriter.create( op.getLoc(), result_type, iota, op.delta(), - getBroadcastDimensionsAttr(&rewriter, iota, op.delta())); - rewriter.replaceOpWithNewOp( + xla::getBroadcastDimensionsAttr(&rewriter, iota, op.delta())); + rewriter.replaceOpWithNewOp( op, result_type, scaled, op.start(), - getBroadcastDimensionsAttr(&rewriter, scaled, op.start())); + xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start())); return matchSuccess(); } }; @@ -1159,13 +1205,13 @@ class GenericConvertReductionOp : public OpRewritePattern { // repeated arithmetic operations. Type reduce_element_type = is_accumulation ? GetAccumulationType(element_type) : element_type; - auto casted_input = rewriter.create( - loc, op.input(), reduce_element_type); + auto casted_input = + rewriter.create(loc, op.input(), reduce_element_type); // Each reduction op can have a different initial value. Value *init = Derived::GetInitialValue(reduce_element_type, loc, rewriter); - auto reduction = rewriter.create( + auto reduction = rewriter.create( loc, casted_input.getResult(), init, GetI64ElementsAttr(xla_dimensions, &rewriter)); BuildReduceBody(reduce_element_type, &reduction.body(), @@ -1183,19 +1229,19 @@ class GenericConvertReductionOp : public OpRewritePattern { divisor_count *= input_shape[i]; } } - auto divisor = - GetScalarOfType(reduce_element_type, loc, divisor_count, &rewriter); + auto divisor = GetScalarConstOfType(reduce_element_type, loc, + divisor_count, &rewriter); auto broadcast_dims = GetI64ElementsAttr({}, &rewriter); - result = rewriter.create(loc, result, divisor.getResult(), - broadcast_dims); + result = rewriter.create(loc, result, divisor.getResult(), + broadcast_dims); } - result = rewriter.create(loc, result, element_type); + result = rewriter.create(loc, result, element_type); // Need to reshape back after the reduction if we're keeping the reduced // dimensions. if (op.keep_dims()) { - result = rewriter.create(loc, op.getType(), result); + result = rewriter.create(loc, op.getType(), result); } rewriter.replaceOp(op, {result}, {op.reduction_indices()}); @@ -1211,14 +1257,13 @@ class GenericConvertReductionOp : public OpRewritePattern { // %divisor = constant dense<...> : tensor // %mean = "xla_hlo.div"(%sum, %divisor) class ConvertMeanOp - : public GenericConvertReductionOp { + : public GenericConvertReductionOp { public: using GenericConvertReductionOp::GenericConvertReductionOp; static Value *GetInitialValue(Type reduce_element_type, Location loc, PatternRewriter &rewriter) { - return GetScalarOfType(reduce_element_type, loc, 0, &rewriter); + return GetScalarConstOfType(reduce_element_type, loc, 0, &rewriter); } }; @@ -1227,14 +1272,14 @@ class ConvertMeanOp // %init = constant dense<...> : tensor // %sum = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.add"] // {dimensions = ...} -class ConvertSumOp : public GenericConvertReductionOp { +class ConvertSumOp + : public GenericConvertReductionOp { public: using GenericConvertReductionOp::GenericConvertReductionOp; static Value *GetInitialValue(Type reduce_element_type, Location loc, PatternRewriter &rewriter) { - return GetScalarOfType(reduce_element_type, loc, 0, &rewriter); + return GetScalarConstOfType(reduce_element_type, loc, 0, &rewriter); } }; @@ -1244,7 +1289,7 @@ class ConvertSumOp : public GenericConvertReductionOp { public: using GenericConvertReductionOp::GenericConvertReductionOp; @@ -1289,7 +1334,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern { Type index_element_type = output_type.getElementType(); Value *index_init_value = - GetScalarOfType(index_element_type, loc, 0, &rewriter); + GetScalarConstOfType(index_element_type, loc, 0, &rewriter); RankedTensorType index_type = RankedTensorType::get(input_type.getShape(), index_element_type); @@ -1304,7 +1349,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern { IntegerAttr iota_dimension = IntegerAttr::get(rewriter.getIntegerType(64), axis); Value *index_values = - rewriter.create(loc, index_type, iota_dimension); + rewriter.create(loc, index_type, iota_dimension); std::vector dimensions = input_type.getShape(); dimensions.erase(dimensions.begin() + axis); @@ -1315,7 +1360,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern { DenseIntElementsAttr reduction_dimensions = GetI64ElementsAttr({axis}, &rewriter); - auto reduction = rewriter.create( + auto reduction = rewriter.create( loc, llvm::ArrayRef(operands), llvm::ArrayRef(init_values), reduction_dimensions); StringRef direction = Derived::GetDirection(); @@ -1403,12 +1448,12 @@ class ConvertTileOp : public OpRewritePattern { RankedTensorType::get(broadcasted_shape, element_type); Type output_type = op.getType(); - Value *result = rewriter.create( + Value *result = rewriter.create( loc, broadcasted_type, op.input(), GetI64ElementsAttr(broadcast_dimensions, &rewriter)); if (output_type != broadcasted_type) { - result = rewriter.create(loc, output_type, result); + result = rewriter.create(loc, output_type, result); } rewriter.replaceOp(op, {result}, {op.multiples()}); @@ -1431,13 +1476,13 @@ class ConvertMaxPoolGradOp : public OpRewritePattern { Type element_type = op.orig_input()->getType().cast().getElementType(); - auto result = rewriter.create( + auto result = rewriter.create( loc, op.getType(), op.orig_input(), op.grad(), - GetScalarOfType(element_type, loc, 0, &rewriter), + GetScalarConstOfType(element_type, loc, 0, &rewriter), GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()), nullptr); - BuildReduceBody(element_type, &result.scatter(), &rewriter); + BuildReduceBody(element_type, &result.scatter(), &rewriter); { OpBuilder::InsertionGuard guard(rewriter); Block *block = rewriter.createBlock(&result.select()); @@ -1446,11 +1491,11 @@ class ConvertMaxPoolGradOp : public OpRewritePattern { Type type = RankedTensorType::get(/*shape=*/{}, element_type); block->addArguments({type, type}); - auto reducer = rewriter.create( + auto reducer = rewriter.create( loc, block->getArgument(0), block->getArgument(1), /*broadcast_dimensions=*/nullptr, StringAttr::get("GE", rewriter.getContext())); - rewriter.create(loc, reducer.getResult()); + rewriter.create(loc, reducer.getResult()); } rewriter.replaceOp(op, {result}, {op.orig_output()}); @@ -1533,7 +1578,7 @@ class ConvertConv2DBackpropInputOp return matchFailure(); } - // Compute xla_hlo::ConvDimensionNumbers, dilation, and padding. + // Compute ConvDimensionNumbers, dilation, and padding. SmallVector kernel_spatial_dims(num_spatial_dims); SmallVector conv_paddings(num_spatial_dims * 2); SmallVector lhs_dilation(num_spatial_dims); @@ -1550,7 +1595,7 @@ class ConvertConv2DBackpropInputOp lhs_dilation[i] = dims.spatial_dims[i].stride; rhs_dilation[i] = dilations[dim]; } - RankedTensorType paddings_ty = mlir::RankedTensorType::get( + RankedTensorType paddings_ty = RankedTensorType::get( {num_spatial_dims, 2}, rewriter.getIntegerType(64)); auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, conv_paddings); auto spatial_dims_attr = GetI64ElementsAttr(spatial_dims, &rewriter); @@ -1567,17 +1612,17 @@ class ConvertConv2DBackpropInputOp } // Mirror the filter in the spatial dimensions. - filter = rewriter.create( + filter = rewriter.create( loc, filter, GetI64ElementsAttr(kernel_spatial_dims, &rewriter)); // activation gradients // = gradients (with padding and dilation) mirrored_weights - Value *result = rewriter.create( + Value *result = rewriter.create( loc, op.getType(), op.out_backprop(), filter, /*window_strides=*/GetI64ElementsAttr(ones, &rewriter), /*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter), GetI64ElementsAttr(rhs_dilation, &rewriter), - xla_hlo::ConvDimensionNumbers::get( + ConvDimensionNumbers::get( /*input_batch_dimension=*/batch_dim_attr, /*input_feature_dimension=*/feature_dim_attr, /*input_spatial_dimensions=*/spatial_dims_attr, @@ -1689,7 +1734,7 @@ class ConvertConv2DBackpropFilterOp return matchFailure(); } - // Compute xla_hlo::ConvDimensionNumbers, dilation, and padding. + // Compute ConvDimensionNumbers, dilation, and padding. SmallVector conv_padding(num_spatial_dims * 2); SmallVector rhs_dilation(num_spatial_dims); SmallVector window_strides(num_spatial_dims); @@ -1761,7 +1806,7 @@ class ConvertConv2DBackpropFilterOp conv_padding[i * 2 + 1] = pad_total - pad_before; } - RankedTensorType paddings_ty = mlir::RankedTensorType::get( + RankedTensorType paddings_ty = RankedTensorType::get( {num_spatial_dims, 2}, rewriter.getIntegerType(64)); auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, conv_padding); auto out_spatial_dims_attr = @@ -1773,12 +1818,12 @@ class ConvertConv2DBackpropFilterOp auto feature_dim_attr = rewriter.getI64IntegerAttr(feature_dim); Location loc = op.getLoc(); - Value *result = rewriter.create( + Value *result = rewriter.create( loc, op.getType(), op.input(), op.out_backprop(), /*window_strides=*/GetI64ElementsAttr(window_strides, &rewriter), /*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter), GetI64ElementsAttr(rhs_dilation, &rewriter), - xla_hlo::ConvDimensionNumbers::get( + ConvDimensionNumbers::get( // Swap batch_dim and feature_dim in the activations. /*input_batch_dimension=*/feature_dim_attr, /*input_feature_dimension=*/batch_dim_attr, @@ -1836,21 +1881,21 @@ class ConvertOneHotOp : public OpRewritePattern { Location loc = op.getLoc(); auto index_type = RankedTensorType::get(output_dims, element_type); - Value *compare = rewriter.create( + Value *compare = rewriter.create( loc, op.indices(), - rewriter.create( + rewriter.create( loc, index_type, IntegerAttr::get(rewriter.getIntegerType(64), axis)), GetI64ElementsAttr(broadcast_dims, &rewriter), StringAttr::get("EQ", rewriter.getContext())); - Value *on_value = rewriter.create( + Value *on_value = rewriter.create( loc, op.getType(), op.on_value(), GetI64ElementsAttr(output_dims, &rewriter)); - Value *off_value = rewriter.create( + Value *off_value = rewriter.create( loc, op.getType(), op.off_value(), GetI64ElementsAttr(output_dims, &rewriter)); - Value *result = rewriter.create( - loc, op.getType(), compare, on_value, off_value); + Value *result = rewriter.create(loc, op.getType(), compare, + on_value, off_value); rewriter.replaceOp( op, {result}, @@ -1860,42 +1905,132 @@ class ConvertOneHotOp : public OpRewritePattern { } }; -#include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc" -} // end anonymous namespace -} // end namespace xla -} // end namespace mlir +// Converts tf.TopKV2 to XLA HLO iota, sort, and slice ops when k is a constant. +// +// tf.TopKV2 sorts along last dimension of the input tensor and then returns +// the top K components' values and indices. This is translated into a few +// ops in XLA HLO: first generating an integer sequence for the indices, +// then sort both the original input tensor and the indices togheter, and +// at last slice out the top K components. +// +// For example, for the following IR: +// +// %k = "tf.Const"() {value = dense<8> : tensor} : () -> tensor +// %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor) -> +// (tensor<16x8xf32>, tensor<16x8xi32>) +// +// We will get: +// +// %1 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<16x16xi32> +// %2 = "xla_hlo.sort"(%input, %1) ( { +// ^bb0(%arg1: tensor, %arg2: tensor, +// %arg3: tensor, %arg4: tensor): +// %7 = "xla_hlo.compare"(%arg1, %arg2) {comparison_direction = "GT"}: ... +// "xla_hlo.return"(%7) : (tensor) -> () +// }) {dimension = 1 : i64, is_stable = true} : ... +// %3 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : ... +// %4 = "xla_hlo.get_tuple_element"(%2) {index = 1 : i32} : ... +// %5 = "xla_hlo.slice"(%3) {limit_indices = dense<[16, 8]> : tensor<2xi64>, +// start_indices dense<0> : tensor<2xi64>, +// strides = dense<1> : tensor<2xi64>} : +// (tensor<16x16xf32>) -> tensor<16x8xf32> +// %6 = "xla_hlo.slice"(%4) ... +class ConvertTopKV2Op : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; -LogicalResult mlir::xla_hlo::legalizeTF(Operation *op, - bool allow_partial_conversion) { + PatternMatchResult matchAndRewrite(TF::TopKV2Op op, + PatternRewriter &rewriter) const override { + // We can only match when the `k` operand is a constant scalar. + DenseIntElementsAttr k_attr; + if (!matchPattern(op.k(), m_Constant(&k_attr))) return matchFailure(); + + // The last dimension of the input tensor's shape should be known so we can + // have clamped end_indices for slices. + TensorType input_type = op.input()->getType().cast(); + if (!input_type.hasRank()) return matchFailure(); + int64_t input_rank = input_type.getRank(); + int64_t last_dim_index = input_rank - 1; + int64_t last_dim_size = input_type.getDimSize(last_dim_index); + if (last_dim_size == ShapedType::kDynamicSize) return matchFailure(); + + // Create an Itoa op for indices. + auto i32_type = rewriter.getIntegerType(32); + Type iota_type = RankedTensorType::get(input_type.getShape(), i32_type); + Value *iota_op = rewriter.create( + op.getLoc(), iota_type, rewriter.getI64IntegerAttr(last_dim_index)); + + // Create the sort op. It takes two inputs, one for the original input, the + // other for the indices. + auto sort_op = rewriter.create( + op.getLoc(), llvm::ArrayRef{op.input(), iota_op}, + last_dim_index, /*is_stable=*/true); + BuildSortComparisonBody({input_type.getElementType(), i32_type}, + /*direction=*/"GT", &sort_op.comparator(), + &rewriter); + + // Get the sorted input and index tuple element. + auto tuple_first_element = + rewriter.create(op.getLoc(), sort_op, 0); + auto tuple_second_element = + rewriter.create(op.getLoc(), sort_op, 1); + + SmallVector begin_indices(input_rank, 0); + auto end_indices = llvm::to_vector<4>(input_type.getShape()); + end_indices.back() = + std::min((*k_attr.begin()).getSExtValue(), last_dim_size); + SmallVector strides(input_rank, 1); + + // Get the slice for the top K elements. + + Value *values = rewriter.create( + op.getLoc(), tuple_first_element, + GetI64ElementsAttr(begin_indices, &rewriter), + GetI64ElementsAttr(end_indices, &rewriter), + GetI64ElementsAttr(strides, &rewriter)); + + Value *indices = rewriter.create( + op.getLoc(), tuple_second_element, + GetI64ElementsAttr(begin_indices, &rewriter), + GetI64ElementsAttr(end_indices, &rewriter), + GetI64ElementsAttr(strides, &rewriter)); + + rewriter.replaceOp(op, {values, indices}); + return matchSuccess(); + } +}; + +#include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc" + +LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { MLIRContext *context = op->getContext(); // Add lowering patterns to the list. OwningRewritePatternList patterns; - xla::populateWithGenerated(context, &patterns); + populateWithGenerated(context, &patterns); // Add patterns that lower some of the high level TensorFlow ops to lower // level TensorFlow ops. So, we don't have to target all the TensorFlow ops // here for lowering to HLO. - mlir::TF::PopulateLoweringTFPatterns(context, &patterns); - patterns.insert, - mlir::xla::ConvertSoftmaxOp, - mlir::xla::ConvertSplitOp, mlir::xla::ConvertStridedSliceOp, - mlir::xla::ConvertMeanOp, mlir::xla::ConvertSumOp, - mlir::xla::ConvertMaxOp, mlir::xla::ConvertTileOp, - mlir::xla::ConvertMaxPoolGradOp, mlir::xla::ConvertOneHotOp, - mlir::xla::ConvertConv2DBackpropInputOp, - mlir::xla::ConvertConv2DBackpropFilterOp>(op->getContext()); + TF::PopulateLoweringTFPatterns(context, &patterns); + patterns + .insert, + ConvertSoftmaxOp, ConvertSplitOp, + ConvertStridedSliceOp, ConvertTopKV2Op, ConvertMeanOp, + ConvertSumOp, ConvertMaxOp, ConvertTileOp, ConvertMaxPoolGradOp, + ConvertOneHotOp, ConvertConv2DBackpropInputOp, + ConvertConv2DBackpropFilterOp>(op->getContext()); ConversionTarget target(*context); target.addLegalDialect(); if (!allow_partial_conversion) { - target.addLegalOp(); + // Fully qualify ReturnOp here as xla_hlo dialect also defines a ReturnOp. + target.addLegalOp(); return applyFullConversion(op, target, patterns); } @@ -1904,10 +2039,19 @@ LogicalResult mlir::xla_hlo::legalizeTF(Operation *op, /// Performs the lowering to XLA dialect. void LegalizeTF::runOnFunction() { - if (failed( - mlir::xla_hlo::legalizeTF(getFunction(), allow_partial_conversion_))) + if (failed(legalizeTF(getFunction(), allow_partial_conversion_))) signalPassFailure(); } static PassRegistration pass( "xla-legalize-tf", "Legalize from TensorFlow to the XLA dialect"); + +} // end namespace + +std::unique_ptr> createLegalizeTFPass( + bool allow_partial_conversion) { + return std::make_unique(allow_partial_conversion); +} + +} // end namespace xla_hlo +} // end namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 4bf7ee16d0a..fb8c6736309 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -157,15 +157,16 @@ def : Pat<(TF_BroadcastToOp:$result AnyRankedTensor:$input, $shape), [(AnyRankedTensor $result)]>; //===----------------------------------------------------------------------===// -// Logical binary op patterns. +// Logical & bitwise binary op patterns. //===----------------------------------------------------------------------===// class DirectLogicalBinaryPat - : Pat<(FromOp I1Tensor:$l, I1Tensor:$r), + : Pat<(FromOp IntegerTensor:$l, IntegerTensor:$r), (ToOp $l, $r, (BinBroadcastDimensions $l, $r))>; foreach fromToBinPair = [[TF_LogicalAndOp, HLO_AndOp], - [TF_LogicalOrOp, HLO_OrOp]] in + [TF_LogicalOrOp, HLO_OrOp], + [TF_BitwiseOrOp, HLO_OrOp]] in def : DirectLogicalBinaryPat; //===----------------------------------------------------------------------===// @@ -236,7 +237,7 @@ def : Pat<(TF_ConcatV2Op $inputs, (TF_ConstOp OneElementAttr:$axis)), //===----------------------------------------------------------------------===// def CastElementsToI64Elements : NativeCodeCall< - "::xla::ConvertElementsAttr(" + "xla::ConvertElementsAttr(" "$0, $_builder.getIntegerType(64)).cast()">; def : Pat<(TF_CrossReplicaSumOp $input, (TF_ConstOp $group_assignment)), @@ -255,25 +256,21 @@ def : Pat<(TF_RFFTOp $input, (TF_ConstOp I32ElementsAttr:$fft_length)), // Pad op patterns. //===----------------------------------------------------------------------===// -def ZeroPaddingAttr : NativeCodeCall < - "DenseIntElementsAttr::get(" - "RankedTensorType::get($0.getType().getShape()[0]," - " getElementTypeOrSelf($0.getType())), " - "{$_builder.getZeroAttr(getElementTypeOrSelf($0.getType()))})">; - class SliceDenseIntElementsAttrColumn2D : NativeCodeCall< - "SliceDenseIntElementsAttrColumn2D(" - "&$_builder, $0, " # column # " )">; + "SliceDenseIntElementsAttrColumn2D($0, " # column # " )">; class SliceDenseIntElementsAttr : NativeCodeCall< - "SliceDenseIntElementsAttr(&$_builder, $0, " # index # ", " # axis # ")">; + "SliceDenseIntElementsAttr($0, " # index # ", " # axis # ")">; +// Interior padding attribute based on the TF padding. +def GetInteriorPadding : NativeCodeCall < + "GetInteriorPadding($0)">; -def : Pat<(TF_PadV2Op $input, (TF_ConstOp I64ElementsAttr:$padding), $c), +def : Pat<(TF_PadV2Op $input, (TF_ConstOp $padding), $c), (HLO_PadOp $input, $c, (SliceDenseIntElementsAttrColumn2D<"0"> $padding), (SliceDenseIntElementsAttrColumn2D<"1"> $padding), - (ZeroPaddingAttr $padding))>; + (GetInteriorPadding $padding))>; //===----------------------------------------------------------------------===// // Identity op patterns. diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_linalg.cc index 8920204abf3..28bacfa87f0 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_linalg.cc @@ -38,6 +38,15 @@ namespace mlir { namespace xla_lhlo { namespace { +ArrayAttr GetNParallelLoopsAttrs(unsigned nParallelLoops, Builder b) { + auto parallelLoopTypeAttr = b.getStringAttr("parallel"); + SmallVector iteratorTypes; + for (int i = 0; i < nParallelLoops; ++i) { + iteratorTypes.push_back(parallelLoopTypeAttr); + } + return b.getArrayAttr(iteratorTypes); +} + template class PointwiseToLinalgConverter : public OpConversionPattern { public: @@ -78,11 +87,6 @@ class PointwiseToLinalgConverter : public OpConversionPattern { result_or_body_arg.emplace_back(memrefType.getElementType()); } - // Pointwise-ops have all surrounding loops parallel, so the loop triple is - // [argDim, 0, 0]. - SmallVector loop_types{rewriter.getI64IntegerAttr(nloops), - rewriter.getI64IntegerAttr(0), - rewriter.getI64IntegerAttr(0)}; // Define the number of input memref/output memrefs. SmallVector nmemrefs{ rewriter.getI64IntegerAttr(bodyArgTypes.size()), @@ -90,7 +94,8 @@ class PointwiseToLinalgConverter : public OpConversionPattern { auto linalgOp = rewriter.create( loc, args, rewriter.getArrayAttr(indexingMaps), - rewriter.getArrayAttr(loop_types), rewriter.getArrayAttr(nmemrefs), + GetNParallelLoopsAttrs(nloops, rewriter), + rewriter.getArrayAttr(nmemrefs), /*doc=*/nullptr, /*fun=*/nullptr, /*library_call=*/nullptr); // Add a block to the region. @@ -158,11 +163,6 @@ class BroadcastInDimConverter : public OpConversionPattern { indexingMaps.emplace_back( AffineMapAttr::get(rewriter.getMultiDimIdentityMap(nloops))); - // Broadcast op has all surrounding loops parallel, so the loop triple is - // [argDim, 0, 0]. - SmallVector loop_types{rewriter.getI64IntegerAttr(nloops), - rewriter.getI64IntegerAttr(0), - rewriter.getI64IntegerAttr(0)}; // Define the number of input memref/output memrefs. SmallVector nmemrefs{ rewriter.getI64IntegerAttr(bodyArgTypes.size()), @@ -171,7 +171,8 @@ class BroadcastInDimConverter : public OpConversionPattern { auto loc = broadcastOp.getLoc(); auto linalgOp = rewriter.create( loc, args, rewriter.getArrayAttr(indexingMaps), - rewriter.getArrayAttr(loop_types), rewriter.getArrayAttr(nmemrefs), + GetNParallelLoopsAttrs(nloops, rewriter), + rewriter.getArrayAttr(nmemrefs), /*doc=*/nullptr, /*fun=*/nullptr, /*library_call=*/nullptr); // Add a block to the region. @@ -207,11 +208,6 @@ class IotaConverter : public OpConversionPattern { indexingMaps.emplace_back( AffineMapAttr::get(rewriter.getMultiDimIdentityMap(nloops))); - // Pointwise-ops have all surrounding loops parallel, so the loop triple is - // [argDim, 0, 0]. - SmallVector loop_types{rewriter.getI64IntegerAttr(nloops), - rewriter.getI64IntegerAttr(0), - rewriter.getI64IntegerAttr(0)}; // Define the number of input memref/output memrefs. SmallVector nmemrefs{rewriter.getI64IntegerAttr(0), rewriter.getI64IntegerAttr(1)}; @@ -219,7 +215,8 @@ class IotaConverter : public OpConversionPattern { auto loc = iotaOp.getLoc(); auto linalgOp = rewriter.create( loc, args, rewriter.getArrayAttr(indexingMaps), - rewriter.getArrayAttr(loop_types), rewriter.getArrayAttr(nmemrefs), + GetNParallelLoopsAttrs(nloops, rewriter), + rewriter.getArrayAttr(nmemrefs), /*doc=*/nullptr, /*fun=*/nullptr, /*library_call=*/nullptr); // Add a block to the region. @@ -277,7 +274,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, // "linalg.yield"(%0) : (f32) -> () // }) { // indexing_maps = [#map0, #map0, #map0], -// n_loop_types = [2, 0, 0], +// iterator_types = ["parallel", "parallel"], // n_views = [2, 1] // } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () // } diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index c800e5005d2..90c28e03d4d 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -1217,26 +1217,26 @@ static void InitializeTrtPlugins(nvinfer1::ILogger* trt_logger) { // static StatusOr> Converter::Create( - nvinfer1::IBuilder* trt_builder, TrtPrecisionMode precision_mode, - bool use_calibration, nvinfer1::ILogger* trt_logger) { + TrtPrecisionMode precision_mode, bool use_calibration, + nvinfer1::ILogger* trt_logger) { std::unique_ptr converter = absl::WrapUnique( - new Converter(trt_builder, precision_mode, use_calibration, trt_logger)); - TF_RETURN_IF_ERROR(converter->Init()); + new Converter(precision_mode, use_calibration, trt_logger)); + TF_RETURN_IF_ERROR(converter->Init(trt_logger)); return converter; } -Converter::Converter(nvinfer1::IBuilder* trt_builder, - TrtPrecisionMode precision_mode, bool use_calibration, +Converter::Converter(TrtPrecisionMode precision_mode, bool use_calibration, nvinfer1::ILogger* trt_logger) - : trt_builder_(trt_builder), - precision_mode_(precision_mode), - use_calibration_(use_calibration) { + : precision_mode_(precision_mode), use_calibration_(use_calibration) { InitializeTrtPlugins(trt_logger); this->RegisterOpConverters(); } -Status Converter::Init() { - // Create the network. +Status Converter::Init(nvinfer1::ILogger* trt_logger) { + VLOG(1) << "Creating TensorRT builder"; + trt_builder_.reset(nvinfer1::createInferBuilder(*trt_logger)); + + VLOG(1) << "Creating TensorRT network"; trt_network_.reset(trt_builder_->createNetwork()); if (!trt_network_) { return errors::Internal("Failed to create TensorRT network object"); @@ -1369,13 +1369,33 @@ Status Converter::RenameAndMarkOutputTensors( } Status Converter::BuildCudaEngine( - TrtUniquePtrType* engine) { - VLOG(1) << "Starting engine creation"; + TrtUniquePtrType* engine, int max_batch_size, + size_t max_workspace_size_bytes, nvinfer1::IGpuAllocator* allocator, + TRTInt8Calibrator* calibrator) { + VLOG(1) << "Configuring TensorRT builder"; + trt_builder_->setMaxBatchSize(max_batch_size); + trt_builder_->setMaxWorkspaceSize(max_workspace_size_bytes); + trt_builder_->setGpuAllocator(allocator); + if (precision_mode_ == TrtPrecisionMode::FP16) { + trt_builder_->setFp16Mode(true); + } else if (precision_mode_ == TrtPrecisionMode::INT8) { + // Setting FP16 mode as well allows TRT to also consider FP16 kernels and + // use them in situations where they are faster than INT8 or where INT8 is + // not supported for a given layer. + trt_builder_->setFp16Mode(true); + trt_builder_->setInt8Mode(true); + if (use_calibration_) { + trt_builder_->setInt8Calibrator(calibrator); + } else { + trt_builder_->setInt8Calibrator(nullptr); + } + } + + VLOG(1) << "Building TensorRT engine"; engine->reset(trt_builder_->buildCudaEngine(*network())); if (engine->get() == nullptr) { return errors::Internal("Failed to build TensorRT engine"); } - VLOG(1) << "Finished conversion"; return Status::OK(); } @@ -5620,37 +5640,13 @@ Status ConvertGraphDefToEngine( engine->reset(); if (convert_successfully) *convert_successfully = false; - // Create the builder. - TrtUniquePtrType builder( - nvinfer1::createInferBuilder(*trt_logger)); - builder->setMaxBatchSize(max_batch_size); - builder->setMaxWorkspaceSize(max_workspace_size_bytes); - builder->setGpuAllocator(allocator); - if (precision_mode == TrtPrecisionMode::FP16) { - builder->setFp16Mode(true); - } else if (precision_mode == TrtPrecisionMode::INT8) { - // Setting FP16 mode as well allows TRT to also consider FP16 kernels and - // use them in situations where they are faster than INT8 or where INT8 is - // not supported for a given layer. - builder->setFp16Mode(true); - builder->setInt8Mode(true); - if (use_calibration) { - builder->setInt8Calibrator(calibrator); - } else { - builder->setInt8Calibrator(nullptr); - } - } - - // Build the network - if (VLOG_IS_ON(1)) { - string mode_str; - TF_RETURN_IF_ERROR(TrtPrecisionModeToName(precision_mode, &mode_str)); - VLOG(1) << "Starting engine conversion, precision mode: " << mode_str; - } - auto statusor = Converter::Create(builder.get(), precision_mode, - use_calibration, trt_logger); + // Creating converter, TensorRT builder and network + auto statusor = + Converter::Create(precision_mode, use_calibration, trt_logger); TF_RETURN_IF_ERROR(statusor.status()); auto converter = std::move(statusor.ValueOrDie()); + + VLOG(1) << "Starting to convert TensorFlow ops to TensorRT layers"; std::vector output_tensors; // Graph nodes are already topologically sorted during construction for (const auto& node_def : gdef.node()) { @@ -5737,7 +5733,10 @@ Status ConvertGraphDefToEngine( converter->MaybeApplyQuantizationRanges(); // Build the engine. - TF_RETURN_IF_ERROR(converter->BuildCudaEngine(engine)); + TF_RETURN_IF_ERROR(converter->BuildCudaEngine( + engine, max_batch_size, max_workspace_size_bytes, allocator, calibrator)); + + VLOG(1) << "Finished conversion"; return Status::OK(); } diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index 00099396308..eb51ec1b3f6 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -446,8 +446,8 @@ class Converter { }; static StatusOr> Create( - nvinfer1::IBuilder* trt_builder, TrtPrecisionMode precision_mode, - bool use_calibration, nvinfer1::ILogger* trt_logger); + TrtPrecisionMode precision_mode, bool use_calibration, + nvinfer1::ILogger* trt_logger); ////////////////////////////////////////////////////////////////////////////// // Methods used by the TRT engine builder to build a TRT network from a TF @@ -467,7 +467,10 @@ class Converter { const std::vector& output_tensors); // Build a TRT engine using the created network. - Status BuildCudaEngine(TrtUniquePtrType* engine); + Status BuildCudaEngine(TrtUniquePtrType* engine, + int max_batch_size, size_t max_workspace_size_bytes, + nvinfer1::IGpuAllocator* allocator, + TRTInt8Calibrator* calibrator); ////////////////////////////////////////////////////////////////////////////// // Methods used by op converters to convert individual TF node and add layers @@ -526,10 +529,10 @@ class Converter { const nvinfer1::Dims& dims); private: - Converter(nvinfer1::IBuilder* trt_builder, TrtPrecisionMode precision_mode, - bool use_calibration, nvinfer1::ILogger* trt_logger); + Converter(TrtPrecisionMode precision_mode, bool use_calibration, + nvinfer1::ILogger* trt_logger); - Status Init(); + Status Init(nvinfer1::ILogger* trt_logger); // Verify the provided batch_size is consistent with batch_size_ and update it // if necessary. @@ -560,7 +563,7 @@ class Converter { std::unordered_map trt_tensors_; // The TRT builder used to create the network and build the engine. Not owned. - nvinfer1::IBuilder* trt_builder_; + TrtUniquePtrType trt_builder_; // The TRT network being built. TrtUniquePtrType trt_network_; diff --git a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc index 7617f10a372..ddfeb1a6b5a 100644 --- a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc +++ b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc @@ -77,8 +77,7 @@ Status ConvertOutputInfo(const tf2xla::Config& config, array_names.push_back(fetch.id().node_name()); } - return ParseOutputArrayInfo(array_names, &specs->output_arrays, - &specs->output_arrays_order); + return ParseOutputArrayInfo(array_names, &specs->outputs); } } // namespace @@ -110,7 +109,7 @@ Status ConvertGraphDefToXlaViaMlir(const GraphDef& graph_def, AddDevicesToOp(*module, &device_set); TF_RETURN_IF_ERROR(mlir::TF::RunBridgeWithStandardPipeline( - *module, /*enable_logging=*/VLOG_IS_ON(1))); + *module, /*enable_logging=*/VLOG_IS_ON(1), /*enable_inliner=*/true)); // Convert the MLIR module to XLA computation. If the input graph can't be // lowered down to a single graph node with a single island by the previous diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index e82546def46..8dc44eac51a 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -503,8 +503,7 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) { ParseShardingFromDevice( *possible_match, /*num_cores_per_replica=*/std::numeric_limits::max())); - if (sharding.has_value()) { - TF_RET_CHECK(sharding.value().type() == xla::OpSharding::MAXIMAL); + if (sharding && sharding->type() == xla::OpSharding::MAXIMAL) { const int core_annotation = sharding.value().tile_assignment_devices(0); if (core == -1 || core > core_annotation) { core = core_annotation; diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 7bb1ad27467..74247bbaec7 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -111,6 +111,13 @@ DataType XlaHelpers::SumAccumulationType(const DataType& dtype) { if (dtype == DT_BFLOAT16 || dtype == DT_HALF) { return DT_FLOAT; } + // Upcast small integer types to 32 bit to avoid overflow. + if (dtype == DT_INT8 || dtype == DT_INT16) { + return DT_INT32; + } + if (dtype == DT_UINT8 || dtype == DT_UINT16) { + return DT_UINT32; + } return dtype; } diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py index 52d6444e5bb..64c85b37504 100644 --- a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py +++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py @@ -188,7 +188,21 @@ def assign_device(tensor, device, assign_tuple_sharding=False): return tensor -def tile(tensor, tile_assignment, assign_tuple_sharding=False): +def tile(tensor, + tile_assignment, + assign_tuple_sharding=False, + use_sharding_op=False): + """Returns a tensor that has tiled sharding. + + Args: + tensor: A tf.Tensor to shard. + tile_assignment: An np.ndarray describing the topology of the tiling and + which device will compute which part of the topology. + assign_tuple_sharding: If the sharding type should be a tuple. + use_sharding_op: If true, adds a sharding op to set the sharding. + """ + if use_sharding_op: + tensor = tf2xla.sharding(tensor) Sharding.tile(tile_assignment).apply_to_tensor( tensor, assign_tuple_sharding=assign_tuple_sharding diff --git a/tensorflow/compiler/xla/parse_flags_from_env.cc b/tensorflow/compiler/xla/parse_flags_from_env.cc index e1e22f78417..3a914c694dc 100644 --- a/tensorflow/compiler/xla/parse_flags_from_env.cc +++ b/tensorflow/compiler/xla/parse_flags_from_env.cc @@ -17,9 +17,12 @@ limitations under the License. // modules to parse flags from an environtment variable, or a file named by the // environment variable. +#include "tensorflow/compiler/xla/parse_flags_from_env.h" + #include #include #include + #include #include #include @@ -28,7 +31,6 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" -#include "tensorflow/compiler/xla/parse_flags_from_env.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -218,9 +220,9 @@ bool ParseFlagsFromEnvAndDieIfUnknown( alternate_envvar); } - LOG(FATAL) << "Unknown flag" << (unknown_flags.size() > 1 ? "s" : "") - << " in " << envvar << ": " << absl::StrJoin(unknown_flags, " ") - << did_you_mean; + LOG(QFATAL) << "Unknown flag" << (unknown_flags.size() > 1 ? "s" : "") + << " in " << envvar << ": " << absl::StrJoin(unknown_flags, " ") + << did_you_mean; return false; } return result; diff --git a/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.h b/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.h index a3d146ff299..a64df225f2a 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.h +++ b/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.h @@ -24,7 +24,7 @@ namespace tpu_driver { xla::StatusOr> CreateGrpcTpuDriver( const TpuDriverConfig& config, - std::shared_ptr credentials); + std::shared_ptr credentials); } // namespace tpu_driver diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 4398561c5c4..054c1da9e03 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -110,6 +110,17 @@ StatusOr GetComputationHloDotGraph( RenderedGraphFormat::kDot); } +// Hashes the HLO module. +StatusOr HashComputation(const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, + HloModule::CreateModuleConfigFromProto( + computation.proto(), GetDebugOptionsFromFlags())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_module, + HloModule::CreateFromProto(computation.proto(), module_config)); + return hlo_module->Hash(); +} + // Registers a 'fn_capsule' as a CPU custom call target. // 'fn_capsule' must be a void* pointer encapsulated in a PyCapsule object, // with name "xla._CUSTOM_CALL_TARGET". @@ -577,7 +588,8 @@ PYBIND11_MODULE(xla_extension, m) { .def("GetProgramShape", &XlaComputation::GetProgramShape) .def("GetSerializedProto", &GetComputationSerializedProto) .def("GetHloText", &GetComputationHloText) - .def("GetHloDotGraph", &GetComputationHloDotGraph); + .def("GetHloDotGraph", &GetComputationHloDotGraph) + .def("Hash", &HashComputation); py::class_(m, "XlaOp"); diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 65db35a6988..c8f66f704d7 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -594,6 +594,9 @@ class Computation(object): def GetReturnValueShape(self): return self._c_computation.GetProgramShape().result_shape() + def Hash(self): + return self._c_computation.Hash() + # An Executable is a C++ class that duck types with the following API: # class Executable(object): diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index db1c01293ab..f490a05e25d 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -112,6 +112,24 @@ class ComputationPrinting(absltest.TestCase): self.assertTrue(hlo_dot_graph.startswith("digraph ")) +class ComputationHashTest(absltest.TestCase): + + def testHash(self): + builder0 = xla_client.ComputationBuilder("computation0") + p0 = builder0.ParameterFromNumpy(np.float32(0)) + p1 = builder0.ParameterFromNumpy(np.zeros((4,), np.float32)) + builder0.Mul(p0, p1) + computation0 = builder0.Build() + + builder1 = xla_client.ComputationBuilder("computation1") + p0 = builder1.ParameterFromNumpy(np.float32(0)) + p1 = builder1.ParameterFromNumpy(np.zeros((4,), np.float32)) + builder1.Mul(p0, p1) + computation1 = builder1.Build() + + self.assertEqual(computation0.Hash(), computation1.Hash()) + + class ComputationsWithConstantsTest(ComputationTest): """Tests focusing on Constant ops.""" diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 23d203850fc..a6300d2dc73 100755 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -4197,6 +4197,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service:hlo_replication_analysis", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 4b6b91af122..2fe8c309cb0 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1825,7 +1825,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { // If the lhs or rhs have only batch and contracting dimensions, a dot can be // rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y)))) if (options_.enable_dot_strength_reduction() && - ShapeUtil::ElementIsFloating(dot->shape()) && + (ShapeUtil::ElementIsFloating(dot->shape()) || + ShapeUtil::ElementIsComplex(dot->shape())) && ((dot->dot_dimension_numbers().lhs_batch_dimensions_size() + dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == lhs->shape().rank()) || @@ -1886,7 +1887,10 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { MakeBinaryHlo(HloOpcode::kMultiply, new_lhs, new_rhs)); std::vector reduce_dims( dot->dot_dimension_numbers().lhs_contracting_dimensions_size()); - PrimitiveType dot_type = dot->shape().element_type() == F64 ? F64 : F32; + PrimitiveType dot_type = + ShapeUtil::ElementIsComplex(dot->shape()) + ? dot->shape().element_type() + : dot->shape().element_type() == F64 ? F64 : F32; new_dot = AsType(new_dot, dot_type); const int64 outer_dims = std::max(rhs_outer_dims, lhs_outer_dims); absl::c_iota( diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 2618a12673f..88282986560 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -4706,12 +4706,11 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { EXPECT_EQ(has_no_dot, dot_should_be_transformed); } -INSTANTIATE_TEST_SUITE_P(BatchDotStrengthReductionTestInstantiation, - BatchDotStrengthReductionTest, - ::testing::Combine(::testing::Values(-1, 1, 2), - ::testing::Values(-1, 1, 2), - ::testing::Values(-1, 1, 2), - ::testing::Values(F64, F32, BF16))); +INSTANTIATE_TEST_SUITE_P( + BatchDotStrengthReductionTestInstantiation, BatchDotStrengthReductionTest, + ::testing::Combine(::testing::Values(-1, 1, 2), ::testing::Values(-1, 1, 2), + ::testing::Values(-1, 1, 2), + ::testing::Values(C128, C64, F64, F32, BF16))); class DotStrengthReductionTest : public AlgebraicSimplifierTest, @@ -4775,7 +4774,8 @@ INSTANTIATE_TEST_SUITE_P( DotStrengthReductionTestInstantiation, DotStrengthReductionTest, ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2), ::testing::Values(1, 2), ::testing::Bool(), - ::testing::Bool(), ::testing::Values(F64, F32, BF16))); + ::testing::Bool(), + ::testing::Values(C128, C64, F64, F32, BF16))); struct DotOfConcatTestSpec { int64 m; diff --git a/tensorflow/compiler/xla/service/all_reduce_simplifier.cc b/tensorflow/compiler/xla/service/all_reduce_simplifier.cc index e541bfea11f..541006f04d5 100644 --- a/tensorflow/compiler/xla/service/all_reduce_simplifier.cc +++ b/tensorflow/compiler/xla/service/all_reduce_simplifier.cc @@ -28,12 +28,17 @@ limitations under the License. namespace xla { StatusOr AllReduceSimplifier::Run(HloModule* module) { - TF_ASSIGN_OR_RETURN(auto replication, HloReplicationAnalysis::Run(module)); + TF_ASSIGN_OR_RETURN( + auto replication, + HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/false)); std::vector all_reduces_to_replace; for (auto computation : module->computations()) { for (HloInstruction* inst : computation->MakeInstructionPostOrder()) { if (!inst->shape().IsArray()) { // We currently do not change tuple-shaped all-reduce. + // Until XLA will support Token fed AllReduce(), the PyTorch client code + // uses a fake data token (constant) which relies on this pass to not + // optimize out (being fed within a tuple input). continue; } if (inst->IsCrossReplicaAllReduce() && diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc index ae39906ef52..06aaad351e6 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_replication_analysis.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -240,7 +241,8 @@ bool ArCrsCombiner::TupleElementsComputeSameValue( /* static */ bool ArCrsCombiner::TestInstructionsComputeSameValue(HloInstruction* i1, HloInstruction* i2) { - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1, + /*spmd_partition=*/false); auto module = i1->parent()->parent(); CHECK_EQ(module, i2->parent()->parent()); combiner.call_graph_ = CallGraph::Build(module); @@ -363,14 +365,14 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) { } } -void ArCrsCombiner::KeepProvablyEqualInstructionGroups() { +Status ArCrsCombiner::KeepProvablyEqualInstructionGroupsMPMD() { for (auto it : all_reduce_map_) { auto channel_id = it.first; VLOG(2) << "KeepProvablyEqualInstructionGroups. Checking AllReduce channel id: " << channel_id << "\n"; auto pairs_vec = it.second; - CHECK_EQ(pairs_vec.size(), num_spatial_partitions_); + TF_RET_CHECK(pairs_vec.size() == num_spatial_partitions_); auto instr_0 = pairs_vec[0].ar; for (int i = 1; i < pairs_vec.size(); ++i) { auto instr_i = pairs_vec[i].ar; @@ -393,6 +395,44 @@ void ArCrsCombiner::KeepProvablyEqualInstructionGroups() { } } } + return Status::OK(); +} + +Status ArCrsCombiner::KeepProvablyEqualInstructionGroupsSPMD( + HloModule* module) { + // For SPMD mode, use HloReplicationAnalysis to figure out HLO value + // equivalence across partitions. + TF_ASSIGN_OR_RETURN( + auto replication_analysis, + HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/true)); + + for (auto it : all_reduce_map_) { + auto channel_id = it.first; + VLOG(2) + << "KeepProvablyEqualInstructionGroups. Checking AllReduce channel id: " + << channel_id << "\n"; + auto pairs_vec = it.second; + TF_RET_CHECK(pairs_vec.size() == 1); + auto instr = pairs_vec[0].ar; + auto next = instr->users()[0]; + while (true) { + // The patterns we detect in ArCrsCombiner::MatchesArCrsPattern() + // guarantee that the HLO produces an array. + TF_RET_CHECK(next->shape().IsArray()); + if (!replication_analysis->HloInstructionIsReplicatedAt(next, {})) { + all_reduce_map_.erase(channel_id); + VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased AllReduce " + "channel id: " + << channel_id << "\n"; + break; + } + if (next->IsCrossReplicaAllReduce()) { + break; + } + next = next->users()[0]; + } + } + return Status::OK(); } StatusOr ArCrsCombiner::RewriteGraph() { @@ -460,7 +500,11 @@ StatusOr ArCrsCombiner::Run(HloModule* module) { GroupAllReducesById(module); - KeepProvablyEqualInstructionGroups(); + if (spmd_partition_) { + TF_RETURN_IF_ERROR(KeepProvablyEqualInstructionGroupsSPMD(module)); + } else { + TF_RETURN_IF_ERROR(KeepProvablyEqualInstructionGroupsMPMD()); + } return RewriteGraph(); } diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.h b/tensorflow/compiler/xla/service/ar_crs_combiner.h index a85e18d328c..95443c0c74a 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.h +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.h @@ -25,18 +25,21 @@ limitations under the License. namespace xla { -// When the HLO graph contains a cross-module AllReduce, followed by some simple -// linear operations, followed by a cross-replica AllReduce (also known as -// cross-replica sum, or CRS), we can combine the CMAR and the CRAR, to use an -// efficient AllReduce implementation that fully utilizes the interconnect -// bandwidth. -// Such sequences appear in spatially partitioned models. +// When the HLO graph contains a cross-module AllReduce (N separate AllReduce +// ops that share the same channel_id for MPMD partitioning, or 1 AllReduce op +// for SPMD partitioning), followed by some simple linear operations, followed +// by a cross-replica AllReduce (also known as cross-replica sum, or CRS), we +// can combine the CMAR and the CRAR, to use an efficient AllReduce +// implementation that fully utilizes the interconnect bandwidth. +// +// Such sequences appear in spatially partitioned models (either MPMD or SPMD). // This pass must run right after spatial partitioning, when the code is still // in a single HLO module. // // The steps are: // 1) Find CMARs followed by simple ops followed by CRARs. -// 2) Group CMARs by channel_id. They must all be rewritten. +// 2) Group CMARs by channel_id. They must all be rewritten. For SPMD +// partitioning, there will only be a single CMAR for each channel_id. // 3) Prove that the CMAR patterns in each core produce the same result. // 4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the // other operand by the number of spatial partitions. @@ -69,9 +72,11 @@ namespace xla { // class ArCrsCombiner : public HloModulePass { public: - ArCrsCombiner(int num_spatial_partitions, int num_replicas) + ArCrsCombiner(int num_spatial_partitions, int num_replicas, + bool spmd_partition) : num_spatial_partitions_(num_spatial_partitions), - num_replicas_(num_replicas) {} + num_replicas_(num_replicas), + spmd_partition_(spmd_partition) {} absl::string_view name() const override { return "ar-crs-combiner"; } StatusOr Run(HloModule* module) override; @@ -153,7 +158,10 @@ class ArCrsCombiner : public HloModulePass { // Looks at each AllReduce group in all_reduce_map_, and keeps only the // groups for which it's safe to move the AllReduce later in the HLO graph. - void KeepProvablyEqualInstructionGroups(); + Status KeepProvablyEqualInstructionGroupsMPMD(); + + // Same as above, but runs on SPMD partitioned module instead of MPMD. + Status KeepProvablyEqualInstructionGroupsSPMD(HloModule* module); // Performs the graph rewrite that eliminates the early AllReduce and turns // the later CRS into an AllReduce. @@ -163,6 +171,15 @@ class ArCrsCombiner : public HloModulePass { int num_replicas_; + // Run this combiner pass assuming the input module is an SPMD partitioned + // module (as opposed to MPMD partitioned). + // + // The main difference between the two w.r.t. this pass is that there would be + // N all-reduce ops for each channel in MPMD mode, whereas there is only 1 + // for each channel in SPMD mode. Also we use HloReplicationAnalysis for HLO + // equivalence check in SPMD mode. + bool spmd_partition_; + // Map from all-reduce ids to the AR/CRS pairs. absl::flat_hash_map> all_reduce_map_; diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc index accc0684e8e..609da2c33a0 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc @@ -452,7 +452,8 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/false); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -464,6 +465,55 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { CompareReplicaGroups(replica_groups_before, replica_groups_after); } +TEST_F(ArCrsCombinerTest, RewriteArConvertCrsSPMD) { + const char* module_str = R"( +HloModule foobar + +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: bf16[]) -> (f32[]) { + %p = bf16[] parameter(0) + %all-reduce.ar.1 = bf16[] + all-reduce(%p), + replica_groups={{0},{1}}, + channel_id=1, + to_apply=%sum.bf16 + %convert.1 = f32[] convert(%all-reduce.ar.1) + %all-reduce.1 = f32[] + all-reduce(%convert.1), + replica_groups={{0,1}}, + to_apply=%sum.f32 + ROOT %tuple = (f32[]) tuple(%all-reduce.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + true); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Convert(op::Parameter())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + TEST_F(ArCrsCombinerTest, RewriteArBitcastCrs) { const char* module_str = R"( HloModule foobar @@ -520,7 +570,8 @@ ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/false); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -587,7 +638,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/false); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT( @@ -600,6 +652,47 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { CompareReplicaGroups(replica_groups_before, replica_groups_after); } +TEST_F(ArCrsCombinerTest, RewriteArMultiplyCrsSPMD) { + const char* module_str = R"( +HloModule foobar + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[]) { + %p = f32[] parameter(0) + %constant.f32 = f32[] constant(123) + + %all-reduce.ar.1 = f32[] all-reduce(%p), replica_groups={{0},{1}}, + channel_id=1, to_apply=%sum.f32 + %multiply.1 = f32[] multiply(%all-reduce.ar.1, %constant.f32) + %all-reduce.1 = f32[] all-reduce(%multiply.1), replica_groups={{0,1}}, + to_apply=%sum.f32, sharding={maximal device=0} + ROOT %tuple = (f32[]) tuple(%all-reduce.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/true); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Multiply(op::Parameter(), op::Constant())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + TEST_F(ArCrsCombinerTest, RewriteArConvertAddCrs) { const char* module_str = R"( HloModule foobar @@ -668,7 +761,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/false); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT( @@ -684,6 +778,55 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { CompareReplicaGroups(replica_groups_before, replica_groups_after); } +TEST_F(ArCrsCombinerTest, RewriteArConvertAddCrsSPMD) { + const char* module_str = R"( +HloModule foobar + +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[]) { + %p = f32[] parameter(0) + %constant.bf16 = bf16[] constant(1) + %constant.f32 = f32[] constant(2) + + %all-reduce.ar.1 = bf16[] all-reduce(%constant.bf16), replica_groups={{0},{1}}, + channel_id=1, to_apply=%sum.bf16 + %convert.1 = f32[] convert(%all-reduce.ar.1), sharding={maximal device=0} + %add.1 = f32[] add(%constant.f32, %convert.1) + %all-reduce.1 = f32[] all-reduce(%add.1), replica_groups={{0,1}}, + to_apply=%sum.f32 + ROOT %tuple = (f32[]) tuple(%all-reduce.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/true); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Add( + op::Divide(op::Constant(), op::Constant()), op::Convert())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewrite) { const char* module_str = R"( HloModule foobar @@ -750,7 +893,46 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(module_str)); - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/false); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_FALSE(changed); +} + +TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewriteSPMD) { + const char* module_str = R"( +HloModule foobar + +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[]) { + %p = f32[] parameter(0) + %constant.bf16 = bf16[] constant(1) + %constant.f32.1 = f32[] constant(2) + + %all-reduce.ar.1 = bf16[] all-reduce(%constant.bf16), replica_groups={{0},{1}}, + channel_id=1, to_apply=%sum.bf16 + %convert.1 = f32[] convert(%all-reduce.ar.1) + %add.1 = f32[] add(%p, %convert.1) + %all-reduce.1 = f32[] all-reduce(%add.1), replica_groups={{0,1}}, to_apply=%sum.f32 + ROOT %tuple = (f32[]) tuple(%all-reduce.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/true); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_FALSE(changed); } @@ -810,7 +992,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/false); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -884,7 +1067,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/false); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -902,6 +1086,50 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { CompareReplicaGroups(replica_groups_before, replica_groups_after); } +TEST_F(ArCrsCombinerTest, RewriteMultipleAddsSPMD) { + const char* module_str = R"( +HloModule foobar + +%sum (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[]) { + %p = f32[] parameter(0) + %constant.1 = f32[] constant(1) + %constant.2 = f32[] constant(2) + + %all-reduce.ar.1 = f32[] all-reduce(%p), replica_groups={{0},{1}}, + channel_id=1, to_apply=%sum + %add.11 = f32[] add(%constant.1, %all-reduce.ar.1) + %add.12 = f32[] add(%constant.2, %add.11) + %all-reduce.1 = f32[] all-reduce(%add.12), replica_groups={{0,1}}, to_apply=%sum + ROOT %tuple = (f32[]) tuple(%all-reduce.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/true); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce( + op::Add(op::Divide(op::Constant(), op::Constant()), + op::Add(op::Divide(op::Constant(), op::Constant()), + op::Parameter()))))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + TEST_F(ArCrsCombinerTest, RewriteArSubtractCrs) { const char* module_str = R"( HloModule foobar @@ -957,7 +1185,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/false); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT( @@ -973,6 +1202,47 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { CompareReplicaGroups(replica_groups_before, replica_groups_after); } +TEST_F(ArCrsCombinerTest, RewriteArSubtractCrsSPMD) { + const char* module_str = R"( +HloModule foobar + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[]) { + %p = f32[] parameter(0) + %constant.f32 = f32[] constant(123) + %all-reduce.ar.1 = f32[] all-reduce(%p), replica_groups={{0},{1}}, + channel_id=1, to_apply=%sum.f32 + %sub.1 = f32[] subtract(%constant.f32, %all-reduce.ar.1) + %all-reduce.1 = f32[] all-reduce(%sub.1), replica_groups={{0,1}}, + to_apply=%sum.f32 + ROOT %tuple = (f32[]) tuple(%all-reduce.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/true); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Subtract( + op::Divide(op::Constant(), op::Constant()), op::Parameter())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + TEST_F(ArCrsCombinerTest, RewriteMultipleARsLeft) { const char* module_str = R"( HloModule foobar @@ -1047,7 +1317,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/false); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -1065,6 +1336,53 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { CompareReplicaGroups(replica_groups_before, replica_groups_after); } +TEST_F(ArCrsCombinerTest, RewriteMultipleARsLeftSPMD) { + const char* module_str = R"( +HloModule foobar + +%sum (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[]) { + %p = f32[] parameter(0) + %const1 = f32[] constant(1) + %const2 = f32[] constant(2) + + %ar11 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=1, + to_apply=%sum + %add11 = f32[] add(%ar11, %const1) + %ar12 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=2, + to_apply=%sum + %add12 = f32[] add(%add11, %ar12) + %crs1 = f32[] all-reduce(%add12), replica_groups={{0,1}}, + to_apply=%sum + ROOT %tuple = (f32[]) tuple(%crs1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/true); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Add( + op::Add(op::Parameter(), op::Divide(op::Constant(), op::Constant())), + op::Parameter())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + TEST_F(ArCrsCombinerTest, RewriteMultipleARsRight) { const char* module_str = R"( HloModule foobar @@ -1139,7 +1457,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/false); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT( @@ -1159,6 +1478,51 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { CompareReplicaGroups(replica_groups_before, replica_groups_after); } +TEST_F(ArCrsCombinerTest, RewriteMultipleARsRightSPMD) { + const char* module_str = R"( +HloModule foobar + +%sum (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[]) { + %p = f32[] parameter(0) + %const1 = f32[] constant(1) + %const2 = f32[] constant(2) + + %ar11 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=1, to_apply=%sum + %ar12 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=2, to_apply=%sum + %add11 = f32[] add(%ar12, %const1) + %add12 = f32[] add(%ar11, %add11) + %crs1 = f32[] all-reduce(%add12), replica_groups={{0,1}}, to_apply=%sum + ROOT %tuple = (f32[]) tuple(%crs1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/true); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Add( + op::Parameter(), + op::Add(op::Parameter(), + op::Divide(op::Constant(), op::Constant())))))); + + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + TEST_F(ArCrsCombinerTest, OneReplicaDontRewrite) { const char* module_str = R"( HloModule foobar @@ -1217,7 +1581,45 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(module_str)); - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1, + /*spmd_partition=*/false); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_FALSE(changed); +} + +TEST_F(ArCrsCombinerTest, OneReplicaDontRewriteSPMD) { + const char* module_str = R"( +HloModule foobar + +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: bf16[]) -> (f32[]) { + %p = bf16[] parameter(0) + %constant.bf16 = bf16[] constant(1) + + %all-reduce.ar.1 = bf16[] all-reduce(%p), replica_groups={{0}}, + channel_id=1, to_apply=%sum.bf16 + %convert.1 = f32[] convert(%all-reduce.ar.1) + %all-reduce.1 = f32[] all-reduce(%convert.1), + replica_groups={{0}}, to_apply=%sum.f32 + ROOT %tuple = (f32[]) tuple(%all-reduce.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1, + /*spmd_partition=*/true); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_FALSE(changed); } @@ -1291,7 +1693,36 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(module_str)); - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/false); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_FALSE(changed); +} + +TEST_F(ArCrsCombinerTest, AllReduceWithReplicasSPMD) { + const char* module_str = R"( +HloModule foobar + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: bf16[]) -> (f32[]) { + %p = bf16[] parameter(0) + %all-reduce.0 = f32[] all-reduce(%p), channel_id=1, replica_groups={{0,1}}, + to_apply=%sum.f32 + %all-reduce.2 = f32[] all-reduce(%all-reduce.0), replica_groups={{0,1}}, + to_apply=%sum.f32 + ROOT %tuple = (f32[]) tuple(%all-reduce.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/true); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_FALSE(changed); } diff --git a/tensorflow/compiler/xla/service/collective_ops_utils.h b/tensorflow/compiler/xla/service/collective_ops_utils.h index f2fa9640f85..2c5f2d64d1f 100644 --- a/tensorflow/compiler/xla/service/collective_ops_utils.h +++ b/tensorflow/compiler/xla/service/collective_ops_utils.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_ +#include #include #include "tensorflow/compiler/xla/executable_run_options.h" @@ -188,6 +189,48 @@ class Rendezvous { virtual ~Rendezvous() {} explicit Rendezvous(const RendezvousKey& k) : key_(k) {} + // Submit a participant to the rendezvous. We get the rendezvous from + // `rendezvous_getter`, which we can then use to drop the existing reference. + static StatusOr SubmitParticipant( + std::function>()> rendezvous_getter, + AllReduceParticipantData participant) { + std::shared_ptr> rendezvous = rendezvous_getter(); + TF_ASSIGN_OR_RETURN(auto p, rendezvous->SubmitParticipant(participant)); + + // Drop our reference to the Rendezvous and wait for all other threads to do + // the same. If we didn't do this, one of the threads could run past this + // point, reenter ExecuteOnStream for another all-reduce, and attempt to + // reuse the Rendezvous! + // + // An alternative way of accomplishing this goal would be to implement + // RefcountingHashMap::erase() and call it during SubmitParticipant. But + // erase() is deceptively complex to implement correctly. + std::shared_ptr blocking_counter = p.second; + rendezvous.reset(); + blocking_counter->DecrementCount(); + xla::WaitAndLogIfStuck(blocking_counter.get(), [&] { + return absl::StrFormat( + "participant waiting for all threads to drop their reference to the " + "rendezvous: %p", + rendezvous.get()); + }); + return p.first; + } + + protected: + // Returns domain-specific output O and whether this replica is primary. + virtual StatusOr> SubmitParticipantImpl( + AllReduceParticipantData participant) = 0; + + virtual void CleanupImpl(O handle, bool is_primary) {} + + tensorflow::mutex mu_; + + bool initialized_ GUARDED_BY(mu_) = false; + + std::vector participants_ GUARDED_BY(mu_); + + private: // Runs the all-reduce on the given thread. If successful, returns // - a handle to the clique that was used, so that the caller may keep the // clique alive if it chooses. @@ -248,21 +291,6 @@ class Rendezvous { return std::make_pair(handle, returned_blocking_counter_); } - - protected: - // Returns domain-specific output O and whether this replica is primary. - virtual StatusOr> SubmitParticipantImpl( - AllReduceParticipantData participant) = 0; - - virtual void CleanupImpl(O handle, bool is_primary) {} - - tensorflow::mutex mu_; - - bool initialized_ GUARDED_BY(mu_) = false; - - std::vector participants_ GUARDED_BY(mu_); - - private: const RendezvousKey key_; tensorflow::BlockingCounter all_participants_present_{ diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 411ae8f7d64..bec66aea27f 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -242,9 +242,15 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_dataflow_analysis", "//tensorflow/compiler/xla/service:hlo_execution_profile", "//tensorflow/compiler/xla/service:logical_buffer", + "//tensorflow/compiler/xla/service:maybe_owning_device_memory", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:macros", + "//tensorflow/core/platform:mutex", + "//tensorflow/core/platform:platform_port", + "//tensorflow/core/platform:types", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/stream_executor:device_memory_allocator", "//tensorflow/stream_executor/host:host_stream", @@ -482,6 +488,7 @@ cc_library( "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:refcounting_hash_map", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto_cc", @@ -495,6 +502,7 @@ cc_library( "//tensorflow/core/platform:macros", "//tensorflow/core/platform:mutex", "//tensorflow/core/platform:platform_port", + "//tensorflow/core/platform:status", "//tensorflow/core/platform:types", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/stream_executor", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 9b79e8ca8d7..d19cf4fb015 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -44,6 +45,7 @@ limitations under the License. #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" #include "tensorflow/stream_executor/host/host_stream.h" namespace xla { @@ -73,11 +75,12 @@ CpuExecutable::CpuExecutable( << reinterpret_cast(compute_function_); } -StatusOr, - std::vector>> +StatusOr, + std::vector, + std::vector>> CpuExecutable::CreateBufferTable( se::DeviceMemoryAllocator* memory_allocator, int device_ordinal, - absl::Span arguments) { + std::vector> arguments) { std::vector unowning_buffers( assignment_->Allocations().size()); std::vector owning_buffers( @@ -91,8 +94,9 @@ CpuExecutable::CreateBufferTable( VLOG(3) << allocation.ToString(); if (allocation.is_entry_computation_parameter()) { - unowning_buffers[i] = arguments[allocation.parameter_number()]->buffer( - allocation.param_shape_index()); + unowning_buffers[i] = arguments[allocation.parameter_number()] + .element(allocation.param_shape_index()) + .AsDeviceMemoryBase(); CHECK_EQ(allocation.size(), unowning_buffers[i].size()) << "Size mismatch on param " << allocation.parameter_number() << " at shape index " << allocation.param_shape_index().ToString(); @@ -134,7 +138,17 @@ CpuExecutable::CreateBufferTable( assignment_->GetUniqueTopLevelOutputSlice()); VLOG(3) << "result index: " << result_slice.index(); - return {{std::move(unowning_buffers), std::move(owning_buffers)}}; + std::vector buffers_to_free; + for (ShapeTree& argument : arguments) { + for (std::pair& buffer : argument) { + auto maybe_owning_buffer = buffer.second.Release(); + if (maybe_owning_buffer) { + buffers_to_free.push_back(std::move(*maybe_owning_buffer)); + } + } + } + return std::make_tuple(std::move(unowning_buffers), std::move(owning_buffers), + std::move(buffers_to_free)); } Status CpuExecutable::ExecuteComputeFunction( @@ -268,9 +282,9 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( return std::move(result_buffer); } -StatusOr CpuExecutable::ExecuteAsyncOnStream( +StatusOr CpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - absl::Span arguments, + std::vector> arguments, HloExecutionProfile* hlo_execution_profile) { if (GetRootValueSet().IsAmbiguous()) { return Unimplemented("Points-to set of root instruction is ambiguous"); @@ -283,7 +297,7 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( for (int64 i = 0; i < entry_comp->num_parameters(); ++i) { const Shape& expected_shape = entry_comp->parameter_instruction(i)->shape(); - const Shape& actual_shape = arguments[i]->on_device_shape(); + const Shape& actual_shape = arguments[i].shape(); CHECK(expected_shape == actual_shape) << absl::StreamFormat( "Shape mismatch on argument %d. Expected %s, but was %s.", i, expected_shape.ToString(/*print_layout=*/true), @@ -297,10 +311,11 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( se::DeviceMemoryAllocator* memory_allocator = run_options->allocator(); std::vector owning_buffers; std::vector unowning_buffers; + std::vector buffers_to_release; TF_ASSIGN_OR_RETURN( - std::tie(unowning_buffers, owning_buffers), + std::tie(unowning_buffers, owning_buffers, buffers_to_release), CreateBufferTable(memory_allocator, stream->parent()->device_ordinal(), - arguments)); + std::move(arguments))); TF_ASSIGN_OR_RETURN( ScopedShapedBuffer result, @@ -339,7 +354,8 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( std::move(owning_buffers)), hlo_execution_profile}); - return std::move(result); + return ExecutionOutput(std::move(result), std::move(buffers_to_release), {}, + se::OwningDeviceMemory()); } /*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 37af630a2d9..6f8a7c3315a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -55,9 +55,9 @@ class CpuExecutable : public Executable { std::unique_ptr hlo_profile_index_map); ~CpuExecutable() override {} - StatusOr ExecuteAsyncOnStream( + StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - absl::Span arguments, + std::vector> arguments, HloExecutionProfile* hlo_execution_profile) override; // This should be called after set_ir_module_string. @@ -96,11 +96,15 @@ class CpuExecutable : public Executable { // allocated by this routine. This routine allocates buffers for temporary // storage and the live-out buffer into which the computation writes it // result. - StatusOr, - std::vector>> + // + // - buffers_to_free: buffers whose ownership was donated by the caller that + // are to be freed by the caller. + StatusOr, + std::vector, + std::vector>> CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator, int device_ordinal, - absl::Span arguments); + std::vector> arguments); // Calls the generated function performing the computation with the given // arguments using the supplied buffers. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 2cb15f6ec4d..56d663f7b24 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/synchronization/mutex.h" #include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/refcounting_hash_map.h" #include "tensorflow/compiler/xla/service/collective_ops_utils.h" @@ -36,6 +37,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/stream_executor/device_memory.h" @@ -414,40 +416,30 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce( xla::RendezvousKey rendezvous_key(run_options->run_id(), participating_replicas_vec, op_kind, op_id); - std::shared_ptr rendezvous = - GlobalRendezvousMap()[rendezvous_key]; auto shape_str = ShapeString(shape_ptr, shape_length); VLOG(2) << "All-reduce input/output shape : " << shape_str; xla::Shape shape = DecodeSelfDescribingShapeConstant(shape_ptr, shape_length).ValueOrDie(); + CHECK(xla::LayoutUtil::IsDenseArray(shape)) + << "All-reduce on CPU is implemented only for dense arrays"; xla::AllReduceParticipantData participant(rendezvous_key); - - CHECK_LE(shape.dimensions_size(), 1); participant.element_count = xla::ShapeUtil::ElementsIn(shape); participant.device_ordinal = device_ordinal; participant.primitive_type = shape.element_type(); participant.stream = run_options->stream(); - - se::DeviceMemoryBase input(input_buffer, xla::ShapeUtil::ByteSizeOf(shape)); - se::DeviceMemoryBase output(output_buffer, xla::ShapeUtil::ByteSizeOf(shape)); - participant.source_data = input; - participant.destination_data = output; + participant.source_data = + se::DeviceMemoryBase(input_buffer, xla::ShapeUtil::ByteSizeOf(shape)); + participant.destination_data = + se::DeviceMemoryBase(output_buffer, xla::ShapeUtil::ByteSizeOf(shape)); participant.reduction_kind = static_cast(reduction_kind); - auto p = rendezvous->SubmitParticipant(participant).ValueOrDie(); - std::shared_ptr blocking_counter = p.second; - blocking_counter->DecrementCount(); - xla::WaitAndLogIfStuck(blocking_counter.get(), [&] { - return absl::StrFormat( - "participant waiting for all threads to drop their reference to the " - "rendezvous: %s", - rendezvous_key.ToString()); - }); - - rendezvous.reset(); + TF_CHECK_OK( + CpuAllReduceRendezvous::SubmitParticipant( + [&] { return GlobalRendezvousMap()[rendezvous_key]; }, participant) + .status()); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ReplicaId( diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 9510d1fecde..cf167a57087 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -1511,6 +1511,7 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) { } Status IrEmitter::HandleReplicaId(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); llvm::FunctionType* replica_id_function_ty = llvm::FunctionType::get(b_.getVoidTy(), diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index 1274c095b95..14ea6f988cb 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -419,14 +419,47 @@ Status DynamicDimensionInferenceVisitor::HandleConvolution( Status DynamicDimensionInferenceVisitor::HandleConcatenate( HloInstruction* hlo) { + // First handle concatenate dimensions. We do this by iterating through all + // operands while tracking both dynamic and static dimensions. + + // static_size is used to keep track of the concated size of static + // dimensions. + int64 static_size = 0; + std::vector dynamic_concat_dims; + for (int64 i = 0; i < hlo->operand_count(); ++i) { + HloInstruction* dynamic_size = parent_->GetDynamicSize( + hlo->mutable_operand(i), {}, hlo->concatenate_dimension()); + if (dynamic_size == nullptr) { + // This is a static dimension. + static_size += + hlo->operand(i)->shape().dimensions(hlo->concatenate_dimension()); + } else { + dynamic_concat_dims.push_back(dynamic_size); + } + } + // If concat dimension is dynamic, calculate its size by summing up static + // dims and dynamic dims together. + if (!dynamic_concat_dims.empty()) { + HloInstruction* dim_size_total = + hlo->parent()->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(static_size))); + for (HloInstruction* dynamic_dim : dynamic_concat_dims) { + dim_size_total = hlo->parent()->AddInstruction( + HloInstruction::CreateBinary(dim_size_total->shape(), HloOpcode::kAdd, + dim_size_total, dynamic_dim)); + } + parent_->SetDynamicSize(hlo, {}, hlo->concatenate_dimension(), + dim_size_total, {.stride = 1, .multiple_of = 1}); + } + + // Simply pass through non-concat dynamic dimensions. return ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, int64 operand_index, HloInstruction* dynamic_size, DimensionConstraint constraint) { int64 concatenate_dimension = hlo->concatenate_dimension(); if (concatenate_dimension == dimension) { - return Unimplemented("Dynamic concatenation is not supported yet: %s", - operand->ToString()); + return Status::OK(); } parent_->SetDynamicSize(hlo, index, dimension, dynamic_size, constraint); @@ -1318,9 +1351,9 @@ Status DynamicDimensionInference::ForwardDynamicSize(HloInstruction* inst, DynamicDimension dynamic_dimension{inst, index, dim}; auto iter = dynamic_mapping_.find(dynamic_dimension); if (iter != dynamic_mapping_.end()) { - dynamic_mapping_.try_emplace(dynamic_dimension_new, iter->second); - constraint_mapping_.try_emplace(dynamic_dimension_new, - constraint_mapping_[dynamic_dimension]); + dynamic_mapping_.insert({dynamic_dimension_new, iter->second}); + constraint_mapping_.insert( + {dynamic_dimension_new, constraint_mapping_[dynamic_dimension]}); auto iter = per_hlo_dynamic_dimensions_.try_emplace(new_inst); iter.first->second.emplace(dynamic_dimension_new); } diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc index 8d92b7e985a..c94a2594f3b 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -567,7 +567,55 @@ Status RewriteDynamicReshapeSingleDim( } return Status::OK(); } - +StatusOr RewriteDynamicConcat( + HloInstruction* concat, + DynamicDimensionInference* dynamic_dimension_inference) { + const int64 concat_dim = concat->concatenate_dimension(); + HloComputation* comp = concat->parent(); + if (dynamic_dimension_inference->GetDynamicSize(concat, {}, concat_dim) == + nullptr) { + // Concat dimension is not dynamic -- no rewrite needed. + return false; + } + std::vector offsets; + for (int64 i = 0; i < concat->shape().dimensions_size(); ++i) { + offsets.push_back(comp->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)))); + } + HloInstruction* rewritten_concat = concat; + // Keep track of previous users before rewrite so that we can update their + // operands later. + auto prev_users = concat->users(); + for (int64 i = 0; i < concat->operand_count(); ++i) { + // Rewrite the concat by dynamic update slicing operand into the concat dim. + HloInstruction* operand = concat->mutable_operand(i); + rewritten_concat = + comp->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + rewritten_concat->shape(), rewritten_concat, operand, offsets)); + // Update the offset of concat dimension by adding the size of the concat + // dimension of the operand to it. + HloInstruction* dynamic_size = + dynamic_dimension_inference->GetDynamicSize(operand, {}, concat_dim); + if (dynamic_size == nullptr) { + HloInstruction* static_size = comp->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0( + operand->shape().dimensions(concat_dim)))); + offsets[concat_dim] = comp->AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeScalarShape(S32), HloOpcode::kAdd, offsets[concat_dim], + static_size)); + } else { + offsets[concat_dim] = comp->AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeScalarShape(S32), HloOpcode::kAdd, offsets[concat_dim], + dynamic_size)); + } + } + for (HloInstruction* user : prev_users) { + TF_RETURN_IF_ERROR(concat->ReplaceUseWith(user, rewritten_concat)); + } + TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize( + concat, rewritten_concat, {})); + return true; +} StatusOr RewriteDynamicReshape( HloInstruction* reshape, DynamicDimensionInference* dynamic_dimension_inference) { @@ -709,6 +757,11 @@ StatusOr DynamicPadder::Run(HloModule* module) { for (HloComputation* computation : module->computations()) { for (HloInstruction* inst : computation->instructions()) { + if (inst->opcode() == HloOpcode::kConcatenate) { + TF_ASSIGN_OR_RETURN( + changed, RewriteDynamicConcat(inst, &dynamic_dimension_inference)); + continue; + } for (int64 operand_num = 0; operand_num < inst->operand_count(); ++operand_num) { HloInstruction* original_operand = inst->mutable_operand(operand_num); diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc index 6c3f0bec493..0e60e420d47 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc @@ -496,6 +496,51 @@ ENTRY main { EXPECT_EQ(result, expected); } +XLA_TEST_F(ExecutionTest, DynamicConcat) { + // Concatting a list of {dynamic_operand, static_operand, dynamic_operand}. + const string hlo_text = R"( +HloModule DynamicConcat + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(lhs, rhs) +} + +ENTRY main { + param_0 = s32[3] parameter(0) + param_1 = s32[3] parameter(1) + param_2 = s32[3] parameter(2) + size = s32[] constant(2) + param_padded_0 = s32[3] set-dimension-size(param_0, size), dimensions={0} + param_padded_2 = s32[3] set-dimension-size(param_2, size), dimensions={0} + %concatenate = s32[9] + concatenate(s32[3] param_padded_0, s32[3] param_1, s32[3] param_padded_2), + dimensions={0} + init = s32[] constant(0) + ROOT reduce = s32[] reduce(concatenate, init), + dimensions={0}, + to_apply=update_s32 +} +)"; + + // Input has upper bound of 3, dynamic dimension is 2. Using -1 as padding. + Literal operand_0 = + LiteralUtil::CreateR1({1, 2, -1}); // Dynamic operand. + Literal operand_1 = + LiteralUtil::CreateR1({3, 4, 5}); // Static operand. + Literal operand_2 = + LiteralUtil::CreateR1({6, 7, -1}); // Dynamic operand. + auto module = GetHloModule(hlo_text); + + Literal result = + PadAndExecute(std::move(module), {&operand_0, &operand_1, &operand_2}); + + Literal expected = LiteralUtil::CreateR0(28); + + EXPECT_EQ(result, expected); +} + XLA_TEST_F(ExecutionTest, DynamicDimensionReduce) { const string hlo_text = R"( HloModule TensorFlowScatterV1 diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index c21721c9339..9ece6172d12 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/status.h" @@ -43,9 +44,36 @@ StatusOr Executable::ExecuteOnStream( return result; } +static ShapeTree MakeMaybeOwningDeviceMemoryTree( + const ShapedBuffer& shaped_buffer) { + ShapeTree result(shaped_buffer.on_device_shape()); + auto in_it = shaped_buffer.buffers().begin(); + auto out_it = result.begin(); + for (; in_it != shaped_buffer.buffers().end(); ++in_it, ++out_it) { + DCHECK(out_it != result.end()); + out_it->second = MaybeOwningDeviceMemory(in_it->second); + } + return result; +} + +StatusOr Executable::ExecuteAsyncOnStream( + const ServiceExecutableRunOptions* run_options, + absl::Span arguments, + HloExecutionProfile* hlo_execution_profile) { + std::vector> args(arguments.size()); + auto out_it = args.begin(); + for (const ShapedBuffer* arg : arguments) { + *out_it++ = MakeMaybeOwningDeviceMemoryTree(*arg); + } + TF_ASSIGN_OR_RETURN(ExecutionOutput out, + ExecuteAsyncOnStream(run_options, std::move(args), + hlo_execution_profile)); + return out.ConsumeResult(); +} + StatusOr Executable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - std::vector> arguments, + std::vector> arguments, HloExecutionProfile* hlo_execution_profile) { StatusOr result = ExecuteAsyncOnStream( run_options, std::move(arguments), hlo_execution_profile); @@ -55,14 +83,6 @@ StatusOr Executable::ExecuteOnStream( return result; } -StatusOr Executable::ExecuteAsyncOnStream( - const ServiceExecutableRunOptions* /*run_options*/, - std::vector> /*arguments*/, - HloExecutionProfile* /*hlo_execution_profile*/) { - return Unimplemented( - "MaybeOwningDeviceMemory version of overload is not implemented "); -} - StatusOr> Executable::ExecuteOnStreams( absl::Span run_options, absl::Span> arguments) { diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 971dab95bfd..496599e7aaf 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -160,22 +160,22 @@ class Executable { // If the hlo_execution_profile is provided as non-nullptr, profiling will be // enabled. Note that profiling is tricky to use correctly, as the profiling // objects (when they exist) must out-live the task. - virtual StatusOr ExecuteAsyncOnStream( + StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, absl::Span arguments, - HloExecutionProfile* hlo_execution_profile) = 0; + HloExecutionProfile* hlo_execution_profile); // Same as ExecuteAsyncOnStream(), but blocks waiting for the computation to // complete. StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - std::vector> arguments, + std::vector> arguments, HloExecutionProfile* hlo_execution_profile); virtual StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - std::vector> arguments, - HloExecutionProfile* hlo_execution_profile); + std::vector> arguments, + HloExecutionProfile* hlo_execution_profile) = 0; // Same as ExecuteOnStream(), but runs this executable on multiple // streams. arguments[i] contains the arguments to the execution on diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index c1475d83197..9634401fe96 100755 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -11,7 +11,6 @@ load( ) load( "//tensorflow:tensorflow.bzl", - "if_nccl", "tf_cc_test", "tf_copts", "tf_cuda_library", @@ -25,6 +24,7 @@ load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured", ) +load("//tensorflow:tensorflow.bzl", "if_nccl") package( default_visibility = [":friends"], @@ -422,7 +422,7 @@ tf_cuda_library( "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor/cuda:cuda_activation", "//tensorflow/stream_executor/cuda:cuda_gpu_executor", - ] + if_cuda([ + ] + if_nccl([ "@local_config_nccl//:nccl", ]), ) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 99bc0f7fee0..93af1cd995e 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -299,11 +299,14 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) { return &module_globals_.emplace(executor, std::move(globals)).first->second; } -StatusOr GpuExecutable::Execute( +StatusOr GpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - absl::Span arguments, - HloExecutionProfile* hlo_execution_profile, bool block_host_until_done) { - se::DeviceMemoryAllocator* memory_allocator = run_options->allocator(); + std::vector> arguments, + HloExecutionProfile* hlo_execution_profile) { + se::DeviceMemoryAllocator* const memory_allocator = run_options->allocator(); + // Force synchronous execution if the allocator requires it. + const bool block_host_until_done = + !memory_allocator->AllowsAsynchronousDeallocation(); if (GetRootValueSet().IsAmbiguous()) { return Unimplemented("Points-to set of root instruction is ambiguous"); @@ -334,7 +337,9 @@ StatusOr GpuExecutable::Execute( if (allocation.is_entry_computation_parameter()) { auto param_no = allocation.parameter_number(); se::DeviceMemoryBase buffer = - arguments[param_no]->buffer(allocation.param_shape_index()); + arguments[param_no] + .element(allocation.param_shape_index()) + .AsDeviceMemoryBase(); // All top-level buffers and sub-buffers must have an explicit, non-null // pointer, except for zero-sized buffers, which may be null. @@ -423,19 +428,17 @@ StatusOr GpuExecutable::Execute( })); TF_RETURN_IF_ERROR(buffer_allocations->TearDown(buffers_in_result)); - return std::move(shaped_buffer); -} - -StatusOr GpuExecutable::ExecuteAsyncOnStream( - const ServiceExecutableRunOptions* run_options, - absl::Span arguments, - HloExecutionProfile* hlo_execution_profile) { - se::DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - // Force synchronous execution if the allocator requires it. - bool block_host_until_done = - !memory_allocator->AllowsAsynchronousDeallocation(); - return Execute(run_options, arguments, hlo_execution_profile, - block_host_until_done); + std::vector buffers_to_free; + for (ShapeTree& argument : arguments) { + for (std::pair& buffer : argument) { + auto maybe_owning_buffer = buffer.second.Release(); + if (maybe_owning_buffer) { + buffers_to_free.push_back(std::move(*maybe_owning_buffer)); + } + } + } + return ExecutionOutput(std::move(shaped_buffer), std::move(buffers_to_free), + {}, {}); } const InstructionValueSet& GpuExecutable::GetRootValueSet() const { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 66f86d768be..51e86a9f8ee 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -82,9 +82,9 @@ class GpuExecutable : public Executable { // ExecuteAsyncOnStream will fail if the compute capability of the stream // doesn't match the compute capability passed to this object's constructor. - StatusOr ExecuteAsyncOnStream( + StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - absl::Span arguments, + std::vector> arguments, HloExecutionProfile* hlo_execution_profile) override; std::shared_ptr GetBufferAssignment() const { @@ -92,11 +92,6 @@ class GpuExecutable : public Executable { } private: - StatusOr Execute( - const ServiceExecutableRunOptions* run_options, - absl::Span arguments, - HloExecutionProfile* hlo_execution_profile, bool block_host_until_done); - // If `block_host_until_done` is false, execution will not block the host // until the kernels have completed. This is used as an optimization for // clients, such as Tensorflow, that use a single stream of execution for diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc index 599eef4e600..24738683a19 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -154,11 +154,9 @@ bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1, // operand shape) and the reduction dimensions need to match. auto* instr_1 = get_real_hero(&instr1); auto* instr_2 = get_real_hero(&instr2); - // TODO(tjoerg): Relax the shape constraint. The datatype does not matter. if (IsReductionFromOrToContiguousDimensions(*instr_1) && IsReductionFromOrToContiguousDimensions(*instr_2) && - (!ShapeUtil::Equal(instr_1->shape(), instr_2->shape()) || - instr_1->dimensions() != instr_2->dimensions())) { + !AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) { return false; } // The elementwise output shapes must be the same (including layout). diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 26a6deb8030..72f69ca2017 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -405,5 +405,33 @@ llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) { EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b))); } +bool AreFusedReductionOutputsConsistent( + absl::Span output_instructions, + const HloInstruction* first_reduce) { + for (const HloInstruction* inst : output_instructions) { + if (IsReductionFromOrToContiguousDimensions(*inst)) { + // Shapes, layouts and dimensions must be the same for all reduces + // inside of this fusion. + // TODO(tjoerg): Relax the shape constraint. The datatype does not matter. + if (!(ShapeUtil::Equal(first_reduce->shape(), inst->shape()) && + ShapeUtil::Equal(first_reduce->operand(0)->shape(), + inst->operand(0)->shape()) && + ShapeUtil::Equal(first_reduce->operand(1)->shape(), + inst->operand(1)->shape()) && + first_reduce->dimensions() == inst->dimensions())) { + return false; + } + } else { + if (!(ShapeUtil::CompatibleIgnoringElementType( + first_reduce->operand(0)->shape(), inst->shape()) && + LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(), + inst->shape().layout()))) { + return false; + } + } + } + return true; +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index f269cf87062..db3cd228841 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -200,6 +200,11 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, // block 0 of the kernel. llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b); +// Returns whether the outputs of a fusion with reduction are consistent. +bool AreFusedReductionOutputsConsistent( + absl::Span output_instructions, + const HloInstruction* first_reduce); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 06a00d2178a..dbc2c95773a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2779,32 +2779,6 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { } namespace { -// Checks that the outputs of a fusion with reduction are consistent. -Status AreFusedReductionOutputsConsistent( - absl::Span output_instructions, - const HloInstruction* first_reduce) { - for (const HloInstruction* inst : output_instructions) { - if (IsReductionFromOrToContiguousDimensions(*inst)) { - // Shapes, layouts and dimensions must be the same for all reduces - // inside of this fusion. - TF_RET_CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape())); - TF_RET_CHECK(ShapeUtil::Equal(first_reduce->operand(0)->shape(), - inst->operand(0)->shape())); - TF_RET_CHECK(ShapeUtil::Equal(first_reduce->operand(1)->shape(), - inst->operand(1)->shape())); - TF_RET_CHECK(first_reduce->dimensions() == inst->dimensions()); - } else { - // For extra outputs we can relax shape equality to allow different - // types (with the same number of elements). Layouts still have to - // match. - TF_RET_CHECK(ShapeUtil::CompatibleIgnoringElementType( - first_reduce->operand(0)->shape(), inst->shape())); - TF_RET_CHECK(LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(), - inst->shape().layout())); - } - } - return Status::OK(); -} // Returns true if all the transitive users of hlo before hitting users in // use_chain_endings are elementwise operations. @@ -2994,8 +2968,10 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( const HloInstruction* first_reduce = reduce_instructions.at(0); if (output_instructions.size() > 1) { - TF_RETURN_IF_ERROR( - AreFusedReductionOutputsConsistent(output_instructions, first_reduce)); + if (!AreFusedReductionOutputsConsistent(output_instructions, + first_reduce)) { + return InternalError("Inconsistent reduction fusion outputs"); + } } // Build a kernel thunk to compute all the outputs. diff --git a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h index ae6f4e39560..345abbd0935 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h @@ -120,10 +120,6 @@ class KernelMappingScheme { return dims_in_blocks_; } - int64 GetNumberOfTilesInTotal() const { - return absl::c_accumulate(dims_in_tiles_, 1LL, std::multiplies()); - } - int64 GetNumberOfTilesInOneBlock() const { return block_size_z_; } int64 BlockSizeZ() const { return block_size_z_; } diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc index ac80552d032..2fb1fc07056 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/types/optional.h" #include "absl/types/span.h" #include "third_party/nccl/nccl.h" +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/refcounting_hash_map.h" #include "tensorflow/compiler/xla/service/collective_ops_utils.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" @@ -454,7 +455,8 @@ struct NcclAllReduceThunk::AuxData { return MatchReductionComputation(crs->to_apply()).has_value() && DatatypeToNccl(AllReducePrimitiveType(crs)).has_value() && crs->IsCrossReplicaAllReduce() && - crs->operand_count() == 1; // One array to reduce. + crs->operand_count() == 1 && // One array to reduce. + LayoutUtil::IsDenseArray(crs->operand(0)->shape()); } /*static*/ absl::flat_hash_set @@ -497,11 +499,8 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { // Find or create the rendezvous for this collective operation. RendezvousKey rendezvous_key = RendezvousKey::FromInstruction( params.run_id, participating_replicas, hlo_instruction()); - std::shared_ptr rendezvous = - GlobalRendezvousMap()[rendezvous_key]; VLOG(2) << "Rendezvous key: " << rendezvous_key.ToString() - << ", rendezvous: " << rendezvous.get() << ", participating replicas: " << absl::StrJoin(participating_replicas, ", "); @@ -519,19 +518,10 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { participant.reduction_kind = *reduction_kind; participant.primitive_type = AllReducePrimitiveType(hlo_instruction()); - // Do the operation. - StatusOr, - std::shared_ptr>> - result = rendezvous->SubmitParticipant(participant); - if (!result.ok()) { - VLOG(1) << "NcclAllReduceThunk::ExecuteOnStream failed: " - << result.status().ToString(); - return result.status(); - } - - std::shared_ptr clique; - std::shared_ptr blocking_counter; - std::tie(clique, blocking_counter) = std::move(result).ValueOrDie(); + TF_ASSIGN_OR_RETURN( + std::shared_ptr clique, + RendezvousNcclAllReduce::SubmitParticipant( + [&] { return GlobalRendezvousMap()[rendezvous_key]; }, participant)); // Keep the clique we used alive for as long as this Thunk lives. Creating // new NCCL cliques is expensive, and this is how we avoid thrashing them. @@ -539,24 +529,6 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { tensorflow::mutex_lock lock(aux_data_->mu); aux_data_->cliques.insert(std::move(clique)); } - - // Drop our reference to the Rendezvous and wait for all other threads to do - // the same. If we didn't do this, one of the threads could run past this - // point, reenter ExecuteOnStream for another all-reduce, and attempt to reuse - // the Rendezvous! - // - // An alternative way of accomplishing this goal would be to implement - // RefcountingHashMap::erase() and call it during SubmitParticipant. But - // erase() is deceptively complex to implement correctly. - rendezvous.reset(); - blocking_counter->DecrementCount(); - WaitAndLogIfStuck(blocking_counter.get(), [&] { - return absl::StrFormat( - "participant for device ordinal %d, stream %p waiting for " - "all threads to drop their reference to the rendezvous: %s", - device_ordinal, params.stream, rendezvous_key.ToString()); - }); - return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc index 1c5b166a801..3e82e3271bb 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc @@ -151,7 +151,8 @@ absl::optional HloInputOutputAliasConfig::GetAliasedOutput( absl::optional HloInputOutputAliasConfig::GetAliasedParameter( const ShapeIndex& output_index) const { - CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index)); + CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index)) + << ToString() << " " << alias_.shape().ToString() << " " << output_index; return alias_.element(output_index); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 5e2e53ea6db..5855911650d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -113,7 +113,7 @@ class HloPrintOptions { .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies) .set_print_metadata(false) .set_print_backend_config(false) - .set_compact_operands(true) + .set_compact_operands(false) .set_print_operand_names(false) .set_print_operand_shape(true) .set_print_program_shape(false) diff --git a/tensorflow/compiler/xla/service/hlo_replication_analysis.cc b/tensorflow/compiler/xla/service/hlo_replication_analysis.cc index e11d3920f95..3a896d4a113 100644 --- a/tensorflow/compiler/xla/service/hlo_replication_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_replication_analysis.cc @@ -35,13 +35,45 @@ namespace { // knowledge in hlo_replication. bool DetermineHloInstructionIsReplicated( const HloInstruction* hlo, const ShapeIndex& index, + bool cross_partition_spmd, const absl::flat_hash_map>& hlo_replication) { + // Returns true if all operands are known to be replicated. + const auto all_operands_replicated = + [&hlo_replication](const HloInstruction* inst) { + for (auto operand : inst->operands()) { + auto operand_it = hlo_replication.find(operand); + if (operand_it == hlo_replication.end() || + !operand_it->second.element({})) { + return false; + } + } + return true; + }; + + if (hlo->IsCrossReplicaAllReduce()) { + if (cross_partition_spmd) { + // Cross-replica all-reduce returns same values across partitions as long + // as its operands are replicated. + return all_operands_replicated(hlo); + } + // Only all-reduce across all cores are replicated, which means there + // is only one subgroup. + return hlo->replica_groups().empty() || hlo->replica_groups().size() == 1; + } + if (hlo->IsCrossModuleAllReduce()) { + return cross_partition_spmd; + } if (hlo->HasSideEffectNoRecurse()) { return false; } if (hlo->opcode() == HloOpcode::kReplicaId) { - return false; + // ReplicaId returns the same value for all partitions in each replica. + return cross_partition_spmd; + } + if (hlo->opcode() == HloOpcode::kPartitionId) { + // PartitionId returns the same value for all replicas in each partition. + return !cross_partition_spmd; } auto it = hlo_replication.find(hlo); if (hlo->opcode() == HloOpcode::kParameter) { @@ -55,11 +87,6 @@ bool DetermineHloInstructionIsReplicated( if (hlo->opcode() == HloOpcode::kConstant) { return true; } - if (hlo->opcode() == HloOpcode::kAllReduce) { - // Only all-reduce across all cores are replicated, which means there - // is only one subgroup. - return hlo->replica_groups().empty() || hlo->replica_groups().size() == 1; - } if (hlo->IsElementwise() || // hlo->opcode() == HloOpcode::kConcatenate || // @@ -80,14 +107,7 @@ bool DetermineHloInstructionIsReplicated( hlo->opcode() == HloOpcode::kDynamicUpdateSlice || // hlo->opcode() == HloOpcode::kReduceWindow || // hlo->opcode() == HloOpcode::kCopy) { - for (auto operand : hlo->operands()) { - auto operand_it = hlo_replication.find(operand); - if (operand_it == hlo_replication.end() || - !operand_it->second.element({})) { - return false; - } - } - return true; + return all_operands_replicated(hlo); } return false; } @@ -235,8 +255,8 @@ bool HloReplicationAnalysis::ComputeHloReplicationOnComputation( ShapeUtil::ForEachSubshape( inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) { *shape_tree.mutable_element(index) = - DetermineHloInstructionIsReplicated(inst, index, - hlo_replication_); + DetermineHloInstructionIsReplicated( + inst, index, cross_partition_spmd_, hlo_replication_); return Status::OK(); }); changed |= assign_or_combine_shapetree(std::move(shape_tree), inst); @@ -248,23 +268,39 @@ bool HloReplicationAnalysis::ComputeHloReplicationOnComputation( void HloReplicationAnalysis::ComputeHloReplication() { // Add entry parameters to the above sets according to user annotation. + // Replicated modules read from `parameter_replicated_at_leaf_buffers` whereas + // SPMD partitioned modules read from HloSharding attributes. auto entry = module_->entry_computation(); for (int i = 0; i < entry->num_parameters(); ++i) { auto param = entry->parameter_instruction(i); ShapeTree shape_tree(param->shape(), false); - const auto& replication = param->parameter_replicated_at_leaf_buffers(); - int leaf_index = 0; - ShapeUtil::ForEachSubshape( - param->shape(), [&](const Shape& subshape, const ShapeIndex& index) { - if (!ShapeUtil::IsLeafIndex(param->shape(), index)) { + if (cross_partition_spmd_ && param->has_sharding()) { + auto sharding_tree = + param->sharding().AsShapeTree(param->shape()).ValueOrDie(); + ShapeUtil::ForEachSubshape( + param->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (!ShapeUtil::IsLeafIndex(param->shape(), index)) { + return Status::OK(); + } + *shape_tree.mutable_element(index) = + sharding_tree.element(index).IsReplicated(); return Status::OK(); - } - if (replication && replication->at(leaf_index)) { - *shape_tree.mutable_element(index) = true; - } - ++leaf_index; - return Status::OK(); - }); + }); + } else if (!cross_partition_spmd_) { + const auto& replication = param->parameter_replicated_at_leaf_buffers(); + int leaf_index = 0; + ShapeUtil::ForEachSubshape( + param->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (!ShapeUtil::IsLeafIndex(param->shape(), index)) { + return Status::OK(); + } + if (replication && replication->at(leaf_index)) { + *shape_tree.mutable_element(index) = true; + } + ++leaf_index; + return Status::OK(); + }); + } hlo_replication_[param] = std::move(shape_tree); } ComputeHloReplicationOnComputation(entry, @@ -281,17 +317,18 @@ bool HloReplicationAnalysis::HloInstructionIsReplicatedAt( } /* static */ StatusOr> -HloReplicationAnalysis::Run(const HloModule* module) { +HloReplicationAnalysis::Run(const HloModule* module, + bool cross_partition_spmd) { const absl::flat_hash_set empty; - return Run(module, &empty); + return Run(module, cross_partition_spmd, &empty); } /* static */ StatusOr> -HloReplicationAnalysis::Run(const HloModule* module, +HloReplicationAnalysis::Run(const HloModule* module, bool cross_partition_spmd, const absl::flat_hash_set* loops_known_with_same_iterations) { - auto analysis = absl::WrapUnique( - new HloReplicationAnalysis(module, loops_known_with_same_iterations)); + auto analysis = absl::WrapUnique(new HloReplicationAnalysis( + module, cross_partition_spmd, loops_known_with_same_iterations)); analysis->ComputeHloReplication(); return analysis; } diff --git a/tensorflow/compiler/xla/service/hlo_replication_analysis.h b/tensorflow/compiler/xla/service/hlo_replication_analysis.h index 3175fc35102..18b2363e454 100644 --- a/tensorflow/compiler/xla/service/hlo_replication_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_replication_analysis.h @@ -25,32 +25,35 @@ limitations under the License. namespace xla { // An HLO pass that determines whether each instruction in the module outputs -// the same value across replicas. It propagates sources of replicated values to +// the same value across replicas or across partitions (depending on the value +// `cross_partition_spmd`). It propagates sources of replicated values to // the rest of the module, where sources include cross-replica-sum, annotated // entry parameters, and constants. class HloReplicationAnalysis { public: // Runs the analysis on module and returns the result or an error. static StatusOr> Run( - const HloModule* module); + const HloModule* module, bool cross_partition_spmd); // Same as above, but the caller can provide additional annotations: a set of // while loops that are known to have the same iteration counts across - // replicas. + // replicas or partitions. static StatusOr> Run( - const HloModule* module, const absl::flat_hash_set* - loops_known_with_same_iterations); + const HloModule* module, bool cross_partition_spmd, + const absl::flat_hash_set* + loops_known_with_same_iterations); // Returns if the HLO instruction outputs the same value (i.e., replicated) at - // the given index across all replicas. + // the given index across all replicas or partitions. bool HloInstructionIsReplicatedAt(const HloInstruction* inst, const ShapeIndex& index) const; private: - HloReplicationAnalysis(const HloModule* module, + HloReplicationAnalysis(const HloModule* module, bool cross_partition_spmd, const absl::flat_hash_set* loops_known_with_same_iterations) : module_(module), + cross_partition_spmd_(cross_partition_spmd), loops_known_with_same_iterations_(*loops_known_with_same_iterations) {} // Computes hlo_replication_. @@ -63,14 +66,25 @@ class HloReplicationAnalysis { const HloModule* module_; + // If true, run this replication analysis for replicated values across + // partitions (not across replicas) on an SPMD partitioned module. This means + // that HloInstructionIsReplicatedAt() returns true if the value is identical + // across partitions for each replica. The module-level parameter and root + // instructions may have HloSharding attributes that indicate whether values + // are identical across partitions. + // + // If false, HloReplicationAnalysis runs across replicas. + bool cross_partition_spmd_; + // A set of while loops that are known to have the same iteration counts - // across replicas. This is provided by the caller as additional annotations. + // across replicas or partitions. This is provided by the caller as additional + // annotations. const absl::flat_hash_set& loops_known_with_same_iterations_; // A map from each analyzed HLO instruction to a shape tree that represents - // whether the instruction outputs the same value across replicas at each - // shape index. + // whether the instruction outputs the same value across replicas or + // partitions at each shape index. absl::flat_hash_map> hlo_replication_; }; diff --git a/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc index 958e99dedb8..56cc8542ac4 100644 --- a/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc @@ -42,16 +42,30 @@ sum { ROOT add.2 = f32[] add(a, b) } +sum.u32 { + a = u32[] parameter(0) + b = u32[] parameter(1) + ROOT add.2 = u32[] add(a, b) +} + ENTRY entry { param = (f32[4096,4096]{1,0}, f32[4096,4096]{1,0}) parameter(0) get-tuple-element.2 = f32[4096,4096]{1,0} get-tuple-element(param), index=0 get-tuple-element.3 = f32[4096,4096]{1,0} get-tuple-element(param), index=1 after-all.1 = token[] after-all() + replica-id = u32[] replica-id() + partition-id = u32[] partition-id() infeed = (f32[4096,4096]{1,0}, token[]) infeed(after-all.1) get-tuple-element.5 = f32[4096,4096]{1,0} get-tuple-element(infeed), index=0 - dot = f32[4096,4096]{1,0} dot(get-tuple-element.5, get-tuple-element.3), lhs_contracting_dims={1}, rhs_contracting_dims={0} - all-reduce = f32[4096,4096]{1,0} all-reduce(dot), replica_groups={}, to_apply=sum + dot = f32[4096,4096]{1,0} dot(get-tuple-element.5, get-tuple-element.3), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + all-reduce = f32[4096,4096]{1,0} all-reduce(dot), replica_groups={}, + to_apply=sum subtract = f32[4096,4096]{1,0} subtract(get-tuple-element.3, all-reduce) + all-reduce-partitions = u32[] all-reduce(partition-id), channel_id=1, + to_apply=sum.u32 + all-reduce-subgroup = u32[] all-reduce(partition-id), + replica_groups={{0,1},{2,3}}, to_apply=sum.u32 ROOT add = f32[4096,4096]{1,0} add(get-tuple-element.2, subtract) } )"; @@ -62,7 +76,8 @@ ENTRY entry { param->set_parameter_replicated_at_leaf_buffers( absl::Span{false, true}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, - HloReplicationAnalysis::Run(module.get())); + HloReplicationAnalysis::Run( + module.get(), /*cross_partition_spmd=*/false)); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "get-tuple-element.2"), {})); EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( @@ -77,6 +92,92 @@ ENTRY entry { FindInstruction(module.get(), "subtract"), {})); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "add"), {})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "replica-id"), {})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "partition-id"), {})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "all-reduce-partitions"), {})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "all-reduce-subgroup"), {})); +} + +TEST_F(HloReplicationAnalysisTest, NoControlFlowSPMD) { + const string module_str = R"( +HloModule NoControlFlow + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) +} + +sum.u32 { + a = u32[] parameter(0) + b = u32[] parameter(1) + ROOT add.2 = u32[] add(a, b) +} + +ENTRY entry { + param = (f32[4096,4096]{1,0}, f32[4096,4096]{1,0}) parameter(0), + sharding={{maximal device=0}, {replicated}} + get-tuple-element.2 = f32[4096,4096]{1,0} get-tuple-element(param), index=0 + get-tuple-element.3 = f32[4096,4096]{1,0} get-tuple-element(param), index=1 + after-all.1 = token[] after-all() + replica-id = u32[] replica-id() + partition-id = u32[] partition-id() + infeed = (f32[4096,4096]{1,0}, token[]) infeed(after-all.1) + get-tuple-element.5 = f32[4096,4096]{1,0} get-tuple-element(infeed), index=0 + dot = f32[4096,4096]{1,0} dot(get-tuple-element.5, get-tuple-element.3), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + all-reduce = f32[4096,4096]{1,0} all-reduce(dot), replica_groups={}, + to_apply=sum + all-reduce-subgroup = f32[4096,4096]{1,0} all-reduce(dot), + replica_groups={{0,1},{2,3}}, to_apply=sum + all-reduce-partitions = f32[4096,4096]{1,0} all-reduce(get-tuple-element.2), + channel_id=1, to_apply=sum + subtract = f32[4096,4096]{1,0} subtract(get-tuple-element.3, + all-reduce-partitions) + all-reduce-same-operand = u32[] all-reduce(replica-id), to_apply=sum.u32 + all-reduce-same-operand-subgroup = u32[] all-reduce(replica-id), + replica_groups={{0,1},{2,3}}, to_apply=sum.u32 + all-reduce-different-operand = u32[] all-reduce(partition-id), + to_apply=sum.u32 + ROOT add = f32[4096,4096]{1,0} add(get-tuple-element.2, subtract) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr analysis, + HloReplicationAnalysis::Run(module.get(), /*cross_partition_spmd=*/true)); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "get-tuple-element.2"), {})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "get-tuple-element.3"), {})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "get-tuple-element.5"), {})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "dot"), {})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "all-reduce"), {})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "subtract"), {})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "add"), {})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "replica-id"), {})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "partition-id"), {})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "all-reduce-partitions"), {})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "all-reduce-same-operand"), {})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "all-reduce-same-operand-subgroup"), {})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "all-reduce-different-operand"), {})); } TEST_F(HloReplicationAnalysisTest, NestedCall) { @@ -111,7 +212,8 @@ ENTRY entry { param->set_parameter_replicated_at_leaf_buffers( absl::Span{true, false}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, - HloReplicationAnalysis::Run(module.get())); + HloReplicationAnalysis::Run( + module.get(), /*cross_partition_spmd=*/false)); EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "get-tuple-element"), {})); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( @@ -163,7 +265,8 @@ ENTRY SimpleWhileLoop { param->set_parameter_replicated_at_leaf_buffers( absl::Span{true, true}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, - HloReplicationAnalysis::Run(module.get())); + HloReplicationAnalysis::Run( + module.get(), /*cross_partition_spmd=*/false)); EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "tuple"), {0})); EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( @@ -212,7 +315,8 @@ ENTRY WhileLoopParameterAliasingNonReplicatedOutput { param->set_parameter_replicated_at_leaf_buffers( absl::Span{true, true}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, - HloReplicationAnalysis::Run(module.get())); + HloReplicationAnalysis::Run( + module.get(), /*cross_partition_spmd=*/false)); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "multiply"), {})); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( @@ -258,7 +362,8 @@ ENTRY WhileLoopDifferentCondition { param->set_parameter_replicated_at_leaf_buffers( absl::Span{true, true}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, - HloReplicationAnalysis::Run(module.get())); + HloReplicationAnalysis::Run( + module.get(), /*cross_partition_spmd=*/false)); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "while"), {0})); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( @@ -307,7 +412,8 @@ ENTRY entry { param->set_parameter_replicated_at_leaf_buffers( absl::Span{true, true, true, true, false, true, true}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, - HloReplicationAnalysis::Run(module.get())); + HloReplicationAnalysis::Run( + module.get(), /*cross_partition_spmd=*/false)); EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "tuple"), {0})); EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( @@ -371,7 +477,8 @@ ENTRY entry { param->set_parameter_replicated_at_leaf_buffers( absl::Span{true, true, true, true, true, true}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, - HloReplicationAnalysis::Run(module.get())); + HloReplicationAnalysis::Run( + module.get(), /*cross_partition_spmd=*/false)); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "tuple"), {0})); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( @@ -409,7 +516,8 @@ ENTRY entry { param->set_parameter_replicated_at_leaf_buffers( absl::Span{true, false, true, true, true}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, - HloReplicationAnalysis::Run(module.get())); + HloReplicationAnalysis::Run( + module.get(), /*cross_partition_spmd=*/false)); EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "tuple-select"), {0})); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( @@ -435,7 +543,8 @@ ENTRY entry { param->set_parameter_replicated_at_leaf_buffers( absl::Span{true, true, true, true, false}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, - HloReplicationAnalysis::Run(module.get())); + HloReplicationAnalysis::Run( + module.get(), /*cross_partition_spmd=*/false)); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "tuple-select"), {0})); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 3073c68c975..552c8eb1ae5 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -89,10 +89,14 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_evaluator", "//tensorflow/compiler/xla/service:hlo_execution_profile", "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:maybe_owning_device_memory", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:macros", + "//tensorflow/core/platform:mutex", + "//tensorflow/core/platform:types", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 0dab86d986c..f82a439fdb0 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h" +#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -39,24 +40,39 @@ namespace interpreter { InterpreterExecutable::InterpreterExecutable( std::unique_ptr hlo_module, std::unique_ptr evaluator) - : Executable(std::move(hlo_module), /*hlo_profile_printer=*/nullptr, + : Executable(std::move(hlo_module), /*hlo_profile_printer_data=*/nullptr, /*hlo_profile_index_map=*/nullptr), evaluator_(std::move(evaluator)) {} InterpreterExecutable::~InterpreterExecutable() {} -StatusOr InterpreterExecutable::ExecuteAsyncOnStream( +StatusOr InterpreterExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - absl::Span arguments, + std::vector> arguments, HloExecutionProfile* hlo_execution_profile) { se::Stream* stream = run_options->stream(); se::StreamExecutor* executor = stream->parent(); const se::Platform* platform = executor->platform(); + // Convert the ShapeTree to a ShapedBuffer. We do this so we can call + // TransferManager methods below. + std::vector argument_buffers; + argument_buffers.reserve(arguments.size()); + for (const ShapeTree& arg : arguments) { + argument_buffers.push_back(ShapedBuffer(arg.shape(), arg.shape(), + /*platform=*/nullptr, + /*device_ordinal=*/0)); + auto in_it = arg.begin(); + auto out_it = argument_buffers.back().buffers().begin(); + for (; in_it != arg.end(); ++in_it, ++out_it) { + out_it->second = in_it->second.AsDeviceMemoryBase(); + } + } + VLOG(1) << "Execute " << module().name(); if (VLOG_IS_ON(2)) { - for (const auto& a : arguments) { - VLOG(2) << "-- argument " << *a; + for (const auto& a : argument_buffers) { + VLOG(2) << "-- argument " << a; } } @@ -71,7 +87,7 @@ StatusOr InterpreterExecutable::ExecuteAsyncOnStream( // Check that the args have the right shape. for (int64 i = 0; i < computation->num_parameters(); ++i) { const auto& expected_shape = computation->parameter_instruction(i)->shape(); - const auto& actual_shape = arguments[i]->on_device_shape(); + const auto& actual_shape = argument_buffers[i].on_device_shape(); if (!Shape::Equal().MinorToMajorOnlyInLayout()(expected_shape, actual_shape)) { return InvalidArgument( @@ -90,7 +106,7 @@ StatusOr InterpreterExecutable::ExecuteAsyncOnStream( for (int64 p = 0; p < computation->num_parameters(); ++p) { TF_ASSIGN_OR_RETURN(Literal arg_literal, transfer_manager->TransferLiteralFromDevice( - run_options->stream(), *arguments[p])); + run_options->stream(), argument_buffers[p])); arg_literals.push_back(std::move(arg_literal)); } @@ -119,7 +135,16 @@ StatusOr InterpreterExecutable::ExecuteAsyncOnStream( profile->set_compute_time_ns(std::max(nanoseconds, 1.0)); } - return std::move(result); + std::vector buffers_to_free; + for (ShapeTree& argument : arguments) { + for (std::pair& buffer : argument) { + auto maybe_owning_buffer = buffer.second.Release(); + if (maybe_owning_buffer) { + buffers_to_free.push_back(std::move(*maybe_owning_buffer)); + } + } + } + return ExecutionOutput(std::move(result), std::move(buffers_to_free), {}, {}); } /*static*/ int64 InterpreterExecutable::ShapeSizeBytes(const Shape& shape) { diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index ba010de76bd..1bea6773fdd 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -46,9 +46,9 @@ class InterpreterExecutable : public Executable { std::unique_ptr evaluator); ~InterpreterExecutable() override; - StatusOr ExecuteAsyncOnStream( + StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - absl::Span arguments, + std::vector> arguments, HloExecutionProfile* hlo_execution_profile) override LOCKS_EXCLUDED(evaluator_lock_); diff --git a/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc b/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc index 5fe5fea71ac..c4bf48bcc00 100644 --- a/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc +++ b/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc @@ -17,7 +17,8 @@ limitations under the License. #include "absl/types/variant.h" namespace xla { -tensorflow::se::DeviceMemoryBase MaybeOwningDeviceMemory::AsDeviceMemoryBase() { +tensorflow::se::DeviceMemoryBase MaybeOwningDeviceMemory::AsDeviceMemoryBase() + const { if (HasOwnership()) { return *absl::get(mem_); } else { diff --git a/tensorflow/compiler/xla/service/maybe_owning_device_memory.h b/tensorflow/compiler/xla/service/maybe_owning_device_memory.h index 8edd64cf681..7d23d178130 100644 --- a/tensorflow/compiler/xla/service/maybe_owning_device_memory.h +++ b/tensorflow/compiler/xla/service/maybe_owning_device_memory.h @@ -49,7 +49,7 @@ class MaybeOwningDeviceMemory { // Fetches the underlying DeviceMemoryBase from a MaybeOwningDeviceMemory. The // caller of this function is *not* responsible for freeing the memory. - tensorflow::se::DeviceMemoryBase AsDeviceMemoryBase(); + tensorflow::se::DeviceMemoryBase AsDeviceMemoryBase() const; // Release the tensorflow::se::OwningDeviceMemory without freeing it, and // moves the ownership of the memory buffer from the object to the caller. diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index 751d258142a..c1dc635eb81 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -244,6 +244,32 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { auto colocated_intervals = GetSortedColocatedIntervals(interval); + if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) { + VLOG(4) << "Interval " << interval.buffer->ToShortString() + << " is reserved in the alternate memory. Total reserved bytes = " + << reserved_in_bytes_; + for (const BufferInterval* colocated_interval : colocated_intervals) { + const HloValue* value = colocated_interval->buffer; + // Color all of the aliased reserved buffers here because reserved + // alternate memory allocations will not have an entry in preset + // allocations that is normally used for coloring. + for (auto& position : value->positions()) { + VLOG(3) << "Coloring " << position.ToString(); + Shape* shape = ShapeUtil::GetMutableSubshape( + position.instruction->mutable_shape(), position.index); + CHECK(shape->IsArray()) << "Coloring a shape that is not an array: " + << position.ToString(); + shape->mutable_layout()->set_memory_space( + options_.alternate_memory_space); + } + } + // Increment the reserved part of alternate memory so that it is not + // available for other buffers. Since all colocated intervals should have + // the same size, just use the first one. + reserved_in_bytes_ += options_.size_fn(*colocated_intervals[0]->buffer); + continue; + } + if (colocated_intervals.size() > 1 && !options_.allocate_across_sequential_calls) { VLOG(4) << "Not allocating " << interval.buffer->ToShortString() @@ -366,10 +392,8 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { } void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() { - // Go through the parameters and outputs and pin them to default memory by - // adding a required assignment. - // TODO(berkin): If these values are already marked alternate memory, use - // those instead. + // Go through the parameters and outputs and pin them to the corresponding + // memory by adding a required assignment. const HloModule& module = alias_analysis_.dataflow_analysis().module(); const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); HloComputation* entry_computation = module.entry_computation(); @@ -379,16 +403,22 @@ void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() { instruction_schedule.at(parameter_instruction); ShapeUtil::ForEachSubshape( parameter_instruction->shape(), - [&](const Shape& /*subshape*/, const ShapeIndex& index) { + [&](const Shape& subshape, const ShapeIndex& index) { + MemorySpace memory_space = MemorySpace::kDefault; + if (subshape.has_layout() && subshape.layout().memory_space() == + options_.alternate_memory_space) { + memory_space = MemorySpace::kAlternate; + } for (const HloBuffer* buffer : alias_analysis_.ComputeBuffersAt(parameter_instruction, index)) { for (const HloValue* value : buffer->values()) { VLOG(3) << "Adding required assignment for parameter value = " << value->ToShortString() - << " time = " << parameter_instruction_time; + << " time = " << parameter_instruction_time << " space = " + << (memory_space == MemorySpace::kDefault ? "def" + : "alt"); required_assignments_[value].push_back( - {/*memory_space=*/MemorySpace::kDefault, - /*time=*/parameter_instruction_time}); + {memory_space, /*time=*/parameter_instruction_time}); } } }); @@ -397,21 +427,56 @@ void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() { int64 root_instruction_time = instruction_schedule.at(root_instruction); ShapeUtil::ForEachSubshape( root_instruction->shape(), - [&](const Shape& /*subshape*/, const ShapeIndex& index) { + [&](const Shape& subshape, const ShapeIndex& index) { + MemorySpace memory_space = MemorySpace::kDefault; + if (subshape.has_layout() && subshape.layout().memory_space() == + options_.alternate_memory_space) { + memory_space = MemorySpace::kAlternate; + } for (const HloBuffer* buffer : alias_analysis_.ComputeBuffersAt(root_instruction, index)) { for (const HloValue* value : buffer->values()) { VLOG(3) << "Adding required assignment for output value = " << value->ToShortString() - << " time = " << root_instruction_time; + << " time = " << root_instruction_time << " space = " + << (memory_space == MemorySpace::kDefault ? "def" : "alt"); required_assignments_[value].push_back( - {/*memory_space=*/MemorySpace::kDefault, - /*time=*/root_instruction_time}); + {memory_space, /*time=*/root_instruction_time}); } } }); } +bool AlternateMemoryBestFitHeap::AreIntervalsReservedInAlternateMemory( + absl::Span colocated_intervals) const { + auto is_position_in_alternate_memory = [&](const HloPosition& position) { + const Shape& shape = position.shape(); + return shape.has_layout() && + shape.layout().memory_space() == options_.alternate_memory_space; + }; + + const HloModule& module = alias_analysis_.dataflow_analysis().module(); + const HloComputation* entry_computation = module.entry_computation(); + const HloInstruction* root_instruction = + entry_computation->root_instruction(); + for (const BufferInterval* colocated_interval : colocated_intervals) { + const HloValue* value = colocated_interval->buffer; + if (value->defining_instruction()->opcode() == HloOpcode::kParameter && + value->defining_instruction()->parent() == entry_computation && + is_position_in_alternate_memory(value->defining_position())) { + return true; + } + + for (const HloPosition& position : value->positions()) { + if (position.instruction == root_instruction && + is_position_in_alternate_memory(position)) { + return true; + } + } + } + return false; +} + void AlternateMemoryBestFitHeap::CommitPendingChunks() { for (auto interval_and_chunk : pending_chunks_) { VLOG(3) << "Committing chunk: " << interval_and_chunk.first.start << "-" @@ -482,8 +547,11 @@ bool AlternateMemoryBestFitHeap::FindAllocation( if (required_assignment_it != required_assignments_.end()) { for (const RequiredMemoryAssignment& required_assignment : required_assignment_it->second) { - VLOG(3) << "Required assignment at time = " << required_assignment.time; - // TODO(berkin): Handle memory requirements for alternate memory space. + VLOG(3) << "Required assignment at time = " << required_assignment.time + << " space = " + << (required_assignment.memory_space == MemorySpace::kDefault + ? "def" + : "alt"); if (required_assignment.memory_space == MemorySpace::kDefault) { if (required_assignment.time == start_time) { definition_requires_buffer_in_default_mem = true; @@ -613,7 +681,7 @@ bool AlternateMemoryBestFitHeap::FindAllocation( } ChunkCandidate chunk_candidate = FindChunkCandidate(alternate_mem_interval); // Check if the new heap size fits within limits. - if (chunk_candidate.heap_size < options_.max_size_in_bytes) { + if (chunk_candidate.heap_size < available_heap_size()) { VLOG(3) << "Move the buffer to alternate memory at " << alternate_mem_interval.start << ". Offset = " << chunk_candidate.chunk.offset @@ -748,7 +816,7 @@ bool AlternateMemoryBestFitHeap::TryAllocatingInAlternateMemoryNoCopy( alternate_mem_interval.end = end_time; // Check if the new heap size fits within limits. Also ensure if a // preferred offset was provided, that offset was used. - if (chunk_candidate.heap_size <= options_.max_size_in_bytes && + if (chunk_candidate.heap_size <= available_heap_size() && (preferred_offset == -1 || preferred_offset == chunk_candidate.chunk.offset)) { VLOG(3) << "Keep the buffer in alternate memory. Offset = " @@ -1130,7 +1198,15 @@ Status MemorySpaceAssignment::SimplifyGraph() { // Ensure the exported preset assignments don't contain a refence to // the removed instruction. preset_assignments_->RemoveAssignmentForInstruction(instruction); - flattened_instruction_sequence_.remove_instruction(instruction); + // Instead of deleting the instruction from the schedule, replace it + // with a nullptr. This is needed because FixSchedule relies on the + // logical time that is the index into flattened_instructions_ for + // scheduling asynchronous copies. + auto instruction_it = + absl::c_find(flattened_instructions_, instruction); + if (instruction_it != flattened_instructions_.end()) { + *instruction_it = nullptr; + } TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); computation_modified = true; } else if (instruction->opcode() == HloOpcode::kGetTupleElement) { @@ -1228,12 +1304,12 @@ void MemorySpaceAssignment::ScheduleAsynchronousCopies() { // If the copy start doesn't happen to be scheduled at the correct // computation, delay it until the correct computation starts. - const auto& flattened_instructions = - flattened_instruction_sequence_.instructions(); int64 copy_start_schedule_after = copy_allocation->copy_start_schedule_after(); + // Accessing flattened_instructions_ here without checking if it is + // nullptr is safe because this method is called before SimplifyGraph. while (copy_allocation->instruction()->parent() != - flattened_instructions[copy_start_schedule_after]->parent()) { + flattened_instructions_[copy_start_schedule_after]->parent()) { VLOG(4) << "Delaying CopyStart (" << copy_start_schedule_after << " to " << (copy_start_schedule_after + 1) << ") for " << copy_allocation->copy_start()->ToString() @@ -1264,8 +1340,7 @@ Status MemorySpaceAssignment::FixSchedule() { VLOG(4) << "Scheduling: " << computation->ToString(); for (int64 instruction_index = 0; - instruction_index < - flattened_instruction_sequence_.instructions().size(); + instruction_index < flattened_instructions_.size(); ++instruction_index) { auto insts_before_iter = schedule_before_.find(instruction_index); if (insts_before_iter != schedule_before_.end()) { @@ -1276,10 +1351,11 @@ Status MemorySpaceAssignment::FixSchedule() { } } } - HloInstruction* instruction = - flattened_instruction_sequence_.instructions()[instruction_index]; - // Insert only if not previously inserted. - if (!inserted_instructions.contains(instruction) && + HloInstruction* instruction = flattened_instructions_[instruction_index]; + // Insert only if it is not deleted (SimplifyGraph sets it to nullptr if + // it was deleted) and not previously inserted. + if (instruction != nullptr && + !inserted_instructions.contains(instruction) && instruction->parent() == computation) { EnsureInstructionAndOperandsInserted(instruction, &new_sequence, &inserted_instructions); diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index bfc91664bea..20551feb715 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -450,7 +450,8 @@ class MemorySpaceAssignment { absl::Span flattened_instructions) : module_(module), alternate_memory_space_(alternate_memory_space), - flattened_instruction_sequence_(flattened_instructions), + flattened_instructions_(flattened_instructions.begin(), + flattened_instructions.end()), preset_assignments_(absl::make_unique()) {} // Process calls Process methods of the allocations after the allocations have @@ -479,7 +480,7 @@ class MemorySpaceAssignment { HloModule* module_; int64 alternate_memory_space_; - HloInstructionSequence flattened_instruction_sequence_; + std::vector flattened_instructions_; AllocationMap allocation_map_; std::unique_ptr preset_assignments_; @@ -546,6 +547,12 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { // Adds input and outputs as required assignments. void AddInputAndOutputRequiredAssignments(); + // Returns true if the colocated intervals in the argument are in a parameter + // or root instruction of the entry computation and are reserved by the user + // to be in the alternate memory space. + bool AreIntervalsReservedInAlternateMemory( + absl::Span colocated_intervals) const; + // Given a buffer interval, returns the colocated intervals. Unlike the // similar GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations, it // returns the colocated intervals sorted by scheduled time. @@ -574,6 +581,11 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { const ChunkCandidate& chunk_candidate); void CommitPendingChunks(); + // Returns the available heap size in the alternate memory. + int64 available_heap_size() const { + return options_.max_size_in_bytes - reserved_in_bytes_; + } + MemorySpaceAssignment::AllocationMap* allocation_map_; const MemorySpaceAssignment::Options& options_; const HloAliasAnalysis& alias_analysis_; @@ -587,6 +599,8 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { // and outputs). absl::flat_hash_map> required_assignments_; + // Number of bytes reserved in alternate memory space. + int64 reserved_in_bytes_ = 0; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index 6041b96636e..068834e5701 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -2032,6 +2032,167 @@ TEST_P(MemorySpaceAssignmentTest, SimpleWhileTupleTest) { EXPECT_THAT(while0, op::ShapeWithLayout(t_s32_f32v1_in_default_mem)); } +TEST_P(MemorySpaceAssignmentTest, EvictionsShouldntBeDelayed) { + // This test reproduces an eviction scheduling bug where evictions to default + // memory can happen later than intended, causing memory corruption. This test + // is a variant of MemoryBoundednessBufferIntervalCompare but uses f32[4,3] + // tensors instead, so at most two tensors should fit in the alternate memory + // space at a given time. We have a number of redundant operations + // (tanh_redundant ops) that do not have users. The bug was due to + // SimplifyGraph removing dead instructions, and removing them from the + // schedule. However, the CopyStart/CopyDone insertion relies on the schedule + // indexes, so they could be inserted too late. + HloComputation::Builder builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 3}); + HloInstruction* p0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); + HloInstruction* tanh0 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0)); + HloInstruction* tanh_redundant0 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0)); + HloInstruction* tanh_redundant1 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0)); + HloInstruction* tanh_redundant2 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0)); + HloInstruction* tanh_redundant3 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0)); + HloInstruction* tanh_redundant4 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0)); + HloInstruction* tanh_redundant5 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0)); + HloInstruction* tanh_redundant6 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0)); + HloInstruction* negate0 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, tanh0)); + HloInstruction* tanh1 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kTanh, negate0)); + HloInstruction* negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); + HloInstruction* tanh2 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh1)); + HloInstruction* negate2 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); + HloInstruction* tanh3 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh2)); + HloInstruction* negate3 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); + HloInstruction* tuple = builder.AddInstruction( + HloInstruction::CreateTuple({tanh3, negate3, tanh0})); + + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence( + computation, + {p0, tanh0, tanh_redundant0, tanh_redundant1, tanh_redundant2, + tanh_redundant3, tanh_redundant4, tanh_redundant5, tanh_redundant6, + negate0, tanh1, negate1, tanh2, negate2, tanh3, negate3, tuple}); + TF_CHECK_OK(module->set_schedule(schedule)); + + AssignMemorySpaceUsingCostAnalysis(module.get()); + + TF_ASSERT_OK_AND_ASSIGN(auto alias_analysis, + HloAliasAnalysis::Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_live_range, + HloLiveRange::Run(module->schedule(), *alias_analysis, + module->entry_computation())); + + std::vector num_live_buffers_in_alternate_mem( + hlo_live_range->flattened_instruction_sequence().size() + 1, 0); + + // Go through each value and for those that are allocated in the alternate + // memory space, increment (inclusive) num_live_buffers_in_alternate_mem for + // every time step that they are live. + for (const HloValue* value : alias_analysis->dataflow_analysis().values()) { + const Shape& shape = value->shape(); + if (!shape.has_layout() || + shape.layout().memory_space() == kDefaultMemorySpace) { + continue; + } + + HloLiveRange::TimeBound time_bound = + hlo_live_range->buffer_live_ranges().at(value); + for (int i = time_bound.start; i <= time_bound.end; ++i) { + ++num_live_buffers_in_alternate_mem[i]; + } + } + + // The test memory can at most hold two f32[4,3] buffers at a time. If there + // is more than that, it means we have memory corruption. + for (int i = 0; i < num_live_buffers_in_alternate_mem.size(); ++i) { + EXPECT_LE(num_live_buffers_in_alternate_mem[i], 2); + } +} + +TEST_P(MemorySpaceAssignmentTest, + InputOutputsInAlternateMemShouldntBeAssigned) { + // When input/outputs are marked to be in the alternate memory (e.g. + // go/tpu-fast-mem-inference), do not allocate those and assume they will live + // in the alternate memory for the entire computation. The BufferAssignment + // pass, which is run after this, will allocate those buffers. + HloComputation::Builder builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); + Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout( + F32, {2, 3}, + /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, + kAlternateMemorySpace); + // p0 is in the default memory space. + HloInstruction* p0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); + // p1 is in the alternate memory space. + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, shape_in_alternate_mem, "p1")); + HloInstruction* negate0 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0)); + HloInstruction* negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); + HloInstruction* negate2 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); + HloInstruction* negate3 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); + HloInstruction* negate4 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); + HloInstruction* negate5 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4)); + HloInstruction* negate6 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5)); + HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( + shape_in_alternate_mem, HloOpcode::kAdd, negate6, p1)); + // Index {0} of the root instruction is in the alternate memory space, index + // {1} is in the default memory space. + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({add, negate5})); + + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, + {p0, p1, negate0, negate1, negate2, negate3, negate4, + negate5, negate6, add, tuple}); + TF_CHECK_OK(module->set_schedule(schedule)); + + std::unique_ptr preset_assignments = + AssignMemorySpace(module.get()); + + // Ensure that p1 is in the alternate memory and add, which has p1 as an + // operand, has a direct dependency to p1 (no CopyStart/CopyDone). + EXPECT_THAT(p1, op::ShapeWithLayout(shape_in_alternate_mem)); + EXPECT_THAT(add, op::Add(op::Negate(), op::Parameter(1))); + // Make sure add is still in the alternate memory space. + EXPECT_THAT(add, op::ShapeWithLayout(shape_in_alternate_mem)); + + // Check the preset assignments and ensure the inputs/outputs in the alternate + // memory space aren't in the preset assignments. Inputs/outputs in the + // alternate memory space are left to BufferAssignment to be allocated. + for (const auto& position_and_chunk : preset_assignments->chunks()) { + const HloPosition& position = position_and_chunk.first; + EXPECT_NE(position.instruction, p1); + EXPECT_NE(position.instruction, add); + } +} + INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation, MemorySpaceAssignmentTest, ::testing::Values(false, true)); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index ea887926338..7cbbb3ec44e 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -46,7 +46,7 @@ Status LowerLHLOToGPU(mlir::ModuleOp module) { // Transform element-wise operations to LinAlg. pm.addPass(::mlir::xla_lhlo::createLegalizeToLinalgPass()); // Go from affine to normal loops. - pm.addPass(::mlir::linalg::createLowerLinalgToLoopsPass()); + pm.addPass(::mlir::linalg::createConvertLinalgToLoopsPass()); // Lower affine to ordinary loops. pm.addPass(::mlir::createLowerAffinePass()); // Move constants out of the loop. diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc index ddfa28a9b42..b035a8ddcb5 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc @@ -344,15 +344,26 @@ Status InsertBufferLoadPreduleIntoKernel( loc, entry_type, builder.getI64IntegerAttr(extent.value())); builder.create(loc, extentValue, shapeEntryPtr); } - // Finally, fill the strides with all ones. + // Finally, fill the strides. + // TODO(b/137624192): Take assigned layout into account. entry_type = struct_type.getStructElementType(4).getArrayElementType(); - for (int64 idx = 0; idx < shape.rank(); ++idx) { + Value* accumulator = nullptr; + for (int64 idx = shape.rank() - 1; idx >= 0; --idx) { auto indexValue = builder.create( loc, offset_type, builder.getI64IntegerAttr(idx)); auto strideEntryPtr = builder.create( loc, entry_type, descPtr, llvm::ArrayRef{zero, strideIndex, indexValue}); - builder.create(loc, one, strideEntryPtr); + if (accumulator) { + auto strideValue = builder.create( + loc, entry_type, + builder.getI64IntegerAttr(shape.dimensions(idx + 1))); + accumulator = builder.create( + loc, entry_type, accumulator, strideValue); + } else { + accumulator = one; + } + builder.create(loc, accumulator, strideEntryPtr); } // Now we can use the descriptor instead of the original argument. value->replaceAllUsesWith(descPtr); diff --git a/tensorflow/compiler/xla/tests/collective_ops_test.cc b/tensorflow/compiler/xla/tests/collective_ops_test.cc index a2de65502b4..8de508e876e 100644 --- a/tensorflow/compiler/xla/tests/collective_ops_test.cc +++ b/tensorflow/compiler/xla/tests/collective_ops_test.cc @@ -41,7 +41,7 @@ using ::testing::UnorderedElementsAre; class CollectiveOpsTest : public HloTestBase { protected: std::unique_ptr MakeCrsModule( - int64 num_elems, std::vector> replica_groups, + const Shape& shape, std::vector> replica_groups, const HloModuleConfig& config, std::string op = "add", std::string datatype = "f32") { std::string hlo_template = R"( @@ -54,11 +54,11 @@ class CollectiveOpsTest : public HloTestBase { } ENTRY test_computation { - p = DATATYPE[NUM_ELEMS] parameter(0) - p2 = DATATYPE[NUM_ELEMS] bitcast(p) - crs = DATATYPE[NUM_ELEMS] all-reduce(p2), replica_groups=REPLICA_GROUPS, to_apply=apply_op - copy = DATATYPE[NUM_ELEMS] copy(crs) - ROOT out = DATATYPE[NUM_ELEMS] bitcast(copy) + p = SHAPE parameter(0) + p2 = SHAPE bitcast(p) + crs = SHAPE all-reduce(p2), replica_groups=REPLICA_GROUPS, to_apply=apply_op + copy = SHAPE copy(crs) + ROOT out = SHAPE bitcast(copy) } )"; std::vector replica_group_strs; @@ -66,71 +66,70 @@ class CollectiveOpsTest : public HloTestBase { replica_group_strs.push_back( absl::StrFormat("{%s}", absl::StrJoin(g, ","))); } - if (num_elems == 1) { + std::string shape_str = shape.ToString(/*print_layout=*/false); + if (shape_str == "f32[1]") { // Exercise the scalar codepath. hlo_template = absl::StrReplaceAll( hlo_template, - {{"DATATYPE[NUM_ELEMS] bitcast(p)", "DATATYPE[] bitcast(p)"}, - {"DATATYPE[NUM_ELEMS] all-reduce", "DATATYPE[] all-reduce"}, - {"DATATYPE[NUM_ELEMS] copy", "DATATYPE[] copy"}}); + {{"DATATYPE[SHAPE] bitcast(p)", "DATATYPE[] bitcast(p)"}, + {"DATATYPE[SHAPE] all-reduce", "DATATYPE[] all-reduce"}, + {"DATATYPE[SHAPE] copy", "DATATYPE[] copy"}}); } - return ParseAndReturnVerifiedModule( - absl::StrReplaceAll( - hlo_template, - {{"NUM_ELEMS", absl::StrCat(num_elems)}, - {"REPLICA_GROUPS", - absl::StrFormat("{%s}", - absl::StrJoin(replica_group_strs, ", "))}, - {"OP", op}, - {"DATATYPE", datatype}}), - config) - .ValueOrDie(); + std::string parameterized_hlo = absl::StrReplaceAll( + hlo_template, + {{"SHAPE", shape_str}, + {"REPLICA_GROUPS", + absl::StrFormat("{%s}", absl::StrJoin(replica_group_strs, ", "))}, + {"OP", op}, + {"DATATYPE", datatype}}); + return ParseAndReturnVerifiedModule(parameterized_hlo, config).ValueOrDie(); } template - void TestTwoReplicasOneOperand(std::string op, - std::vector input_value, - std::vector expected_value) { + void TestTwoReplicasOneOperand(std::string op, Literal input_value, + Literal expected_value) { const int kNumReplicas = 2; std::string dtype = primitive_util::LowercasePrimitiveTypeName( primitive_util::NativeToPrimitiveType()); auto config = GetModuleConfigForTest(); config.set_replica_count(kNumReplicas); - auto module = MakeCrsModule(/*num_elems=*/input_value.size(), - /*replica_groups=*/{}, config, - /*op=*/op, /*datatype=*/dtype); - auto literal = LiteralUtil::CreateR1(input_value); - auto expected = LiteralUtil::CreateR1(expected_value); + auto module = MakeCrsModule( + /*shape_str=*/input_value.shape(), + /*replica_groups=*/{}, config, + /*op=*/op, /*datatype=*/dtype); TF_ASSERT_OK_AND_ASSIGN(std::vector results, - ExecuteReplicated(std::move(module), {&literal}, + ExecuteReplicated(std::move(module), {&input_value}, /*num_replicas=*/kNumReplicas, /*use_threads=*/true)); for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) { - EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected, results[replica_idx], - ErrorSpec{1e-5, 1e-5})); + EXPECT_TRUE(LiteralTestUtil::NearOrEqual( + expected_value, results[replica_idx], ErrorSpec{1e-5, 1e-5})); } } template void TestAllOps() { auto cast = [&](int value) { return static_cast(value); }; - std::vector input_value = {cast(1), cast(2), cast(3)}; + auto to_literal = [&](absl::Span values) { + return LiteralUtil::CreateR1(values); + }; + Literal input_value = to_literal({cast(1), cast(2), cast(3)}); TestTwoReplicasOneOperand( "add", - /*input_value=*/input_value, - /*expected_value=*/{cast(2), cast(4), cast(6)}); + /*input_value=*/input_value.Clone(), + /*expected_value=*/to_literal({cast(2), cast(4), cast(6)})); TestTwoReplicasOneOperand( "multiply", - /*input_value=*/input_value, - /*expected_value=*/{cast(1), cast(4), cast(9)}); + /*input_value=*/input_value.Clone(), + /*expected_value=*/to_literal({cast(1), cast(4), cast(9)})); TestTwoReplicasOneOperand( "maximum", - /*input_value=*/input_value, - /*expected_value=*/{cast(1), cast(2), cast(3)}); + /*input_value=*/input_value.Clone(), + /*expected_value=*/to_literal({cast(1), cast(2), cast(3)})); TestTwoReplicasOneOperand( "minimum", - /*input_value=*/input_value, - /*expected_value=*/{cast(1), cast(2), cast(3)}); + /*input_value=*/input_value.Clone(), + /*expected_value=*/to_literal({cast(1), cast(2), cast(3)})); } }; @@ -169,10 +168,18 @@ static Eigen::half ToHalf(T value) { return static_cast(value); } +XLA_TEST_F(CollectiveOpsTest, AllReduce_sum_float32_2D) { + TestTwoReplicasOneOperand( + "add", + /*input_value=*/LiteralUtil::CreateR2({{1, 2}, {3, 4}}), + /*expected_value=*/LiteralUtil::CreateR2({{2, 4}, {6, 8}})); +} + XLA_TEST_F(CollectiveOpsTest, AllReduceSingleOutput_float32) { - TestTwoReplicasOneOperand("add", - /*input_value=*/{1}, - /*expected_value=*/{2}); + TestTwoReplicasOneOperand( + "add", + /*input_value=*/LiteralUtil::CreateR1({1}), + /*expected_value=*/LiteralUtil::CreateR1({2})); } XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int8) { @@ -227,12 +234,13 @@ XLA_TEST_F(CollectiveOpsTest, AllReduce_AllCombinations) { config.set_replica_count(devices.size()); config.set_static_device_assignment(device_assn); - auto module = MakeCrsModule(kNumElems, /*replica_groups=*/{}, config); - std::vector input_vec(kNumElems); absl::c_iota(input_vec, 0); auto input_literal = LiteralUtil::CreateR1(input_vec); + auto module = MakeCrsModule(input_literal.shape(), + /*replica_groups=*/{}, config); + TF_ASSERT_OK_AND_ASSIGN( std::vector results, ExecuteReplicated(std::move(module), {&input_literal}, @@ -270,7 +278,8 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllReduce_NcclChannelCaching)) { auto config = GetModuleConfigForTest(); config.set_replica_count(devices.size()); config.set_static_device_assignment(e.device_assn); - auto module = MakeCrsModule(kNumElems, /*replica_groups=*/{}, config); + auto module = MakeCrsModule(input_literal.shape(), + /*replica_groups=*/{}, config); e.executable = test_runner_ .CreateExecutable(std::move(module), /*run_hlo_passes=*/true) @@ -325,20 +334,21 @@ XLA_TEST_F(CollectiveOpsTest, AllReduce_ManyConcurrentAllReduces) { const int64 kNumThreads = 200; const int64 kRunsPerThread = 10; + std::vector input_vec(kNumElems); + absl::c_iota(input_vec, 0); + auto input_literal = LiteralUtil::CreateR1(input_vec); + auto config = GetModuleConfigForTest(); config.set_replica_count(2); auto executable = test_runner_ - .CreateExecutable( - MakeCrsModule(kNumElems, /*replica_groups=*/{}, config), - /*run_hlo_passes=*/true) + .CreateExecutable(MakeCrsModule(input_literal.shape(), + /*replica_groups=*/{}, config), + /*run_hlo_passes=*/true) .ValueOrDie(); std::vector devices = {0, 1}; auto device_assn = MakeDeviceAssn(devices); - std::vector input_vec(kNumElems); - absl::c_iota(input_vec, 0); - auto input_literal = LiteralUtil::CreateR1(input_vec); HloRunner::ReplicatedExecuteOptions opts; opts.num_replicas = devices.size(); opts.use_threads = true; @@ -368,11 +378,12 @@ XLA_TEST_F(CollectiveOpsTest, AllReduce_ThreeReplicaGroups) { auto config = GetModuleConfigForTest(); config.set_replica_count(4); - auto module = MakeCrsModule(/*num_elems=*/kNumElems, - /*replica_groups=*/{{0}, {1, 2}, {3}}, config); std::vector input_vec(kNumElems); absl::c_iota(input_vec, 0); auto input_literal = LiteralUtil::CreateR1(input_vec); + auto module = MakeCrsModule( + /*shape_str=*/input_literal.shape(), + /*replica_groups=*/{{0}, {1, 2}, {3}}, config); TF_ASSERT_OK_AND_ASSIGN( std::vector results, @@ -397,7 +408,8 @@ XLA_TEST_F(CollectiveOpsTest, ReplicaId) { const char* const kModuleStr = R"( HloModule test ENTRY test_computation { - ROOT id = u32[] replica-id() + id = u32[] replica-id() + ROOT out = u32[] copy(id) } )"; const int64 kNumReplicas = 4; diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 09d2d76cabf..da20d28ea81 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -95,6 +95,14 @@ tf_cc_binary( ], ) +tf_cc_binary( + name = "replay_computation_mlir_gpu", + deps = [ + ":replay_computation_library", + "//tensorflow/compiler/xla/service:mlir_gpu_plugin", + ], +) + tf_cc_binary( name = "replay_computation_interpreter", deps = [ diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 74859e11a79..588420eb1b6 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -237,6 +237,7 @@ tf_proto_library( make_default_target_header_only = True, protodeps = [ ":error_codes_proto_impl", + ":test_log_proto_impl", ":core_protos", "//tensorflow/core/framework:protos_all", "//tensorflow/core/lib/core:error_codes_proto", @@ -396,6 +397,7 @@ filegroup( "//tensorflow/core/platform:env.h", "//tensorflow/core/platform:file_statistics.h", "//tensorflow/core/platform:file_system.h", + "//tensorflow/core/platform:path.h", ], visibility = ["//visibility:private"], ) @@ -501,12 +503,6 @@ cc_library( cc_library( name = "lib", hdrs = [ - "lib/monitoring/collected_metrics.h", - "lib/monitoring/collection_registry.h", - "lib/monitoring/counter.h", - "lib/monitoring/gauge.h", - "lib/monitoring/metric_def.h", - "lib/monitoring/sampler.h", ":platform_base_hdrs", ":platform_env_hdrs", ":platform_file_system_hdrs", @@ -520,6 +516,7 @@ cc_library( "//tensorflow/core/lib/histogram:legacy_lib_histogram_all_headers", "//tensorflow/core/lib/io:legacy_lib_io_headers", "//tensorflow/core/lib/math:math_util.h", + "//tensorflow/core/lib/monitoring:legacy_lib_monitoring_lib_headers", "//tensorflow/core/lib/random:legacy_lib_random_headers", "//tensorflow/core/lib/strings:legacy_lib_string_headers", ], @@ -571,6 +568,24 @@ cc_library( ], ) +# TODO(gunan): Move this to core/util/BUILD once the file is created +cc_library( + name = "util_reporter", + srcs = ["util/reporter.cc"], + hdrs = ["util/reporter.h"], + # This should only be used in tensorflow/core/platform:test_benchmark + visibility = ["//tensorflow/core/platform:__subpackages__"], + deps = [ + ":test_log_proto_impl_cc", + "//tensorflow/core/platform:env", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:macros", + "//tensorflow/core/platform:mutex", + "//tensorflow/core/platform:str_util", + "//tensorflow/core/platform:types", + ], +) + # Test support library needed for all tests # This is currently public, but may be made internal in the # future. Try to avoid depending on it. @@ -578,7 +593,6 @@ cc_library( name = "test", testonly = 1, srcs = [ - "util/reporter.cc", "//tensorflow/core/platform:legacy_test_srcs", ], hdrs = [ @@ -1548,6 +1562,8 @@ filegroup( "//tensorflow/core/lib/io:legacy_lib_io_all_headers", "//tensorflow/core/lib/io:legacy_lib_io_all_srcs", "//tensorflow/core/lib/math:math_util.h", + "//tensorflow/core/lib/monitoring:legacy_lib_monitoring_all_headers", + "//tensorflow/core/lib/monitoring:legacy_lib_monitoring_all_srcs", "//tensorflow/core/lib/random:legacy_lib_random_all_headers", "//tensorflow/core/lib/random:legacy_lib_random_all_srcs", "//tensorflow/core/lib/strings:legacy_lib_strings_all_headers", @@ -1979,32 +1995,6 @@ tf_pyclif_proto_library( ] ] -# The following targets were moved to core/framework. The aliases are only temporary -# since moving existing users will require several CLs over several projects. - -[ - alias( - name = "framework_%s_pyclif%s" % (proto_name, target_suffix), - actual = "//tensorflow/core/framework:%s_pyclif%s" % (proto_name, target_suffix), - visibility = ["//visibility:public"], - ) - for target_suffix in [ - "", - "_pb2", - ] - for proto_name in [ - "cost_graph", - "tensor", - "kernel_def", - "node_def", - "function", - "graph", - "step_stats", - "types", - "variable", - ] -] - # ----------------------------------------------------------------------------- # Internal targets @@ -2095,6 +2085,7 @@ LIB_INTERNAL_PRIVATE_HEADERS = [ "//tensorflow/core/lib/gtl:legacy_lib_gtl_all_headers", "//tensorflow/core/lib/histogram:legacy_lib_histogram_all_headers", "//tensorflow/core/lib/hash:legacy_lib_hash_all_headers", + "//tensorflow/core/lib/monitoring:legacy_lib_monitoring_all_headers", "//tensorflow/core/lib/io:legacy_lib_io_all_headers", "//tensorflow/core/lib/random:legacy_lib_random_all_headers", "//tensorflow/core/lib/strings:legacy_lib_strings_all_headers", @@ -2116,12 +2107,11 @@ LIB_INTERNAL_PUBLIC_HEADERS = [ "//tensorflow/core/lib/gtl:legacy_lib_internal_public_gtl_headers", "//tensorflow/core/lib/hash:legacy_lib_internal_public_headers", "//tensorflow/core/lib/io:legacy_lib_internal_public_headers", - "lib/monitoring/mobile_counter.h", - "lib/monitoring/mobile_gauge.h", - "lib/monitoring/mobile_sampler.h", + "//tensorflow/core/lib/monitoring:legacy_lib_monitoring_lib_internal_public_headers", "//tensorflow/core/lib/random:legacy_lib_internal_public_random_headers", "//tensorflow/core/lib/strings:legacy_lib_internal_public_string_headers", "lib/wav/wav_io.h", + "//tensorflow/core/platform:blocking_counter.h", "//tensorflow/core/platform:demangle.h", "//tensorflow/core/platform:denormal.h", "//tensorflow/core/platform:host_info.h", @@ -2185,6 +2175,7 @@ cc_library( deps = tf_additional_lib_deps() + [ ":core_stringpiece", ":lib_proto_parsing", + ":util_reporter", # TODO(gunan): REMOVE as soon as cc_shared_library is supported. "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "//third_party/eigen3", @@ -2240,6 +2231,15 @@ cc_library( "//tensorflow/core/lib/io:zlib_inputstream", "//tensorflow/core/lib/io:zlib_outputbuffer", "//tensorflow/core/lib/math:math_util", + "//tensorflow/core/lib/monitoring:collected_metrics", + "//tensorflow/core/lib/monitoring:collection_registry", + "//tensorflow/core/lib/monitoring:counter", + "//tensorflow/core/lib/monitoring:gauge", + "//tensorflow/core/lib/monitoring:metric_def", + "//tensorflow/core/lib/monitoring:mobile_counter", + "//tensorflow/core/lib/monitoring:mobile_gauge", + "//tensorflow/core/lib/monitoring:mobile_sampler", + "//tensorflow/core/lib/monitoring:sampler", "//tensorflow/core/lib/random:exact_uniform_int", "//tensorflow/core/lib/random:philox", "//tensorflow/core/lib/random:philox_random", @@ -2255,6 +2255,8 @@ cc_library( "//tensorflow/core/lib/strings:strcat", "//tensorflow/core/lib/strings:stringprintf", "//tensorflow/core/platform:abi", + "//tensorflow/core/platform:blocking_counter", + "//tensorflow/core/platform:coding", "//tensorflow/core/platform:context", "//tensorflow/core/platform:cord", "//tensorflow/core/platform:cpu_feature_guard", @@ -2265,6 +2267,7 @@ cc_library( "//tensorflow/core/platform:errors", "//tensorflow/core/platform:file_statistics", "//tensorflow/core/platform:fingerprint", + "//tensorflow/core/platform:hash", "//tensorflow/core/platform:load_library", "//tensorflow/core/platform:logger", "//tensorflow/core/platform:mutex", @@ -2272,6 +2275,7 @@ cc_library( "//tensorflow/core/platform:net", "//tensorflow/core/platform:null_file_system", "//tensorflow/core/platform:numbers", + "//tensorflow/core/platform:path", "//tensorflow/core/platform:platform_port", "//tensorflow/core/platform:platform_strings", "//tensorflow/core/platform:prefetch", @@ -2466,6 +2470,15 @@ tf_proto_library( make_default_target_header_only = True, ) +tf_proto_library( + name = "test_log_proto_impl", + srcs = ["util/test_log.proto"], + cc_api_version = 2, + make_default_target_header_only = True, + # Not to be used outside this file. + visibility = ["//visibility:private"], +) + tf_proto_library( name = "core_protos", srcs = COMMON_PROTO_SRCS + [ @@ -2489,12 +2502,12 @@ tf_proto_library( "protobuf/tensorflow_server.proto", "protobuf/trackable_object_graph.proto", "protobuf/transport_options.proto", - "util/test_log.proto", ], cc_api_version = 2, make_default_target_header_only = True, protodeps = [ ":error_codes_proto_impl", + ":test_log_proto_impl", "//tensorflow/core/framework:protos_all", "//tensorflow/core/lib/core:error_codes_proto", "//tensorflow/core/profiler/protobuf:xplane_proto", @@ -3455,11 +3468,6 @@ tf_cc_tests( name = "low_level_library_tests", size = "small", srcs = [ - "lib/monitoring/collection_registry_test.cc", - "lib/monitoring/counter_test.cc", - "lib/monitoring/gauge_test.cc", - "lib/monitoring/metric_def_test.cc", - "lib/monitoring/sampler_test.cc", "lib/wav/wav_io_test.cc", "//tensorflow/core/lib/core:legacy_lib_core_all_tests", "//tensorflow/core/lib/gtl:legacy_lib_gtl_tests", @@ -3467,6 +3475,11 @@ tf_cc_tests( "//tensorflow/core/lib/histogram:legacy_lib_histogram_all_tests", "//tensorflow/core/lib/io:legacy_lib_io_all_tests", "//tensorflow/core/lib/math:math_util_test.cc", + "//tensorflow/core/lib/monitoring:collection_registry_test.cc", + "//tensorflow/core/lib/monitoring:counter_test.cc", + "//tensorflow/core/lib/monitoring:gauge_test.cc", + "//tensorflow/core/lib/monitoring:metric_def_test.cc", + "//tensorflow/core/lib/monitoring:sampler_test.cc", "//tensorflow/core/lib/random:legacy_lib_random_tests", "//tensorflow/core/lib/strings:legacy_low_level_library_tests", "//tensorflow/core/platform:fingerprint_test.cc", diff --git a/tensorflow/core/api_def/base_api/api_def_CSRSparseMatrixToSparseTensor.pbtxt b/tensorflow/core/api_def/base_api/api_def_CSRSparseMatrixToSparseTensor.pbtxt index 2b932378339..21d1983b4b3 100644 --- a/tensorflow/core/api_def/base_api/api_def_CSRSparseMatrixToSparseTensor.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_CSRSparseMatrixToSparseTensor.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "CSRSparseMatrixToSparseTensor" + visibility: HIDDEN in_arg { name: "sparse_matrix" description: "A (possibly batched) CSRSparseMatrix." diff --git a/tensorflow/core/api_def/base_api/api_def_DenseToCSRSparseMatrix.pbtxt b/tensorflow/core/api_def/base_api/api_def_DenseToCSRSparseMatrix.pbtxt index 9e578c0f123..23822cbf438 100644 --- a/tensorflow/core/api_def/base_api/api_def_DenseToCSRSparseMatrix.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_DenseToCSRSparseMatrix.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "DenseToCSRSparseMatrix" + visibility: HIDDEN in_arg { name: "dense_input" description: "A Dense tensor." diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatrixAdd.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatrixAdd.pbtxt index 78c20141b67..58328c1941f 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseMatrixAdd.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseMatrixAdd.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SparseMatrixAdd" + visibility: HIDDEN in_arg { name: "a" description: "A CSRSparseMatrix." diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatrixMatMul.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatrixMatMul.pbtxt index 8d4da45cd8a..679edada8f8 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseMatrixMatMul.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseMatrixMatMul.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SparseMatrixMatMul" + visibility: HIDDEN in_arg { name: "a" description: "A CSRSparseMatrix." diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatrixMul.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatrixMul.pbtxt index 0f9a8b30351..aa7554f7104 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseMatrixMul.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseMatrixMul.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SparseMatrixMul" + visibility: HIDDEN in_arg { name: "a" description: "A CSRSparseMatrix." diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatrixNNZ.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatrixNNZ.pbtxt index 7e19822a6d7..cc04b94c82a 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseMatrixNNZ.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseMatrixNNZ.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SparseMatrixNNZ" + visibility: HIDDEN in_arg { name: "sparse_matrix" description: "A CSRSparseMatrix." diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatrixOrderingAMD.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatrixOrderingAMD.pbtxt index 32704f2cf33..9e7842c0f68 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseMatrixOrderingAMD.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseMatrixOrderingAMD.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SparseMatrixOrderingAMD" + visibility: HIDDEN in_arg { name: "input" description: "A `CSRSparseMatrix`." diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatrixSoftmax.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatrixSoftmax.pbtxt index bf868e5ff5c..31d9aaf44b0 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseMatrixSoftmax.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseMatrixSoftmax.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SparseMatrixSoftmax" + visibility: HIDDEN in_arg { name: "logits" description: "A CSRSparseMatrix." diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatrixSoftmaxGrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatrixSoftmaxGrad.pbtxt index bb7961b94fd..0705ec91132 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseMatrixSoftmaxGrad.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseMatrixSoftmaxGrad.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SparseMatrixSoftmaxGrad" + visibility: HIDDEN in_arg { name: "softmax" description: "A CSRSparseMatrix." diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatrixSparseCholesky.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatrixSparseCholesky.pbtxt index e69814e9f91..f7cdd3574ac 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseMatrixSparseCholesky.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseMatrixSparseCholesky.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SparseMatrixSparseCholesky" + visibility: HIDDEN in_arg { name: "input" description: "A `CSRSparseMatrix`." diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatrixSparseMatMul.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatrixSparseMatMul.pbtxt index 8c9cc0ba151..f84b3948be4 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseMatrixSparseMatMul.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseMatrixSparseMatMul.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SparseMatrixSparseMatMul" + visibility: HIDDEN in_arg { name: "a" description: "A CSRSparseMatrix." diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatrixTranspose.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatrixTranspose.pbtxt index 5a3cfba8cce..179bb312ade 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseMatrixTranspose.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseMatrixTranspose.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SparseMatrixTranspose" + visibility: HIDDEN in_arg { name: "input" description: "A CSRSparseMatrix." diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatrixZeros.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatrixZeros.pbtxt index c535bba6876..08a5cc16e7d 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseMatrixZeros.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseMatrixZeros.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SparseMatrixZeros" + visibility: HIDDEN in_arg { name: "dense_shape" description: "The desired matrix shape." diff --git a/tensorflow/core/api_def/base_api/api_def_SparseTensorToCSRSparseMatrix.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseTensorToCSRSparseMatrix.pbtxt index dc8c229056b..9deb28c61f5 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseTensorToCSRSparseMatrix.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseTensorToCSRSparseMatrix.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SparseTensorToCSRSparseMatrix" + visibility: HIDDEN in_arg { name: "indices" description: "SparseTensor indices." diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 133a6c31a93..c836cb23898 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -521,7 +521,6 @@ Status DirectSession::RunInternal( executor_step_count, &debugger_state)); } - run_state.rendez.reset(new IntraProcessRendezvous(device_mgr_.get())); #ifndef __ANDROID__ // Set up for collectives if ExecutorsAndKeys declares a key. if (executors_and_keys->collective_graph_key != @@ -616,7 +615,6 @@ Status DirectSession::RunInternal( Executor::Args args; args.step_id = step_id; args.call_frame = call_frame; - args.rendezvous = run_state.rendez.get(); args.collective_executor = (run_state.collective_executor ? run_state.collective_executor->get() : nullptr); @@ -688,14 +686,21 @@ Status DirectSession::RunInternal( }; if (can_execute_synchronously) { + PrivateIntraProcessRendezvous rendezvous(device_mgr_.get()); + args.rendezvous = &rendezvous; + const auto& item = executors_and_keys->items[0]; set_threadpool_args_for_item(item, &args); run_status = item.executor->Run(args); } else { + core::RefCountPtr rendezvous( + new RefCountedIntraProcessRendezvous(device_mgr_.get())); + args.rendezvous = rendezvous.get(); + // `barrier` will delete itself after the final executor finishes. Notification executors_done; ExecutorBarrier* barrier = - new ExecutorBarrier(num_executors, run_state.rendez.get(), + new ExecutorBarrier(num_executors, rendezvous.get(), [&run_state, &executors_done](const Status& ret) { { mutex_lock l(run_state.mu); @@ -1139,7 +1144,7 @@ Status DirectSession::SendPRunInputs(const NamedTensorList& inputs, Status DirectSession::RecvPRunOutputs( const std::vector& output_names, - const ExecutorsAndKeys* executors_and_keys, RunState* run_state, + const ExecutorsAndKeys* executors_and_keys, PartialRunState* run_state, std::vector* outputs) { Status s; if (!output_names.empty()) { diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index a272633b4e2..7bbb198ef44 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -191,7 +191,6 @@ class DirectSession : public Session { struct RunState { mutex mu; Status status GUARDED_BY(mu); - core::RefCountPtr rendez = nullptr; std::unique_ptr collective_executor; std::unique_ptr collector; TensorStore tensor_store; @@ -208,6 +207,7 @@ class DirectSession : public Session { Notification executors_done; std::unordered_map pending_inputs; // true if fed std::unordered_map pending_outputs; // true if fetched + core::RefCountPtr rendez = nullptr; PartialRunState(const std::vector& pending_input_names, const std::vector& pending_output_names, @@ -282,7 +282,7 @@ class DirectSession : public Session { // tensors are computed. ::tensorflow::Status RecvPRunOutputs( const std::vector& output_names, - const ExecutorsAndKeys* executors_and_keys, RunState* run_state, + const ExecutorsAndKeys* executors_and_keys, PartialRunState* run_state, std::vector* outputs); // Check if the specified fetches can be computed from the feeds diff --git a/tensorflow/core/common_runtime/eager/copy_to_device_node.h b/tensorflow/core/common_runtime/eager/copy_to_device_node.h index 144184fac9a..53f3ff94d78 100644 --- a/tensorflow/core/common_runtime/eager/copy_to_device_node.h +++ b/tensorflow/core/common_runtime/eager/copy_to_device_node.h @@ -43,7 +43,7 @@ class CopyToDeviceNode : public EagerNode { MEMDEBUG_CACHE_OP(MEMDEBUG_CACHE_VAL ? MEMDEBUG_CACHE_VAL : "eager::CopyToDeviceNode"); TF_RETURN_IF_ERROR(src_->CopyToDevice(ctx_, dstd_, &tensor)); - return dst_->SetTensor(tensor); + return dst_->SetTensor(std::move(tensor)); } void Abort(Status status) override { dst_->Poison(status); } diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index e2c424e8ed6..32fdb21c1b4 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -642,8 +642,10 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, graph_collector = ctx->GetGraphCollector(); } + const bool async = executor.Async(); for (int i = 0; i < num_outputs; ++i) { - TF_RETURN_IF_ERROR(TensorHandle::CreateAsyncLocalHandle( + TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle( + async, /* d= */ ctx->CanonicalDevice(kernel->OutputDevice(i)), /* op_device= */ kernel->device(), /* resource_device= */ kernel->OutputResourceDevice(i), @@ -651,7 +653,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, } Status s; - if (executor.Async()) { + if (async) { auto node = absl::make_unique( ctx, op->Inputs(), op->remote_func_params(), std::move(kernel), graph_collector, output_dtypes, op->GetCancellationManager(), @@ -1038,24 +1040,21 @@ Status EagerKernelExecute( profiler::TraceMe activity("EagerKernelExecute", profiler::TraceMeLevel::kInfo); std::vector outputs(1); - gtl::InlinedVector input_vector(op_inputs.size()); - std::unique_ptr inputs; - TF_RETURN_IF_ERROR(ExecuteNodeArgs::CreateExecuteNodeArgs( - std::move(input_vector), ctx, op_inputs, &inputs)); + ExecuteNodeArgs inputs(op_inputs.size()); + TF_RETURN_IF_ERROR(inputs.Init(ctx, op_inputs)); // TODO(apassos) figure out how to record stats for ops which are a part of // functions. - // TODO(agarwal): change Run to take vector of handles ? // TODO(b/111859745): When we support recovering from kernel/device errors, we // would need to call XlaDevice::EnsureDeviceContextOk() before using an XLA // device. We don't call it now because it is an unneeded overhead (it // acquires a lock) and we can't recover from errors anyway. ScopedStepContainer* container = ctx->StepContainer(); if (container == nullptr) { - TF_RETURN_IF_ERROR(kernel->Run(*inputs, &outputs, cancellation_manager, + TF_RETURN_IF_ERROR(kernel->Run(inputs, &outputs, cancellation_manager, remote_func_params)); } else { - TF_RETURN_IF_ERROR(kernel->Run(container, *inputs, &outputs, + TF_RETURN_IF_ERROR(kernel->Run(container, inputs, &outputs, cancellation_manager, remote_func_params)); } if (graph_collector != nullptr) { @@ -1089,7 +1088,7 @@ Status EagerKernelExecute( DCHECK_EQ(ctx->CanonicalDevice(kernel->OutputDevice(i)), retvals[i]->device()); - TF_RETURN_IF_ERROR(retvals[i]->SetTensor(outputs[i])); + TF_RETURN_IF_ERROR(retvals[i]->SetTensor(std::move(outputs[i]))); } return Status::OK(); } @@ -1100,9 +1099,9 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx, EagerExecutor* executor, Device* dstd, TensorHandle** result) { TF_RETURN_IF_ERROR(executor->status()); - TF_RETURN_IF_ERROR(TensorHandle::CreateAsyncLocalHandle( - ctx->CanonicalDevice(dstd), dstd, h->resource_device(), h->dtype, ctx, - result)); + TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle( + true, ctx->CanonicalDevice(dstd), dstd, h->resource_device(), h->dtype, + ctx, result)); // Note that `h` may not be currently ready. However execution order will // make sure that `h` is ready before the copy is actually done. @@ -1150,10 +1149,9 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, } uint64 recv_op_id = 0; if (recver_is_local) { - TF_RETURN_IF_ERROR(TensorHandle::CreateAsyncLocalHandle( - /* d= */ device, - /* op_device= */ device, /*resource_device=*/nullptr, h->dtype, ctx, - result)); + TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle( + true, /* d= */ device, /* op_device= */ device, + /*resource_device=*/nullptr, h->dtype, ctx, result)); } else { uint64 context_id = ctx->GetContextId(); string remote_task; diff --git a/tensorflow/core/common_runtime/eager/execute_node.h b/tensorflow/core/common_runtime/eager/execute_node.h index 3fb53736078..08cecf56098 100644 --- a/tensorflow/core/common_runtime/eager/execute_node.h +++ b/tensorflow/core/common_runtime/eager/execute_node.h @@ -45,16 +45,12 @@ namespace tensorflow { class ExecuteNodeArgs : public EagerKernelArgs { public: - static Status CreateExecuteNodeArgs( - gtl::InlinedVector&& tensor_args, EagerContext* ctx, - const gtl::InlinedVector& op_inputs, - std::unique_ptr* args) { - args->reset(new ExecuteNodeArgs(std::move(tensor_args))); - return (*args)->Init(ctx, op_inputs); - } - + explicit ExecuteNodeArgs(int count) : EagerKernelArgs(count) {} ~ExecuteNodeArgs() override; + Status Init(EagerContext* ctx, + const gtl::InlinedVector& op_inputs); + bool HasRemoteInputs() const override { return has_remote_inputs_; }; #if !defined(IS_MOBILE_PLATFORM) @@ -65,12 +61,6 @@ class ExecuteNodeArgs : public EagerKernelArgs { #endif // IS_MOBILE_PLATFORM private: - explicit ExecuteNodeArgs(gtl::InlinedVector&& tensor_args) - : EagerKernelArgs(std::move(tensor_args)) {} - - Status Init(EagerContext* ctx, - const gtl::InlinedVector& op_inputs); - bool has_remote_inputs_ = false; TensorReferenceVector protected_tensors_; #if !defined(IS_MOBILE_PLATFORM) diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h index 395dcc98f78..04d97f2b80c 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.h +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h @@ -59,6 +59,8 @@ class EagerKernelArgs : public FunctionArgsInterface { public: EagerKernelArgs() {} + explicit EagerKernelArgs(int count) : tensor_args_(count) {} + explicit EagerKernelArgs(gtl::InlinedVector&& tensor_args) : tensor_args_(std::move(tensor_args)) {} diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index e4b297b646f..717ec586eef 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -131,10 +131,10 @@ TensorHandle::TensorHandle(std::unique_ptr t, #endif ctx_(ctx), is_remote_(false), + is_async_(false), + is_ready_(true), tensor_handle_data_(std::move(t)) { DVLOG(3) << "Creating Local TensorHandle: " << this << " device: " << device_; - // Notify immediately since this handle is already ready. - is_ready_notification_.Notify(); } TensorHandle::TensorHandle(std::unique_ptr t, @@ -150,25 +150,26 @@ TensorHandle::TensorHandle(std::unique_ptr t, #endif ctx_(ctx), is_remote_(false), + is_async_(false), + is_ready_(true), handle_dtypes_and_shapes_(resource_handle.dtypes_and_shapes()), tensor_handle_data_(std::move(t)) { DVLOG(3) << "Creating Local TensorHandle: " << this << " device: " << device_; - // Notify immediately since this handle is already ready. - is_ready_notification_.Notify(); } -Status TensorHandle::CreateAsyncLocalHandle(Device* d, Device* op_device, +Status TensorHandle::CreateEmptyLocalHandle(bool async, Device* d, + Device* op_device, Device* resource_device, DataType dtype, EagerContext* ctx, TensorHandle** h) { - *h = new TensorHandle(absl::make_unique(), d, - op_device, resource_device, dtype, ctx); + *h = new TensorHandle(absl::make_unique(), async, + d, op_device, resource_device, dtype, ctx); return Status::OK(); } -TensorHandle::TensorHandle(std::unique_ptr t, - Device* d, Device* op_device, +TensorHandle::TensorHandle(std::unique_ptr t, + bool async, Device* d, Device* op_device, Device* resource_device, DataType dtype, EagerContext* ctx) : dtype(dtype), @@ -181,6 +182,8 @@ TensorHandle::TensorHandle(std::unique_ptr t, #endif ctx_(ctx), is_remote_(false), + is_async_(async), + is_ready_(!async), tensor_handle_data_(std::move(t)) { DVLOG(3) << "Creating Async Local TensorHandle: " << this << " device: " << device_; @@ -219,11 +222,11 @@ TensorHandle::TensorHandle(std::unique_ptr t, remote_output_num_(t->output_num()), ctx_(ctx), is_remote_(true), + is_async_(false), + is_ready_(true), tensor_handle_data_(std::move(t)) { DVLOG(3) << "Creating Remote TensorHandle: " << this << " device: " << device_; - // Notify immediately since this handle is already ready. - is_ready_notification_.Notify(); } Status TensorHandle::CreateUnshapedRemoteHandle( @@ -255,21 +258,30 @@ TensorHandle::TensorHandle(std::unique_ptr t, remote_context_id_(t->context_id()), ctx_(ctx), is_remote_(true), + is_async_(true), + is_ready_(false), tensor_handle_data_(std::move(t)) { DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this << " device: " << device_; } #endif -bool TensorHandle::IsReady() { - return is_ready_notification_.HasBeenNotified(); +bool TensorHandle::IsReady() const { + // Avoid mutex acquisition for local sync handles + if (!is_async_ && !is_remote_) { + return true; + } + + tf_shared_lock l(mu_); + return is_ready_; } Status TensorHandle::WaitReady(const char* caller) { if (!IsReady()) { profiler::TraceMe activity(absl::StrCat(caller, " WaitReady"), profiler::TraceMeLevel::kInfo); - is_ready_notification_.WaitForNotification(); + tf_shared_lock l(mu_); + mu_.Await(Condition(&is_ready_)); } return is_poisoned_; } @@ -401,7 +413,7 @@ Status TensorHandle::NumElements(int64* num_elements) { Status TensorHandle::RemoteAddress(Device* d, int64* op_id, int32* output_num) const { if (d != device_) { - tf_shared_lock l(remote_mirrors_mutex_); + tf_shared_lock l(mu_); auto mirror = remote_mirrors_.find(d); if (mirror != remote_mirrors_.end()) { *op_id = mirror->second->op_id(); @@ -439,7 +451,7 @@ void TensorHandle::SetRemoteOpIdAndOutputNumToLocalTensorHandle( } bool TensorHandle::HasRemoteMirror(Device* d) { - tf_shared_lock l(remote_mirrors_mutex_); + tf_shared_lock l(mu_); auto mirror = remote_mirrors_.find(d); if (mirror != remote_mirrors_.end()) { return true; @@ -454,7 +466,7 @@ bool TensorHandle::HasRemoteMirror(Device* d) { } bool TensorHandle::HasResourceShapeMirror(Device* d) { - tf_shared_lock l(resource_shape_mirrors_mutex_); + tf_shared_lock l(mu_); auto mirror = resource_shape_mirrors_.find(d); if (mirror != resource_shape_mirrors_.end()) { return true; @@ -464,7 +476,7 @@ bool TensorHandle::HasResourceShapeMirror(Device* d) { Status TensorHandle::AddUnshapedRemoteMirror( std::unique_ptr t, Device* d) { - mutex_lock l(remote_mirrors_mutex_); + mutex_lock l(mu_); if (remote_mirrors_.find(d) != remote_mirrors_.end()) { return errors::Internal("Attempted to duplicate a remote mirror."); } @@ -480,7 +492,7 @@ Status TensorHandle::AddUnshapedRemoteMirror( Status TensorHandle::AddResourceShapeMirror( std::unique_ptr t, Device* d) { - mutex_lock l(resource_shape_mirrors_mutex_); + mutex_lock l(mu_); auto ret = resource_shape_mirrors_.insert(std::make_pair(d, std::move(t))); if (!ret.second) { return errors::Internal("Attempted to duplicate a resource shape mirror."); @@ -491,7 +503,7 @@ Status TensorHandle::AddResourceShapeMirror( Status TensorHandle::AddRemoteMirror(std::unique_ptr t, Device* d) { - mutex_lock l(remote_mirrors_mutex_); + mutex_lock l(mu_); auto ret = remote_mirrors_.insert(std::make_pair(d, std::move(t))); if (!ret.second) { return errors::Internal("Attempted to duplicate a remote mirror."); @@ -505,7 +517,7 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape, DVLOG(3) << "SetRemoteShape on TensorHandle: " << this << " device: " << d; if (d != device_) { - mutex_lock l(remote_mirrors_mutex_); + mutex_lock l(mu_); if (remote_mirrors_.find(d) != remote_mirrors_.end()) { return errors::Internal( "Attempted to set remote shape for existing mirror."); @@ -528,8 +540,7 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape, } DCHECK(is_remote_) << "SeRemoteShape is only called on remote handles."; - DCHECK(!is_ready_notification_.HasBeenNotified()) - << "SetRemoteShape is only called on non-ready handles."; + DCHECK(!IsReady()) << "SetRemoteShape is only called on non-ready handles."; UnshapedRemoteTensorHandleData* p = reinterpret_cast( @@ -539,15 +550,16 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape, remote_op_id_, remote_output_num_, shape, remote_task_, remote_context_id_, ctx_); is_poisoned_ = Status::OK(); - is_ready_notification_.Notify(); + mutex_lock l(mu_); + is_ready_ = true; return Status::OK(); } #endif -Status TensorHandle::SetTensor(const tensorflow::Tensor& tensor) { +Status TensorHandle::SetTensor(tensorflow::Tensor&& tensor) { DCHECK(!is_remote_) << "SetTensor is not called on remote handles."; - DCHECK(!is_ready_notification_.HasBeenNotified()) + DCHECK(!is_async_ || !IsReady()) << "SetTensor is only called on non-ready handles."; DVLOG(3) << "SetTensor on TensorHandle: " << this; @@ -557,19 +569,24 @@ Status TensorHandle::SetTensor(const tensorflow::Tensor& tensor) { handle_dtypes_and_shapes_ = resource_handle.dtypes_and_shapes(); } tensor_handle_data_ = absl::make_unique(tensor); - is_poisoned_ = Status::OK(); - is_ready_notification_.Notify(); + if (is_async_) { + is_poisoned_ = Status::OK(); + mutex_lock l(mu_); + is_ready_ = true; + } + return Status::OK(); } void TensorHandle::Poison(Status status) { - DCHECK(!is_ready_notification_.HasBeenNotified()) + DCHECK(!is_async_ || !IsReady()) << "Poison(status) can only be called on non-ready handle: " << this; DVLOG(3) << "Poison on TensorHandle: " << this; is_poisoned_ = status; - is_ready_notification_.Notify(); + mutex_lock l(mu_); + is_ready_ = true; } Status TensorHandle::CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd, diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index a8b05e34b43..c32ec834071 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -67,9 +67,9 @@ class TensorHandle : public core::RefCounted { TensorHandle(std::unique_ptr t, const ResourceHandle& resource_handle, Device* d, Device* op_device, EagerContext* ctx); - TensorHandle(std::unique_ptr t, Device* d, - Device* op_device, Device* resource_device, DataType dtype, - EagerContext* ctx); + TensorHandle(std::unique_ptr t, bool async, + Device* d, Device* op_device, Device* resource_device, + DataType dtype, EagerContext* ctx); #if !defined(IS_MOBILE_PLATFORM) TensorHandle(std::unique_ptr t, DataType dtype, @@ -87,7 +87,7 @@ class TensorHandle : public core::RefCounted { static Status CreateLocalHandle(const class Tensor& t, Device* d, Device* op_device, EagerContext* ctx, TensorHandle** h); - static Status CreateAsyncLocalHandle(Device* d, Device* op_device, + static Status CreateEmptyLocalHandle(bool async, Device* d, Device* op_device, Device* resource_device, DataType dtype, EagerContext* ctx, TensorHandle** h); #if !defined(IS_MOBILE_PLATFORM) @@ -158,7 +158,7 @@ class TensorHandle : public core::RefCounted { // Sets the `tensor` for this async non-ready handle making it ready. // This method or Poison must be called exactly once for non-ready async // handles to make them ready. - Status SetTensor(const tensorflow::Tensor& tensor); + Status SetTensor(tensorflow::Tensor&& tensor); // Poisons this non-ready handle with an error `status`. // Poisoning means that the handle will become ready and methods trying @@ -167,8 +167,6 @@ class TensorHandle : public core::RefCounted { // on a non-ready tensor. void Poison(Status status); - bool IsReady(); - Status CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd, tensorflow::Tensor* output); @@ -207,6 +205,12 @@ class TensorHandle : public core::RefCounted { std::vector* result); private: + // The TensorHandleData can either represent a local or remote tensor handle. + // Further, it can be in a non-ready state. It would become ready with a call + // to either SetTensor or SetRemoteShape which replaces the underlying data + // with a ready version of the tensor handle data. + bool IsReady() const; + // If the contents of the Tensor pointed to by this handle is yet to be // computed by a EagerNode, this function will block till that computation is // done and the handle is "ready". @@ -232,24 +236,24 @@ class TensorHandle : public core::RefCounted { // backing the resource. Else resource_device_ is nullptr. tensorflow::Device* const resource_device_; + mutable mutex mu_; + #if !defined(IS_MOBILE_PLATFORM) // TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica // variable is ready, since we could get the shape locally without remote copy // then. - mutable mutex resource_shape_mirrors_mutex_; std::map> - resource_shape_mirrors_ GUARDED_BY(resource_shape_mirrors_mutex_); + resource_shape_mirrors_ GUARDED_BY(mu_); - mutable mutex remote_mirrors_mutex_; // TODO(gjn): Unshaped remote mirrors are long expected to be long-lived. // Consider replacing the unshaped_remote_mirrors_ map with something more // efficient. std::map> - unshaped_remote_mirrors_ GUARDED_BY(remote_mirrors_mutex_); + unshaped_remote_mirrors_ GUARDED_BY(mu_); // TODO(gjn): Is std::map the most optimal choice here? Perhaps this should be // a fixed size map. std::map> - remote_mirrors_ GUARDED_BY(remote_mirrors_mutex_); + remote_mirrors_ GUARDED_BY(mu_); // IDs required when this class is representing a remote tensor handle. int64 remote_op_id_; @@ -263,24 +267,18 @@ class TensorHandle : public core::RefCounted { // `ctx` object is not owned and should outlive this handle. EagerContext* const ctx_; - // Explanation for NOLINT below: absl has clang-tidy macro to rename - // 'tensorflow::Notification' to 'absl::Notification'. TF does not use - // absl::Notification in open source now, so we can't follow clang-tidy - tensorflow::Notification is_ready_notification_; // NOLINT // Does not need synchronization because it can be accessed only after // WaitReady() has returned. At that point, is_poisoned_ is immutable. Status is_poisoned_; const bool is_remote_; + const bool is_async_; + bool is_ready_ GUARDED_BY(mu_); // If this TensorHandle 1) is a local tensor, and 2) is a resource handle or // refers to a remote resource handle, we store data types and shapes for // the underlying resource. std::vector handle_dtypes_and_shapes_; - // The TensorHandleData can either represent a local or remote tensor handle. - // Further, it can be in a non-ready state. It would become ready with a call - // to either SetTensor or SetRemoteShape which replaces the underlying data - // with a ready version of the tensor handle data. // Does not need synchronization because it can be accessed only after // WaitReady() has returned. At that point, tensor_handle_data_ is immutable. std::unique_ptr tensor_handle_data_; diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_data.cc b/tensorflow/core/common_runtime/eager/tensor_handle_data.cc index 4fb44269584..d718e39687f 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle_data.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle_data.cc @@ -58,44 +58,44 @@ Status LocalTensorHandleData::NumElements(int64* num_elements) const { return Status::OK(); } -Status AsyncLocalTensorHandleData::Tensor(const tensorflow::Tensor** t) const { +Status EmptyLocalTensorHandleData::Tensor(const tensorflow::Tensor** t) const { return errors::Unavailable( - "Unable to get a tensor for an async handle. " + "Unable to get a tensor for an empty handle. " "Please wait until it is ready"); } -Status AsyncLocalTensorHandleData::TensorValue(tensorflow::TensorValue* t) { +Status EmptyLocalTensorHandleData::TensorValue(tensorflow::TensorValue* t) { return errors::Unavailable( - "Unable to get a tensor for an async handle. " + "Unable to get a tensor for an empty handle. " "Please wait until it is ready"); } -Status AsyncLocalTensorHandleData::Shape(TensorShape* shape) const { +Status EmptyLocalTensorHandleData::Shape(TensorShape* shape) const { return errors::Unavailable( - "Unable to get shape information for an async handle. " + "Unable to get shape information for an empty handle. " "Please wait until it is ready"); } -Status AsyncLocalTensorHandleData::NumDims(int* num_dims) const { +Status EmptyLocalTensorHandleData::NumDims(int* num_dims) const { return errors::Unavailable( - "Unable to get shape information for an async handle. " + "Unable to get shape information for an empty handle. " "Please wait until it is ready"); } -Status AsyncLocalTensorHandleData::Dim(int dim_index, int64* dim) const { +Status EmptyLocalTensorHandleData::Dim(int dim_index, int64* dim) const { return errors::Unavailable( - "Unable to get shape information for an async handle. " + "Unable to get shape information for an empty handle. " "Please wait until it is ready"); } -Status AsyncLocalTensorHandleData::NumElements(int64* num_elements) const { +Status EmptyLocalTensorHandleData::NumElements(int64* num_elements) const { return errors::Unavailable( - "Unable to get shape information for an async handle. " + "Unable to get shape information for an empty handle. " "Please wait until it is ready"); } -string AsyncLocalTensorHandleData::DebugString() const { - return "AsyncLocalTensorHandleData"; +string EmptyLocalTensorHandleData::DebugString() const { + return "EmptyLocalTensorHandleData"; } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_data.h b/tensorflow/core/common_runtime/eager/tensor_handle_data.h index c9be6592426..e50200277f1 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle_data.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle_data.h @@ -58,15 +58,14 @@ class LocalTensorHandleData : public TensorHandleData { tensorflow::Tensor tensor_; }; -// Async Local Tensor Handle: A non-ready local tensor handle used in async -// eager execution. Once the execution is complete this is replaced by a local -// tensor handle. -class AsyncLocalTensorHandleData : public TensorHandleData { +// Empty Local Tensor Handle: Once the execution is complete this is replaced by +// a local tensor handle. +class EmptyLocalTensorHandleData : public TensorHandleData { public: - AsyncLocalTensorHandleData() {} - ~AsyncLocalTensorHandleData() override {} + EmptyLocalTensorHandleData() {} + ~EmptyLocalTensorHandleData() override {} - // Async tensor handles are not ready and hence cannot satisfy any of these + // Empty tensor handles are not ready and hence cannot satisfy any of these // requests. Status Tensor(const tensorflow::Tensor** t) const override; Status TensorValue(tensorflow::TensorValue* t) override; diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc index 05801d6e564..ea81cda6199 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc @@ -33,13 +33,12 @@ TEST(TensorHandle_ShapeTest, AsyncShape) { TensorHandle* sync_th; EXPECT_TRUE(TensorHandle::CreateLocalHandle(t, &sync_th).ok()); TensorHandle* async_th; - EXPECT_TRUE(TensorHandle::CreateAsyncLocalHandle(nullptr, nullptr, nullptr, - DataType::DT_UINT16, nullptr, - &async_th) + EXPECT_TRUE(TensorHandle::CreateEmptyLocalHandle(true, nullptr, nullptr, + nullptr, DataType::DT_UINT16, + nullptr, &async_th) .ok()); EXPECT_TRUE(async_th->CopyInferenceShape(sync_th).ok()); - EXPECT_FALSE(async_th->IsReady()); TensorShape sync_shape; TensorShape async_shape; diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index e1135ec488c..1c04adf7872 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -1287,7 +1287,7 @@ class ExecutorState { int64 step_id_; // Not owned. - Rendezvous* rendezvous_; + RendezvousInterface* rendezvous_; Executor::RendezvousFactory* create_rendezvous_ = nullptr; CollectiveExecutor* collective_executor_ = nullptr; SessionState* session_state_; diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h index 42d5b9eab4f..c147deee694 100644 --- a/tensorflow/core/common_runtime/executor.h +++ b/tensorflow/core/common_runtime/executor.h @@ -88,7 +88,7 @@ class Executor { struct Args { int64 step_id = 0; - Rendezvous* rendezvous = nullptr; + RendezvousInterface* rendezvous = nullptr; StepStatsCollectorInterface* stats_collector = nullptr; CallFrameInterface* call_frame = nullptr; CancellationManager* cancellation_manager = nullptr; diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 668a671d5e8..501002e1f7f 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -1017,7 +1017,7 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle, Item* item, DoneCallback done) { string target_device = parent_->GetDeviceName(handle); string source_device = opts.source_device; - Rendezvous* rendezvous = opts.rendezvous; + RendezvousInterface* rendezvous = opts.rendezvous; DeviceContext* device_context; Status s = parent_->GetDeviceContext(target_device, &device_context); if (!s.ok()) { @@ -1116,11 +1116,11 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, } Options run_opts = opts; if (opts.create_rendezvous) { - Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_); + auto* rendezvous = new PrivateIntraProcessRendezvous(device_mgr_); run_opts.rendezvous = rendezvous; run_opts.create_rendezvous = false; - done = [done = std::move(done), rendezvous](const Status& status) { - rendezvous->Unref(); + done = [done = std::move(done), rendezvous](const Status& status) mutable { + delete rendezvous; done(status); }; } @@ -1187,11 +1187,11 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, Options run_opts = opts; if (opts.create_rendezvous) { - Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_); + auto* rendezvous = new PrivateIntraProcessRendezvous(device_mgr_); run_opts.rendezvous = rendezvous; run_opts.create_rendezvous = false; - done = [done = std::move(done), rendezvous](const Status& status) { - rendezvous->Unref(); + done = [done = std::move(done), rendezvous](const Status& status) mutable { + delete rendezvous; done(status); }; } diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index c3d6e948f1e..89e4daa50b3 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -1854,7 +1854,8 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) { Tensor y; FunctionLibraryRuntime::Options opts; - opts.rendezvous = new IntraProcessRendezvous(device_mgr_.get()); + PrivateIntraProcessRendezvous rendezvous(device_mgr_.get()); + opts.rendezvous = &rendezvous; opts.source_device = "/device:CPU:1"; // Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1. TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}, true)); @@ -1869,7 +1870,6 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) { y, test::AsTensor({"/job:localhost/replica:0/task:0/device:CPU:1"}, TensorShape({}))); - opts.rendezvous->Unref(); } namespace { diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index 0e230a5d2bd..2287bf889ab 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -61,7 +61,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #if GOOGLE_CUDA #include "third_party/gpus/cudnn/cudnn.h" -#include "tensorflow/core/platform/cuda.h" +#include "tensorflow/stream_executor/cuda/cuda_activation.h" #elif TENSORFLOW_USE_ROCM #include "tensorflow/core/platform/rocm.h" #endif diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc index 48ee4b11a33..0a7d50f9ea4 100644 --- a/tensorflow/core/common_runtime/graph_runner.cc +++ b/tensorflow/core/common_runtime/graph_runner.cc @@ -45,7 +45,7 @@ namespace { // A simple rendezvous class. // Assumes a single sender and a single receiver, no duplicate sends, and no // sends of dead tensors. -class SimpleRendezvous : public Rendezvous { +class SimpleRendezvous : public RendezvousInterface { public: explicit SimpleRendezvous() {} @@ -124,8 +124,7 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library, std::unique_ptr graph_to_run(new Graph(graph->op_registry())); CopyGraph(*graph, graph_to_run.get()); - SimpleRendezvous* rendez = new SimpleRendezvous; - core::ScopedUnref rendez_unref(rendez); + SimpleRendezvous rendez; // Extract the input names and keys, and feed in the inputs. std::vector input_names; @@ -136,8 +135,8 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library, tensor_name, FrameAndIter(0, 0)); Rendezvous::ParsedKey parsed; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(full_key, &parsed)); - TF_RETURN_IF_ERROR(rendez->Send(parsed, Rendezvous::Args(), in.second, - false /* is_dead */)); + TF_RETURN_IF_ERROR(rendez.Send(parsed, Rendezvous::Args(), in.second, + false /* is_dead */)); } // Call RewriteGraphForExecution @@ -180,7 +179,7 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library, // called via this method. args.step_id = LogMemory::CONSTANT_FOLDING_STEP_ID; args.runner = runner; - args.rendezvous = rendez; + args.rendezvous = &rendez; // NOTE: Use of graph runner is limited to single-device executions // so a CollectiveExecutor should never be required. args.collective_executor = nullptr; @@ -201,7 +200,7 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library, bool is_dead; Tensor output_tensor; TF_RETURN_IF_ERROR( - rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead)); + rendez.Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead)); // Does a deep copy so that ownership of the tensor isn't tied to the // allocator of the cpu device we created above. The allocator could be // deleted along with the device. diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 5c08d330ccf..4c01978e6d5 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -122,7 +122,7 @@ Status ProcessFunctionLibraryRuntime::SendTensors( const string& key_prefix, int64 src_incarnation, gtl::ArraySlice tensors_to_send, DeviceContext* device_context, const std::vector& alloc_attrs, - Rendezvous* rendezvous) { + RendezvousInterface* rendezvous) { std::vector keys; for (int i = 0; i < tensors_to_send.size(); ++i) { string name = strings::StrCat(key_prefix, i); @@ -140,8 +140,9 @@ void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync( const string& source_device, const string& target_device, const string& key_prefix, int64 src_incarnation, int64 num_tensors, DeviceContext* device_context, - const std::vector& alloc_attrs, Rendezvous* rendezvous, - std::vector* received_tensors, StatusCallback done) { + const std::vector& alloc_attrs, + RendezvousInterface* rendezvous, std::vector* received_tensors, + StatusCallback done) { std::vector keys; for (int64 i = 0; i < num_tensors; ++i) { string name = strings::StrCat(key_prefix, i); diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index 0166267b3ab..ee5d8bf2b16 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -92,7 +92,7 @@ class ProcessFunctionLibraryRuntime { gtl::ArraySlice tensors_to_send, DeviceContext* device_context, const std::vector& alloc_attrs, - Rendezvous* rendezvous); + RendezvousInterface* rendezvous); // Receives `received_tensors` from `target_device` (originally sent from // `source_device`) using `rendezvous`. Uses `key_prefix` to construct the @@ -105,7 +105,7 @@ class ProcessFunctionLibraryRuntime { const string& key_prefix, int64 src_incarnation, int64 num_tensors, DeviceContext* device_context, const std::vector& alloc_attrs, - Rendezvous* rendezvous, std::vector* received_tensors, + RendezvousInterface* rendezvous, std::vector* received_tensors, StatusCallback done); static const char kDefaultFLRDevice[]; diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc index 1a5ed3caa11..55bc408f9c5 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc @@ -110,12 +110,6 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { } } - ~ProcessFunctionLibraryRuntimeTest() override { - if (rendezvous_ != nullptr) { - rendezvous_->Unref(); - } - } - void Init(const std::vector& flib, const SessionMetadata* session_metadata = nullptr) { FunctionDefLibrary proto; @@ -127,7 +121,8 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { device_mgr_.get(), Env::Default(), /*config=*/nullptr, TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, nullptr, cluster_flr_.get(), nullptr, session_metadata)); - rendezvous_ = new IntraProcessRendezvous(device_mgr_.get()); + rendezvous_ = + absl::make_unique(device_mgr_.get()); } Status Instantiate( @@ -263,7 +258,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { test::function::FunctionTestSchedClosure(fn); }; - opts.rendezvous = rendezvous_; + opts.rendezvous = rendezvous_.get(); opts.runner = &runner; Status status; Notification done; @@ -292,7 +287,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { std::unique_ptr lib_def_; std::unique_ptr cluster_flr_; std::unique_ptr proc_flr_; - IntraProcessRendezvous* rendezvous_ = nullptr; + std::unique_ptr rendezvous_ = nullptr; }; TEST_F(ProcessFunctionLibraryRuntimeTest, GetFLRNull) { @@ -344,7 +339,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) { Init({test::function::XTimesTwo()}); FunctionLibraryRuntime::Options opts; opts.source_device = "/job:a/replica:0/task:0/cpu:0"; - opts.rendezvous = rendezvous_; + opts.rendezvous = rendezvous_.get(); opts.remote_execution = true; FunctionLibraryRuntime::InstantiateOptions instantiate_opts; instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0"; @@ -359,7 +354,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) { Init({test::function::FindDevice()}); FunctionLibraryRuntime::Options opts; opts.source_device = "/job:a/replica:0/task:0/cpu:0"; - opts.rendezvous = rendezvous_; + opts.rendezvous = rendezvous_.get(); opts.remote_execution = true; FunctionLibraryRuntime::InstantiateOptions instantiate_opts; instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0"; @@ -375,7 +370,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) { auto x = test::AsTensor({1, 2, 3, 4}); FunctionLibraryRuntime::Options opts; opts.source_device = "/job:a/replica:0/task:0/cpu:0"; - opts.rendezvous = rendezvous_; + opts.rendezvous = rendezvous_.get(); opts.remote_execution = true; FunctionLibraryRuntime::InstantiateOptions instantiate_opts; instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0"; @@ -392,7 +387,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) { Init({test::function::FindDevice()}); FunctionLibraryRuntime::Options opts; opts.source_device = "/job:a/replica:0/task:0/cpu:0"; - opts.rendezvous = rendezvous_; + opts.rendezvous = rendezvous_.get(); opts.remote_execution = true; FunctionLibraryRuntime::InstantiateOptions instantiate_opts; instantiate_opts.target = "/job:a/replica:0/task:0/cpu:1"; @@ -411,7 +406,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) { Init({test::function::FindDevice()}); FunctionLibraryRuntime::Options opts; opts.source_device = "/job:a/replica:0/task:0/cpu:0"; - opts.rendezvous = rendezvous_; + opts.rendezvous = rendezvous_.get(); opts.remote_execution = true; Tensor y; FunctionLibraryRuntime::InstantiateOptions instantiate_opts_0; @@ -432,7 +427,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRSerialTest) { Init({test::function::FindDevice()}); FunctionLibraryRuntime::Options opts; opts.source_device = "/job:a/replica:0/task:0/cpu:0"; - opts.rendezvous = rendezvous_; + opts.rendezvous = rendezvous_.get(); opts.remote_execution = true; FunctionLibraryRuntime::InstantiateOptions instantiate_opts; instantiate_opts.target = "/job:b/replica:0/task:0/device:CPU:0"; @@ -462,7 +457,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRParallelTest) { Init({test::function::FindDevice()}); FunctionLibraryRuntime::Options opts; opts.source_device = "/job:a/replica:0/task:0/cpu:0"; - opts.rendezvous = rendezvous_; + opts.rendezvous = rendezvous_.get(); opts.remote_execution = true; FunctionLibraryRuntime::InstantiateOptions instantiate_opts; instantiate_opts.target = "/job:b/replica:0/task:0/device:CPU:0"; @@ -509,7 +504,7 @@ void TestTwoDeviceMult( const string& error = "") { fixture->Init({test::function::TwoDeviceMult()}); FunctionLibraryRuntime::Options opts; - opts.rendezvous = fixture->rendezvous_; + opts.rendezvous = fixture->rendezvous_.get(); auto x = test::AsTensor({1, 2, 3}); Tensor y_cpu; Tensor y_gpu; @@ -542,7 +537,7 @@ void TestTwoDeviceInputOutput( fixture->Init({test::function::TwoDeviceInputOutput()}); FunctionLibraryRuntime::Options opts; - opts.rendezvous = fixture->rendezvous_; + opts.rendezvous = fixture->rendezvous_.get(); Tensor x1 = test::AsTensor({1, 2}); if (absl::StrContains(inst_opts.input_devices[0], "GPU")) { x1 = fixture->CPUToGPU(x1); @@ -743,7 +738,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ResourceOutput_GPU) { // Run the function taking a resource and outputing it FunctionLibraryRuntime::Options opts; - opts.rendezvous = rendezvous_; + opts.rendezvous = rendezvous_.get(); Tensor x1 = CPUToGPU(test::AsTensor({1, 2})); Tensor x2 = GetResourceHandle("my_gpu_var", mgr->default_container(), "/job:a/replica:0/task:0/device:GPU:0"); @@ -985,7 +980,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataAbsent) { Init({SessionMetadataReaderOpFn()}, /*session_metadata=*/nullptr); FunctionLibraryRuntime::Options opts; opts.source_device = "/job:a/replica:0/task:0/cpu:0"; - opts.rendezvous = rendezvous_; + opts.rendezvous = rendezvous_.get(); opts.remote_execution = true; FunctionLibraryRuntime::InstantiateOptions instantiate_opts; instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0"; @@ -1001,7 +996,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresent) { Init({SessionMetadataReaderOpFn()}, &session_metadata); FunctionLibraryRuntime::Options opts; opts.source_device = "/job:a/replica:0/task:0/cpu:0"; - opts.rendezvous = rendezvous_; + opts.rendezvous = rendezvous_.get(); opts.remote_execution = true; FunctionLibraryRuntime::InstantiateOptions instantiate_opts; instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0"; @@ -1027,7 +1022,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresentAfterCloning) { TF_ASSERT_OK(flr->Clone(&cloned_lib_def, &cloned_proc_flr, &cloned_flr)); FunctionLibraryRuntime::Options opts; opts.source_device = "/job:a/replica:0/task:0/cpu:0"; - opts.rendezvous = rendezvous_; + opts.rendezvous = rendezvous_.get(); opts.remote_execution = true; FunctionLibraryRuntime::InstantiateOptions instantiate_opts; instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0"; diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.cc b/tensorflow/core/common_runtime/rendezvous_mgr.cc index d6fea8bd5d5..6ed7df2cc1e 100644 --- a/tensorflow/core/common_runtime/rendezvous_mgr.cc +++ b/tensorflow/core/common_runtime/rendezvous_mgr.cc @@ -32,38 +32,12 @@ limitations under the License. namespace tensorflow { -IntraProcessRendezvous::IntraProcessRendezvous(const DeviceMgr* device_mgr) - : device_mgr_(device_mgr), local_(NewLocalRendezvous()) {} - -IntraProcessRendezvous::~IntraProcessRendezvous() { local_->Unref(); } - -Status IntraProcessRendezvous::Send(const ParsedKey& parsed, - const Rendezvous::Args& args, - const Tensor& val, const bool is_dead) { - VLOG(1) << "IntraProcessRendezvous Send " << this << " " << parsed.FullKey(); - { - mutex_lock l(mu_); - if (!status_.ok()) return status_; - } - - // Buffers "val" and "device_context" in local_. - return local_->Send(parsed, args, val, is_dead); -} - -Status IntraProcessRendezvous::ParseKey(const string& key, bool is_src, - Rendezvous::ParsedKey* parsed) { - { - mutex_lock l(mu_); - if (!status_.ok()) return status_; - } - TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, parsed)); - return Status::OK(); -} - -void IntraProcessRendezvous::SameWorkerRecvDone( - const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args, - const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out, - StatusCallback done) { +namespace { +void SameWorkerRecvDone(const DeviceMgr* device_mgr, + const Rendezvous::ParsedKey& parsed, + const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& in, + Tensor* out, StatusCallback done) { // Do a quick copy (sharing the underlying buffer) if both tensors // are on host memory. const bool src_host = @@ -88,13 +62,13 @@ void IntraProcessRendezvous::SameWorkerRecvDone( } Device* src_device; - Status s = device_mgr_->LookupDevice(parsed.src_device, &src_device); + Status s = device_mgr->LookupDevice(parsed.src_device, &src_device); if (!s.ok()) { done(s); return; } Device* dst_device; - s = device_mgr_->LookupDevice(parsed.dst_device, &dst_device); + s = device_mgr->LookupDevice(parsed.dst_device, &dst_device); if (!s.ok()) { done(s); return; @@ -131,16 +105,18 @@ void IntraProcessRendezvous::SameWorkerRecvDone( out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute); } -void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed, - const Rendezvous::Args& recv_args, - DoneCallback done) { - VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << parsed.FullKey(); +void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr, + LocalRendezvous* local, + const RendezvousInterface::ParsedKey& parsed, + const Rendezvous::Args& recv_args, + RendezvousInterface::DoneCallback done) { + VLOG(1) << "IntraProcessRendezvous Recv " << local << " " << parsed.FullKey(); MEMDEBUG_CACHE_OP("RecvAsync"); // Recv the tensor from local_. - local_->RecvAsync( + local->RecvAsync( parsed, recv_args, - [this, parsed, done = std::move(done)]( + [device_mgr, parsed, done = std::move(done)]( const Status& status, const Rendezvous::Args& send_args, const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) mutable { @@ -156,7 +132,7 @@ void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed, }; if (status.ok() && in.IsInitialized()) { - SameWorkerRecvDone(parsed, send_args, recv_args, in, out, + SameWorkerRecvDone(device_mgr, parsed, send_args, recv_args, in, out, std::move(final_callback)); } else { final_callback(status); @@ -164,9 +140,57 @@ void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed, }); } -void IntraProcessRendezvous::StartAbort(const Status& s) { - CHECK(!s.ok()); - local_->StartAbort(s); +} // namespace + +RefCountedIntraProcessRendezvous::RefCountedIntraProcessRendezvous( + const DeviceMgr* device_mgr) + : device_mgr_(device_mgr) {} + +RefCountedIntraProcessRendezvous::~RefCountedIntraProcessRendezvous() {} + +Status RefCountedIntraProcessRendezvous::Send(const ParsedKey& key, + const Rendezvous::Args& args, + const Tensor& val, + const bool is_dead) { + VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey(); + return local_.Send(key, args, val, is_dead); +} + +void RefCountedIntraProcessRendezvous::RecvAsync(const ParsedKey& key, + const Rendezvous::Args& args, + DoneCallback done) { + VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key.FullKey(); + IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done)); +} + +void RefCountedIntraProcessRendezvous::StartAbort(const Status& s) { + local_.StartAbort(s); +} + +PrivateIntraProcessRendezvous::PrivateIntraProcessRendezvous( + const DeviceMgr* device_mgr) + : device_mgr_(device_mgr) {} + +PrivateIntraProcessRendezvous::~PrivateIntraProcessRendezvous() {} + +Status PrivateIntraProcessRendezvous::Send(const ParsedKey& key, + const Rendezvous::Args& args, + const Tensor& val, + const bool is_dead) { + DVLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey(); + return local_.Send(key, args, val, is_dead); +} + +void PrivateIntraProcessRendezvous::RecvAsync(const ParsedKey& key, + const Rendezvous::Args& args, + DoneCallback done) { + DVLOG(1) << "StackAllocatedIntraProcessRendezvous Recv " << this << " " + << key.FullKey(); + IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done)); +} + +void PrivateIntraProcessRendezvous::StartAbort(const Status& s) { + local_.StartAbort(s); } } // end namespace tensorflow diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.h b/tensorflow/core/common_runtime/rendezvous_mgr.h index b4d8ab4eb2b..eea5fbe388c 100644 --- a/tensorflow/core/common_runtime/rendezvous_mgr.h +++ b/tensorflow/core/common_runtime/rendezvous_mgr.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/local_rendezvous.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" @@ -29,60 +30,61 @@ limitations under the License. namespace tensorflow { -// IntraProcessRendezvous is a Rendezvous which expects all producers -// and consumers to be devices immediately accessible within the -// process. That is, it will never be necessary to perform an RPC to +// The IntraProcessRendezvous classes are implementations of a Rendezvous that +// expects all producers and consumers to be devices immediately accessible +// within the process. That is, it will never be necessary to perform an RPC to // communicate with either. // -// Buffering of Tensor values is delegated to a "local" Rendezvous -// obtained from NewLocalRendezvous(). This class just adds -// functionality to coordinate multiple process-local devices. -class IntraProcessRendezvous : public Rendezvous { - public: - explicit IntraProcessRendezvous(const DeviceMgr* device_mgr); +// Buffering of Tensor values is delegated to a `LocalRendezvous`. An +// IntraProcessRendezvous. just adds functionality to coordinate multiple +// process-local devices. - // Forwards to local_, where the Tensor "val" will be buffered and - // any waiting callback stored. +// Reference-counted implementation that may be shared between multiple threads. +class RefCountedIntraProcessRendezvous : public Rendezvous { + public: + explicit RefCountedIntraProcessRendezvous(const DeviceMgr* device_mgr); + + // Implementation of RendezvousInterface methods. Status Send(const ParsedKey& key, const Rendezvous::Args& args, const Tensor& val, const bool is_dead) override; - - // This method is called only by the RecvOp. It tests to see - // whether the value will be produced by a local or remote device - // and handles accordingly. In the local case it forwards to - // local_, in the remote case it initiates an RPC request. void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args, DoneCallback done) override; - void StartAbort(const Status& status) override; private: const DeviceMgr* device_mgr_; - Rendezvous* local_; // Owns a Ref on this object. + LocalRendezvous local_; - mutable mutex mu_; + ~RefCountedIntraProcessRendezvous() override; - // Status given by StartAbort() if any. - Status status_ GUARDED_BY(mu_); + TF_DISALLOW_COPY_AND_ASSIGN(RefCountedIntraProcessRendezvous); +}; - ~IntraProcessRendezvous() override; +// RefCountedIntraProcessRendezvous is aliased to IntraProcessRendezvous for +// backwards compatibility with existing users. +using IntraProcessRendezvous = RefCountedIntraProcessRendezvous; - // Parses "key" into "parsed". If "is_src" is true, checks that the - // rendezvous key's source is in this process. If "is_src" is false, - // checks that the rendezvous key's destination is in this process. - Status ParseKey(const string& key, bool is_src, - Rendezvous::ParsedKey* parsed); +// Non-reference-counted implementation that may be stack-allocated for +// performance. +// +// Prefer to use PrivateIntraProcessRendezvous in new code. +class PrivateIntraProcessRendezvous : public RendezvousInterface { + public: + explicit PrivateIntraProcessRendezvous(const DeviceMgr* device_mgr); + ~PrivateIntraProcessRendezvous() override; - // Callback handling the case when a rendezvous has been - // accomplished in local_ and the consumer is local to this process. - // Tensor "in" will be copied into "out". The key "parsed" encodes - // the src and dst devices. - typedef std::function StatusCallback; - void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed, - const Rendezvous::Args& send_args, - const Rendezvous::Args& recv_args, const Tensor& in, - Tensor* out, StatusCallback done); + // Implementation of RendezvousInterface methods. + Status Send(const ParsedKey& key, const Rendezvous::Args& args, + const Tensor& val, const bool is_dead) override; + void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args, + DoneCallback done) override; + void StartAbort(const Status& status) override; - TF_DISALLOW_COPY_AND_ASSIGN(IntraProcessRendezvous); + private: + const DeviceMgr* device_mgr_; + LocalRendezvous local_; + + TF_DISALLOW_COPY_AND_ASSIGN(PrivateIntraProcessRendezvous); }; } // end namespace tensorflow diff --git a/tensorflow/core/common_runtime/rendezvous_util.cc b/tensorflow/core/common_runtime/rendezvous_util.cc index 43ca3f1e3e0..df3e9a2452d 100644 --- a/tensorflow/core/common_runtime/rendezvous_util.cc +++ b/tensorflow/core/common_runtime/rendezvous_util.cc @@ -20,7 +20,7 @@ limitations under the License. namespace tensorflow { Status SendTensorsToRendezvous( - Rendezvous* rendezvous, DeviceContext* device_context, + RendezvousInterface* rendezvous, DeviceContext* device_context, const std::vector& alloc_attrs, const std::vector& keys, gtl::ArraySlice tensors_to_send) { if (keys.size() != tensors_to_send.size()) { @@ -54,7 +54,7 @@ Status SendTensorsToRendezvous( } void RecvOutputsFromRendezvousAsync( - Rendezvous* rendezvous, DeviceContext* device_context, + RendezvousInterface* rendezvous, DeviceContext* device_context, const std::vector& alloc_attrs, const std::vector& keys, std::vector* received_tensors, StatusCallback done) { @@ -118,7 +118,8 @@ void RecvOutputsFromRendezvousAsync( status_cb->Unref(); } -Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out, +Status RecvOutputsFromRendezvous(RendezvousInterface* rendezvous, + NamedTensors* out, const Rendezvous::Args& args) { // Receives values requested by the caller. Rendezvous::ParsedKey parsed; diff --git a/tensorflow/core/common_runtime/rendezvous_util.h b/tensorflow/core/common_runtime/rendezvous_util.h index deb9a7c8225..fe95dc0ef57 100644 --- a/tensorflow/core/common_runtime/rendezvous_util.h +++ b/tensorflow/core/common_runtime/rendezvous_util.h @@ -31,7 +31,7 @@ typedef std::function StatusCallback; // allocated. `alloc_attrs` should either be {} or should match the length of // `keys`. Status SendTensorsToRendezvous( - Rendezvous* rendezvous, DeviceContext* device_context, + RendezvousInterface* rendezvous, DeviceContext* device_context, const std::vector& alloc_attrs, const std::vector& keys, gtl::ArraySlice tensors_to_send); @@ -40,12 +40,13 @@ Status SendTensorsToRendezvous( // information as how to store the received tensors. Should be {} or match the // length of `keys`. void RecvOutputsFromRendezvousAsync( - Rendezvous* rendezvous, DeviceContext* device_context, + RendezvousInterface* rendezvous, DeviceContext* device_context, const std::vector& alloc_attrs, const std::vector& keys, std::vector* received_tensors, StatusCallback done); -Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out, +Status RecvOutputsFromRendezvous(RendezvousInterface* rendezvous, + NamedTensors* out, const Rendezvous::Args& args); } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc index 55b16b2587f..0dfcd82d737 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc @@ -277,7 +277,7 @@ void RemoteCopyNode::StartRecv(StatusCallback done) { done(status); return; } - status = captured_state_->dst()->SetTensor(outputs[0]); + status = captured_state_->dst()->SetTensor(std::move(outputs[0])); done(status); } else { // Handles captured_state_->dst_ internally. diff --git a/tensorflow/core/framework/cancellation.cc b/tensorflow/core/framework/cancellation.cc index 35c4be63ce4..a91442fcbad 100644 --- a/tensorflow/core/framework/cancellation.cc +++ b/tensorflow/core/framework/cancellation.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/framework/cancellation.h" +#include + #include "absl/memory/memory.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" @@ -29,20 +31,13 @@ CancellationManager::CancellationManager() next_cancellation_token_(0) {} CancellationManager::CancellationManager(CancellationManager* parent) - : is_cancelling_(false), - is_cancelled_(false), - next_cancellation_token_(0), - parent_(parent), - parent_token_(parent->get_cancellation_token()) { - bool registered = parent->RegisterCallback(parent_token_, - [this]() { this->StartCancel(); }); - if (!registered) { - is_cancelled_ = true; - } + : is_cancelling_(false), next_cancellation_token_(0), parent_(parent) { + is_cancelled_ = parent->RegisterChild(this); } void CancellationManager::StartCancel() { gtl::FlatMap callbacks_to_run; + std::forward_list children_to_cancel; Notification* cancelled_notification = nullptr; { mutex_lock l(mu_); @@ -52,6 +47,16 @@ void CancellationManager::StartCancel() { is_cancelling_ = true; if (state_) { std::swap(state_->callbacks, callbacks_to_run); + + // Remove all children from the list of children. + CancellationManager* child = state_->first_child; + while (child != nullptr) { + children_to_cancel.push_front(child); + child->is_removed_from_parent_ = true; + child = child->next_sibling_; + } + state_->first_child = nullptr; + cancelled_notification = &state_->cancelled_notification; } } @@ -63,6 +68,9 @@ void CancellationManager::StartCancel() { for (auto key_and_value : callbacks_to_run) { key_and_value.second(); } + for (CancellationManager* child : children_to_cancel) { + child->StartCancel(); + } { mutex_lock l(mu_); is_cancelling_ = false; @@ -113,6 +121,65 @@ bool CancellationManager::DeregisterCallback(CancellationToken token) { } } +bool CancellationManager::RegisterChild(CancellationManager* child) { + mutex_lock l(mu_); + if (is_cancelled_.load(std::memory_order_relaxed) || is_cancelling_) { + child->is_removed_from_parent_ = true; + return true; + } + + if (!state_) { + state_ = absl::make_unique(); + } + + // Push `child` onto the front of the list of children. + CancellationManager* current_head = state_->first_child; + state_->first_child = child; + child->prev_sibling_ = nullptr; + child->next_sibling_ = current_head; + if (current_head) { + current_head->prev_sibling_ = child; + } + + return false; +} + +void CancellationManager::DeregisterChild(CancellationManager* child) { + DCHECK_EQ(child->parent_, this); + Notification* cancelled_notification = nullptr; + { + mutex_lock l(mu_); + if (!child->is_removed_from_parent_) { + // Remove the child from this manager's list of children. + DCHECK(state_); + + if (child->prev_sibling_ == nullptr) { + // The child was at the head of the list. + DCHECK_EQ(state_->first_child, child); + state_->first_child = child->next_sibling_; + } else { + child->prev_sibling_->next_sibling_ = child->next_sibling_; + } + + if (child->next_sibling_ != nullptr) { + child->next_sibling_->prev_sibling_ = child->prev_sibling_; + } + + child->is_removed_from_parent_ = true; + } + if (is_cancelling_) { + cancelled_notification = &state_->cancelled_notification; + } + } + + // Wait for an ongoing call to StartCancel() to finish. This wait ensures that + // the caller of DeregisterChild does not return immediately and free a child + // that may currently be being cancelled by StartCancel(). + if (cancelled_notification) { + cancelled_notification->WaitForNotification(); + } +} + bool CancellationManager::TryDeregisterCallback(CancellationToken token) { mutex_lock lock(mu_); if (is_cancelled_ || is_cancelling_) { @@ -127,7 +194,7 @@ bool CancellationManager::TryDeregisterCallback(CancellationToken token) { CancellationManager::~CancellationManager() { if (parent_) { - parent_->DeregisterCallback(parent_token_); + parent_->DeregisterChild(this); } if (state_) { StartCancel(); diff --git a/tensorflow/core/framework/cancellation.h b/tensorflow/core/framework/cancellation.h index 7a98ee992a9..3e1727ae54a 100644 --- a/tensorflow/core/framework/cancellation.h +++ b/tensorflow/core/framework/cancellation.h @@ -147,14 +147,33 @@ class CancellationManager { struct State { Notification cancelled_notification; gtl::FlatMap callbacks; + + // If this CancellationManager has any children, this member points to the + // head of a doubly-linked list of its children. + CancellationManager* first_child = nullptr; // Not owned. }; + bool RegisterChild(CancellationManager* child); + void DeregisterChild(CancellationManager* child); + bool is_cancelling_; std::atomic_bool is_cancelled_; std::atomic next_cancellation_token_; CancellationManager* const parent_ = nullptr; // Not owned. - const CancellationToken parent_token_ = kInvalidToken; + + // If this CancellationManager is associated with a parent, this member will + // be set to `true` after this is removed from the parent's list of children. + bool is_removed_from_parent_ GUARDED_BY(parent_->mu_) = false; + + // If this CancellationManager is associated with a parent, these members form + // a doubly-linked list of that parent's children. + // + // These fields are valid only when `this->is_removed_from_parent_` is false. + CancellationManager* prev_sibling_ GUARDED_BY(parent_->mu_) = + nullptr; // Not owned. + CancellationManager* next_sibling_ GUARDED_BY(parent_->mu_) = + nullptr; // Not owned. mutex mu_; std::unique_ptr state_ GUARDED_BY(mu_); diff --git a/tensorflow/core/framework/cancellation_test.cc b/tensorflow/core/framework/cancellation_test.cc index df73526d594..e4994350ddd 100644 --- a/tensorflow/core/framework/cancellation_test.cc +++ b/tensorflow/core/framework/cancellation_test.cc @@ -15,7 +15,11 @@ limitations under the License. #include "tensorflow/core/framework/cancellation.h" +#include +#include +#include #include + #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/test.h" @@ -199,4 +203,32 @@ TEST(Cancellation, Parent_AlreadyCancelled) { EXPECT_TRUE(child.IsCancelled()); } +TEST(Cancellation, Parent_RandomDestructionOrder) { + CancellationManager parent; + std::random_device rd; + std::mt19937 g(rd()); + + // To cover the linked-list codepaths, perform multiple randomized rounds of + // registering and deregistering children with `parent`. + for (int rounds = 0; rounds < 100; ++rounds) { + std::vector> children; + + // 1. Register a random number of children with the parent. + std::uniform_int_distribution dist(1, 9); + const size_t round_size = dist(rd); + for (size_t i = 0; i < round_size; ++i) { + children.push_back(absl::make_unique(&parent)); + EXPECT_FALSE(children.back()->IsCancelled()); + } + + // 2. Deregister the children in a random order. + std::vector destruction_order(round_size); + std::iota(destruction_order.begin(), destruction_order.end(), 0); + std::shuffle(destruction_order.begin(), destruction_order.end(), g); + for (size_t index : destruction_order) { + children[index].reset(); + } + } +} + } // namespace tensorflow diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index f27fa75eb7d..a1625b48408 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/profiler/lib/traceme.h" @@ -389,7 +390,7 @@ Status DatasetBase::DatasetGraphDefBuilder::AddInputDataset( TF_RETURN_IF_ERROR(AddPlaceholder(t, output)); DCHECK_NE(ctx->input_list(), nullptr); ctx->input_list()->emplace_back((*output)->name(), std::move(t)); - LOG(WARNING) + LOG_EVERY_N_SEC(WARNING, 30) << "Input of " << dataset->DebugString() << " will not be optimized because the dataset does not implement the " "AsGraphDefInternal() method needed to apply optimizations."; diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 09a94d8b550..0e260d26592 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -687,7 +687,7 @@ class FunctionLibraryRuntime { // tensors to the remote TensorHandles in the default device. absl::optional op_id = absl::nullopt; - Rendezvous* rendezvous = nullptr; + RendezvousInterface* rendezvous = nullptr; CancellationManager* cancellation_manager = nullptr; CollectiveExecutor* collective_executor = nullptr; ScopedStepContainer* step_container = nullptr; diff --git a/tensorflow/core/framework/local_rendezvous.cc b/tensorflow/core/framework/local_rendezvous.cc new file mode 100644 index 00000000000..c21974552e7 --- /dev/null +++ b/tensorflow/core/framework/local_rendezvous.cc @@ -0,0 +1,300 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/local_rendezvous.h" + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/gtl/manual_constructor.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Represents a blocked Send() or Recv() call in the rendezvous. +struct LocalRendezvous::Item { + enum Type { kSend = 0, kRecv = 1 }; + + Item(Rendezvous::Args send_args, const Tensor& value, bool is_dead) + : Item(send_args, kSend) { + send_state.value.Init(value); + send_state.is_dead = is_dead; + } + + Item(Rendezvous::Args recv_args, Rendezvous::DoneCallback waiter, + CancellationToken cancellation_token) + : Item(recv_args, kRecv) { + recv_state.waiter.Init(std::move(waiter)); + recv_state.cancellation_token = cancellation_token; + } + + ~Item() { + if (args.device_context) { + args.device_context->Unref(); + } + if (type == kSend) { + send_state.value.Destroy(); + } else { + recv_state.waiter.Destroy(); + } + } + + const Rendezvous::Args args; + const Type type; + + // Link to next item in an ItemQueue. + Item* next = nullptr; + + // The validity of `send_state` or `recv_state` is determined by `type == + // kSend` or `type == kRecv` respectively. + union { + struct { + ManualConstructor value; + bool is_dead; + } send_state; + struct { + ManualConstructor waiter; + CancellationToken cancellation_token; + } recv_state; + }; + + private: + Item(Rendezvous::Args args, Type type) : args(args), type(type) { + if (args.device_context) { + args.device_context->Ref(); + } + } +}; + +void LocalRendezvous::ItemQueue::push_back(Item* item) { + if (TF_PREDICT_TRUE(head == nullptr)) { + // The queue is empty. + head = item; + tail = item; + } else { + DCHECK_EQ(tail->type, item->type); + tail->next = item; + tail = item; + } +} + +LocalRendezvous::~LocalRendezvous() { + if (!table_.empty()) { + StartAbort(errors::Cancelled("LocalRendezvous deleted")); + } +} + +namespace { +uint64 KeyHash(const StringPiece& k) { return Hash64(k.data(), k.size()); } +} // namespace + +Status LocalRendezvous::Send(const Rendezvous::ParsedKey& key, + const Rendezvous::Args& send_args, + const Tensor& val, const bool is_dead) { + uint64 key_hash = KeyHash(key.FullKey()); + DVLOG(2) << "Send " << this << " " << key_hash << " " << key.FullKey(); + + mu_.lock(); + if (!status_.ok()) { + // Rendezvous has been aborted. + Status s = status_; + mu_.unlock(); + return s; + } + + ItemQueue* queue = &table_[key_hash]; + if (queue->head == nullptr || queue->head->type == Item::kSend) { + // There is no waiter for this message. Append the message + // into the queue. The waiter will pick it up when arrives. + // Only send-related fields need to be filled. + // TODO(b/143786186): Investigate moving the allocation of `Item` outside + // the lock. + DVLOG(2) << "Enqueue Send Item (key:" << key.FullKey() << "). "; + queue->push_back(new Item(send_args, val, is_dead)); + mu_.unlock(); + return Status::OK(); + } + + DVLOG(2) << "Consume Recv Item (key:" << key.FullKey() << "). "; + // There is an earliest waiter to consume this message. + Item* item = queue->head; + + // Delete the queue when the last element has been consumed. + if (item->next == nullptr) { + DVLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). "; + table_.erase(key_hash); + } else { + queue->head = item->next; + } + mu_.unlock(); + + // Notify the waiter by invoking its done closure, outside the + // lock. + DCHECK_EQ(item->type, Item::kRecv); + (*item->recv_state.waiter)(Status::OK(), send_args, item->args, val, is_dead); + delete item; + return Status::OK(); +} + +void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key, + const Rendezvous::Args& recv_args, + Rendezvous::DoneCallback done) { + uint64 key_hash = KeyHash(key.FullKey()); + DVLOG(2) << "Recv " << this << " " << key_hash << " " << key.FullKey(); + + mu_.lock(); + if (!status_.ok()) { + // Rendezvous has been aborted. + Status s = status_; + mu_.unlock(); + done(s, Rendezvous::Args(), recv_args, Tensor(), false); + return; + } + + ItemQueue* queue = &table_[key_hash]; + if (queue->head == nullptr || queue->head->type == Item::kRecv) { + // There is no message to pick up. + // Only recv-related fields need to be filled. + CancellationManager* cm = recv_args.cancellation_manager; + CancellationToken token = CancellationManager::kInvalidToken; + bool already_cancelled = false; + if (cm != nullptr) { + token = cm->get_cancellation_token(); + already_cancelled = !cm->RegisterCallback(token, [this, token, key_hash] { + Item* item = nullptr; + { + mutex_lock l(mu_); + ItemQueue* queue = &table_[key_hash]; + // Find an item in the queue with a cancellation token that matches + // `token`, and remove it. + if (queue->head != nullptr && queue->head->type == Item::kRecv) { + for (Item *prev = nullptr, *curr = queue->head; curr != nullptr; + prev = curr, curr = curr->next) { + if (curr->recv_state.cancellation_token == token) { + item = curr; + if (queue->head->next == nullptr) { + // We have a single-element queue, so we can erase it from + // the table. + table_.erase(key_hash); + } else { + // Remove the current item from the queue. + if (curr == queue->head) { + DCHECK_EQ(prev, nullptr); + queue->head = curr->next; + } else { + DCHECK_NE(prev, nullptr); + prev->next = curr->next; + } + if (queue->tail == curr) { + queue->tail = prev; + } + } + break; + } + } + } + } + + if (item != nullptr) { + (*item->recv_state.waiter)( + StatusGroup::MakeDerived( + errors::Cancelled("RecvAsync is cancelled.")), + Rendezvous::Args(), item->args, Tensor(), /*is_dead=*/false); + delete item; + } + }); + } + if (already_cancelled) { + mu_.unlock(); + done(StatusGroup::MakeDerived( + errors::Cancelled("RecvAsync is cancelled.")), + Rendezvous::Args(), recv_args, Tensor(), /*is_dead=*/false); + return; + } + + DVLOG(2) << "Enqueue Recv Item (key:" << key.FullKey() << "). "; + + // TODO(b/143786186): Investigate moving the allocation of `Item` outside + // the lock. + if (cm != nullptr) { + // NOTE(mrry): We must wrap `done` with code that deregisters the + // cancellation callback before calling the `done` callback, because the + // cancellation manager may no longer be live after `done` is called. + queue->push_back(new Item( + recv_args, + [cm, token, done = std::move(done)]( + const Status& s, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& v, bool dead) { + cm->TryDeregisterCallback(token); + done(s, send_args, recv_args, v, dead); + }, + token)); + } else { + queue->push_back(new Item(recv_args, std::move(done), token)); + } + + mu_.unlock(); + return; + } + + DVLOG(2) << "Consume Send Item (key:" << key.FullKey() << "). "; + // A message has already arrived and is queued in the table under + // this key. Consumes the message and invokes the done closure. + Item* item = queue->head; + + // Delete the queue when the last element has been consumed. + if (item->next == nullptr) { + DVLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). "; + table_.erase(key_hash); + } else { + queue->head = item->next; + } + mu_.unlock(); + + // Invoke done() without holding the table lock. + DCHECK_EQ(item->type, Item::kSend); + done(Status::OK(), item->args, recv_args, *item->send_state.value, + item->send_state.is_dead); + delete item; +} + +void LocalRendezvous::StartAbort(const Status& status) { + CHECK(!status.ok()); + Table table; + { + mutex_lock l(mu_); + status_.Update(status); + table_.swap(table); + } + for (auto& p : table) { + Item* item = p.second.head; + while (item != nullptr) { + if (item->type == Item::kRecv) { + (*item->recv_state.waiter)(status, Rendezvous::Args(), + Rendezvous::Args(), Tensor(), false); + } + Item* to_delete = item; + item = item->next; + delete to_delete; + } + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/local_rendezvous.h b/tensorflow/core/framework/local_rendezvous.h new file mode 100644 index 00000000000..07c52712de7 --- /dev/null +++ b/tensorflow/core/framework/local_rendezvous.h @@ -0,0 +1,75 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_LOCAL_RENDEZVOUS_H_ +#define TENSORFLOW_CORE_FRAMEWORK_LOCAL_RENDEZVOUS_H_ + +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Implements the basic logic of matching Send and Recv operations. See +// RendezvousInterface for more details. +// +// NOTE: Most users will use a class that wraps LocalRendezvous, such as +// IntraProcessRendezvous or RemoteRendezvous. This class does not implement +// RendezvousInterface because virtual dispatch to LocalRendezvous methods +// is not expected to be needed. +class LocalRendezvous { + public: + LocalRendezvous() = default; + ~LocalRendezvous(); + + Status Send(const Rendezvous::ParsedKey& key, + const Rendezvous::Args& send_args, const Tensor& val, + const bool is_dead); + void RecvAsync(const Rendezvous::ParsedKey& key, + const Rendezvous::Args& recv_args, + Rendezvous::DoneCallback done); + void StartAbort(const Status& status); + + private: + struct Item; + + // By invariant, the item queue under each key is of the form + // [item.type == kSend]* meaning each item is a sent message. + // or + // [item.type == kRecv]* meaning each item is a waiter. + struct ItemQueue { + void push_back(Item* item); + + Item* head = nullptr; + Item* tail = nullptr; + }; + + typedef gtl::FlatMap Table; + + // TODO(zhifengc): shard table_. + mutex mu_; + Table table_ GUARDED_BY(mu_); + Status status_ GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(LocalRendezvous); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_LOCAL_RENDEZVOUS_H_ diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 66bb57f736b..959be0781be 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -277,10 +277,11 @@ OpKernelContext::OpKernelContext(Params* params) params, static_cast(params->op_kernel->output_types().size())) {} OpKernelContext::OpKernelContext(Params* params, int num_outputs) - : params_(params), - outputs_(num_outputs), - temp_memory_allocated_(0), - persistent_memory_allocated_(0) { + : params_(params), outputs_(num_outputs) { + if (params_->record_tensor_accesses || params_->track_allocations) { + tracking_state_ = absl::make_unique(); + } + params_->ensure_eigen_gpu_device(); if (params_->eigen_gpu_device != nullptr) { Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes()); @@ -291,9 +292,6 @@ OpKernelContext::OpKernelContext(Params* params, int num_outputs) SetStatus(s); } } - if (params_->record_tensor_accesses) { - referenced_tensors_.Init(); - } } OpKernelContext::~OpKernelContext() { @@ -302,12 +300,12 @@ OpKernelContext::~OpKernelContext() { delete value.tensor; } } - if (params_->record_tensor_accesses) referenced_tensors_.Destroy(); - if (params_->track_allocations && !wrapped_allocators_.empty()) { + if (params_->track_allocations && + !tracking_state_->wrapped_allocators.empty()) { LOG(WARNING) << "OpKernelContext is tracking allocations but they are not " << "being consumed by the StepStatsCollector."; - for (auto& wrapped_alloator : wrapped_allocators_) { - wrapped_alloator.second->GetRecordsAndUnRef(); + for (auto& wrapped_allocator : tracking_state_->wrapped_allocators) { + wrapped_allocator.second->GetRecordsAndUnRef(); } } } @@ -321,15 +319,17 @@ Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) { allocator = params_->device->GetAllocator(attr); } if (TF_PREDICT_FALSE(track_allocations())) { - mutex_lock lock(mu_); - for (const auto& wrapped : wrapped_allocators_) { + DCHECK(tracking_state_); + mutex_lock lock(tracking_state_->mu); + for (const auto& wrapped : tracking_state_->wrapped_allocators) { if (wrapped.first == allocator) { return wrapped.second; } } TrackingAllocator* wrapped_allocator = new TrackingAllocator(allocator, params_->track_allocations); - wrapped_allocators_.push_back(std::make_pair(allocator, wrapped_allocator)); + tracking_state_->wrapped_allocators.push_back( + std::make_pair(allocator, wrapped_allocator)); return wrapped_allocator; } else { return allocator; @@ -341,9 +341,10 @@ void OpKernelContext::SetStatus(const Status& status) { } void OpKernelContext::really_record_tensor_reference(const Tensor& tensor) { - mutex_lock l(mu_); + DCHECK(tracking_state_); + mutex_lock l(tracking_state_->mu); // Keep a reference to the underlying memory around. - referenced_tensors_->Add(tensor); + tracking_state_->referenced_tensors.Add(tensor); } Status OpKernelContext::input(StringPiece name, const Tensor** tensor) { @@ -804,8 +805,9 @@ Status OpKernelContext::allocate_temp( record_temp_memory_allocation(alloc_size, *out_temp); } } else if (record_memory_consumption_) { - mutex_lock l(stats_mu_); - temp_memory_allocated_ += out_temp->TotalBytes(); + DCHECK(tracking_state_); + mutex_lock l(tracking_state_->stats_mu); + tracking_state_->temp_memory_allocated += out_temp->TotalBytes(); } return s; } @@ -917,20 +919,18 @@ void OpKernelContext::set_output(int index, const Tensor& tensor) { record_tensor_reference(tensor); outputs_[index] = TensorValue(new Tensor(tensor)); if (track_allocations() && tensor.TotalBytes() > 0) { - mutex_lock l(stats_mu_); - if (!temp_tensor_buffer_and_size_) { - return; - } + DCHECK(tracking_state_); + mutex_lock l(tracking_state_->stats_mu); const auto it = std::find_if( - temp_tensor_buffer_and_size_->begin(), - temp_tensor_buffer_and_size_->end(), + tracking_state_->temp_tensor_buffer_and_size.begin(), + tracking_state_->temp_tensor_buffer_and_size.end(), [&tensor](const std::pair& e) { return e.first == static_cast(tensor.tensor_data().data()); }); - if (it != temp_tensor_buffer_and_size_->end()) { - temp_memory_allocated_ -= it->second; - temp_tensor_buffer_and_size_->erase(it); + if (it != tracking_state_->temp_tensor_buffer_and_size.end()) { + tracking_state_->temp_memory_allocated -= it->second; + tracking_state_->temp_tensor_buffer_and_size.erase(it); } } } @@ -1000,57 +1000,67 @@ Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs, void OpKernelContext::record_temp_memory_allocation(int64 size, const Tensor& t) { - mutex_lock l(stats_mu_); - temp_memory_allocated_ += size; - if (!temp_tensor_buffer_and_size_) { - temp_tensor_buffer_and_size_.reset( - new gtl::InlinedVector, 2>()); + if (tracking_state_) { + mutex_lock l(tracking_state_->stats_mu); + tracking_state_->temp_memory_allocated += size; + tracking_state_->temp_tensor_buffer_and_size.emplace_back( + static_cast(t.tensor_data().data()), size); } - temp_tensor_buffer_and_size_->emplace_back( - static_cast(t.tensor_data().data()), size); } int64 OpKernelContext::temp_memory_allocated() const { - mutex_lock l(stats_mu_); - return temp_memory_allocated_; + if (tracking_state_) { + mutex_lock l(tracking_state_->stats_mu); + return tracking_state_->temp_memory_allocated; + } else { + return 0; + } } void OpKernelContext::record_persistent_memory_allocation(int64 size, int64 alloc_id) { - mutex_lock l(stats_mu_); - persistent_memory_allocated_ += size; - if (alloc_id >= 0) { - if (!persistent_alloc_ids_) { - persistent_alloc_ids_.reset(new gtl::InlinedVector()); + if (tracking_state_) { + mutex_lock l(tracking_state_->stats_mu); + tracking_state_->persistent_memory_allocated += size; + if (alloc_id >= 0) { + tracking_state_->persistent_alloc_ids.push_back(alloc_id); } - persistent_alloc_ids_->push_back(alloc_id); } } int64 OpKernelContext::persistent_memory_allocated() const { - mutex_lock l(stats_mu_); - return persistent_memory_allocated_; + if (tracking_state_) { + mutex_lock l(tracking_state_->stats_mu); + return tracking_state_->persistent_memory_allocated; + } else { + return 0; + } } std::vector OpKernelContext::persistent_alloc_ids() const { - mutex_lock l(stats_mu_); - if (persistent_alloc_ids_) { - return std::vector(persistent_alloc_ids_->begin(), - persistent_alloc_ids_->end()); + if (tracking_state_) { + mutex_lock l(tracking_state_->stats_mu); + return std::vector(tracking_state_->persistent_alloc_ids.begin(), + tracking_state_->persistent_alloc_ids.end()); } else { return std::vector(); } } void OpKernelContext::clear_recorded_memory() { - mutex_lock l(stats_mu_); - temp_memory_allocated_ = 0; - persistent_memory_allocated_ = 0; - if (temp_tensor_buffer_and_size_) { - temp_tensor_buffer_and_size_->clear(); + if (tracking_state_) { + mutex_lock l(tracking_state_->stats_mu); + tracking_state_->temp_memory_allocated = 0; + tracking_state_->persistent_memory_allocated = 0; + tracking_state_->temp_tensor_buffer_and_size.clear(); + tracking_state_->persistent_alloc_ids.clear(); } - if (persistent_alloc_ids_) { - persistent_alloc_ids_->clear(); +} + +void OpKernelContext::set_record_memory_consumption(bool v) { + record_memory_consumption_ = v; + if (v && !tracking_state_) { + tracking_state_ = absl::make_unique(); } } diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 8372359e7ae..7f9895f7771 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -672,7 +672,7 @@ class OpKernelContext { // Mechanism used by this op kernel invocation to communicate with // computations running on other devices. - Rendezvous* rendezvous = nullptr; + RendezvousInterface* rendezvous = nullptr; const std::function* create_rendezvous; @@ -726,8 +726,8 @@ class OpKernelContext { const int* forward_from_array = nullptr; // For tracking actively running deferred ops. - std::function inc_num_deferred_ops_function = []() {}; - std::function dec_num_deferred_ops_function = []() {}; + std::function inc_num_deferred_ops_function; + std::function dec_num_deferred_ops_function; }; // params must outlive the OpKernelContext. @@ -1090,9 +1090,11 @@ class OpKernelContext { } gtl::InlinedVector ConsumeWrappedAllocators() { - mutex_lock lock(mu_); gtl::InlinedVector retrieved; - retrieved.swap(wrapped_allocators_); + if (tracking_state_) { + mutex_lock lock(tracking_state_->mu); + retrieved.swap(tracking_state_->wrapped_allocators); + } return retrieved; } @@ -1100,7 +1102,7 @@ class OpKernelContext { // // An op kernel communicates with outside environment through // Rendezvous Send() and Recv(). - Rendezvous* rendezvous() const { return params_->rendezvous; } + RendezvousInterface* rendezvous() const { return params_->rendezvous; } Status create_rendezvous(const int64 step_id, const DeviceMgr* device_mgr, Rendezvous** r) const { return (*params_->create_rendezvous)(step_id, device_mgr, r); @@ -1233,27 +1235,29 @@ class OpKernelContext { // Records temp memory allocation. Tensor object is recorded to identify the // case where temp memory is used as output memory. void record_temp_memory_allocation(int64 size, const Tensor& t) - LOCKS_EXCLUDED(stats_mu_); + LOCKS_EXCLUDED(tracking_state_->stats_mu); // Returns recorded size of temporary memory; - int64 temp_memory_allocated() const LOCKS_EXCLUDED(stats_mu_); + int64 temp_memory_allocated() const LOCKS_EXCLUDED(tracking_state_->stats_mu); // Records persistent memory allocation, size can be negative indicating // deallocation. void record_persistent_memory_allocation(int64 size, int64 alloc_id = -1) - LOCKS_EXCLUDED(stats_mu_); + LOCKS_EXCLUDED(tracking_state_->stats_mu); // Returns recorded size and ids of persistent memory. - int64 persistent_memory_allocated() const LOCKS_EXCLUDED(stats_mu_); + int64 persistent_memory_allocated() const + LOCKS_EXCLUDED(tracking_state_->stats_mu); - std::vector persistent_alloc_ids() const LOCKS_EXCLUDED(stats_mu_); + std::vector persistent_alloc_ids() const + LOCKS_EXCLUDED(tracking_state_->stats_mu); // Resets counters for temp and persistent memory and recorded ids. - void clear_recorded_memory() LOCKS_EXCLUDED(stats_mu_); + void clear_recorded_memory() LOCKS_EXCLUDED(tracking_state_->stats_mu); bool input_is_ref(int index) const; - void set_record_memory_consumption(bool v) { record_memory_consumption_ = v; } + void set_record_memory_consumption(bool v); // Used by OpKernel implementations to track actively running deferred ops. // @@ -1267,10 +1271,14 @@ class OpKernelContext { // functions. It then must call these two functions in pairs, before and after // device execution, respectively. TF_MUST_USE_RESULT std::function inc_num_deferred_ops_function() { - return params_->inc_num_deferred_ops_function; + return params_->inc_num_deferred_ops_function + ? params_->inc_num_deferred_ops_function + : []() {}; } TF_MUST_USE_RESULT std::function dec_num_deferred_ops_function() { - return params_->dec_num_deferred_ops_function; + return params_->dec_num_deferred_ops_function + ? params_->dec_num_deferred_ops_function + : []() {}; } Allocator* get_allocator(AllocatorAttributes attr); @@ -1312,26 +1320,30 @@ class OpKernelContext { Status status_; friend class CollectiveExecutor; // for access to params_ Params* params_; // not owned - mutable mutex mu_; // mutable so const accessors can acquire the lock - gtl::InlinedVector wrapped_allocators_ GUARDED_BY(mu_); gtl::InlinedVector outputs_; // Keep track of calls to ScopedAllocator. // TODO(ayushd): change to absl::flat_hash_set. std::unique_ptr> allocated_scope_ids_; - // Constructed only if record_tensor_accesses>. - ManualConstructor referenced_tensors_ GUARDED_BY(mu_); - // The following data members are only used when allocation tracking is - // enabled. - mutable mutex stats_mu_; - int64 temp_memory_allocated_ GUARDED_BY(stats_mu_); - int64 persistent_memory_allocated_ GUARDED_BY(stats_mu_); - std::unique_ptr, 2>> - temp_tensor_buffer_and_size_ GUARDED_BY(stats_mu_); - std::unique_ptr> persistent_alloc_ids_ - GUARDED_BY(stats_mu_); + // enabled, memory consumption is being recorded, or tensor access is being + // recorded. + struct TrackingState { + mutable mutex mu; + gtl::InlinedVector wrapped_allocators GUARDED_BY(mu); + + UniqueTensorReferences referenced_tensors GUARDED_BY(mu); + + mutable mutex stats_mu; + int64 temp_memory_allocated GUARDED_BY(stats_mu) = 0; + + int64 persistent_memory_allocated GUARDED_BY(stats_mu) = 0; + gtl::InlinedVector, 2> + temp_tensor_buffer_and_size GUARDED_BY(stats_mu); + gtl::InlinedVector persistent_alloc_ids GUARDED_BY(stats_mu); + }; + std::unique_ptr tracking_state_; TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext); }; @@ -1618,8 +1630,9 @@ inline void OpKernelContext::record_tensor_reference(const Tensor& tensor) { inline void OpKernelContext::retrieve_accessed_tensors( TensorReferenceVector* out_vector) { if (params_->record_tensor_accesses) { - mutex_lock l(mu_); - referenced_tensors_->FreezeAndReturnReferences(out_vector); + DCHECK(tracking_state_); + mutex_lock l(tracking_state_->mu); + tracking_state_->referenced_tensors.FreezeAndReturnReferences(out_vector); } } diff --git a/tensorflow/core/framework/rendezvous.cc b/tensorflow/core/framework/rendezvous.cc index ad3cf912d23..764f8995d02 100644 --- a/tensorflow/core/framework/rendezvous.cc +++ b/tensorflow/core/framework/rendezvous.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/local_rendezvous.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -113,10 +114,10 @@ Status Rendezvous::ParseKey(StringPiece key, ParsedKey* out) { return errors::InvalidArgument("Invalid rendezvous key: ", key); } -Rendezvous::~Rendezvous() {} +RendezvousInterface::~RendezvousInterface() {} -Status Rendezvous::Recv(const ParsedKey& key, const Args& recv_args, - Tensor* val, bool* is_dead, int64 timeout_ms) { +Status RendezvousInterface::Recv(const ParsedKey& key, const Args& recv_args, + Tensor* val, bool* is_dead, int64 timeout_ms) { Status ret; Notification n; RecvAsync(key, recv_args, @@ -141,308 +142,36 @@ Status Rendezvous::Recv(const ParsedKey& key, const Args& recv_args, return ret; } -Status Rendezvous::Recv(const ParsedKey& key, const Args& args, Tensor* val, - bool* is_dead) { +Status RendezvousInterface::Recv(const ParsedKey& key, const Args& args, + Tensor* val, bool* is_dead) { const int64 no_timeout = 0; return Recv(key, args, val, is_dead, no_timeout); } namespace { -class LocalRendezvousImpl : public Rendezvous { +class LocalRendezvousWrapper : public Rendezvous { public: - explicit LocalRendezvousImpl() {} + LocalRendezvousWrapper() = default; Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val, const bool is_dead) override { - uint64 key_hash = KeyHash(key.FullKey()); - DVLOG(2) << "Send " << this << " " << key_hash << " " << key.FullKey(); - - mu_.lock(); - if (!status_.ok()) { - // Rendezvous has been aborted. - Status s = status_; - mu_.unlock(); - return s; - } - - ItemQueue* queue = &table_[key_hash]; - if (queue->head == nullptr || queue->head->type == Item::kSend) { - // There is no waiter for this message. Append the message - // into the queue. The waiter will pick it up when arrives. - // Only send-related fields need to be filled. - // TODO(b/143786186): Investigate moving the allocation of `Item` outside - // the lock. - DVLOG(2) << "Enqueue Send Item (key:" << key.FullKey() << "). "; - queue->push_back(new Item(send_args, val, is_dead)); - mu_.unlock(); - return Status::OK(); - } - - DVLOG(2) << "Consume Recv Item (key:" << key.FullKey() << "). "; - // There is an earliest waiter to consume this message. - Item* item = queue->head; - - // Delete the queue when the last element has been consumed. - if (item->next == nullptr) { - DVLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). "; - table_.erase(key_hash); - } else { - queue->head = item->next; - } - mu_.unlock(); - - // Notify the waiter by invoking its done closure, outside the - // lock. - DCHECK_EQ(item->type, Item::kRecv); - (*item->recv_state.waiter)(Status::OK(), send_args, item->args, val, - is_dead); - delete item; - return Status::OK(); + return impl_.Send(key, send_args, val, is_dead); } void RecvAsync(const ParsedKey& key, const Args& recv_args, DoneCallback done) override { - uint64 key_hash = KeyHash(key.FullKey()); - DVLOG(2) << "Recv " << this << " " << key_hash << " " << key.FullKey(); - - mu_.lock(); - if (!status_.ok()) { - // Rendezvous has been aborted. - Status s = status_; - mu_.unlock(); - done(s, Args(), recv_args, Tensor(), false); - return; - } - - ItemQueue* queue = &table_[key_hash]; - if (queue->head == nullptr || queue->head->type == Item::kRecv) { - // There is no message to pick up. - // Only recv-related fields need to be filled. - CancellationManager* cm = recv_args.cancellation_manager; - CancellationToken token = CancellationManager::kInvalidToken; - bool already_cancelled = false; - if (cm != nullptr) { - token = cm->get_cancellation_token(); - already_cancelled = !cm->RegisterCallback(token, [this, token, - key_hash] { - Item* item = nullptr; - { - mutex_lock l(mu_); - ItemQueue* queue = &table_[key_hash]; - // Find an item in the queue with a cancellation token that matches - // `token`, and remove it. - if (queue->head != nullptr && queue->head->type == Item::kRecv) { - for (Item *prev = nullptr, *curr = queue->head; curr != nullptr; - prev = curr, curr = curr->next) { - if (curr->recv_state.cancellation_token == token) { - item = curr; - if (queue->head->next == nullptr) { - // We have a single-element queue, so we can erase it from - // the table. - table_.erase(key_hash); - } else { - // Remove the current item from the queue. - if (curr == queue->head) { - DCHECK_EQ(prev, nullptr); - queue->head = curr->next; - } else { - DCHECK_NE(prev, nullptr); - prev->next = curr->next; - } - if (queue->tail == curr) { - queue->tail = prev; - } - } - break; - } - } - } - } - - if (item != nullptr) { - (*item->recv_state.waiter)( - StatusGroup::MakeDerived( - errors::Cancelled("RecvAsync is cancelled.")), - Args(), item->args, Tensor(), /*is_dead=*/false); - delete item; - } - }); - } - if (already_cancelled) { - mu_.unlock(); - done(StatusGroup::MakeDerived( - errors::Cancelled("RecvAsync is cancelled.")), - Args(), recv_args, Tensor(), /*is_dead=*/false); - return; - } - - DVLOG(2) << "Enqueue Recv Item (key:" << key.FullKey() << "). "; - - // TODO(b/143786186): Investigate moving the allocation of `Item` outside - // the lock. - if (cm != nullptr) { - // NOTE(mrry): We must wrap `done` with code that deregisters the - // cancellation callback before calling the `done` callback, because the - // cancellation manager may no longer be live after `done` is called. - queue->push_back(new Item( - recv_args, - [cm, token, done = std::move(done)]( - const Status& s, const Args& send_args, const Args& recv_args, - const Tensor& v, bool dead) { - cm->TryDeregisterCallback(token); - done(s, send_args, recv_args, v, dead); - }, - token)); - } else { - queue->push_back(new Item(recv_args, std::move(done), token)); - } - - mu_.unlock(); - return; - } - - DVLOG(2) << "Consume Send Item (key:" << key.FullKey() << "). "; - // A message has already arrived and is queued in the table under - // this key. Consumes the message and invokes the done closure. - Item* item = queue->head; - - // Delete the queue when the last element has been consumed. - if (item->next == nullptr) { - DVLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). "; - table_.erase(key_hash); - } else { - queue->head = item->next; - } - mu_.unlock(); - - // Invoke done() without holding the table lock. - DCHECK_EQ(item->type, Item::kSend); - done(Status::OK(), item->args, recv_args, *item->send_state.value, - item->send_state.is_dead); - delete item; + impl_.RecvAsync(key, recv_args, std::move(done)); } - void StartAbort(const Status& status) override { - CHECK(!status.ok()); - Table table; - { - mutex_lock l(mu_); - status_.Update(status); - table_.swap(table); - } - for (auto& p : table) { - Item* item = p.second.head; - while (item != nullptr) { - if (item->type == Item::kRecv) { - (*item->recv_state.waiter)(status, Args(), Args(), Tensor(), false); - } - Item* to_delete = item; - item = item->next; - delete to_delete; - } - } - } + void StartAbort(const Status& status) override { impl_.StartAbort(status); } private: - typedef LocalRendezvousImpl ME; + LocalRendezvous impl_; - // Represents a blocked Send() or Recv() call in the rendezvous. - struct Item { - enum Type { kSend = 0, kRecv = 1 }; - - Item(Args send_args, const Tensor& value, bool is_dead) - : Item(send_args, kSend) { - send_state.value.Init(value); - send_state.is_dead = is_dead; - } - - Item(Args recv_args, DoneCallback waiter, - CancellationToken cancellation_token) - : Item(recv_args, kRecv) { - recv_state.waiter.Init(std::move(waiter)); - recv_state.cancellation_token = cancellation_token; - } - - ~Item() { - if (args.device_context) { - args.device_context->Unref(); - } - if (type == kSend) { - send_state.value.Destroy(); - } else { - recv_state.waiter.Destroy(); - } - } - - const Args args; - const Type type; - - // Link to next item in an ItemQueue. - Item* next = nullptr; - - // The validity of `send_state` or `recv_state` is determined by `type == - // kSend` or `type == kRecv` respectively. - union { - struct { - ManualConstructor value; - bool is_dead; - } send_state; - struct { - ManualConstructor waiter; - CancellationToken cancellation_token; - } recv_state; - }; - - private: - Item(Args args, Type type) : args(args), type(type) { - if (args.device_context) { - args.device_context->Ref(); - } - } - }; - - // We key the hash table by KeyHash of the Rendezvous::CreateKey string - static uint64 KeyHash(const StringPiece& k) { - return Hash64(k.data(), k.size()); - } - - // By invariant, the item queue under each key is of the form - // [item.type == kSend]* meaning each item is a sent message. - // or - // [item.type == kRecv]* meaning each item is a waiter. - struct ItemQueue { - void push_back(Item* item) { - if (TF_PREDICT_TRUE(head == nullptr)) { - // The queue is empty. - head = item; - tail = item; - } else { - DCHECK_EQ(tail->type, item->type); - tail->next = item; - tail = item; - } - } - - Item* head = nullptr; - Item* tail = nullptr; - }; - typedef gtl::FlatMap Table; - - // TODO(zhifengc): shard table_. - mutex mu_; - Table table_ GUARDED_BY(mu_); - Status status_ GUARDED_BY(mu_); - - ~LocalRendezvousImpl() override { - if (!table_.empty()) { - StartAbort(errors::Cancelled("LocalRendezvousImpl deleted")); - } - } - - TF_DISALLOW_COPY_AND_ASSIGN(LocalRendezvousImpl); + TF_DISALLOW_COPY_AND_ASSIGN(LocalRendezvousWrapper); }; } // namespace -Rendezvous* NewLocalRendezvous() { return new LocalRendezvousImpl(); } +Rendezvous* NewLocalRendezvous() { return new LocalRendezvousWrapper; } } // end namespace tensorflow diff --git a/tensorflow/core/framework/rendezvous.h b/tensorflow/core/framework/rendezvous.h index d6e910da991..b9172f63df6 100644 --- a/tensorflow/core/framework/rendezvous.h +++ b/tensorflow/core/framework/rendezvous.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_ -#define TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_ +#ifndef TENSORFLOW_CORE_FRAMEWORK_RENDEZVOUS_H_ +#define TENSORFLOW_CORE_FRAMEWORK_RENDEZVOUS_H_ #include @@ -44,7 +44,7 @@ namespace tensorflow { // been produced. A consumer has the choice of making a blocking call // or providing a callback: in either case, the consumer receives the // Tensor as soon as it is available. A producer never blocks. -class Rendezvous : public core::RefCounted { +class RendezvousInterface { public: struct Args { DeviceContext* device_context = nullptr; @@ -52,13 +52,6 @@ class Rendezvous : public core::RefCounted { CancellationManager* cancellation_manager = nullptr; // not owned. }; - // Constructs a rendezvous key for the tensor of "name" sent from - // "src_device" to "dst_device". The tensor is generated in the frame - // and iteration specified by "frame_iter". - static string CreateKey(const string& src_device, uint64 src_incarnation, - const string& dst_device, const string& name, - const FrameAndIter& frame_iter); - // Parses the key constructed by CreateKey and parse src/dst device // names into structures respectively. struct ParsedKey { @@ -81,7 +74,6 @@ class Rendezvous : public core::RefCounted { friend class RecvOp; string buf_; }; - static Status ParseKey(StringPiece key, ParsedKey* out); // The caller is a tensor producer and it sends a message (a tensor // "val" and a bool "is_dead") under the given "key". @@ -123,12 +115,28 @@ class Rendezvous : public core::RefCounted { virtual void StartAbort(const Status& status) = 0; protected: - ~Rendezvous() override; + virtual ~RendezvousInterface(); virtual bool is_cross_process() { return false; } friend class ProcessFunctionLibraryRuntime; }; +// A reference-counted implementation of RendezvousInterface. +// +// This class is used in cases where a rendezvous may be shared between multiple +// threads with no clear owner. +class Rendezvous : public RendezvousInterface, public core::RefCounted { + public: + // Constructs a rendezvous key for the tensor of "name" sent from + // "src_device" to "dst_device". The tensor is generated in the frame + // and iteration specified by "frame_iter". + static string CreateKey(const string& src_device, uint64 src_incarnation, + const string& dst_device, const string& name, + const FrameAndIter& frame_iter); + + static Status ParseKey(StringPiece key, ParsedKey* out); +}; + // Returns a Rendezvous instance that is limited to use only by // producers and consumers in the local process. The caller assumes // ownership of one Ref() on the returned object. @@ -136,4 +144,4 @@ Rendezvous* NewLocalRendezvous(); } // end namespace tensorflow -#endif // TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_ +#endif // TENSORFLOW_CORE_FRAMEWORK_RENDEZVOUS_H_ diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index aa0a6247312..b11df7e5d8a 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -306,7 +306,7 @@ class InferenceContext { // idx can be negative for an offset from end of dimensions. // idx must be in the range [-1 * s.rank, s.rank). DimensionHandle Dim(ShapeHandle s, int64 idx) { - if (s->rank_ == kUnknownRank) { + if (!s.Handle() || s->rank_ == kUnknownRank) { return UnknownDim(); } return DimKnownRank(s, idx); diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD index 3f79c023caf..fd2ea4f893c 100644 --- a/tensorflow/core/grappler/BUILD +++ b/tensorflow/core/grappler/BUILD @@ -23,14 +23,20 @@ cc_library( hdrs = ["utils.h"], visibility = ["//visibility:public"], deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:graph", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - ], + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + ], + }), ) tf_cc_test( diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 8e60d391cd8..751bf952213 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -684,28 +684,28 @@ OpLevelCostEstimator::ConvolutionDimensionsFromInputs( int x_index, y_index, channel_index; const string& data_format = GetDataFormat(op_info); if (data_format == "NCHW") { - x_index = 2; - y_index = 3; channel_index = 1; + y_index = 2; + x_index = 3; } else { // Use NHWC. - x_index = 1; - y_index = 2; + y_index = 1; + x_index = 2; channel_index = 3; } const string& filter_format = GetFilterFormat(op_info); int filter_x_index, filter_y_index, in_channel_index, out_channel_index; if (filter_format == "HWIO") { - filter_x_index = 0; - filter_y_index = 1; + filter_y_index = 0; + filter_x_index = 1; in_channel_index = 2; out_channel_index = 3; } else { // Use OIHW - filter_x_index = 2; - filter_y_index = 3; - in_channel_index = 1; out_channel_index = 0; + in_channel_index = 1; + filter_y_index = 2; + filter_x_index = 3; } int64 batch = image_shape.dim(0).size(); int64 ix = image_shape.dim(x_index).size(); @@ -1311,9 +1311,9 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation( // TODO(varomodt): should we centralize the Conv2D input/output shapes? OpInfo::TensorProperties output; if (data_format == "NCHW") { - output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.ox, dims.oy}); + output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.oy, dims.ox}); } else if (data_format == "NHWC") { - output = DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz}); + output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oy, dims.ox, dims.oz}); } // Add the operations the fused op always computes. @@ -1768,12 +1768,12 @@ OpLevelCostEstimator::OpDimensionsFromInputs( int x_index, y_index, channel_index; const string& data_format = GetDataFormat(op_info); if (data_format == "NCHW") { - x_index = 2; - y_index = 3; channel_index = 1; - } else { - x_index = 1; y_index = 2; + x_index = 3; + } else { + y_index = 1; + x_index = 2; channel_index = 3; } int64 batch = image_shape.dim(0).size(); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index d2ff480c29d..7f6940fb31d 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -1455,7 +1455,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { if (IsInPreserveSet(*node)) return false; if (IsConcat(*node) && node->attr().count("N") != 0) { const int n = node->attr().at("N").i(); - return n > 1; + return n > 1 && FirstNInputsAreUnique(*node, n); } else if ((IsSplit(*node) || IsSplitV(*node)) && node->attr().count("num_split") != 0) { const int num_split = node->attr().at("num_split").i(); @@ -1489,6 +1489,17 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { } private: + bool FirstNInputsAreUnique(const NodeDef& node, int n) const { + if (n > node.input_size()) return false; + absl::flat_hash_set unique_inputs; + const int start = node.op() == "Concat" ? 1 : 0; + const int end = start + n; + for (int i = start; i < end; ++i) { + unique_inputs.insert(node.input(i)); + } + return unique_inputs.size() == n; + } + // Returns the length of the common unary chain of ops that can be // hoisted to the other side of concat or split. Status FindCommonUnaryOpChain(const NodeDef& root_node, int* prefix_length, @@ -1525,7 +1536,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { Status HoistUnaryOpChain(const int prefix_length, const ChainLinkSet& tails, std::set* ctrl_inputs, NodeDef* root_node) { VLOG(3) << "Hoist unary op chain:" - << " root=" << root_node->name() + << " root=" << root_node->DebugString() << " prefix_length=" << prefix_length << " ctrl_inputs=[" << absl::StrJoin(*ctrl_inputs, ", ") << "]"; diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index 8941d5552b6..7572141d415 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -386,11 +386,17 @@ cc_library( hdrs = ["transitive_fanin.h"], visibility = ["//visibility:public"], deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:utils", - ], + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], + }), ) tf_cc_test( diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index f66102ab3ac..0872c0b0611 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -8210,4 +8210,72 @@ tf_cc_shared_object( ], ) -exports_files(["ops_testutil.h"]) +exports_files([ + "cwise_op_abs.cc", + "cwise_op_add_1.cc", + "cwise_op_add_2.cc", + "cwise_op_atan2.cc", + "cwise_op_cos.cc", + "cwise_op_div.cc", + "cwise_op_equal_to_1.cc", + "cwise_op_equal_to_2.cc", + "cwise_op_exp.cc", + "cwise_op_floor.cc", + "cwise_op_floor_div.cc", + "cwise_op_floor_mod.cc", + "cwise_op_gpu_add.cu.cc", + "cwise_op_gpu_atan2.cu.cc", + "cwise_op_gpu_cos.cu.cc", + "cwise_op_gpu_div.cu.cc", + "cwise_op_gpu_equal_to.cu.cc", + "cwise_op_gpu_exp.cu.cc", + "cwise_op_gpu_floor.cu.cc", + "cwise_op_gpu_floor_div.cu.cc", + "cwise_op_gpu_greater.cu.cc", + "cwise_op_gpu_greater_equal.cu.cc", + "cwise_op_gpu_less.cu.cc", + "cwise_op_gpu_less_equal.cu.cc", + "cwise_op_gpu_logical_not.cu.cc", + "cwise_op_gpu_maximum.cu.cc", + "cwise_op_gpu_minimum.cu.cc", + "cwise_op_gpu_mul.cu.cc", + "cwise_op_gpu_neg.cu.cc", + "cwise_op_gpu_round.cu.cc", + "cwise_op_gpu_rsqrt.cu.cc", + "cwise_op_gpu_select.cu.cc", + "cwise_op_gpu_sigmoid.cu.cc", + "cwise_op_gpu_sin.cu.cc", + "cwise_op_gpu_sqrt.cu.cc", + "cwise_op_gpu_squared_difference.cu.cc", + "cwise_op_gpu_sub.cu.cc", + "cwise_op_gpu_tanh.cu.cc", + "cwise_op_greater.cc", + "cwise_op_greater_equal.cc", + "cwise_op_less.cc", + "cwise_op_less_equal.cc", + "cwise_op_logical_not.cc", + "cwise_op_maximum.cc", + "cwise_op_minimum.cc", + "cwise_op_mul_1.cc", + "cwise_op_mul_2.cc", + "cwise_op_neg.cc", + "cwise_op_not_equal_to_2.cc", + "cwise_op_round.cc", + "cwise_op_rsqrt.cc", + "cwise_op_select.cc", + "cwise_op_sigmoid.cc", + "cwise_op_sin.cc", + "cwise_op_sqrt.cc", + "cwise_op_square.cc", + "cwise_op_squared_difference.cc", + "cwise_op_sub.cc", + "cwise_op_tanh.cc", + "dequantize_op.cc", + "ops_testutil.h", + "quantize_and_dequantize_op.cc", + "quantize_op.cc", + "sparse_cross_op.cc", + "sparse_fill_empty_rows_op.cc", + "sparse_reshape_op.cc", + "unary_ops_composition.cc", +]) diff --git a/tensorflow/core/kernels/check_numerics_op.cc b/tensorflow/core/kernels/check_numerics_op.cc index 63c52a11624..17cf5c37ae9 100644 --- a/tensorflow/core/kernels/check_numerics_op.cc +++ b/tensorflow/core/kernels/check_numerics_op.cc @@ -31,7 +31,7 @@ limitations under the License. #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #if GOOGLE_CUDA -#include "tensorflow/core/platform/cuda.h" +#include "tensorflow/stream_executor/cuda/cuda_activation.h" #elif TENSORFLOW_USE_ROCM #include "tensorflow/core/platform/rocm.h" #endif diff --git a/tensorflow/core/kernels/conv_grad_filter_ops_benchmark_test.cc b/tensorflow/core/kernels/conv_grad_filter_ops_benchmark_test.cc index bb6eb846408..97148945331 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops_benchmark_test.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops_benchmark_test.cc @@ -38,21 +38,22 @@ static Tensor MakeRandomTensor(const TensorShape& shape) { template static Graph* Conv2DBackpropFilter(int batch, int height, int width, - int in_depth, int filter_w, int filter_h, - int out_depth, TensorFormat data_format) { + int in_depth, int filter_h, int filter_w, + int out_depth, int stride_h, int stride_w, + Padding padding, TensorFormat data_format) { auto* graph = new Graph(OpRegistry::Global()); Tensor input_t = data_format == FORMAT_NHWC ? MakeRandomTensor({batch, height, width, in_depth}) : MakeRandomTensor({batch, in_depth, height, width}); Tensor filter_t = - MakeRandomTensor({filter_w, filter_h, in_depth, out_depth}); + MakeRandomTensor({filter_h, filter_w, in_depth, out_depth}); // Compute dimensions for the `out_backprop` tensor. Conv2DParameters params; params.dilations = {1, 1, 1, 1}; - params.strides = {1, 1, 1, 1}; - params.padding = Padding::SAME; + params.strides = {1, stride_h, stride_w, 1}; + params.padding = padding; params.data_format = data_format; Conv2DDimensions conv2d_dims; @@ -83,23 +84,19 @@ static Graph* Conv2DBackpropFilter(int batch, int height, int width, .Input(filter_dims) .Input(backprop) .Attr("T", DataTypeToEnum::value) - .Attr("strides", {1, 1, 1, 1}) - .Attr("padding", "SAME") + .Attr("strides", {1, stride_h, stride_w, 1}) + .Attr("padding", padding == Padding::SAME + ? "SAME" + : padding == Padding::VALID ? "VALID" : "N/A") .Attr("data_format", ToString(data_format)) .Finalize(graph, &conv2d)); return graph; } -// -------------------------------------------------------------------------- // -// The following benchmarks are used to compare different data format -// performance for different data types. They make sense only when CUDA enabled, -// because on CPU we only support data in NHWC. -// -------------------------------------------------------------------------- // - // Macro arguments names: --------------------------------------------------- // // T: data type -// FORMAT: data format (NHWC or NCHW) +// FMT: data format (NHWC or NCHW) // N: batch size // H: height // W: width @@ -107,57 +104,79 @@ static Graph* Conv2DBackpropFilter(int batch, int height, int width, // FC: filter count // FH: filter height // FW: filter width +// SH: stride height +// SW: stride width -#define BM_NAME(name, type, T, FORMAT, N, H, W, C, FW, FH, FC) \ - name##_##T##_##FORMAT##_##type##_##N##_##H##_##W##_##C##_##FW##_##FH##_##FC +#define BM_CONCAT(a, b) a##_##b -#define BM_Conv2DBwdFilterFmt(T, FORMAT, N, H, W, C, FW, FH, FC, type) \ - static void BM_NAME(BM_Conv2DBackpropFilter, type, T, FORMAT, N, H, W, C, \ - FW, FH, FC)(int iters) { \ - testing::ItemsProcessed(static_cast(iters) * (N) * (H) * (W) * \ - (C)); \ - test::Benchmark(#type, Conv2DBackpropFilter(N, H, W, C, FW, FH, FC, \ - FORMAT_##FORMAT)) \ - .Run(iters); \ - } \ - BENCHMARK(BM_NAME(BM_Conv2DBackpropFilter, type, T, FORMAT, N, H, W, C, FW, \ - FH, FC)); +#define BM_NAME(name, type, T, FMT, N, H, W, C, FH, FW, FC, SH, SW, PADDING) \ + BM_CONCAT(name##_##T##_##FMT##_##type##_in##N##x##H##x##W##x##C, \ + f##FH##x##FW##x##FC##_##s##SH##x##SW##_##PADDING) + +#define BM_Conv2DBwdFilter(T, FMT, N, H, W, C, FH, FW, FC, SH, SW, PADDING, \ + type) \ + static void BM_NAME(BM_Conv2DBackpropFilter, type, T, FMT, N, H, W, C, FH, \ + FW, FC, SH, SW, PADDING)(int iters) { \ + testing::ItemsProcessed(static_cast(iters) * (N) * (H) * (W) * \ + (C)); \ + test::Benchmark(#type, Conv2DBackpropFilter(N, H, W, C, FH, FW, FC, SH, \ + SW, PADDING, FORMAT_##FMT)) \ + .Run(iters); \ + } \ + BENCHMARK(BM_NAME(BM_Conv2DBackpropFilter, type, T, FMT, N, H, W, C, FH, FW, \ + FC, SH, SW, PADDING)); + +// ResNet50-ish convolutions. +#define BENCHMARK_DTYPE(FMT, BATCH, T, D) \ + BM_Conv2DBwdFilter(T, FMT, BATCH, 112, 112, 64, 2, 2, 64, 2, 2, SAME, D); \ + BM_Conv2DBwdFilter(T, FMT, BATCH, 56, 56, 128, 2, 2, 128, 2, 2, SAME, D); \ + BM_Conv2DBwdFilter(T, FMT, BATCH, 28, 28, 256, 2, 2, 256, 2, 2, SAME, D); \ + \ + BM_Conv2DBwdFilter(T, FMT, BATCH, 112, 112, 64, 2, 2, 64, 2, 2, VALID, D); \ + BM_Conv2DBwdFilter(T, FMT, BATCH, 56, 56, 128, 2, 2, 128, 2, 2, VALID, D); \ + BM_Conv2DBwdFilter(T, FMT, BATCH, 28, 28, 256, 2, 2, 256, 2, 2, VALID, D); \ + \ + BM_Conv2DBwdFilter(T, FMT, BATCH, 56, 56, 64, 1, 1, 64, 1, 1, SAME, D); \ + BM_Conv2DBwdFilter(T, FMT, BATCH, 56, 56, 64, 1, 1, 256, 1, 1, SAME, D); \ + BM_Conv2DBwdFilter(T, FMT, BATCH, 56, 56, 256, 1, 1, 64, 1, 1, SAME, D); \ + BM_Conv2DBwdFilter(T, FMT, BATCH, 56, 56, 64, 3, 3, 64, 1, 1, SAME, D); \ + \ + BM_Conv2DBwdFilter(T, FMT, BATCH, 28, 28, 128, 1, 1, 128, 1, 1, SAME, D); \ + BM_Conv2DBwdFilter(T, FMT, BATCH, 28, 28, 128, 1, 1, 512, 1, 1, SAME, D); \ + BM_Conv2DBwdFilter(T, FMT, BATCH, 28, 28, 512, 1, 1, 128, 1, 1, SAME, D); \ + BM_Conv2DBwdFilter(T, FMT, BATCH, 28, 28, 512, 3, 3, 128, 1, 1, SAME, D); \ + \ + BM_Conv2DBwdFilter(T, FMT, BATCH, 14, 14, 256, 1, 1, 256, 1, 1, SAME, D); \ + BM_Conv2DBwdFilter(T, FMT, BATCH, 14, 14, 256, 1, 1, 1024, 1, 1, SAME, D); \ + BM_Conv2DBwdFilter(T, FMT, BATCH, 14, 14, 1024, 1, 1, 256, 1, 1, SAME, D); \ + BM_Conv2DBwdFilter(T, FMT, BATCH, 14, 14, 256, 3, 3, 256, 1, 1, SAME, D); -#if GOOGLE_CUDA using fp32 = float; using fp16 = Eigen::half; -// ResNet50-ish convolutions. -#define BENCHMARK_DTYPE(FORMAT, BATCH, T) \ - BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 56, 56, 64, 1, 1, 64, gpu); \ - BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 56, 56, 64, 1, 1, 256, gpu); \ - BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 56, 56, 256, 1, 1, 64, gpu); \ - BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 56, 56, 64, 3, 3, 64, gpu); \ - \ - BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 28, 28, 128, 1, 1, 128, gpu); \ - BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 28, 28, 128, 1, 1, 512, gpu); \ - BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 28, 28, 512, 1, 1, 128, gpu); \ - BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 28, 28, 512, 3, 3, 128, gpu); \ - \ - BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 14, 14, 256, 1, 1, 256, gpu); \ - BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 14, 14, 256, 1, 1, 1024, gpu); \ - BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 14, 14, 1024, 1, 1, 256, gpu); \ - BM_Conv2DBwdFilterFmt(T, FORMAT, BATCH, 14, 14, 256, 3, 3, 256, gpu); +BENCHMARK_DTYPE(NHWC, 8, fp32, cpu); +BENCHMARK_DTYPE(NHWC, 16, fp32, cpu); +BENCHMARK_DTYPE(NHWC, 32, fp32, cpu); -BENCHMARK_DTYPE(NHWC, 32, fp32); -BENCHMARK_DTYPE(NCHW, 32, fp32); +#if GOOGLE_CUDA +// -------------------------------------------------------------------------- // +// The following benchmarks are used to compare different data format +// performance for different data types. They make sense only when CUDA enabled, +// because on CPU we only support data in NHWC. +// -------------------------------------------------------------------------- // -BENCHMARK_DTYPE(NHWC, 32, fp16); -BENCHMARK_DTYPE(NCHW, 32, fp16); +BENCHMARK_DTYPE(NHWC, 32, fp32, gpu); +BENCHMARK_DTYPE(NCHW, 32, fp32, gpu); -BENCHMARK_DTYPE(NHWC, 64, fp32); -BENCHMARK_DTYPE(NCHW, 64, fp32); +BENCHMARK_DTYPE(NHWC, 32, fp16, gpu); +BENCHMARK_DTYPE(NCHW, 32, fp16, gpu); -BENCHMARK_DTYPE(NHWC, 64, fp16); -BENCHMARK_DTYPE(NCHW, 64, fp16); +BENCHMARK_DTYPE(NHWC, 64, fp32, gpu); +BENCHMARK_DTYPE(NCHW, 64, fp32, gpu); + +BENCHMARK_DTYPE(NHWC, 64, fp16, gpu); +BENCHMARK_DTYPE(NCHW, 64, fp16, gpu); #endif // GOOGLE_CUDA -BM_Conv2DBwdFilterFmt(float, NHWC, 8, 32, 32, 128, 1, 1, 128, cpu); - } // namespace tensorflow diff --git a/tensorflow/core/kernels/conv_grad_input_ops_benchmark_test.cc b/tensorflow/core/kernels/conv_grad_input_ops_benchmark_test.cc index ee4d2800ca7..713c935dcf7 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops_benchmark_test.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops_benchmark_test.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include "tensorflow/cc/ops/standard_ops.h" @@ -38,8 +37,9 @@ static Tensor MakeRandomTensor(const TensorShape& shape) { template static Graph* Conv2DBackpropInput(int batch, int height, int width, - int in_depth, int filter_w, int filter_h, - int out_depth, TensorFormat data_format) { + int in_depth, int filter_h, int filter_w, + int out_depth, int stride_h, int stride_w, + Padding padding, TensorFormat data_format) { auto* graph = new Graph(OpRegistry::Global()); Tensor input_t = data_format == FORMAT_NHWC @@ -51,8 +51,8 @@ static Graph* Conv2DBackpropInput(int batch, int height, int width, // Compute dimensions for the `out_backprop` tensor. Conv2DParameters params; params.dilations = {1, 1, 1, 1}; - params.strides = {1, 1, 1, 1}; - params.padding = Padding::SAME; + params.strides = {1, stride_h, stride_w, 1}; + params.padding = padding; params.data_format = data_format; Conv2DDimensions conv2d_dims; @@ -83,23 +83,19 @@ static Graph* Conv2DBackpropInput(int batch, int height, int width, .Input(filter) .Input(backprop) .Attr("T", DataTypeToEnum::value) - .Attr("strides", {1, 1, 1, 1}) - .Attr("padding", "SAME") + .Attr("strides", {1, stride_h, stride_w, 1}) + .Attr("padding", padding == Padding::SAME + ? "SAME" + : padding == Padding::VALID ? "VALID" : "N/A") .Attr("data_format", ToString(data_format)) .Finalize(graph, &conv2d)); return graph; } -// -------------------------------------------------------------------------- // -// The following benchmarks are used to compare different data format -// performance for different data types. They make sense only when CUDA enabled, -// because on CPU we only support data in NHWC. -// -------------------------------------------------------------------------- // - // Macro arguments names: --------------------------------------------------- // // T: data type -// FORMAT: data format (NHWC or NCHW) +// FMT: data format (NHWC or NCHW) // N: batch size // H: height // W: width @@ -107,54 +103,78 @@ static Graph* Conv2DBackpropInput(int batch, int height, int width, // FC: filter count // FH: filter height // FW: filter width +// SH: stride height +// SW: stride width -#define BM_NAME(name, type, T, FORMAT, N, H, W, C, FW, FH, FC) \ - name##_##T##_##FORMAT##_##type##_##N##_##H##_##W##_##C##_##FW##_##FH##_##FC +#define BM_CONCAT(a, b) a##_##b -#define BM_Conv2DBwdInputFmt(T, FORMAT, N, H, W, C, FW, FH, FC, type) \ - static void BM_NAME(BM_Conv2DBackpropInput, type, T, FORMAT, N, H, W, C, FW, \ - FH, FC)(int iters) { \ - testing::ItemsProcessed(static_cast(iters) * (N) * (H) * (W) * \ - (C)); \ - test::Benchmark(#type, Conv2DBackpropInput(N, H, W, C, FW, FH, FC, \ - FORMAT_##FORMAT)) \ - .Run(iters); \ - } \ - BENCHMARK(BM_NAME(BM_Conv2DBackpropInput, type, T, FORMAT, N, H, W, C, FW, \ - FH, FC)); +#define BM_NAME(name, type, T, FMT, N, H, W, C, FH, FW, FC, SH, SW, PADDING) \ + BM_CONCAT(name##_##T##_##FMT##_##type##_in##N##x##H##x##W##x##C, \ + f##FH##x##FW##x##FC##_##s##SH##x##SW##_##PADDING) + +#define BM_Conv2DBwdInput(T, FMT, N, H, W, C, FW, FH, FC, SH, SW, PADDING, \ + type) \ + static void BM_NAME(BM_Conv2DBackpropInput, type, T, FMT, N, H, W, C, FH, \ + FW, FC, SH, SW, PADDING)(int iters) { \ + testing::ItemsProcessed(static_cast(iters) * (N) * (H) * (W) * \ + (C)); \ + test::Benchmark(#type, Conv2DBackpropInput(N, H, W, C, FH, FW, FC, SH, \ + SW, PADDING, FORMAT_##FMT)) \ + .Run(iters); \ + } \ + BENCHMARK(BM_NAME(BM_Conv2DBackpropInput, type, T, FMT, N, H, W, C, FH, FW, \ + FC, SH, SW, PADDING)); -#if GOOGLE_CUDA using fp32 = float; using fp16 = Eigen::half; // ResNet50-ish convolutions. -#define BENCHMARK_DTYPE(FORMAT, BATCH, T) \ - BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 56, 56, 64, 1, 1, 64, gpu); \ - BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 56, 56, 64, 1, 1, 256, gpu); \ - BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 56, 56, 256, 1, 1, 64, gpu); \ - BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 56, 56, 64, 3, 3, 64, gpu); \ - \ - BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 28, 28, 128, 1, 1, 128, gpu); \ - BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 28, 28, 128, 1, 1, 512, gpu); \ - BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 28, 28, 512, 1, 1, 128, gpu); \ - BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 28, 28, 512, 3, 3, 128, gpu); \ - \ - BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 14, 14, 256, 1, 1, 256, gpu); \ - BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 14, 14, 256, 1, 1, 1024, gpu); \ - BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 14, 14, 1024, 1, 1, 256, gpu); \ - BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 14, 14, 256, 3, 3, 256, gpu); +#define BENCHMARK_DTYPE(FMT, BATCH, T, D) \ + BM_Conv2DBwdInput(T, FMT, BATCH, 112, 112, 64, 2, 2, 64, 2, 2, SAME, D); \ + BM_Conv2DBwdInput(T, FMT, BATCH, 56, 56, 128, 2, 2, 128, 2, 2, SAME, D); \ + BM_Conv2DBwdInput(T, FMT, BATCH, 28, 28, 256, 2, 2, 256, 2, 2, SAME, D); \ + \ + BM_Conv2DBwdInput(T, FMT, BATCH, 112, 112, 64, 2, 2, 64, 2, 2, VALID, D); \ + BM_Conv2DBwdInput(T, FMT, BATCH, 56, 56, 128, 2, 2, 128, 2, 2, VALID, D); \ + BM_Conv2DBwdInput(T, FMT, BATCH, 28, 28, 256, 2, 2, 256, 2, 2, VALID, D); \ + \ + BM_Conv2DBwdInput(T, FMT, BATCH, 56, 56, 64, 1, 1, 64, 1, 1, SAME, D); \ + BM_Conv2DBwdInput(T, FMT, BATCH, 56, 56, 64, 1, 1, 256, 1, 1, SAME, D); \ + BM_Conv2DBwdInput(T, FMT, BATCH, 56, 56, 256, 1, 1, 64, 1, 1, SAME, D); \ + BM_Conv2DBwdInput(T, FMT, BATCH, 56, 56, 64, 3, 3, 64, 1, 1, SAME, D); \ + \ + BM_Conv2DBwdInput(T, FMT, BATCH, 28, 28, 128, 1, 1, 128, 1, 1, SAME, D); \ + BM_Conv2DBwdInput(T, FMT, BATCH, 28, 28, 128, 1, 1, 512, 1, 1, SAME, D); \ + BM_Conv2DBwdInput(T, FMT, BATCH, 28, 28, 512, 1, 1, 128, 1, 1, SAME, D); \ + BM_Conv2DBwdInput(T, FMT, BATCH, 28, 28, 512, 3, 3, 128, 1, 1, SAME, D); \ + \ + BM_Conv2DBwdInput(T, FMT, BATCH, 14, 14, 256, 1, 1, 256, 1, 1, SAME, D); \ + BM_Conv2DBwdInput(T, FMT, BATCH, 14, 14, 256, 1, 1, 1024, 1, 1, SAME, D); \ + BM_Conv2DBwdInput(T, FMT, BATCH, 14, 14, 1024, 1, 1, 256, 1, 1, SAME, D); \ + BM_Conv2DBwdInput(T, FMT, BATCH, 14, 14, 256, 3, 3, 256, 1, 1, SAME, D); -BENCHMARK_DTYPE(NHWC, 32, fp32); -BENCHMARK_DTYPE(NCHW, 32, fp32); +BENCHMARK_DTYPE(NHWC, 8, fp32, cpu); +BENCHMARK_DTYPE(NHWC, 16, fp32, cpu); +BENCHMARK_DTYPE(NHWC, 32, fp32, cpu); -BENCHMARK_DTYPE(NHWC, 32, fp16); -BENCHMARK_DTYPE(NCHW, 32, fp16); +#if GOOGLE_CUDA +// -------------------------------------------------------------------------- // +// The following benchmarks are used to compare different data format +// performance for different data types. They make sense only when CUDA enabled, +// because on CPU we only support data in NHWC. +// -------------------------------------------------------------------------- // -BENCHMARK_DTYPE(NHWC, 64, fp32); -BENCHMARK_DTYPE(NCHW, 64, fp32); +BENCHMARK_DTYPE(NHWC, 32, fp32, gpu); +BENCHMARK_DTYPE(NCHW, 32, fp32, gpu); -BENCHMARK_DTYPE(NHWC, 64, fp16); -BENCHMARK_DTYPE(NCHW, 64, fp16); +BENCHMARK_DTYPE(NHWC, 32, fp16, gpu); +BENCHMARK_DTYPE(NCHW, 32, fp16, gpu); + +BENCHMARK_DTYPE(NHWC, 64, fp32, gpu); +BENCHMARK_DTYPE(NCHW, 64, fp32, gpu); + +BENCHMARK_DTYPE(NHWC, 64, fp16, gpu); +BENCHMARK_DTYPE(NCHW, 64, fp16, gpu); #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/crop_and_resize_op.cc b/tensorflow/core/kernels/crop_and_resize_op.cc index 892fadd51dd..5223501997e 100644 --- a/tensorflow/core/kernels/crop_and_resize_op.cc +++ b/tensorflow/core/kernels/crop_and_resize_op.cc @@ -40,7 +40,7 @@ limitations under the License. #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #if GOOGLE_CUDA -#include "tensorflow/core/platform/cuda.h" +#include "tensorflow/stream_executor/cuda/cuda_activation.h" using stream_executor::cuda::ScopedActivateExecutorContext; #elif TENSORFLOW_USE_ROCM #include "tensorflow/core/platform/rocm.h" diff --git a/tensorflow/core/kernels/cuda_solvers.cc b/tensorflow/core/kernels/cuda_solvers.cc index 9abf5439571..1c569204265 100644 --- a/tensorflow/core/kernels/cuda_solvers.cc +++ b/tensorflow/core/kernels/cuda_solvers.cc @@ -30,10 +30,10 @@ #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" -#include "tensorflow/core/platform/cuda.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/cuda/cuda_activation.h" // The CUDA cublas_api.h API contains const-correctness errors. Instead of // casting away constness on our data, we instead reinterpret the CuBLAS diff --git a/tensorflow/core/kernels/reduction_ops.h b/tensorflow/core/kernels/reduction_ops.h index 3c62dcfc081..46d8051fff1 100644 --- a/tensorflow/core/kernels/reduction_ops.h +++ b/tensorflow/core/kernels/reduction_ops.h @@ -72,6 +72,34 @@ struct ReduceEigenImpl \ + struct ReduceEigenImpl> { \ + void operator()(const Device& d, OUT_T out, IN_T in, \ + const ReductionAxes& reduction_axes, \ + const functor::MeanReducer& reducer) { \ + static_assert(std::is_same::value, \ + ""); \ + Eigen::internal::SumReducer sum_reducer; \ + out.device(d) = (in.template cast().reduce( \ + reduction_axes, sum_reducer) / \ + static_cast(in.size() / out.size())) \ + .template cast(); \ + } \ + } + +CASTING_SPECIALIZATION(uint8, uint64); +CASTING_SPECIALIZATION(uint16, uint64); +CASTING_SPECIALIZATION(uint32, uint64); +CASTING_SPECIALIZATION(int8, int64); +CASTING_SPECIALIZATION(int16, int64); +CASTING_SPECIALIZATION(int32, int64); +#undef CASTING_SPECIALIZATION + // TODO(rmlarsen): Refactor this such that taking the sqrt can be optional // controlled by an attribute. template { std::vector ys(out_height + 1); std::vector xs(out_width + 1); + // Compute the cached interpolation weights on the x and y dimensions. if (half_pixel_centers) { compute_interpolation_weights(HalfPixelScaler(), out_height, in_height, height_scale, ys.data()); @@ -237,7 +238,6 @@ struct ResizeBilinear { width_scale, xs.data()); } else { - // Compute the cached interpolation weights on the x and y dimensions. compute_interpolation_weights(LegacyScaler(), out_height, in_height, height_scale, ys.data()); compute_interpolation_weights(LegacyScaler(), out_width, in_width, @@ -309,13 +309,16 @@ struct ResizeBilinearGrad { output_grad.setZero(); - // Each resized pixel was computed as a weighted average of four input - // pixels. Here we find the pixels that contributed to each output pixel - // and add the corresponding coefficient to the gradient. - // resized(b, y, x, c) = top_left * (1 - y) * (1 - x) - // + top_right * (1 - y) * x - // + bottom_left * y * (1 - x) - // + bottom_right * y * x + // Each resized output pixel was computed as a weighted average of four + // input pixels. Here we find the four input pixel locations that + // contributed to each output pixel and propgate the gradient at the output + // pixel location to each of those four input pixel locations in the same + // proportions that they originally contributed to the output pixel. + // Here is the forward-propagation pseudo-code, for reference: + // resized(b, y, x, c) = top_left * (1 - y) * (1 - x) + // + top_right * (1 - y) * x + // + bottom_left * y * (1 - x) + // + bottom_right * y * x for (Eigen::Index b = 0; b < batch; ++b) { for (Eigen::Index y = 0; y < resized_height; ++y) { const float in_y = scaler(y, height_scale); diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl.h b/tensorflow/core/kernels/segment_reduction_ops_impl.h index 5aa05faab97..a472655d3e0 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_impl.h +++ b/tensorflow/core/kernels/segment_reduction_ops_impl.h @@ -45,7 +45,7 @@ limitations under the License. #if GOOGLE_CUDA #include "tensorflow/core/kernels/cuda_solvers.h" -#include "tensorflow/core/platform/cuda.h" +#include "tensorflow/stream_executor/cuda/cuda_activation.h" using stream_executor::cuda::ScopedActivateExecutorContext; #elif TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc b/tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc index 94cbae3185f..6e0397c8d27 100644 --- a/tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc +++ b/tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc @@ -36,7 +36,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/cuda_sparse.h" -#include "tensorflow/core/platform/cuda.h" +#include "tensorflow/stream_executor/cuda/cuda_activation.h" using ::perftools::gputools::cuda::ScopedActivateExecutorContext; #endif diff --git a/tensorflow/core/kernels/sparse/sparse_tensor_to_csr_sparse_matrix_op.cc b/tensorflow/core/kernels/sparse/sparse_tensor_to_csr_sparse_matrix_op.cc index f791b6c5105..3ecebfe0ac7 100644 --- a/tensorflow/core/kernels/sparse/sparse_tensor_to_csr_sparse_matrix_op.cc +++ b/tensorflow/core/kernels/sparse/sparse_tensor_to_csr_sparse_matrix_op.cc @@ -34,7 +34,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/cuda_sparse.h" -#include "tensorflow/core/platform/cuda.h" +#include "tensorflow/stream_executor/cuda/cuda_activation.h" using ::perftools::gputools::cuda::ScopedActivateExecutorContext; #endif diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 3490dc1ee80..467087b7864 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -670,7 +670,9 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -682,7 +684,9 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); @@ -849,7 +853,9 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -861,7 +867,9 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); @@ -1340,7 +1348,9 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -1352,7 +1362,9 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); @@ -1456,7 +1468,9 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -1468,7 +1482,9 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); @@ -2957,7 +2973,9 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -2969,7 +2987,9 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); @@ -3195,7 +3215,9 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -3207,7 +3229,9 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); @@ -3337,7 +3361,9 @@ DECLARE_GPU_SPEC(float, int32); DECLARE_GPU_SPEC(float, int64); DECLARE_GPU_SPEC(double, int32); DECLARE_GPU_SPEC(double, int64); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64, int32); DECLARE_GPU_SPEC(complex64, int64); @@ -3355,7 +3381,9 @@ DECLARE_GPU_SPEC(complex128, int64); REGISTER_GPU_KERNELS(Eigen::half); REGISTER_GPU_KERNELS(float); REGISTER_GPU_KERNELS(double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_GPU_KERNELS(complex64); REGISTER_GPU_KERNELS(complex128); @@ -3622,7 +3650,9 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -3634,7 +3664,9 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); @@ -4151,7 +4183,9 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -4163,7 +4197,9 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc index 0995b31e734..8b7f5dc2e40 100644 --- a/tensorflow/core/kernels/training_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc @@ -524,7 +524,9 @@ struct ApplyPowerSign { template struct functor::ApplyGradientDescent; template struct functor::ApplyGradientDescent; template struct functor::ApplyGradientDescent; -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS template struct functor::ApplyGradientDescent; template struct functor::ApplyGradientDescent; @@ -534,7 +536,9 @@ template struct functor::ApplyGradientDescent; template struct functor::ApplyAdagrad; template struct functor::ApplyAdagrad; template struct functor::ApplyAdagrad; -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS template struct functor::ApplyAdagrad; template struct functor::ApplyAdagrad; @@ -544,7 +548,9 @@ template struct functor::ApplyAdagrad; template struct functor::ApplyAdagradV2; template struct functor::ApplyAdagradV2; template struct functor::ApplyAdagradV2; -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS template struct functor::ApplyAdagradV2; template struct functor::ApplyAdagradV2; @@ -554,7 +560,9 @@ template struct functor::ApplyAdagradV2; template struct functor::ApplyAdadelta; template struct functor::ApplyAdadelta; template struct functor::ApplyAdadelta; -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS template struct functor::ApplyAdadelta; template struct functor::ApplyAdadelta; @@ -572,7 +580,9 @@ template struct functor::ApplyFtrlV2; template struct functor::ApplyMomentum; template struct functor::ApplyMomentum; template struct functor::ApplyMomentum; -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS template struct functor::ApplyMomentum; template struct functor::ApplyMomentum; @@ -582,7 +592,9 @@ template struct functor::ApplyMomentum; template struct functor::ApplyKerasMomentum; template struct functor::ApplyKerasMomentum; template struct functor::ApplyKerasMomentum; -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS template struct functor::ApplyKerasMomentum; template struct functor::ApplyKerasMomentum; @@ -597,7 +609,9 @@ template struct functor::SparseApplyKerasMomentum; template struct functor::SparseApplyKerasMomentum; template struct functor::SparseApplyKerasMomentum; template struct functor::SparseApplyKerasMomentum; -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS template struct functor::SparseApplyKerasMomentum; template struct functor::SparseApplyKerasMomentum; @@ -609,7 +623,9 @@ template struct functor::SparseApplyKerasMomentum; template struct functor::ApplyAdam; template struct functor::ApplyAdam; template struct functor::ApplyAdam; -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS template struct functor::ApplyAdam; template struct functor::ApplyAdam; @@ -627,7 +643,9 @@ template struct functor::ApplyAdaMax; template struct functor::ApplyRMSProp; template struct functor::ApplyRMSProp; template struct functor::ApplyRMSProp; -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS template struct functor::ApplyRMSProp; template struct functor::ApplyRMSProp; @@ -637,7 +655,9 @@ template struct functor::ApplyRMSProp; template struct functor::ApplyCenteredRMSProp; template struct functor::ApplyCenteredRMSProp; template struct functor::ApplyCenteredRMSProp; -#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && \ + !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support + // complex sqrt #ifndef PLATFORM_WINDOWS template struct functor::ApplyCenteredRMSProp; template struct functor::ApplyCenteredRMSProp; diff --git a/tensorflow/core/kernels/where_op.cc b/tensorflow/core/kernels/where_op.cc index 2030512bfd2..318894bfce4 100644 --- a/tensorflow/core/kernels/where_op.cc +++ b/tensorflow/core/kernels/where_op.cc @@ -41,7 +41,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/kernels/cuda_solvers.h" #if GOOGLE_CUDA -#include "tensorflow/core/platform/cuda.h" +#include "tensorflow/stream_executor/cuda/cuda_activation.h" using stream_executor::cuda::ScopedActivateExecutorContext; #elif TENSORFLOW_USE_ROCM #include "tensorflow/core/platform/rocm.h" diff --git a/tensorflow/core/lib/core/BUILD b/tensorflow/core/lib/core/BUILD index baf81113029..a3ed21f8771 100644 --- a/tensorflow/core/lib/core/BUILD +++ b/tensorflow/core/lib/core/BUILD @@ -37,8 +37,7 @@ cc_library( name = "blocking_counter", hdrs = ["blocking_counter.h"], deps = [ - "//tensorflow/core/platform:logging", - "//tensorflow/core/platform:mutex", + "//tensorflow/core/platform:blocking_counter", ], ) @@ -53,13 +52,9 @@ cc_library( cc_library( name = "coding", - srcs = ["coding.cc"], hdrs = ["coding.h"], deps = [ - "//tensorflow/core/lib/core:raw_coding", - "//tensorflow/core/lib/core:stringpiece", - "//tensorflow/core/platform:byte_order", - "//tensorflow/core/platform:types", + "//tensorflow/core/platform:coding", ], ) @@ -172,7 +167,6 @@ filegroup( srcs = [ "arena.cc", "bitmap.cc", - "coding.cc", ], visibility = ["//tensorflow/core:__pkg__"], ) diff --git a/tensorflow/core/lib/core/blocking_counter.h b/tensorflow/core/lib/core/blocking_counter.h index 5dab07dbef9..8355a7ac870 100644 --- a/tensorflow/core/lib/core/blocking_counter.h +++ b/tensorflow/core/lib/core/blocking_counter.h @@ -16,65 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_LIB_CORE_BLOCKING_COUNTER_H_ #define TENSORFLOW_LIB_CORE_BLOCKING_COUNTER_H_ -#include - -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/mutex.h" - -namespace tensorflow { - -class BlockingCounter { - public: - BlockingCounter(int initial_count) - : state_(initial_count << 1), notified_(false) { - CHECK_GE(initial_count, 0); - DCHECK_EQ((initial_count << 1) >> 1, initial_count); - } - - ~BlockingCounter() {} - - inline void DecrementCount() { - unsigned int v = state_.fetch_sub(2, std::memory_order_acq_rel) - 2; - if (v != 1) { - DCHECK_NE(((v + 2) & ~1), 0); - return; // either count has not dropped to 0, or waiter is not waiting - } - mutex_lock l(mu_); - DCHECK(!notified_); - notified_ = true; - cond_var_.notify_all(); - } - - inline void Wait() { - unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel); - if ((v >> 1) == 0) return; - mutex_lock l(mu_); - while (!notified_) { - cond_var_.wait(l); - } - } - // Wait for the specified time, return false iff the count has not dropped to - // zero before the timeout expired. - inline bool WaitFor(std::chrono::milliseconds ms) { - unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel); - if ((v >> 1) == 0) return true; - mutex_lock l(mu_); - while (!notified_) { - const std::cv_status status = cond_var_.wait_for(l, ms); - if (status == std::cv_status::timeout) { - return false; - } - } - return true; - } - - private: - mutex mu_; - condition_variable cond_var_; - std::atomic state_; // low bit is waiter flag - bool notified_; -}; - -} // namespace tensorflow +#include "tensorflow/core/platform/blocking_counter.h" #endif // TENSORFLOW_LIB_CORE_BLOCKING_COUNTER_H_ diff --git a/tensorflow/core/lib/core/coding.h b/tensorflow/core/lib/core/coding.h index bfab80dd007..a1121c888dd 100644 --- a/tensorflow/core/lib/core/coding.h +++ b/tensorflow/core/lib/core/coding.h @@ -21,49 +21,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_CORE_CODING_H_ #define TENSORFLOW_CORE_LIB_CORE_CODING_H_ -#include "tensorflow/core/lib/core/raw_coding.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace core { - -// Maximum number of bytes occupied by a varint32. -static const int kMaxVarint32Bytes = 5; - -// Maximum number of bytes occupied by a varint64. -static const int kMaxVarint64Bytes = 10; - -// Lower-level versions of Put... that write directly into a character buffer -// REQUIRES: dst has enough space for the value being written -extern void EncodeFixed16(char* dst, uint16 value); -extern void EncodeFixed32(char* dst, uint32 value); -extern void EncodeFixed64(char* dst, uint64 value); -extern void PutFixed16(string* dst, uint16 value); -extern void PutFixed32(string* dst, uint32 value); -extern void PutFixed64(string* dst, uint64 value); - -extern void PutVarint32(string* dst, uint32 value); -extern void PutVarint64(string* dst, uint64 value); - -extern bool GetVarint32(StringPiece* input, uint32* value); -extern bool GetVarint64(StringPiece* input, uint64* value); - -extern const char* GetVarint32Ptr(const char* p, const char* limit, uint32* v); -extern const char* GetVarint64Ptr(const char* p, const char* limit, uint64* v); - -// Internal routine for use by fallback path of GetVarint32Ptr -extern const char* GetVarint32PtrFallback(const char* p, const char* limit, - uint32* value); -extern const char* GetVarint32Ptr(const char* p, const char* limit, - uint32* value); -extern char* EncodeVarint32(char* dst, uint32 v); -extern char* EncodeVarint64(char* dst, uint64 v); - -// Returns the length of the varint32 or varint64 encoding of "v" -extern int VarintLength(uint64_t v); - -} // namespace core -} // namespace tensorflow +#include "tensorflow/core/platform/coding.h" #endif // TENSORFLOW_CORE_LIB_CORE_CODING_H_ diff --git a/tensorflow/core/lib/hash/BUILD b/tensorflow/core/lib/hash/BUILD index a44e7836cab..de2eebc785f 100644 --- a/tensorflow/core/lib/hash/BUILD +++ b/tensorflow/core/lib/hash/BUILD @@ -41,13 +41,9 @@ cc_library( cc_library( name = "hash", - srcs = ["hash.cc"], hdrs = ["hash.h"], deps = [ - "//tensorflow/core/lib/core:raw_coding", - "//tensorflow/core/lib/core:stringpiece", - "//tensorflow/core/platform:macros", - "//tensorflow/core/platform:types", + "//tensorflow/core/platform:hash", ], ) @@ -65,7 +61,6 @@ filegroup( srcs = [ "crc32c.cc", "crc32c_accelerate.cc", - "hash.cc", ], visibility = ["//tensorflow/core:__pkg__"], ) diff --git a/tensorflow/core/lib/hash/hash.h b/tensorflow/core/lib/hash/hash.h index 675bab71919..fa2cc295b15 100644 --- a/tensorflow/core/lib/hash/hash.h +++ b/tensorflow/core/lib/hash/hash.h @@ -18,96 +18,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_HASH_HASH_H_ #define TENSORFLOW_CORE_LIB_HASH_HASH_H_ -#include -#include - -#include -#include - -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { - -extern uint32 Hash32(const char* data, size_t n, uint32 seed); -extern uint64 Hash64(const char* data, size_t n, uint64 seed); - -inline uint64 Hash64(const char* data, size_t n) { - return Hash64(data, n, 0xDECAFCAFFE); -} - -inline uint64 Hash64(const string& str) { - return Hash64(str.data(), str.size()); -} - -inline uint64 Hash64Combine(uint64 a, uint64 b) { - return a ^ (b + 0x9e3779b97f4a7800ULL + (a << 10) + (a >> 4)); -} - -// Combine two hashes in an order-independent way. This operation should be -// associative and compute the same hash for a collection of elements -// independent of traversal order. Note that it is better to combine hashes -// symmetrically with addition rather than XOR, since (x^x) == 0 but (x+x) != 0. -inline uint64 Hash64CombineUnordered(uint64 a, uint64 b) { return a + b; } - -// Hash functor suitable for use with power-of-two sized hashtables. Use -// instead of std::hash. -// -// In particular, tensorflow::hash is not the identity function for pointers. -// This is important for power-of-two sized hashtables like FlatMap and FlatSet, -// because otherwise they waste the majority of their hash buckets. -// -// The second type argument is only used for SFNIAE below. -template -struct hash { - size_t operator()(const T& t) const { return std::hash()(t); } -}; - -template -struct hash::value>::type> { - size_t operator()(T value) const { - // This works around a defect in the std::hash C++ spec that isn't fixed in - // (at least) gcc 4.8.4: - // http://www.open-std.org/jtc1/sc22/wg21/docs/lwg-defects.html#2148 - // - // We should be able to remove this and use the default - // tensorflow::hash() once we stop building with GCC versions old - // enough to not have this defect fixed. - return std::hash()(static_cast(value)); - } -}; - -template -struct hash { - size_t operator()(const T* t) const { - // Hash pointers as integers, but bring more entropy to the lower bits. - size_t k = static_cast(reinterpret_cast(t)); - return k + (k >> 6); - } -}; - -template <> -struct hash { - size_t operator()(const string& s) const { - return static_cast(Hash64(s)); - } -}; - -template <> -struct hash { - size_t operator()(StringPiece sp) const { - return static_cast(Hash64(sp.data(), sp.size())); - } -}; -using StringPieceHasher = ::tensorflow::hash; - -template -struct hash> { - size_t operator()(const std::pair& p) const { - return Hash64Combine(hash()(p.first), hash()(p.second)); - } -}; - -} // namespace tensorflow +#include "tensorflow/core/platform/hash.h" #endif // TENSORFLOW_CORE_LIB_HASH_HASH_H_ diff --git a/tensorflow/core/lib/histogram/BUILD b/tensorflow/core/lib/histogram/BUILD index 5eba33b0430..9108a09dd15 100644 --- a/tensorflow/core/lib/histogram/BUILD +++ b/tensorflow/core/lib/histogram/BUILD @@ -2,6 +2,8 @@ package( default_visibility = [ # tensorflow/core:lib effectively exposes all targets under tensorflow/core/lib/** "//tensorflow/core:__pkg__", + # tensorflow/core/lib/monitoring:sampler uses histogram + "//tensorflow/core/lib/monitoring:__pkg__", ], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/core/lib/io/BUILD b/tensorflow/core/lib/io/BUILD index aa4f34d45c5..123e24db3c7 100644 --- a/tensorflow/core/lib/io/BUILD +++ b/tensorflow/core/lib/io/BUILD @@ -4,8 +4,6 @@ package( "//tensorflow/c/experimental/filesystem/plugins/posix:__pkg__", # tensorflow/core:lib effectively exposes all targets under tensorflow/core/lib/** "//tensorflow/core:__pkg__", - # tensorflow/core/platform:env uses :path - "//tensorflow/core/platform:__pkg__", ], licenses = ["notice"], # Apache 2.0 ) @@ -103,15 +101,9 @@ cc_library( cc_library( name = "path", - srcs = ["path.cc"], hdrs = ["path.h"], deps = [ - "//tensorflow/core/platform:logging", - "//tensorflow/core/platform:mutex", - "//tensorflow/core/platform:scanner", - "//tensorflow/core/platform:strcat", - "//tensorflow/core/platform:stringpiece", - "//tensorflow/core/platform:types", + "//tensorflow/core/platform:path", ], alwayslink = True, ) @@ -322,7 +314,6 @@ filegroup( "inputbuffer.cc", "inputstream_interface.cc", "iterator.cc", - "path.cc", "random_inputstream.cc", "record_reader.cc", "record_writer.cc", diff --git a/tensorflow/core/lib/io/compression.cc b/tensorflow/core/lib/io/compression.cc index 0aa4caaaef8..116608dcc55 100644 --- a/tensorflow/core/lib/io/compression.cc +++ b/tensorflow/core/lib/io/compression.cc @@ -22,6 +22,7 @@ namespace compression { const char kNone[] = ""; const char kGzip[] = "GZIP"; const char kSnappy[] = "SNAPPY"; +const char kZlib[] = "ZLIB"; } // namespace compression } // namespace io diff --git a/tensorflow/core/lib/io/compression.h b/tensorflow/core/lib/io/compression.h index 10981846d0a..7856ea8bb00 100644 --- a/tensorflow/core/lib/io/compression.h +++ b/tensorflow/core/lib/io/compression.h @@ -23,6 +23,7 @@ namespace compression { extern const char kNone[]; extern const char kGzip[]; extern const char kSnappy[]; +extern const char kZlib[]; } // namespace compression } // namespace io diff --git a/tensorflow/core/lib/io/path.h b/tensorflow/core/lib/io/path.h index 7cfd9809fdb..f5deacd1026 100644 --- a/tensorflow/core/lib/io/path.h +++ b/tensorflow/core/lib/io/path.h @@ -16,83 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_PATH_H_ #define TENSORFLOW_CORE_LIB_IO_PATH_H_ -#include "tensorflow/core/platform/stringpiece.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace io { -namespace internal { -string JoinPathImpl(std::initializer_list paths); -} - -// Utility routines for processing filenames - -#ifndef SWIG // variadic templates -// Join multiple paths together, without introducing unnecessary path -// separators. -// For example: -// -// Arguments | JoinPath -// ---------------------------+---------- -// '/foo', 'bar' | /foo/bar -// '/foo/', 'bar' | /foo/bar -// '/foo', '/bar' | /foo/bar -// -// Usage: -// string path = io::JoinPath("/mydir", filename); -// string path = io::JoinPath(FLAGS_test_srcdir, filename); -// string path = io::JoinPath("/full", "path", "to", "filename"); -template -string JoinPath(const T&... args) { - return internal::JoinPathImpl({args...}); -} -#endif /* SWIG */ - -// Return true if path is absolute. -bool IsAbsolutePath(tensorflow::StringPiece path); - -// Returns the part of the path before the final "/". If there is a single -// leading "/" in the path, the result will be the leading "/". If there is -// no "/" in the path, the result is the empty prefix of the input. -tensorflow::StringPiece Dirname(tensorflow::StringPiece path); - -// Returns the part of the path after the final "/". If there is no -// "/" in the path, the result is the same as the input. -tensorflow::StringPiece Basename(tensorflow::StringPiece path); - -// Returns the part of the basename of path after the final ".". If -// there is no "." in the basename, the result is empty. -tensorflow::StringPiece Extension(tensorflow::StringPiece path); - -// Collapse duplicate "/"s, resolve ".." and "." path elements, remove -// trailing "/". -// -// NOTE: This respects relative vs. absolute paths, but does not -// invoke any system calls (getcwd(2)) in order to resolve relative -// paths with respect to the actual working directory. That is, this is purely -// string manipulation, completely independent of process state. -string CleanPath(tensorflow::StringPiece path); - -// Populates the scheme, host, and path from a URI. scheme, host, and path are -// guaranteed by this function to point into the contents of uri, even if -// empty. -// -// Corner cases: -// - If the URI is invalid, scheme and host are set to empty strings and the -// passed string is assumed to be a path -// - If the URI omits the path (e.g. file://host), then the path is left empty. -void ParseURI(tensorflow::StringPiece uri, tensorflow::StringPiece* scheme, - tensorflow::StringPiece* host, tensorflow::StringPiece* path); - -// Creates a URI from a scheme, host, and path. If the scheme is empty, we just -// return the path. -string CreateURI(tensorflow::StringPiece scheme, tensorflow::StringPiece host, - tensorflow::StringPiece path); - -// Creates a temporary file name with an extension. -string GetTempFilename(const string& extension); - -} // namespace io -} // namespace tensorflow +#include "tensorflow/core/platform/path.h" #endif // TENSORFLOW_CORE_LIB_IO_PATH_H_ diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc index 2c24a74f54b..1af81bd902c 100644 --- a/tensorflow/core/lib/io/record_reader.cc +++ b/tensorflow/core/lib/io/record_reader.cc @@ -31,7 +31,7 @@ namespace io { RecordReaderOptions RecordReaderOptions::CreateRecordReaderOptions( const string& compression_type) { RecordReaderOptions options; - if (compression_type == "ZLIB") { + if (compression_type == compression::kZlib) { options.compression_type = io::RecordReaderOptions::ZLIB_COMPRESSION; #if defined(IS_SLIM_BUILD) LOG(ERROR) << "Compression is not supported but compression_type is set." diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc index 405e65a2a6a..52d0ef9a358 100644 --- a/tensorflow/core/lib/io/record_writer.cc +++ b/tensorflow/core/lib/io/record_writer.cc @@ -31,7 +31,7 @@ bool IsZlibCompressed(RecordWriterOptions options) { RecordWriterOptions RecordWriterOptions::CreateRecordWriterOptions( const string& compression_type) { RecordWriterOptions options; - if (compression_type == "ZLIB") { + if (compression_type == compression::kZlib) { options.compression_type = io::RecordWriterOptions::ZLIB_COMPRESSION; #if defined(IS_SLIM_BUILD) LOG(ERROR) << "Compression is not supported but compression_type is set." diff --git a/tensorflow/core/lib/monitoring/BUILD b/tensorflow/core/lib/monitoring/BUILD new file mode 100644 index 00000000000..35c59079231 --- /dev/null +++ b/tensorflow/core/lib/monitoring/BUILD @@ -0,0 +1,195 @@ +package( + default_visibility = [ + # tensorflow/core:lib effectively exposes all targets under tensorflow/core/lib/** + "//tensorflow/core:__pkg__", + # tensorflow/core/platform:monitoring depends on this package + "//tensorflow/core/platform:__subpackages__", + ], + licenses = ["notice"], # Apache 2.0 +) + +# Todo(bmzhao): Remaining targets to add are: all tests. + +cc_library( + name = "collected_metrics", + hdrs = ["collected_metrics.h"], + deps = [ + ":metric_def", + "//tensorflow/core/framework:summary_proto_cc", + ], +) + +cc_library( + name = "collection_registry", + srcs = ["collection_registry.cc"], + hdrs = ["collection_registry.h"], + deps = [ + ":collected_metrics", + ":metric_def", + "//tensorflow/core/framework:summary_proto_cc", + "//tensorflow/core/platform:env", + "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:macros", + "//tensorflow/core/platform:mutex", + "//tensorflow/core/platform:stringpiece", + "//tensorflow/core/platform:thread_annotations", + "//tensorflow/core/platform:types", + ], +) + +cc_library( + name = "counter", + hdrs = ["counter.h"], + deps = [ + ":collection_registry", + ":metric_def", + ":mobile_counter", + "//tensorflow/core/lib/core:status", + "//tensorflow/core/platform", + "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:macros", + "//tensorflow/core/platform:mutex", + "//tensorflow/core/platform:thread_annotations", + ], +) + +cc_library( + name = "gauge", + hdrs = ["gauge.h"], + deps = [ + ":collection_registry", + ":metric_def", + ":mobile_gauge", + "//tensorflow/core/lib/core:status", + "//tensorflow/core/platform", + "//tensorflow/core/platform:macros", + "//tensorflow/core/platform:mutex", + "//tensorflow/core/platform:thread_annotations", + "//tensorflow/core/platform:types", + ], +) + +cc_library( + name = "metric_def", + hdrs = ["metric_def.h"], + deps = [ + "//tensorflow/core/framework:summary_proto_cc", + "//tensorflow/core/platform:stringpiece", + "//tensorflow/core/platform:types", + ], +) + +cc_library( + name = "mobile_counter", + hdrs = ["mobile_counter.h"], + deps = [ + "//tensorflow/core/lib/core:status", + "//tensorflow/core/platform:macros", + "//tensorflow/core/platform:types", + ], +) + +cc_library( + name = "mobile_gauge", + hdrs = ["mobile_gauge.h"], + deps = [ + "//tensorflow/core/lib/core:status", + "//tensorflow/core/platform:macros", + "//tensorflow/core/platform:types", + ], +) + +cc_library( + name = "mobile_sampler", + hdrs = ["mobile_sampler.h"], + deps = [ + ":metric_def", + "//tensorflow/core/framework:summary_proto_cc", + "//tensorflow/core/lib/core:status", + "//tensorflow/core/platform:macros", + "//tensorflow/core/platform:types", + ], +) + +cc_library( + name = "sampler", + srcs = ["sampler.cc"], + hdrs = ["sampler.h"], + deps = [ + ":collection_registry", + ":metric_def", + ":mobile_sampler", + "//tensorflow/core/framework:summary_proto_cc", + "//tensorflow/core/lib/core:status", + "//tensorflow/core/lib/histogram", + "//tensorflow/core/platform", + "//tensorflow/core/platform:macros", + "//tensorflow/core/platform:mutex", + "//tensorflow/core/platform:thread_annotations", + ], +) + +filegroup( + name = "legacy_lib_monitoring_lib_headers", + srcs = [ + "collected_metrics.h", + "collection_registry.h", + "counter.h", + "gauge.h", + "metric_def.h", + "sampler.h", + ], + visibility = ["//tensorflow/core:__pkg__"], +) + +filegroup( + name = "legacy_lib_monitoring_lib_internal_public_headers", + srcs = [ + "mobile_counter.h", + "mobile_gauge.h", + "mobile_sampler.h", + ], + visibility = ["//tensorflow/core:__pkg__"], +) + +filegroup( + name = "legacy_lib_monitoring_all_headers", + srcs = [ + "collected_metrics.h", + "collection_registry.h", + "counter.h", + "gauge.h", + "metric_def.h", + "mobile_counter.h", + "mobile_gauge.h", + "mobile_sampler.h", + "sampler.h", + ], + visibility = ["//tensorflow/core:__pkg__"], +) + +filegroup( + name = "legacy_lib_monitoring_all_srcs", + srcs = [ + "collection_registry.cc", + "sampler.cc", + ], + visibility = ["//tensorflow/core:__pkg__"], +) + +# Note(bmzhao): Ideally we would use a filegroup to represent these tests instead. +# However, that causes tf_cc_tests to link all of these tests into a single object +# file. This breaks collection_registry_test, because sample_test.cc has static variables +# that instantiate metrics with the same names that collection_registry_test tries +# to create ("/tensorflow/test/sampler_with_labels" and +# "/tensorflow/test/sampler_without_labels"). +exports_files( + [ + "collection_registry_test.cc", + "counter_test.cc", + "gauge_test.cc", + "metric_def_test.cc", + "sampler_test.cc", + ], + visibility = ["//tensorflow/core:__pkg__"], +) diff --git a/tensorflow/core/lib/monitoring/collection_registry.h b/tensorflow/core/lib/monitoring/collection_registry.h index 9e4e1989dd8..b3db7079d12 100644 --- a/tensorflow/core/lib/monitoring/collection_registry.h +++ b/tensorflow/core/lib/monitoring/collection_registry.h @@ -20,13 +20,13 @@ limitations under the License. #include #include "tensorflow/core/framework/summary.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/monitoring/collected_metrics.h" #include "tensorflow/core/lib/monitoring/metric_def.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/lib/monitoring/metric_def.h b/tensorflow/core/lib/monitoring/metric_def.h index bc4365e439c..84b915f360c 100644 --- a/tensorflow/core/lib/monitoring/metric_def.h +++ b/tensorflow/core/lib/monitoring/metric_def.h @@ -20,7 +20,7 @@ limitations under the License. #include #include "tensorflow/core/framework/summary.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/core/lib/random/BUILD b/tensorflow/core/lib/random/BUILD index c9d48689849..7360e72f233 100644 --- a/tensorflow/core/lib/random/BUILD +++ b/tensorflow/core/lib/random/BUILD @@ -64,11 +64,9 @@ cc_library( cc_library( name = "random", - srcs = ["random.cc"], hdrs = ["random.h"], deps = [ - "//tensorflow/core/platform:mutex", - "//tensorflow/core/platform:types", + "//tensorflow/core/platform:random", ], ) @@ -133,7 +131,6 @@ filegroup( name = "legacy_lib_random_all_srcs", srcs = [ "distribution_sampler.cc", - "random.cc", "random_distributions.cc", "simple_philox.cc", "weighted_picker.cc", diff --git a/tensorflow/core/lib/random/random.h b/tensorflow/core/lib/random/random.h index 5335c8cc3c9..e280d98d551 100644 --- a/tensorflow/core/lib/random/random.h +++ b/tensorflow/core/lib/random/random.h @@ -16,20 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_LIB_RANDOM_RANDOM_H_ #define TENSORFLOW_LIB_RANDOM_RANDOM_H_ -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace random { - -// Return a 64-bit random value. Different sequences are generated -// in different processes. -uint64 New64(); - -// Return a 64-bit random value. Uses -// std::mersenne_twister_engine::default_seed as seed value. -uint64 New64DefaultSeed(); - -} // namespace random -} // namespace tensorflow +#include "tensorflow/core/platform/random.h" #endif // TENSORFLOW_LIB_RANDOM_RANDOM_H_ diff --git a/tensorflow/core/lib/strings/BUILD b/tensorflow/core/lib/strings/BUILD index 598a8bc5a47..31425aabc10 100644 --- a/tensorflow/core/lib/strings/BUILD +++ b/tensorflow/core/lib/strings/BUILD @@ -14,11 +14,9 @@ package( cc_library( name = "base64", - srcs = ["base64.cc"], hdrs = ["base64.h"], deps = [ - "//tensorflow/core/lib/core:errors", - "//tensorflow/core/lib/core:status", + "//tensorflow/core/platform:base64", ], ) @@ -113,7 +111,6 @@ filegroup( filegroup( name = "legacy_lib_strings_all_srcs", srcs = [ - "base64.cc", "ordered_code.cc", "proto_serialization.cc", "proto_text_util.cc", diff --git a/tensorflow/core/lib/strings/base64.h b/tensorflow/core/lib/strings/base64.h index 15a273b36a9..bb7cbfb3777 100644 --- a/tensorflow/core/lib/strings/base64.h +++ b/tensorflow/core/lib/strings/base64.h @@ -16,43 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_STRINGS_BASE64_H_ #define TENSORFLOW_CORE_LIB_STRINGS_BASE64_H_ -#include -#include "tensorflow/core/lib/core/status.h" - -namespace tensorflow { - -/// \brief Converts data into web-safe base64 encoding. -/// -/// See https://en.wikipedia.org/wiki/Base64 -template -Status Base64Encode(StringPiece source, bool with_padding, T* encoded); -template -Status Base64Encode(StringPiece source, - T* encoded); // with_padding=false. - -/// \brief Converts data from web-safe base64 encoding. -/// -/// See https://en.wikipedia.org/wiki/Base64 -template -Status Base64Decode(StringPiece data, T* decoded); - -// Explicit instantiations defined in base64.cc. -extern template Status Base64Decode(StringPiece data, string* decoded); -extern template Status Base64Encode(StringPiece source, - string* encoded); -extern template Status Base64Encode(StringPiece source, - bool with_padding, string* encoded); - -#ifdef USE_TSTRING -extern template Status Base64Decode(StringPiece data, - tstring* decoded); -extern template Status Base64Encode(StringPiece source, - tstring* encoded); -extern template Status Base64Encode(StringPiece source, - bool with_padding, - tstring* encoded); -#endif // USE_TSTRING - -} // namespace tensorflow +#include "tensorflow/core/platform/base64.h" #endif // TENSORFLOW_CORE_LIB_STRINGS_BASE64_H_ diff --git a/tensorflow/core/nccl/BUILD b/tensorflow/core/nccl/BUILD index c449b7290df..35157bad58f 100644 --- a/tensorflow/core/nccl/BUILD +++ b/tensorflow/core/nccl/BUILD @@ -42,6 +42,7 @@ cc_library( "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib", "//tensorflow/core:stream_executor", + "//tensorflow/core/profiler/lib:traceme", ]), alwayslink = 1, ) diff --git a/tensorflow/core/nccl/nccl_manager.cc b/tensorflow/core/nccl/nccl_manager.cc index 133772a96df..aadd2a00f3c 100644 --- a/tensorflow/core/nccl/nccl_manager.cc +++ b/tensorflow/core/nccl/nccl_manager.cc @@ -21,8 +21,9 @@ limitations under the License. #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/profiler/lib/traceme.h" #if GOOGLE_CUDA -#include "tensorflow/core/platform/cuda.h" +#include "tensorflow/stream_executor/cuda/cuda_activation.h" #elif TENSORFLOW_USE_ROCM #include "tensorflow/core/platform/rocm.h" #endif diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD index 4a09e0dfd1e..8e6fd49d1ab 100644 --- a/tensorflow/core/platform/BUILD +++ b/tensorflow/core/platform/BUILD @@ -70,6 +70,7 @@ tf_instantiate_platform_libraries(names = [ "strong_hash", "subprocess", "test", + "test_benchmark", "tracing", "types", "unbounded_work_queue", @@ -83,11 +84,42 @@ cc_library( deps = [":types"], ) +cc_library( + name = "base64", + srcs = ["base64.cc"], + hdrs = ["base64.h"], + deps = [ + ":errors", + ":status", + ], +) + +cc_library( + name = "blocking_counter", + hdrs = ["blocking_counter.h"], + deps = [ + ":logging", + ":mutex", + ], +) + cc_library( name = "byte_order", hdrs = ["byte_order.h"], ) +cc_library( + name = "coding", + srcs = ["coding.cc"], + hdrs = ["coding.h"], + deps = [ + ":byte_order", + ":raw_coding", + ":stringpiece", + ":types", + ], +) + cc_library( name = "context", textual_hdrs = ["context.h"], @@ -173,9 +205,9 @@ cc_library( hdrs = ["error.h"], deps = [ ":platform", + ":status", + ":strcat", ":types", - "//tensorflow/core/lib/core:status", - "//tensorflow/core/platform:strcat", ], ) @@ -215,6 +247,18 @@ cc_library( ], ) +cc_library( + name = "hash", + srcs = ["hash.cc"], + hdrs = ["hash.h"], + deps = [ + ":macros", + ":raw_coding", + ":stringpiece", + ":types", + ], +) + cc_library( name = "human_readable_json", textual_hdrs = ["human_readable_json.h"], @@ -303,6 +347,21 @@ cc_library( ], ) +cc_library( + name = "path", + srcs = ["path.cc"], + hdrs = ["path.h"], + deps = [ + ":logging", + ":mutex", + ":scanner", + ":strcat", + ":stringpiece", + ":types", + ], + alwayslink = True, +) + cc_library( name = "platform", hdrs = ["platform.h"], @@ -369,13 +428,23 @@ cc_library( name = "protobuf_internal", hdrs = ["protobuf_internal.h"], deps = [ + ":errors", ":platform", ":protobuf", ":types", - "//tensorflow/core/lib/core:errors", ] + if_static(["@com_google_protobuf//:protobuf"]), ) +cc_library( + name = "random", + srcs = ["random.cc"], + hdrs = ["random.h"], + deps = [ + ":mutex", + ":types", + ], +) + cc_library( name = "raw_coding", hdrs = ["raw_coding.h"], @@ -473,7 +542,6 @@ cc_library( ":logging", ":macros", ":mutex", - ":stacktrace", ":str_util", ":strcat", ":stringpiece", @@ -551,13 +619,13 @@ cc_library( srcs = ["tensor_coding.cc"], hdrs = ["tensor_coding.h"], deps = [ + ":coding", ":platform", ":protobuf", + ":refcount", ":stringpiece", ":strcat", ":types", - "//tensorflow/core/lib/core:coding", - "//tensorflow/core/lib/core:refcount", ] + tf_additional_tensor_coding_deps(), ) @@ -568,6 +636,16 @@ cc_library( deps = tf_mobile_aware_deps("test_impl"), ) +cc_library( + name = "test_benchmark", + testonly = True, + hdrs = ["test_benchmark.h"], + deps = [ + ":platform", + ":test_benchmark_impl", + ], +) + cc_library( name = "thread_annotations", hdrs = ["thread_annotations.h"], @@ -710,10 +788,14 @@ filegroup( "**/unbounded_work_queue.cc", "**/windows_file_system.cc", "abi.cc", + "coding.cc", "cpu_info.cc", + "hash.cc", "numbers.cc", + "path.cc", "platform_strings.cc", "protobuf.cc", + "random.cc", "scanner.cc", "strcat.cc", "stringprintf.cc", @@ -813,6 +895,7 @@ filegroup( "**/human_readable_json.cc", "**/rocm_rocdl_path.cc", "abi.cc", + "coding.cc", "cpu_info.cc", "cpu_feature_guard.cc", "denormal.cc", @@ -820,11 +903,14 @@ filegroup( "error.cc", "file_system.cc", "file_system_helper.cc", + "hash.cc", "logger.cc", "numbers.cc", + "path.cc", "platform_strings.cc", "protobuf.cc", "protobuf_util.cc", + "random.cc", "scanner.cc", "setround.cc", "status.cc", diff --git a/tensorflow/core/lib/strings/base64.cc b/tensorflow/core/platform/base64.cc similarity index 98% rename from tensorflow/core/lib/strings/base64.cc rename to tensorflow/core/platform/base64.cc index 80eec3a9403..0ff690f1b32 100644 --- a/tensorflow/core/lib/strings/base64.cc +++ b/tensorflow/core/platform/base64.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/lib/strings/base64.h" +#include "tensorflow/core/platform/base64.h" #include #include -#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/platform/base64.h b/tensorflow/core/platform/base64.h new file mode 100644 index 00000000000..7b764732dc9 --- /dev/null +++ b/tensorflow/core/platform/base64.h @@ -0,0 +1,58 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_BASE64_H_ +#define TENSORFLOW_CORE_PLATFORM_BASE64_H_ + +#include +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +/// \brief Converts data into web-safe base64 encoding. +/// +/// See https://en.wikipedia.org/wiki/Base64 +template +Status Base64Encode(StringPiece source, bool with_padding, T* encoded); +template +Status Base64Encode(StringPiece source, + T* encoded); // with_padding=false. + +/// \brief Converts data from web-safe base64 encoding. +/// +/// See https://en.wikipedia.org/wiki/Base64 +template +Status Base64Decode(StringPiece data, T* decoded); + +// Explicit instantiations defined in base64.cc. +extern template Status Base64Decode(StringPiece data, string* decoded); +extern template Status Base64Encode(StringPiece source, + string* encoded); +extern template Status Base64Encode(StringPiece source, + bool with_padding, string* encoded); + +#ifdef USE_TSTRING +extern template Status Base64Decode(StringPiece data, + tstring* decoded); +extern template Status Base64Encode(StringPiece source, + tstring* encoded); +extern template Status Base64Encode(StringPiece source, + bool with_padding, + tstring* encoded); +#endif // USE_TSTRING + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_BASE64_H_ diff --git a/tensorflow/core/platform/blocking_counter.h b/tensorflow/core/platform/blocking_counter.h new file mode 100644 index 00000000000..9e7ca004024 --- /dev/null +++ b/tensorflow/core/platform/blocking_counter.h @@ -0,0 +1,80 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_BLOCKING_COUNTER_H_ +#define TENSORFLOW_CORE_PLATFORM_BLOCKING_COUNTER_H_ + +#include + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +class BlockingCounter { + public: + BlockingCounter(int initial_count) + : state_(initial_count << 1), notified_(false) { + CHECK_GE(initial_count, 0); + DCHECK_EQ((initial_count << 1) >> 1, initial_count); + } + + ~BlockingCounter() {} + + inline void DecrementCount() { + unsigned int v = state_.fetch_sub(2, std::memory_order_acq_rel) - 2; + if (v != 1) { + DCHECK_NE(((v + 2) & ~1), 0); + return; // either count has not dropped to 0, or waiter is not waiting + } + mutex_lock l(mu_); + DCHECK(!notified_); + notified_ = true; + cond_var_.notify_all(); + } + + inline void Wait() { + unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel); + if ((v >> 1) == 0) return; + mutex_lock l(mu_); + while (!notified_) { + cond_var_.wait(l); + } + } + // Wait for the specified time, return false iff the count has not dropped to + // zero before the timeout expired. + inline bool WaitFor(std::chrono::milliseconds ms) { + unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel); + if ((v >> 1) == 0) return true; + mutex_lock l(mu_); + while (!notified_) { + const std::cv_status status = cond_var_.wait_for(l, ms); + if (status == std::cv_status::timeout) { + return false; + } + } + return true; + } + + private: + mutex mu_; + condition_variable cond_var_; + std::atomic state_; // low bit is waiter flag + bool notified_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_BLOCKING_COUNTER_H_ diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD index 27321b3be0e..7b194e78911 100644 --- a/tensorflow/core/platform/cloud/BUILD +++ b/tensorflow/core/platform/cloud/BUILD @@ -35,7 +35,10 @@ cc_library( name = "file_block_cache", hdrs = ["file_block_cache.h"], copts = tf_copts(), - deps = ["//tensorflow/core:lib"], + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core/platform:stringpiece", + ], ) cc_library( @@ -47,6 +50,7 @@ cc_library( deps = [ ":file_block_cache", "//tensorflow/core:lib", + "//tensorflow/core/platform:stringpiece", ], ) @@ -95,6 +99,10 @@ cc_library( "//tensorflow/core:framework_headers_lib", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/platform:numbers", + "//tensorflow/core/platform:path", + "//tensorflow/core/platform:str_util", + "//tensorflow/core/platform:stringprintf", "@jsoncpp_git//:jsoncpp", ], alwayslink = 1, @@ -127,6 +135,10 @@ cc_library( "//tensorflow/core:framework_headers_lib", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/platform:numbers", + "//tensorflow/core/platform:path", + "//tensorflow/core/platform:str_util", + "//tensorflow/core/platform:stringprintf", "@jsoncpp_git//:jsoncpp", ], alwayslink = 1, @@ -139,6 +151,7 @@ cc_library( deps = [ "//tensorflow/core:framework_headers_lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/platform:stringpiece", ], ) @@ -151,6 +164,9 @@ cc_library( ":http_request", "//tensorflow/core:framework_headers_lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/platform:scanner", + "//tensorflow/core/platform:str_util", + "//tensorflow/core/platform:stringpiece", "@curl", ], ) @@ -167,6 +183,9 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:test", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:status", + "//tensorflow/core/platform:stringpiece", "@curl", ], ) @@ -185,6 +204,10 @@ cc_library( ":retrying_utils", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/platform:base64", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:path", + "//tensorflow/core/platform:status", "@com_google_absl//absl/strings", "@jsoncpp_git//:jsoncpp", ], @@ -222,6 +245,9 @@ cc_library( ":compute_engine_metadata_client", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:status", + "//tensorflow/core/platform:str_util", ], ) @@ -250,6 +276,9 @@ cc_library( ":http_request", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/platform:base64", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:status", "@boringssl//:crypto", "@jsoncpp_git//:jsoncpp", ], @@ -322,6 +351,7 @@ tf_cc_test( "//tensorflow/core:lib_internal", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:blocking_counter", ], ) @@ -335,6 +365,8 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:str_util", ], ) @@ -348,6 +380,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:str_util", ], ) @@ -361,6 +394,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:str_util", ], ) @@ -373,6 +407,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:path", ], ) @@ -391,6 +426,9 @@ tf_cc_test( "//tensorflow/core:lib_internal", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:base64", + "//tensorflow/core/platform:path", + "//tensorflow/core/platform:scanner", "@boringssl//:crypto", ], ) @@ -411,6 +449,7 @@ tf_cc_test( "//tensorflow/core:lib_internal", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:path", ], ) @@ -452,6 +491,7 @@ tf_cc_test( "//tensorflow/core:lib_internal", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:str_util", ], ) @@ -476,5 +516,6 @@ tf_cc_test( "//tensorflow/core:lib_internal", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:str_util", ], ) diff --git a/tensorflow/core/platform/cloud/auth_provider.h b/tensorflow/core/platform/cloud/auth_provider.h index 7347bc626d8..954c861169b 100644 --- a/tensorflow/core/platform/cloud/auth_provider.h +++ b/tensorflow/core/platform/cloud/auth_provider.h @@ -17,8 +17,9 @@ limitations under the License. #define TENSORFLOW_CORE_PLATFORM_CLOUD_AUTH_PROVIDER_H_ #include -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" + +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/compute_engine_metadata_client.h b/tensorflow/core/platform/cloud/compute_engine_metadata_client.h index 7f060327da5..d7611615606 100644 --- a/tensorflow/core/platform/cloud/compute_engine_metadata_client.h +++ b/tensorflow/core/platform/cloud/compute_engine_metadata_client.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_COMPUTE_ENGINE_METADATA_CLIENT_H_ #define TENSORFLOW_CORE_PLATFORM_CLOUD_COMPUTE_ENGINE_METADATA_CLIENT_H_ -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/cloud/http_request.h" #include "tensorflow/core/platform/cloud/retrying_utils.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/compute_engine_zone_provider.cc b/tensorflow/core/platform/cloud/compute_engine_zone_provider.cc index e147d883710..8008c9cc9ec 100644 --- a/tensorflow/core/platform/cloud/compute_engine_zone_provider.cc +++ b/tensorflow/core/platform/cloud/compute_engine_zone_provider.cc @@ -16,7 +16,8 @@ limitations under the License. #include "tensorflow/core/platform/cloud/compute_engine_zone_provider.h" #include -#include "tensorflow/core/lib/strings/str_util.h" + +#include "tensorflow/core/platform/str_util.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/platform/cloud/curl_http_request.cc b/tensorflow/core/platform/cloud/curl_http_request.cc index c64e215ea99..b3646eba391 100644 --- a/tensorflow/core/platform/cloud/curl_http_request.cc +++ b/tensorflow/core/platform/cloud/curl_http_request.cc @@ -13,15 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "tensorflow/core/platform/cloud/curl_http_request.h" +#include + #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/strings/scanner.h" -#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/cloud/curl_http_request.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/scanner.h" +#include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/version.h" diff --git a/tensorflow/core/platform/cloud/curl_http_request.h b/tensorflow/core/platform/cloud/curl_http_request.h index 9ad75e52f20..b8e9aeb3399 100644 --- a/tensorflow/core/platform/cloud/curl_http_request.h +++ b/tensorflow/core/platform/cloud/curl_http_request.h @@ -19,14 +19,15 @@ limitations under the License. #include #include #include + #include -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/cloud/http_request.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/curl_http_request_test.cc b/tensorflow/core/platform/cloud/curl_http_request_test.cc index e31901a7a0f..754f3e4b4b9 100644 --- a/tensorflow/core/platform/cloud/curl_http_request_test.cc +++ b/tensorflow/core/platform/cloud/curl_http_request_test.cc @@ -14,10 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/platform/cloud/curl_http_request.h" + #include + #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/file_block_cache.h b/tensorflow/core/platform/cloud/file_block_cache.h index c98b10640fa..d2453016a1c 100644 --- a/tensorflow/core/platform/cloud/file_block_cache.h +++ b/tensorflow/core/platform/cloud/file_block_cache.h @@ -22,11 +22,12 @@ limitations under the License. #include #include #include -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" + #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc b/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc index 77850906c6c..09644767152 100644 --- a/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc +++ b/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc @@ -14,7 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/platform/cloud/gcs_dns_cache.h" -#include "tensorflow/core/lib/strings/str_util.h" + +#include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index 55bbec7cb88..b6b988047c8 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -27,12 +27,7 @@ limitations under the License. #endif #include "absl/base/macros.h" #include "include/json/json.h" -#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/cloud/curl_http_request.h" #include "tensorflow/core/platform/cloud/file_block_cache.h" #include "tensorflow/core/platform/cloud/google_auth_provider.h" @@ -40,8 +35,13 @@ limitations under the License. #include "tensorflow/core/platform/cloud/retrying_utils.h" #include "tensorflow/core/platform/cloud/time_util.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/numbers.h" +#include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/str_util.h" +#include "tensorflow/core/platform/stringprintf.h" #include "tensorflow/core/platform/thread_annotations.h" #ifdef _WIN32 diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h index 0bd95f7c6b6..a4d3bcc8f05 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.h +++ b/tensorflow/core/platform/cloud/gcs_file_system.h @@ -21,7 +21,6 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/cloud/auth_provider.h" #include "tensorflow/core/platform/cloud/compute_engine_metadata_client.h" #include "tensorflow/core/platform/cloud/compute_engine_zone_provider.h" @@ -32,6 +31,7 @@ limitations under the License. #include "tensorflow/core/platform/cloud/http_request.h" #include "tensorflow/core/platform/cloud/retrying_file_system.h" #include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc index 566ad45a43c..71121afbd98 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc @@ -14,11 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/platform/cloud/gcs_file_system.h" + #include + #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/cloud/http_request_fake.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/gcs_throttle_test.cc b/tensorflow/core/platform/cloud/gcs_throttle_test.cc index e8eebc5fbc3..404e922502d 100644 --- a/tensorflow/core/platform/cloud/gcs_throttle_test.cc +++ b/tensorflow/core/platform/cloud/gcs_throttle_test.cc @@ -14,8 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/platform/cloud/gcs_throttle.h" + #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/google_auth_provider.cc b/tensorflow/core/platform/cloud/google_auth_provider.cc index e91a9f89757..b8d2acd83ff 100644 --- a/tensorflow/core/platform/cloud/google_auth_provider.cc +++ b/tensorflow/core/platform/cloud/google_auth_provider.cc @@ -22,13 +22,14 @@ limitations under the License. #endif #include #include + #include "absl/strings/match.h" #include "include/json/json.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/base64.h" +#include "tensorflow/core/platform/base64.h" #include "tensorflow/core/platform/cloud/retrying_utils.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/path.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/google_auth_provider_test.cc b/tensorflow/core/platform/cloud/google_auth_provider_test.cc index 8c7e107037a..5bee2072034 100644 --- a/tensorflow/core/platform/cloud/google_auth_provider_test.cc +++ b/tensorflow/core/platform/cloud/google_auth_provider_test.cc @@ -14,10 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/platform/cloud/google_auth_provider.h" + #include + #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/cloud/http_request_fake.h" +#include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/http_request.h b/tensorflow/core/platform/cloud/http_request.h index e925eefb1f2..5681293915f 100644 --- a/tensorflow/core/platform/cloud/http_request.h +++ b/tensorflow/core/platform/cloud/http_request.h @@ -19,12 +19,13 @@ limitations under the License. #include #include #include -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" + #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/http_request_fake.h b/tensorflow/core/platform/cloud/http_request_fake.h index 0a1164b64a7..9564aa7d30b 100644 --- a/tensorflow/core/platform/cloud/http_request_fake.h +++ b/tensorflow/core/platform/cloud/http_request_fake.h @@ -20,14 +20,15 @@ limitations under the License. #include #include #include + #include -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/cloud/curl_http_request.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/platform/cloud/oauth_client.cc b/tensorflow/core/platform/cloud/oauth_client.cc index 89b1056be7d..bd4b3ae0b5c 100644 --- a/tensorflow/core/platform/cloud/oauth_client.cc +++ b/tensorflow/core/platform/cloud/oauth_client.cc @@ -22,14 +22,15 @@ limitations under the License. #include #endif #include + #include #include #include #include -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/base64.h" +#include "tensorflow/core/platform/base64.h" #include "tensorflow/core/platform/cloud/curl_http_request.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/oauth_client.h b/tensorflow/core/platform/cloud/oauth_client.h index 519d69acf98..ed8bf257253 100644 --- a/tensorflow/core/platform/cloud/oauth_client.h +++ b/tensorflow/core/platform/cloud/oauth_client.h @@ -17,10 +17,11 @@ limitations under the License. #define TENSORFLOW_CORE_PLATFORM_CLOUD_OAUTH_CLIENT_H_ #include + #include "include/json/json.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/cloud/http_request.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/oauth_client_test.cc b/tensorflow/core/platform/cloud/oauth_client_test.cc index 7b76e4c6c16..8dfff63873f 100644 --- a/tensorflow/core/platform/cloud/oauth_client_test.cc +++ b/tensorflow/core/platform/cloud/oauth_client_test.cc @@ -14,16 +14,18 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/platform/cloud/oauth_client.h" + #include + #include #include #include #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/base64.h" -#include "tensorflow/core/lib/strings/scanner.h" +#include "tensorflow/core/platform/base64.h" #include "tensorflow/core/platform/cloud/http_request_fake.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/scanner.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -33,7 +35,7 @@ constexpr char kTestData[] = "core/platform/cloud/testdata/"; constexpr char kTokenJson[] = R"( { - "access_token":"1/fFAGRNJru1FTz70BzhT3Zg", + "access_token":"WITH_FAKE_ACCESS_TOKEN_TEST_SHOULD_BE_HAPPY", "expires_in":3920, "token_type":"Bearer" })"; @@ -54,7 +56,7 @@ TEST(OAuthClientTest, ParseOAuthResponse) { uint64 expiration_timestamp; TF_EXPECT_OK(OAuthClient().ParseOAuthResponse(kTokenJson, request_timestamp, &token, &expiration_timestamp)); - EXPECT_EQ("1/fFAGRNJru1FTz70BzhT3Zg", token); + EXPECT_EQ("WITH_FAKE_ACCESS_TOKEN_TEST_SHOULD_BE_HAPPY", token); EXPECT_EQ(4020, expiration_timestamp); } @@ -62,7 +64,7 @@ TEST(OAuthClientTest, GetTokenFromRefreshTokenJson) { const string credentials_json = R"( { "client_id": "test_client_id", - "client_secret": "test_client_secret", + "client_secret": "@@@test_client_secret@@@", "refresh_token": "test_refresh_token", "type": "authorized_user" })"; @@ -73,7 +75,7 @@ TEST(OAuthClientTest, GetTokenFromRefreshTokenJson) { std::vector requests({new FakeHttpRequest( "Uri: https://www.googleapis.com/oauth2/v3/token\n" "Post body: client_id=test_client_id&" - "client_secret=test_client_secret&" + "client_secret=@@@test_client_secret@@@&" "refresh_token=test_refresh_token&grant_type=refresh_token\n", kTokenJson)}); FakeEnv env; @@ -85,7 +87,7 @@ TEST(OAuthClientTest, GetTokenFromRefreshTokenJson) { TF_EXPECT_OK(client.GetTokenFromRefreshTokenJson( json, "https://www.googleapis.com/oauth2/v3/token", &token, &expiration_timestamp)); - EXPECT_EQ("1/fFAGRNJru1FTz70BzhT3Zg", token); + EXPECT_EQ("WITH_FAKE_ACCESS_TOKEN_TEST_SHOULD_BE_HAPPY", token); EXPECT_EQ(13920, expiration_timestamp); } @@ -111,7 +113,7 @@ TEST(OAuthClientTest, GetTokenFromServiceAccountJson) { TF_EXPECT_OK(client.GetTokenFromServiceAccountJson( json, "https://www.googleapis.com/oauth2/v3/token", "https://test-token-scope.com", &token, &expiration_timestamp)); - EXPECT_EQ("1/fFAGRNJru1FTz70BzhT3Zg", token); + EXPECT_EQ("WITH_FAKE_ACCESS_TOKEN_TEST_SHOULD_BE_HAPPY", token); EXPECT_EQ(13920, expiration_timestamp); // Now look at the JWT claim that was sent to the OAuth server. diff --git a/tensorflow/core/platform/cloud/ram_file_block_cache.h b/tensorflow/core/platform/cloud/ram_file_block_cache.h index 46fb9a35b88..97105ff046a 100644 --- a/tensorflow/core/platform/cloud/ram_file_block_cache.h +++ b/tensorflow/core/platform/cloud/ram_file_block_cache.h @@ -22,12 +22,13 @@ limitations under the License. #include #include #include -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" + #include "tensorflow/core/platform/cloud/file_block_cache.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/platform/cloud/ram_file_block_cache_test.cc b/tensorflow/core/platform/cloud/ram_file_block_cache_test.cc index 9f37be65943..e018333b1b7 100644 --- a/tensorflow/core/platform/cloud/ram_file_block_cache_test.cc +++ b/tensorflow/core/platform/cloud/ram_file_block_cache_test.cc @@ -14,9 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/platform/cloud/ram_file_block_cache.h" + #include -#include "tensorflow/core/lib/core/blocking_counter.h" + #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/cloud/now_seconds_env.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/notification.h" diff --git a/tensorflow/core/platform/cloud/retrying_file_system.h b/tensorflow/core/platform/cloud/retrying_file_system.h index 9659edd890e..12bbc7d6abb 100644 --- a/tensorflow/core/platform/cloud/retrying_file_system.h +++ b/tensorflow/core/platform/cloud/retrying_file_system.h @@ -20,12 +20,12 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/cloud/retrying_utils.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/retrying_file_system_test.cc b/tensorflow/core/platform/cloud/retrying_file_system_test.cc index 2b26f27f82c..1df371a6080 100644 --- a/tensorflow/core/platform/cloud/retrying_file_system_test.cc +++ b/tensorflow/core/platform/cloud/retrying_file_system_test.cc @@ -14,9 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/platform/cloud/retrying_file_system.h" + #include + #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/retrying_utils.cc b/tensorflow/core/platform/cloud/retrying_utils.cc index 9c963dd82f2..1f0c41824bf 100644 --- a/tensorflow/core/platform/cloud/retrying_utils.cc +++ b/tensorflow/core/platform/cloud/retrying_utils.cc @@ -14,9 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/platform/cloud/retrying_utils.h" -#include "tensorflow/core/lib/core/errors.h" + #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/file_system.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/retrying_utils.h b/tensorflow/core/platform/cloud/retrying_utils.h index 1a7ce1b122b..70b98463477 100644 --- a/tensorflow/core/platform/cloud/retrying_utils.h +++ b/tensorflow/core/platform/cloud/retrying_utils.h @@ -17,7 +17,8 @@ limitations under the License. #define TENSORFLOW_CORE_PLATFORM_CLOUD_RETRYING_UTILS_H_ #include -#include "tensorflow/core/lib/core/status.h" + +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/retrying_utils_test.cc b/tensorflow/core/platform/cloud/retrying_utils_test.cc index 771bb44285e..7a2dbacacc8 100644 --- a/tensorflow/core/platform/cloud/retrying_utils_test.cc +++ b/tensorflow/core/platform/cloud/retrying_utils_test.cc @@ -14,10 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/platform/cloud/retrying_utils.h" + #include + #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/time_util.cc b/tensorflow/core/platform/cloud/time_util.cc index afd06efa854..c780bd25e85 100644 --- a/tensorflow/core/platform/cloud/time_util.cc +++ b/tensorflow/core/platform/cloud/time_util.cc @@ -21,7 +21,7 @@ limitations under the License. #ifdef _WIN32 #define timegm _mkgmtime #endif -#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/time_util.h b/tensorflow/core/platform/cloud/time_util.h index d6d4bc499fe..944efe9bbd4 100644 --- a/tensorflow/core/platform/cloud/time_util.h +++ b/tensorflow/core/platform/cloud/time_util.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_TIME_UTIL_H_ #define TENSORFLOW_CORE_PLATFORM_CLOUD_TIME_UTIL_H_ -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/zone_provider.h b/tensorflow/core/platform/cloud/zone_provider.h index 421b6a7e1af..d1682fa81cc 100644 --- a/tensorflow/core/platform/cloud/zone_provider.h +++ b/tensorflow/core/platform/cloud/zone_provider.h @@ -17,8 +17,9 @@ limitations under the License. #define TENSORFLOW_CORE_PLATFORM_CLOUD_ZONE_PROVIDER_H_ #include -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" + +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/lib/core/coding.cc b/tensorflow/core/platform/coding.cc similarity index 99% rename from tensorflow/core/lib/core/coding.cc rename to tensorflow/core/platform/coding.cc index 4c33dfa211e..ef0df8fa42a 100644 --- a/tensorflow/core/lib/core/coding.cc +++ b/tensorflow/core/platform/coding.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/platform/coding.h" #include "tensorflow/core/platform/byte_order.h" diff --git a/tensorflow/core/platform/coding.h b/tensorflow/core/platform/coding.h new file mode 100644 index 00000000000..cd66e54dfdb --- /dev/null +++ b/tensorflow/core/platform/coding.h @@ -0,0 +1,69 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Endian-neutral encoding: +// * Fixed-length numbers are encoded with least-significant byte first +// * In addition we support variable length "varint" encoding +// * Strings are encoded prefixed by their length in varint format + +#ifndef TENSORFLOW_CORE_PLATFORM_CODING_H_ +#define TENSORFLOW_CORE_PLATFORM_CODING_H_ + +#include "tensorflow/core/platform/raw_coding.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace core { + +// Maximum number of bytes occupied by a varint32. +static const int kMaxVarint32Bytes = 5; + +// Maximum number of bytes occupied by a varint64. +static const int kMaxVarint64Bytes = 10; + +// Lower-level versions of Put... that write directly into a character buffer +// REQUIRES: dst has enough space for the value being written +extern void EncodeFixed16(char* dst, uint16 value); +extern void EncodeFixed32(char* dst, uint32 value); +extern void EncodeFixed64(char* dst, uint64 value); +extern void PutFixed16(string* dst, uint16 value); +extern void PutFixed32(string* dst, uint32 value); +extern void PutFixed64(string* dst, uint64 value); + +extern void PutVarint32(string* dst, uint32 value); +extern void PutVarint64(string* dst, uint64 value); + +extern bool GetVarint32(StringPiece* input, uint32* value); +extern bool GetVarint64(StringPiece* input, uint64* value); + +extern const char* GetVarint32Ptr(const char* p, const char* limit, uint32* v); +extern const char* GetVarint64Ptr(const char* p, const char* limit, uint64* v); + +// Internal routine for use by fallback path of GetVarint32Ptr +extern const char* GetVarint32PtrFallback(const char* p, const char* limit, + uint32* value); +extern const char* GetVarint32Ptr(const char* p, const char* limit, + uint32* value); +extern char* EncodeVarint32(char* dst, uint32 v); +extern char* EncodeVarint64(char* dst, uint64 v); + +// Returns the length of the varint32 or varint64 encoding of "v" +extern int VarintLength(uint64_t v); + +} // namespace core +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CODING_H_ diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 33e815c2a3f..a95de6632ce 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -10,6 +10,26 @@ load( "if_mkl_ml", ) +def well_known_proto_libs(): + """Set of standard protobuf protos, like Any and Timestamp. + + This list should be provided by protobuf.bzl, but it's not. + """ + return [ + "@com_google_protobuf//:any_proto", + "@com_google_protobuf//:api_proto", + "@com_google_protobuf//:compiler_plugin_proto", + "@com_google_protobuf//:descriptor_proto", + "@com_google_protobuf//:duration_proto", + "@com_google_protobuf//:empty_proto", + "@com_google_protobuf//:field_mask_proto", + "@com_google_protobuf//:source_context_proto", + "@com_google_protobuf//:struct_proto", + "@com_google_protobuf//:timestamp_proto", + "@com_google_protobuf//:type_proto", + "@com_google_protobuf//:wrappers_proto", + ] + # Appends a suffix to a list of deps. def tf_deps(deps, suffix): tf_deps = [] @@ -259,18 +279,6 @@ def cc_proto_library( **kargs ) - # Temporarily also add an alias with the 'protolib_name'. So far we relied - # on copybara to switch dependencies to the _cc dependencies. Now that these - # copybara rules are removed, we need to first change the internal BUILD - # files to depend on the correct targets instead, then this can be removed. - # TODO(b/143648532): Remove this once all reverse dependencies are migrated. - if protolib_name != name: - native.alias( - name = protolib_name, - actual = name, - visibility = kargs["visibility"], - ) - # Re-defined protocol buffer rule to bring in the change introduced in commit # https://github.com/google/protobuf/commit/294b5758c373cbab4b72f35f4cb62dc1d8332b68 # which was not part of a stable protobuf release in 04/2018. @@ -386,19 +394,6 @@ def tf_proto_library_cc( deps = [s + "_genproto" for s in protolib_deps], ) - # Temporarily also add an alias with 'name'. So far we relied on - # copybara to switch dependencies to the _cc dependencies. Now that these - # copybara rules are removed, we need to change the internal BUILD files to - # depend on the correct targets instead. - # TODO(b/143648532): Remove this once all reverse dependencies are - # migrated. - native.alias( - name = name, - actual = cc_name, - testonly = testonly, - visibility = visibility, - ) - native.alias( name = cc_name + "_headers_only", actual = cc_name, @@ -504,8 +499,20 @@ def tf_proto_library( make_default_target_header_only = False, exports = []): """Make a proto library, possibly depending on other proto libraries.""" + + # TODO(b/145545130): Add docstring explaining what rules this creates and how + # opensource projects importing TF in bazel can use them safely (i.e. w/o ODR or + # ABI violations). _ignore = (js_codegen, exports) + native.proto_library( + name = name, + srcs = srcs, + deps = protodeps + well_known_proto_libs(), + visibility = visibility, + testonly = testonly, + ) + tf_proto_library_cc( name = name, testonly = testonly, diff --git a/tensorflow/core/platform/default/build_refactor.bzl b/tensorflow/core/platform/default/build_refactor.bzl index e7eddeb3343..4f11699f766 100644 --- a/tensorflow/core/platform/default/build_refactor.bzl +++ b/tensorflow/core/platform/default/build_refactor.bzl @@ -39,11 +39,9 @@ TF_DEFAULT_PLATFORM_LIBRARIES = { ], "deps": [ "@local_config_cuda//cuda:cuda_headers", - "//tensorflow/core:lib", - # TODO(bmzhao): When bazel gains cc_shared_library support, the targets below are - # the actual granular targets we should depend on, instead of tf/core:lib. - # "//tensorflow/core/platform:logging", - # "//tensorflow/core/platform:types", + "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:path", + "//tensorflow/core/platform:types", ], "visibility": ["//visibility:private"], "tags": ["no_oss", "manual"], @@ -77,26 +75,26 @@ TF_DEFAULT_PLATFORM_LIBRARIES = { "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", "//third_party/eigen3", - "//tensorflow/core/lib/core:blocking_counter", "//tensorflow/core/lib/core:error_codes_proto_cc", - "//tensorflow/core/lib/core:errors", - "//tensorflow/core/lib/core:status", "//tensorflow/core/lib/core:stringpiece", - "//tensorflow/core/lib/io:path", "//tensorflow/core/platform", + "//tensorflow/core/platform:blocking_counter", "//tensorflow/core/platform:context", "//tensorflow/core/platform:cord", "//tensorflow/core/platform:denormal", "//tensorflow/core/platform:error", + "//tensorflow/core/platform:errors", "//tensorflow/core/platform:env_time", "//tensorflow/core/platform:file_statistics", "//tensorflow/core/platform:load_library", "//tensorflow/core/platform:logging", "//tensorflow/core/platform:macros", "//tensorflow/core/platform:mutex", + "//tensorflow/core/platform:path", "//tensorflow/core/platform:platform_port", "//tensorflow/core/platform:protobuf", "//tensorflow/core/platform:setround", + "//tensorflow/core/platform:status", "//tensorflow/core/platform:stringpiece", "//tensorflow/core/platform:stringprintf", "//tensorflow/core/platform:strcat", @@ -131,10 +129,10 @@ TF_DEFAULT_PLATFORM_LIBRARIES = { "//tensorflow/core/platform:default/human_readable_json.cc", ], "deps": [ - "//tensorflow/core/lib/core:errors", - "//tensorflow/core/lib/core:status", - "//tensorflow/core/platform:strcat", + "//tensorflow/core/platform:errors", "//tensorflow/core/platform:protobuf", + "//tensorflow/core/platform:status", + "//tensorflow/core/platform:strcat", ], "visibility": ["//visibility:private"], "tags": ["no_oss", "manual"], @@ -148,8 +146,8 @@ TF_DEFAULT_PLATFORM_LIBRARIES = { "//tensorflow/core/platform:default/load_library.cc", ], "deps": [ - "//tensorflow/core/lib/core:errors", - "//tensorflow/core/lib/core:status", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:status", ], "visibility": ["//visibility:private"], "tags": ["no_oss", "manual"], @@ -236,12 +234,9 @@ TF_DEFAULT_PLATFORM_LIBRARIES = { ], "deps": [ "@local_config_rocm//rocm:rocm_headers", - "//tensorflow/core:lib", - # TODO(bmzhao): When bazel gains cc_shared_library support, the targets below are - # the actual granular targets we should depend on, instead of tf/core:lib. - # "//tensorflow/core/lib/io:path", - # "//tensorflow/core/platform:logging", - # "//tensorflow/core/platform:types", + "//tensorflow/core/platform:path", + "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:types", ], "visibility": ["//visibility:private"], "tags": ["no_oss", "manual"], @@ -327,6 +322,25 @@ TF_DEFAULT_PLATFORM_LIBRARIES = { "tags": ["no_oss", "manual"], "visibility": ["//visibility:private"], }, + "test_benchmark": { + "name": "test_benchmark_impl", + "testonly": True, + "srcs": [ + "//tensorflow/core/platform:default/test_benchmark.cc", + ], + "hdrs": [ + "//tensorflow/core/platform:default/test_benchmark.h", + ], + "deps": [ + "//tensorflow/core/platform", + "//tensorflow/core/platform:env", + "//tensorflow/core/platform:macros", + "//tensorflow/core/platform:types", + "//tensorflow/core:util_reporter", + ], + "tags": ["no_oss", "manual"], + "visibility": ["//visibility:private"], + }, "tracing": { "name": "tracing_impl", "textual_hdrs": [ @@ -342,6 +356,7 @@ TF_DEFAULT_PLATFORM_LIBRARIES = { "deps": [ "//tensorflow/core/lib/hash", "//tensorflow/core/platform", + "//tensorflow/core/platform:hash", "//tensorflow/core/platform:logging", "//tensorflow/core/platform:macros", "//tensorflow/core/platform:strcat", @@ -404,24 +419,24 @@ TF_WINDOWS_PLATFORM_LIBRARIES = { "//third_party/eigen3", "//tensorflow/core/lib/core:blocking_counter", "//tensorflow/core/lib/core:error_codes_proto_cc", - "//tensorflow/core/lib/core:errors", - "//tensorflow/core/lib/core:status", "//tensorflow/core/lib/core:stringpiece", - "//tensorflow/core/lib/io:path", "//tensorflow/core/platform", "//tensorflow/core/platform:context", "//tensorflow/core/platform:cord", "//tensorflow/core/platform:denormal", "//tensorflow/core/platform:error", + "//tensorflow/core/platform:errors", "//tensorflow/core/platform:env_time", "//tensorflow/core/platform:file_statistics", "//tensorflow/core/platform:load_library", "//tensorflow/core/platform:logging", "//tensorflow/core/platform:macros", "//tensorflow/core/platform:mutex", + "//tensorflow/core/platform:path", "//tensorflow/core/platform:platform_port", "//tensorflow/core/platform:protobuf", "//tensorflow/core/platform:setround", + "//tensorflow/core/platform:status", "//tensorflow/core/platform:stringpiece", "//tensorflow/core/platform:stringprintf", "//tensorflow/core/platform:strcat", @@ -457,8 +472,8 @@ TF_WINDOWS_PLATFORM_LIBRARIES = { "//tensorflow/core/platform:windows/load_library.cc", ], "deps": [ - "//tensorflow/core/lib/core:errors", - "//tensorflow/core/lib/core:status", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:status", "//tensorflow/core/platform:windows_wide_char_impl", ], "visibility": ["//visibility:private"], diff --git a/tensorflow/core/platform/default/human_readable_json.cc b/tensorflow/core/platform/default/human_readable_json.cc index c3a61a3d58c..88ab9aa87fc 100644 --- a/tensorflow/core/platform/default/human_readable_json.cc +++ b/tensorflow/core/platform/default/human_readable_json.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/core/platform/human_readable_json.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/strcat.h" namespace tensorflow { diff --git a/tensorflow/core/platform/default/load_library.cc b/tensorflow/core/platform/default/load_library.cc index eaa68e66704..ef9edcc4501 100644 --- a/tensorflow/core/platform/default/load_library.cc +++ b/tensorflow/core/platform/default/load_library.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { diff --git a/tensorflow/core/platform/default/posix_file_system.cc b/tensorflow/core/platform/default/posix_file_system.cc index 56c00279e6b..106a0412fb7 100644 --- a/tensorflow/core/platform/default/posix_file_system.cc +++ b/tensorflow/core/platform/default/posix_file_system.cc @@ -28,12 +28,12 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/default/posix_file_system.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/error.h" #include "tensorflow/core/platform/file_system_helper.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/strcat.h" #include "tensorflow/core/protobuf/error_codes.pb.h" diff --git a/tensorflow/core/platform/default/posix_file_system.h b/tensorflow/core/platform/default/posix_file_system.h index 752eccea66b..c418a08e944 100644 --- a/tensorflow/core/platform/default/posix_file_system.h +++ b/tensorflow/core/platform/default/posix_file_system.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_POSIX_POSIX_FILE_SYSTEM_H_ #define TENSORFLOW_CORE_PLATFORM_POSIX_POSIX_FILE_SYSTEM_H_ -#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/path.h" namespace tensorflow { diff --git a/tensorflow/core/platform/default/rocm_rocdl_path.cc b/tensorflow/core/platform/default/rocm_rocdl_path.cc index 0831544f616..55075969cbd 100644 --- a/tensorflow/core/platform/default/rocm_rocdl_path.cc +++ b/tensorflow/core/platform/default/rocm_rocdl_path.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/path.h" #if !defined(PLATFORM_GOOGLE) && TENSORFLOW_USE_ROCM #include "rocm/rocm_config.h" diff --git a/tensorflow/core/platform/default/test_benchmark.cc b/tensorflow/core/platform/default/test_benchmark.cc index dedab42bd73..533c4ac1df1 100644 --- a/tensorflow/core/platform/default/test_benchmark.cc +++ b/tensorflow/core/platform/default/test_benchmark.cc @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/platform/test_benchmark.h" - -#include -#include +#include "tensorflow/core/platform/default/test_benchmark.h" #include +#include +#include #include -#include "tensorflow/core/lib/strings/str_util.h" + #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/util/reporter.h" namespace tensorflow { diff --git a/tensorflow/core/platform/default/test_benchmark.h b/tensorflow/core/platform/default/test_benchmark.h new file mode 100644 index 00000000000..203a8a045ff --- /dev/null +++ b/tensorflow/core/platform/default/test_benchmark.h @@ -0,0 +1,105 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Simple benchmarking facility. +#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_TEST_BENCHMARK_H_ +#define TENSORFLOW_CORE_PLATFORM_DEFAULT_TEST_BENCHMARK_H_ + +#include +#include + +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/types.h" + +#define BENCHMARK(n) \ + static ::tensorflow::testing::Benchmark* TF_BENCHMARK_CONCAT( \ + __benchmark_, n, __LINE__) TF_ATTRIBUTE_UNUSED = \ + (new ::tensorflow::testing::Benchmark(#n, (n))) +#define TF_BENCHMARK_CONCAT(a, b, c) TF_BENCHMARK_CONCAT2(a, b, c) +#define TF_BENCHMARK_CONCAT2(a, b, c) a##b##c + +namespace tensorflow { +namespace testing { + +// The DoNotOptimize(...) function can be used to prevent a value or +// expression from being optimized away by the compiler. This function is +// intended to add little to no overhead. +// See: http://stackoverflow.com/questions/28287064 +// +// The specific guarantees of DoNotOptimize(x) are: +// 1) x, and any data it transitively points to, will exist (in a register or +// in memory) at the current point in the program. +// 2) The optimizer will assume that DoNotOptimize(x) could mutate x or +// anything it transitively points to (although it actually doesn't). +// +// To see this in action: +// +// void BM_multiply(benchmark::State& state) { +// int a = 2; +// int b = 4; +// for (auto _ : state) { +// testing::DoNotOptimize(a); +// testing::DoNotOptimize(b); +// int c = a * b; +// testing::DoNotOptimize(c); +// } +// } +// BENCHMARK(BM_multiply); +// +// Guarantee (2) applied to 'a' and 'b' prevents the compiler lifting the +// multiplication outside of the loop. Guarantee (1) applied to 'c' prevents the +// compiler from optimizing away 'c' as dead code. +template +void DoNotOptimize(const T& var) { + asm volatile("" : "+m"(const_cast(var))); +} + +class Benchmark { + public: + Benchmark(const char* name, void (*fn)(int)); + Benchmark(const char* name, void (*fn)(int, int)); + Benchmark(const char* name, void (*fn)(int, int, int)); + + Benchmark* Arg(int x); + Benchmark* ArgPair(int x, int y); + Benchmark* Range(int lo, int hi); + Benchmark* RangePair(int lo1, int hi1, int lo2, int hi2); + static void Run(const char* pattern); + + private: + string name_; + int num_args_; + std::vector > args_; + void (*fn0_)(int) = nullptr; + void (*fn1_)(int, int) = nullptr; + void (*fn2_)(int, int, int) = nullptr; + + void Register(); + void Run(int arg1, int arg2, int* run_count, double* run_seconds); +}; + +void RunBenchmarks(); +void SetLabel(const std::string& label); +void BytesProcessed(int64); +void ItemsProcessed(int64); +void StartTiming(); +void StopTiming(); +void UseRealTime(); + +} // namespace testing +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_TEST_BENCHMARK_H_ diff --git a/tensorflow/core/platform/env.cc b/tensorflow/core/platform/env.cc index 602915540ee..ee4ae92f905 100644 --- a/tensorflow/core/platform/env.cc +++ b/tensorflow/core/platform/env.cc @@ -21,10 +21,10 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env_time.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/host_info.h" +#include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/stringprintf.h" diff --git a/tensorflow/core/platform/env.h b/tensorflow/core/platform/env.h index be8399c879b..d5a22b1de2d 100644 --- a/tensorflow/core/platform/env.h +++ b/tensorflow/core/platform/env.h @@ -23,15 +23,15 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/env_time.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/file_system.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/numa.h" #include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc index 1da8aaab743..1f4bd7c6a79 100644 --- a/tensorflow/core/platform/env_test.cc +++ b/tensorflow/core/platform/env_test.cc @@ -20,13 +20,13 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/cord.h" #include "tensorflow/core/platform/null_file_system.h" +#include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/str_util.h" +#include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/core/platform/error.cc b/tensorflow/core/platform/error.cc index 00ddf1dc241..cb09a3a86cc 100644 --- a/tensorflow/core/platform/error.cc +++ b/tensorflow/core/platform/error.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/strcat.h" namespace tensorflow { diff --git a/tensorflow/core/platform/error.h b/tensorflow/core/platform/error.h index 3ba3e749c34..0b08ac36682 100644 --- a/tensorflow/core/platform/error.h +++ b/tensorflow/core/platform/error.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/platform/file_system.cc b/tensorflow/core/platform/file_system.cc index 58d14e3d2d3..fb013b82570 100644 --- a/tensorflow/core/platform/file_system.cc +++ b/tensorflow/core/platform/file_system.cc @@ -20,9 +20,9 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/platform/strcat.h" diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h index ade98a12637..caeedbffbc1 100644 --- a/tensorflow/core/platform/file_system.h +++ b/tensorflow/core/platform/file_system.h @@ -23,8 +23,8 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/cord.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/file_statistics.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/platform.h" diff --git a/tensorflow/core/platform/file_system_helper.cc b/tensorflow/core/platform/file_system_helper.cc index e47fa133c6d..da3acba7d1a 100644 --- a/tensorflow/core/platform/file_system_helper.cc +++ b/tensorflow/core/platform/file_system_helper.cc @@ -19,11 +19,11 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/platform/threadpool.h" diff --git a/tensorflow/core/platform/file_system_helper.h b/tensorflow/core/platform/file_system_helper.h index 8d812b0e381..7427dea77ef 100644 --- a/tensorflow/core/platform/file_system_helper.h +++ b/tensorflow/core/platform/file_system_helper.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/platform/file_system_test.cc b/tensorflow/core/platform/file_system_test.cc index a931634a3c8..278561f4f0d 100644 --- a/tensorflow/core/platform/file_system_test.cc +++ b/tensorflow/core/platform/file_system_test.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/null_file_system.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/str_util.h" +#include "tensorflow/core/platform/strcat.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/core/platform/hadoop/BUILD b/tensorflow/core/platform/hadoop/BUILD index fc6ae4dc6b7..49d9e9975cf 100644 --- a/tensorflow/core/platform/hadoop/BUILD +++ b/tensorflow/core/platform/hadoop/BUILD @@ -18,6 +18,8 @@ cc_library( deps = [ "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/platform:path", + "//tensorflow/core/platform:strcat", "//third_party/hadoop:hdfs", ], alwayslink = 1, @@ -58,5 +60,7 @@ tf_cc_test( "//tensorflow/core:lib_internal", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:path", + "//tensorflow/core/platform:str_util", ], ) diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc index 59c3fe2540f..34dc1cf305b 100644 --- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc +++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc @@ -17,15 +17,15 @@ limitations under the License. #include -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/error.h" #include "tensorflow/core/platform/file_system.h" #include "tensorflow/core/platform/file_system_helper.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/strcat.h" #include "third_party/hadoop/hdfs.h" namespace tensorflow { diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc b/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc index 1242e2547fc..3104addc4e0 100644 --- a/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc +++ b/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include "tensorflow/core/platform/hadoop/hadoop_file_system.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/core/lib/hash/hash.cc b/tensorflow/core/platform/hash.cc similarity index 96% rename from tensorflow/core/lib/hash/hash.cc rename to tensorflow/core/platform/hash.cc index dc9d300d00e..74a18f8f05e 100644 --- a/tensorflow/core/lib/hash/hash.cc +++ b/tensorflow/core/platform/hash.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/hash.h" -#include "tensorflow/core/lib/core/raw_coding.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/raw_coding.h" #include "tensorflow/core/platform/types.h" #include diff --git a/tensorflow/core/platform/hash.h b/tensorflow/core/platform/hash.h new file mode 100644 index 00000000000..3a9de99f2bc --- /dev/null +++ b/tensorflow/core/platform/hash.h @@ -0,0 +1,113 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Simple hash functions used for internal data structures + +#ifndef TENSORFLOW_CORE_PLATFORM_HASH_H_ +#define TENSORFLOW_CORE_PLATFORM_HASH_H_ + +#include +#include + +#include +#include + +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +extern uint32 Hash32(const char* data, size_t n, uint32 seed); +extern uint64 Hash64(const char* data, size_t n, uint64 seed); + +inline uint64 Hash64(const char* data, size_t n) { + return Hash64(data, n, 0xDECAFCAFFE); +} + +inline uint64 Hash64(const string& str) { + return Hash64(str.data(), str.size()); +} + +inline uint64 Hash64Combine(uint64 a, uint64 b) { + return a ^ (b + 0x9e3779b97f4a7800ULL + (a << 10) + (a >> 4)); +} + +// Combine two hashes in an order-independent way. This operation should be +// associative and compute the same hash for a collection of elements +// independent of traversal order. Note that it is better to combine hashes +// symmetrically with addition rather than XOR, since (x^x) == 0 but (x+x) != 0. +inline uint64 Hash64CombineUnordered(uint64 a, uint64 b) { return a + b; } + +// Hash functor suitable for use with power-of-two sized hashtables. Use +// instead of std::hash. +// +// In particular, tensorflow::hash is not the identity function for pointers. +// This is important for power-of-two sized hashtables like FlatMap and FlatSet, +// because otherwise they waste the majority of their hash buckets. +// +// The second type argument is only used for SFNIAE below. +template +struct hash { + size_t operator()(const T& t) const { return std::hash()(t); } +}; + +template +struct hash::value>::type> { + size_t operator()(T value) const { + // This works around a defect in the std::hash C++ spec that isn't fixed in + // (at least) gcc 4.8.4: + // http://www.open-std.org/jtc1/sc22/wg21/docs/lwg-defects.html#2148 + // + // We should be able to remove this and use the default + // tensorflow::hash() once we stop building with GCC versions old + // enough to not have this defect fixed. + return std::hash()(static_cast(value)); + } +}; + +template +struct hash { + size_t operator()(const T* t) const { + // Hash pointers as integers, but bring more entropy to the lower bits. + size_t k = static_cast(reinterpret_cast(t)); + return k + (k >> 6); + } +}; + +template <> +struct hash { + size_t operator()(const string& s) const { + return static_cast(Hash64(s)); + } +}; + +template <> +struct hash { + size_t operator()(StringPiece sp) const { + return static_cast(Hash64(sp.data(), sp.size())); + } +}; +using StringPieceHasher = ::tensorflow::hash; + +template +struct hash> { + size_t operator()(const std::pair& p) const { + return Hash64Combine(hash()(p.first), hash()(p.second)); + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_HASH_H_ diff --git a/tensorflow/core/platform/human_readable_json.h b/tensorflow/core/platform/human_readable_json.h index 49908eac7c8..f6830e20207 100644 --- a/tensorflow/core/platform/human_readable_json.h +++ b/tensorflow/core/platform/human_readable_json.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_HUMAN_READABLE_JSON_H_ #define TENSORFLOW_CORE_PLATFORM_HUMAN_READABLE_JSON_H_ -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/platform/load_library.h b/tensorflow/core/platform/load_library.h index c7eeb2918ca..01efd4c1d01 100644 --- a/tensorflow/core/platform/load_library.h +++ b/tensorflow/core/platform/load_library.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_LOAD_LIBRARY_H_ #define TENSORFLOW_CORE_PLATFORM_LOAD_LIBRARY_H_ -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/lib/io/path.cc b/tensorflow/core/platform/path.cc similarity index 99% rename from tensorflow/core/lib/io/path.cc rename to tensorflow/core/platform/path.cc index ea9d93629c9..864bf49b2bb 100644 --- a/tensorflow/core/lib/io/path.cc +++ b/tensorflow/core/platform/path.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/path.h" #include #include diff --git a/tensorflow/core/platform/path.h b/tensorflow/core/platform/path.h new file mode 100644 index 00000000000..db0348d8960 --- /dev/null +++ b/tensorflow/core/platform/path.h @@ -0,0 +1,98 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_PATH_H_ +#define TENSORFLOW_CORE_PLATFORM_PATH_H_ + +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace io { +namespace internal { +string JoinPathImpl(std::initializer_list paths); +} + +// Utility routines for processing filenames + +#ifndef SWIG // variadic templates +// Join multiple paths together, without introducing unnecessary path +// separators. +// For example: +// +// Arguments | JoinPath +// ---------------------------+---------- +// '/foo', 'bar' | /foo/bar +// '/foo/', 'bar' | /foo/bar +// '/foo', '/bar' | /foo/bar +// +// Usage: +// string path = io::JoinPath("/mydir", filename); +// string path = io::JoinPath(FLAGS_test_srcdir, filename); +// string path = io::JoinPath("/full", "path", "to", "filename"); +template +string JoinPath(const T&... args) { + return internal::JoinPathImpl({args...}); +} +#endif /* SWIG */ + +// Return true if path is absolute. +bool IsAbsolutePath(tensorflow::StringPiece path); + +// Returns the part of the path before the final "/". If there is a single +// leading "/" in the path, the result will be the leading "/". If there is +// no "/" in the path, the result is the empty prefix of the input. +tensorflow::StringPiece Dirname(tensorflow::StringPiece path); + +// Returns the part of the path after the final "/". If there is no +// "/" in the path, the result is the same as the input. +tensorflow::StringPiece Basename(tensorflow::StringPiece path); + +// Returns the part of the basename of path after the final ".". If +// there is no "." in the basename, the result is empty. +tensorflow::StringPiece Extension(tensorflow::StringPiece path); + +// Collapse duplicate "/"s, resolve ".." and "." path elements, remove +// trailing "/". +// +// NOTE: This respects relative vs. absolute paths, but does not +// invoke any system calls (getcwd(2)) in order to resolve relative +// paths with respect to the actual working directory. That is, this is purely +// string manipulation, completely independent of process state. +string CleanPath(tensorflow::StringPiece path); + +// Populates the scheme, host, and path from a URI. scheme, host, and path are +// guaranteed by this function to point into the contents of uri, even if +// empty. +// +// Corner cases: +// - If the URI is invalid, scheme and host are set to empty strings and the +// passed string is assumed to be a path +// - If the URI omits the path (e.g. file://host), then the path is left empty. +void ParseURI(tensorflow::StringPiece uri, tensorflow::StringPiece* scheme, + tensorflow::StringPiece* host, tensorflow::StringPiece* path); + +// Creates a URI from a scheme, host, and path. If the scheme is empty, we just +// return the path. +string CreateURI(tensorflow::StringPiece scheme, tensorflow::StringPiece host, + tensorflow::StringPiece path); + +// Creates a temporary file name with an extension. +string GetTempFilename(const string& extension); + +} // namespace io +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_PATH_H_ diff --git a/tensorflow/core/platform/platform_strings_test.cc b/tensorflow/core/platform/platform_strings_test.cc index a4eb845c25e..a4af143d58f 100644 --- a/tensorflow/core/platform/platform_strings_test.cc +++ b/tensorflow/core/platform/platform_strings_test.cc @@ -15,6 +15,8 @@ limitations under the License. // Test for the platform_strings.h header file. +#include "tensorflow/core/platform/platform_strings.h" + #include #include #include @@ -23,12 +25,11 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/platform_strings.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/str_util.h" // Embed the platform strings in this binary. TF_PLATFORM_STRINGS() diff --git a/tensorflow/core/platform/port_test.cc b/tensorflow/core/platform/port_test.cc index 94a9e4d4589..4f59ed6f1c5 100644 --- a/tensorflow/core/platform/port_test.cc +++ b/tensorflow/core/platform/port_test.cc @@ -15,12 +15,12 @@ limitations under the License. #include -#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/env_time.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/threadpool.h" namespace tensorflow { namespace port { diff --git a/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.cc b/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.cc index 12dc9c58b38..0534443d17c 100644 --- a/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.cc +++ b/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.cc @@ -28,8 +28,8 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stringprintf.h" namespace tensorflow { namespace profile_utils { diff --git a/tensorflow/core/platform/protobuf_internal.h b/tensorflow/core/platform/protobuf_internal.h index d0cfde09bc1..bf72968a157 100644 --- a/tensorflow/core/platform/protobuf_internal.h +++ b/tensorflow/core/platform/protobuf_internal.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_PLATFORM_PROTOBUF_INTERNAL_H_ #include "google/protobuf/any.pb.h" -#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/lib/random/random.cc b/tensorflow/core/platform/random.cc similarity index 96% rename from tensorflow/core/lib/random/random.cc rename to tensorflow/core/platform/random.cc index 82dc8295073..d7252810021 100644 --- a/tensorflow/core/lib/random/random.cc +++ b/tensorflow/core/platform/random.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/random.h" #include #include "tensorflow/core/platform/mutex.h" diff --git a/tensorflow/core/platform/random.h b/tensorflow/core/platform/random.h new file mode 100644 index 00000000000..f605fd9e477 --- /dev/null +++ b/tensorflow/core/platform/random.h @@ -0,0 +1,35 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_RANDOM_H_ +#define TENSORFLOW_CORE_PLATFORM_RANDOM_H_ + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace random { + +// Return a 64-bit random value. Different sequences are generated +// in different processes. +uint64 New64(); + +// Return a 64-bit random value. Uses +// std::mersenne_twister_engine::default_seed as seed value. +uint64 New64DefaultSeed(); + +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_RANDOM_H_ diff --git a/tensorflow/core/platform/rocm_rocdl_path_test.cc b/tensorflow/core/platform/rocm_rocdl_path_test.cc index 4a4d9b89c59..3436dafac6d 100644 --- a/tensorflow/core/platform/rocm_rocdl_path_test.cc +++ b/tensorflow/core/platform/rocm_rocdl_path_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/platform/rocm_rocdl_path.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/core/platform/s3/BUILD b/tensorflow/core/platform/s3/BUILD index d518bfb71a2..a5494d5c318 100644 --- a/tensorflow/core/platform/s3/BUILD +++ b/tensorflow/core/platform/s3/BUILD @@ -65,6 +65,7 @@ cc_library( deps = [ "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/platform:stringprintf", "@aws", ], alwayslink = 1, @@ -83,6 +84,8 @@ cc_library( ":aws_logging", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/platform:path", + "//tensorflow/core/platform:str_util", "@aws", ], alwayslink = 1, @@ -103,6 +106,7 @@ tf_cc_test( "//tensorflow/core:lib_internal", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:path", "@aws", ], ) diff --git a/tensorflow/core/platform/s3/aws_logging.cc b/tensorflow/core/platform/s3/aws_logging.cc index dac56908893..1d549a2a61e 100644 --- a/tensorflow/core/platform/s3/aws_logging.cc +++ b/tensorflow/core/platform/s3/aws_logging.cc @@ -13,9 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/platform/s3/aws_logging.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/mutex.h" #include #include @@ -23,6 +20,10 @@ limitations under the License. #include +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stringprintf.h" + namespace tensorflow { AWSLogSystem::AWSLogSystem(Aws::Utils::Logging::LogLevel log_level) diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc index d32158f70bd..936339079cf 100644 --- a/tensorflow/core/platform/s3/s3_file_system.cc +++ b/tensorflow/core/platform/s3/s3_file_system.cc @@ -13,12 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/platform/s3/s3_file_system.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/platform/file_system_helper.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/s3/aws_crypto.h" -#include "tensorflow/core/platform/s3/aws_logging.h" #include #include @@ -38,6 +32,13 @@ limitations under the License. #include +#include "tensorflow/core/platform/file_system_helper.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/s3/aws_crypto.h" +#include "tensorflow/core/platform/s3/aws_logging.h" +#include "tensorflow/core/platform/str_util.h" + namespace tensorflow { namespace { diff --git a/tensorflow/core/platform/s3/s3_file_system_test.cc b/tensorflow/core/platform/s3/s3_file_system_test.cc index e7c3e4a8904..98778495f47 100644 --- a/tensorflow/core/platform/s3/s3_file_system_test.cc +++ b/tensorflow/core/platform/s3/s3_file_system_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/platform/s3/s3_file_system.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/core/platform/status.cc b/tensorflow/core/platform/status.cc index a7fd3e693a1..d9cd02a27fb 100644 --- a/tensorflow/core/platform/status.cc +++ b/tensorflow/core/platform/status.cc @@ -22,7 +22,6 @@ limitations under the License. #include "absl/base/call_once.h" #include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/stacktrace.h" #include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/platform/strcat.h" #include "tensorflow/core/platform/stringprintf.h" @@ -92,8 +91,6 @@ Status::Status(tensorflow::error::Code code, StringPiece msg) { state_ = std::unique_ptr(new State); state_->code = code; state_->msg = string(msg); - VLOG(5) << "Generated non-OK status: \"" << *this << "\". " - << CurrentStackTrace(); } void Status::Update(const Status& new_status) { diff --git a/tensorflow/core/platform/tensor_coding.cc b/tensorflow/core/platform/tensor_coding.cc index f115da2b4d6..c12810a42d6 100644 --- a/tensorflow/core/platform/tensor_coding.cc +++ b/tensorflow/core/platform/tensor_coding.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/platform/coding.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/strcat.h" #include "tensorflow/core/platform/stringpiece.h" diff --git a/tensorflow/core/platform/tensor_coding.h b/tensorflow/core/platform/tensor_coding.h index 63e47a880a9..fcfa5469e18 100644 --- a/tensorflow/core/platform/tensor_coding.h +++ b/tensorflow/core/platform/tensor_coding.h @@ -19,9 +19,9 @@ limitations under the License. #include -#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/platform/test_benchmark.h b/tensorflow/core/platform/test_benchmark.h index 61fcd0d372c..aff1edb2d51 100644 --- a/tensorflow/core/platform/test_benchmark.h +++ b/tensorflow/core/platform/test_benchmark.h @@ -17,102 +17,13 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_TEST_BENCHMARK_H_ #define TENSORFLOW_CORE_PLATFORM_TEST_BENCHMARK_H_ -#include -#include -#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/platform.h" -#include "tensorflow/core/platform/types.h" #if defined(PLATFORM_GOOGLE) -#include "tensorflow/core/platform/google/build_config/benchmark.h" - +#include "tensorflow/core/platform/google/test_benchmark.h" #else -#define BENCHMARK(n) \ - static ::tensorflow::testing::Benchmark* TF_BENCHMARK_CONCAT( \ - __benchmark_, n, __LINE__) TF_ATTRIBUTE_UNUSED = \ - (new ::tensorflow::testing::Benchmark(#n, (n))) -#define TF_BENCHMARK_CONCAT(a, b, c) TF_BENCHMARK_CONCAT2(a, b, c) -#define TF_BENCHMARK_CONCAT2(a, b, c) a##b##c - +#include "tensorflow/core/platform/default/test_benchmark.h" #endif // PLATFORM_GOOGLE -namespace tensorflow { -namespace testing { - -#if defined(PLATFORM_GOOGLE) - -using ::testing::Benchmark; -using ::testing::DoNotOptimize; - -#else - -// The DoNotOptimize(...) function can be used to prevent a value or -// expression from being optimized away by the compiler. This function is -// intended to add little to no overhead. -// See: http://stackoverflow.com/questions/28287064 -// -// The specific guarantees of DoNotOptimize(x) are: -// 1) x, and any data it transitively points to, will exist (in a register or -// in memory) at the current point in the program. -// 2) The optimizer will assume that DoNotOptimize(x) could mutate x or -// anything it transitively points to (although it actually doesn't). -// -// To see this in action: -// -// void BM_multiply(benchmark::State& state) { -// int a = 2; -// int b = 4; -// for (auto _ : state) { -// testing::DoNotOptimize(a); -// testing::DoNotOptimize(b); -// int c = a * b; -// testing::DoNotOptimize(c); -// } -// } -// BENCHMARK(BM_multiply); -// -// Guarantee (2) applied to 'a' and 'b' prevents the compiler lifting the -// multiplication outside of the loop. Guarantee (1) applied to 'c' prevents the -// compiler from optimizing away 'c' as dead code. -template -void DoNotOptimize(const T& var) { - asm volatile("" : "+m"(const_cast(var))); -} - -class Benchmark { - public: - Benchmark(const char* name, void (*fn)(int)); - Benchmark(const char* name, void (*fn)(int, int)); - Benchmark(const char* name, void (*fn)(int, int, int)); - - Benchmark* Arg(int x); - Benchmark* ArgPair(int x, int y); - Benchmark* Range(int lo, int hi); - Benchmark* RangePair(int lo1, int hi1, int lo2, int hi2); - static void Run(const char* pattern); - - private: - string name_; - int num_args_; - std::vector > args_; - void (*fn0_)(int) = nullptr; - void (*fn1_)(int, int) = nullptr; - void (*fn2_)(int, int, int) = nullptr; - - void Register(); - void Run(int arg1, int arg2, int* run_count, double* run_seconds); -}; -#endif - -void RunBenchmarks(); -void SetLabel(const std::string& label); -void BytesProcessed(int64); -void ItemsProcessed(int64); -void StartTiming(); -void StopTiming(); -void UseRealTime(); - -} // namespace testing -} // namespace tensorflow #endif // TENSORFLOW_CORE_PLATFORM_TEST_BENCHMARK_H_ diff --git a/tensorflow/core/platform/threadpool.cc b/tensorflow/core/platform/threadpool.cc index fa22ad3867b..18aa7684aba 100644 --- a/tensorflow/core/platform/threadpool.cc +++ b/tensorflow/core/platform/threadpool.cc @@ -19,7 +19,7 @@ limitations under the License. #include "absl/types/optional.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/context.h" #include "tensorflow/core/platform/denormal.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/core/platform/tracing.cc b/tensorflow/core/platform/tracing.cc index 30aa664ae01..a7745903d4b 100644 --- a/tensorflow/core/platform/tracing.cc +++ b/tensorflow/core/platform/tracing.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/hash.h" namespace tensorflow { namespace tracing { diff --git a/tensorflow/core/platform/unbounded_work_queue_test.cc b/tensorflow/core/platform/unbounded_work_queue_test.cc index 03d91cd4893..ada99c5e1a3 100644 --- a/tensorflow/core/platform/unbounded_work_queue_test.cc +++ b/tensorflow/core/platform/unbounded_work_queue_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/platform/unbounded_work_queue.h" #include "absl/memory/memory.h" -#include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/core/platform/windows/load_library.cc b/tensorflow/core/platform/windows/load_library.cc index 177253debdc..f95e770cc6b 100644 --- a/tensorflow/core/platform/windows/load_library.cc +++ b/tensorflow/core/platform/windows/load_library.cc @@ -25,7 +25,7 @@ limitations under the License. #undef LoadLibrary #undef ERROR -#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/windows/wide_char.h" #pragma comment(lib, "Shlwapi.lib") diff --git a/tensorflow/core/platform/windows/windows_file_system.h b/tensorflow/core/platform/windows/windows_file_system.h index 255f6d59a6f..2e0de725762 100644 --- a/tensorflow/core/platform/windows/windows_file_system.h +++ b/tensorflow/core/platform/windows/windows_file_system.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_WINDOWS_WINDOWS_FILE_SYSTEM_H_ #define TENSORFLOW_CORE_PLATFORM_WINDOWS_WINDOWS_FILE_SYSTEM_H_ -#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/platform.h" #ifdef PLATFORM_WINDOWS diff --git a/tensorflow/core/profiler/internal/traceme_recorder.cc b/tensorflow/core/profiler/internal/traceme_recorder.cc index d191a49fc94..3257a347d66 100644 --- a/tensorflow/core/profiler/internal/traceme_recorder.cc +++ b/tensorflow/core/profiler/internal/traceme_recorder.cc @@ -199,7 +199,9 @@ void TraceMeRecorder::RegisterThread(int32 tid, ThreadLocalRecorder* thread) { void TraceMeRecorder::UnregisterThread(TraceMeRecorder::ThreadEvents&& events) { mutex_lock lock(mutex_); threads_.erase(events.thread.tid); - orphaned_events_.push_back(std::move(events)); + if (!events.events.empty()) { + orphaned_events_.push_back(std::move(events)); + } } // This method is performance critical and should be kept fast. It is called @@ -211,7 +213,10 @@ TraceMeRecorder::Events TraceMeRecorder::Clear() { std::swap(orphaned_events_, result); for (const auto& entry : threads_) { auto* recorder = entry.second; - result.push_back(recorder->Clear()); + TraceMeRecorder::ThreadEvents events = recorder->Clear(); + if (!events.events.empty()) { + result.push_back(std::move(events)); + } } return result; } diff --git a/tensorflow/core/protobuf/debug_event.proto b/tensorflow/core/protobuf/debug_event.proto index 06499c2406c..8f9680f38d9 100644 --- a/tensorflow/core/protobuf/debug_event.proto +++ b/tensorflow/core/protobuf/debug_event.proto @@ -87,8 +87,7 @@ message DebugEvent { // a Python function). GraphOpCreation graph_op_creation = 7; - // Information about a debugged graph, including its graph def and - // list of the graph's ops that are instrumented. + // Information about a debugged graph. DebuggedGraph debugged_graph = 8; // Execution of an op or a Graph (e.g., a tf.function). @@ -200,6 +199,9 @@ message DebuggedGraph { // An encoded version of a GraphDef. // This graph may include the debugger-inserted ops. bytes instrumented_graph_def = 5; + + // IDs of the immediate enclosing context (graph), if any. + string outer_context_id = 6; } // Data relating to the eager execution of an op or a Graph. diff --git a/tensorflow/core/util/reporter.cc b/tensorflow/core/util/reporter.cc index eb69e292116..8e9d863b4c2 100644 --- a/tensorflow/core/util/reporter.cc +++ b/tensorflow/core/util/reporter.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/core/util/reporter.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/str_util.h" namespace tensorflow { diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl index 2fb3ea5e714..f37ab23a67a 100644 --- a/tensorflow/lite/build_def.bzl +++ b/tensorflow/lite/build_def.bzl @@ -153,6 +153,7 @@ def tflite_cc_shared_object( linkstatic = 1, deps = [], visibility = None, + per_os_targets = False, tags = None): """Builds a shared object for TFLite.""" tf_cc_shared_object( @@ -164,6 +165,7 @@ def tflite_cc_shared_object( deps = deps, visibility = visibility, tags = tags, + per_os_targets = per_os_targets, ) def tf_to_tflite(name, src, options, out): diff --git a/tensorflow/lite/c/BUILD b/tensorflow/lite/c/BUILD index 37b996c565c..629320370cb 100644 --- a/tensorflow/lite/c/BUILD +++ b/tensorflow/lite/c/BUILD @@ -1,8 +1,128 @@ +load( + "//tensorflow/lite:build_def.bzl", + "tflite_cc_shared_object", + "tflite_copts", +) + package( - default_visibility = ["//visibility:public"], + default_visibility = [":experimental"], licenses = ["notice"], # Apache 2.0 ) +package_group( + name = "experimental", + packages = [ + "//tensorflow/lite/...", + "//third_party/dart/tflite_native/...", # whitelisted + ], +) + +# Generates a platform-specific shared library containing the TensorFlow Lite C +# API implementation as define in `c_api.h`. The exact output library name +# is platform dependent: +# - Linux/Android: `libtensorflowlite_c.so` +# - Mac: `libtensorflowlite_c.dylib` +# - Windows: `tensorflowlite_c.dll` +tflite_cc_shared_object( + name = "tensorflowlite_c", + linkopts = select({ + "//tensorflow:macos": [ + "-Wl,-exported_symbols_list,$(location //tensorflow/lite/c:exported_symbols.lds)", + ], + "//tensorflow:windows": [], + "//conditions:default": [ + "-z defs", + "-Wl,--version-script,$(location //tensorflow/lite/c:version_script.lds)", + ], + }), + per_os_targets = True, + deps = [ + ":c_api", + ":c_api_experimental", + ":exported_symbols.lds", + ":version_script.lds", + ], +) + +cc_library( + name = "c_api_internal", + srcs = [ + "c_api.h", + "common.h", + ], + hdrs = ["c_api_internal.h"], + copts = tflite_copts(), + visibility = ["//visibility:private"], + deps = [ + ":common", + "//tensorflow/lite:framework", + ], +) + +cc_library( + name = "c_api", + srcs = ["c_api.cc"], + hdrs = [ + "c_api.h", + "common.h", + ], + copts = tflite_copts(), + visibility = [ + ":experimental", + ], + deps = [ + ":c_api_internal", + ":common", + "//tensorflow/lite:framework", + "//tensorflow/lite:version", + "//tensorflow/lite/kernels:builtin_ops", + ], + alwayslink = 1, +) + +cc_library( + name = "c_api_experimental", + srcs = ["c_api_experimental.cc"], + hdrs = ["c_api_experimental.h"], + copts = tflite_copts(), + deps = [ + ":c_api", + ":c_api_internal", + "//tensorflow/lite:kernel_api", + ], + alwayslink = 1, +) + +cc_test( + name = "c_api_test", + size = "small", + srcs = ["c_api_test.cc"], + data = [ + "//tensorflow/lite:testdata/add.bin", + "//tensorflow/lite:testdata/add_quantized.bin", + ], + deps = [ + ":c_api", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +cc_test( + name = "c_api_experimental_test", + size = "small", + srcs = ["c_api_experimental_test.cc"], + data = ["//tensorflow/lite:testdata/add.bin"], + deps = [ + ":c_api", + ":c_api_experimental", + "//tensorflow/lite:kernel_api", + "//tensorflow/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "common", srcs = ["common.c"], @@ -13,6 +133,7 @@ cc_library( visibility = [ "//tensorflow/lite:__subpackages__", ], + alwayslink = 1, ) # For use with library targets that can't use relative paths. diff --git a/tensorflow/lite/c/README.md b/tensorflow/lite/c/README.md new file mode 100644 index 00000000000..06579199393 --- /dev/null +++ b/tensorflow/lite/c/README.md @@ -0,0 +1,48 @@ +# TensorFlow Lite C API + +This directory contains C APIs for TensorFlow Lite. This includes C APIs +for common types, like kernels and delegates, as well as an explicit C API +for inference. + +## Header summary + +Each public C header contains types and methods for specific uses: + +* `common.h` - Contains common C enums, types and methods used throughout + TensorFlow Lite. This includes everything from error codes, to the kernel + and delegate APIs. +* `builtin_op_data.h` - Contains op-specific data that is used for builtin + kernels. This should only be used when (re)implementing a builtin operator. +* `c_api.h` - Contains the TensorFlow Lite C API for inference. The + functionality here is largely equivalent (though a strict subset of) the + functionality provided by the C++ `Interpreter` API. +* `c_api_experimental.h` - Contains experimental C API methods for inference. + These methods are useful and usable, but aren't yet part of the stable API. + +## Using the C API + +See the [`c_api.h`](c_api.h) header for API usage details. + +## Building the C API + +A native shared library target that contains the C API for inference has been +provided. Assuming a working [bazel](https://bazel.build/versions/master/docs/install.html) +configuration, this can be built as follows: + +```sh +bazel build -c opt --cxxopt=--std=c++11 //tensorflow/lite/c:tensorflowlite_c +``` + +and for Android (replace `android_arm` with `android_arm64` for 64-bit), +assuming you've [configured your project for Android builds](../g3doc/guide/android.md): + +```sh +bazel build -c opt --cxxopt=--std=c++11 --config=android_arm \ + //tensorflow/lite/c:tensorflowlite_c +``` + +The generated shared library will be available in your +`bazel-bin/tensorflow/lite/c` directory. A target which packages the shared +library together with the necessary headers (`c_api.h`, `c_api_experimental.h` +and `common.h`) will be available soon, and will also be released as a prebuilt +archive (together with existing prebuilt packages for Android/iOS). diff --git a/tensorflow/lite/experimental/c/c_api.cc b/tensorflow/lite/c/c_api.cc similarity index 97% rename from tensorflow/lite/experimental/c/c_api.cc rename to tensorflow/lite/c/c_api.cc index ab3ee961bb1..7ceddab4ecf 100644 --- a/tensorflow/lite/experimental/c/c_api.cc +++ b/tensorflow/lite/c/c_api.cc @@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/c/c_api.h" +#include "tensorflow/lite/c/c_api.h" #include +#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/error_reporter.h" -#include "tensorflow/lite/experimental/c/c_api_internal.h" -#include "tensorflow/lite/experimental/c/c_api_types.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" diff --git a/tensorflow/lite/experimental/c/c_api.h b/tensorflow/lite/c/c_api.h similarity index 83% rename from tensorflow/lite/experimental/c/c_api.h rename to tensorflow/lite/c/c_api.h index 09a045b1f2a..036df27b5d1 100644 --- a/tensorflow/lite/experimental/c/c_api.h +++ b/tensorflow/lite/c/c_api.h @@ -12,28 +12,59 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_C_C_API_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_C_C_API_H_ +#ifndef TENSORFLOW_LITE_C_C_API_H_ +#define TENSORFLOW_LITE_C_C_API_H_ #include #include -// Eventually the various C APIs defined in context.h will be migrated into -// the appropriate /c/c_api*.h header. For now, we pull in existing definitions -// for convenience. -#include "c_api_types.h" +#include "common.h" // -------------------------------------------------------------------------- -// Experimental C API for TensorFlowLite. -// -// The API leans towards simplicity and uniformity instead of convenience, as -// most usage will be by language-specific wrappers. -// -// Conventions: -// * We use the prefix TfLite for everything in the API. -// * size_t is used to represent byte sizes of objects that are -// materialized in the address space of the calling process. -// * int is used as an index into arrays. +/// C API for TensorFlow Lite. +/// +/// The API leans towards simplicity and uniformity instead of convenience, as +/// most usage will be by language-specific wrappers. It provides largely the +/// same set of functionality as that of the C++ TensorFlow Lite `Interpreter` +/// API, but is useful for shared libraries where having a stable ABI boundary +/// is important. +/// +/// Conventions: +/// * We use the prefix TfLite for everything in the API. +/// * size_t is used to represent byte sizes of objects that are +/// materialized in the address space of the calling process. +/// * int is used as an index into arrays. +/// +/// Usage: +///

+/// // Create the model and interpreter options.
+/// TfLiteModel* model = TfLiteModelCreateFromFile("/path/to/model.tflite");
+/// TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
+/// TfLiteInterpreterOptionsSetNumThreads(options, 2);
+///
+/// // Create the interpreter.
+/// TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options);
+///
+/// // Allocate tensors and populate the input tensor data.
+/// TfLiteInterpreterAllocateTensors(interpreter);
+/// TfLiteTensor* input_tensor =
+///     TfLiteInterpreterGetInputTensor(interpreter, 0);
+/// TfLiteTensorCopyFromBuffer(input_tensor, input.data(),
+///                            input.size() * sizeof(float));
+///
+/// // Execute inference.
+/// TfLiteInterpreterInvoke(interpreter);
+///
+/// // Extract the output tensor data.
+/// TfLiteTensor* output_tensor =
+//      TfLiteInterpreterGetInputTensor(interpreter, 0);
+/// TfLiteTensorCopyToBuffer(output_tensor, output.data(),
+///                          output.size() * sizeof(float));
+///
+/// // Dispose of the model and interpreter objects.
+/// TfLiteInterpreterDelete(interpreter);
+/// TfLiteInterpreterOptionsDelete(options);
+/// TfLiteModelDelete(model);
 
 #ifdef SWIG
 #define TFL_CAPI_EXPORT
@@ -235,4 +266,4 @@ TFL_CAPI_EXPORT extern TfLiteStatus TfLiteTensorCopyToBuffer(
 }  // extern "C"
 #endif  // __cplusplus
 
-#endif  // TENSORFLOW_LITE_EXPERIMENTAL_C_C_API_H_
+#endif  // TENSORFLOW_LITE_C_C_API_H_
diff --git a/tensorflow/lite/experimental/c/c_api_experimental.cc b/tensorflow/lite/c/c_api_experimental.cc
similarity index 93%
rename from tensorflow/lite/experimental/c/c_api_experimental.cc
rename to tensorflow/lite/c/c_api_experimental.cc
index 5bc305ef64b..4b812172937 100644
--- a/tensorflow/lite/experimental/c/c_api_experimental.cc
+++ b/tensorflow/lite/c/c_api_experimental.cc
@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/lite/experimental/c/c_api_experimental.h"
+#include "tensorflow/lite/c/c_api_experimental.h"
 
-#include "tensorflow/lite/experimental/c/c_api_internal.h"
+#include "tensorflow/lite/c/c_api_internal.h"
 
 #ifdef __cplusplus
 extern "C" {
diff --git a/tensorflow/lite/experimental/c/c_api_experimental.h b/tensorflow/lite/c/c_api_experimental.h
similarity index 92%
rename from tensorflow/lite/experimental/c/c_api_experimental.h
rename to tensorflow/lite/c/c_api_experimental.h
index ce1a4a37293..a8f1a4294f5 100644
--- a/tensorflow/lite/experimental/c/c_api_experimental.h
+++ b/tensorflow/lite/c/c_api_experimental.h
@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#ifndef TENSORFLOW_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_
-#define TENSORFLOW_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_
+#ifndef TENSORFLOW_LITE_C_C_API_EXPERIMENTAL_H_
+#define TENSORFLOW_LITE_C_C_API_EXPERIMENTAL_H_
 
 #include "tensorflow/lite/builtin_ops.h"
-#include "tensorflow/lite/experimental/c/c_api.h"
+#include "tensorflow/lite/c/c_api.h"
 
 #ifdef __cplusplus
 extern "C" {
diff --git a/tensorflow/lite/experimental/c/c_api_experimental_test.cc b/tensorflow/lite/c/c_api_experimental_test.cc
similarity index 94%
rename from tensorflow/lite/experimental/c/c_api_experimental_test.cc
rename to tensorflow/lite/c/c_api_experimental_test.cc
index 0d383998a29..ce72954774c 100644
--- a/tensorflow/lite/experimental/c/c_api_experimental_test.cc
+++ b/tensorflow/lite/c/c_api_experimental_test.cc
@@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/lite/experimental/c/c_api_experimental.h"
+#include "tensorflow/lite/c/c_api_experimental.h"
 
 #include 
 #include "tensorflow/lite/builtin_ops.h"
-#include "tensorflow/lite/experimental/c/c_api.h"
+#include "tensorflow/lite/c/c_api.h"
 #include "tensorflow/lite/testing/util.h"
 
 namespace {
diff --git a/tensorflow/lite/experimental/c/c_api_internal.h b/tensorflow/lite/c/c_api_internal.h
similarity index 91%
rename from tensorflow/lite/experimental/c/c_api_internal.h
rename to tensorflow/lite/c/c_api_internal.h
index 8f5c301bc1d..474482d159a 100644
--- a/tensorflow/lite/experimental/c/c_api_internal.h
+++ b/tensorflow/lite/c/c_api_internal.h
@@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#ifndef TENSORFLOW_LITE_EXPERIMENTAL_C_C_API_INTERNAL_H_
-#define TENSORFLOW_LITE_EXPERIMENTAL_C_C_API_INTERNAL_H_
+#ifndef TENSORFLOW_LITE_C_C_API_INTERNAL_H_
+#define TENSORFLOW_LITE_C_C_API_INTERNAL_H_
 
-#include "tensorflow/lite/experimental/c/c_api.h"
+#include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/interpreter.h"
 #include "tensorflow/lite/model.h"
 #include "tensorflow/lite/op_resolver.h"
 
 // Internal structures used by the C API. These are likely to change and should
-// not be depended on.
+// not be depended on directly by any C API clients.
 //
 // NOTE: This header does not follow C conventions and does not define a C API.
 // It is effectively an (internal) implementation detail of the C API.
diff --git a/tensorflow/lite/experimental/c/c_api_test.cc b/tensorflow/lite/c/c_api_test.cc
similarity index 99%
rename from tensorflow/lite/experimental/c/c_api_test.cc
rename to tensorflow/lite/c/c_api_test.cc
index 8de0f414086..eb2a70f9f0b 100644
--- a/tensorflow/lite/experimental/c/c_api_test.cc
+++ b/tensorflow/lite/c/c_api_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/lite/experimental/c/c_api.h"
+#include "tensorflow/lite/c/c_api.h"
 
 #include 
 #include 
diff --git a/tensorflow/lite/experimental/c/exported_symbols.lds b/tensorflow/lite/c/exported_symbols.lds
similarity index 100%
rename from tensorflow/lite/experimental/c/exported_symbols.lds
rename to tensorflow/lite/c/exported_symbols.lds
diff --git a/tensorflow/lite/experimental/c/version_script.lds b/tensorflow/lite/c/version_script.lds
similarity index 100%
rename from tensorflow/lite/experimental/c/version_script.lds
rename to tensorflow/lite/c/version_script.lds
diff --git a/tensorflow/lite/core/api/profiler.h b/tensorflow/lite/core/api/profiler.h
index aea70cd73f8..7bc296510d4 100644
--- a/tensorflow/lite/core/api/profiler.h
+++ b/tensorflow/lite/core/api/profiler.h
@@ -93,9 +93,6 @@ class ScopedOperatorProfile : public ScopedProfile {
   tflite::ScopedOperatorProfile TFLITE_VARNAME_UNIQ(_profile_, __COUNTER__)( \
       (profiler), (tag), (node_index))
 
-#define TFLITE_SCOPED_OPERATOR_PROFILE(profiler, node_index) \
-  TFLITE_SCOPED_TAGGED_OPERATOR_PROFILE((profiler), "OpInvoke", (node_index))
-
 #define TFLITE_SCOPED_DELEGATE_OPERATOR_PROFILE(profiler, node_index)   \
   TFLITE_SCOPED_TAGGED_OPERATOR_PROFILE((profiler), "DelegateOpInvoke", \
                                         (node_index))
diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc
index 38a9d24d782..e453ff2ff7e 100644
--- a/tensorflow/lite/core/subgraph.cc
+++ b/tensorflow/lite/core/subgraph.cc
@@ -758,7 +758,17 @@ TfLiteStatus Subgraph::Invoke() {
     TfLiteNode& node = nodes_and_registration_[node_index].first;
     const TfLiteRegistration& registration =
         nodes_and_registration_[node_index].second;
-    TFLITE_SCOPED_OPERATOR_PROFILE(profiler_.get(), node_index);
+
+    const char* op_name = nullptr;
+    if (profiler_) {
+      if (registration.builtin_code == tflite::BuiltinOperator_CUSTOM) {
+        const char* const custom_name = registration.custom_name;
+        op_name = custom_name ? custom_name : "UnknownCustomOp";
+      } else {
+        op_name = tflite::EnumNamesBuiltinOperator()[registration.builtin_code];
+      }
+    }
+    TFLITE_SCOPED_TAGGED_OPERATOR_PROFILE(profiler_.get(), op_name, node_index);
 
     // TODO(ycling): This is an extra loop through inputs to check if the data
     // need to be copied from Delegate buffer to raw memory, which is often not
diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD
index 27940fd88be..4cfbeff2081 100644
--- a/tensorflow/lite/delegates/gpu/BUILD
+++ b/tensorflow/lite/delegates/gpu/BUILD
@@ -37,6 +37,7 @@ cc_library(
         "//conditions:default": [],
     }),
     deps = [
+        "@com_google_absl//absl/base:core_headers",
         "@com_google_absl//absl/types:span",
         "//tensorflow/lite:kernel_api",
         "//tensorflow/lite:minimal_logging",
@@ -106,7 +107,7 @@ objc_library(
     ],
 )
 
-# build -c opt --config android_arm64 --copt -Os --copt -DTFLITE_GPU_BINARY_RELEASE --copt -fvisibility=hidden --linkopt -s --strip always :libtensorflowlite_gpu_gl.so
+# build -c opt --config android_arm64 --copt -Os --copt -DTFLITE_GPU_BINARY_RELEASE --copt --linkopt -s --strip always :libtensorflowlite_gpu_gl.so
 cc_binary(
     name = "libtensorflowlite_gpu_gl.so",
     linkopts = [
@@ -115,8 +116,12 @@ cc_binary(
         "//tensorflow:android": [
             "-lEGL",
             "-lGLESv3",
+            "-fvisibility=hidden",
+        ],
+        "//tensorflow:windows": [],
+        "//conditions:default": [
+            "-fvisibility=hidden",
         ],
-        "//conditions:default": [],
     }),
     linkshared = 1,
     linkstatic = 1,
@@ -127,7 +132,7 @@ cc_binary(
     deps = [":gl_delegate"],
 )
 
-# build -c opt --config android_arm64 --copt -Os --copt -DTFLITE_GPU_BINARY_RELEASE --copt -fvisibility=hidden --linkopt -s --strip always :libtensorflowlite_gpu_delegate.so
+# build -c opt --config android_arm64 --copt -Os --copt -DTFLITE_GPU_BINARY_RELEASE --copt --linkopt -s --strip always :libtensorflowlite_gpu_delegate.so
 cc_binary(
     name = "libtensorflowlite_gpu_delegate.so",
     linkopts = [
@@ -136,8 +141,12 @@ cc_binary(
         "//tensorflow:android": [
             "-lEGL",
             "-lGLESv3",
+            "-fvisibility=hidden",
+        ],
+        "//tensorflow:windows": [],
+        "//conditions:default": [
+            "-fvisibility=hidden",
         ],
-        "//conditions:default": [],
     }),
     linkshared = 1,
     linkstatic = 1,
diff --git a/tensorflow/lite/delegates/gpu/common/BUILD b/tensorflow/lite/delegates/gpu/common/BUILD
index 20ff67677a9..4da852b0565 100644
--- a/tensorflow/lite/delegates/gpu/common/BUILD
+++ b/tensorflow/lite/delegates/gpu/common/BUILD
@@ -19,6 +19,19 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "custom_parsers",
+    srcs = ["custom_parsers.cc"],
+    hdrs = ["custom_parsers.h"],
+    deps = [
+        "//tensorflow/lite/delegates/gpu/common:shape",
+        "//tensorflow/lite/delegates/gpu/common:status",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:any",
+        "@flatbuffers",
+    ],
+)
+
 cc_library(
     name = "access_type",
     hdrs = ["access_type.h"],
@@ -96,6 +109,7 @@ cc_library(
     srcs = ["model_builder.cc"],
     hdrs = ["model_builder.h"],
     deps = [
+        ":custom_parsers",
         ":data_type",
         ":model",
         ":operations",
diff --git a/tensorflow/lite/delegates/gpu/common/custom_parsers.cc b/tensorflow/lite/delegates/gpu/common/custom_parsers.cc
new file mode 100644
index 00000000000..d46a9247c81
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/common/custom_parsers.cc
@@ -0,0 +1,36 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/delegates/gpu/common/custom_parsers.h"
+
+#include 
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/any.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/status.h"
+
+namespace tflite {
+namespace gpu {
+
+Status ParseCustomAttributes(absl::string_view op_name, const void* data,
+                             uint32_t data_size, absl::any* attr,
+                             BHWC* output_shape) {
+  return UnimplementedError(absl::StrCat(
+      "Attributes parsing is not enabled for ", op_name, " operation"));
+}
+
+}  // namespace gpu
+}  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/custom_parsers.h b/tensorflow/lite/delegates/gpu/common/custom_parsers.h
new file mode 100644
index 00000000000..e9a191d46cb
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/common/custom_parsers.h
@@ -0,0 +1,37 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_PARSERS_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_PARSERS_H_
+
+#include 
+
+#include "absl/strings/string_view.h"
+#include "absl/types/any.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/status.h"
+
+namespace tflite {
+namespace gpu {
+
+// Matches the custom operation by the string name and parses attributes stored
+// as flexbuffers.
+Status ParseCustomAttributes(absl::string_view op_name, const void* data,
+                             uint32_t data_size, absl::any* attr,
+                             BHWC* output_shape);
+
+}  // namespace gpu
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_PARSERS_H_
diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc
index d7fe8938699..8e33c4eeb75 100644
--- a/tensorflow/lite/delegates/gpu/common/model_builder.cc
+++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc
@@ -36,6 +36,7 @@ limitations under the License.
 #include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/context.h"
+#include "tensorflow/lite/delegates/gpu/common/custom_parsers.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
@@ -843,7 +844,7 @@ class Conv2DOperationParser : public TFLiteOperationParser {
   Status IsSupported(const TfLiteContext* context,
                      const TfLiteNode* tflite_node,
                      const TfLiteRegistration* registration) final {
-    RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
+    RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
     RETURN_IF_ERROR(
         CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
     RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
@@ -2227,6 +2228,110 @@ class SpaceToBatchOperationParser : public TFLiteOperationParser {
   }
 };
 
+class RoIToTransformMatrixOperationParser : public TFLiteOperationParser {
+ public:
+  Status IsSupported(const TfLiteContext* context,
+                     const TfLiteNode* tflite_node,
+                     const TfLiteRegistration* registration) final {
+    RETURN_IF_ERROR(
+        CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
+    return OkStatus();
+  }
+
+  Status Parse(const TfLiteNode* tflite_node,
+               const TfLiteRegistration* registration, GraphFloat32* graph,
+               ObjectReader* reader) final {
+    Node* node = graph->NewNode();
+    RETURN_IF_ERROR(reader->AddInput(node, 0));  // bbox
+    RETURN_IF_ERROR(reader->AddOutputs(node));
+
+    std::string op_name = "roi_to_transform_matrix";
+    node->operation.type = op_name;
+    BHWC output_shape;
+    RETURN_IF_ERROR(
+        ParseCustomAttributes(op_name, tflite_node->custom_initial_data,
+                              tflite_node->custom_initial_data_size,
+                              &(node->operation.attributes), &output_shape));
+
+    auto output_value = graph->FindOutputs(node->id)[0];
+    output_value->tensor.shape = output_shape;
+    return OkStatus();
+  }
+
+ private:
+};
+
+class TransformTensorOperationParser : public TFLiteOperationParser {
+ public:
+  Status IsSupported(const TfLiteContext* context,
+                     const TfLiteNode* tflite_node,
+                     const TfLiteRegistration* registration) final {
+    RETURN_IF_ERROR(
+        CheckInputsOutputs(context, tflite_node, /*inputs=*/2, /*outputs=*/1));
+    return OkStatus();
+  }
+
+  Status Parse(const TfLiteNode* tflite_node,
+               const TfLiteRegistration* registration, GraphFloat32* graph,
+               ObjectReader* reader) final {
+    Node* node = graph->NewNode();
+    RETURN_IF_ERROR(reader->AddInput(node, 0));  // data
+    RETURN_IF_ERROR(reader->AddInput(node, 1));  // bbox
+    RETURN_IF_ERROR(reader->AddOutputs(node));
+
+    std::string op_name = "transform_tensor";
+    node->operation.type = op_name;
+    BHWC output_shape;
+    RETURN_IF_ERROR(
+        ParseCustomAttributes(op_name, tflite_node->custom_initial_data,
+                              tflite_node->custom_initial_data_size,
+                              &(node->operation.attributes), &output_shape));
+
+    auto output_value = graph->FindOutputs(node->id)[0];
+
+    output_value->tensor.shape =
+        BHWC(1, output_shape.h, output_shape.w,
+             graph->FindInputs(node->id)[0]->tensor.shape.c);
+    return OkStatus();
+  }
+
+ private:
+};
+
+class TransformLandmarksOperationParser : public TFLiteOperationParser {
+ public:
+  Status IsSupported(const TfLiteContext* context,
+                     const TfLiteNode* tflite_node,
+                     const TfLiteRegistration* registration) final {
+    RETURN_IF_ERROR(
+        CheckInputsOutputs(context, tflite_node, /*inputs=*/2, /*outputs=*/1));
+    return OkStatus();
+  }
+
+  Status Parse(const TfLiteNode* tflite_node,
+               const TfLiteRegistration* registration, GraphFloat32* graph,
+               ObjectReader* reader) final {
+    Node* node = graph->NewNode();
+    RETURN_IF_ERROR(reader->AddInput(node, 0));  // data
+    RETURN_IF_ERROR(reader->AddInput(node, 1));  // bbox
+    RETURN_IF_ERROR(reader->AddOutputs(node));
+    std::string op_name = "transform_landmarks";
+    node->operation.type = op_name;
+    BHWC output_shape;
+    RETURN_IF_ERROR(
+        ParseCustomAttributes(op_name, tflite_node->custom_initial_data,
+                              tflite_node->custom_initial_data_size,
+                              &(node->operation.attributes), &output_shape));
+
+    auto output_value = graph->FindOutputs(node->id)[0];
+
+    output_value->tensor.shape = graph->FindInputs(node->id)[0]->tensor.shape;
+    return OkStatus();
+  }
+
+ private:
+};
+
 class UnsupportedOperationParser : public TFLiteOperationParser {
  public:
   Status IsSupported(const TfLiteContext* context,
@@ -2332,6 +2437,18 @@ std::unique_ptr NewOperationParser(
       if (custom_name == "MaxUnpooling2D") {
         return absl::make_unique();
       }
+      if (custom_name == "RoIToTransformMatrix") {
+        return absl::make_unique();
+      }
+
+      if (custom_name == "TransformTensor") {
+        return absl::make_unique();
+      }
+
+      if (custom_name == "TransformLandmarks") {
+        return absl::make_unique();
+      }
+
       break;
   }
   return absl::make_unique();
diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc b/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc
index 871cd505368..efaf39390d9 100644
--- a/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc
+++ b/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc
@@ -62,14 +62,17 @@ class Softmax : public NodeShader {
     std::string source = R"(
   highp float sum = 0.0;
   for (int d = 0; d < $src_depth$ - 1; ++d) {
-    sum += dot(vec4(1.0), exp($input_data_0[gid.x, gid.y, d]$));
+    highp vec4 v = $input_data_0[gid.x, gid.y, d]$;
+    sum += dot(vec4(1.0), exp(v));
   }
   {
     int d = $src_depth$ - 1;
-    sum += dot($mask$, exp($input_data_0[gid.x, gid.y, d]$));
+    highp vec4 v = $input_data_0[gid.x, gid.y, d]$;
+    sum += dot($mask$, exp(v));
   }
   for (int d = 0; d < $src_depth$; ++d) {
-    vec4 temp_sum = exp($input_data_0[gid.x, gid.y, d]$) / sum;
+    highp vec4 v = $input_data_0[gid.x, gid.y, d]$;
+    vec4 temp_sum = exp(v) / sum;
     $output_data_0[gid.x, gid.y, d] = temp_sum$;
   }
 )";
diff --git a/tensorflow/lite/delegates/gpu/gl_delegate.h b/tensorflow/lite/delegates/gpu/gl_delegate.h
index f1d30fd946e..bfc15fb120e 100644
--- a/tensorflow/lite/delegates/gpu/gl_delegate.h
+++ b/tensorflow/lite/delegates/gpu/gl_delegate.h
@@ -19,6 +19,7 @@ limitations under the License.
 #include 
 
 #include 
+#include "absl/base/macros.h"
 #include "tensorflow/lite/c/common.h"
 
 #ifdef SWIG
@@ -39,6 +40,15 @@ limitations under the License.
 extern "C" {
 #endif  // __cplusplus
 
+// WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING
+//
+// GPU delegate declared in this file is OBSOLETE and replaced with the delegate
+// declared in delegate.h. New delegate combines all GL, CL and soon
+// Vulkan-based implementations in one.
+// Please migrate before end of 2019.
+//
+// WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING
+
 // LINT.IfChange
 enum TfLiteGlObjectType {
   TFLITE_GL_OBJECT_TYPE_FASTEST = 0,
@@ -109,6 +119,7 @@ TFL_CAPI_EXPORT TfLiteGpuDelegateOptions TfLiteGpuDelegateOptionsDefault();
 //   .preferred_gl_object_type = TFLITE_GL_OBJECT_TYPE_FASTEST,
 //   .dynamic_batch_enabled = false,
 // },
+ABSL_DEPRECATED("Use TfLiteGpuDelegateV2Create defined in delegate.h instead.")
 TFL_CAPI_EXPORT TfLiteDelegate* TfLiteGpuDelegateCreate(
     const TfLiteGpuDelegateOptions* options);
 
diff --git a/tensorflow/lite/delegates/nnapi/BUILD b/tensorflow/lite/delegates/nnapi/BUILD
index 572adc5a0cc..54251676da3 100644
--- a/tensorflow/lite/delegates/nnapi/BUILD
+++ b/tensorflow/lite/delegates/nnapi/BUILD
@@ -103,6 +103,7 @@ cc_library(
     }),
     deps = [
         ":nnapi_delegate",
+        "//tensorflow/lite/nnapi:nnapi_handler",
         "//tensorflow/lite/nnapi:nnapi_implementation",
         "@com_google_absl//absl/memory",
         "@com_google_googletest//:gtest",
@@ -156,6 +157,30 @@ cc_test(
     ],
 )
 
+cc_test(
+    name = "nnapi_delegate_device_selection_test",
+    size = "small",
+    srcs = [
+        "nnapi_delegate_device_selection_test.cc",
+    ],
+    tags = [
+        "no_mac",
+        "no_windows",
+        "tflite_not_portable_ios",
+    ],
+    deps = [
+        ":nnapi_delegate",
+        ":nnapi_delegate_mock_test",
+        "//tensorflow/lite:framework",
+        "//tensorflow/lite:minimal_logging",
+        "//tensorflow/lite/c:common",
+        "//tensorflow/lite/kernels:test_util",
+        "//tensorflow/lite/nnapi:nnapi_implementation",
+        "//tensorflow/lite/nnapi:nnapi_lib",
+        "@com_google_googletest//:gtest",
+    ],
+)
+
 cc_test(
     name = "quant_lstm_sup_test",
     size = "small",
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
index 3e4967aebfc..cc73f3020e5 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
@@ -2873,12 +2873,34 @@ TfLiteStatus NNAPIDelegateKernel::Init(TfLiteContext* context,
   const auto delegate_options =
       StatefulNnApiDelegate::GetOptions(params->delegate);
   const char* device_name_ptr = delegate_options.accelerator_name;
-  // user specified an acclelerator to use.
-  if (nnapi_->android_sdk_version >= kMinSdkVersionForNNAPI12 &&
-      device_name_ptr != nullptr) {
-    nnapi_device_ = GetDeviceHandle(context, device_name_ptr);
-    if (nnapi_device_ == nullptr) {
-      return kTfLiteError;
+  if (nnapi_->android_sdk_version >= kMinSdkVersionForNNAPI12) {
+    if (device_name_ptr != nullptr) {
+      // User specified an accelerator to use.
+      ANeuralNetworksDevice* nnapi_device =
+          GetDeviceHandle(context, device_name_ptr);
+      if (nnapi_device == nullptr) {
+        return kTfLiteError;
+      }
+      nnapi_devices_.push_back(nnapi_device);
+    } else if (delegate_options.disallow_nnapi_cpu) {
+      std::string nnapi_cpu("nnapi-reference");
+      uint32_t num_devices = 0;
+      NnApiImplementation()->ANeuralNetworks_getDeviceCount(&num_devices);
+
+      for (uint32_t i = 0; i < num_devices; i++) {
+        ANeuralNetworksDevice* device = nullptr;
+        const char* buffer = nullptr;
+        NnApiImplementation()->ANeuralNetworks_getDevice(i, &device);
+        NnApiImplementation()->ANeuralNetworksDevice_getName(device, &buffer);
+        if (nnapi_cpu != buffer) {
+          nnapi_devices_.push_back(device);
+        }
+      }
+      if (nnapi_devices_.empty()) {
+        context->ReportError(
+            context, "NNAPI delegate requested but no accelerators available.");
+        return kTfLiteError;
+      }
     }
   }
 
@@ -2898,12 +2920,13 @@ TfLiteStatus NNAPIDelegateKernel::Init(TfLiteContext* context,
 
   if (!nn_compilation_) {
     ANeuralNetworksCompilation* compilation = nullptr;
-    if (nnapi_device_ != nullptr) {
+    if (!nnapi_devices_.empty()) {
       // Compile for the selected accelerator.
       RETURN_TFLITE_ERROR_IF_NN_ERROR(
           context,
           nnapi_->ANeuralNetworksCompilation_createForDevices(
-              nn_model_.get(), &nnapi_device_, 1, &compilation),
+              nn_model_.get(), nnapi_devices_.data(), nnapi_devices_.size(),
+              &compilation),
           nnapi_errno);
     } else {
       RETURN_TFLITE_ERROR_IF_NN_ERROR(context,
@@ -3587,6 +3610,7 @@ StatefulNnApiDelegate::StatefulNnApiDelegate(Options options)
   if (options.model_token) {
     delegate_data_.model_token = options.model_token;
   }
+  delegate_data_.disallow_nnapi_cpu = options.disallow_nnapi_cpu;
   TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
                        "Created TensorFlow Lite delegate for NNAPI.");
   Prepare = DoPrepare;
@@ -3613,6 +3637,7 @@ const StatefulNnApiDelegate::Options StatefulNnApiDelegate::GetOptions(
   options.model_token = delegate_data->model_token.empty()
                             ? nullptr
                             : delegate_data->model_token.c_str();
+  options.disallow_nnapi_cpu = delegate_data->disallow_nnapi_cpu;
   return options;
 }
 
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.h b/tensorflow/lite/delegates/nnapi/nnapi_delegate.h
index 9fdbe626320..022e9ed53ac 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.h
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.h
@@ -63,6 +63,12 @@ class StatefulNnApiDelegate : public TfLiteDelegate {
     // NOTE: when using compilation caching, it is not recommended to use the
     // same delegate instance for multiple models.
     const char* model_token = nullptr;
+
+    // Whether to disallow NNAPI CPU usage. Only effective on Android 10 and
+    // above. The NNAPI CPU typically performs less well than built-in TfLite
+    // kernels, but allowing CPU allows partial acceleration of models. If this
+    // is set to true, NNAPI is only used if the whole model is accelerated.
+    bool disallow_nnapi_cpu = false;
   };
 
   // Uses default options.
@@ -131,6 +137,8 @@ class StatefulNnApiDelegate : public TfLiteDelegate {
     std::string cache_dir;
     // The unique token string for NNAPI model.
     std::string model_token;
+    // Whether to disallow NNAPI CPU.
+    bool disallow_nnapi_cpu;
     // Tensor to ANeuralNetworksMemory mapping.
     std::vector tensor_memory_map;
     // Constains a non zero value if any NNAPI method call
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_device_selection_test.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_device_selection_test.cc
new file mode 100644
index 00000000000..146bf1eaa47
--- /dev/null
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_device_selection_test.cc
@@ -0,0 +1,190 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include 
+
+#include 
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
+#include "tensorflow/lite/delegates/nnapi/nnapi_delegate_mock_test.h"
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/model.h"
+#include "tensorflow/lite/nnapi/NeuralNetworksTypes.h"
+#include "tensorflow/lite/nnapi/nnapi_implementation.h"
+
+namespace tflite {
+namespace {
+
+class SingleOpModelWithNNAPI : public SingleOpModel {
+ public:
+  SingleOpModelWithNNAPI() = default;
+  void Init(tflite::StatefulNnApiDelegate::Options options) {
+    stateful_delegate_.reset(new StatefulNnApiDelegate(options));
+    auto* delegate = stateful_delegate_.get();
+    this->SetApplyDelegate([delegate, this](Interpreter* interpreter) {
+      compilation_status_ = interpreter->ModifyGraphWithDelegate(delegate);
+    });
+  }
+
+  StatefulNnApiDelegate* GetDelegate() { return stateful_delegate_.get(); }
+
+  void SetBufferHandle(int index, TfLiteBufferHandle handle) {
+    interpreter_->SetBufferHandle(index, handle, stateful_delegate_.get());
+  }
+  TfLiteStatus GetCompilationStatus() { return compilation_status_; }
+
+ private:
+  std::unique_ptr stateful_delegate_;
+  TfLiteStatus compilation_status_;
+};
+
+class FloatAddOpModel : public SingleOpModelWithNNAPI {
+ public:
+  FloatAddOpModel() = default;
+  void Init(tflite::StatefulNnApiDelegate::Options options,
+            const TensorData& input1, const TensorData& input2,
+            const TensorData& output, ActivationFunctionType activation_type,
+            bool allow_fp32_relax_to_fp16 = false) {
+    SingleOpModelWithNNAPI::Init(options);
+    input1_ = AddInput(input1);
+    input2_ = AddInput(input2);
+    output_ = AddOutput(output);
+    SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions,
+                 CreateAddOptions(builder_, activation_type).Union());
+    BuildInterpreter({GetShape(input1_), GetShape(input2_)},
+                     allow_fp32_relax_to_fp16);
+  }
+
+  int input1() { return input1_; }
+  int input2() { return input2_; }
+
+  std::vector GetOutput() { return ExtractVector(output_); }
+
+ protected:
+  int input1_;
+  int input2_;
+  int output_;
+
+ private:
+};
+
+struct NnApiDeviceSelectionTest
+    : ::tflite::delegate::nnapi::NnApiDelegateMockTest {
+  void SetUp() override {
+    ::tflite::delegate::nnapi::NnApiDelegateMockTest::SetUp();
+    nnapi_->ANeuralNetworks_getDeviceCount = [](uint32_t* numDevices) -> int {
+      *numDevices = 3;
+      return 0;
+    };
+    nnapi_->ANeuralNetworks_getDevice =
+        [](uint32_t devIndex, ANeuralNetworksDevice** device) -> int {
+      *device = reinterpret_cast(devIndex + 1);
+      return 0;
+    };
+    nnapi_->ANeuralNetworksDevice_getName =
+        [](const ANeuralNetworksDevice* device, const char** name) -> int {
+      if (device == reinterpret_cast(1)) {
+        *name = "dsp";
+      } else if (device == reinterpret_cast(2)) {
+        *name = "gpu";
+      } else {
+        *name = "nnapi-reference";
+      }
+      return 0;
+    };
+  }
+  void InitWithOptions(tflite::StatefulNnApiDelegate::Options options) {
+    m.Init(options, {TensorType_FLOAT32, {1, 2, 2, 1}},
+           {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {}},
+           ActivationFunctionType_NONE);
+    m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8});
+    m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5});
+  }
+  FloatAddOpModel m;
+};
+
+TEST_F(NnApiDeviceSelectionTest, DoesntSetDevicesWithoutFlags) {
+  nnapi_->ANeuralNetworksCompilation_createForDevices =
+      [](ANeuralNetworksModel* model,
+         const ANeuralNetworksDevice* const* devices, uint32_t numDevices,
+         ANeuralNetworksCompilation** compilation) -> int {
+    EXPECT_TRUE(false) << "Should not call createForDevices";
+    return 1;
+  };
+
+  tflite::StatefulNnApiDelegate::Options options;
+  InitWithOptions(options);
+  m.Invoke();
+  EXPECT_EQ(m.GetCompilationStatus(), kTfLiteOk);
+}
+
+TEST_F(NnApiDeviceSelectionTest, SetsDeviceBasedOnOptions) {
+  nnapi_mock_->CompilationCreateReturns<1>();
+  nnapi_->ANeuralNetworksCompilation_createForDevices =
+      [](ANeuralNetworksModel* model,
+         const ANeuralNetworksDevice* const* devices, uint32_t numDevices,
+         ANeuralNetworksCompilation** compilation) -> int {
+    EXPECT_EQ(numDevices, 1);
+    EXPECT_EQ(devices[0], reinterpret_cast(1));
+    if (numDevices != 1 ||
+        devices[0] != reinterpret_cast(1)) {
+      return 1;
+    } else {
+      *compilation = reinterpret_cast(3);
+      return 0;
+    }
+  };
+
+  tflite::StatefulNnApiDelegate::Options options;
+  options.accelerator_name = "dsp";
+  InitWithOptions(options);
+  m.Invoke();
+  EXPECT_EQ(m.GetCompilationStatus(), kTfLiteOk);
+}
+
+TEST_F(NnApiDeviceSelectionTest, DisallowsCPUBasedOnOptions) {
+  nnapi_mock_->CompilationCreateReturns<1>();
+  nnapi_->ANeuralNetworksCompilation_createForDevices =
+      [](ANeuralNetworksModel* model,
+         const ANeuralNetworksDevice* const* devices, uint32_t numDevices,
+         ANeuralNetworksCompilation** compilation) -> int {
+    EXPECT_EQ(numDevices, 2);
+    EXPECT_EQ(devices[0], reinterpret_cast(1));
+    EXPECT_EQ(devices[1], reinterpret_cast(2));
+    if (numDevices != 2 ||
+        devices[0] != reinterpret_cast(1) ||
+        devices[1] != reinterpret_cast(2)) {
+      return 1;
+    } else {
+      *compilation = reinterpret_cast(3);
+      return 0;
+    }
+  };
+
+  tflite::StatefulNnApiDelegate::Options options;
+  options.disallow_nnapi_cpu = true;
+  InitWithOptions(options);
+  m.Invoke();
+  EXPECT_EQ(m.GetCompilationStatus(), kTfLiteOk);
+}
+
+}  // namespace
+}  // namespace tflite
+
+int main(int argc, char** argv) {
+  ::tflite::LogToStderr();
+  ::testing::InitGoogleTest(&argc, argv);
+  return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h b/tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h
index 5390b181583..6a9493f9f4d 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h
@@ -285,7 +285,7 @@ class NNAPIDelegateKernel {
   // Access to NNApi.
   const NnApi* nnapi_;
   // ANN device handle.
-  ANeuralNetworksDevice* nnapi_device_ = nullptr;
+  std::vector nnapi_devices_;
   // ANN API state.
   std::unique_ptr nn_model_;
   std::unique_ptr
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_mock_test.h b/tensorflow/lite/delegates/nnapi/nnapi_delegate_mock_test.h
index 8551bdea0a8..4a48409de1e 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_mock_test.h
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_mock_test.h
@@ -28,134 +28,17 @@ limitations under the License.
 #include 
 #include "absl/memory/memory.h"
 #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
+#include "tensorflow/lite/nnapi/nnapi_handler.h"
 #include "tensorflow/lite/nnapi/nnapi_implementation.h"
 
 namespace tflite {
 namespace delegate {
 namespace nnapi {
 
-class NnApiMock {
+class NnApiMock : public ::tflite::nnapi::NnApiHandler {
  public:
-  template 
-  void GetDeviceCountReturns() {
-    nnapi_->ANeuralNetworks_getDeviceCount = [](uint32_t* numDevices) -> int {
-      *numDevices = 2;
-      return Value;
-    };
-  }
-
-  template 
-  void ModelCreateReturns() {
-    nnapi_->ANeuralNetworksModel_create = [](ANeuralNetworksModel** model) {
-      *model = reinterpret_cast(1);
-      return Value;
-    };
-  }
-
-  template 
-  void AddOperandReturns() {
-    nnapi_->ANeuralNetworksModel_addOperand =
-        [](ANeuralNetworksModel* model,
-           const ANeuralNetworksOperandType* type) { return Value; };
-  }
-
-  template 
-  void SetOperandValueReturns() {
-    nnapi_->ANeuralNetworksModel_setOperandValue =
-        [](ANeuralNetworksModel* model, int32_t index, const void* buffer,
-           size_t length) { return Value; };
-  }
-
-  template 
-  void AddOperationReturns() {
-    nnapi_->ANeuralNetworksModel_addOperation =
-        [](ANeuralNetworksModel* model, ANeuralNetworksOperationType type,
-           uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount,
-           const uint32_t* outputs) { return Value; };
-  }
-
-  template 
-  void IdentifyInputAndOutputsReturns() {
-    nnapi_->ANeuralNetworksModel_identifyInputsAndOutputs =
-        [](ANeuralNetworksModel* model, uint32_t inputCount,
-           const uint32_t* inputs, uint32_t outputCount,
-           const uint32_t* outputs) { return Value; };
-  }
-
-  template 
-  void RelaxComputationFloatReturns() {
-    nnapi_->ANeuralNetworksModel_relaxComputationFloat32toFloat16 =
-        [](ANeuralNetworksModel* model, bool allow) { return Value; };
-  }
-
-  template 
-  void ModelFinishReturns() {
-    nnapi_->ANeuralNetworksModel_finish = [](ANeuralNetworksModel* model) {
-      return Value;
-    };
-  }
-
-  template 
-  void MemoryCreateFromFdReturns() {
-    nnapi_->ANeuralNetworksMemory_createFromFd =
-        [](size_t size, int protect, int fd, size_t offset,
-           ANeuralNetworksMemory** memory) {
-          *memory = reinterpret_cast(2);
-          return Value;
-        };
-  }
-
-  template 
-  void CompilationCreateReturns() {
-    nnapi_->ANeuralNetworksCompilation_create =
-        [](ANeuralNetworksModel* model,
-           ANeuralNetworksCompilation** compilation) {
-          *compilation = reinterpret_cast(3);
-          return Value;
-        };
-  }
-
-  template 
-  void CompilationFinishReturns() {
-    nnapi_->ANeuralNetworksCompilation_finish =
-        [](ANeuralNetworksCompilation* compilation) { return Value; };
-  }
-
-  template 
-  void ExecutionCreateReturns() {
-    nnapi_->ANeuralNetworksExecution_create =
-        [](ANeuralNetworksCompilation* compilation,
-           ANeuralNetworksExecution** execution) {
-          if (compilation == nullptr) return 1;
-          *execution = reinterpret_cast(4);
-          return Value;
-        };
-  }
-  template 
-  void ExecutionSetInputFromMemoryReturns() {
-    nnapi_->ANeuralNetworksExecution_setInputFromMemory =
-        [](ANeuralNetworksExecution* execution, int32_t index,
-           const ANeuralNetworksOperandType* type,
-           const ANeuralNetworksMemory* memory, size_t offset,
-           size_t length) { return Value; };
-  }
-  template 
-  void ExecutionSetOutputFromMemoryReturns() {
-    nnapi_->ANeuralNetworksExecution_setOutputFromMemory =
-        [](ANeuralNetworksExecution* execution, int32_t index,
-           const ANeuralNetworksOperandType* type,
-           const ANeuralNetworksMemory* memory, size_t offset,
-           size_t length) { return Value; };
-  }
-
-  template 
-  void ExecutionComputeReturns() {
-    nnapi_->ANeuralNetworksExecution_compute =
-        [](ANeuralNetworksExecution* execution) { return Value; };
-  }
-
   explicit NnApiMock(NnApi* nnapi, int android_sdk_version = 29)
-      : nnapi_(nnapi), prev_nnapi_(*nnapi) {
+      : ::tflite::nnapi::NnApiHandler(nnapi) {
     nnapi_->nnapi_exists = true;
     nnapi_->android_sdk_version = android_sdk_version;
 
@@ -186,23 +69,16 @@ class NnApiMock {
     ExecutionComputeReturns<0>();
   }
 
-  ~NnApiMock() {
-    // Restores global NNAPI to original value for non mocked tests
-    *nnapi_ = prev_nnapi_;
-  }
-
- private:
-  NnApi* nnapi_;
-  NnApi prev_nnapi_;
+  ~NnApiMock() { Reset(); }
 };
 
 class NnApiDelegateMockTest : public ::testing::Test {
+ protected:
   void SetUp() override {
     nnapi_ = const_cast(NnApiImplementation());
     nnapi_mock_ = absl::make_unique(nnapi_);
   }
 
- protected:
   NnApi* nnapi_;
   std::unique_ptr nnapi_mock_;
 };
diff --git a/tensorflow/lite/examples/label_image/label_image.cc b/tensorflow/lite/examples/label_image/label_image.cc
index 452fa0c9682..a3d07a66a02 100644
--- a/tensorflow/lite/examples/label_image/label_image.cc
+++ b/tensorflow/lite/examples/label_image/label_image.cc
@@ -60,7 +60,9 @@ TfLiteDelegatePtr CreateGPUDelegate(Settings* s) {
   TfLiteGpuDelegateOptionsV2 gpu_opts = TfLiteGpuDelegateOptionsV2Default();
   gpu_opts.inference_preference =
       TFLITE_GPU_INFERENCE_PREFERENCE_SUSTAINED_SPEED;
-  gpu_opts.is_precision_loss_allowed = s->allow_fp16 ? 1 : 0;
+  gpu_opts.inference_priority1 =
+      s->allow_fp16 ? TFLITE_GPU_INFERENCE_PRIORITY_MIN_LATENCY
+                    : TFLITE_GPU_INFERENCE_PRIORITY_MAX_PRECISION;
   return evaluation::CreateGPUDelegate(s->model, &gpu_opts);
 #else
   return evaluation::CreateGPUDelegate(s->model);
diff --git a/tensorflow/lite/experimental/c/BUILD b/tensorflow/lite/experimental/c/BUILD
deleted file mode 100644
index 8e6b4803155..00000000000
--- a/tensorflow/lite/experimental/c/BUILD
+++ /dev/null
@@ -1,120 +0,0 @@
-load(
-    "//tensorflow/lite:build_def.bzl",
-    "tflite_cc_shared_object",
-    "tflite_copts",
-)
-
-package(
-    default_visibility = [":experimental"],
-    licenses = ["notice"],  # Apache 2.0
-)
-
-package_group(
-    name = "experimental",
-    packages = [
-        "//tensorflow/lite/experimental/...",
-        "//third_party/dart/tflite_native/...",  # whitelisted
-    ],
-)
-
-tflite_cc_shared_object(
-    name = "libtensorflowlite_c.so",
-    linkopts = select({
-        "//tensorflow:macos": [
-            "-Wl,-exported_symbols_list,$(location //tensorflow/lite/experimental/c:exported_symbols.lds)",
-            "-Wl,-install_name,@rpath/libtensorflowlite_c.so",
-        ],
-        "//tensorflow:windows": [],
-        "//conditions:default": [
-            "-z defs",
-            "-Wl,--version-script,$(location //tensorflow/lite/experimental/c:version_script.lds)",
-        ],
-    }),
-    deps = [
-        ":c_api",
-        ":c_api_experimental",
-        ":exported_symbols.lds",
-        ":version_script.lds",
-    ],
-)
-
-cc_library(
-    name = "c_api_internal",
-    srcs = [
-        "c_api.h",
-        "c_api_types.h",
-    ],
-    hdrs = ["c_api_internal.h"],
-    copts = tflite_copts(),
-    visibility = [
-        "//tensorflow/lite/experimental/c:__subpackages__",
-    ],
-    deps = [
-        "//tensorflow/lite:framework",
-        "//tensorflow/lite/c:common",
-    ],
-)
-
-cc_library(
-    name = "c_api",
-    srcs = ["c_api.cc"],
-    hdrs = [
-        "c_api.h",
-        "c_api_types.h",
-    ],
-    copts = tflite_copts(),
-    visibility = [
-        ":experimental",
-    ],
-    deps = [
-        ":c_api_internal",
-        "//tensorflow/lite:framework",
-        "//tensorflow/lite:version",
-        "//tensorflow/lite/c:common",
-        "//tensorflow/lite/kernels:builtin_ops",
-    ],
-    alwayslink = 1,
-)
-
-cc_library(
-    name = "c_api_experimental",
-    srcs = ["c_api_experimental.cc"],
-    hdrs = ["c_api_experimental.h"],
-    copts = tflite_copts(),
-    deps = [
-        ":c_api",
-        ":c_api_internal",
-        "//tensorflow/lite:kernel_api",
-    ],
-    alwayslink = 1,
-)
-
-cc_test(
-    name = "c_api_test",
-    size = "small",
-    srcs = ["c_api_test.cc"],
-    data = [
-        "//tensorflow/lite:testdata/add.bin",
-        "//tensorflow/lite:testdata/add_quantized.bin",
-    ],
-    deps = [
-        ":c_api",
-        "//tensorflow/lite/c:common",
-        "//tensorflow/lite/testing:util",
-        "@com_google_googletest//:gtest",
-    ],
-)
-
-cc_test(
-    name = "c_api_experimental_test",
-    size = "small",
-    srcs = ["c_api_experimental_test.cc"],
-    data = ["//tensorflow/lite:testdata/add.bin"],
-    deps = [
-        ":c_api",
-        ":c_api_experimental",
-        "//tensorflow/lite:kernel_api",
-        "//tensorflow/lite/testing:util",
-        "@com_google_googletest//:gtest",
-    ],
-)
diff --git a/tensorflow/lite/experimental/c/README.md b/tensorflow/lite/experimental/c/README.md
new file mode 100644
index 00000000000..a17f7f8f2c7
--- /dev/null
+++ b/tensorflow/lite/experimental/c/README.md
@@ -0,0 +1 @@
+The C API has been migrated to [lite/c](../../c/README.md).
diff --git a/tensorflow/lite/experimental/c/c_api_types.h b/tensorflow/lite/experimental/c/c_api_types.h
deleted file mode 100644
index b3b0ddc059d..00000000000
--- a/tensorflow/lite/experimental/c/c_api_types.h
+++ /dev/null
@@ -1,673 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// This file defines common C types and APIs for implementing operations,
-// delegates and other constructs in TensorFlow Lite. The actual operations and
-// delegtes can be defined using C++, but the interface between the interpreter
-// and the operations are C.
-//
-// Summary of abstractions
-// TF_LITE_ENSURE - Self-sufficient error checking
-// TfLiteStatus - Status reporting
-// TfLiteIntArray - stores tensor shapes (dims),
-// TfLiteContext - allows an op to access the tensors
-// TfLiteTensor - tensor (a multidimensional array)
-// TfLiteNode - a single node or operation
-// TfLiteRegistration - the implementation of a conceptual operation.
-// TfLiteDelegate - allows delegation of nodes to alternative backends.
-//
-// Some abstractions in this file are created and managed by Interpreter.
-
-#ifndef TENSORFLOW_LITE_C_COMMON_H_
-#define TENSORFLOW_LITE_C_COMMON_H_
-
-#include 
-#include 
-#include 
-
-#ifdef __cplusplus
-extern "C" {
-#endif  // __cplusplus
-
-typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus;
-
-// The list of external context types known to TF Lite. This list exists solely
-// to avoid conflicts and to ensure ops can share the external contexts they
-// need. Access to the external contexts is controled by one of the
-// corresponding support files.
-typedef enum {
-  kTfLiteEigenContext = 0,       // include eigen_support.h to use.
-  kTfLiteGemmLowpContext = 1,    // include gemm_support.h to use.
-  kTfLiteEdgeTpuContext = 2,     // Placeholder for Edge TPU support.
-  kTfLiteCpuBackendContext = 3,  // include cpu_backend_support.h to use.
-  kTfLiteMaxExternalContexts = 4
-} TfLiteExternalContextType;
-
-// Forward declare so dependent structs and methods can reference these types
-// prior to the struct definitions.
-struct TfLiteContext;
-struct TfLiteDelegate;
-struct TfLiteRegistration;
-
-// An external context is a collection of information unrelated to the TF Lite
-// framework, but useful to a subset of the ops. TF Lite knows very little
-// about about the actual contexts, but it keeps a list of them, and is able to
-// refresh them if configurations like the number of recommended threads
-// change.
-typedef struct {
-  TfLiteExternalContextType type;
-  TfLiteStatus (*Refresh)(struct TfLiteContext* context);
-} TfLiteExternalContext;
-
-#define kTfLiteOptionalTensor (-1)
-
-// Fixed size list of integers. Used for dimensions and inputs/outputs tensor
-// indices
-typedef struct {
-  int size;
-// gcc 6.1+ have a bug where flexible members aren't properly handled
-// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c
-#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \
-    __GNUC_MINOR__ >= 1
-  int data[0];
-#else
-  int data[];
-#endif
-} TfLiteIntArray;
-
-// Given the size (number of elements) in a TfLiteIntArray, calculate its size
-// in bytes.
-int TfLiteIntArrayGetSizeInBytes(int size);
-
-// Create a array of a given `size` (uninitialized entries).
-// This returns a pointer, that you must free using TfLiteIntArrayFree().
-TfLiteIntArray* TfLiteIntArrayCreate(int size);
-
-// Check if two intarrays are equal. Returns 1 if they are equal, 0 otherwise.
-int TfLiteIntArrayEqual(const TfLiteIntArray* a, const TfLiteIntArray* b);
-
-// Check if an intarray equals an array. Returns 1 if equals, 0 otherwise.
-int TfLiteIntArrayEqualsArray(const TfLiteIntArray* a, int b_size,
-                              const int b_data[]);
-
-// Create a copy of an array passed as `src`.
-// You are expected to free memory with TfLiteIntArrayFree
-TfLiteIntArray* TfLiteIntArrayCopy(const TfLiteIntArray* src);
-
-// Free memory of array `a`.
-void TfLiteIntArrayFree(TfLiteIntArray* a);
-
-// Fixed size list of floats. Used for per-channel quantization.
-typedef struct {
-  int size;
-// gcc 6.1+ have a bug where flexible members aren't properly handled
-// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c
-#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \
-    __GNUC_MINOR__ >= 1
-  float data[0];
-#else
-  float data[];
-#endif
-} TfLiteFloatArray;
-
-// Given the size (number of elements) in a TfLiteFloatArray, calculate its size
-// in bytes.
-int TfLiteFloatArrayGetSizeInBytes(int size);
-
-// Create a array of a given `size` (uninitialized entries).
-// This returns a pointer, that you must free using TfLiteFloatArrayFree().
-TfLiteFloatArray* TfLiteFloatArrayCreate(int size);
-
-// Free memory of array `a`.
-void TfLiteFloatArrayFree(TfLiteFloatArray* a);
-
-// Since we must not depend on any libraries, define a minimal subset of
-// error macros while avoiding names that have pre-conceived meanings like
-// assert and check.
-
-// Check whether value is true, and if not return kTfLiteError from
-// the current function (and report the error string msg).
-#define TF_LITE_ENSURE_MSG(context, value, msg)            \
-  do {                                                     \
-    if (!(value)) {                                        \
-      (context)->ReportError((context), __FILE__ " " msg); \
-      return kTfLiteError;                                 \
-    }                                                      \
-  } while (0)
-
-// Check whether the value `a` is true, and if not return kTfLiteError from
-// the current function, while also reporting the location of the error.
-#define TF_LITE_ENSURE(context, a)                                          \
-  do {                                                                      \
-    if (!(a)) {                                                             \
-      (context)->ReportError((context), "%s:%d %s was not true.", __FILE__, \
-                             __LINE__, #a);                                 \
-      return kTfLiteError;                                                  \
-    }                                                                       \
-  } while (0)
-
-#define TF_LITE_ENSURE_STATUS(a) \
-  do {                           \
-    if ((a) != kTfLiteOk) {      \
-      return kTfLiteError;       \
-    }                            \
-  } while (0)
-
-// Check whether the value `a == b` is true, and if not return kTfLiteError from
-// the current function, while also reporting the location of the error.
-// `a` and `b` may be evaluated more than once, so no side effects or
-// extremely expensive computations should be done.
-#define TF_LITE_ENSURE_EQ(context, a, b)                                       \
-  do {                                                                         \
-    if ((a) != (b)) {                                                          \
-      (context)->ReportError((context), "%s:%d %s != %s (%d != %d)", __FILE__, \
-                             __LINE__, #a, #b, (a), (b));                      \
-      return kTfLiteError;                                                     \
-    }                                                                          \
-  } while (0)
-
-#define TF_LITE_ENSURE_TYPES_EQ(context, a, b)                                 \
-  do {                                                                         \
-    if ((a) != (b)) {                                                          \
-      (context)->ReportError((context), "%s:%d %s != %s (%s != %s)", __FILE__, \
-                             __LINE__, #a, #b, TfLiteTypeGetName(a),           \
-                             TfLiteTypeGetName(b));                            \
-      return kTfLiteError;                                                     \
-    }                                                                          \
-  } while (0)
-
-#define TF_LITE_ENSURE_OK(context, status) \
-  do {                                     \
-    if ((status) != kTfLiteOk) {           \
-      return kTfLiteError;                 \
-    }                                      \
-  } while (0)
-
-// Single-precision complex data type compatible with the C99 definition.
-typedef struct {
-  float re, im;  // real and imaginary parts, respectively.
-} TfLiteComplex64;
-
-// Half precision data type compatible with the C99 definition.
-typedef struct {
-  uint16_t data;
-} TfLiteFloat16;
-
-// Types supported by tensor
-typedef enum {
-  kTfLiteNoType = 0,
-  kTfLiteFloat32 = 1,
-  kTfLiteInt32 = 2,
-  kTfLiteUInt8 = 3,
-  kTfLiteInt64 = 4,
-  kTfLiteString = 5,
-  kTfLiteBool = 6,
-  kTfLiteInt16 = 7,
-  kTfLiteComplex64 = 8,
-  kTfLiteInt8 = 9,
-  kTfLiteFloat16 = 10,
-} TfLiteType;
-
-// Return the name of a given type, for error reporting purposes.
-const char* TfLiteTypeGetName(TfLiteType type);
-
-// SupportedQuantizationTypes.
-typedef enum {
-  // No quantization.
-  kTfLiteNoQuantization = 0,
-  // Affine quantization (with support for per-channel quantization).
-  // Corresponds to TfLiteAffineQuantization.
-  kTfLiteAffineQuantization = 1,
-} TfLiteQuantizationType;
-
-// Structure specifying the quantization used by the tensor, if-any.
-typedef struct {
-  // The type of quantization held by params.
-  TfLiteQuantizationType type;
-  // Holds a reference to one of the quantization param structures specified
-  // below.
-  void* params;
-} TfLiteQuantization;
-
-// Legacy. Will be deprecated in favor of TfLiteAffineQuantization.
-// If per-layer quantization is specified this field will still be populated in
-// addition to TfLiteAffineQuantization.
-// Parameters for asymmetric quantization. Quantized values can be converted
-// back to float using:
-//     real_value = scale * (quantized_value - zero_point)
-typedef struct {
-  float scale;
-  int32_t zero_point;
-} TfLiteQuantizationParams;
-
-// Parameters for asymmetric quantization across a dimension (i.e per output
-// channel quantization).
-// quantized_dimension specifies which dimension the scales and zero_points
-// correspond to.
-// For a particular value in quantized_dimension, quantized values can be
-// converted back to float using:
-//     real_value = scale * (quantized_value - zero_point)
-typedef struct {
-  TfLiteFloatArray* scale;
-  TfLiteIntArray* zero_point;
-  int32_t quantized_dimension;
-} TfLiteAffineQuantization;
-
-/* A union of pointers that points to memory for a given tensor. */
-typedef union {
-  /* Do not access these members directly, if possible, use
-   * GetTensorData(tensor) instead, otherwise only access .data, as other
-   * members are deprecated. */
-  int32_t* i32;
-  int64_t* i64;
-  float* f;
-  TfLiteFloat16* f16;
-  char* raw;
-  const char* raw_const;
-  uint8_t* uint8;
-  bool* b;
-  int16_t* i16;
-  TfLiteComplex64* c64;
-  int8_t* int8;
-  /* Only use this member. */
-  void* data;
-} TfLitePtrUnion;
-
-// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
-// data (or data externally allocated). kTfLiteArenaRw is arena allocated
-// data. kTfLiteDynamic is for tensors that are allocated during evaluation.
-typedef enum {
-  kTfLiteMemNone = 0,
-  kTfLiteMmapRo,
-  kTfLiteArenaRw,
-  kTfLiteArenaRwPersistent,
-  kTfLiteDynamic,
-} TfLiteAllocationType;
-
-// The delegates should use zero or positive integers to represent handles.
-// -1 is reserved from unallocated status.
-typedef int TfLiteBufferHandle;
-enum {
-  kTfLiteNullBufferHandle = -1,
-};
-
-// An tensor in the interpreter system which is a wrapper around a buffer of
-// data including a dimensionality (or NULL if not currently defined).
-typedef struct {
-  // The data type specification for data stored in `data`. This affects
-  // what member of `data` union should be used.
-  TfLiteType type;
-  // A union of data pointers. The appropriate type should be used for a typed
-  // tensor based on `type`.
-  TfLitePtrUnion data;
-  // A pointer to a structure representing the dimensionality interpretation
-  // that the buffer should have. NOTE: the product of elements of `dims`
-  // and the element datatype size should be equal to `bytes` below.
-  TfLiteIntArray* dims;
-  // Quantization information.
-  TfLiteQuantizationParams params;
-  // How memory is mapped
-  //  kTfLiteMmapRo: Memory mapped read only.
-  //  i.e. weights
-  //  kTfLiteArenaRw: Arena allocated read write memory
-  //  (i.e. temporaries, outputs).
-  TfLiteAllocationType allocation_type;
-  // The number of bytes required to store the data of this Tensor. I.e.
-  // (bytes of each element) * dims[0] * ... * dims[n-1].  For example, if
-  // type is kTfLiteFloat32 and dims = {3, 2} then
-  // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
-  size_t bytes;
-
-  // An opaque pointer to a tflite::MMapAllocation
-  const void* allocation;
-
-  // Null-terminated name of this tensor.
-  const char* name;
-
-  // The delegate which knows how to handle `buffer_handle`.
-  // WARNING: This is an experimental interface that is subject to change.
-  struct TfLiteDelegate* delegate;
-
-  // An integer buffer handle that can be handled by `delegate`.
-  // The value is valid only when delegate is not null.
-  // WARNING: This is an experimental interface that is subject to change.
-  TfLiteBufferHandle buffer_handle;
-
-  // If the delegate uses its own buffer (e.g. GPU memory), the delegate is
-  // responsible to set data_is_stale to true.
-  // `delegate->CopyFromBufferHandle` can be called to copy the data from
-  // delegate buffer.
-  // WARNING: This is an // experimental interface that is subject to change.
-  bool data_is_stale;
-
-  // True if the tensor is a variable.
-  bool is_variable;
-
-  // Quantization information. Replaces params field above.
-  TfLiteQuantization quantization;
-} TfLiteTensor;
-
-// Free data memory of tensor `t`.
-void TfLiteTensorDataFree(TfLiteTensor* t);
-
-// Free quantization data.
-void TfLiteQuantizationFree(TfLiteQuantization* quantization);
-
-// Free memory of tensor `t`.
-void TfLiteTensorFree(TfLiteTensor* t);
-
-// Set all of a tensor's fields (and free any previously allocated data).
-void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
-                       TfLiteQuantizationParams quantization, char* buffer,
-                       size_t size, TfLiteAllocationType allocation_type,
-                       const void* allocation, bool is_variable,
-                       TfLiteTensor* tensor);
-
-// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
-// types other than kTfLiteDynamic will be ignored.
-void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
-
-// A structure representing an instance of a node.
-// This structure only exhibits the inputs, outputs and user defined data, not
-// other features like the type.
-typedef struct {
-  // Inputs to this node expressed as indices into the simulator's tensors.
-  TfLiteIntArray* inputs;
-
-  // Outputs to this node expressed as indices into the simulator's tensors.
-  TfLiteIntArray* outputs;
-
-  // intermediate tensors to this node expressed as indices into the simulator's
-  // tensors.
-  TfLiteIntArray* intermediates;
-
-  // Temporary tensors uses during the computations. This usually contains no
-  // tensors, but ops are allowed to change that if they need scratch space of
-  // any sort.
-  TfLiteIntArray* temporaries;
-
-  // Opaque data provided by the node implementer through `Registration.init`.
-  void* user_data;
-
-  // Opaque data provided to the node if the node is a builtin. This is usually
-  // a structure defined in builtin_op_data.h
-  void* builtin_data;
-
-  // Custom initial data. This is the opaque data provided in the flatbuffer.
-  // WARNING: This is an experimental interface that is subject to change.
-  const void* custom_initial_data;
-  int custom_initial_data_size;
-
-  // The pointer to the delegate. This is non-null only when the node is
-  // created by calling `interpreter.ModifyGraphWithDelegate`.
-  // WARNING: This is an experimental interface that is subject to change.
-  struct TfLiteDelegate* delegate;
-} TfLiteNode;
-
-typedef struct TfLiteContext {
-  // Number of tensors in the context.
-  size_t tensors_size;
-
-  // The execution plan contains a list of the node indices in execution
-  // order. execution_plan->size is the current number of nodes. And,
-  // execution_plan->data[0] is the first node that needs to be run.
-  // TfLiteDelegates can traverse the current execution plan by iterating
-  // through each member of this array and using GetNodeAndRegistration() to
-  // access details about a node. i.e.
-  // TfLiteIntArray* execution_plan;
-  // TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan));
-  // for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) {
-  //    int node_index = execution_plan->data[exec_index];
-  //    TfLiteNode* node;
-  //    TfLiteRegistration* reg;
-  //    context->GetNodeAndRegistration(context, node_index, &node, ®);
-  // }
-  // WARNING: This is an experimental interface that is subject to change.
-  TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context,
-                                   TfLiteIntArray** execution_plan);
-
-  // An array of tensors in the interpreter context (of length `tensors_size`)
-  TfLiteTensor* tensors;
-
-  // opaque full context ptr (an opaque c++ data structure)
-  void* impl_;
-
-  // Request memory pointer be resized. Updates dimensions on the tensor.
-  // NOTE: ResizeTensor takes ownership of newSize.
-  TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor,
-                               TfLiteIntArray* new_size);
-  // Request that an error be reported with format string msg.
-  void (*ReportError)(struct TfLiteContext*, const char* msg, ...);
-
-  // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries.  If
-  // non-null, the value pointed to by `first_new_tensor_index` will be set to
-  // the index of the first new tensor.
-  TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add,
-                             int* first_new_tensor_index);
-
-  // Get a Tensor node by node_index.
-  // WARNING: This is an experimental interface that is subject to change.
-  TfLiteStatus (*GetNodeAndRegistration)(
-      struct TfLiteContext*, int node_index, TfLiteNode** node,
-      struct TfLiteRegistration** registration);
-
-  // Replace ops with one or more stub delegate operations. This function
-  // does not take ownership of `nodes_to_replace`.
-  TfLiteStatus (*ReplaceNodeSubsetsWithDelegateKernels)(
-      struct TfLiteContext*, struct TfLiteRegistration registration,
-      const TfLiteIntArray* nodes_to_replace, struct TfLiteDelegate* delegate);
-
-  // Number of threads that are recommended to subsystems like gemmlowp and
-  // eigen.
-  int recommended_num_threads;
-
-  // Access external contexts by type.
-  // WARNING: This is an experimental interface that is subject to change.
-  TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*,
-                                               TfLiteExternalContextType);
-  // Set the value of a external context. Does not take ownership of the
-  // pointer.
-  // WARNING: This is an experimental interface that is subject to change.
-  void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType,
-                             TfLiteExternalContext*);
-
-  // Flag for allowing float16 precision for FP32 calculation.
-  // default: false.
-  // WARNING: This is an experimental API and subject to change.
-  bool allow_fp32_relax_to_fp16;
-
-  // Pointer to the op-level profiler, if set; nullptr otherwise.
-  void* profiler;
-
-  // Allocate memory for op data. This method should only be used in `Init`
-  // method and the allocated memory will be available until `Free` method is
-  // called.
-  // On TFL, it allocates memory from heap using malloc, but for micro, this
-  // will be allocating from the allocator.
-  // WARNING: This is an experimental interface that is subject to change.
-  void* (*AllocateOpData)(struct TfLiteContext* ctx, size_t size);
-
-  // Deallocate memory holding op data. This method should only be used inside
-  // `Free` method. Caller needs to make sure that that `buffer` is allocated by
-  // `AllocateOpData` method.
-  // On TFL, it will free the buffer, and for micro, this method is a no-op.
-  // WARNING: This is an experimental interface that is subject to change.
-  void (*DeallocateOpData)(struct TfLiteContext* ctx, void* buffer);
-
-  // Allocate a temporary tensor to the node. This method also makes a copy of
-  // the shape array internally so the shape array could be deallocated right
-  // afterwards. WARNING: This is an experimental interface that is subject to
-  // change.
-  TfLiteStatus (*AllocateTemporaryTensor)(struct TfLiteContext* ctx,
-                                          TfLiteNode* node, int dims,
-                                          int* shape, TfLiteType data_type,
-                                          TfLiteAllocationType allocation_type,
-                                          int* new_tensor_index);
-
-  // Deallocate all temporary tensors associated to the node (including
-  // kTfLiteArenaRwPersistent persistent tensors). It also deallocates
-  // all the shape tensors.
-  // WARNING: This is an experimental interface that is subject to change.
-  void (*DeallocateAllTemporaryTensors)(struct TfLiteContext* ctx,
-                                        TfLiteNode* node);
-
-  // Resize the memory pointer of the `tensor`. This method behaves the same as
-  // `ResizeTensor`, except that it makes a copy of the shape array internally
-  // so the shape array could be deallocated right afterwards.
-  // WARNING: This is an experimental interface that is subject to change.
-  TfLiteStatus (*ResizeTensorExplicit)(struct TfLiteContext* ctx,
-                                       TfLiteTensor* tensor, int dims,
-                                       const int* shape);
-} TfLiteContext;
-
-typedef struct TfLiteRegistration {
-  // Initializes the op from serialized data.
-  // If a built-in op:
-  //   `buffer` is the op's params data (TfLiteLSTMParams*).
-  //   `length` is zero.
-  // If custom op:
-  //   `buffer` is the op's `custom_options`.
-  //   `length` is the size of the buffer.
-  //
-  // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer
-  // or an instance of a struct).
-  //
-  // The returned pointer will be stored with the node in the `user_data` field,
-  // accessible within prepare and invoke functions below.
-  // NOTE: if the data is already in the desired format, simply implement this
-  // function to return `nullptr` and implement the free function to be a no-op.
-  void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
-
-  // The pointer `buffer` is the data previously returned by an init invocation.
-  void (*free)(TfLiteContext* context, void* buffer);
-
-  // prepare is called when the inputs this node depends on have been resized.
-  // context->ResizeTensor() can be called to request output tensors to be
-  // resized.
-  //
-  // Returns kTfLiteOk on success.
-  TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
-
-  // Execute the node (should read node->inputs and output to node->outputs).
-  // Returns kTfLiteOk on success.
-  TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
-
-  // profiling_string is called during summarization of profiling information
-  // in order to group executions together. Providing a value here will cause a
-  // given op to appear multiple times is the profiling report. This is
-  // particularly useful for custom ops that can perform significantly
-  // different calculations depending on their `user-data`.
-  const char* (*profiling_string)(const TfLiteContext* context,
-                                  const TfLiteNode* node);
-
-  // Builtin codes. If this kernel refers to a builtin this is the code
-  // of the builtin. This is so we can do marshaling to other frameworks like
-  // NN API.
-  // Note: It is the responsibility of the registration binder to set this
-  // properly.
-  int32_t builtin_code;
-
-  // Custom op name. If the op is a builtin, this will be null.
-  // Note: It is the responsibility of the registration binder to set this
-  // properly.
-  // WARNING: This is an experimental interface that is subject to change.
-  const char* custom_name;
-
-  // The version of the op.
-  // Note: It is the responsibility of the registration binder to set this
-  // properly.
-  int version;
-} TfLiteRegistration;
-
-// The flags used in `TfLiteDelegate`. Note that this is a bitmask, so the
-// values should be 1, 2, 4, 8, ...etc.
-typedef enum {
-  kTfLiteDelegateFlagsNone = 0,
-  // The flag is set if the delegate can handle dynamic sized tensors.
-  // For example, the output shape of a `Resize` op with non-constant shape
-  // can only be inferred when the op is invoked.
-  // In this case, the Delegate is responsible for calling
-  // `SetTensorToDynamic` to mark the tensor as a dynamic tensor, and calling
-  // `ResizeTensor` when invoking the op.
-  //
-  // If the delegate isn't capable to handle dynamic tensors, this flag need
-  // to be set to false.
-  kTfLiteDelegateFlagsAllowDynamicTensors = 1
-} TfLiteDelegateFlags;
-
-// WARNING: This is an experimental interface that is subject to change.
-typedef struct TfLiteDelegate {
-  // Data that delegate needs to identify itself. This data is owned by the
-  // delegate. The delegate is owned in the user code, so the delegate is
-  // responsible for doing this when it is destroyed.
-  void* data_;
-
-  // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the
-  // delegate a view of the current graph through TfLiteContext*. It typically
-  // will look at the nodes and call ReplaceNodeSubsetsWithDelegateKernels()
-  // to ask the TensorFlow lite runtime to create macro-nodes to represent
-  // delegated subgraphs of the original graph.
-  TfLiteStatus (*Prepare)(TfLiteContext* context,
-                          struct TfLiteDelegate* delegate);
-
-  // Copy the data from delegate buffer handle into raw memory of the given
-  // 'tensor'. This cannot be null. The delegate is allowed to allocate the raw
-  // bytes as long as it follows the rules for kTfLiteDynamic tensors.
-  TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context,
-                                       struct TfLiteDelegate* delegate,
-                                       TfLiteBufferHandle buffer_handle,
-                                       TfLiteTensor* tensor);
-
-  // Copy the data from raw memory of the given 'tensor' to delegate buffer
-  // handle. This can be null if the delegate doesn't use its own buffer.
-  TfLiteStatus (*CopyToBufferHandle)(TfLiteContext* context,
-                                     struct TfLiteDelegate* delegate,
-                                     TfLiteBufferHandle buffer_handle,
-                                     TfLiteTensor* tensor);
-
-  // Free the Delegate Buffer Handle. Note: This only frees the handle, but
-  // this doesn't release the underlying resource (e.g. textures). The
-  // resources are either owned by application layer or the delegate.
-  // This can be null if the delegate doesn't use its own buffer.
-  void (*FreeBufferHandle)(TfLiteContext* context,
-                           struct TfLiteDelegate* delegate,
-                           TfLiteBufferHandle* handle);
-
-  // Bitmask flags. See the comments in `TfLiteDelegateFlags`.
-  int64_t flags;
-} TfLiteDelegate;
-
-// Build a 'null' delegate, with all the fields properly set to their default
-// values.
-TfLiteDelegate TfLiteDelegateCreate();
-
-// WARNING: This is an experimental interface that is subject to change.
-//
-// Currently, TfLiteDelegateParams has to be allocated in a way that it's
-// trivially destructable. It will be stored as `builtin_data` field in
-// `TfLiteNode` of the delegate node.
-//
-// See also the `CreateDelegateParams` function in `interpreter.cc` details.
-typedef struct {
-  TfLiteDelegate* delegate;
-  TfLiteIntArray* nodes_to_replace;
-  TfLiteIntArray* input_tensors;
-  TfLiteIntArray* output_tensors;
-} TfLiteDelegateParams;
-
-#ifdef __cplusplus
-}  // extern "C"
-#endif  // __cplusplus
-#endif  // TENSORFLOW_LITE_C_COMMON_H_
diff --git a/tensorflow/lite/experimental/examples/lstm/BUILD b/tensorflow/lite/experimental/examples/lstm/BUILD
index 2531889dafb..719e59c6a8c 100644
--- a/tensorflow/lite/experimental/examples/lstm/BUILD
+++ b/tensorflow/lite/experimental/examples/lstm/BUILD
@@ -35,7 +35,7 @@ py_library(
 
 py_test(
     name = "unidirectional_sequence_lstm_test",
-    size = "large",
+    size = "medium",
     srcs = ["unidirectional_sequence_lstm_test.py"],
     python_version = "PY3",
     srcs_version = "PY2AND3",
@@ -58,7 +58,7 @@ py_test(
 
 py_test(
     name = "unidirectional_sequence_rnn_test",
-    size = "large",
+    size = "medium",
     srcs = ["unidirectional_sequence_rnn_test.py"],
     python_version = "PY3",
     srcs_version = "PY2AND3",
@@ -81,7 +81,7 @@ py_test(
 
 py_test(
     name = "bidirectional_sequence_lstm_test",
-    size = "large",
+    size = "medium",
     srcs = ["bidirectional_sequence_lstm_test.py"],
     python_version = "PY3",
     srcs_version = "PY2AND3",
@@ -104,13 +104,14 @@ py_test(
 
 py_test(
     name = "bidirectional_sequence_rnn_test",
-    size = "large",
+    size = "medium",
     srcs = ["bidirectional_sequence_rnn_test.py"],
     python_version = "PY3",
     srcs_version = "PY2AND3",
     tags = [
         "no_oss",
         "no_pip",
+        "notap",  # b/141373014
     ],
     deps = [
         ":rnn",
diff --git a/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_lstm_test.py b/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_lstm_test.py
index f04a265714f..d4b5e2b663a 100644
--- a/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_lstm_test.py
+++ b/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_lstm_test.py
@@ -27,7 +27,9 @@ from tensorflow.python.framework import test_util
 from tensorflow.python.platform import test
 
 # Number of steps to train model.
-TRAIN_STEPS = 1
+# Dial to 0 means no training at all, all the weights will be just using their
+# initial values. This can help make the test smaller.
+TRAIN_STEPS = 0
 
 CONFIG = tf.ConfigProto(device_count={"GPU": 0})
 
@@ -37,7 +39,8 @@ class BidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
   def setUp(self):
     tf.reset_default_graph()
     # Import MNIST dataset
-    self.mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
+    self.mnist = input_data.read_data_sets(
+        "/tmp/data/", fake_data=True, one_hot=True)
 
     # Define constants
     # Unrolled through 28 time steps
@@ -144,8 +147,10 @@ class BidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
     sess.run(init)
     for _ in range(TRAIN_STEPS):
       batch_x, batch_y = self.mnist.train.next_batch(
-          batch_size=self.batch_size, shuffle=False)
+          batch_size=self.batch_size, fake_data=True)
 
+      batch_x = np.array(batch_x)
+      batch_y = np.array(batch_y)
       batch_x = batch_x.reshape((self.batch_size, self.time_steps,
                                  self.n_input))
       sess.run(opt, feed_dict={x: batch_x, y: batch_y})
@@ -200,7 +205,8 @@ class BidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
       - Expected output.
 
     """
-    b1, _ = self.mnist.train.next_batch(batch_size=1)
+    b1, _ = self.mnist.train.next_batch(batch_size=1, fake_data=True)
+    b1 = np.array(b1, dtype=np.dtype("float32"))
     sample_input = np.reshape(b1, (1, self.time_steps, self.n_input))
 
     expected_output = sess.run(output_class, feed_dict={x: sample_input})
diff --git a/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_rnn_test.py b/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_rnn_test.py
index 606f969b92a..b90d4d52b29 100644
--- a/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_rnn_test.py
+++ b/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_rnn_test.py
@@ -31,7 +31,9 @@ from tensorflow.python.platform import test
 FLAGS = flags.FLAGS
 
 # Number of steps to train model.
-TRAIN_STEPS = 1
+# Dial to 0 means no training at all, all the weights will be just using their
+# initial values. This can help make the test smaller.
+TRAIN_STEPS = 0
 
 CONFIG = tf.ConfigProto(device_count={"GPU": 0})
 
@@ -58,7 +60,8 @@ class BidirectionalSequenceRnnTest(test_util.TensorFlowTestCase):
     super(BidirectionalSequenceRnnTest, self).setUp()
     # Import MNIST dataset
     data_dir = tempfile.mkdtemp(dir=FLAGS.test_tmpdir)
-    self.mnist = input_data.read_data_sets(data_dir, one_hot=True)
+    self.mnist = input_data.read_data_sets(
+        data_dir, fake_data=True, one_hot=True)
 
   def buildRnnLayer(self):
     return tf.keras.layers.StackedRNNCells([
@@ -165,8 +168,10 @@ class BidirectionalSequenceRnnTest(test_util.TensorFlowTestCase):
     sess.run(init)
     for _ in range(TRAIN_STEPS):
       batch_x, batch_y = self.mnist.train.next_batch(
-          batch_size=self.batch_size, shuffle=False)
+          batch_size=self.batch_size, shuffle=False, fake_data=True)
 
+      batch_x = np.array(batch_x)
+      batch_y = np.array(batch_y)
       batch_x = batch_x.reshape((self.batch_size, self.time_steps,
                                  self.n_input))
       sess.run(opt, feed_dict={x: batch_x, y: batch_y})
@@ -228,7 +233,8 @@ class BidirectionalSequenceRnnTest(test_util.TensorFlowTestCase):
       - Expected output.
 
     """
-    b1, _ = self.mnist.train.next_batch(batch_size=1)
+    b1, _ = self.mnist.train.next_batch(batch_size=1, fake_data=True)
+    b1 = np.array(b1, dtype=np.dtype("float32"))
     sample_input = np.reshape(b1, (1, self.time_steps, self.n_input))
 
     expected_output = sess.run(output_class, feed_dict={x: sample_input})
diff --git a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py
index d937a111529..ba936a4e8cd 100644
--- a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py
+++ b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py
@@ -27,7 +27,9 @@ from tensorflow.python.platform import test
 
 
 # Number of steps to train model.
-TRAIN_STEPS = 1
+# Dial to 0 means no training at all, all the weights will be just using their
+# initial values. This can help make the test smaller.
+TRAIN_STEPS = 0
 
 CONFIG = tf.ConfigProto(device_count={"GPU": 0})
 
@@ -37,7 +39,8 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
   def setUp(self):
     tf.reset_default_graph()
     # Import MNIST dataset
-    self.mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
+    self.mnist = input_data.read_data_sets(
+        "/tmp/data/", fake_data=True, one_hot=True)
 
     # Define constants
     # Unrolled through 28 time steps
@@ -133,8 +136,10 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
     sess.run(init)
     for _ in range(TRAIN_STEPS):
       batch_x, batch_y = self.mnist.train.next_batch(
-          batch_size=self.batch_size, shuffle=False)
+          batch_size=self.batch_size, fake_data=True)
 
+      batch_x = np.array(batch_x)
+      batch_y = np.array(batch_y)
       batch_x = batch_x.reshape((self.batch_size, self.time_steps,
                                  self.n_input))
       sess.run(opt, feed_dict={x: batch_x, y: batch_y})
@@ -184,7 +189,8 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
       - Expected output.
 
     """
-    b1, _ = self.mnist.train.next_batch(batch_size=1)
+    b1, _ = self.mnist.train.next_batch(batch_size=1, fake_data=True)
+    b1 = np.array(b1, dtype=np.dtype("float32"))
     sample_input = np.reshape(b1, (1, self.time_steps, self.n_input))
 
     expected_output = sess.run(output_class, feed_dict={x: sample_input})
diff --git a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_rnn_test.py b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_rnn_test.py
index a3859e1ad40..49c3d5e7757 100644
--- a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_rnn_test.py
+++ b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_rnn_test.py
@@ -30,7 +30,9 @@ from tensorflow.python.platform import test
 FLAGS = flags.FLAGS
 
 # Number of steps to train model.
-TRAIN_STEPS = 1
+# Dial to 0 means no training at all, all the weights will be just using their
+# initial values. This can help make the test smaller.
+TRAIN_STEPS = 0
 
 CONFIG = tf.ConfigProto(device_count={"GPU": 0})
 
@@ -57,7 +59,8 @@ class UnidirectionalSequenceRnnTest(test_util.TensorFlowTestCase):
     super(UnidirectionalSequenceRnnTest, self).setUp()
     # Import MNIST dataset
     data_dir = tempfile.mkdtemp(dir=FLAGS.test_tmpdir)
-    self.mnist = input_data.read_data_sets(data_dir, one_hot=True)
+    self.mnist = input_data.read_data_sets(
+        data_dir, fake_data=True, one_hot=True)
 
   def buildRnnLayer(self):
     return tf.keras.layers.StackedRNNCells([
@@ -128,8 +131,10 @@ class UnidirectionalSequenceRnnTest(test_util.TensorFlowTestCase):
     sess.run(tf.global_variables_initializer())
     for _ in range(TRAIN_STEPS):
       batch_x, batch_y = self.mnist.train.next_batch(
-          batch_size=self.batch_size, shuffle=False)
+          batch_size=self.batch_size, fake_data=True)
 
+      batch_x = np.array(batch_x)
+      batch_y = np.array(batch_y)
       batch_x = batch_x.reshape((self.batch_size, self.time_steps,
                                  self.n_input))
       sess.run(opt, feed_dict={x: batch_x, y: batch_y})
@@ -179,7 +184,8 @@ class UnidirectionalSequenceRnnTest(test_util.TensorFlowTestCase):
       - Expected output.
 
     """
-    b1, _ = self.mnist.train.next_batch(batch_size=1)
+    b1, _ = self.mnist.train.next_batch(batch_size=1, fake_data=True)
+    b1 = np.array(b1, dtype=np.dtype("float32"))
     sample_input = np.reshape(b1, (1, self.time_steps, self.n_input))
 
     expected_output = sess.run(output_class, feed_dict={x: sample_input})
diff --git a/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md
index 6daca3e4f5c..cbd1d016b83 100644
--- a/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md
+++ b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md
@@ -6,21 +6,18 @@ Unity by way of a C# `Interpreter` wrapper.
 
 Note that the native TF Lite plugin(s) *must* be built before using the Unity
 Plugin, and placed in Assets/TensorFlowLite/SDK/Plugins/. For the editor (note
-that this has only been tested on Linux; the syntax may differ on Mac/Windows):
+that the generated shared library name and suffix are platform-dependent):
 
 ```sh
-bazel build -c opt --cxxopt=--std=c++11 \
-  //tensorflow/lite/experimental/c:libtensorflowlite_c.so
+bazel build -c opt --cxxopt=--std=c++11 //tensorflow/lite/c:tensorflowlite_c
 ```
 
 and for Android (replace `android_arm` with `android_arm64` for 64-bit):
 
 ```sh
 bazel build -c opt --cxxopt=--std=c++11 --config=android_arm \
-  //tensorflow/lite/experimental/c:libtensorflowlite_c.so
+  //tensorflow/lite/c:tensorflowlite_c
 ```
 
 If you encounter issues with native plugin discovery on Mac ("Darwin")
-platforms, try renaming `libtensorflowlite_c.so` to `tensorflowlite_c.bundle`.
-Similarly, on Windows you'll likely need to rename `libtensorflowlite_c.so` to
-`tensorflowlite_c.dll`.
+platforms, try renaming `libtensorflowlite_c.dylib` to `tensorflowlite_c.bundle`.
diff --git a/tensorflow/lite/experimental/ios/BUILD.apple b/tensorflow/lite/experimental/ios/BUILD.apple
index 011d1725e3e..cf81057b167 100644
--- a/tensorflow/lite/experimental/ios/BUILD.apple
+++ b/tensorflow/lite/experimental/ios/BUILD.apple
@@ -5,23 +5,20 @@ load("//tensorflow/lite/experimental/ios:ios.bzl", "TFL_MINIMUM_OS_VERSION")
 load("@build_bazel_rules_apple//apple:ios.bzl", "ios_static_framework")
 
 package(
-    default_visibility = ["//tensorflow/lite/experimental/c:experimental"],
+    default_visibility = ["//tensorflow/lite/c:experimental"],
     licenses = ["notice"],  # Apache 2.0
 )
 
 TFL_LIBRARY_HDRS = [
     "//tensorflow/lite/delegates/gpu:metal_delegate.h",
-    "//tensorflow/lite/experimental/c:c_api.h",
-]
-
-TFL_FRAMEWORK_HDRS = TFL_LIBRARY_HDRS + [
-    "//tensorflow/lite/experimental/c:c_api_types.h",
+    "//tensorflow/lite/c:c_api.h",
+    "//tensorflow/lite/c:common.h",
 ]
 
 # bazel build -c opt --config=ios_fat //tensorflow/lite/experimental/ios:TensorFlowLiteC_framework
 ios_static_framework(
     name = "TensorFlowLiteC_framework",
-    hdrs = TFL_FRAMEWORK_HDRS,
+    hdrs = TFL_LIBRARY_HDRS,
     bundle_name = "TensorFlowLiteC",
     minimum_os_version = TFL_MINIMUM_OS_VERSION,
     deps = [
@@ -29,18 +26,6 @@ ios_static_framework(
     ],
 )
 
-# bazel build -c opt --config=ios --ios_multi_cpus=armv7,arm64,x86_64 //tensorflow/lite/experimental/ios:TensorFlowLiteCWithSelectTfOps_framework
-ios_static_framework(
-    name = "TensorFlowLiteCWithSelectTfOps_framework",
-    hdrs = TFL_FRAMEWORK_HDRS,
-    bundle_name = "TensorFlowLiteC",
-    minimum_os_version = TFL_MINIMUM_OS_VERSION,
-    deps = [
-        ":TensorFlowLiteC",
-        "//tensorflow/lite/delegates/flex:delegate",
-    ],
-)
-
 objc_library(
     name = "TensorFlowLiteC",
     hdrs = TFL_LIBRARY_HDRS,
@@ -53,6 +38,24 @@ objc_library(
     ],
 )
 
+# This target builds the flex delegate as a separate static framework, which
+# does not include the TensorFlow Lite runtime. As this target does not contain
+# TensorFlow Lite runtime, it is intended to be linked along with the
+# TensorFlowLiteC framework above in a composable way.
+#
+# The flex delegate cannot be built for i386, so it can't be built with ios_fat
+# config.
+#
+# bazel build -c opt --config=ios --ios_multi_cpus=armv7,arm64,x86_64 //tensorflow/lite/experimental/ios:TensorFlowLiteSelectTfOps_framework
+ios_static_framework(
+    name = "TensorFlowLiteSelectTfOps_framework",
+    bundle_name = "TensorFlowLiteSelectTfOps",
+    minimum_os_version = TFL_MINIMUM_OS_VERSION,
+    deps = [
+        "//tensorflow/lite/delegates/flex:delegate",
+    ],
+)
+
 # Using this intermediate target is a workaround for a bug in bazel build rules
 # involving mixed objc_library & cc_library deps mentioned in (b/74809458).
 # When these dependencies are declared directly under the "TensorFlowLiteC"
@@ -68,16 +71,20 @@ cc_library(
     hdrs = TFL_LIBRARY_HDRS,
     tags = ["nobuilder"],
     deps = [
+        "//tensorflow/lite/c:c_api",
         "//tensorflow/lite/delegates/gpu:metal_delegate",
-        "//tensorflow/lite/experimental/c:c_api",
     ],
 )
 
 # Used for building TensorFlowLiteC framework.
 build_test(
     name = "framework_build_test",
+    tags = [
+        "nomsan",  # b/145205324
+        "notsan",  # b/145205324
+    ],
     targets = [
         ":TensorFlowLiteC_framework",
-        ":TensorFlowLiteCWithSelectTfOps_framework",
+        ":TensorFlowLiteSelectTfOps_framework",
     ],
 )
diff --git a/tensorflow/lite/experimental/ios/TensorFlowLiteSelectTfOps.md b/tensorflow/lite/experimental/ios/TensorFlowLiteSelectTfOps.md
new file mode 100644
index 00000000000..525049db2b7
--- /dev/null
+++ b/tensorflow/lite/experimental/ios/TensorFlowLiteSelectTfOps.md
@@ -0,0 +1,19 @@
+# TensorFlow Lite with Select TensorFlow ops
+
+For enabling the Select TensorFlow ops for your TensorFlow Lite app, please add
+the `TensorFlowLiteSelectTfOps` pod to your Podfile, in addition to
+`TensorFlowLiteSwift` or `TensorFlowLiteObjC` pod, depending on your primary
+language.
+
+After that, you should also force load the framework from your project. Add the
+following line to the `Other Linker Flags` under your project's Build Settings
+page.
+
+```
+-force_load "$(PROJECT_DIR)/Pods/TensorFlowLiteSelectTfOps/Frameworks/TensorFlowLiteSelectTfOps.framework/TensorFlowLiteSelectTfOps"
+```
+
+Please refer to the [Select operators from TensorFlow][ops-select] guide for
+more details.
+
+[ops-select]: https://www.tensorflow.org/lite/guide/ops_select#ios
diff --git a/tensorflow/lite/experimental/ios/TensorFlowLiteSelectTfOps.podspec.template b/tensorflow/lite/experimental/ios/TensorFlowLiteSelectTfOps.podspec.template
new file mode 100644
index 00000000000..7a91e4a08ce
--- /dev/null
+++ b/tensorflow/lite/experimental/ios/TensorFlowLiteSelectTfOps.podspec.template
@@ -0,0 +1,21 @@
+Pod::Spec.new do |s|
+  s.name             = 'TensorFlowLiteSelectTfOps'
+  s.version          = '${TFL_BUILD_VERSION}'
+  s.authors          = 'Google Inc.'
+  s.license          = { :type => 'Apache' }
+  s.homepage         = 'https://github.com/tensorflow/tensorflow'
+  s.source           = { :http => "${TFL_DOWNLOAD_URL}" }
+  s.summary          = 'TensorFlow Lite'
+  s.description      = <<-DESC
+
+  This pod can be used in addition to `TensorFlowLiteSwift` or
+  `TensorFlowLiteObjC` pod, in order to enable Select TensorFlow ops. The
+  resulting binary should also be force-loaded to the final app binary.
+                       DESC
+
+  s.ios.deployment_target = '9.0'
+
+  s.module_name = 'TensorFlowLiteSelectTfOps'
+  s.library = 'c++'
+  s.vendored_frameworks = 'Frameworks/TensorFlowLiteSelectTfOps.framework'
+end
diff --git a/tensorflow/lite/experimental/kernels/hashtable_ops_test.cc b/tensorflow/lite/experimental/kernels/hashtable_ops_test.cc
index ddc31a5abfd..8790a2c9960 100644
--- a/tensorflow/lite/experimental/kernels/hashtable_ops_test.cc
+++ b/tensorflow/lite/experimental/kernels/hashtable_ops_test.cc
@@ -695,7 +695,7 @@ TEST(HashtableOpsTest, TestHashtable) {
 
 template 
 TfLiteTensor CreateTensor(TfLiteType type, std::vector vec) {
-  TfLiteTensor tensor;
+  TfLiteTensor tensor = {};
   TfLiteIntArray* dims = TfLiteIntArrayCreate(1);
   dims->data[0] = vec.size();
   tensor.dims = dims;
diff --git a/tensorflow/lite/experimental/micro/kernels/pooling_test.cc b/tensorflow/lite/experimental/micro/kernels/pooling_test.cc
index d2f8f41edcd..03909b994f8 100644
--- a/tensorflow/lite/experimental/micro/kernels/pooling_test.cc
+++ b/tensorflow/lite/experimental/micro/kernels/pooling_test.cc
@@ -54,7 +54,7 @@ void TestAveragePoolingFloat(std::initializer_list input_dims_data,
       resolver.FindOp(tflite::BuiltinOperator_AVERAGE_POOL_2D, 1);
   TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
 
-  TfLiteConvParams builtin_data = {padding,      stride_width,  stride_height,
+  TfLitePoolParams builtin_data = {padding,      stride_width,  stride_height,
                                    filter_width, filter_height, activation};
   const char* init_data = reinterpret_cast(&builtin_data);
   size_t init_data_size = 0;
@@ -125,7 +125,7 @@ void TestAveragePoolingUint8(
       resolver.FindOp(tflite::BuiltinOperator_AVERAGE_POOL_2D, 1);
   TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
 
-  TfLiteConvParams builtin_data = {padding,      stride_width,  stride_height,
+  TfLitePoolParams builtin_data = {padding,      stride_width,  stride_height,
                                    filter_width, filter_height, activation};
   const char* init_data = reinterpret_cast(&builtin_data);
   size_t init_data_size = 0;
@@ -198,7 +198,7 @@ void TestAveragePoolingInt8(std::initializer_list input_dims_data,
       resolver.FindOp(tflite::BuiltinOperator_AVERAGE_POOL_2D, 1);
   TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
 
-  TfLiteConvParams builtin_data = {padding,      stride_width,  stride_height,
+  TfLitePoolParams builtin_data = {padding,      stride_width,  stride_height,
                                    filter_width, filter_height, activation};
   const char* init_data = reinterpret_cast(&builtin_data);
   size_t init_data_size = 0;
diff --git a/tensorflow/lite/experimental/micro/micro_allocator.cc b/tensorflow/lite/experimental/micro/micro_allocator.cc
index 700016af510..82b3b350c23 100644
--- a/tensorflow/lite/experimental/micro/micro_allocator.cc
+++ b/tensorflow/lite/experimental/micro/micro_allocator.cc
@@ -89,6 +89,11 @@ MicroAllocator::MicroAllocator(TfLiteContext* context, const Model* model,
       reinterpret_cast(memory_allocator_.AllocateFromTail(
           sizeof(TfLiteTensor) * context_->tensors_size,
           alignof(TfLiteTensor)));
+  if (context_->tensors == nullptr) {
+    error_reporter_->Report(
+        "Failed to allocate memory for context->tensors, %d bytes required",
+        sizeof(TfLiteTensor) * context_->tensors_size);
+  }
 
   // Null all inputs so we can later perform a null check to avoid re-allocating
   // registered pre-allocated inputs.
@@ -230,6 +235,12 @@ TfLiteStatus MicroAllocator::FinishTensorAllocation() {
   TensorInfo* tensor_info =
       reinterpret_cast(tmp_allocator.AllocateFromTail(
           sizeof(TensorInfo) * tensors_size, alignof(TensorInfo)));
+  if (tensor_info == nullptr) {
+    error_reporter_->Report(
+        "Failed to allocate memory for tensor_info, %d bytes required",
+        sizeof(TfLiteTensor) * context_->tensors_size);
+    return kTfLiteError;
+  }
 
   // Set up the runtime data structures for all tensors.
   for (size_t i = 0; i < tensors_size; ++i) {
diff --git a/tensorflow/lite/experimental/micro/tools/ci_build/test_arduino_library.sh b/tensorflow/lite/experimental/micro/tools/ci_build/test_arduino_library.sh
index 8911f0d4274..04a5a617655 100755
--- a/tensorflow/lite/experimental/micro/tools/ci_build/test_arduino_library.sh
+++ b/tensorflow/lite/experimental/micro/tools/ci_build/test_arduino_library.sh
@@ -23,7 +23,6 @@ set -e
 ARDUINO_HOME_DIR=${HOME}/Arduino
 ARDUINO_LIBRARIES_DIR=${ARDUINO_HOME_DIR}/libraries
 ARDUINO_CLI_TOOL=/tmp/arduino-cli
-# Necessary due to bug in arduino-cli that allows it to build files in pwd
 TEMP_BUILD_DIR=/tmp/tflite-arduino-build
 
 LIBRARY_ZIP=${1}
@@ -57,11 +56,9 @@ InstallLibraryDependencies () {
 
 InstallLibraryDependencies
 
-# Change into this dir before running the tests
-cd ${TEMP_BUILD_DIR}
-
 for f in ${ARDUINO_LIBRARIES_DIR}/tensorflow_lite/examples/*/*.ino; do
-  ${ARDUINO_CLI_TOOL} compile --fqbn arduino:mbed:nano33ble $f
+  ${ARDUINO_CLI_TOOL} compile --build-cache-path ${TEMP_BUILD_DIR} --build-path ${TEMP_BUILD_DIR} --fqbn arduino:mbed:nano33ble $f
 done
 
 rm -rf ${ARDUINO_LIBRARIES_DIR}
+rm -rf ${TEMP_BUILD_DIR}
diff --git a/tensorflow/lite/experimental/objc/BUILD.apple b/tensorflow/lite/experimental/objc/BUILD.apple
index 09e672ceff3..198e90c1cbc 100644
--- a/tensorflow/lite/experimental/objc/BUILD.apple
+++ b/tensorflow/lite/experimental/objc/BUILD.apple
@@ -44,7 +44,7 @@ RELEASE_COPTS = [
     # Warns if an @selector() expression is encountered with a method name that hasn't been defined yet.
     "-Wundeclared-selector",
     # Turn off warnings for headers not part of TensorFlow Lite Objective-C API.
-    "--system-header-prefix=tensorflow/lite/experimental/c/",
+    "--system-header-prefix=tensorflow/lite/c/",
 ]
 
 # Compiler flags for building test libraries.
@@ -63,7 +63,7 @@ objc_library(
     tags = TFL_DEFAULT_TAGS,
     visibility = ios_visibility_whitelist(),
     deps = [
-        "//tensorflow/lite/experimental/c:c_api",
+        "//tensorflow/lite/c:c_api",
     ],
     alwayslink = 1,
 )
diff --git a/tensorflow/lite/experimental/objc/TensorFlowLite.tulsiproj/Configs/TensorFlowLite.tulsigen b/tensorflow/lite/experimental/objc/TensorFlowLite.tulsiproj/Configs/TensorFlowLite.tulsigen
index feacdbad8de..bbd35902dce 100644
--- a/tensorflow/lite/experimental/objc/TensorFlowLite.tulsiproj/Configs/TensorFlowLite.tulsigen
+++ b/tensorflow/lite/experimental/objc/TensorFlowLite.tulsiproj/Configs/TensorFlowLite.tulsigen
@@ -1,7 +1,7 @@
 {
   "sourceFilters" : [
     "tensorflow/lite",
-    "tensorflow/lite/experimental/c",
+    "tensorflow/lite/c",
     "tensorflow/lite/experimental/objc",
     "tensorflow/lite/experimental/objc/apis",
     "tensorflow/lite/experimental/objc/apps/TestApp/TestApp",
diff --git a/tensorflow/lite/experimental/objc/TensorFlowLiteObjC-nightly.podspec b/tensorflow/lite/experimental/objc/TensorFlowLiteObjC-nightly.podspec
index 762ba7b83c9..2447f432664 100644
--- a/tensorflow/lite/experimental/objc/TensorFlowLiteObjC-nightly.podspec
+++ b/tensorflow/lite/experimental/objc/TensorFlowLiteObjC-nightly.podspec
@@ -25,7 +25,7 @@ Pod::Spec.new do |s|
   s.source_files = [
     objc_dir + '{apis,sources}/*.{h,m,mm}',
     tfl_dir + 'experimental/c/c_api.h',
-    tfl_dir + 'experimental/c/c_api_types.h',
+    tfl_dir + 'experimental/c/common.h',
   ]
   s.module_map = objc_dir + 'apis/framework.modulemap'
   s.dependency 'TensorFlowLiteC', "~> #{s.version}"
diff --git a/tensorflow/lite/experimental/objc/TensorFlowLiteObjC.podspec b/tensorflow/lite/experimental/objc/TensorFlowLiteObjC.podspec
index 3af0eff111e..b3ece575fd8 100644
--- a/tensorflow/lite/experimental/objc/TensorFlowLiteObjC.podspec
+++ b/tensorflow/lite/experimental/objc/TensorFlowLiteObjC.podspec
@@ -25,7 +25,7 @@ Pod::Spec.new do |s|
   s.source_files = [
     objc_dir + '{apis,sources}/*.{h,m,mm}',
     tfl_dir + 'experimental/c/c_api.h',
-    tfl_dir + 'experimental/c/c_api_types.h',
+    tfl_dir + 'experimental/c/common.h',
   ]
   s.module_map = objc_dir + 'apis/framework.modulemap'
   s.dependency 'TensorFlowLiteC', "#{s.version}"
diff --git a/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm b/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm
index e8e69484e21..8ef4c571558 100644
--- a/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm
+++ b/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm
@@ -20,7 +20,7 @@
 #import "tensorflow/lite/experimental/objc/apis/TFLInterpreterOptions.h"
 #import "tensorflow/lite/experimental/objc/apis/TFLTensor.h"
 
-#include "tensorflow/lite/experimental/c/c_api.h"
+#include "tensorflow/lite/c/c_api.h"
 
 NS_ASSUME_NONNULL_BEGIN
 
diff --git a/tensorflow/lite/experimental/ruy/allocator.cc b/tensorflow/lite/experimental/ruy/allocator.cc
index 8c4536bdeb1..d702f70e9fb 100644
--- a/tensorflow/lite/experimental/ruy/allocator.cc
+++ b/tensorflow/lite/experimental/ruy/allocator.cc
@@ -26,19 +26,19 @@ namespace ruy {
 
 namespace detail {
 
-void *AlignedAllocator::SystemAlignedAlloc(std::ptrdiff_t num_bytes) {
+void *SystemAlignedAlloc(std::ptrdiff_t num_bytes) {
 #ifdef _WIN32
-  return _aligned_malloc(num_bytes, kAlignment);
+  return _aligned_malloc(num_bytes, kMinimumBlockAlignment);
 #else
   void *ptr;
-  if (posix_memalign(&ptr, kAlignment, num_bytes)) {
+  if (posix_memalign(&ptr, kMinimumBlockAlignment, num_bytes)) {
     return nullptr;
   }
   return ptr;
 #endif
 }
 
-void AlignedAllocator::SystemAlignedFree(void *ptr) {
+void SystemAlignedFree(void *ptr) {
 #ifdef _WIN32
   _aligned_free(ptr);
 #else
diff --git a/tensorflow/lite/experimental/ruy/allocator.h b/tensorflow/lite/experimental/ruy/allocator.h
index f233090ce49..2f5c98d6870 100644
--- a/tensorflow/lite/experimental/ruy/allocator.h
+++ b/tensorflow/lite/experimental/ruy/allocator.h
@@ -34,38 +34,49 @@ inline void* VoidPtrAdd(void* p, std::ptrdiff_t offset) {
   return reinterpret_cast(addr);
 }
 
-// Simple allocator designed to converge to a steady-state where all
+// Minimum alignment for blocks.
+//
+// Considerations:
+//  - This needs to be at least the alignment of any usual data type.
+//  - It's useful that this is at least the size of a cache line to limit
+//    possible cache side effects (if only on performance behavior).
+//  - It's useful that this is at least the size of SIMD registers, as
+//    some SIMD instruction sets have at least performance behavior
+//    differences (e.g. NEON) or even different requirements (e.g. SSE)
+//    based on that.
+//  - It's useful that this is at least the size of an "exclusive reservation
+//    granule" on ARM, meaning that if we use this Allocator to allocate
+//    an atomic variable, there will be no side effects from other things
+//    contending for exclusive/atomic memory accesses to it. While the
+//    ARM reference manual mentions that this granule size may be as large
+//    as 2048 bytes, in practice we observe it to be 64 bytes. It can
+//    be queried cheaply, at runtime, from userspace, if needed.
+static constexpr std::ptrdiff_t kMinimumBlockAlignment = 64;
+
+// Primitive allocation functions obtaining aligned memory from the
+// operating system.
+void* SystemAlignedAlloc(std::ptrdiff_t num_bytes);
+void SystemAlignedFree(void* ptr);
+
+// Specialized allocator designed to converge to a steady-state where all
 // allocations are bump-ptr allocations from an already-allocated buffer.
 //
 // To support these constraints, this allocator only supports two
 // operations.
 // - AllocateAlignedBytes: allocates a pointer to storage of a specified
-// size, which must be aligned to kAlignment.
+// size, which must be aligned to kMinimumBlockAlignment.
 // - FreeAll: frees all previous allocations (but retains the internal
 // buffer to minimize future calls into the system allocator).
 //
+// This class is specialized for supporting just those two operations
+// under this specific steady-state usage pattern. Extending this class
+// with new allocation interfaces that don't fit that pattern is probably not
+// the right choice. Instead, build a new class on top of
+// SystemAlignedAlloc/SystemAlignedFree.
+//
 // All operations happen on aligned blocks for simplicity.
 class AlignedAllocator {
  public:
-  // Alignment of allocated blocks.
-  //
-  // Considerations:
-  //  - This needs to be at least the alignment of any usual data type.
-  //  - It's useful that this is at least the size of a cache line to limit
-  //    possible cache side effects (if only on performance behavior).
-  //  - It's useful that this is at least the size of SIMD registers, as
-  //    some SIMD instruction sets have at least performance behavior
-  //    differences (e.g. NEON) or even different requirements (e.g. SSE)
-  //    based on that.
-  //  - It's useful that this is at least the size of an "exclusive reservation
-  //    granule" on ARM, meaning that if we use this Allocator to allocate
-  //    an atomic variable, there will be no side effects from other things
-  //    contending for exclusive/atomic memory accesses to it. While the
-  //    ARM reference manual mentions that this granule size may be as large
-  //    as 2048 bytes, in practice we observe it to be 64 bytes. It can
-  //    be queried cheaply, at runtime, from userspace, if needed.
-  static constexpr std::ptrdiff_t kAlignment = 64;
-
   void operator=(const AlignedAllocator&) = delete;
   ~AlignedAllocator() {
     FreeAll();
@@ -74,7 +85,7 @@ class AlignedAllocator {
 
   void* AllocateAlignedBytes(std::ptrdiff_t num_bytes) {
     RUY_DCHECK_GT(num_bytes, 0);
-    RUY_DCHECK((num_bytes & (kAlignment - 1)) == 0);
+    RUY_DCHECK((num_bytes & (kMinimumBlockAlignment - 1)) == 0);
     if (void* p = AllocateFast(num_bytes)) {
       return p;
     }
@@ -105,17 +116,7 @@ class AlignedAllocator {
     fallback_blocks_total_size_ = 0;
   }
 
-  void FreeOne(void* ptr) {
-    for (auto p = fallback_blocks_.begin(); p != fallback_blocks_.end(); ++p) {
-      if (*p == ptr) {
-        SystemAlignedFree(ptr);
-        fallback_blocks_.erase(p);
-        return;
-      }
-    }
-    RUY_DCHECK(false);  // Trying to free pointer we did not allocate.
-  }
-
+ private:
   void* AllocateFast(std::ptrdiff_t num_bytes) {
     if (current_ + num_bytes > size_) {
       return nullptr;
@@ -132,12 +133,6 @@ class AlignedAllocator {
     return p;
   }
 
- private:
-  // Primitive allocation functions obtaining aligned memory from the
-  // operating system.
-  void* SystemAlignedAlloc(std::ptrdiff_t num_bytes);
-  void SystemAlignedFree(void* ptr);
-
   // Theory of operation:
   //
   // - ptr_, current_, and size_ implement a basic bump-ptr allocator.
@@ -171,7 +166,7 @@ class Allocator {
       return nullptr;
     }
     return aligned.AllocateAlignedBytes(
-        round_up_pot(num_bytes, detail::AlignedAllocator::kAlignment));
+        round_up_pot(num_bytes, detail::kMinimumBlockAlignment));
   }
   template 
   void Allocate(std::ptrdiff_t count, Pointer* out) {
diff --git a/tensorflow/lite/experimental/ruy/prepacked_cache.cc b/tensorflow/lite/experimental/ruy/prepacked_cache.cc
index 93fc4363044..2bd23f834c4 100644
--- a/tensorflow/lite/experimental/ruy/prepacked_cache.cc
+++ b/tensorflow/lite/experimental/ruy/prepacked_cache.cc
@@ -58,19 +58,14 @@ void PrepackedCache::EjectOne() {
   PrepackedMatrix &pmatrix = oldest->second.first;
   cache_size_ -= pmatrix.data_size;
   cache_size_ -= pmatrix.sums_size;
-  allocator_.FreeOne(pmatrix.data);
-  allocator_.FreeOne(pmatrix.sums);
+  allocator_.Free(pmatrix.data);
+  allocator_.Free(pmatrix.sums);
   cache_.erase(oldest);
 }
 
 void PrepackedCache::AllocatePrepackedMatrix(PrepackedMatrix *pmatrix) {
-  pmatrix->data = AllocateBytes(pmatrix->data_size);
-  pmatrix->sums = AllocateBytes(pmatrix->sums_size);
-}
-
-void *PrepackedCache::AllocateBytes(std::ptrdiff_t num_bytes) {
-  // Force system allocation for now to enable easy ejections.
-  return allocator_.AllocateSlow(num_bytes);
+  pmatrix->data = allocator_.Alloc(pmatrix->data_size);
+  pmatrix->sums = allocator_.Alloc(pmatrix->sums_size);
 }
 
 void PrepackedCache::DoInsert(const CacheKey &key,
diff --git a/tensorflow/lite/experimental/ruy/prepacked_cache.h b/tensorflow/lite/experimental/ruy/prepacked_cache.h
index 053108e61ed..9c77c48cf69 100644
--- a/tensorflow/lite/experimental/ruy/prepacked_cache.h
+++ b/tensorflow/lite/experimental/ruy/prepacked_cache.h
@@ -16,6 +16,7 @@ limitations under the License.
 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PREPACKED_CACHE_H_
 #define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PREPACKED_CACHE_H_
 
+#include 
 #include 
 #include 
 #include 
@@ -27,6 +28,40 @@ limitations under the License.
 
 namespace ruy {
 
+namespace detail {
+
+// Tracks a set of blocks allocated from the underlying system allocator.
+class SystemBlockAllocator {
+ public:
+  void *Alloc(std::ptrdiff_t num_bytes) {
+    void *p = detail::SystemAlignedAlloc(num_bytes);
+    blocks_.push_back(p);
+    return p;
+  }
+
+  void Free(void *block) {
+    for (auto it = blocks_.begin(); it != blocks_.end(); ++it) {
+      if (*it == block) {
+        detail::SystemAlignedFree(block);
+        blocks_.erase(it);
+        return;
+      }
+    }
+    RUY_DCHECK(false);  // Trying to free pointer we did not allocate.
+  }
+
+  ~SystemBlockAllocator() {
+    for (void *block : blocks_) {
+      detail::SystemAlignedFree(block);
+    }
+  }
+
+ private:
+  std::vector blocks_;
+};
+
+}  // namespace detail
+
 enum CachePolicy { kNoCache, kCacheLHSOnGemV };
 
 // "Low effort" Least Recently Used Cache for Prepacked Matrices
@@ -80,12 +115,8 @@ class PrepackedCache {
 
  private:
   void EjectOne();
-  void *AllocateBytes(std::ptrdiff_t num_bytes);
   void DoInsert(const CacheKey &key, const PrepackedMatrix &matrix);
-  // Since this cache is used in the context of "pre-packing", we need to
-  // handle allocating the space for the packed matrix ourselves, so we need
-  // our own allocator.
-  AlignedAllocator allocator_;
+  detail::SystemBlockAllocator allocator_;
   std::map cache_;
   const int32_t ejection_threshold_;
   size_t cache_size_;
diff --git a/tensorflow/lite/experimental/ruy/prepacked_cache_test.cc b/tensorflow/lite/experimental/ruy/prepacked_cache_test.cc
index 0e912495d09..efb6f2b358c 100644
--- a/tensorflow/lite/experimental/ruy/prepacked_cache_test.cc
+++ b/tensorflow/lite/experimental/ruy/prepacked_cache_test.cc
@@ -33,7 +33,7 @@ TEST(PrepackedCacheTest, TestCacheEjection) {
   mat1.data_size = 16;
   mat1.sums_size = 8;
   prepacked_cache.AllocatePrepackedMatrix(&mat1);
-  auto cache_key1 = std::make_pair(reinterpret_cast(0), mat1.data);
+  auto cache_key1 = std::make_pair(nullptr, mat1.data);
   prepacked_cache.Insert(cache_key1, mat1);
   std::this_thread::sleep_for(std::chrono::milliseconds(10));
   // Get a time point after the insertion into the cache.
@@ -49,7 +49,7 @@ TEST(PrepackedCacheTest, TestCacheEjection) {
   mat2.sums_size = 4;
   prepacked_cache.AllocatePrepackedMatrix(&mat2);
 
-  auto cache_key2 = std::make_pair(reinterpret_cast(0), mat2.data);
+  auto cache_key2 = std::make_pair(nullptr, mat2.data);
   prepacked_cache.Insert(cache_key2, mat2);
   // The cache size was exceeded by inserting mat2. Ensure that mat1 was
   // ejected.
@@ -67,7 +67,7 @@ TEST(PrepackedCacheTest, TestCacheBasic) {
   mat1.sums_size = 8;
   prepacked_cache.AllocatePrepackedMatrix(&mat1);
 
-  auto cache_key1 = std::make_pair(reinterpret_cast(0), mat1.data);
+  auto cache_key1 = std::make_pair(nullptr, mat1.data);
   prepacked_cache.Insert(cache_key1, mat1);
   std::this_thread::sleep_for(std::chrono::milliseconds(10));
   EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
@@ -77,7 +77,7 @@ TEST(PrepackedCacheTest, TestCacheBasic) {
   mat2.sums_size = 4;
   prepacked_cache.AllocatePrepackedMatrix(&mat2);
 
-  auto cache_key2 = std::make_pair(reinterpret_cast(0), mat2.data);
+  auto cache_key2 = std::make_pair(nullptr, mat2.data);
   std::this_thread::sleep_for(std::chrono::milliseconds(10));
   prepacked_cache.Insert(cache_key2, mat2);
   // The cache size was not exceeded by inserting mat2. Ensure that mat1 was not
@@ -95,7 +95,7 @@ TEST(PrepackedCacheTest, TestCacheEjection2) {
   mat1.data_size = 16;
   mat1.sums_size = 8;
   prepacked_cache.AllocatePrepackedMatrix(&mat1);
-  auto cache_key1 = std::make_pair(reinterpret_cast(0), mat1.data);
+  auto cache_key1 = std::make_pair(nullptr, mat1.data);
   prepacked_cache.Insert(cache_key1, mat1);
   std::this_thread::sleep_for(std::chrono::milliseconds(10));
 
@@ -104,7 +104,7 @@ TEST(PrepackedCacheTest, TestCacheEjection2) {
   mat2.data_size = 16;
   mat2.sums_size = 8;
   prepacked_cache.AllocatePrepackedMatrix(&mat2);
-  auto cache_key2 = std::make_pair(reinterpret_cast(0), mat2.data);
+  auto cache_key2 = std::make_pair(nullptr, mat2.data);
   prepacked_cache.Insert(cache_key2, mat2);
   std::this_thread::sleep_for(std::chrono::milliseconds(10));
 
@@ -113,7 +113,7 @@ TEST(PrepackedCacheTest, TestCacheEjection2) {
   mat31.data_size = 16;
   mat31.sums_size = 8;
   prepacked_cache.AllocatePrepackedMatrix(&mat31);
-  auto cache_key3 = std::make_pair(reinterpret_cast(0), mat31.data);
+  auto cache_key3 = std::make_pair(nullptr, mat31.data);
   prepacked_cache.Insert(cache_key3, mat31);
   std::this_thread::sleep_for(std::chrono::milliseconds(10));
 
@@ -128,7 +128,7 @@ TEST(PrepackedCacheTest, TestCacheEjection2) {
   mat4.data_size = 16;
   mat4.sums_size = 8;
   prepacked_cache.AllocatePrepackedMatrix(&mat4);
-  auto cache_key4 = std::make_pair(reinterpret_cast(0), mat4.data);
+  auto cache_key4 = std::make_pair(nullptr, mat4.data);
   prepacked_cache.Insert(cache_key4, mat4);
   std::this_thread::sleep_for(std::chrono::milliseconds(10));
 
diff --git a/tensorflow/lite/experimental/support/java/README.md b/tensorflow/lite/experimental/support/java/README.md
index bc00123af70..d5f3e121f3a 100644
--- a/tensorflow/lite/experimental/support/java/README.md
+++ b/tensorflow/lite/experimental/support/java/README.md
@@ -114,8 +114,9 @@ try{
 }
 
 // Running inference
-if(null != tflite)
+if(null != tflite) {
     tflite.run(tImage.getBuffer(), probabilityBuffer.getBuffer());
+}
 ```
 
 ### Accessing the result
@@ -138,9 +139,9 @@ import org.tensorflow.lite.support.common.FileUtil;
 final String ASSOCIATED_AXIS_LABELS = "labels.txt";
 List associatedAxisLabels = null;
 
-try{
+try {
     associatedAxisLabels = FileUtil.loadLabels(this, ASSOCIATED_AXIS_LABELS);
-} catch (IOException e){
+} catch (IOException e) {
     Log.e("tfliteSupport", "Error reading label file", e);
 }
 ```
@@ -192,11 +193,11 @@ int size = height > width ? width : height;
 ImageProcessor imageProcessor =
     new ImageProcessor.Builder()
         // Center crop the image to the largest square possible
-        .add(new ResizeWithCropOrPadOp(size , size))
+        .add(new ResizeWithCropOrPadOp(size, size))
         // Resize using Bilinear or Nearest neighbour
         .add(new ResizeOp(224, 224, ResizeOp.ResizeMethod.BILINEAR));
         // Rotation counter-clockwise in 90 degree increments
-        .add(new Rot90Op(rotateDegrees/90))
+        .add(new Rot90Op(rotateDegrees / 90))
         .build();
 ```
 
@@ -229,5 +230,5 @@ TensorProcessor probabilityProcessor =
     new TensorProcessor.Builder().add(new NormalizeOp(0, 255)).build();
 
 // Post-processor which dequantize the result
-TensorBuffer dequantizedBuffer = probabilityProcessor.process(probabilityBuffer)
+TensorBuffer dequantizedBuffer = probabilityProcessor.process(probabilityBuffer);
 ```
diff --git a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java
index 286f91f5037..ea6a085a3bc 100644
--- a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java
+++ b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java
@@ -61,10 +61,7 @@ public abstract class TensorBuffer {
    * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
    * 
* - *

The size of a fixed-size TensorBuffer cannot be changed once it is created. However, loading - * arraies or data buffers of the same buffer size but different shapes is allowed. - * - *

TODO(b/139782181): Shall we make it fixed-size or fixed-shape? + *

The size of a fixed-size TensorBuffer cannot be changed once it is created. * * @param shape The shape of the {@link TensorBuffer} to be created. * @param dataType The dataType of the {@link TensorBuffer} to be created. @@ -87,7 +84,7 @@ public abstract class TensorBuffer { * Creates an empty dynamic {@link TensorBuffer} with specified {@link DataType}. The shape of the * created {@link TensorBuffer} is {0}. * - *

Dynamic TensorBuffers will reallocate memory when Loading arraies or data buffers of + *

Dynamic TensorBuffers will reallocate memory when loading arrays or data buffers of * different buffer sizes. * * @param dataType The dataType of the {@link TensorBuffer} to be created. @@ -326,7 +323,7 @@ public abstract class TensorBuffer { allocateMemory(shape); } else { // Make sure the new shape fits the buffer size when TensorBuffer has fixed size. - SupportPreconditions.checkArgument(flatSize == computeFlatSize(shape)); + SupportPreconditions.checkArgument(Arrays.equals(shape, this.shape)); this.shape = shape.clone(); } } diff --git a/tensorflow/lite/experimental/swift/TensorFlowLite.tulsiproj/Configs/TensorFlowLite.tulsigen b/tensorflow/lite/experimental/swift/TensorFlowLite.tulsiproj/Configs/TensorFlowLite.tulsigen index 7ad7e33cf09..d919ada871d 100644 --- a/tensorflow/lite/experimental/swift/TensorFlowLite.tulsiproj/Configs/TensorFlowLite.tulsigen +++ b/tensorflow/lite/experimental/swift/TensorFlowLite.tulsiproj/Configs/TensorFlowLite.tulsigen @@ -1,6 +1,6 @@ { "sourceFilters" : [ - "tensorflow/lite/experimental/c", + "tensorflow/lite/c", "tensorflow/lite/experimental/swift", "tensorflow/lite/experimental/swift/Sources", "tensorflow/lite/experimental/swift/TestApp/TestApp", diff --git a/tensorflow/lite/g3doc/guide/hosted_models.md b/tensorflow/lite/g3doc/guide/hosted_models.md index ff3feb2e113..8bfdaf94538 100644 --- a/tensorflow/lite/g3doc/guide/hosted_models.md +++ b/tensorflow/lite/g3doc/guide/hosted_models.md @@ -13,12 +13,12 @@ models to find the optimal balance between size, performance, and accuracy. ## Image classification For more information about image classification, see -Image classification. +Image classification. ## Question and Answer For more information about text classification with Mobile BERT, see -Question And Answer. +Question And Answer. ### Quantized models diff --git a/tensorflow/lite/g3doc/guide/ios.md b/tensorflow/lite/g3doc/guide/ios.md index fc997bccf9d..0c7e5dc9c90 100644 --- a/tensorflow/lite/g3doc/guide/ios.md +++ b/tensorflow/lite/g3doc/guide/ios.md @@ -7,7 +7,7 @@ example: image classification example For an explanation of the source code, you should also read -[TensorFlow Lite iOS image classification](https://www.tensorflow.org/lite/models/image_classification/ios). +[TensorFlow Lite iOS image classification](https://www.tensorflow.org/code/py/tensorflow_examples/lite/examples/image_classification/ios/EXPLORE_THE_CODE.md). This example app uses [image classification](https://www.tensorflow.org/lite/models/image_classification/overview) diff --git a/tensorflow/lite/g3doc/microcontrollers/index.md b/tensorflow/lite/g3doc/microcontrollers/index.md index b78b131784e..64e80686116 100644 --- a/tensorflow/lite/g3doc/microcontrollers/index.md +++ b/tensorflow/lite/g3doc/microcontrollers/index.md @@ -35,6 +35,8 @@ There are example applications available for the following development boards: * [Arduino Nano 33 BLE Sense](https://store.arduino.cc/usa/nano-33-ble-sense-with-headers) * [SparkFun Edge](https://www.sparkfun.com/products/15170) * [STM32F746 Discovery kit](https://www.st.com/en/evaluation-tools/32f746gdiscovery.html) +* [Adafruit EdgeBadge](https://www.adafruit.com/product/4400) +* [Adafruit TensorFlow Lite for Microcontrollers Kit](https://www.adafruit.com/product/4317) To learn more about the libraries and examples, see [Get started with microcontrollers](get_started.md). diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 6bbc6561143..b3657228e63 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -354,17 +354,6 @@ cc_test( ], ) -cc_library( - name = "activation_functor", - hdrs = [ - "activation_functor.h", - ], - copts = tflite_copts(), - deps = [ - "//tensorflow/lite/c:common", - ], -) - cc_library( name = "op_macros", hdrs = [ @@ -614,7 +603,6 @@ cc_library( "//tensorflow/lite/kernels/internal:kernel_utils", "//tensorflow/lite/kernels/internal:tensor", "//tensorflow/lite/kernels/internal:tensor_utils", - "//third_party/eigen3", "@gemmlowp", ], ) diff --git a/tensorflow/lite/kernels/activation_functor.h b/tensorflow/lite/kernels/activation_functor.h deleted file mode 100644 index 60e93c185a9..00000000000 --- a/tensorflow/lite/kernels/activation_functor.h +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ -#define TENSORFLOW_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ - -#include -#include -#include - -#include "tensorflow/lite/c/builtin_op_data.h" - -namespace tflite { - -// Dynamic (non-fused) activation functor. perhaps it is worth having -// template instantiation? -// TODO(aselle): Make this more efficient by pulling the switch to conv_eval -// using template inlining. -class ActivationFunctor { - public: - explicit ActivationFunctor(TfLiteFusedActivation act) : act_(act) {} - - float operator()(float a) const { - switch (act_) { - case kTfLiteActNone: - return a; - case kTfLiteActRelu: - return a < 0.f ? 0.f : a; - case kTfLiteActRelu6: - return std::max(0.f, std::min(a, 6.f)); - case kTfLiteActTanh: - return std::tanh(a); - case kTfLiteActSigmoid: - return 1.0f / (1.0f + std::exp(-a)); - default: - // TODO(aselle): More informative fatal error! - exit(1); - } - } - - private: - TfLiteFusedActivation act_; -}; - -} // namespace tflite - -#endif // TENSORFLOW_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index c547b7921dc..646f14680ac 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -459,7 +459,6 @@ cc_library( ":types", ":scoped_profiling_label_wrapper", "@gemmlowp//:fixedpoint", - "@gemmlowp//:profiler", "//third_party/eigen3", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:op_macros", @@ -512,6 +511,7 @@ cc_library( ":compatibility", ":quantization_util", ":round", + ":scoped_profiling_label_wrapper", ":strided_slice_logic", ":legacy_types", ":tensor", @@ -577,7 +577,6 @@ cc_library( ":compatibility", ":round", "//tensorflow/lite/c:common", - "//tensorflow/lite/kernels:activation_functor", "//tensorflow/lite/kernels:cpu_backend_context", "@gemmlowp", ], @@ -670,14 +669,10 @@ cc_library( ], copts = tflite_copts() + NEON_FLAGS_IF_APPLICABLE, deps = [ - ":common", - ":compatibility", ":cpu_check", - ":types", + "//third_party/eigen3", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:cpu_backend_context", - "//tensorflow/lite/kernels:op_macros", - "@gemmlowp//:fixedpoint", ] + select({ ":aarch64": [ ":neon_tensor_utils", diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/mean.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/mean.h index 44f7040c089..5a9d4df9aa6 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/mean.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/mean.h @@ -23,18 +23,10 @@ limitations under the License. namespace tflite { namespace optimized_integer_ops { -#ifdef USE_NEON - -using optimized_ops::DivideSumForMeanImpl; -using optimized_ops::RoundToNearest; - -#endif // USE_NEON - inline void MeanImpl(const tflite::MeanParams& op_params, const RuntimeShape& input_shape, const int8_t* input_data, - int32 input_zero_point, float input_scale, + int32 multiplier, int32 shift, int32 bias, const RuntimeShape& output_shape, int8_t* output_data, - int32 output_zero_point, float output_scale, int start_depth, int end_depth) { gemmlowp::ScopedProfilingLabel label("Mean4D/Int8/MeanImpl"); @@ -45,7 +37,6 @@ inline void MeanImpl(const tflite::MeanParams& op_params, const int output_width = output_shape.Dims(2); const int input_height = input_shape.Dims(1); const int input_width = input_shape.Dims(2); - const float num_elements_in_axis = input_width * input_height; TFLITE_CHECK_EQ(op_params.axis_count, 2); TFLITE_CHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) || @@ -53,82 +44,98 @@ inline void MeanImpl(const tflite::MeanParams& op_params, TFLITE_CHECK_EQ(output_height, 1); TFLITE_CHECK_EQ(output_width, 1); - const bool ordinary_mean = - (input_zero_point == output_zero_point && input_scale == output_scale); - float scale = 0.0f, bias = 0.0f; - if (!ordinary_mean) { - scale = input_scale / output_scale; - bias = -input_zero_point * scale + 0.5; - } + constexpr static int32_t kMinValue = std::numeric_limits::min(); + constexpr static int32_t kMaxValue = std::numeric_limits::max(); #ifdef USE_NEON - const float32x4_t num_elements_dup = vdupq_n_f32(num_elements_in_axis); - // This is only an approximation as NEON does not offer division instruction. - const float32x4_t scale_dup = vdupq_n_f32(scale); - const float32x4_t num_elements_reverse = vrecpeq_f32(num_elements_dup); - float32x4_t zero_point_with_bias_dup = vdupq_n_f32(output_zero_point + bias); + const int32x4_t bias_dup = vdupq_n_s32(bias); + const int32x4_t min_dup = vdupq_n_s32(kMinValue); + const int32x4_t max_dup = vdupq_n_s32(kMaxValue); #endif // USE_NEON - for (int out_b = 0; out_b < output_batch; ++out_b) { int out_d = start_depth; #ifdef USE_NEON - for (; out_d < end_depth - 8; out_d += 8) { - float32x4_t temp_sum_1 = vdupq_n_f32(0); - float32x4_t temp_sum_2 = vdupq_n_f32(0); + for (; out_d <= end_depth - 16; out_d += 16) { + int32x4x4_t temp_sum; + temp_sum.val[0] = vdupq_n_s32(0); + temp_sum.val[1] = vdupq_n_s32(0); + temp_sum.val[2] = vdupq_n_s32(0); + temp_sum.val[3] = vdupq_n_s32(0); for (int in_h = 0; in_h < input_height; ++in_h) { for (int in_w = 0; in_w < input_width; ++in_w) { const int8_t* input_data_ptr = input_data + Offset(input_shape, out_b, in_h, in_w, out_d); - int8x8_t input_data_val = vld1_s8(input_data_ptr); - int16x8_t input_data_val_shift = vmovl_s8(input_data_val); - float32x4_t input_float_1 = - vcvtq_f32_s32(vmovl_s16(vget_high_s16(input_data_val_shift))); - float32x4_t input_float_2 = - vcvtq_f32_s32(vmovl_s16(vget_low_s16(input_data_val_shift))); - temp_sum_1 = vaddq_f32(temp_sum_1, input_float_1); - temp_sum_2 = vaddq_f32(temp_sum_2, input_float_2); + int8x16_t input_data_val = vld1q_s8(input_data_ptr); + + int16x8_t input_data_low_shift = + vmovl_s8(vget_low_s8(input_data_val)); + int16x8_t input_data_high_shift = + vmovl_s8(vget_high_s8(input_data_val)); + + int32x4_t input_low_low = + vmovl_s16(vget_low_s16(input_data_low_shift)); + int32x4_t input_high_low = + vmovl_s16(vget_high_s16(input_data_low_shift)); + int32x4_t input_low_high = + vmovl_s16(vget_low_s16(input_data_high_shift)); + int32x4_t input_high_high = + vmovl_s16(vget_high_s16(input_data_high_shift)); + + temp_sum.val[0] = vaddq_s32(temp_sum.val[0], input_low_low); + temp_sum.val[1] = vaddq_s32(temp_sum.val[1], input_high_low); + temp_sum.val[2] = vaddq_s32(temp_sum.val[2], input_low_high); + temp_sum.val[3] = vaddq_s32(temp_sum.val[3], input_high_high); } } - const float32x4_t mean_1 = - DivideSumForMeanImpl(temp_sum_1, num_elements_reverse, ordinary_mean, - scale_dup, zero_point_with_bias_dup); - const float32x4_t mean_2 = - DivideSumForMeanImpl(temp_sum_2, num_elements_reverse, ordinary_mean, - scale_dup, zero_point_with_bias_dup); + temp_sum = optimized_ops::MultiplyByQuantizedMultiplier4Rows( + temp_sum, multiplier, shift); + + temp_sum.val[0] = vaddq_s32(temp_sum.val[0], bias_dup); + temp_sum.val[1] = vaddq_s32(temp_sum.val[1], bias_dup); + temp_sum.val[2] = vaddq_s32(temp_sum.val[2], bias_dup); + temp_sum.val[3] = vaddq_s32(temp_sum.val[3], bias_dup); + + temp_sum.val[0] = vminq_s32(vmaxq_s32(temp_sum.val[0], min_dup), max_dup); + temp_sum.val[1] = vminq_s32(vmaxq_s32(temp_sum.val[1], min_dup), max_dup); + temp_sum.val[2] = vminq_s32(vmaxq_s32(temp_sum.val[2], min_dup), max_dup); + temp_sum.val[3] = vminq_s32(vmaxq_s32(temp_sum.val[3], min_dup), max_dup); + + int16x4_t narrowed_low_low = vmovn_s32(temp_sum.val[0]); + int16x4_t narrowed_high_low = vmovn_s32(temp_sum.val[1]); + int16x4_t narrowed_low_high = vmovn_s32(temp_sum.val[2]); + int16x4_t narrowed_high_high = vmovn_s32(temp_sum.val[3]); + + int16x8_t combined_low = + vcombine_s16(narrowed_low_low, narrowed_high_low); + int16x8_t combined_high = + vcombine_s16(narrowed_low_high, narrowed_high_high); + + int8x8_t narrowed_low = vmovn_s16(combined_low); + int8x8_t narrowed_high = vmovn_s16(combined_high); + + int8x16_t combined_output = vcombine_s8(narrowed_low, narrowed_high); - int32x4_t casted_mean_1 = RoundToNearest(mean_1); - int16x4_t narrow_range_mean_1 = vmovn_s32(casted_mean_1); - int32x4_t casted_mean_2 = RoundToNearest(mean_2); - int16x4_t narrow_range_mean_2 = vmovn_s32(casted_mean_2); - int16x8_t combined_mean = - vcombine_s16(narrow_range_mean_2, narrow_range_mean_1); - int8x8_t narrowed_combined_mean = vmovn_s16(combined_mean); int8_t* output_data_ptr = output_data + Offset(output_shape, out_b, 0, 0, out_d); - vst1_s8(output_data_ptr, narrowed_combined_mean); + vst1q_s8(output_data_ptr, combined_output); } #endif // USE_NEON for (; out_d < end_depth; ++out_d) { - float temp_value = 0; + int acc = 0; for (int in_h = 0; in_h < input_height; ++in_h) { for (int in_w = 0; in_w < input_width; ++in_w) { - temp_value += - input_data[Offset(input_shape, out_b, in_h, in_w, out_d)]; + acc += input_data[Offset(input_shape, out_b, in_h, in_w, out_d)]; } } - temp_value = temp_value / num_elements_in_axis; - if (ordinary_mean) { - output_data[Offset(output_shape, out_b, 0, 0, out_d)] = - static_cast(round(temp_value)); - } else { - output_data[Offset(output_shape, out_b, 0, 0, out_d)] = - static_cast(round(temp_value * scale + bias)) + - output_zero_point; - } + acc = MultiplyByQuantizedMultiplier(acc, multiplier, shift); + acc += bias; + acc = std::min(std::max(acc, kMinValue), kMaxValue); + output_data[Offset(output_shape, out_b, 0, 0, out_d)] = + static_cast(acc); } } } @@ -136,38 +143,34 @@ inline void MeanImpl(const tflite::MeanParams& op_params, struct MeanWorkerTask : cpu_backend_threadpool::Task { MeanWorkerTask(const tflite::MeanParams& op_params, const RuntimeShape& input_shape, const int8_t* input_data, - int32 input_zero_point, float input_scale, + int32 multiplier, int32 shift, int32 bias, const RuntimeShape& output_shape, int8_t* output_data, - int32 output_zero_point, float output_scale, int start_height, - int end_height) + int start_height, int end_height) : op_params(op_params), input_shape(input_shape), input_data(input_data), - input_zero_point(input_zero_point), - input_scale(input_scale), + multiplier(multiplier), + shift(shift), + bias(bias), output_shape(output_shape), output_data(output_data), - output_zero_point(output_zero_point), - output_scale(output_scale), start_height(start_height), end_height(end_height) {} void Run() override { - MeanImpl(op_params, input_shape, input_data, input_zero_point, input_scale, - output_shape, output_data, output_zero_point, output_scale, - start_height, end_height); + MeanImpl(op_params, input_shape, input_data, multiplier, shift, bias, + output_shape, output_data, start_height, end_height); } private: const tflite::MeanParams& op_params; const RuntimeShape& input_shape; const int8_t* input_data; - int32 input_zero_point; - float input_scale; + int32 multiplier; + int32 shift; + int32 bias; const RuntimeShape& output_shape; int8_t* output_data; - int32 output_zero_point; - float output_scale; int start_height; int end_height; }; @@ -197,6 +200,18 @@ inline void Mean(const tflite::MeanParams& op_params, TFLITE_CHECK_EQ(output_height, 1); TFLITE_CHECK_EQ(output_width, 1); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const float num_elements_in_axis = input_width * input_height; + + int32 bias = + output_zero_point - + static_cast(input_zero_point * input_scale / output_scale); + float real_scale = input_scale / (num_elements_in_axis * output_scale); + + int32 multiplier, shift; + QuantizeMultiplier(real_scale, &multiplier, &shift); + constexpr int kMinDepthPerThread = 8; int thread_count = output_depth / kMinDepthPerThread; thread_count = thread_count > 0 ? thread_count : 1; @@ -204,9 +219,8 @@ inline void Mean(const tflite::MeanParams& op_params, std::min(thread_count, cpu_backend_context->max_num_threads()); if (capped_thread_count == 1) { - MeanImpl(op_params, input_shape, input_data, input_zero_point, input_scale, - output_shape, output_data, output_zero_point, output_scale, 0, - output_depth); + MeanImpl(op_params, input_shape, input_data, multiplier, shift, bias, + output_shape, output_data, 0, output_depth); } else { // Instead parrallel for batch, we loop for the output_depth since batch // is typical 1. @@ -219,9 +233,8 @@ inline void Mean(const tflite::MeanParams& op_params, // Try to distribute the tasks as even as possible. int depth_end = depth_start + (output_depth - depth_start) / (capped_thread_count - i); - tasks.emplace_back(op_params, input_shape, input_data, input_zero_point, - input_scale, output_shape, output_data, - output_zero_point, output_scale, depth_start, + tasks.emplace_back(op_params, input_shape, input_data, multiplier, shift, + bias, output_shape, output_data, depth_start, depth_end); depth_start = depth_end; } diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc index b23e0305990..d5c1f227b9a 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -1970,6 +1970,37 @@ void NeonSub1Vector(const int16_t* vector, int v_size, int16_t* result) { } } +namespace { + +#if __aarch64__ +inline bool IsAllZero(const uint32x4_t u32x4) { + const uint32_t u32 = vmaxvq_u32(u32x4); + return !u32; +} +#else +inline bool IsAllZero(const uint32x4_t u32x4) { + const uint32x2_t u32x2 = vqadd_u32(vget_high_u32(u32x4), vget_low_u32(u32x4)); + const uint64x1_t u64 = vreinterpret_u64_u32(u32x2); + return !vget_lane_u64(u64, 0); +} +#endif + +#ifndef __SSE__ +// With Intel NEON-2-SSE translator library, this is a redefinition.. +inline bool IsAllZero(const int8x16_t v) { + return IsAllZero(vreinterpretq_u32_s8(v)); +} +#endif + +inline bool IsAllZero(const float32x4_t v_f32x4) { + const float32x4_t zero_f32x4 = vmovq_n_f32(0.0f); + // Compare-absolute greater-than, |v| > |0|, equivalently v != 0 + const uint32x4_t cmp_result = vcagtq_f32(v_f32x4, zero_f32x4); + return IsAllZero(cmp_result); +} + +} // namespace + bool NeonIsZeroVector(const float* vector, int v_size) { // If v_size is not divisible by the vector size, then we need to process the // final few elements sequentially. postamble_start shows the start index @@ -1977,15 +2008,10 @@ bool NeonIsZeroVector(const float* vector, int v_size) { const int postamble_start = RoundDownVectors(v_size); - const float32x4_t zero_x4_float = vmovq_n_f32(0.0f); int v = 0; for (; v < postamble_start; v += kFloatValuesPerNeonVector) { - const float32x4_t i_x4_float = vld1q_f32(vector + v); - uint32x4_t cmp_result = vceqq_f32(i_x4_float, zero_x4_float); - if (vgetq_lane_u32(cmp_result, 0) == 0) return false; - if (vgetq_lane_u32(cmp_result, 1) == 0) return false; - if (vgetq_lane_u32(cmp_result, 2) == 0) return false; - if (vgetq_lane_u32(cmp_result, 3) == 0) return false; + const float32x4_t v_f32x4 = vld1q_f32(vector + v); + if (!IsAllZero(v_f32x4)) return false; } // Postamble loop for (; v < v_size; ++v) { @@ -2001,15 +2027,10 @@ bool NeonIsZeroVector(const int8_t* vector, int v_size) { const int postamble_start = RoundDownVectors(v_size); - static const int32x4_t zero_x4_int32 = vmovq_n_s32(0); int v = 0; for (; v < postamble_start; v += kInt8ValuesPerNeonVector) { - const int32x4_t i_x4_int32 = vreinterpretq_s32_s8(vld1q_s8(vector + v)); - const uint32x4_t cmp_result = vceqq_s32(i_x4_int32, zero_x4_int32); - if (vgetq_lane_u32(cmp_result, 0) == 0) return false; - if (vgetq_lane_u32(cmp_result, 1) == 0) return false; - if (vgetq_lane_u32(cmp_result, 2) == 0) return false; - if (vgetq_lane_u32(cmp_result, 3) == 0) return false; + const int8x16_t v_s8x16 = vld1q_s8(vector + v); + if (!IsAllZero(v_s8x16)) return false; } // Postamble loop for (; v < v_size; ++v) { diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h index c4d3d0e13be..626afbe5d8d 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -196,15 +196,6 @@ void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch, PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector); } -void ApplySigmoidToVector(const float* vector, int v_size, float* result) { - PortableApplySigmoidToVector(vector, v_size, result); -} - -void ApplyActivationToVector(const float* vector, int v_size, - TfLiteFusedActivation activation, float* result) { - PortableApplyActivationToVector(vector, v_size, activation, result); -} - void Sub1Vector(const float* vector, int v_size, float* result) { NEON_OR_PORTABLE(Sub1Vector, vector, v_size, result); } diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index e478fb87720..26005e069a7 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -195,6 +195,71 @@ MatrixMap MapAsMatrixWithGivenNumberOfRows(Scalar* data, return MatrixMap(data, rows, cols); } +// TODO(renjieliu): Refactor this to merge with other +// MultiplyByQuantizedMultipler. +#ifdef USE_NEON +inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows( + int32x4x4_t input_val, int32 quantized_multiplier, int shift) { + using gemmlowp::RoundingDivideByPOT; + using gemmlowp::SaturatingRoundingDoublingHighMul; + const int left_shift = shift > 0 ? shift : 0; + const int right_shift = shift > 0 ? 0 : -shift; + int32x4x4_t result; + // The vector type support for SaturatingRoundingDoublingHighMulth in gemmlowp + // is limited to NEON. +#ifdef GEMMLOWP_NEON + const int32x4_t left_shifted_one_dup = vdupq_n_s32(1 << left_shift); + result.val[0] = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( + vmulq_s32(input_val.val[0], left_shifted_one_dup), + quantized_multiplier), + right_shift); + result.val[1] = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( + vmulq_s32(input_val.val[1], left_shifted_one_dup), + quantized_multiplier), + right_shift); + result.val[2] = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( + vmulq_s32(input_val.val[2], left_shifted_one_dup), + quantized_multiplier), + right_shift); + result.val[3] = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( + vmulq_s32(input_val.val[3], left_shifted_one_dup), + quantized_multiplier), + right_shift); +#else + for (int i = 0; i < 4; ++i) { + int32_t vals[4]; + vals[0] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul( + vgetq_lane_s32(input_val.val[i], 0) * (1 << left_shift), + quantized_multiplier), + right_shift); + vals[1] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul( + vgetq_lane_s32(input_val.val[i], 1) * (1 << left_shift), + quantized_multiplier), + right_shift); + vals[2] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul( + vgetq_lane_s32(input_val.val[i], 2) * (1 << left_shift), + quantized_multiplier), + right_shift); + vals[3] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul( + vgetq_lane_s32(input_val.val[i], 3) * (1 << left_shift), + quantized_multiplier), + right_shift); + + result.val[i] = vld1q_s32(reinterpret_cast(&vals)); + } +#endif + return result; +} +#endif + inline void AddBiasAndEvalActivationFunction(float output_activation_min, float output_activation_max, const RuntimeShape& bias_shape, @@ -849,9 +914,8 @@ inline uint32x4_t RoundToNearestUnsigned(const float32x4_t input) { inline void MeanImpl(const tflite::MeanParams& op_params, const RuntimeShape& input_shape, const uint8_t* input_data, - int32 input_zero_point, float input_scale, + int32 multiplier, int32 shift, int32 bias, const RuntimeShape& output_shape, uint8_t* output_data, - int32 output_zero_point, float output_scale, int start_depth, int end_depth) { gemmlowp::ScopedProfilingLabel label("Mean4D/Uint8/MeanImpl"); @@ -862,7 +926,6 @@ inline void MeanImpl(const tflite::MeanParams& op_params, const int output_width = output_shape.Dims(2); const int input_height = input_shape.Dims(1); const int input_width = input_shape.Dims(2); - const float num_elements_in_axis = input_width * input_height; TFLITE_CHECK_EQ(op_params.axis_count, 2); TFLITE_CHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) || @@ -870,83 +933,103 @@ inline void MeanImpl(const tflite::MeanParams& op_params, TFLITE_CHECK_EQ(output_height, 1); TFLITE_CHECK_EQ(output_width, 1); - const bool ordinary_mean = - (input_zero_point == output_zero_point && input_scale == output_scale); - float scale = 0.0f, bias = 0.0f; - if (!ordinary_mean) { - scale = input_scale / output_scale; - bias = -input_zero_point * scale + 0.5; - } + constexpr int32_t kMinValue = std::numeric_limits::min(); + constexpr int32_t kMaxValue = std::numeric_limits::max(); #ifdef USE_NEON - const float32x4_t num_elements_dup = vdupq_n_f32(num_elements_in_axis); - // This is only an approximation as NEON does not offer division instruction. - const float32x4_t scale_dup = vdupq_n_f32(scale); - const float32x4_t num_elements_reverse = vrecpeq_f32(num_elements_dup); - float32x4_t zero_point_with_bias_dup = vdupq_n_f32(output_zero_point + bias); + const int32x4_t bias_dup = vdupq_n_s32(bias); + const int32x4_t min_dup = vdupq_n_s32(kMinValue); + const int32x4_t max_dup = vdupq_n_s32(kMaxValue); #endif // USE_NEON for (int out_b = 0; out_b < output_batch; ++out_b) { int out_d = start_depth; #ifdef USE_NEON - for (; out_d < end_depth - 8; out_d += 8) { - float32x4_t temp_sum_1 = vdupq_n_f32(0); - float32x4_t temp_sum_2 = vdupq_n_f32(0); + for (; out_d <= end_depth - 16; out_d += 16) { + int32x4x4_t temp_sum; + temp_sum.val[0] = vdupq_n_s32(0); + temp_sum.val[1] = vdupq_n_s32(0); + temp_sum.val[2] = vdupq_n_s32(0); + temp_sum.val[3] = vdupq_n_s32(0); for (int in_h = 0; in_h < input_height; ++in_h) { for (int in_w = 0; in_w < input_width; ++in_w) { const uint8_t* input_data_ptr = input_data + Offset(input_shape, out_b, in_h, in_w, out_d); - uint8x8_t input_data_val = vld1_u8(input_data_ptr); - int16x8_t input_data_val_shift = - vreinterpretq_s16_u16(vmovl_u8(input_data_val)); - float32x4_t input_float_1 = - vcvtq_f32_s32(vmovl_s16(vget_high_s16(input_data_val_shift))); - float32x4_t input_float_2 = - vcvtq_f32_s32(vmovl_s16(vget_low_s16(input_data_val_shift))); - temp_sum_1 = vaddq_f32(temp_sum_1, input_float_1); - temp_sum_2 = vaddq_f32(temp_sum_2, input_float_2); + uint8x16_t input_data_val = vld1q_u8(input_data_ptr); + + int16x8_t input_data_low_shift = + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_data_val))); + int16x8_t input_data_high_shift = + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_data_val))); + + int32x4_t input_low_low = + vmovl_s16(vget_low_s16(input_data_low_shift)); + int32x4_t input_high_low = + vmovl_s16(vget_high_s16(input_data_low_shift)); + int32x4_t input_low_high = + vmovl_s16(vget_low_s16(input_data_high_shift)); + int32x4_t input_high_high = + vmovl_s16(vget_high_s16(input_data_high_shift)); + + temp_sum.val[0] = vaddq_s32(temp_sum.val[0], input_low_low); + temp_sum.val[1] = vaddq_s32(temp_sum.val[1], input_high_low); + temp_sum.val[2] = vaddq_s32(temp_sum.val[2], input_low_high); + temp_sum.val[3] = vaddq_s32(temp_sum.val[3], input_high_high); } } - const float32x4_t mean_1 = - DivideSumForMeanImpl(temp_sum_1, num_elements_reverse, ordinary_mean, - scale_dup, zero_point_with_bias_dup); - const float32x4_t mean_2 = - DivideSumForMeanImpl(temp_sum_2, num_elements_reverse, ordinary_mean, - scale_dup, zero_point_with_bias_dup); + temp_sum = + MultiplyByQuantizedMultiplier4Rows(temp_sum, multiplier, shift); + + temp_sum.val[0] = vaddq_s32(temp_sum.val[0], bias_dup); + temp_sum.val[1] = vaddq_s32(temp_sum.val[1], bias_dup); + temp_sum.val[2] = vaddq_s32(temp_sum.val[2], bias_dup); + temp_sum.val[3] = vaddq_s32(temp_sum.val[3], bias_dup); + + temp_sum.val[0] = vminq_s32(vmaxq_s32(temp_sum.val[0], min_dup), max_dup); + temp_sum.val[1] = vminq_s32(vmaxq_s32(temp_sum.val[1], min_dup), max_dup); + temp_sum.val[2] = vminq_s32(vmaxq_s32(temp_sum.val[2], min_dup), max_dup); + temp_sum.val[3] = vminq_s32(vmaxq_s32(temp_sum.val[3], min_dup), max_dup); + + uint16x4_t narrowed_low_low = + vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[0])); + uint16x4_t narrowed_high_low = + vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[1])); + uint16x4_t narrowed_low_high = + vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[2])); + uint16x4_t narrowed_high_high = + vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[3])); + + uint16x8_t combined_low = + vcombine_u16(narrowed_low_low, narrowed_high_low); + uint16x8_t combined_high = + vcombine_u16(narrowed_low_high, narrowed_high_high); + + uint8x8_t narrowed_low = vmovn_u16(combined_low); + uint8x8_t narrowed_high = vmovn_u16(combined_high); + + uint8x16_t combined_output = vcombine_u8(narrowed_low, narrowed_high); - uint32x4_t casted_mean_1 = RoundToNearestUnsigned(mean_1); - uint16x4_t narrow_range_mean_1 = vmovn_u32(casted_mean_1); - uint32x4_t casted_mean_2 = RoundToNearestUnsigned(mean_2); - uint16x4_t narrow_range_mean_2 = vmovn_u32(casted_mean_2); - uint16x8_t combined_mean = - vcombine_u16(narrow_range_mean_2, narrow_range_mean_1); - uint8x8_t narrowed_combined_mean = vmovn_u16(combined_mean); uint8_t* output_data_ptr = output_data + Offset(output_shape, out_b, 0, 0, out_d); - vst1_u8(output_data_ptr, narrowed_combined_mean); + vst1q_u8(output_data_ptr, combined_output); } #endif // USE_NEON for (; out_d < end_depth; ++out_d) { - float temp_value = 0; + int acc = 0; for (int in_h = 0; in_h < input_height; ++in_h) { for (int in_w = 0; in_w < input_width; ++in_w) { - temp_value += - input_data[Offset(input_shape, out_b, in_h, in_w, out_d)]; + acc += input_data[Offset(input_shape, out_b, in_h, in_w, out_d)]; } } - temp_value = temp_value / num_elements_in_axis; - if (ordinary_mean) { - output_data[Offset(output_shape, out_b, 0, 0, out_d)] = - static_cast(round(temp_value)); - } else { - output_data[Offset(output_shape, out_b, 0, 0, out_d)] = - static_cast(round(temp_value * scale + bias)) + - output_zero_point; - } + acc = MultiplyByQuantizedMultiplier(acc, multiplier, shift); + acc += bias; + acc = std::min(std::max(acc, kMinValue), kMaxValue); + output_data[Offset(output_shape, out_b, 0, 0, out_d)] = + static_cast(acc); } } } @@ -954,40 +1037,36 @@ inline void MeanImpl(const tflite::MeanParams& op_params, struct MeanWorkerTask : cpu_backend_threadpool::Task { MeanWorkerTask(const tflite::MeanParams& op_params, const RuntimeShape& input_shape, const uint8_t* input_data, - int32 input_zero_point, float input_scale, + int32 multiplier, int32 shift, int32 bias, const RuntimeShape& output_shape, uint8_t* output_data, - int32 output_zero_point, float output_scale, int start_height, - int end_height) - : op_params_(op_params), - input_shape_(input_shape), - input_data_(input_data), - input_zero_point_(input_zero_point), - input_scale_(input_scale), - output_shape_(output_shape), - output_data_(output_data), - output_zero_point_(output_zero_point), - output_scale_(output_scale), - start_height_(start_height), - end_height_(end_height) {} + int start_height, int end_height) + : op_params(op_params), + input_shape(input_shape), + input_data(input_data), + multiplier(multiplier), + shift(shift), + bias(bias), + output_shape(output_shape), + output_data(output_data), + start_height(start_height), + end_height(end_height) {} void Run() override { - MeanImpl(op_params_, input_shape_, input_data_, input_zero_point_, - input_scale_, output_shape_, output_data_, output_zero_point_, - output_scale_, start_height_, end_height_); + MeanImpl(op_params, input_shape, input_data, multiplier, shift, bias, + output_shape, output_data, start_height, end_height); } private: - const tflite::MeanParams& op_params_; - const RuntimeShape& input_shape_; - const uint8_t* input_data_; - int32 input_zero_point_; - float input_scale_; - const RuntimeShape& output_shape_; - uint8_t* output_data_; - int32 output_zero_point_; - float output_scale_; - int start_height_; - int end_height_; + const tflite::MeanParams& op_params; + const RuntimeShape& input_shape; + const uint8_t* input_data; + int32 multiplier; + int32 shift; + int32 bias; + const RuntimeShape& output_shape; + uint8_t* output_data; + int start_height; + int end_height; }; inline void Mean(const tflite::MeanParams& op_params, @@ -1015,6 +1094,18 @@ inline void Mean(const tflite::MeanParams& op_params, TFLITE_CHECK_EQ(output_height, 1); TFLITE_CHECK_EQ(output_width, 1); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const float num_elements_in_axis = input_width * input_height; + + int32 bias = + output_zero_point - + static_cast(input_zero_point * input_scale / output_scale); + float real_scale = input_scale / (num_elements_in_axis * output_scale); + + int32 multiplier, shift; + QuantizeMultiplier(real_scale, &multiplier, &shift); + constexpr int kMinDepthPerThread = 8; int thread_count = output_depth / kMinDepthPerThread; thread_count = thread_count > 0 ? thread_count : 1; @@ -1022,9 +1113,8 @@ inline void Mean(const tflite::MeanParams& op_params, std::min(thread_count, cpu_backend_context->max_num_threads()); if (capped_thread_count == 1) { - MeanImpl(op_params, input_shape, input_data, input_zero_point, input_scale, - output_shape, output_data, output_zero_point, output_scale, 0, - output_depth); + MeanImpl(op_params, input_shape, input_data, multiplier, shift, bias, + output_shape, output_data, 0, output_depth); } else { // Instead parrallel for batch, we loop for the output_depth since batch // is typical 1. @@ -1037,9 +1127,8 @@ inline void Mean(const tflite::MeanParams& op_params, // Try to distribute the tasks as even as possible. int depth_end = depth_start + (output_depth - depth_start) / (capped_thread_count - i); - tasks.emplace_back(op_params, input_shape, input_data, input_zero_point, - input_scale, output_shape, output_data, - output_zero_point, output_scale, depth_start, + tasks.emplace_back(op_params, input_shape, input_data, multiplier, shift, + bias, output_shape, output_data, depth_start, depth_end); depth_start = depth_end; } @@ -5465,71 +5554,6 @@ inline void TransposeConvV2( } } -// TODO(renjieliu): Refactor this to merge with other -// MultiplyByQuantizedMultipler. -#ifdef USE_NEON -inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows( - int32x4x4_t input_val, int32 quantized_multiplier, int shift) { - using gemmlowp::RoundingDivideByPOT; - using gemmlowp::SaturatingRoundingDoublingHighMul; - const int left_shift = shift > 0 ? shift : 0; - const int right_shift = shift > 0 ? 0 : -shift; - int32x4x4_t result; - // The vector type support for SaturatingRoundingDoublingHighMulth in gemmlowp - // is limited to NEON. -#ifdef GEMMLOWP_NEON - const int32x4_t left_shifted_one_dup = vdupq_n_s32(1 << left_shift); - result.val[0] = - RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( - vmulq_s32(input_val.val[0], left_shifted_one_dup), - quantized_multiplier), - right_shift); - result.val[1] = - RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( - vmulq_s32(input_val.val[1], left_shifted_one_dup), - quantized_multiplier), - right_shift); - result.val[2] = - RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( - vmulq_s32(input_val.val[2], left_shifted_one_dup), - quantized_multiplier), - right_shift); - result.val[3] = - RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( - vmulq_s32(input_val.val[3], left_shifted_one_dup), - quantized_multiplier), - right_shift); -#else - for (int i = 0; i < 4; ++i) { - int32_t vals[4]; - vals[0] = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul( - vgetq_lane_s32(input_val.val[i], 0) * (1 << left_shift), - quantized_multiplier), - right_shift); - vals[1] = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul( - vgetq_lane_s32(input_val.val[i], 1) * (1 << left_shift), - quantized_multiplier), - right_shift); - vals[2] = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul( - vgetq_lane_s32(input_val.val[i], 2) * (1 << left_shift), - quantized_multiplier), - right_shift); - vals[3] = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul( - vgetq_lane_s32(input_val.val[i], 3) * (1 << left_shift), - quantized_multiplier), - right_shift); - - result.val[i] = vld1q_s32(reinterpret_cast(&vals)); - } -#endif - return result; -} -#endif - inline void Quantize(int32_t multiplier, int32_t shift, int32_t total_size, int32_t output_zp, int32_t* scratch, uint8_t* output) { gemmlowp::ScopedProfilingLabel label("Quantize/uint8"); diff --git a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h index 3659b6f4e1a..37c1c5ce05a 100644 --- a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h @@ -206,15 +206,6 @@ void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch, PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector); } -void ApplySigmoidToVector(const float* vector, int v_size, float* result) { - PortableApplySigmoidToVector(vector, v_size, result); -} - -void ApplyActivationToVector(const float* vector, int v_size, - TfLiteFusedActivation activation, float* result) { - PortableApplyActivationToVector(vector, v_size, activation, result); -} - void Sub1Vector(const float* vector, int v_size, float* result) { NEON_OR_PORTABLE(Sub1Vector, vector, v_size, result); } diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc index dba6079009a..1ba34d45987 100644 --- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -21,7 +21,6 @@ limitations under the License. #include "fixedpoint/fixedpoint.h" #include "tensorflow/lite/c/builtin_op_data.h" -#include "tensorflow/lite/kernels/activation_functor.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/compatibility.h" @@ -591,23 +590,6 @@ void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch, } } -void PortableApplySigmoidToVector(const float* vector, int v_size, - float* result) { - auto sigmoid_func = ActivationFunctor(kTfLiteActSigmoid); - for (int v = 0; v < v_size; v++) { - *result++ = (sigmoid_func)(*vector++); - } -} - -void PortableApplyActivationToVector(const float* vector, int v_size, - TfLiteFusedActivation activation, - float* result) { - auto activation_func = ActivationFunctor(activation); - for (int v = 0; v < v_size; v++) { - *result++ = (activation_func)(*vector++); - } -} - void PortableSub1Vector(const float* vector, int v_size, float* result) { for (int v = 0; v < v_size; v++) { *result++ = 1.0f - *vector++; diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h index 9d8cf4e2b9a..587501fe2cb 100644 --- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h @@ -229,15 +229,6 @@ void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch, PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector); } -void ApplySigmoidToVector(const float* vector, int v_size, float* result) { - PortableApplySigmoidToVector(vector, v_size, result); -} - -void ApplyActivationToVector(const float* vector, int v_size, - TfLiteFusedActivation activation, float* result) { - PortableApplyActivationToVector(vector, v_size, activation, result); -} - void Sub1Vector(const float* vector, int v_size, float* result) { PortableSub1Vector(vector, v_size, result); } diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h index ddc400bb0c9..954ef6716b6 100644 --- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h +++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h @@ -171,15 +171,6 @@ void PortableVectorBatchVectorAssign(const float* vector, int v_size, void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch, float* batch_vector); -// Apply sigmoid to elements of a vector. -void PortableApplySigmoidToVector(const float* vector, int v_size, - float* result); - -// Apply activation function to elements of a vector. -void PortableApplyActivationToVector(const float* vector, int v_size, - TfLiteFusedActivation activation, - float* result); - // Compute "1.0f - elements of vector" (used in CIFG). void PortableSub1Vector(const float* vector, int v_size, float* result); diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index d91c00f755e..53b2049d74a 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -28,7 +28,6 @@ limitations under the License. #include "third_party/eigen3/Eigen/Core" #include "fixedpoint/fixedpoint.h" -#include "profiling/instrumentation.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" @@ -54,6 +53,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/softmax.h" #include "tensorflow/lite/kernels/internal/reference/strided_slice.h" #include "tensorflow/lite/kernels/internal/round.h" +#include "tensorflow/lite/kernels/internal/scoped_profiling_label_wrapper.h" #include "tensorflow/lite/kernels/internal/strided_slice_logic.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/types.h" @@ -191,7 +191,7 @@ inline void Relu(const RuntimeShape& input_shape, const T* input_data, template inline void Relu1(const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& output_shape, T* output_data) { - gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)"); + ScopedProfilingLabelWrapper label("Relu1 (not fused)"); const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; ++i) { const T val = input_data[i]; @@ -204,7 +204,7 @@ inline void Relu1(const RuntimeShape& input_shape, const T* input_data, inline void Relu6(const RuntimeShape& input_shape, const float* input_data, const RuntimeShape& output_shape, float* output_data) { - gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)"); + ScopedProfilingLabelWrapper label("Relu6 (not fused)"); const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; ++i) { const float val = input_data[i]; @@ -219,7 +219,7 @@ template inline void ReluX(const tflite::ReluParams& params, const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& output_shape, T* output_data) { - gemmlowp::ScopedProfilingLabel label("Quantized ReluX (not fused)"); + ScopedProfilingLabelWrapper label("Quantized ReluX (not fused)"); const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; ++i) { const int32 val = static_cast(input_data[i]); @@ -237,7 +237,7 @@ template inline void ReluX(const tflite::ActivationParams& params, const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& output_shape, T* output_data) { - gemmlowp::ScopedProfilingLabel label("Quantized ReluX (not fused)"); + ScopedProfilingLabelWrapper label("Quantized ReluX (not fused)"); const int flat_size = MatchingFlatSize(input_shape, output_shape); const T max_value = params.quantized_activation_max; const T min_value = params.quantized_activation_min; @@ -252,7 +252,7 @@ inline void ReluX(const tflite::ActivationParams& params, inline void LeakyRelu(const tflite::LeakyReluParams& params, const RuntimeShape& input_shape, const float* input_data, const RuntimeShape& output_shape, float* output_data) { - gemmlowp::ScopedProfilingLabel label("LeakyRelu (not fused)"); + ScopedProfilingLabelWrapper label("LeakyRelu (not fused)"); const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; ++i) { const float val = input_data[i]; @@ -267,7 +267,7 @@ inline void QuantizeLeakyRelu(const LeakyReluParams& params, T q_alpha, const T* input_data, const RuntimeShape& output_shape, T* output_data) { - gemmlowp::ScopedProfilingLabel label("LeakyRelu (not fused)"); + ScopedProfilingLabelWrapper label("LeakyRelu (not fused)"); const int flat_size = MatchingFlatSize(input_shape, output_shape); static const int32 quantized_min = std::numeric_limits::min(); static const int32 quantized_max = std::numeric_limits::max(); @@ -420,12 +420,11 @@ inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params, } } - inline void Mul(const ArithmeticParams& params, const RuntimeShape& input1_shape, const int16* input1_data, const RuntimeShape& input2_shape, const int16* input2_data, const RuntimeShape& output_shape, int16* output_data) { - gemmlowp::ScopedProfilingLabel label("Mul/Int16"); + ScopedProfilingLabelWrapper label("Mul/Int16"); const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape); @@ -444,7 +443,7 @@ inline void Mul(const ArithmeticParams& params, const RuntimeShape& input1_shape, const int16* input1_data, const RuntimeShape& input2_shape, const int16* input2_data, const RuntimeShape& output_shape, uint8* output_data) { - gemmlowp::ScopedProfilingLabel label("Mul/Int16Uint8"); + ScopedProfilingLabelWrapper label("Mul/Int16Uint8"); int32 output_offset = params.output_offset; int32 output_activation_min = params.quantized_activation_min; int32 output_activation_max = params.quantized_activation_max; @@ -581,7 +580,7 @@ inline void Div(const ArithmeticParams& params, const RuntimeShape& output_shape, uint8* output_data) { TFLITE_DCHECK_LE(params.quantized_activation_min, params.quantized_activation_max); - gemmlowp::ScopedProfilingLabel label("Div/8bit"); + ScopedProfilingLabelWrapper label("Div/8bit"); const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape); @@ -695,7 +694,7 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params, const float* input2_data, const RuntimeShape& output_shape, float* output_data) { - gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/float"); + ScopedProfilingLabelWrapper label("BroadcastSub4DSlow/float"); NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, @@ -736,7 +735,7 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params, const uint8* input2_data, const RuntimeShape& output_shape, uint8* output_data) { - gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/uint8"); + ScopedProfilingLabelWrapper label("BroadcastSub4DSlow/uint8"); NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, @@ -800,7 +799,7 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params, const int32* input2_data, const RuntimeShape& output_shape, int32* output_data) { - gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/int32"); + ScopedProfilingLabelWrapper label("BroadcastSub4DSlow/int32"); NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, @@ -840,7 +839,7 @@ void BroadcastSub4DSlow(const ArithmeticParams& params, const RuntimeShape& input1_shape, const T* input1_data, const RuntimeShape& input2_shape, const T* input2_data, const RuntimeShape& output_shape, T* output_data) { - gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/templated"); + ScopedProfilingLabelWrapper label("BroadcastSub4DSlow/templated"); NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, @@ -918,7 +917,7 @@ inline void SubWithActivation(const ArithmeticParams& params, const int32* input2_data, const RuntimeShape& output_shape, int32* output_data) { - gemmlowp::ScopedProfilingLabel label("SubWithActivation"); + ScopedProfilingLabelWrapper label("SubWithActivation"); const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; ++i) { @@ -948,7 +947,7 @@ inline void Sub16(const ArithmeticParams& params, const RuntimeShape& input1_shape, const int16_t* input1_data, const RuntimeShape& input2_shape, const int16_t* input2_data, const RuntimeShape& output_shape, int16_t* output_data) { - gemmlowp::ScopedProfilingLabel label("Sub/Int16"); + ScopedProfilingLabelWrapper label("Sub/Int16"); const int input1_shift = params.input1_shift; const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape); @@ -996,7 +995,7 @@ template void Pack(const PackParams& params, const RuntimeShape* const* input_shapes, const Scalar* const* input_data, const RuntimeShape& output_shape, Scalar* output_data) { - gemmlowp::ScopedProfilingLabel label("Pack"); + ScopedProfilingLabelWrapper label("Pack"); const int dimensions = output_shape.DimensionsCount(); int axis = params.axis; int inputs_count = params.inputs_count; @@ -1024,7 +1023,7 @@ template void Unpack(const UnpackParams& params, const RuntimeShape& input_shape, const Scalar* input_data, const RuntimeShape& output_shape, Scalar* const* output_datas) { - gemmlowp::ScopedProfilingLabel label("Unpack"); + ScopedProfilingLabelWrapper label("Unpack"); const int dimensions = input_shape.DimensionsCount(); const int outputs_count = params.num_split; @@ -1058,7 +1057,7 @@ void PackWithScaling(const PackParams& params, const RuntimeShape* const* input_shapes, const uint8* const* input_data, const RuntimeShape& output_shape, uint8* output_data) { - gemmlowp::ScopedProfilingLabel label("PackWithScaling"); + ScopedProfilingLabelWrapper label("PackWithScaling"); const int dimensions = output_shape.DimensionsCount(); int axis = params.axis; const int32* input_zeropoint = params.input_zeropoint; @@ -1108,7 +1107,7 @@ void DepthConcatenation(const ConcatenationParams& params, const RuntimeShape* const* input_shapes, const Scalar* const* input_data, const RuntimeShape& output_shape, Scalar* output_data) { - gemmlowp::ScopedProfilingLabel label("DepthConcatenation"); + ScopedProfilingLabelWrapper label("DepthConcatenation"); auto params_copy = params; params_copy.axis = 3; Concatenation(params_copy, input_shapes, input_data, output_shape, @@ -1512,7 +1511,7 @@ template void Split(const SplitParams& params, const RuntimeShape& input_shape, const Scalar* input_data, const RuntimeShape* const* output_shapes, Scalar* const* output_data) { - gemmlowp::ScopedProfilingLabel label("Split"); + ScopedProfilingLabelWrapper label("Split"); const int split_dimensions = input_shape.DimensionsCount(); int axis = params.axis < 0 ? params.axis + split_dimensions : params.axis; int outputs_count = params.num_split; @@ -1616,7 +1615,7 @@ inline void LogSoftmax(const SoftmaxParams& params, inline void LogSoftmax(const SoftmaxParams& params, const RuntimeShape& input_shape, const uint8* input_data, const RuntimeShape& output_shape, uint8* output_data) { - gemmlowp::ScopedProfilingLabel label("LogSoftmax/8bit"); + ScopedProfilingLabelWrapper label("LogSoftmax/8bit"); const int32 input_multiplier = params.input_multiplier; const int32 input_left_shift = params.input_left_shift; const int32 reverse_scaling_divisor = params.reverse_scaling_divisor; @@ -1770,7 +1769,7 @@ inline void Requantize(const input_type* input_data, int32_t size, int32_t effective_scale_multiplier, int32_t effective_scale_shift, int32_t input_zeropoint, int32_t output_zeropoint, output_type* output_data) { - gemmlowp::ScopedProfilingLabel label("Requantize"); + ScopedProfilingLabelWrapper label("Requantize"); const bool same_scale = (effective_scale_multiplier == 1 << 30 && effective_scale_shift == 1); if (same_scale) { @@ -1807,7 +1806,7 @@ inline void Requantize(const input_type* input_data, int32_t size, inline void FakeQuant(const tflite::FakeQuantParams& op_params, const RuntimeShape& input_shape, const float* input_data, const RuntimeShape& output_shape, float* output_data) { - gemmlowp::ScopedProfilingLabel label("FakeQuant"); + ScopedProfilingLabelWrapper label("FakeQuant"); float rmin = op_params.minmax.min; float rmax = op_params.minmax.max; int num_bits = op_params.num_bits; @@ -1860,7 +1859,7 @@ inline void Gather(const tflite::GatherParams& op_params, const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& coords_shape, const CoordsT* coords_data, const RuntimeShape& output_shape, T* output_data) { - gemmlowp::ScopedProfilingLabel label("Gather"); + ScopedProfilingLabelWrapper label("Gather"); int axis = op_params.axis; if (axis < 0) { axis += input_shape.DimensionsCount(); @@ -1898,7 +1897,7 @@ inline void GatherNd(const RuntimeShape& params_shape, const RuntimeShape& indices_shape, const IndicesT* indices_data, const RuntimeShape& output_shape, ParamsT* output_data) { - gemmlowp::ScopedProfilingLabel label("GatherNd"); + ScopedProfilingLabelWrapper label("GatherNd"); int n_slices = 1; int slice_size = 1; @@ -1935,7 +1934,7 @@ inline void ScatterNd(const RuntimeShape& indices_shape, const RuntimeShape& updates_shape, const UpdatesT* updates_data, const RuntimeShape& output_shape, UpdatesT* output_data) { - gemmlowp::ScopedProfilingLabel label("ScatterNd"); + ScopedProfilingLabelWrapper label("ScatterNd"); int n_slices = 1; int slice_size = 1; @@ -2043,7 +2042,7 @@ inline void SpaceToBatchND( const RuntimeShape& unextended_input2_shape, const int32* block_shape_data, const RuntimeShape& unextended_input3_shape, const int32* paddings_data, const RuntimeShape& unextended_output_shape, T* output_data) { - gemmlowp::ScopedProfilingLabel label("SpaceToBatchND"); + ScopedProfilingLabelWrapper label("SpaceToBatchND"); TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); const RuntimeShape input1_shape = @@ -2101,7 +2100,7 @@ inline void BatchToSpaceND( const RuntimeShape& unextended_input2_shape, const int32* block_shape_data, const RuntimeShape& unextended_input3_shape, const int32* crops_data, const RuntimeShape& unextended_output_shape, T* output_data) { - gemmlowp::ScopedProfilingLabel label("BatchToSpaceND"); + ScopedProfilingLabelWrapper label("BatchToSpaceND"); TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); const RuntimeShape input1_shape = @@ -2351,7 +2350,7 @@ inline void Slice(const tflite::SliceParams& op_params, template inline void Exp(const T* input_data, const size_t num_elements, T* output_data) { - gemmlowp::ScopedProfilingLabel label("Exp"); + ScopedProfilingLabelWrapper label("Exp"); for (size_t idx = 0; idx < num_elements; ++idx) { output_data[idx] = std::exp(input_data[idx]); } @@ -2482,7 +2481,7 @@ inline bool Mean(const T* input_data, const int* input_dims, const int* output_dims, const int output_num_dims, const int* axis, const int num_axis_dimensions, bool keep_dims, int* temp_index, int* resolved_axis, U* temp_sum) { - gemmlowp::ScopedProfilingLabel label("Mean"); + ScopedProfilingLabelWrapper label("Mean"); // Reset output data. size_t num_outputs = 1; for (int idx = 0; idx < output_num_dims; ++idx) { @@ -2536,7 +2535,7 @@ inline void Mean(const tflite::MeanParams& op_params, const RuntimeShape& unextended_input_shape, const T* input_data, const RuntimeShape& unextended_output_shape, T* output_data) { - gemmlowp::ScopedProfilingLabel label("Mean4D"); + ScopedProfilingLabelWrapper label("Mean4D"); // Current implementation only supports dimension equals 4 and simultaneous // reduction over width and height. @@ -2581,7 +2580,7 @@ inline void Mean(const tflite::MeanParams& op_params, float input_scale, const RuntimeShape& unextended_output_shape, uint8_t* output_data, int32 output_zero_point, float output_scale) { - gemmlowp::ScopedProfilingLabel label("Mean4D/Uint8"); + ScopedProfilingLabelWrapper label("Mean4D/Uint8"); // Current implementation only supports dimension equals 4 and simultaneous // reduction over width and height. @@ -2623,7 +2622,7 @@ inline void Mean(const tflite::MeanParams& op_params, acc += input_data[Offset(input_shape, out_b, in_h, in_w, out_d)]; } } - MultiplyByQuantizedMultiplier(acc, multiplier, shift); + acc = MultiplyByQuantizedMultiplier(acc, multiplier, shift); acc += bias; acc = std::min(std::max(acc, kMinValue), kMaxValue); output_data[Offset(output_shape, out_b, 0, 0, out_d)] = @@ -2647,11 +2646,9 @@ inline bool QuantizedMeanOrSum(const T* input_data, int32 input_zero_point, bool compute_sum) { const bool uint8_case = std::is_same::value; if (uint8_case) { - gemmlowp::ScopedProfilingLabel label(compute_sum ? "Sum/Uint8" - : "Mean/Uint8"); + ScopedProfilingLabelWrapper label(compute_sum ? "Sum/Uint8" : "Mean/Uint8"); } else { - gemmlowp::ScopedProfilingLabel label(compute_sum ? "Sum/Int8" - : "Mean/Int8"); + ScopedProfilingLabelWrapper label(compute_sum ? "Sum/Int8" : "Mean/Int8"); } // Reset output data. size_t num_outputs = 1; @@ -3248,7 +3245,7 @@ template void Reverse(int axis, const RuntimeShape& input_shape, const Scalar* input_data, const RuntimeShape& output_shape, Scalar* output_data) { - gemmlowp::ScopedProfilingLabel label("Reverse"); + ScopedProfilingLabelWrapper label("Reverse"); int outer_size = 1; for (int i = 0; i < axis; ++i) { @@ -3276,7 +3273,7 @@ void ReverseSequence(const TS* seq_lengths, const int seq_dim, const int batch_dim, const RuntimeShape& input_shape, const Scalar* input_data, const RuntimeShape& output_shape, Scalar* output_data) { - gemmlowp::ScopedProfilingLabel label("ReverseSequence"); + ScopedProfilingLabelWrapper label("ReverseSequence"); int outer_size = 1; int outer_dim = std::min(batch_dim, seq_dim); @@ -3353,7 +3350,7 @@ void ReverseSequence(const TS* seq_lengths, const int seq_dim, template inline void HardSwish(const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& output_shape, T* output_data) { - gemmlowp::ScopedProfilingLabel label("ReferenceHardSwish/Float"); + ScopedProfilingLabelWrapper label("ReferenceHardSwish/Float"); auto matching_size = MatchingFlatSize(input_shape, output_shape); const T* in_end = input_data + matching_size; for (; input_data < in_end; input_data++, output_data++) { @@ -3387,7 +3384,7 @@ template inline void HardSwish(const HardSwishParams& params, const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& output_shape, T* output_data) { - gemmlowp::ScopedProfilingLabel label("ReferenceHardSwish/Quantized"); + ScopedProfilingLabelWrapper label("ReferenceHardSwish/Quantized"); const int flat_size = MatchingFlatSize(input_shape, output_shape); diff --git a/tensorflow/lite/kernels/internal/tensor_utils.h b/tensorflow/lite/kernels/internal/tensor_utils.h index 7121403532a..fccd058bea5 100644 --- a/tensorflow/lite/kernels/internal/tensor_utils.h +++ b/tensorflow/lite/kernels/internal/tensor_utils.h @@ -16,7 +16,9 @@ limitations under the License. #define TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ #include +#include +#include "third_party/eigen3/Eigen/Core" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" @@ -401,12 +403,78 @@ void VectorBatchVectorAssign(const T* vector, int v_size, int n_batch, } } -// Apply sigmoid to elements of a vector. -void ApplySigmoidToVector(const float* vector, int v_size, float* result); +// Apply Rectified Linear to elements of a vector. +inline void ApplyReluToVector(const float* __restrict__ vector, int v_size, + float* __restrict__ result) { + for (int v = 0; v < v_size; v++) { + result[v] = std::max(0.0f, vector[v]); + } +} -// Apply activation function to elements of a vector. -void ApplyActivationToVector(const float* vector, int v_size, - TfLiteFusedActivation activation, float* result); +// Apply Rectified Linear 1 (cap to [-1;1]) to elements of a vector +inline void ApplyRelu1ToVector(const float* __restrict__ vector, int v_size, + float* __restrict__ result) { + for (int v = 0; v < v_size; v++) { + result[v] = std::max(-1.0f, std::min(vector[v], 1.0f)); + } +} + +// Apply Rectified Linear 6 (cap to [0;6]) to elements of a vector +inline void ApplyRelu6ToVector(const float* __restrict__ vector, int v_size, + float* __restrict__ result) { + for (int v = 0; v < v_size; v++) { + result[v] = std::max(0.0f, std::min(vector[v], 6.0f)); + } +} + +// Apply tanh to elements of a vector +inline void ApplyTanhToVector(const float* __restrict__ vector, int v_size, + float* __restrict__ result) { + using VectorMap = Eigen::Map>; + VectorMap input_map(const_cast(vector), v_size); + VectorMap output_map(result, v_size); + output_map.array() = input_map.array().tanh(); +} + +// Apply signbit to elements of a vector +inline void ApplySignbitToVector(const float* __restrict__ vector, int v_size, + float* __restrict__ result) { + for (int v = 0; v < v_size; v++) { + result[v] = std::signbit(vector[v]); + } +} + +// Apply sigmoid to elements of a vector. +inline void ApplySigmoidToVector(const float* __restrict__ vector, int v_size, + float* __restrict__ result) { + using VectorMap = Eigen::Map>; + VectorMap input_map(const_cast(vector), v_size); + VectorMap output_map(result, v_size); + output_map.array() = input_map.array().logistic(); +} + +// Apply appropriate activation function to elements of a vector. +inline void ApplyActivationToVector(const float* __restrict__ vector, + int v_size, + TfLiteFusedActivation activation, + float* __restrict__ result) { + switch (activation) { + case kTfLiteActNone: + return; + case kTfLiteActRelu: + return ApplyReluToVector(vector, v_size, result); + case kTfLiteActRelu1: + return ApplyRelu1ToVector(vector, v_size, result); + case kTfLiteActRelu6: + return ApplyRelu6ToVector(vector, v_size, result); + case kTfLiteActTanh: + return ApplyTanhToVector(vector, v_size, result); + case kTfLiteActSignBit: + return ApplySignbitToVector(vector, v_size, result); + case kTfLiteActSigmoid: + return ApplySigmoidToVector(vector, v_size, result); + } +} // Compute "1.0f - elements of vector" (used in CIFG). void Sub1Vector(const float* vector, int v_size, float* result); diff --git a/tensorflow/lite/kernels/lstm_eval.cc b/tensorflow/lite/kernels/lstm_eval.cc index 6773b691cfd..ba631a6ee24 100644 --- a/tensorflow/lite/kernels/lstm_eval.cc +++ b/tensorflow/lite/kernels/lstm_eval.cc @@ -24,7 +24,6 @@ limitations under the License. #include "profiling/profiler.h" #endif -#include "third_party/eigen3/Eigen/Core" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/kernel_utils.h" @@ -363,28 +362,6 @@ inline void LstmStepWithAuxInput( } } -void ApplyActivationsToVector(float* input, int input_size, - TfLiteFusedActivation activation_type, - float* output) { - using VectorMap = Eigen::Map>; - VectorMap input_map(input, input_size, 1); - VectorMap output_map(output, input_size, 1); - switch (activation_type) { - case kTfLiteActSigmoid: { - output_map.array() = input_map.array().logistic(); - break; - } - case kTfLiteActTanh: { - output_map.array() = input_map.array().tanh(); - break; - } - default: { - tensor_utils::ApplyActivationToVector(input, input_size, activation_type, - output); - } - } -} - // Same as above but with quantized weight matrices. In detail: // Input of size 'n_batch * n_input': // input_ptr_batch @@ -699,8 +676,8 @@ inline void LstmStepWithAuxInput( tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch, input_gate_scratch); } - ApplyActivationsToVector(input_gate_scratch, n_cell * n_batch, - kTfLiteActSigmoid, input_gate_scratch); + tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, + input_gate_scratch); } // For each batch and cell: update forget gate. @@ -721,8 +698,8 @@ inline void LstmStepWithAuxInput( tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch, forget_gate_scratch); } - ApplyActivationsToVector(forget_gate_scratch, n_cell * n_batch, - kTfLiteActSigmoid, forget_gate_scratch); + tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, + forget_gate_scratch); // For each batch and cell: update the cell. tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, @@ -736,8 +713,8 @@ inline void LstmStepWithAuxInput( tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch, cell_scratch); } - ApplyActivationsToVector(cell_scratch, n_batch * n_cell, params->activation, - cell_scratch); + tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, + params->activation, cell_scratch); if (use_cifg) { tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, forget_gate_scratch); @@ -772,10 +749,10 @@ inline void LstmStepWithAuxInput( tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch, output_gate_scratch); } - ApplyActivationsToVector(output_gate_scratch, n_batch * n_cell, - kTfLiteActSigmoid, output_gate_scratch); - ApplyActivationsToVector(cell_state_ptr, n_batch * n_cell, params->activation, - cell_scratch); + tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, + output_gate_scratch); + tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, + params->activation, cell_scratch); tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, n_batch * n_cell, output_gate_scratch); diff --git a/tensorflow/lite/kernels/lstm_test.cc b/tensorflow/lite/kernels/lstm_test.cc index d6a5f9a23cc..ac2f28cf278 100644 --- a/tensorflow/lite/kernels/lstm_test.cc +++ b/tensorflow/lite/kernels/lstm_test.cc @@ -2079,6 +2079,480 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest, VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm); } +class LSTMIntegerOpModel : public SingleOpModel { + public: + LSTMIntegerOpModel(int n_batch, int n_input, int n_cell, int n_output, + bool use_cifg, bool use_peephole, + bool use_projection_weights, bool use_projection_bias, + bool use_layer_norm, float cell_clip, float proj_clip, + const std::vector>& input_shapes, + const std::vector>& ranges, + const std::vector>& intermediates) + : n_batch_(n_batch), + n_input_(n_input), + n_cell_(n_cell), + n_output_(n_output) { + EXPECT_EQ(input_shapes.size() + 1, ranges.size()); + EXPECT_EQ(intermediates.size(), 5); + input_ = AddInput( + {TensorType_INT8, input_shapes[0], ranges[0].first, ranges[0].second}); + + if (use_cifg) { + input_to_input_weights_ = AddNullInput(); + } else { + input_to_input_weights_ = AddInput({TensorType_INT8, input_shapes[1], + ranges[1].first, ranges[1].second}); + } + input_to_forget_weights_ = AddInput( + {TensorType_INT8, input_shapes[2], ranges[2].first, ranges[2].second}); + input_to_cell_weights_ = AddInput( + {TensorType_INT8, input_shapes[3], ranges[3].first, ranges[3].second}); + input_to_output_weights_ = AddInput( + {TensorType_INT8, input_shapes[4], ranges[4].first, ranges[4].second}); + + if (use_cifg) { + recurrent_to_input_weights_ = AddNullInput(); + } else { + recurrent_to_input_weights_ = + AddInput({TensorType_INT8, input_shapes[5], ranges[5].first, + ranges[5].second}); + } + recurrent_to_forget_weights_ = AddInput( + {TensorType_INT8, input_shapes[6], ranges[6].first, ranges[6].second}); + recurrent_to_cell_weights_ = AddInput( + {TensorType_INT8, input_shapes[7], ranges[7].first, ranges[7].second}); + recurrent_to_output_weights_ = AddInput( + {TensorType_INT8, input_shapes[8], ranges[8].first, ranges[8].second}); + + if (use_peephole) { + if (use_cifg) { + cell_to_input_weights_ = AddNullInput(); + } else { + cell_to_input_weights_ = AddInput({TensorType_INT16, input_shapes[9], + ranges[9].first, ranges[9].second}); + } + cell_to_forget_weights_ = AddInput({TensorType_INT16, input_shapes[10], + ranges[10].first, ranges[10].second}); + cell_to_output_weights_ = AddInput({TensorType_INT8, input_shapes[11], + ranges[11].first, ranges[11].second}); + } else { + cell_to_input_weights_ = AddNullInput(); + cell_to_forget_weights_ = AddNullInput(); + cell_to_output_weights_ = AddNullInput(); + } + + if (use_cifg) { + input_gate_bias_ = AddNullInput(); + } else { + input_gate_bias_ = AddInput({TensorType_INT32, input_shapes[12], + ranges[12].first, ranges[12].second}); + } + forget_gate_bias_ = AddInput({TensorType_INT32, input_shapes[13], + ranges[13].first, ranges[13].second}); + cell_bias_ = AddInput({TensorType_INT32, input_shapes[14], ranges[14].first, + ranges[14].second}); + output_gate_bias_ = AddInput({TensorType_INT32, input_shapes[15], + ranges[15].first, ranges[15].second}); + + if (use_projection_weights) { + projection_weights_ = AddInput({TensorType_INT8, input_shapes[16], + ranges[16].first, ranges[16].second}); + if (use_projection_bias) { + projection_bias_ = AddInput({TensorType_INT32, input_shapes[17], + ranges[17].first, ranges[17].second}); + } else { + projection_bias_ = AddNullInput(); + } + } else { + projection_weights_ = AddNullInput(); + projection_bias_ = AddNullInput(); + } + + // Adding the 2 input state tensors. + input_activation_state_ = AddInput({TensorType_INT16, input_shapes[18], + ranges[18].first, ranges[18].second}, + true); + input_cell_state_ = AddInput({TensorType_INT16, input_shapes[19], + ranges[19].first, ranges[19].second}, + true); + + // Layer norm weights. + if (use_layer_norm) { + if (use_cifg) { + input_layer_norm_coefficients_ = AddNullInput(); + } else { + input_layer_norm_coefficients_ = + AddInput({TensorType_INT16, input_shapes[20], ranges[20].first, + ranges[20].second}); + } + forget_layer_norm_coefficients_ = + AddInput({TensorType_INT16, input_shapes[21], ranges[21].first, + ranges[21].second}); + cell_layer_norm_coefficients_ = + AddInput({TensorType_INT16, input_shapes[22], ranges[22].first, + ranges[22].second}); + output_layer_norm_coefficients_ = + AddInput({TensorType_INT16, input_shapes[23], ranges[23].first, + ranges[23].second}); + } + + for (int i = 0; i < intermediates.size(); ++i) { + intermediates_[i] = + AddIntermediate(TensorType_INT16, {intermediates[i].first}, + {intermediates[i].second}); + } + + output_ = AddOutput({TensorType_INT8, + {n_batch, n_output}, + ranges[24].first, + ranges[24].second}); + + SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, + CreateLSTMOptions(builder_, ActivationFunctionType_TANH, + cell_clip, proj_clip) + .Union()); + + // Do not apply delegate yet since tensor values are not known (and more + // specifically scales in quantized tensors are not known). + BuildInterpreter(input_shapes, /*allow_fp32_relax_to_fp16=*/false, + /*apply_delegate=*/false); + } + + void SetInputToInputWeights(const std::vector& f) { + QuantizeAndPopulate(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(const std::vector& f) { + QuantizeAndPopulate(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(const std::vector& f) { + QuantizeAndPopulate(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(const std::vector& f) { + QuantizeAndPopulate(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(const std::vector& f) { + QuantizeAndPopulate(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(const std::vector& f) { + QuantizeAndPopulate(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(const std::vector& f) { + QuantizeAndPopulate(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(const std::vector& f) { + QuantizeAndPopulate(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(const std::vector& f) { + QuantizeAndPopulate(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(const std::vector& f) { + QuantizeAndPopulate(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(const std::vector& f) { + QuantizeAndPopulate(cell_to_output_weights_, f); + } + + void SetInputLayerNormCoefficients(const std::vector& f) { + QuantizeAndPopulate(input_layer_norm_coefficients_, f); + } + + void SetForgetLayerNormCoefficients(const std::vector& f) { + QuantizeAndPopulate(forget_layer_norm_coefficients_, f); + } + + void SetCellLayerNormCoefficients(const std::vector& f) { + QuantizeAndPopulate(cell_layer_norm_coefficients_, f); + } + + void SetOutputLayerNormCoefficients(const std::vector& f) { + QuantizeAndPopulate(output_layer_norm_coefficients_, f); + } + + void SetInputGateBias(const std::vector& f) { + QuantizeAndPopulate(input_gate_bias_, f); + } + + void SetForgetGateBias(const std::vector& f) { + QuantizeAndPopulate(forget_gate_bias_, f); + } + + void SetCellBias(const std::vector& f) { + QuantizeAndPopulate(cell_bias_, f); + } + + void SetOutputGateBias(const std::vector& f) { + QuantizeAndPopulate(output_gate_bias_, f); + } + + void SetProjectionWeights(const std::vector& f) { + QuantizeAndPopulate(projection_weights_, f); + } + + void SetProjectionBias(const std::vector& f) { + QuantizeAndPopulate(projection_bias_, f); + } + + void SetInput(const std::vector& f) { + QuantizeAndPopulate(input_, f); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + int num_inputs() { return n_input_; } + int num_outputs() { return n_output_; } + int num_cells() { return n_cell_; } + int num_batches() { return n_batch_; } + + protected: + int input_; + int input_to_input_weights_; + int input_to_forget_weights_; + int input_to_cell_weights_; + int input_to_output_weights_; + + int recurrent_to_input_weights_; + int recurrent_to_forget_weights_; + int recurrent_to_cell_weights_; + int recurrent_to_output_weights_; + + int cell_to_input_weights_; + int cell_to_forget_weights_; + int cell_to_output_weights_; + + int input_layer_norm_coefficients_; + int forget_layer_norm_coefficients_; + int cell_layer_norm_coefficients_; + int output_layer_norm_coefficients_; + + int input_gate_bias_; + int forget_gate_bias_; + int cell_bias_; + int output_gate_bias_; + + int projection_weights_; + int projection_bias_; + int input_activation_state_; + int input_cell_state_; + + int intermediates_[5]; + + int output_; + int output_state_; + int cell_state_; + + int n_batch_; + int n_input_; + int n_cell_; + int n_output_; +}; + +TEST(LSTMIntegerOpModel, NoCifgYesLayerNormNoYesProjectionNoPeephole) { + // Hyper parameters. + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 4; + const int n_output = 3; + const float cell_clip = 0.0; + const float proj_clip = 0.0; + + // Model related weights. + const std::vector input_to_input_weights = { + 0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2, 0.3, -0.4, 0.5, + -0.8, 0.7, -0.6, 0.5, -0.4, -0.5, -0.4, -0.3, -0.2, -0.1}; + + const std::vector input_to_forget_weights = { + -0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, -0.4, 0.3, -0.8, + -0.4, 0.3, -0.5, -0.4, -0.6, 0.3, -0.4, -0.6, -0.5, -0.5}; + + const std::vector input_to_cell_weights = { + -0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6, + 0.6, -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8, 0.6}; + + const std::vector input_to_output_weights = { + -0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2, + 0.6, -0.2, 0.4, -0.7, -0.3, -0.5, 0.1, 0.5, -0.6, -0.4}; + + const std::vector input_gate_bias = {0.03, 0.15, 0.22, 0.38}; + + const std::vector forget_gate_bias = {0.1, -0.3, -0.2, 0.1}; + + const std::vector cell_gate_bias = {-0.05, 0.72, 0.25, 0.08}; + + const std::vector output_gate_bias = {0.05, -0.01, 0.2, 0.1}; + + const std::vector recurrent_to_input_weights = { + -0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6}; + + const std::vector recurrent_to_cell_weights = { + -0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2}; + + const std::vector recurrent_to_forget_weights = { + -0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2}; + + const std::vector recurrent_to_output_weights = { + 0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2}; + + const std::vector input_layer_norm_coefficients = {0.1, 0.2, 0.3, 0.5}; + const std::vector forget_layer_norm_coefficients = {0.2, 0.2, 0.4, + 0.3}; + const std::vector cell_layer_norm_coefficients = {0.7, 0.2, 0.3, 0.8}; + const std::vector output_layer_norm_coefficients = {0.6, 0.2, 0.2, + 0.5}; + + const std::vector projection_weights = { + -0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2}; + + // Input shapes. + const std::vector> inputs = { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor + + {n_cell}, // input_layer_norm_coefficient tensor + {n_cell}, // forget_layer_norm_coefficient tensor + {n_cell}, // cell_layer_norm_coefficient tensor + {n_cell}, // output_layer_norm_coefficient tensor + }; + + // Input ranges. + const std::vector> ranges = { + {-1.0, 127.0 / 128}, // input tensor + {-1.0, 0.9}, // input_to_input_weight tensor + {-1.0, 1.0}, // input_to_forget_weight tensor + {-1.0, 1.0}, // input_to_cell_weight tensor + {-1.0, 0.8}, // input_to_output_weight tensor + + {-0.8, 1.0}, // recurrent_to_input_weight tensor + {-0.8, 0.9}, // recurrent_to_forget_weight tensor + {-0.8, 1.0}, // recurrent_to_cell_weight tensor + {-1.0, 1.0}, // recurrent_to_output_weight tensor + + {-1, 1}, // cell_to_input_weight tensor + {-1, 1}, // cell_to_forget_weight tensor + {-1, 1}, // cell_to_output_weight tensor + + {-100, 100}, // input_gate_bias tensor + {-100, 80}, // forget_gate_bias tensor + {-100, 100}, // cell_bias tensor + {-100, 100}, // output_gate_bias tensor + + {-0.5, 0.5}, // projection_weight tensor + {-1, 1}, // projection_bias tensor + + {-1.0, 32767.0 / 32768}, // activation_state tensor + {-1, 1}, // cell_state tensor + + {0, 0.5}, // input_layer_norm_coefficient tensor + {0, 0.5}, // forget_layer_norm_coefficient tensor + {0, 1.0}, // cell_layer_norm_coefficient tensor + {0, 1.0}, // output_layer_norm_coefficient tensor + // Output scale is the same as input activation scale and only activation + // scale is used in the op, so this is only provided for clarity. + {-1.0, 32767.0 / 32768}, // output tensor. + }; + + // The scale and zero point of intermediate tensors. + std::vector> intermediates = { + {0.007059, 0}, {0.007812, 0}, {0.007059, 0}, {0.007812, 0}, {0.007, 0}}; + + // Create model. + LSTMIntegerOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, + /*use_layer_norm=*/true, cell_clip, proj_clip, inputs, + ranges, intermediates); + + // Set weights. + lstm.SetInputToInputWeights(input_to_input_weights); + lstm.SetInputToCellWeights(input_to_cell_weights); + lstm.SetInputToForgetWeights(input_to_forget_weights); + lstm.SetInputToOutputWeights(input_to_output_weights); + + lstm.SetInputGateBias(input_gate_bias); + lstm.SetCellBias(cell_gate_bias); + lstm.SetForgetGateBias(forget_gate_bias); + lstm.SetOutputGateBias(output_gate_bias); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights); + + lstm.SetProjectionWeights(projection_weights); + + lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients); + lstm.SetForgetLayerNormCoefficients(forget_layer_norm_coefficients); + lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients); + lstm.SetOutputLayerNormCoefficients(output_layer_norm_coefficients); + + // Model inputs. sequence -batch - input + const std::vector> lstm_input = { + { + 0.7, 0.8, 0.1, 0.2, 0.3, // + 0.8, 0.1, 0.2, 0.4, 0.5, // + }, + { + 0.2, 0.7, 0.7, 0.1, 0.7, // + 0.3, 0.2, 0.9, 0.8, 0.1, // + }, + { + 0.7, 0.8, 0.1, 0.2, 0.3, // + 0.3, 0.2, 0.9, 0.8, 0.1, // + }, + }; + + // Expected outputs. + const std::vector> expected_output = { + {107, 127, 127, -41, 127, 127}, + {53, 127, 127, 22, 127, 127}, + {90, 127, 127, 34, 127, 127}, + }; + + // Invoke and verify the result. + const int input_sequence_size = lstm_input.size(); + EXPECT_GT(input_sequence_size, 0); + for (int i = 0; i < input_sequence_size; ++i) { + lstm.SetInput(lstm_input[i]); + lstm.Invoke(); + const auto x = lstm.GetOutput(); + EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(expected_output[i])); + } +} + #ifdef GTEST_HAS_DEATH_TEST TEST(LSTMOpModel, InvalidTypeTest) { const int n_batch = 1; diff --git a/tensorflow/lite/kernels/reduce.cc b/tensorflow/lite/kernels/reduce.cc index 5685d4c4ff9..7c412334ab1 100644 --- a/tensorflow/lite/kernels/reduce.cc +++ b/tensorflow/lite/kernels/reduce.cc @@ -445,9 +445,25 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) { case kTfLiteUInt8: { // TODO(b/139102329): Handle all the cases in the combined reference // method. - if (op_context.input->params.zero_point == - op_context.output->params.zero_point && - op_context.input->params.scale == op_context.output->params.scale) { + tflite::MeanParams op_params; + op_params.axis_count = num_axis; + ResolveAxis(GetTensorData(op_context.axis), num_axis, &op_params); + if (op_context.params->keep_dims && + NumDimensions(op_context.input) == 4 && op_params.axis_count == 2 && + ((op_params.axis[0] == 1 && op_params.axis[1] == 2) || + (op_params.axis[0] == 2 && op_params.axis[1] == 1))) { + reference_ops::Mean(op_params, GetTensorShape(op_context.input), + GetTensorData(op_context.input), + op_context.input->params.zero_point, + op_context.input->params.scale, + GetTensorShape(op_context.output), + GetTensorData(op_context.output), + op_context.output->params.zero_point, + op_context.output->params.scale); + } else if (op_context.input->params.zero_point == + op_context.output->params.zero_point && + op_context.input->params.scale == + op_context.output->params.scale) { TF_LITE_ENSURE( context, reference_ops::Mean( @@ -726,6 +742,7 @@ TfLiteRegistration* Register_MEAN() { return Register_MEAN_REF(); #endif } + TfLiteRegistration* Register_SUM() { return Register_SUM_REF(); } TfLiteRegistration* Register_REDUCE_PROD() { return Register_REDUCE_PROD_REF(); diff --git a/tensorflow/lite/kernels/reduce_test.cc b/tensorflow/lite/kernels/reduce_test.cc index 87c178fc673..2bcfedaaf9f 100644 --- a/tensorflow/lite/kernels/reduce_test.cc +++ b/tensorflow/lite/kernels/reduce_test.cc @@ -287,18 +287,21 @@ TEST(ConstFloatMeanOpTest, KeepDims4DMeanUInt8) { TEST(ConstFloatMeanOpTest, KeepDims4DMeanLargeDepthUInt8) { float kQuantizedTolerance = GetTolerance(-5.0, 5.0); - std::vector data = {0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.1, - 0.1, 0.1, 0.1, 0.4, 0.2, 0.2, 0.2, 0.9, 0.9, - 0.9, 0.9, 0.2, 0.3, 0.7, 0.7, 0.1, 0.1, 0.3, - 0.3, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4}; - MeanOpConstModel m({TensorType_UINT8, {1, 2, 2, 9}, -1.0, 1.0}, + std::vector data = { + 0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.1, 0.1, 0.1, 0.1, 0.4, 0.2, 0.2, + 0.2, 0.9, 0.9, 0.9, 0.9, 0.2, 0.3, 0.7, 0.7, 0.1, 0.1, 0.3, 0.3, 0.1, 0.2, + 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.1, + 0.1, 0.1, 0.1, 0.4, 0.2, 0.2, 0.2, 0.9, 0.9, 0.9, 0.9, 0.2, 0.3, 0.7, 0.7, + 0.1, 0.1, 0.3, 0.3, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4}; + MeanOpConstModel m({TensorType_UINT8, {1, 2, 2, 18}, -1.0, 1.0}, {TensorType_UINT8, {2}, -1.0, 1.0}, {2}, {1, 2}, true); m.QuantizeAndPopulate(m.Input(), data); m.Invoke(); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 9})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 18})); EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( - {0.35, 0.325, 0.2, 0.35, 0.375, 0.325, 0.225, 0.45, 0.425}, + {0.5, 0.55, 0.25, 0.35, 0.45, 0.5, 0.25, 0.3, 0.2, 0.2, 0.1, + 0.15, 0.35, 0.3, 0.15, 0.2, 0.6, 0.65}, kQuantizedTolerance))); } @@ -455,6 +458,26 @@ TEST(ConstInt8MeanOpTest, QuantizedDifferentScale) { kQuantizedTolerance))); } +TEST(ConstFloatMeanOpTest, KeepDims4DMeanLargeDepthInt8) { + float kQuantizedTolerance = GetTolerance(-5.0, 5.0); + std::vector data = { + 0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.1, 0.1, 0.1, 0.1, 0.4, 0.2, 0.2, + 0.2, 0.9, 0.9, 0.9, 0.9, 0.2, 0.3, 0.7, 0.7, 0.1, 0.1, 0.3, 0.3, 0.1, 0.2, + 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.1, + 0.1, 0.1, 0.1, 0.4, 0.2, 0.2, 0.2, 0.9, 0.9, 0.9, 0.9, 0.2, 0.3, 0.7, 0.7, + 0.1, 0.1, 0.3, 0.3, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4}; + MeanOpConstModel m({TensorType_INT8, {1, 2, 2, 18}, -1.0, 1.0}, + {TensorType_INT8, {2}, -1.0, 1.0}, {2}, {1, 2}, true); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 18})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + {0.5, 0.55, 0.25, 0.35, 0.45, 0.5, 0.25, 0.3, 0.2, 0.2, 0.1, + 0.15, 0.35, 0.3, 0.15, 0.2, 0.6, 0.65}, + kQuantizedTolerance))); +} + TEST(DynamicUint8MeanOpTest, NotKeepDims) { float kQuantizedTolerance = GetTolerance(-5.0, 2.0); std::vector data = {1.3, -4.8, -3.6, 0.24}; diff --git a/tensorflow/lite/kernels/test_util.cc b/tensorflow/lite/kernels/test_util.cc index 12cde4cc9d1..67cd514e1e8 100644 --- a/tensorflow/lite/kernels/test_util.cc +++ b/tensorflow/lite/kernels/test_util.cc @@ -88,6 +88,24 @@ int SingleOpModel::AddInput(const TensorData& t, bool is_variable) { return id; } +int SingleOpModel::AddIntermediate(TensorType type, + const std::vector& scale, + const std::vector& zero_point) { + // Currently supports only int16 intermediate types. + // TODO(jianlijianli): make use of the type. + int id = tensors_.size(); + flatbuffers::Offset q_params = + CreateQuantizationParameters(builder_, /*min=*/0, /*max=*/0, + builder_.CreateVector(scale), + builder_.CreateVector(zero_point)); + tensors_.push_back(CreateTensor(builder_, builder_.CreateVector({}), + type, + /*buffer=*/0, + /*name=*/0, q_params, false)); + intermediates_.push_back(id); + return id; +} + int SingleOpModel::AddNullInput() { int id = kTfLiteOptionalTensor; inputs_.push_back(id); @@ -108,7 +126,8 @@ void SingleOpModel::SetBuiltinOp(BuiltinOperator type, builder_, /*opcode_index=*/0, builder_.CreateVector(inputs_), builder_.CreateVector(outputs_), builtin_options_type, builtin_options, - /*custom_options=*/0, CustomOptionsFormat_FLEXBUFFERS)); + /*custom_options=*/0, CustomOptionsFormat_FLEXBUFFERS, 0, + builder_.CreateVector(intermediates_))); } void SingleOpModel::SetCustomOp( diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h index 380d9b10e89..d9f3bc9d584 100644 --- a/tensorflow/lite/kernels/test_util.h +++ b/tensorflow/lite/kernels/test_util.h @@ -165,6 +165,9 @@ class SingleOpModel { } int AddInput(const TensorData& t, bool is_variable = false); + int AddIntermediate(TensorType type, const std::vector& scale, + const std::vector& zero_point); + // Templated version of AddConstInput(). template int AddConstInput(const TensorData& t, std::initializer_list data) { @@ -587,6 +590,7 @@ class SingleOpModel { std::map tensor_data_; std::vector inputs_; + std::vector intermediates_; std::vector outputs_; std::vector> tensors_; std::vector> opcodes_; diff --git a/tensorflow/lite/nnapi/BUILD b/tensorflow/lite/nnapi/BUILD index e26d9567337..0a687e83131 100644 --- a/tensorflow/lite/nnapi/BUILD +++ b/tensorflow/lite/nnapi/BUILD @@ -57,7 +57,7 @@ cc_library( "//conditions:default": ["-lrt"], }), deps = [ - "//tensorflow/lite/nnapi:nnapi_lib", + ":nnapi_lib", ], ) @@ -76,7 +76,29 @@ cc_test( name = "nnapi_implementation_test", srcs = ["nnapi_implementation_test.cc"], deps = [ - "//tensorflow/lite/nnapi:nnapi_implementation", + ":nnapi_implementation", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "nnapi_handler", + srcs = ["nnapi_handler.cc"], + hdrs = ["nnapi_handler.h"], + deps = [ + ":nnapi_implementation", + ":nnapi_lib", + "//tensorflow/core/platform:logging", + "//tensorflow/lite:framework", + ], +) + +cc_test( + name = "nnapi_handler_test", + srcs = ["nnapi_handler_test.cc"], + deps = [ + ":nnapi_handler", + ":nnapi_implementation", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/lite/nnapi/nnapi_handler.cc b/tensorflow/lite/nnapi/nnapi_handler.cc new file mode 100644 index 00000000000..354ad66463c --- /dev/null +++ b/tensorflow/lite/nnapi/nnapi_handler.cc @@ -0,0 +1,44 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/nnapi/nnapi_handler.h" + +#include + +#include "tensorflow/lite/nnapi/nnapi_implementation.h" + +namespace tflite { +namespace nnapi { + +const NnApi* NnApiPassthroughInstance() { + static const NnApi orig_nnapi_copy = *NnApiImplementation(); + return &orig_nnapi_copy; +} + +// static +NnApiHandler* NnApiHandler::Instance() { + // Ensuring that the original copy of nnapi is saved before we return + // access to NnApiHandler + NnApiPassthroughInstance(); + static NnApiHandler handler{const_cast(NnApiImplementation())}; + return &handler; +} + +void NnApiHandler::Reset() { + // Restores global NNAPI to original value + *nnapi_ = *NnApiPassthroughInstance(); +} + +} // namespace nnapi +} // namespace tflite diff --git a/tensorflow/lite/nnapi/nnapi_handler.h b/tensorflow/lite/nnapi/nnapi_handler.h new file mode 100644 index 00000000000..70406ba2c6e --- /dev/null +++ b/tensorflow/lite/nnapi/nnapi_handler.h @@ -0,0 +1,197 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_NNAPI_NNAPI_HANDLER_H_ +#define TENSORFLOW_LITE_NNAPI_NNAPI_HANDLER_H_ + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/lite/nnapi/nnapi_implementation.h" + +namespace tflite { +namespace nnapi { + +// Offers an interface to alter the behaviour of the NNAPI instance. +// As for NNAPI, it is designed to be a singleton. +// It allows to change the behaviour of some of the methods with some stub +// implementation and then to reset the behavior to the original one using +// Reset(). +// +class NnApiHandler { + public: + // No destructor defined to allow this class to be used as singleton. + + // Factory method, only one instance per process/jni library. + static NnApiHandler* Instance(); + + // Makes the current object a transparent proxy again, resetting any + // applied changes to its methods. + void Reset(); + + // Using templates in the ...Returns methods because the functions need to be + // stateless and the template generated code is more readable than using a + // file-local variable in the method implementation to store the configured + // result. + + template + void GetDeviceCountReturns() { + nnapi_->ANeuralNetworks_getDeviceCount = [](uint32_t* numDevices) -> int { + *numDevices = 2; + return Value; + }; + } + + void StubGetDeviceCountWith(int(stub)(uint32_t*)) { + nnapi_->ANeuralNetworks_getDeviceCount = stub; + } + + template + void ModelCreateReturns() { + nnapi_->ANeuralNetworksModel_create = [](ANeuralNetworksModel** model) { + *model = reinterpret_cast(1); + return Value; + }; + } + + template + void AddOperandReturns() { + nnapi_->ANeuralNetworksModel_addOperand = + [](ANeuralNetworksModel* model, + const ANeuralNetworksOperandType* type) { return Value; }; + } + + template + void SetOperandValueReturns() { + nnapi_->ANeuralNetworksModel_setOperandValue = + [](ANeuralNetworksModel* model, int32_t index, const void* buffer, + size_t length) { return Value; }; + } + + template + void AddOperationReturns() { + nnapi_->ANeuralNetworksModel_addOperation = + [](ANeuralNetworksModel* model, ANeuralNetworksOperationType type, + uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, + const uint32_t* outputs) { return Value; }; + } + + template + void IdentifyInputAndOutputsReturns() { + nnapi_->ANeuralNetworksModel_identifyInputsAndOutputs = + [](ANeuralNetworksModel* model, uint32_t inputCount, + const uint32_t* inputs, uint32_t outputCount, + const uint32_t* outputs) { return Value; }; + } + + template + void RelaxComputationFloatReturns() { + nnapi_->ANeuralNetworksModel_relaxComputationFloat32toFloat16 = + [](ANeuralNetworksModel* model, bool allow) { return Value; }; + } + + template + void ModelFinishReturns() { + nnapi_->ANeuralNetworksModel_finish = [](ANeuralNetworksModel* model) { + return Value; + }; + } + + template + void MemoryCreateFromFdReturns() { + nnapi_->ANeuralNetworksMemory_createFromFd = + [](size_t size, int protect, int fd, size_t offset, + ANeuralNetworksMemory** memory) { + *memory = reinterpret_cast(2); + return Value; + }; + } + + template + void CompilationCreateReturns() { + nnapi_->ANeuralNetworksCompilation_create = + [](ANeuralNetworksModel* model, + ANeuralNetworksCompilation** compilation) { + *compilation = reinterpret_cast(3); + return Value; + }; + } + + template + void CompilationFinishReturns() { + nnapi_->ANeuralNetworksCompilation_finish = + [](ANeuralNetworksCompilation* compilation) { return Value; }; + } + + template + void ExecutionCreateReturns() { + nnapi_->ANeuralNetworksExecution_create = + [](ANeuralNetworksCompilation* compilation, + ANeuralNetworksExecution** execution) { + if (compilation == nullptr) return 1; + *execution = reinterpret_cast(4); + return Value; + }; + } + template + void ExecutionSetInputFromMemoryReturns() { + nnapi_->ANeuralNetworksExecution_setInputFromMemory = + [](ANeuralNetworksExecution* execution, int32_t index, + const ANeuralNetworksOperandType* type, + const ANeuralNetworksMemory* memory, size_t offset, + size_t length) { return Value; }; + } + template + void ExecutionSetOutputFromMemoryReturns() { + nnapi_->ANeuralNetworksExecution_setOutputFromMemory = + [](ANeuralNetworksExecution* execution, int32_t index, + const ANeuralNetworksOperandType* type, + const ANeuralNetworksMemory* memory, size_t offset, + size_t length) { return Value; }; + } + + template + void ExecutionComputeReturns() { + nnapi_->ANeuralNetworksExecution_compute = + [](ANeuralNetworksExecution* execution) { return Value; }; + } + + protected: + explicit NnApiHandler(NnApi* nnapi) : nnapi_(nnapi) { DCHECK(nnapi); } + + NnApi* nnapi_; +}; + +// Returns a pointer to an unaltered instance of NNAPI. Is intended +// to be used by stub methods when wanting to pass-through to original +// implementation for example: +// +// NnApiTestUtility()->StubGetDeviceWith( +// [](uint32_t devIndex, ANeuralNetworksDevice** device) -> int { +// static int count = 0; +// if (count++ < 1) { +// NnApiPassthroughInstance()->ANeuralNetworks_getDevice( +// devIndex, device); +// } else { +// return ANEURALNETWORKS_BAD_DATA; +// } +// }); +const NnApi* NnApiPassthroughInstance(); + +// Returns an instance of NnApiProxy that can be used to alter +// the behaviour of the TFLite wide instance of NnApi. +NnApiHandler* NnApiProxyInstance(); + +} // namespace nnapi +} // namespace tflite + +#endif // TENSORFLOW_LITE_NNAPI_NNAPI_HANDLER_H_ diff --git a/tensorflow/lite/nnapi/nnapi_handler_test.cc b/tensorflow/lite/nnapi/nnapi_handler_test.cc new file mode 100644 index 00000000000..aea766ef036 --- /dev/null +++ b/tensorflow/lite/nnapi/nnapi_handler_test.cc @@ -0,0 +1,143 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/nnapi/nnapi_handler.h" + +#include +#include + +#include +#include +#include "tensorflow/lite/nnapi/nnapi_implementation.h" + +namespace tflite { +namespace nnapi { + +using testing::Eq; +using testing::Ne; +using testing::NotNull; + +void ExpectEquals(const NnApi& left, const NnApi& right); + +class NnApiHandlerTest : public ::testing::Test { + protected: + ~NnApiHandlerTest() override { NnApiHandler::Instance()->Reset(); } +}; + +TEST_F(NnApiHandlerTest, ShouldAlterNnApiInstanceBehaviour) { + const NnApi* nnapi = NnApiImplementation(); + + const auto device_count_stub = [](uint32_t* device_count) -> int { + *device_count = 999; + return ANEURALNETWORKS_NO_ERROR; + }; + + NnApiHandler::Instance()->StubGetDeviceCountWith(device_count_stub); + + ASSERT_THAT(nnapi->ANeuralNetworks_getDeviceCount, NotNull()); + + uint32_t device_count = 0; + nnapi->ANeuralNetworks_getDeviceCount(&device_count); + EXPECT_THAT(device_count, Eq(999)); +} + +TEST_F(NnApiHandlerTest, ShouldRestoreNnApiToItsOriginalValueWithReset) { + NnApi nnapi_orig_copy = *NnApiImplementation(); + + auto device_count_override = [](uint32_t* device_count) -> int { + *device_count = 777; + return ANEURALNETWORKS_NO_ERROR; + }; + + NnApiHandler::Instance()->StubGetDeviceCountWith(device_count_override); + + EXPECT_THAT(nnapi_orig_copy.ANeuralNetworks_getDeviceCount, + Ne(NnApiImplementation()->ANeuralNetworks_getDeviceCount)); + + NnApiHandler::Instance()->Reset(); + + ExpectEquals(nnapi_orig_copy, *NnApiImplementation()); +} + +int (*device_count_ptr)(uint32_t*); +TEST_F(NnApiHandlerTest, ShouldSupportPassthroughCalls) { + const NnApi* nnapi = NnApiImplementation(); + device_count_ptr = nnapi->ANeuralNetworks_getDeviceCount; + + NnApiHandler::Instance()->StubGetDeviceCountWith( + [](uint32_t* device_count) -> int { + return NnApiPassthroughInstance()->ANeuralNetworks_getDeviceCount == + device_count_ptr; + }); + + uint32_t device_count = 0; + EXPECT_THAT(nnapi->ANeuralNetworks_getDeviceCount(&device_count), Eq(1)); +} + +void ExpectEquals(const NnApi& left, const NnApi& right) { +#define EXPECT_NNAPI_MEMBER_EQ(name) EXPECT_EQ(left.name, right.name) + + EXPECT_NNAPI_MEMBER_EQ(nnapi_exists); + EXPECT_NNAPI_MEMBER_EQ(android_sdk_version); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksMemory_createFromFd); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksMemory_free); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksModel_create); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksModel_free); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksModel_finish); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksModel_addOperand); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksModel_setOperandValue); + EXPECT_NNAPI_MEMBER_EQ( + ANeuralNetworksModel_setOperandSymmPerChannelQuantParams); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksModel_setOperandValueFromMemory); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksModel_addOperation); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksModel_identifyInputsAndOutputs); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksModel_relaxComputationFloat32toFloat16); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksCompilation_create); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksCompilation_free); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksCompilation_setPreference); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksCompilation_finish); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_create); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_free); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_setInput); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_setInputFromMemory); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_setOutput); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_setOutputFromMemory); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_startCompute); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksEvent_wait); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksEvent_free); + EXPECT_NNAPI_MEMBER_EQ(ASharedMemory_create); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworks_getDeviceCount); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworks_getDevice); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksDevice_getName); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksDevice_getVersion); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksDevice_getFeatureLevel); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksDevice_getType); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksModel_getSupportedOperationsForDevices); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksCompilation_createForDevices); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksCompilation_setCaching); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_compute); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_getOutputOperandRank); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_getOutputOperandDimensions); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksBurst_create); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksBurst_free); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_burstCompute); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksMemory_createFromAHardwareBuffer); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_setMeasureTiming); + EXPECT_NNAPI_MEMBER_EQ(ANeuralNetworksExecution_getDuration); + +#undef EXPECT_NNAPI_MEMBER_EQ +} + +} // namespace nnapi +} // namespace tflite diff --git a/tensorflow/lite/optional_debug_tools.cc b/tensorflow/lite/optional_debug_tools.cc index 2d33f79c1f0..4e9b7d4e0a4 100644 --- a/tensorflow/lite/optional_debug_tools.cc +++ b/tensorflow/lite/optional_debug_tools.cc @@ -115,6 +115,14 @@ void PrintInterpreterState(Interpreter* interpreter) { PrintTfLiteIntVector(node.inputs); printf(" Outputs:"); PrintTfLiteIntVector(node.outputs); + if (node.intermediates && node.intermediates->size) { + printf(" Intermediates:"); + PrintTfLiteIntVector(node.intermediates); + } + if (node.temporaries && node.temporaries->size) { + printf(" Temporaries:"); + PrintTfLiteIntVector(node.temporaries); + } } } diff --git a/tensorflow/lite/profiling/memory_info.cc b/tensorflow/lite/profiling/memory_info.cc index 9e658f65408..39f94f0d250 100644 --- a/tensorflow/lite/profiling/memory_info.cc +++ b/tensorflow/lite/profiling/memory_info.cc @@ -35,20 +35,18 @@ MemoryUsage GetMemoryUsage() { result.max_rss_kb = res.ru_maxrss; } const auto mem = mallinfo(); - result.total_allocated_bytes = mem.uordblks; + result.total_allocated_bytes = mem.arena; + result.in_use_allocated_bytes = mem.uordblks; #endif return result; } -void MemoryUsage::SummaryToStream(std::ostream* stream) const { - *stream << "memory usage: max resident set size = " << max_rss_kb / 1024.0 +void MemoryUsage::AllStatsToStream(std::ostream* stream) const { + *stream << "max resident set size = " << max_rss_kb / 1024.0 << " MB, total malloc-ed size = " - << total_allocated_bytes / 1024.0 / 1024.0 << " MB"; -} - -void MemoryUsage::ShortSummaryToStream(std::ostream* stream) const { - *stream << "max_rss_mb=" << max_rss_kb / 1024.0 - << " total_malloced_mb=" << total_allocated_bytes / 1024.0 / 1024.0; + << total_allocated_bytes / 1024.0 / 1024.0 + << " MB, in-use allocated/mmapped size = " + << in_use_allocated_bytes / 1024.0 / 1024.0 << " MB"; } } // namespace memory diff --git a/tensorflow/lite/profiling/memory_info.h b/tensorflow/lite/profiling/memory_info.h index 370ca3d8ebf..b5bc3a07cf7 100644 --- a/tensorflow/lite/profiling/memory_info.h +++ b/tensorflow/lite/profiling/memory_info.h @@ -26,21 +26,30 @@ struct MemoryUsage { static const int kValueNotSet; MemoryUsage() - : max_rss_kb(kValueNotSet), total_allocated_bytes(kValueNotSet) {} + : max_rss_kb(kValueNotSet), + total_allocated_bytes(kValueNotSet), + in_use_allocated_bytes(kValueNotSet) {} // The maximum memory size (in kilobytes) occupied by an OS process that is // held in main memory (RAM). Such memory usage information is generally // referred as resident set size (rss). This is an alias to rusage::ru_maxrss. int64_t max_rss_kb; - // Total allocated space in bytes. This is an alias to mallinfo::uordblks. + // Total non-mmapped space allocated from system in bytes. This is an alias to + // mallinfo::arena. int total_allocated_bytes; + // Total allocated (including mmapped) bytes that's in use (i.e. excluding + // those are freed). This is an alias to mallinfo::uordblks. + int in_use_allocated_bytes; + MemoryUsage operator+(MemoryUsage const& obj) const { MemoryUsage res; res.max_rss_kb = max_rss_kb + obj.max_rss_kb; res.total_allocated_bytes = total_allocated_bytes + obj.total_allocated_bytes; + res.in_use_allocated_bytes = + in_use_allocated_bytes + obj.in_use_allocated_bytes; return res; } @@ -49,15 +58,16 @@ struct MemoryUsage { res.max_rss_kb = max_rss_kb - obj.max_rss_kb; res.total_allocated_bytes = total_allocated_bytes - obj.total_allocated_bytes; + res.in_use_allocated_bytes = + in_use_allocated_bytes - obj.in_use_allocated_bytes; return res; } - void SummaryToStream(std::ostream* stream) const; - void ShortSummaryToStream(std::ostream* stream) const; + void AllStatsToStream(std::ostream* stream) const; friend std::ostream& operator<<(std::ostream& stream, const MemoryUsage& obj) { - obj.SummaryToStream(&stream); + obj.AllStatsToStream(&stream); return stream; } }; diff --git a/tensorflow/lite/profiling/memory_info_test.cc b/tensorflow/lite/profiling/memory_info_test.cc index de595a2b2f1..5a359134160 100644 --- a/tensorflow/lite/profiling/memory_info_test.cc +++ b/tensorflow/lite/profiling/memory_info_test.cc @@ -25,23 +25,28 @@ TEST(MemoryUsage, AddAndSub) { MemoryUsage mem1, mem2; mem1.max_rss_kb = 5; mem1.total_allocated_bytes = 7000; + mem1.in_use_allocated_bytes = 2000; mem2.max_rss_kb = 3; - mem2.total_allocated_bytes = 5000; + mem2.total_allocated_bytes = 7000; + mem2.in_use_allocated_bytes = 4000; const auto add_mem = mem1 + mem2; EXPECT_EQ(8, add_mem.max_rss_kb); - EXPECT_EQ(12000, add_mem.total_allocated_bytes); + EXPECT_EQ(14000, add_mem.total_allocated_bytes); + EXPECT_EQ(6000, add_mem.in_use_allocated_bytes); const auto sub_mem = mem1 - mem2; EXPECT_EQ(2, sub_mem.max_rss_kb); - EXPECT_EQ(2000, sub_mem.total_allocated_bytes); + EXPECT_EQ(0, sub_mem.total_allocated_bytes); + EXPECT_EQ(-2000, sub_mem.in_use_allocated_bytes); } TEST(MemoryUsage, GetMemoryUsage) { MemoryUsage result; EXPECT_EQ(MemoryUsage::kValueNotSet, result.max_rss_kb); EXPECT_EQ(MemoryUsage::kValueNotSet, result.total_allocated_bytes); + EXPECT_EQ(MemoryUsage::kValueNotSet, result.in_use_allocated_bytes); #ifdef __linux__ // Just allocate some space in heap so that we could meaningful memory usage diff --git a/tensorflow/lite/profiling/profile_summarizer.cc b/tensorflow/lite/profiling/profile_summarizer.cc index d69b0a697d7..b004bc2e361 100644 --- a/tensorflow/lite/profiling/profile_summarizer.cc +++ b/tensorflow/lite/profiling/profile_summarizer.cc @@ -27,7 +27,7 @@ namespace { struct OperatorDetails { uint32_t subgraph_index; uint32_t node_index; - std::string name; + std::string op_description; std::vector inputs; std::vector outputs; }; @@ -74,20 +74,11 @@ OperatorDetails GetOperatorDetails(const tflite::Interpreter& interpreter, auto node_reg = subgraph->node_and_registration(node_index); auto inputs = node_reg->first.inputs; auto outputs = node_reg->first.outputs; - int code = node_reg->second.builtin_code; - const char* op_name = nullptr; - if (code == tflite::BuiltinOperator_CUSTOM) { - const char* custom_name = node_reg->second.custom_name; - op_name = custom_name ? custom_name : "UnknownCustomOp"; - } else { - op_name = tflite::EnumNamesBuiltinOperator()[code]; - } const char* profiling_string = interpreter.OpProfilingString(node_reg->second, &node_reg->first); OperatorDetails details; - details.name = op_name; if (profiling_string) { - details.name += ":" + std::string(profiling_string); + details.op_description = std::string(profiling_string); } details.inputs = GetTensorNames(interpreter, inputs); details.outputs = GetTensorNames(interpreter, outputs); @@ -132,9 +123,6 @@ void ProfileSummarizer::ProcessProfiles( int64_t base_start_us = events[0]->begin_timestamp_us; int node_num = 0; - auto tag_string = [](const string& s, const string& t) { - return (t == "OpInvoke" || t == "DelegateOpInvoke") ? s : s + "/" + t; - }; // Total time will be accumulated per subgraph. std::map total_us_per_subgraph_map; @@ -154,13 +142,16 @@ void ProfileSummarizer::ProcessProfiles( const auto op_details = GetOperatorDetails(interpreter, subgraph_index, node_index); - const auto type_in_stats = tag_string(op_details.name, event->tag); + std::string type_in_stats(event->tag); + if (!op_details.op_description.empty()) { + type_in_stats += "/" + op_details.op_description; + } const auto node_name = ToString(op_details.outputs); // Append node index to node name because 'stats_calculator' can not // distinguish two nodes w/ the same 'node_name'. const auto node_name_in_stats = - tag_string(node_name + ":" + std::to_string(node_index), event->tag); + node_name + ":" + std::to_string(node_index); stats_calculator->AddNodeStats(node_name_in_stats, type_in_stats, node_num, start_us, node_exec_time, diff --git a/tensorflow/lite/profiling/profile_summarizer_test.cc b/tensorflow/lite/profiling/profile_summarizer_test.cc index 6340921bc0e..0c4b9fcd88f 100644 --- a/tensorflow/lite/profiling/profile_summarizer_test.cc +++ b/tensorflow/lite/profiling/profile_summarizer_test.cc @@ -141,7 +141,7 @@ TEST(ProfileSummarizerTest, InterpreterPlusProfilingDetails) { summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter); auto output = summarizer.GetOutputString(); // TODO(shashishekhar): Add a better test here. - ASSERT_TRUE(output.find("SimpleOpEval:Profile") != std::string::npos) + ASSERT_TRUE(output.find("SimpleOpEval/Profile") != std::string::npos) << output; } diff --git a/tensorflow/lite/profiling/profiler.h b/tensorflow/lite/profiling/profiler.h index e75c90bf6b6..ff398698616 100644 --- a/tensorflow/lite/profiling/profiler.h +++ b/tensorflow/lite/profiling/profiler.h @@ -32,6 +32,5 @@ using Profiler = NoopProfiler; } // namespace tflite #define SCOPED_TAGGED_OPERATOR_PROFILE TFLITE_SCOPED_TAGGED_OPERATOR_PROFILE -#define SCOPED_OPERATOR_PROFILE TFLITE_SCOPED_OPERATOR_PROFILE #endif // TENSORFLOW_LITE_PROFILING_PROFILER_H_ diff --git a/tensorflow/lite/profiling/profiler_test.cc b/tensorflow/lite/profiling/profiler_test.cc index 57da951c8ce..cedb109697d 100644 --- a/tensorflow/lite/profiling/profiler_test.cc +++ b/tensorflow/lite/profiling/profiler_test.cc @@ -97,13 +97,13 @@ TEST(ProfilingTest, ProfilesAreCollected) { TEST(ProfilingTest, NullProfiler) { Profiler* profiler = nullptr; - { SCOPED_OPERATOR_PROFILE(profiler, 1); } + { SCOPED_TAGGED_OPERATOR_PROFILE(profiler, "noop", 1); } } TEST(ProfilingTest, ScopedProfile) { BufferedProfiler profiler(1024); profiler.StartProfiling(); - { SCOPED_OPERATOR_PROFILE(&profiler, 1); } + { SCOPED_TAGGED_OPERATOR_PROFILE(&profiler, "noop", 1); } profiler.StopProfiling(); auto profile_events = profiler.GetProfileEvents(); EXPECT_EQ(1, profile_events.size()); @@ -112,7 +112,7 @@ TEST(ProfilingTest, ScopedProfile) { TEST(ProfilingTest, NoopProfiler) { NoopProfiler profiler; profiler.StartProfiling(); - { SCOPED_OPERATOR_PROFILE(&profiler, 1); } + { SCOPED_TAGGED_OPERATOR_PROFILE(&profiler, "noop", 1); } profiler.StopProfiling(); auto profile_events = profiler.GetProfileEvents(); EXPECT_EQ(0, profile_events.size()); diff --git a/tensorflow/lite/python/tflite_convert.py b/tensorflow/lite/python/tflite_convert.py index d66fe0bb5a9..59e43be807a 100644 --- a/tensorflow/lite/python/tflite_convert.py +++ b/tensorflow/lite/python/tflite_convert.py @@ -205,8 +205,9 @@ def _convert_tf1_model(flags): if flags.conversion_summary_dir: converter.conversion_summary_dir = flags.conversion_summary_dir - if flags.experimental_new_converter: - converter.experimental_new_converter = True + # TODO(b/145312675): Enable the new converter by default. It requires to + # add a new command line argument like `experimental_legacy_converter`. + converter.experimental_new_converter = flags.experimental_new_converter # Convert model. output_data = converter.convert() @@ -230,8 +231,9 @@ def _convert_tf2_model(flags): model = keras.models.load_model(flags.keras_model_file) converter = lite.TFLiteConverterV2.from_keras_model(model) - if flags.experimental_new_converter: - converter.experimental_new_converter = True + # TODO(b/145312675): Enable the new converter by default. It requires to + # add a new command line argument like `experimental_legacy_converter`. + converter.experimental_new_converter = flags.experimental_new_converter # Convert the model. tflite_model = converter.convert() diff --git a/tensorflow/lite/testing/BUILD b/tensorflow/lite/testing/BUILD index b893ee0524f..25da7cedf01 100644 --- a/tensorflow/lite/testing/BUILD +++ b/tensorflow/lite/testing/BUILD @@ -491,8 +491,10 @@ edgetpu_ops = [ "depthwiseconv", # high error "fully_connected", "l2norm", # high error + "maximum", "max_pool", "mean", + "minimum", "mul", "pad", # high error "relu", diff --git a/tensorflow/lite/testing/generate_examples_lib.py b/tensorflow/lite/testing/generate_examples_lib.py index f57e25c68b5..1d257e1f3c7 100644 --- a/tensorflow/lite/testing/generate_examples_lib.py +++ b/tensorflow/lite/testing/generate_examples_lib.py @@ -237,6 +237,7 @@ class Options(object): # test sets. # TODO(juhoha): Separate the state from the options. self.multi_gen_state = None + self.use_experimental_converter = False def _prepare_dir(options): diff --git a/tensorflow/lite/testing/generated_examples_zip_test.cc b/tensorflow/lite/testing/generated_examples_zip_test.cc index 16b4675bb0d..d1b3d267eba 100644 --- a/tensorflow/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/lite/testing/generated_examples_zip_test.cc @@ -156,14 +156,22 @@ const std::map& GetKnownBrokenNnapiTests() { const std::map& GetKnownQuantizeBrokenTests() { static const std::map* const kQuantizeBrokenTests = new std::map({ - {R"(^\/conv.*fully_quantize=True)", "134594898"}, - {R"(^\/depthwiseconv.*fully_quantize=True)", "134594898"}, {R"(^\/sum.*fully_quantize=True)", "134594898"}, {R"(^\/l2norm.*fully_quantize=True)", "134594898"}, }); return *kQuantizeBrokenTests; } +const std::map& GetQuantizeTestsError() { + static const std::map* const kQuantizeBrokenTests = + new std::map({ + {R"(^\/conv_relu1.*fully_quantize=True)", 18}, + {R"(^\/conv_relu6.*fully_quantize=True)", 8}, + {R"(^\/maximum.*fully_quantize=True)", 8}, + }); + return *kQuantizeBrokenTests; +} + // Allows test data to be unarchived into a temporary directory and makes // sure those temporary directories are removed later. class ArchiveEnvironment : public ::testing::Environment { @@ -299,10 +307,16 @@ TEST_P(OpsTest, RunZipTests) { tflite::testing::TfLiteDriver test_driver( FLAGS_use_nnapi ? TfLiteDriver::DelegateType::kNnapi : TfLiteDriver::DelegateType::kNone); + + auto quantized_tests_error = GetQuantizeTestsError(); bool fully_quantize = false; if (test_path.find("fully_quantize=True") != std::string::npos) { - // TODO(b/134594898): Tighten this constraint. - test_driver.SetThreshold(0.2, 0.1); + for (const auto& p : quantized_tests_error) { + if (RE2::PartialMatch(test_name, p.first)) { + test_driver.SetQuantizationErrorMultiplier(p.second); + break; + } + } fully_quantize = true; } @@ -313,7 +327,6 @@ TEST_P(OpsTest, RunZipTests) { auto kBrokenNnapiTests = GetKnownBrokenNnapiTests(); broken_tests.insert(kBrokenNnapiTests.begin(), kBrokenNnapiTests.end()); } - auto quantize_broken_tests = GetKnownQuantizeBrokenTests(); bool result = tflite::testing::ParseAndRunTests(&tflite_stream, &test_driver); string message = test_driver.GetErrorMessage(); @@ -346,7 +359,7 @@ TEST_P(OpsTest, RunZipTests) { if (!result) { string bug_number; // See if the tests are potential quantize failures. - for (const auto& p : quantize_broken_tests) { + for (const auto& p : GetKnownQuantizeBrokenTests()) { if (RE2::PartialMatch(test_name, p.first)) { bug_number = p.second; break; diff --git a/tensorflow/lite/testing/op_tests/binary_op.py b/tensorflow/lite/testing/op_tests/binary_op.py index c9900dc288d..88702b0542f 100644 --- a/tensorflow/lite/testing/op_tests/binary_op.py +++ b/tensorflow/lite/testing/op_tests/binary_op.py @@ -129,7 +129,10 @@ def make_binary_op_tests(options, name="input2", shape=parameters["input_shape_2"]) out = binary_operator(input1, input2) - if parameters["activation"]: + # TODO(karimnosseir): Update condition after moving to new converter. + if parameters["activation"] and (not options.use_experimental_converter or + (parameters["dtype"] != tf.int32 and + parameters["dtype"] != tf.int64)): out = tf.nn.relu(out) return [input1, input2], [out] diff --git a/tensorflow/lite/testing/op_tests/relu1.py b/tensorflow/lite/testing/op_tests/relu1.py index 21c03c89454..ac92bac1cb2 100644 --- a/tensorflow/lite/testing/op_tests/relu1.py +++ b/tensorflow/lite/testing/op_tests/relu1.py @@ -30,8 +30,7 @@ def make_relu1_tests(options): # Chose a set of parameters test_parameters = [{ - "input_shape": [[], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3], - [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]], + "input_shape": [[], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3]], "fully_quantize": [True, False], "input_range": [(-2, 8)] }] diff --git a/tensorflow/lite/testing/tflite_driver.cc b/tensorflow/lite/testing/tflite_driver.cc index 795fb1fee99..47293016ab6 100644 --- a/tensorflow/lite/testing/tflite_driver.cc +++ b/tensorflow/lite/testing/tflite_driver.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include - #include "absl/strings/escaping.h" #include "tensorflow/lite/builtin_op_data.h" #include "tensorflow/lite/delegates/flex/delegate.h" @@ -37,6 +36,22 @@ namespace { const double kRelativeThreshold = 1e-2f; const double kAbsoluteThreshold = 1e-4f; +// For quantized tests, we use a different error measurement from float ones. +// Assumes the baseline is a always a float TF model. +// Error of a quantized model compared to the baseline comes from two sources: +// 1. the math done with quantized inputs, and +// 2. quantization of the output. +// Assumes there is no error introduced by source 1, the theoretical maximum +// error allowed for the output is 0.5 * scale, because scale is equal to the +// size of the quantization bucket. +// +// As a result, we use `scale` as a unit for measuring the quantization error. +// To add the error introduced by source 1 as well, we need to relax the +// multiplier from 0.5 to a larger number, which is model/op dependent. +// The number below is good enough to account for both the two sources of error +// for most quantized op tests to pass. +const int kQuantizationErrorMultiplier = 4; + // Returns the value in the given position in a tensor. template T Value(void* data, int index) { @@ -58,15 +73,31 @@ unique_void_ptr make_type_erased_array(size_t size) { [](void* data) { delete[] static_cast(data); }); } +bool IsQuantized(const TfLiteTensor& tensor) { + if (tensor.type != kTfLiteInt8) return false; + + if (tensor.quantization.params != nullptr) { + auto* quantization = + reinterpret_cast(tensor.quantization.params); + if (quantization->scale != nullptr && quantization->scale->size == 1 && + quantization->zero_point != nullptr && + quantization->zero_point->size == 1) { + return true; + } + } + return false; +} } // namespace class TfLiteDriver::DataExpectation { public: - DataExpectation(double relative_threshold, double absolute_threshold) + DataExpectation(double relative_threshold, double absolute_threshold, + int quantization_error_multiplier) : data_(nullptr, nullptr), num_elements_(0), relative_threshold_(relative_threshold), - absolute_threshold_(absolute_threshold) {} + absolute_threshold_(absolute_threshold), + quantization_error_multiplier_(quantization_error_multiplier) {} template void SetData(const string& csv_values) { @@ -128,11 +159,13 @@ class TfLiteDriver::DataExpectation { } bool TypedCheckString(bool verbose, const TfLiteTensor& tensor); + bool QuantizedCheck(bool verbose, const TfLiteTensor& tensor); unique_void_ptr data_; size_t num_elements_; double relative_threshold_; double absolute_threshold_; + int quantization_error_multiplier_; }; class TfLiteDriver::ShapeExpectation { @@ -218,8 +251,37 @@ bool TfLiteDriver::DataExpectation::TypedCheckString( return true; } +bool TfLiteDriver::DataExpectation::QuantizedCheck(bool verbose, + const TfLiteTensor& tensor) { + auto* quantization = + reinterpret_cast(tensor.quantization.params); + const float scale = quantization->scale->data[0]; + const int32 zero_point = quantization->zero_point->data[0]; + + bool good_result = true; + for (int i = 0; i < tensor.bytes; i++) { + const int32 computed = tensor.data.int8[i]; + const float dequantized = + static_cast(scale * (computed - zero_point)); + const float reference = Value(data_.get(), i); + if (std::abs(dequantized - reference) > + quantization_error_multiplier_ * scale) { + if (verbose) { + std::cerr << " index " << i << ": got " << dequantized + << ", but expected " << reference << std::endl; + } + good_result = false; + } + } + return good_result; +} + bool TfLiteDriver::DataExpectation::Check(bool verbose, const TfLiteTensor& tensor) { + if (IsQuantized(tensor)) { + return QuantizedCheck(verbose, tensor); + } + switch (tensor.type) { case kTfLiteFloat32: return TypedCheck(verbose, tensor); @@ -247,7 +309,8 @@ bool TfLiteDriver::DataExpectation::Check(bool verbose, TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel) : delegate_(nullptr, nullptr), relative_threshold_(kRelativeThreshold), - absolute_threshold_(kAbsoluteThreshold) { + absolute_threshold_(kAbsoluteThreshold), + quantization_error_multiplier_(kQuantizationErrorMultiplier) { if (reference_kernel) { resolver_.reset(new ops::builtin::BuiltinRefOpResolver); } else { @@ -395,6 +458,11 @@ void TfLiteDriver::SetThreshold(double relative_threshold, absolute_threshold_ = absolute_threshold; } +void TfLiteDriver::SetQuantizationErrorMultiplier( + int quantization_error_multiplier) { + quantization_error_multiplier_ = quantization_error_multiplier; +} + void TfLiteDriver::SetExpectation(int id, const string& csv_values) { if (!IsValid()) return; auto* tensor = interpreter_->tensor(id); @@ -402,7 +470,14 @@ void TfLiteDriver::SetExpectation(int id, const string& csv_values) { Invalidate(absl::StrCat("Overridden expectation for tensor '", id, "'")); } expected_output_[id].reset( - new DataExpectation(relative_threshold_, absolute_threshold_)); + new DataExpectation(relative_threshold_, absolute_threshold_, + quantization_error_multiplier_)); + + if (IsQuantized(*tensor)) { + expected_output_[id]->SetData(csv_values); + return; + } + switch (tensor->type) { case kTfLiteFloat32: expected_output_[id]->SetData(csv_values); diff --git a/tensorflow/lite/testing/tflite_driver.h b/tensorflow/lite/testing/tflite_driver.h index ae843d1cba7..258902606a5 100644 --- a/tensorflow/lite/testing/tflite_driver.h +++ b/tensorflow/lite/testing/tflite_driver.h @@ -64,6 +64,7 @@ class TfLiteDriver : public TestRunner { bool CheckResults() override; string ReadOutput(int id) override; void SetThreshold(double relative_threshold, double absolute_threshold); + void SetQuantizationErrorMultiplier(int quantization_error_multiplier); protected: Interpreter::TfLiteDelegatePtr delegate_; @@ -95,6 +96,7 @@ class TfLiteDriver : public TestRunner { std::map tensors_to_deallocate_; double relative_threshold_; double absolute_threshold_; + int quantization_error_multiplier_; }; } // namespace testing diff --git a/tensorflow/lite/testing/tflite_driver_test.cc b/tensorflow/lite/testing/tflite_driver_test.cc index 99efd2d66d1..6dac9565dde 100644 --- a/tensorflow/lite/testing/tflite_driver_test.cc +++ b/tensorflow/lite/testing/tflite_driver_test.cc @@ -112,7 +112,7 @@ TEST(TfliteDriverTest, AddQuantizedInt8Test) { runner->SetInput(1, "1,1,1,1"); - runner->SetExpectation(2, "3,3,3,3"); + runner->SetExpectation(2, "0.0117,0.0117,0.0117,0.0117"); runner->Invoke(); ASSERT_TRUE(runner->IsValid()); diff --git a/tensorflow/lite/testing/toco_convert.py b/tensorflow/lite/testing/toco_convert.py index f4072b241a0..e8d1e8eec12 100644 --- a/tensorflow/lite/testing/toco_convert.py +++ b/tensorflow/lite/testing/toco_convert.py @@ -112,9 +112,16 @@ def toco_convert(options, graph_def, input_tensors, output_tensors, **kwargs): graphdef_file.flush() input_shapes = zip_test_utils.get_input_shapes_map(input_tensors) - converter = tf.compat.v1.lite.TocoConverter.from_frozen_graph( + converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph( graphdef_file.name, input_arrays, output_tensors, input_shapes) + # TODO(b/145313371): Evaluate should we make it work with the new + # converter. + # Note: Currently this line is a non-functional change because the new + # converter is disabled by default. Since this code path doesn't work + # with new converter yet, it's explicitly disabled for easier testing. + converter.experimental_new_converter = False + def representative_dataset(input_tensors): calibration_inputs = [] for _, shape, _ in input_tensors: @@ -139,6 +146,8 @@ def toco_convert(options, graph_def, input_tensors, output_tensors, **kwargs): if extra_toco_options.inference_output_type: converter.inference_output_type = ( extra_toco_options.inference_output_type) + else: + converter.inference_output_type = tf.int8 try: tflite_model = converter.convert() diff --git a/tensorflow/lite/toco/tflite/op_version.cc b/tensorflow/lite/toco/tflite/op_version.cc index 39258339e0e..a7a829e77e3 100644 --- a/tensorflow/lite/toco/tflite/op_version.cc +++ b/tensorflow/lite/toco/tflite/op_version.cc @@ -74,7 +74,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) { {{OperatorType::kCast, 1}, "1.5.0"}, {{OperatorType::kConcatenation, 1}, "1.5.0"}, {{OperatorType::kConcatenation, 2}, "1.14.0"}, - {{OperatorType::kDepthToSpace, 1}, kPendingReleaseOpVersion}, + {{OperatorType::kDepthToSpace, 1}, "2.1.0"}, {{OperatorType::kFakeQuant, 1}, "1.5.0"}, {{OperatorType::kFakeQuant, 2}, "1.10.0"}, {{OperatorType::kFullyConnected, 1}, "1.5.0"}, @@ -82,7 +82,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) { {{OperatorType::kFullyConnected, 3}, "1.14.0"}, {{OperatorType::kFullyConnected, 4}, "1.14.0"}, {{OperatorType::kFullyConnected, 5}, "2.0.0"}, - {{OperatorType::kFullyConnected, 6}, kPendingReleaseOpVersion}, + {{OperatorType::kFullyConnected, 6}, "2.1.0"}, {{OperatorType::kGather, 1}, "1.6.0"}, {{OperatorType::kGather, 2}, "1.14.0"}, {{OperatorType::kGather, 3}, "1.15.0"}, @@ -145,7 +145,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) { {{OperatorType::kSplitV, 1}, "1.13.1"}, {{OperatorType::kStridedSlice, 1}, "1.6.0"}, {{OperatorType::kStridedSlice, 2}, "1.14.0"}, - {{OperatorType::kStridedSlice, 3}, kPendingReleaseOpVersion}, + {{OperatorType::kStridedSlice, 3}, "2.1.0"}, {{OperatorType::kTopK_V2, 1}, "1.7.0"}, {{OperatorType::kTopK_V2, 2}, "1.14.0"}, {{OperatorType::kArgMax, 1}, "1.9.0"}, @@ -205,7 +205,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) { {{OperatorType::kElu, 1}, "1.14.0"}, {{OperatorType::kRound, 1}, "1.14.0"}, {{OperatorType::kRelu, 1}, "1.5.0"}, - {{OperatorType::kRelu, 2}, kPendingReleaseOpVersion}, + {{OperatorType::kRelu, 2}, "2.1.0"}, {{OperatorType::kRelu1, 1}, "1.5.0"}, {{OperatorType::kPRelu, 1}, "1.8.0"}, {{OperatorType::kExp, 1}, "1.7.0"}, diff --git a/tensorflow/lite/tools/benchmark/benchmark_model.cc b/tensorflow/lite/tools/benchmark/benchmark_model.cc index f7ce6d86ab3..6c3fccc5e22 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_model.cc @@ -192,8 +192,13 @@ TfLiteStatus BenchmarkModel::Run() { inference_time_us, init_mem_usage, overall_mem_usage}); - TFLITE_LOG(INFO) << "Init " << init_mem_usage << std::endl - << "Overall " << overall_mem_usage; + TFLITE_LOG(INFO) + << "Note: as the benchmark tool itself affects memory footprint, the " + "following is only APPROXIMATE to the actual memory footprint of the " + "model at runtime. Take the information at your discretion."; + TFLITE_LOG(INFO) << "Peak memory footprint (MB): init=" + << init_mem_usage.max_rss_kb / 1024.0 + << " overall=" << overall_mem_usage.max_rss_kb / 1024.0; return status; } diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc index 89c83520ab8..dc4a43ee6cb 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc @@ -492,13 +492,15 @@ TfLiteStatus BenchmarkTfLiteModel::PrepareInputData() { } else if (t->type == kTfLiteUInt8) { int low = has_value_range ? low_range : 0; int high = has_value_range ? high_range : 254; + // std::uniform_int_distribution is specified not to support char types. t_data = CreateInputTensorData( - num_elements, std::uniform_int_distribution(low, high)); + num_elements, std::uniform_int_distribution(low, high)); } else if (t->type == kTfLiteInt8) { int low = has_value_range ? low_range : -127; int high = has_value_range ? high_range : 127; + // std::uniform_int_distribution is specified not to support char types. t_data = CreateInputTensorData( - num_elements, std::uniform_int_distribution(low, high)); + num_elements, std::uniform_int_distribution(low, high)); } else if (t->type == kTfLiteString) { // TODO(haoliang): No need to cache string tensors right now. } else { diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h index ca7731eed33..a6fc38a6180 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h @@ -90,8 +90,9 @@ class BenchmarkTfLiteModel : public BenchmarkModel { InputTensorData tmp; tmp.bytes = sizeof(T) * num_elements; T* raw = new T[num_elements]; - std::generate_n(raw, num_elements, - [&]() { return distribution(random_engine_); }); + std::generate_n(raw, num_elements, [&]() { + return static_cast(distribution(random_engine_)); + }); // Now initialize the type-erased unique_ptr (with custom deleter) from // 'raw'. tmp.data = std::unique_ptr( diff --git a/tensorflow/lite/tools/evaluation/utils.cc b/tensorflow/lite/tools/evaluation/utils.cc index ef48aabc399..290e7549908 100644 --- a/tensorflow/lite/tools/evaluation/utils.cc +++ b/tensorflow/lite/tools/evaluation/utils.cc @@ -121,7 +121,7 @@ Interpreter::TfLiteDelegatePtr CreateGPUDelegate( tflite::FlatBufferModel* model) { #if defined(__ANDROID__) TfLiteGpuDelegateOptionsV2 options = TfLiteGpuDelegateOptionsV2Default(); - options.is_precision_loss_allowed = 1; + options.inference_priority1 = TFLITE_GPU_INFERENCE_PRIORITY_MIN_LATENCY; options.inference_preference = TFLITE_GPU_INFERENCE_PREFERENCE_SUSTAINED_SPEED; diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD index 9e2b92ceb0c..71c368679aa 100644 --- a/tensorflow/lite/tools/optimize/BUILD +++ b/tensorflow/lite/tools/optimize/BUILD @@ -244,6 +244,7 @@ tf_cc_test( "//tensorflow/lite/tools/optimize:testdata/split.bin", "//tensorflow/lite/tools/optimize:testdata/maximum.bin", "//tensorflow/lite/tools/optimize:testdata/minimum.bin", + "//tensorflow/lite/tools/optimize:testdata/unpack.bin", ], tags = [ "tflite_not_portable_android", diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc index c2b35e4e916..6a8258fc27c 100644 --- a/tensorflow/lite/tools/optimize/operator_property.cc +++ b/tensorflow/lite/tools/optimize/operator_property.cc @@ -70,9 +70,9 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, property.version = 2; break; case BuiltinOperator_SPLIT: - property.arbitrary_outputs = true; // We skip input 0 since it is the split dim which is not real valued. property.inputs = {{1, {}}}; + property.arbitrary_outputs = true; property.restrict_same_input_output_scale = true; property.version = 2; break; @@ -391,6 +391,12 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, property.restrict_same_input_output_scale = true; property.version = 2; break; + case BuiltinOperator_UNPACK: + property.inputs = {{0, {}}}; + property.arbitrary_outputs = true; + property.restrict_same_input_output_scale = true; + property.version = 1; + break; default: // No quantized implementation exists for this operation. property.quantizable = false; diff --git a/tensorflow/lite/tools/optimize/quantize_model.cc b/tensorflow/lite/tools/optimize/quantize_model.cc index 304aad618d8..42db16eb965 100644 --- a/tensorflow/lite/tools/optimize/quantize_model.cc +++ b/tensorflow/lite/tools/optimize/quantize_model.cc @@ -711,6 +711,11 @@ TfLiteStatus QuantizeIntemediateTensors(ModelT* model, // Quantize tensros that have shared range. For example, in LSTM, the output // tensor and input state tensor should share the same range because they are // using the same scale and zero point. +// We have to model this explicitely because the output is modeled as an extra +// tensor in LSTM. In calibrator, state tensors are logged both before and after +// the inferece so the range is fully captured. But output, although it is +// identical to activation, is not a state tensor the input value (range) of the +// very first inference is not captured. TfLiteStatus QuantizeSharedRange(ModelT* model, ErrorReporter* error_reporter) { for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size(); subgraph_idx++) { diff --git a/tensorflow/lite/tools/optimize/quantize_model_test.cc b/tensorflow/lite/tools/optimize/quantize_model_test.cc index eaf0cfbd694..95c71fe0861 100644 --- a/tensorflow/lite/tools/optimize/quantize_model_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_model_test.cc @@ -1206,10 +1206,61 @@ TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) { EXPECT_EQ(subgraph->tensors[5]->name, "input"); EXPECT_EQ(subgraph->tensors[6]->name, "output"); } + INSTANTIATE_TEST_SUITE_P(MinimumMaximumTestInst, QuantizeMinimumMaximumTest, testing::ValuesIn({internal::kModelWithMinimumOp, internal::kModelWithMaximumOp})); +class QuantizeUnpackTest : public QuantizeModelTest { + protected: + QuantizeUnpackTest() { + input_model_ = ReadModel(internal::kModelWithUnpack); + + readonly_model_ = input_model_->GetModel(); + readonly_model_->UnPackTo(&model_); + } +}; + +TEST_F(QuantizeUnpackTest, VerifyUnpack) { + auto status = QuantizeModel(&builder_, &model_, &error_reporter_); + + ASSERT_EQ(kTfLiteOk, status); + + const auto subgraph = model_.subgraphs[0].get(); + auto op = subgraph->operators[1].get(); + + auto float_graph = readonly_model_->subgraphs()->Get(0); + + ASSERT_EQ(model_.operator_codes[op->opcode_index].get()->builtin_code, + BuiltinOperator_UNPACK); + + // Get unpack input and output tensors + auto unpack_input = subgraph->tensors[op->inputs[0]].get(); + auto unpack_output_0 = subgraph->tensors[op->outputs[0]].get(); + auto unpack_output_1 = subgraph->tensors[op->outputs[1]].get(); + + // Verify Unpack input is quantized. + ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(), + TensorType_FLOAT32); + EXPECT_EQ(unpack_input->type, TensorType_INT8); + + // The model should only have one input and 2 outputs. + EXPECT_EQ(subgraph->inputs.size(), 1); + EXPECT_EQ(subgraph->outputs.size(), 2); + + // Ensure quantization parameters before and after unpack + // are preserved after quantization for all outputs of + // unpack. + EXPECT_FLOAT_EQ(unpack_input->quantization->scale[0], + unpack_output_0->quantization->scale[0]); + EXPECT_FLOAT_EQ(unpack_input->quantization->scale[0], + unpack_output_1->quantization->scale[0]); + EXPECT_FLOAT_EQ(unpack_input->quantization->zero_point[0], + unpack_output_0->quantization->zero_point[0]); + EXPECT_FLOAT_EQ(unpack_input->quantization->zero_point[0], + unpack_output_1->quantization->zero_point[0]); +} + } // namespace } // namespace optimize } // namespace tflite diff --git a/tensorflow/lite/tools/optimize/test_util.cc b/tensorflow/lite/tools/optimize/test_util.cc index d66b355dfd1..4aea0bb0fed 100644 --- a/tensorflow/lite/tools/optimize/test_util.cc +++ b/tensorflow/lite/tools/optimize/test_util.cc @@ -55,6 +55,8 @@ const char* kLstmQuantized = "lstm_quantized.bin"; const char* kModelWithMinimumOp = "minimum.bin"; const char* kModelWithMaximumOp = "maximum.bin"; +const char* kModelWithUnpack = "unpack.bin"; + int FailOnErrorReporter::Report(const char* format, va_list args) { char buf[1024]; vsnprintf(buf, sizeof(buf), format, args); diff --git a/tensorflow/lite/tools/optimize/test_util.h b/tensorflow/lite/tools/optimize/test_util.h index 6f19c2058f0..845dfd813a4 100644 --- a/tensorflow/lite/tools/optimize/test_util.h +++ b/tensorflow/lite/tools/optimize/test_util.h @@ -86,6 +86,9 @@ extern const char* kModelWithMinimumOp; // Test model with a maximum op. extern const char* kModelWithMaximumOp; +// Test model with an unpack op. +extern const char* kModelWithUnpack; + // An error reporter that fails on testing. class FailOnErrorReporter : public ErrorReporter { public: diff --git a/tensorflow/lite/tools/optimize/testdata/unpack.bin b/tensorflow/lite/tools/optimize/testdata/unpack.bin new file mode 100644 index 00000000000..72e58bfa1ea Binary files /dev/null and b/tensorflow/lite/tools/optimize/testdata/unpack.bin differ diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 2f0241146b8..e9e74e85ffa 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3,6 +3,21 @@ # Public targets: # ":platform" - Low-level and platform-specific Python code. +load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension") +load("//tensorflow:tensorflow.bzl", "pybind_extension") +load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_tests") +load("//tensorflow:tensorflow.bzl", "tf_external_workspace_visible") +load("//tensorflow/core/platform:build_config.bzl", "pyx_library", "tf_additional_all_protos", "tf_additional_cupti_test_flags", "tf_additional_lib_deps", "tf_proto_library", "tf_proto_library_py", "tf_protos_grappler") # @unused +load("//tensorflow/core/platform:build_config_root.bzl", "if_static", "tf_additional_plugin_deps") +load("//tensorflow/python:build_defs.bzl", "tf_gen_op_wrapper_private_py") +load( + "//third_party/ngraph:build_defs.bzl", + "if_ngraph", +) + visibility = [ "//engedu/ml/tf_from_scratch:__pkg__", "//third_party/cloud_tpu/convergence_tools:__subpackages__", @@ -19,20 +34,6 @@ visibility = [ "//bazel_pip/tensorflow/lite/toco/python:__pkg__", ] -load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test") -load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension") -load("//tensorflow:tensorflow.bzl", "pybind_extension") -load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") -load("//tensorflow:tensorflow.bzl", "cuda_py_test") -load("//tensorflow:tensorflow.bzl", "cuda_py_tests") -load("//tensorflow/core/platform:build_config.bzl", "pyx_library", "tf_additional_all_protos", "tf_additional_cupti_test_flags", "tf_additional_lib_deps", "tf_proto_library", "tf_proto_library_py", "tf_protos_grappler") # @unused -load("//tensorflow/core/platform:build_config_root.bzl", "if_static", "tf_additional_plugin_deps") -load("//tensorflow/python:build_defs.bzl", "tf_gen_op_wrapper_private_py") -load( - "//third_party/ngraph:build_defs.bzl", - "if_ngraph", -) - package( default_visibility = visibility, licenses = ["notice"], # Apache 2.0 @@ -400,13 +401,26 @@ cc_library( ], ) +tf_python_pybind_extension( + name = "_pywrap_bfloat16", + srcs = ["lib/core/bfloat16_wrapper.cc"], + hdrs = ["lib/core/bfloat16.h"], + module_name = "_pywrap_bfloat16", + deps = [ + "//third_party/python_runtime:headers", + "@pybind11", + ], +) + cc_library( name = "ndarray_tensor_bridge", srcs = ["lib/core/ndarray_tensor_bridge.cc"], hdrs = ["lib/core/ndarray_tensor_bridge.h"], - visibility = visibility + [ - "//learning/deepmind/courier:__subpackages__", - ], + visibility = tf_external_workspace_visible( + visibility + [ + "//learning/deepmind/courier:__subpackages__", + ], + ), deps = [ ":bfloat16_lib", ":numpy_lib", @@ -421,7 +435,7 @@ cc_library( srcs = ["lib/core/py_exception_registry.cc"], hdrs = ["lib/core/py_exception_registry.h"], deps = [ - "//tensorflow/c:tf_status", + "//tensorflow/c:tf_status_headers", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//third_party/python_runtime:headers", @@ -432,6 +446,7 @@ cc_library( name = "pybind11_absl", hdrs = ["lib/core/pybind11_absl.h"], features = ["-parse_headers"], + visibility = tf_external_workspace_visible(visibility), deps = [ "//tensorflow/core/platform:stringpiece", "@pybind11", @@ -442,6 +457,7 @@ cc_library( name = "pybind11_lib", hdrs = ["lib/core/pybind11_lib.h"], features = ["-parse_headers"], + visibility = tf_external_workspace_visible(visibility), deps = [ "@pybind11", ], @@ -454,9 +470,10 @@ cc_library( "//tensorflow/c:headers", ], features = ["-parse_headers"], + visibility = tf_external_workspace_visible(visibility), deps = [ ":py_exception_registry", - "//tensorflow/c:tf_status", + "//tensorflow/c:tf_status_headers", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//third_party/python_runtime:headers", @@ -468,6 +485,7 @@ cc_library( name = "pybind11_proto", hdrs = ["lib/core/pybind11_proto.h"], features = ["-parse_headers"], + visibility = tf_external_workspace_visible(visibility), deps = [ "@com_google_absl//absl/strings", "@pybind11", @@ -769,9 +787,9 @@ cc_library( name = "ndarray_tensor", srcs = ["lib/core/ndarray_tensor.cc"], hdrs = ["lib/core/ndarray_tensor.h"], - visibility = visibility + [ + visibility = tf_external_workspace_visible(visibility + [ "//learning/deepmind/courier:__subpackages__", - ], + ]), deps = [ ":bfloat16_lib", ":ndarray_tensor_bridge", @@ -1158,6 +1176,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":_dtypes", + ":_pywrap_bfloat16", ":pywrap_tensorflow", "//tensorflow/core:protos_all_py", ], @@ -3687,9 +3706,9 @@ py_library( srcs = ["ops/math_ops.py"], srcs_version = "PY2AND3", deps = [ - "constant_op", ":array_ops", ":common_shapes", + ":constant_op", ":control_flow_ops_gen", ":data_flow_ops_gen", ":dtypes", @@ -5442,7 +5461,6 @@ tf_py_wrap_cc( "grappler/cost_analyzer.i", "grappler/item.i", "grappler/tf_optimizer.i", - "lib/core/bfloat16.i", "lib/core/strings.i", "lib/io/file_io.i", "lib/io/py_record_reader.i", @@ -5528,6 +5546,7 @@ WIN_LIB_FILES_FOR_EXPORTED_SYMBOLS = [ ":numpy_lib", # checkpoint_reader ":safe_ptr", # checkpoint_reader ":python_op_gen", # python_op_gen + ":bfloat16_lib", # bfloat16 "//tensorflow/core/util/tensor_bundle", # checkpoint_reader ] @@ -5586,17 +5605,20 @@ filegroup( cc_import( name = "_pywrap_tensorflow_internal_linux", shared_library = "//tensorflow/python:lib_pywrap_tensorflow_internal.so", + visibility = tf_external_workspace_visible(visibility), ) cc_import( name = "_pywrap_tensorflow_internal_macos", shared_library = "//tensorflow/python:lib_pywrap_tensorflow_internal.dylib", + visibility = tf_external_workspace_visible(visibility), ) cc_import( name = "_pywrap_tensorflow_internal_windows", interface_library = "//tensorflow/python:pywrap_tensorflow_import_lib_file", shared_library = "//tensorflow/python:_pywrap_tensorflow_internal.dll", + visibility = tf_external_workspace_visible(visibility), ) # Rename the import library for _pywrap_tensorflow_internal.pyd to _pywrap_tensorflow_internal.lib diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 10034a9ed65..7ba4d4278fc 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -106,6 +106,7 @@ from tensorflow.python.ops import sets from tensorflow.python.ops import stateful_random_ops from tensorflow.python.ops.distributions import distributions from tensorflow.python.ops.linalg import linalg +from tensorflow.python.ops.linalg.sparse import sparse from tensorflow.python.ops.losses import losses from tensorflow.python.ops.signal import signal from tensorflow.python.profiler import profiler diff --git a/tensorflow/python/autograph/converters/arg_defaults_test.py b/tensorflow/python/autograph/converters/arg_defaults_test.py index 33dabe52839..6448f3124db 100644 --- a/tensorflow/python/autograph/converters/arg_defaults_test.py +++ b/tensorflow/python/autograph/converters/arg_defaults_test.py @@ -20,7 +20,7 @@ from __future__ import print_function from tensorflow.python.autograph.converters import arg_defaults from tensorflow.python.autograph.core import converter_testing -from tensorflow.python.autograph.pyct import compiler +from tensorflow.python.autograph.pyct import parser from tensorflow.python.platform import test @@ -28,8 +28,7 @@ class ArgDefaultsTransformerTest(converter_testing.TestCase): def assertTransformedFirstLineIs(self, node, expected): self.assertEqual( - compiler.ast_to_source(node, - include_encoding_marker=False).split('\n')[0], + parser.unparse(node, include_encoding_marker=False).split('\n')[0], expected) def test_no_args(self): diff --git a/tensorflow/python/autograph/converters/conditional_expressions.py b/tensorflow/python/autograph/converters/conditional_expressions.py index 4538b16660c..125ef5375be 100644 --- a/tensorflow/python/autograph/converters/conditional_expressions.py +++ b/tensorflow/python/autograph/converters/conditional_expressions.py @@ -27,8 +27,15 @@ class ConditionalExpressionTransformer(converter.Base): def visit_IfExp(self, node): return templates.replace_as_expression( - '''ag__.if_stmt(test, lambda: true_expr, - lambda: false_expr, lambda: (), lambda _: None)''', + '''ag__.if_stmt( + test, + lambda: true_expr, + lambda: false_expr, + lambda: (), + lambda _: None, + ('',), + ()) + ''', test=node.test, true_expr=node.body, false_expr=node.orelse) diff --git a/tensorflow/python/autograph/converters/control_flow.py b/tensorflow/python/autograph/converters/control_flow.py index 8f4170281b4..5bf488cd209 100644 --- a/tensorflow/python/autograph/converters/control_flow.py +++ b/tensorflow/python/autograph/converters/control_flow.py @@ -383,6 +383,9 @@ class ControlFlowTransformer(converter.Base): composite_symbol_names = tuple( gast.Str(str(symbol)) for symbol in composite_loop_vars) + # TODO(b/140125096): Populate. + opts = gast.Dict([], []) + # TODO(mdan): Use a single template. # If the body and test functions took a single tuple for loop_vars, instead # of *loop_vars, then a single template could be used. @@ -401,7 +404,8 @@ class ControlFlowTransformer(converter.Base): state_setter_name, (loop_vars,), (basic_symbol_names,), - (composite_symbol_names,)) + (composite_symbol_names,), + opts) """ node = templates.replace( template, @@ -415,7 +419,8 @@ class ControlFlowTransformer(converter.Base): state_getter_name=state_getter_name, state_setter_name=state_setter_name, basic_symbol_names=basic_symbol_names, - composite_symbol_names=composite_symbol_names) + composite_symbol_names=composite_symbol_names, + opts=opts) else: template = """ state_functions @@ -431,7 +436,8 @@ class ControlFlowTransformer(converter.Base): state_setter_name, (), (), - (composite_symbol_names,)) + (composite_symbol_names,), + opts) """ node = templates.replace( template, @@ -442,7 +448,8 @@ class ControlFlowTransformer(converter.Base): state_functions=state_functions, state_getter_name=state_getter_name, state_setter_name=state_setter_name, - composite_symbol_names=composite_symbol_names) + composite_symbol_names=composite_symbol_names, + opts=opts) undefined_assigns = self._create_undefined_assigns(possibly_undefs) return undefined_assigns + node @@ -500,6 +507,9 @@ class ControlFlowTransformer(converter.Base): composite_symbol_names = tuple( gast.Str(str(symbol)) for symbol in composite_loop_vars) + # TODO(b/140125096): Populate. + opts = gast.Dict([], []) + # TODO(mdan): Use a single template. # If the body and test functions took a single tuple for loop_vars, instead # of *loop_vars, then a single template could be used. @@ -520,7 +530,8 @@ class ControlFlowTransformer(converter.Base): state_setter_name, (loop_vars,), (basic_symbol_names,), - (composite_symbol_names,)) + (composite_symbol_names,), + opts) """ return templates.replace( template, @@ -538,7 +549,8 @@ class ControlFlowTransformer(converter.Base): state_getter_name=state_getter_name, state_setter_name=state_setter_name, basic_symbol_names=basic_symbol_names, - composite_symbol_names=composite_symbol_names) + composite_symbol_names=composite_symbol_names, + opts=opts) else: template = """ undefined_assigns @@ -556,7 +568,8 @@ class ControlFlowTransformer(converter.Base): state_setter_name, (), (), - (composite_symbol_names,)) + (composite_symbol_names,), + opts) """ return templates.replace( template, @@ -571,7 +584,8 @@ class ControlFlowTransformer(converter.Base): state_functions=state_functions, state_getter_name=state_getter_name, state_setter_name=state_setter_name, - composite_symbol_names=composite_symbol_names) + composite_symbol_names=composite_symbol_names, + opts=opts) def transform(node, ctx): diff --git a/tensorflow/python/autograph/converters/directives.py b/tensorflow/python/autograph/converters/directives.py index b712c21d364..fe1c75a5864 100644 --- a/tensorflow/python/autograph/converters/directives.py +++ b/tensorflow/python/autograph/converters/directives.py @@ -98,9 +98,9 @@ class DirectivesTransformer(converter.Base): raise ValueError( '"%s" must be used inside a statement' % directive.__name__) target = self.get_local(ENCLOSING_LOOP) - node_anno = anno.getanno(target, converter.AgAnno.DIRECTIVES, {}) + node_anno = anno.getanno(target, anno.Basic.DIRECTIVES, {}) node_anno[directive] = _map_args(call_node, directive) - anno.setanno(target, converter.AgAnno.DIRECTIVES, node_anno) + anno.setanno(target, anno.Basic.DIRECTIVES, node_anno) return call_node def visit_Name(self, node): diff --git a/tensorflow/python/autograph/converters/directives_test.py b/tensorflow/python/autograph/converters/directives_test.py index 62de7d6229a..545094521ec 100644 --- a/tensorflow/python/autograph/converters/directives_test.py +++ b/tensorflow/python/autograph/converters/directives_test.py @@ -20,7 +20,6 @@ from __future__ import print_function from tensorflow.python.autograph.converters import directives as directives_converter from tensorflow.python.autograph.core import converter_testing -from tensorflow.python.autograph.core.converter import AgAnno from tensorflow.python.autograph.lang import directives from tensorflow.python.autograph.pyct import anno from tensorflow.python.autograph.pyct import parser @@ -68,7 +67,7 @@ class DirectivesTest(converter_testing.TestCase): node, ctx = self.prepare(test_fn, {'directives': directives}) node = directives_converter.transform(node, ctx) - d = anno.getanno(node.body[1], AgAnno.DIRECTIVES) + d = anno.getanno(node.body[1], anno.Basic.DIRECTIVES) d = d[directives.set_loop_options] self.assertEqual(d['parallel_iterations'].n, 10) self.assertEqual(d['back_prop'].id, 'a') diff --git a/tensorflow/python/autograph/core/config.py b/tensorflow/python/autograph/core/config.py index 41d05ce6502..b336ea771d3 100644 --- a/tensorflow/python/autograph/core/config.py +++ b/tensorflow/python/autograph/core/config.py @@ -49,6 +49,7 @@ CONVERSION_RULES = ( # Known libraries DoNotConvert('numpy'), DoNotConvert('tensorflow'), + DoNotConvert('PIL'), # TODO(b/133417201): Remove. DoNotConvert('tensorflow_probability'), diff --git a/tensorflow/python/autograph/core/converter.py b/tensorflow/python/autograph/core/converter.py index e9bf009d029..e286e38d855 100644 --- a/tensorflow/python/autograph/core/converter.py +++ b/tensorflow/python/autograph/core/converter.py @@ -69,7 +69,6 @@ import enum from tensorflow.python.autograph.pyct import anno from tensorflow.python.autograph.pyct import ast_util from tensorflow.python.autograph.pyct import cfg -from tensorflow.python.autograph.pyct import compiler from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import qual_names from tensorflow.python.autograph.pyct import templates @@ -329,10 +328,10 @@ class Base(transformer.Base): for other_value in arg_values_found[1:]: if not ast_util.matches(first_value, other_value): qn = anno.getanno(node, anno.Basic.QN) - raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' % - (qn, directive.__name__, arg, - compiler.ast_to_source(other_value).strip(), - compiler.ast_to_source(first_value).strip())) + raise ValueError( + '%s has ambiguous annotations for %s(%s): %s, %s' % + (qn, directive.__name__, arg, parser.unparse(other_value).strip(), + parser.unparse(first_value).strip())) return first_value def visit(self, node): @@ -355,15 +354,6 @@ class AnnotatedDef(reaching_definitions.Definition): self.directives = {} -class AgAnno(enum.Enum): - """Annotation labels specific to AutoGraph. See anno.py.""" - - DIRECTIVES = 'User directives associated with the annotated statement.' - - def __repr__(self): - return self.name - - def standard_analysis(node, context, is_initial=False): """Performs a complete static analysis of the given code. diff --git a/tensorflow/python/autograph/core/converter_test.py b/tensorflow/python/autograph/core/converter_test.py index 2d5b33465e0..030ec761d95 100644 --- a/tensorflow/python/autograph/core/converter_test.py +++ b/tensorflow/python/autograph/core/converter_test.py @@ -21,7 +21,7 @@ from __future__ import print_function from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import converter_testing from tensorflow.python.autograph.pyct import anno -from tensorflow.python.autograph.pyct import compiler +from tensorflow.python.autograph.pyct import loader from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import templates from tensorflow.python.platform import test @@ -43,7 +43,7 @@ class ConversionOptionsTest(converter_testing.TestCase): ''' opts_packed = templates.replace(template, opts_ast=opts_ast) - reparsed, _, _ = compiler.ast_to_object(opts_packed) + reparsed, _, _ = loader.load_ast(opts_packed) reparsed.__dict__['ag__'] = self.make_fake_mod( 'fake_ag', converter.ConversionOptions, converter.Feature) diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py index b11f210a951..4ea1187f8ed 100644 --- a/tensorflow/python/autograph/core/converter_testing.py +++ b/tensorflow/python/autograph/core/converter_testing.py @@ -32,7 +32,7 @@ from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import function_wrappers from tensorflow.python.autograph.core import naming from tensorflow.python.autograph.lang import special_functions -from tensorflow.python.autograph.pyct import compiler +from tensorflow.python.autograph.pyct import loader from tensorflow.python.autograph.pyct import origin_info from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import pretty_printer @@ -97,7 +97,7 @@ class TestCase(test.TestCase): return f(*args, **kwargs) try: - result, source, source_map = compiler.ast_to_object( + result, source, source_map = loader.load_ast( node, include_source_map=True) # TODO(mdan): Move the unparsing from converter into pyct and reuse here. @@ -120,7 +120,7 @@ class TestCase(test.TestCase): if source is None: print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False)) else: - print('Offending compiled code:\n%s' % source) + print('Offending source code:\n%s' % source) raise @contextlib.contextmanager diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index b8b0eeee63c..17104c10c1b 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -422,6 +422,27 @@ def converted_call(f, logging.log(2, 'Whitelisted: %s: AutoGraph is disabled in context', f) return _call_unconverted(f, args, kwargs, options, False) + if is_autograph_artifact(f): + logging.log(2, 'Permanently whitelisted: %s: AutoGraph artifact', f) + return _call_unconverted(f, args, kwargs, options) + + # If this is a partial, unwrap it and redo all the checks. + if isinstance(f, functools.partial): + new_kwargs = {} + if f.keywords is not None: + new_kwargs = f.keywords + if kwargs is not None: + new_kwargs.update(kwargs) + new_args = f.args + args + logging.log(3, 'Forwarding call of partial %s with\n%s\n%s\n', f, new_args, + new_kwargs) + return converted_call( + f.func, + new_args, + new_kwargs, + caller_fn_scope=caller_fn_scope, + options=options) + if inspect_utils.isbuiltin(f): if f is eval: return py_builtins.eval_in_original_context(f, args, caller_fn_scope) @@ -432,10 +453,6 @@ def converted_call(f, else: return py_builtins.overload_of(f)(*args) - if is_autograph_artifact(f): - logging.log(2, 'Permanently whitelisted: %s: AutoGraph artifact', f) - return _call_unconverted(f, args, kwargs, options) - # TODO(b/122265385): Remove this bypass. if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper') or _is_known_loaded_type(f, 'wrapt', 'BoundFunctionWrapper')): @@ -453,7 +470,7 @@ def converted_call(f, # Constructors are permanently whitelisted. # TODO(mdan): Toggle as experimental feature instead. # TODO(b/124016764): Remove this limitation. - if tf_inspect.isclass(f): + if inspect_utils.isconstructor(f): logging.log(2, 'Permanently whitelisted: %s: constructor', f) return _call_unconverted(f, args, kwargs, options) @@ -484,19 +501,6 @@ def converted_call(f, # TODO(mdan): Move this entire block inside to_graph. try: # Begin of transformation error guards - # Unwrap functools.partial objects - # TODO(mdan): Consider sharing unwrapping logic with tf_inspect. - # TODO(b/120224672): This unwrapping should be done before the checks above. - while isinstance(f, functools.partial): - args = f.args + args - new_kwargs = {} - if f.keywords is not None: - new_kwargs.update(f.keywords) - if kwargs is not None: - new_kwargs.update(kwargs) - kwargs = new_kwargs - f = f.func - if tf_inspect.isfunction(f) or tf_inspect.ismethod(f): # Regular functions target_entity = f @@ -508,19 +512,12 @@ def converted_call(f, else: effective_args = args - elif hasattr(f, '__call__') and hasattr(f, '__class__'): - # Callable objects - target_entity = f.__call__ + elif hasattr(f, '__class__') and hasattr(f.__class__, '__call__'): + # Callable objects. Dunder methods have special lookup rules, see: + # https://docs.python.org/3/reference/datamodel.html#specialnames + target_entity = f.__class__.__call__ effective_args = (f,) + args - elif tf_inspect.isclass(f): - # Constructors - # Note: Until we support class constructurs, and enable whole-class - # conversion with an experimental flag, this branch is dead code. - # TODO(mdan): Consider removing unless there is a compelling use case. - target_entity = f - effective_args = args - else: target_entity = f raise NotImplementedError('unknown callable type "%s"' % type(f)) diff --git a/tensorflow/python/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py index a3d9a870def..e9b9fc75150 100644 --- a/tensorflow/python/autograph/impl/api_test.py +++ b/tensorflow/python/autograph/impl/api_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import collections import contextlib import functools @@ -416,21 +417,50 @@ class ApiTest(test.TestCase): def test_converted_call_callable_metaclass(self): + test_self = self + class TestMetaclass(type): def __call__(cls): self.assertTrue(converter_testing.is_inside_generated_code()) inst = object.__new__(cls) inst.__init__() + + def instance_call(unused_self): + test_self.fail( + 'The class-bound __call__ should be called, not the instance' + ' bound one.') + + inst.__call__ = instance_call return inst tmc = TestMetaclass('TestClass', (), {}) - # This functools.partial will hide the class form the constructor - # check. Not ideal. See b/120224672. - tc = api.converted_call( - functools.partial(tmc), (), None, options=DEFAULT_RECURSIVE) + tc = api.converted_call(tmc, (), None, options=DEFAULT_RECURSIVE) self.assertIsInstance(tc, tmc) + def test_converted_call_callable_abc(self): + + test_self = self + + @six.add_metaclass(abc.ABCMeta) + class TestBase(object): + + @abc.abstractmethod + def __call__(self): + test_self.fail('This should not be called') + + class TestSubclass(TestBase): + + def __init__(self): + test_self.assertFalse(converter_testing.is_inside_generated_code()) + + def __call__(self, expected): + test_self.assertTrue(expected) + test_self.assertTrue(converter_testing.is_inside_generated_code()) + + tc = api.converted_call(TestSubclass, (), None, options=DEFAULT_RECURSIVE) + api.converted_call(tc, (True,), None, options=DEFAULT_RECURSIVE) + @test_util.run_deprecated_v1 def test_converted_call_constructor(self): @@ -467,6 +497,15 @@ class ApiTest(test.TestCase): ag_logging.set_verbosity(0, False) os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1' + def test_converted_call_partial_of_whitelisted_method(self): + + def test_fn(_): + self.assertFalse(converter_testing.is_inside_generated_code()) + + converter_testing.whitelist(test_fn) + api.converted_call( + functools.partial(test_fn, None), (), None, options=DEFAULT_RECURSIVE) + def test_converted_call_already_converted(self): def f(x): @@ -973,7 +1012,7 @@ class ApiTest(test.TestCase): return x # Just check that the output is parseable Python code. - self.assertIsNotNone(parser.parse_str(api.to_code(test_fn))) + self.assertIsNotNone(parser.parse(api.to_code(test_fn))) def test_to_code_with_wrapped_function(self): diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py index 4c8555eb293..c256b4e8e65 100644 --- a/tensorflow/python/autograph/impl/conversion.py +++ b/tensorflow/python/autograph/impl/conversion.py @@ -52,7 +52,7 @@ from tensorflow.python.autograph.core import naming from tensorflow.python.autograph.core import unsupported_features_checker from tensorflow.python.autograph.lang import special_functions from tensorflow.python.autograph.pyct import ast_util -from tensorflow.python.autograph.pyct import compiler +from tensorflow.python.autograph.pyct import loader from tensorflow.python.autograph.pyct import inspect_utils from tensorflow.python.autograph.pyct import origin_info from tensorflow.python.autograph.pyct import parser @@ -282,8 +282,7 @@ def _convert_with_cache(entity, program_ctx, free_nonglobal_var_names): free_nonglobal_var_names, entity_info.future_features) - module, _, source_map = compiler.ast_to_object( - nodes, include_source_map=True) + module, _, source_map = loader.load_ast(nodes, include_source_map=True) module_name = module.__name__ converted_entity_info = _ConvertedEntityFactoryInfo( @@ -519,8 +518,7 @@ def convert_entity_to_ast(o, program_ctx): 'supported for now.' % (o, type(o))) if logging.has_verbosity(2): - logging.log(2, 'Compiled output of %s:\n\n%s\n', o, - compiler.ast_to_source(nodes)) + logging.log(2, 'Compiled output of %s:\n\n%s\n', o, parser.unparse(nodes)) if logging.has_verbosity(4): for n in nodes: logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o, diff --git a/tensorflow/python/autograph/impl/conversion_test.py b/tensorflow/python/autograph/impl/conversion_test.py index 4cdffe6d6e2..a6336ef0dab 100644 --- a/tensorflow/python/autograph/impl/conversion_test.py +++ b/tensorflow/python/autograph/impl/conversion_test.py @@ -30,7 +30,7 @@ from tensorflow.python.autograph.core import config from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.impl import api from tensorflow.python.autograph.impl import conversion -from tensorflow.python.autograph.pyct import compiler +from tensorflow.python.autograph.pyct import parser from tensorflow.python.framework import constant_op from tensorflow.python.keras.engine import training from tensorflow.python.platform import test @@ -128,9 +128,8 @@ class ConversionTest(test.TestCase): self.assertIsInstance(fn_node, gast.FunctionDef) self.assertEqual('tf__f', name) self.assertEqual( - compiler.ast_to_source( - fn_node.args.defaults[0], include_encoding_marker=False).strip(), - 'None') + parser.unparse(fn_node.args.defaults[0], + include_encoding_marker=False).strip(), 'None') def test_convert_entity_to_ast_call_tree(self): diff --git a/tensorflow/python/autograph/operators/control_flow.py b/tensorflow/python/autograph/operators/control_flow.py index bbfee424315..c862379e1d0 100644 --- a/tensorflow/python/autograph/operators/control_flow.py +++ b/tensorflow/python/autograph/operators/control_flow.py @@ -109,160 +109,140 @@ def _disallow_undefs_into_loop(*values): 'return statements are not supported within a TensorFlow loop.') -def _shape_greater_than_or_equal(shape1, shape2): - """Check whether the shape2 is equal or more specific than shape1.""" - - # The following logic was mirrored from control_flow_ops.py's - # _ShapeLessThanOrEqual function. - if shape1.dims is None: +def _is_subshape(left, right): + """Returns True if left shape is at least as specific as right shape.""" + # TODO(mdan): This code should be in TensorShape. + # Note: this is not the same as TensorShape.is_compatible_with, which is + # symmetric. + # This code also duplicates _ShapeLessThanOrEqual from control_flow_ops.py. + if right.dims is None: return True - if shape1.ndims != shape2.ndims: + if left.ndims != right.ndims: return False - for dim1, dim2 in zip(shape1.dims, shape2.dims): - if dim1.value is not None and dim1.value != dim2.value: + for ldim, rdim in zip(left.dims, right.dims): + if rdim.value is not None and ldim.value != rdim.value: return False return True -def _verify_tf_loop_vars(init_loop_vars, - first_iter_vars, - basic_symbol_names, - composite_symbol_names, - include_shapes=True): - """Verifies loop variables for consistency.""" +def _verify_single_loop_var(name, check_shape, init_loop_var, first_iter_var): + """Verifies whether init_loop_var and first_iter_var are consistent.""" + if isinstance(init_loop_var, (bool, int, float, str)): + init_loop_var = ops.convert_to_tensor_v2(init_loop_var) - # The whole point of _verify_tf_loop_vars is to give more useful error message - # than tf-level exception by including variable names. If it's not available, - # there is no point at performing this verification here. As of 2019-07-31, - # operators:control_flow_test does not pass the names. - if basic_symbol_names is None: + if isinstance(first_iter_var, (bool, int, float, str)): + first_iter_var = ops.convert_to_tensor_v2(first_iter_var) + + if (not tensor_util.is_tensor(init_loop_var) or + not tensor_util.is_tensor(first_iter_var)): return - output_symbol_names = basic_symbol_names + composite_symbol_names + # TODO(mdan): Properly account for CompositeTensors. + if (not hasattr(init_loop_var, 'dtype') or + not hasattr(first_iter_var, 'dtype')): + return + if (not hasattr(init_loop_var, 'shape') or + not hasattr(first_iter_var, 'shape')): + return - assert len(init_loop_vars) == len(first_iter_vars) == len(output_symbol_names) + if init_loop_var.dtype != first_iter_var.dtype: + raise TypeError( + '"{}" has dtype {} before the loop, but dtype {} after one' + ' iteration. TensorFlow control flow requires it stays the' + ' same.'.format( + name, + init_loop_var.dtype.name, + first_iter_var.dtype.name, + )) - for init_loop_var, first_iter_var, name in zip(init_loop_vars, - first_iter_vars, - output_symbol_names): + if check_shape: + init_shape = init_loop_var.shape + first_iter_shape = first_iter_var.shape + # TODO(b/135183013): Update needed once we support shape_invariants. + if not _is_subshape(first_iter_shape, init_shape): + raise ValueError( + '"{}" has shape {} before the loop, but shape {} after one' + ' iteration. TensorFlow control flow requires it stays the' + ' same or be more specific.'.format(name, init_shape, + first_iter_shape)) + +def _verify_tf_loop_vars(init_loop_vars, + first_iter_vars, + symbol_names, + opts, + check_shapes=True): + """Verifies loop variables for consistency.""" + # TODO(b/140125096): Use this. + del opts + + named_vars = zip(symbol_names, init_loop_vars, first_iter_vars) + for name, init_loop_var, first_iter_var in named_vars: try: nest.assert_same_structure( init_loop_var, first_iter_var, expand_composites=True) except (ValueError, TypeError) as e: raise TypeError('"{}" does not have the same nested structure after one' ' iteration.\n\n{}'.format(name, e)) - - def _check_same_type(name, init_loop_var, first_iter_var): - """Ensures init_loop_var and first_iter_var are consistent.""" - if isinstance(init_loop_var, (bool, int, float, str)): - init_loop_var = ops.convert_to_tensor_v2(init_loop_var) - - if isinstance(first_iter_var, (bool, int, float, str)): - first_iter_var = ops.convert_to_tensor_v2(first_iter_var) - - if (not tensor_util.is_tensor(init_loop_var) or - not tensor_util.is_tensor(first_iter_var)): - return - - # TODO(mdan): Properly account for CompositeTensors. - if (not hasattr(init_loop_var, 'dtype') or - not hasattr(first_iter_var, 'dtype')): - return - if (not hasattr(init_loop_var, 'shape') or - not hasattr(first_iter_var, 'shape')): - return - - if init_loop_var.dtype != first_iter_var.dtype: - raise TypeError( - '"{}" has dtype {} before the loop, but dtype {} after one' - ' iteration. TensorFlow control flow requires it stays the' - ' same.'.format( - name, - init_loop_var.dtype.name, - first_iter_var.dtype.name, - )) - - if include_shapes: - init_shape = init_loop_var.shape - first_iter_shape = first_iter_var.shape - # TODO(b/135183013): Update needed once we support shape_invariants. - if not _shape_greater_than_or_equal(init_shape, first_iter_shape): - raise ValueError( - '"{}" has shape {} before the loop, but shape {} after one' - ' iteration. TensorFlow control flow requires it stays the' - ' same or be more specific.'.format(name, init_shape, - first_iter_shape)) - nest.map_structure( - functools.partial(_check_same_type, name), init_loop_var, - first_iter_var) + functools.partial(_verify_single_loop_var, name, check_shapes), + init_loop_var, first_iter_var) -def _verify_tf_cond_vars(body_outputs, orelse_outputs, basic_symbol_names, - composite_symbol_names): - """Verifies variables manipulated by a conditional for consistency.""" +def _verify_single_cond_var(name, body_var, orelse_var): + """Verifies whether body_var and orelse_var are consistent.""" + if isinstance(body_var, (bool, int, float, str)): + body_var = ops.convert_to_tensor_v2(body_var) - # The whole point of _verify_tf_cond_vars is to give more useful error message - # than tf-level exception by including variable names. If it's not available, - # there is no point at performing this verification here. As of 2019-07-31, - # conditional expression does not pass the names. - if basic_symbol_names is None: + if isinstance(orelse_var, (bool, int, float, str)): + orelse_var = ops.convert_to_tensor_v2(orelse_var) + + if (not tensor_util.is_tensor(body_var) or + not tensor_util.is_tensor(orelse_var)): return - output_symbol_names = basic_symbol_names + composite_symbol_names + # TODO(mdan): Properly account for CompositeTensors. + if (not hasattr(body_var, 'dtype') or + not hasattr(orelse_var, 'dtype')): + return - basic_body_outputs, composite_body_outputs = body_outputs - basic_orelse_outputs, composite_orelse_outputs = orelse_outputs - assert isinstance(composite_body_outputs, tuple) - assert isinstance(composite_orelse_outputs, tuple) + if body_var.dtype != orelse_var.dtype: + raise TypeError( + '"{}" has dtype {} in the TRUE branch, but dtype={} in the FALSE' + ' branch. TensorFlow control flow requires that they are the' + ' same.'.format(name, body_var.dtype.name, + orelse_var.dtype.name)) + + +def _verify_tf_cond_vars(body_vars, orelse_vars, symbol_names): + """Verifies variables manipulated by a conditional for consistency.""" + basic_body_vars, composite_body_vars = body_vars + basic_orelse_vars, composite_orelse_vars = orelse_vars + assert isinstance(composite_body_vars, tuple) + assert isinstance(composite_orelse_vars, tuple) # TODO(kkimlabs): Make this more consistent. # The basic outputs should always be a tuple. - if not isinstance(basic_body_outputs, tuple): - basic_body_outputs = (basic_body_outputs,) - if not isinstance(basic_orelse_outputs, tuple): - basic_orelse_outputs = (basic_orelse_outputs,) + if not isinstance(basic_body_vars, tuple): + basic_body_vars = (basic_body_vars,) + if not isinstance(basic_orelse_vars, tuple): + basic_orelse_vars = (basic_orelse_vars,) - body_outputs = basic_body_outputs + composite_body_outputs - orelse_outputs = basic_orelse_outputs + composite_orelse_outputs + body_vars = basic_body_vars + composite_body_vars + orelse_vars = basic_orelse_vars + composite_orelse_vars - for body_output, orelse_output, name in zip(body_outputs, orelse_outputs, - output_symbol_names): + named_vars = zip(symbol_names, body_vars, orelse_vars) + for name, body_var, orelse_var in named_vars: try: nest.assert_same_structure( - body_output, orelse_output, expand_composites=True) + body_var, orelse_var, expand_composites=True) except (ValueError, TypeError) as e: raise TypeError( '"{}" does not have the same nested structure in the TRUE and FALSE' ' branches.\n\n{}'.format(name, str(e))) - def _check_same_type(name, body_output_var, orelse_output_var): - """Verfies that body_output_var and orelse_output_var have same dtype.""" - if isinstance(body_output_var, (bool, int, float, str)): - body_output_var = ops.convert_to_tensor_v2(body_output_var) - - if isinstance(orelse_output_var, (bool, int, float, str)): - orelse_output_var = ops.convert_to_tensor_v2(orelse_output_var) - - if (not tensor_util.is_tensor(body_output_var) or - not tensor_util.is_tensor(orelse_output_var)): - return - - # TODO(mdan): Properly account for CompositeTensors. - if (not hasattr(body_output_var, 'dtype') or - not hasattr(orelse_output_var, 'dtype')): - return - - if body_output_var.dtype != orelse_output_var.dtype: - raise TypeError( - '"{}" has dtype {} in the TRUE branch, but dtype={} in the FALSE' - ' branch. TensorFlow control flow requires that they are the' - ' same.'.format(name, body_output_var.dtype.name, - orelse_output_var.dtype.name)) - nest.map_structure( - functools.partial(_check_same_type, name), body_output, orelse_output) + functools.partial(_verify_single_cond_var, name), body_var, orelse_var) def for_stmt(iter_, @@ -271,8 +251,9 @@ def for_stmt(iter_, get_state, set_state, init_vars, - basic_symbol_names=None, - composite_symbol_names=None): + basic_symbol_names, + composite_symbol_names, + opts): """Functional form of a for statement. The loop operates on a state, which includes all symbols that are @@ -308,6 +289,7 @@ def for_stmt(iter_, init_vars: Tuple containing the initial state. basic_symbol_names: Tuple containing basic loop var names. composite_symbol_names: Tuple containing composite loop var names. + opts: Optional dict of extra loop parameters. Returns: Tuple containing the final state. @@ -316,26 +298,26 @@ def for_stmt(iter_, if tensors.is_range_tensor(iter_): return _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars, basic_symbol_names, - composite_symbol_names) + composite_symbol_names, opts) else: return _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars, basic_symbol_names, - composite_symbol_names) + composite_symbol_names, opts) if isinstance(iter_, dataset_ops.DatasetV2): return _tf_dataset_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars, basic_symbol_names, - composite_symbol_names) + composite_symbol_names, opts) if isinstance(iter_, iterator_ops.OwnedIterator): return _tf_iterator_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars, basic_symbol_names, - composite_symbol_names) + composite_symbol_names, opts) if isinstance(iter_, ragged_tensor.RaggedTensor): return _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars, basic_symbol_names, - composite_symbol_names) + composite_symbol_names, opts) # Note: This experimental interface is subject to change. custom_handler = getattr(iter_, '_autograph_for_loop', None) @@ -360,9 +342,15 @@ def _py_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars): return state -def _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state, - init_vars, basic_symbol_names, - composite_symbol_names): +def _known_len_tf_for_stmt(iter_, + extra_test, + body, + get_state, + set_state, + init_vars, + basic_symbol_names, + composite_symbol_names, + opts): """Overload of for_stmt that iterates over TF entities that admit a length.""" _disallow_undefs_into_loop(*init_vars) @@ -377,8 +365,6 @@ def _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state, """Main loop body.""" iterate = iter_.read(iterate_index) new_vars = body(iterate, *loop_vars) - _verify_tf_loop_vars(loop_vars, new_vars, basic_symbol_names, - composite_symbol_names) loop_vars = (iterate_index + 1,) if new_vars: @@ -388,13 +374,12 @@ def _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state, def while_cond(iterate_index, *loop_vars): if extra_test is not None: - return control_flow_ops.cond( - iterate_index < n, lambda: extra_test(*loop_vars), lambda: False) + return control_flow_ops.cond(iterate_index < n, + lambda: extra_test(*loop_vars), + lambda: False) return iterate_index < n - opts = {} - # TODO(b/134181679): We do not always set maximum_iterations since that - # is significantly slower on GPU. + # TODO(b/134181679): Let the op itself handle optimizations. if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): opts['maximum_iterations'] = n @@ -403,10 +388,10 @@ def _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state, while_body, get_state, set_state, - (0,) + init_vars, - None, - None, - opts=opts, + (array_ops.zeros_like(n),) + init_vars, + ('',) + basic_symbol_names, + composite_symbol_names, + opts, ) # Note: the iteration index is not returned by the while loop, however @@ -422,9 +407,15 @@ def _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state, return results -def _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state, - init_vars, basic_symbol_names, - composite_symbol_names): +def _tf_ragged_for_stmt(iter_, + extra_test, + body, + get_state, + set_state, + init_vars, + basic_symbol_names, + composite_symbol_names, + opts): """Overload of for_stmt that iterates over TF ragged tensors.""" _disallow_undefs_into_loop(*init_vars) @@ -438,8 +429,6 @@ def _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state, """Main loop body.""" iterate = iter_[iterate_index] new_vars = body(iterate, *loop_vars) - _verify_tf_loop_vars(loop_vars, new_vars, basic_symbol_names, - composite_symbol_names) loop_vars = (iterate_index + 1,) if new_vars: @@ -450,10 +439,13 @@ def _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state, def while_cond(iterate_index, *loop_vars): if extra_test is not None: return control_flow_ops.cond( - iterate_index < n, lambda: extra_test(*loop_vars), lambda: False) + iterate_index < n, + lambda: extra_test(*loop_vars), + lambda: False, + ) return iterate_index < n - opts = {'maximum_iterations': n} + opts['maximum_iterations'] = n results = _tf_while_stmt( while_cond, @@ -461,9 +453,9 @@ def _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state, get_state, set_state, (array_ops.zeros_like(n),) + init_vars, - None, - None, - opts=opts, + ('',) + basic_symbol_names, + composite_symbol_names, + opts, ) if isinstance(results, (tuple, list)): @@ -476,8 +468,15 @@ def _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state, return results -def _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars, - basic_symbol_names, composite_symbol_names): +def _tf_range_for_stmt(iter_, + extra_test, + body, + get_state, + set_state, + init_vars, + basic_symbol_names, + composite_symbol_names, + opts): """Overload of for_stmt that iterates over a TF range (and elides it).""" _disallow_undefs_into_loop(*init_vars) @@ -497,8 +496,9 @@ def _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars, def build_main_test(): """Main iteration condition.""" - # Note(b/138857806): LogicalAnd is slow on GPU so we avoid adding it if - # `delta` is a compile time constant. + # TODO(b/138857806): The optimizer should handle this. + # LogicalAnd is slow on GPU so we avoid adding it if `delta` is a + # compile time constant. delta_const = tensor_util.constant_value(delta) if delta_const is not None: # Support single element arrays. @@ -515,16 +515,13 @@ def _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars, main_test = build_main_test() if extra_test is not None: return control_flow_ops.cond( - main_test, lambda: extra_test(*loop_vars), lambda: False) + main_test, + lambda: extra_test(*loop_vars), + lambda: False, + ) return main_test - # The first loopvar corresponds to the iterate variable which is internal. - if isinstance(basic_symbol_names, tuple): - basic_symbol_names = (None,) + basic_symbol_names - - opts = {} - # TODO(b/134181679): We do not always set maximum_iterations since that - # is significantly slower on GPU. + # TODO(b/134181679): The op should handle this optimizations. if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): # This specific dtype is required by while_loop. opts['maximum_iterations'] = math_ops.cast( @@ -536,9 +533,9 @@ def _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars, get_state, set_state, (start,) + init_vars, - basic_symbol_names, + ('',) + basic_symbol_names, composite_symbol_names, - opts=opts, + opts, ) # Note: the iteration index is not returned by the while loop, however @@ -556,21 +553,24 @@ def _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars, def _tf_iterator_for_stmt(itr, extra_test, body, get_state, set_state, init_vars, basic_symbol_names, - composite_symbol_names): + composite_symbol_names, opts): """Overload of for_stmt that iterates over TF Iterators. See for_loop.""" _disallow_undefs_into_loop(*init_vars) def while_body_actual(opt_iterate, *loop_vars): """Actual main loop body.""" new_vars = body(opt_iterate.get_value(), *loop_vars) - _verify_tf_loop_vars(loop_vars, new_vars, basic_symbol_names, - composite_symbol_names) # TODO(mdan): Fix this inconsistency in the converter. if new_vars is None: new_vars = () + # Note: this verification duplicates that perfrmed in tf_while_stmt, + # but needs to be done earlier to prevent the tf.cond inside while_body + # from blowing up first. + _verify_tf_loop_vars(loop_vars, new_vars, + basic_symbol_names + composite_symbol_names, opts) return new_vars - def while_body(has_next, loop_vars): + def while_body(has_next, *loop_vars): """Main loop body.""" opt_iterate = iterator_ops.get_next_as_optional(itr) has_next = opt_iterate.has_value() @@ -591,30 +591,32 @@ def _tf_iterator_for_stmt(itr, extra_test, body, get_state, set_state, if dummy_state: new_vars = new_vars[1:] - return has_next, new_vars + return (has_next,) + new_vars - def while_cond(has_next, loop_vars): + def while_cond(has_next, *loop_vars): if extra_test is not None: return control_flow_ops.cond( - has_next, lambda: extra_test(*loop_vars), lambda: False) + has_next, + lambda: extra_test(*loop_vars), + lambda: False, + ) return has_next - # The first loopvar corresponds to the iterate variable which is internal. - _, final_vars = _tf_while_stmt( + final_vars = _tf_while_stmt( while_cond, while_body, get_state, set_state, - (True, init_vars), - None, - None, - opts=None, + (True,) + init_vars, + ('',) + basic_symbol_names, + composite_symbol_names, + opts, ) - return final_vars + return final_vars[1:] def _tf_dataset_for_stmt(ds, extra_test, body, get_state, set_state, init_vars, - basic_symbol_names, composite_symbol_names): + basic_symbol_names, composite_symbol_names, opts): """Overload of for_stmt that iterates over TF Datasets.""" _disallow_undefs_into_loop(*init_vars) @@ -623,11 +625,11 @@ def _tf_dataset_for_stmt(ds, extra_test, body, get_state, set_state, init_vars, return _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state, set_state, init_vars, basic_symbol_names, - composite_symbol_names) + composite_symbol_names, opts) return _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state, init_vars, basic_symbol_names, - composite_symbol_names) + composite_symbol_names, opts) def _general_purpose_scan(ds, init_state, body): @@ -646,7 +648,7 @@ def _general_purpose_scan(ds, init_state, body): def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state, set_state, init_vars, basic_symbol_names, - composite_symbol_names): + composite_symbol_names, opts): """Overload of _dataset_for_stmt with early stopping. See for_stmt.""" # TODO(mdan): Simplify this - following it is extremely difficult. @@ -661,14 +663,17 @@ def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state, _verify_tf_loop_vars( loop_vars + state, outputs + state, - basic_symbol_names, - composite_symbol_names, - include_shapes=False) + basic_symbol_names + composite_symbol_names, + opts, + check_shapes=False) return outputs, get_state() extra_cond = extra_test(*loop_vars) new_vars, new_state = control_flow_ops.cond( - extra_cond, true_fn, lambda: (loop_vars, state)) + extra_cond, + true_fn, + lambda: (loop_vars, state), + ) scan_outputs = new_vars, new_state, extra_cond # Note: new_aug_vars is the actual state of scan; scan_outputs is its output @@ -696,12 +701,15 @@ def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state, def _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state, init_vars, - basic_symbol_names, composite_symbol_names): + basic_symbol_names, composite_symbol_names, + opts): """Overload of _dataset_for_stmt without early stopping. See for_stmt.""" init_state = get_state() assert isinstance(init_vars, tuple) assert isinstance(init_state, tuple) + symbol_names = basic_symbol_names + composite_symbol_names + # Workaround for Dataset.reduce not allowing empty state tensors - create # a dummy state variable that remains unused. # TODO(mdan): reduce should allow and match empty structures. @@ -710,10 +718,10 @@ def _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state, init_vars, if no_vars: init_vars = (constant_op.constant(0),) - if isinstance(basic_symbol_names, tuple): - basic_symbol_names = (None,) + basic_symbol_names + symbol_names = ('',) + symbol_names if no_state: init_state = (constant_op.constant(0),) + symbol_names = symbol_names + ('',) def scan_body(aug_vars, iterate): """The main loop body wrapper.""" @@ -735,9 +743,9 @@ def _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state, init_vars, _verify_tf_loop_vars( loop_vars + state, new_vars + new_state, - basic_symbol_names, - composite_symbol_names, - include_shapes=False) + symbol_names, + opts, + check_shapes=False) scan_outputs = new_vars, new_state # Note: new_aug_vars is the actual state of scan; scan_outputs is its output @@ -760,16 +768,14 @@ def _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state, init_vars, return final_vars -def while_stmt( - test, - body, - get_state, - set_state, - init_vars, - basic_symbol_names=None, - composite_symbol_names=None, - opts=None, -): +def while_stmt(test, + body, + get_state, + set_state, + init_vars, + basic_symbol_names, + composite_symbol_names, + opts): """Functional form of a while statement. The loop operates on a so-called state, which includes all symbols that are @@ -818,17 +824,11 @@ def while_stmt( return _py_while_stmt(test, body, get_state, set_state, init_vars, opts) -# TODO(kkimlabs): Some callers set basic_symbol_names=None and -# composite_symbol_names=None and call _verify_tf_loop_vars(...) itself. We can -# remove these arguments once all callers do that. def _tf_while_stmt(test, body, get_state, set_state, init_vars, basic_symbol_names, composite_symbol_names, opts): """Overload of while_stmt that stages a TF while_stmt.""" _disallow_undefs_into_loop(*init_vars) - if opts is None: - opts = {} - # TODO(mdan): Simplify this. loop_vars_slice = slice(len(init_vars)) state_slice = slice(len(init_vars), None) @@ -839,12 +839,13 @@ def _tf_while_stmt(test, body, get_state, set_state, init_vars, return test(*aug_loop_vars[loop_vars_slice]) def aug_body(*aug_loop_vars): + """Main loop body.""" state = aug_loop_vars[state_slice] set_state(state) loop_vars = body(*aug_loop_vars[loop_vars_slice]) new_state = loop_vars + get_state() - _verify_tf_loop_vars(aug_loop_vars, new_state, basic_symbol_names, - composite_symbol_names) + _verify_tf_loop_vars(aug_loop_vars, new_state, + basic_symbol_names + composite_symbol_names, opts) return new_state @@ -948,8 +949,8 @@ def if_stmt(cond, orelse, get_state, set_state, - basic_symbol_names=None, - composite_symbol_names=None): + basic_symbol_names, + composite_symbol_names): """Functional form of an if statement. Args: @@ -1005,14 +1006,14 @@ def tf_if_stmt(cond, body, orelse, get_state, set_state, basic_symbol_names, result[body_branch] = body() if result[orelse_branch] is not None: _verify_tf_cond_vars(result[body_branch], result[orelse_branch], - basic_symbol_names, composite_symbol_names) + basic_symbol_names + composite_symbol_names) return result[body_branch] def error_checking_orelse(): result[orelse_branch] = orelse() if result[body_branch] is not None: _verify_tf_cond_vars(result[body_branch], result[orelse_branch], - basic_symbol_names, composite_symbol_names) + basic_symbol_names + composite_symbol_names) return result[orelse_branch] final_vars, final_state = control_flow_ops.cond(cond, error_checking_body, diff --git a/tensorflow/python/autograph/operators/control_flow_test.py b/tensorflow/python/autograph/operators/control_flow_test.py index 2290d61c6fd..a85d74246a1 100644 --- a/tensorflow/python/autograph/operators/control_flow_test.py +++ b/tensorflow/python/autograph/operators/control_flow_test.py @@ -50,7 +50,10 @@ class ForLoopTest(test.TestCase): body=lambda i, s: (s * 10 + i,), get_state=lambda: (), set_state=lambda _: None, - init_vars=(0,)) + init_vars=(0,), + basic_symbol_names=('s',), + composite_symbol_names=(), + opts={}) self.assertEqual(self.evaluate(s), (1234,)) def test_range_tensor(self): @@ -60,7 +63,10 @@ class ForLoopTest(test.TestCase): body=lambda i, s: (s * 10 + i,), get_state=lambda: (), set_state=lambda _: None, - init_vars=(0,)) + init_vars=(0,), + basic_symbol_names=('s',), + composite_symbol_names=(), + opts={}) self.assertEqual(self.evaluate(s), (1234,)) def test_range_tensor_random_delta(self): @@ -71,7 +77,10 @@ class ForLoopTest(test.TestCase): body=lambda i, s: (s * 10 + i,), get_state=lambda: (), set_state=lambda _: None, - init_vars=(0,)) + init_vars=(0,), + basic_symbol_names=('s',), + composite_symbol_names=(), + opts={}) self.assertEqual(self.evaluate(s), (1234,)) def test_range_tensor_explicit_limit_delta(self): @@ -81,7 +90,10 @@ class ForLoopTest(test.TestCase): body=lambda i, s: (s * 100 + i,), get_state=lambda: (), set_state=lambda _: None, - init_vars=(0,)) + init_vars=(0,), + basic_symbol_names=('s',), + composite_symbol_names=(), + opts={}) self.assertEqual(self.evaluate(s), (-171207,)) def test_range_tensor_random_negative_delta(self): @@ -92,7 +104,10 @@ class ForLoopTest(test.TestCase): body=lambda i, s: (s * 100 + i,), get_state=lambda: (), set_state=lambda _: None, - init_vars=(0,)) + init_vars=(0,), + basic_symbol_names=('s',), + composite_symbol_names=(), + opts={}) self.assertEqual(self.evaluate(s), (171207,)) def test_range_tensor_negative_delta(self): @@ -102,7 +117,10 @@ class ForLoopTest(test.TestCase): body=lambda i, s: (s * 100 + i,), get_state=lambda: (), set_state=lambda _: None, - init_vars=(0,)) + init_vars=(0,), + basic_symbol_names=('s',), + composite_symbol_names=(), + opts={}) self.assertEqual(self.evaluate(s), (171207,)) def test_tensor_with_extra_test_only_python_state(self): @@ -128,7 +146,10 @@ class ForLoopTest(test.TestCase): extra_test=lambda: state.field_1 < 6, get_state=get_state, set_state=set_state, - init_vars=()) + init_vars=(), + basic_symbol_names=(), + composite_symbol_names=(), + opts={}) self.assertEqual(self.evaluate(state.field_1), 6) self.assertEqual(self.evaluate(state.field_2), 6) @@ -139,7 +160,10 @@ class ForLoopTest(test.TestCase): body=lambda i, s: (s * 10 + i,), get_state=None, set_state=None, - init_vars=(0,)) + init_vars=(0,), + basic_symbol_names=('s',), + composite_symbol_names=(), + opts={}) self.assertEqual(s, (1234,)) def test_tf_dataset(self): @@ -149,7 +173,10 @@ class ForLoopTest(test.TestCase): body=lambda i, s: (s * 10 + i,), get_state=lambda: (), set_state=lambda _: None, - init_vars=(constant_op.constant(0, dtype=dtypes.int64),)) + init_vars=(constant_op.constant(0, dtype=dtypes.int64),), + basic_symbol_names=('s',), + composite_symbol_names=(), + opts={}) self.assertEqual(self.evaluate(s), (1234,)) def test_dataset_with_extra_test(self): @@ -159,7 +186,10 @@ class ForLoopTest(test.TestCase): body=lambda i, s: (s + i,), get_state=lambda: (), set_state=lambda _: None, - init_vars=(constant_op.constant(0, dtype=dtypes.int64),)) + init_vars=(constant_op.constant(0, dtype=dtypes.int64),), + basic_symbol_names=('s',), + composite_symbol_names=(), + opts={}) self.assertEqual(self.evaluate(s), (3,)) def test_dataset_with_extra_test_and_state(self): @@ -181,7 +211,10 @@ class ForLoopTest(test.TestCase): body=body, get_state=get_state, set_state=set_state, - init_vars=(constant_op.constant(0, dtype=dtypes.int64),)) + init_vars=(constant_op.constant(0, dtype=dtypes.int64),), + basic_symbol_names=('s',), + composite_symbol_names=(), + opts={}) self.assertEqual(self.evaluate(s), (3,)) self.assertEqual(self.evaluate(state[0]), (3,)) @@ -197,7 +230,10 @@ class ForLoopTest(test.TestCase): body=guarded_body, get_state=lambda: (), set_state=lambda _: None, - init_vars=(constant_op.constant(0, dtype=dtypes.int64),)) + init_vars=(constant_op.constant(0, dtype=dtypes.int64),), + basic_symbol_names=('s',), + composite_symbol_names=(), + opts={}) self.assertEqual(self.evaluate(s), (3,)) def test_tf_dataset_no_loop_vars(self): @@ -217,7 +253,10 @@ class ForLoopTest(test.TestCase): body=stateless_with_side_effects, get_state=lambda: (), set_state=lambda _: None, - init_vars=()) + init_vars=(), + basic_symbol_names=('i',), + composite_symbol_names=(), + opts={}) self.evaluate(test_fn()) self.assertEqual(self.evaluate(v.read_value()), 1234) @@ -233,7 +272,10 @@ class ForLoopTest(test.TestCase): body=lambda i, s: (s * 10 + i,), get_state=lambda: (), set_state=lambda _: None, - init_vars=(constant_op.constant(0, dtype=dtypes.int64),)) + init_vars=(constant_op.constant(0, dtype=dtypes.int64),), + basic_symbol_names=('s',), + composite_symbol_names=(), + opts={}) s, = test_fn() self.assertAllEqual(s, 1234) @@ -253,7 +295,10 @@ class ForLoopTest(test.TestCase): body=stateless_with_side_effects, get_state=lambda: (), set_state=lambda _: None, - init_vars=()) + init_vars=(), + basic_symbol_names=('i',), + composite_symbol_names=(), + opts={}) self.evaluate(test_fn()) self.assertEqual(self.evaluate(v.read_value()), 1234) @@ -265,7 +310,10 @@ class ForLoopTest(test.TestCase): body=lambda i, s: (s * 10 + i[0],), get_state=lambda: (), set_state=lambda _: None, - init_vars=(0,)) + init_vars=(0,), + basic_symbol_names=('s',), + composite_symbol_names=(), + opts={}) self.assertEqual(self.evaluate(s), (123,)) def test_tf_ragged_tensor_higher_dimensional(self): @@ -279,7 +327,10 @@ class ForLoopTest(test.TestCase): body=lambda i, s: (s * 10 + i[0][0],), get_state=lambda: (), set_state=lambda _: None, - init_vars=(0,)) + init_vars=(0,), + basic_symbol_names=('s',), + composite_symbol_names=(), + opts={}) self.assertEqual(self.evaluate(s), (12,)) def test_tf_ragged_tensor_no_loop_vars(self): @@ -298,7 +349,10 @@ class ForLoopTest(test.TestCase): body=stateless_with_side_effects, get_state=lambda: (), set_state=lambda _: None, - init_vars=()) + init_vars=(), + basic_symbol_names=(), + composite_symbol_names=(), + opts={}) self.evaluate(test_fn()) # Note: 123 = ((0*10 + 1)*10+2)*10+3 (first element of each row). @@ -315,7 +369,10 @@ class WhileLoopTest(test.TestCase): body=lambda i, s: (i + 1, s + i), get_state=lambda: (), set_state=lambda _: None, - init_vars=(0, 0)) + init_vars=(0, 0), + basic_symbol_names=('i', 's'), + composite_symbol_names=(), + opts={}) self.assertEqual((5, 10), self.evaluate(results)) def test_tensor_with_tf_side_effects_in_cond(self): @@ -334,7 +391,10 @@ class WhileLoopTest(test.TestCase): body=lambda i: (i + 1,), get_state=lambda: (), set_state=lambda _: None, - init_vars=(0,)) + init_vars=(0,), + basic_symbol_names=('i',), + composite_symbol_names=(), + opts={}) results = test_fn() @@ -364,7 +424,10 @@ class WhileLoopTest(test.TestCase): body=body, get_state=get_state, set_state=set_state, - init_vars=(0, 0)) + init_vars=(0, 0), + basic_symbol_names=('i',), + composite_symbol_names=(), + opts={}) self.assertEqual(self.evaluate(s), (5, 10)) self.assertEqual(self.evaluate(state.field), 10) @@ -375,7 +438,10 @@ class WhileLoopTest(test.TestCase): body=lambda i, s: (i + 1, s + i), get_state=lambda: (), set_state=lambda _: None, - init_vars=(0, constant_op.constant(0))) + init_vars=(0, constant_op.constant(0)), + basic_symbol_names=('i', 's'), + composite_symbol_names=(), + opts={}) result_i, result_s = results self.assertEqual(5, result_i) self.assertEqual(10, self.evaluate(result_s)) @@ -387,7 +453,10 @@ class WhileLoopTest(test.TestCase): body=lambda i, s: (i + 1, s + i), get_state=None, set_state=None, - init_vars=(0, 0)) + init_vars=(0, 0), + basic_symbol_names=('i', 's'), + composite_symbol_names=(), + opts={}) self.assertEqual((5, 10), results) def test_python_infinite_loop(self): @@ -399,7 +468,10 @@ class WhileLoopTest(test.TestCase): body=lambda i: (i + 1,), get_state=None, set_state=None, - init_vars=(0,)) + init_vars=(0,), + basic_symbol_names=('i',), + composite_symbol_names=(), + opts={}) def test_python_long_loop_unroll_warning(self): if __debug__: @@ -415,7 +487,10 @@ class WhileLoopTest(test.TestCase): body=lambda i, _: (i + 1, gen_math_ops.add(i, 1),), get_state=None, set_state=None, - init_vars=(0, None)) + init_vars=(0, None), + basic_symbol_names=('i',), + composite_symbol_names=(), + opts={}) self.assertTrue(re.match( r'.*ops.*loop.*large.*iterations.*Add.*', out_capturer.getvalue())) @@ -432,7 +507,9 @@ class IfStmtTest(test.TestCase): body=lambda: constant_op.constant(1), orelse=lambda: constant_op.constant(-1), get_state=lambda: (), - set_state=lambda _: None) + set_state=lambda _: None, + basic_symbol_names=('_',), + composite_symbol_names=()) self.assertEqual(1, self.evaluate(test_fn(constant_op.constant(True)))) self.assertEqual(-1, self.evaluate(test_fn(constant_op.constant(False)))) @@ -445,7 +522,9 @@ class IfStmtTest(test.TestCase): body=lambda: (constant_op.constant(1), constant_op.constant(2)), orelse=lambda: (constant_op.constant(-1), constant_op.constant(-2)), get_state=lambda: (), - set_state=lambda _: None) + set_state=lambda _: None, + basic_symbol_names=('_',), + composite_symbol_names=()) self.assertEqual((1, 2), self.evaluate(test_fn(constant_op.constant(True)))) self.assertEqual((-1, -2), @@ -459,7 +538,9 @@ class IfStmtTest(test.TestCase): body=lambda: 1, orelse=lambda: -1, get_state=lambda: (), - set_state=lambda _: None) + set_state=lambda _: None, + basic_symbol_names=('_',), + composite_symbol_names=()) self.assertEqual(1, test_fn(True)) self.assertEqual(-1, test_fn(False)) @@ -472,7 +553,9 @@ class IfStmtTest(test.TestCase): body=lambda: (1, 2), orelse=lambda: (-1, -2), get_state=lambda: (), - set_state=lambda _: None) + set_state=lambda _: None, + basic_symbol_names=('_',), + composite_symbol_names=()) self.assertEqual((1, 2), test_fn(True)) self.assertEqual((-1, -2), test_fn(False)) diff --git a/tensorflow/python/autograph/pyct/BUILD b/tensorflow/python/autograph/pyct/BUILD index f7ae813c41d..b9931236428 100644 --- a/tensorflow/python/autograph/pyct/BUILD +++ b/tensorflow/python/autograph/pyct/BUILD @@ -25,10 +25,10 @@ py_library( "anno.py", "ast_util.py", "cfg.py", - "compiler.py", "error_utils.py", "errors.py", "inspect_utils.py", + "loader.py", "origin_info.py", "parser.py", "pretty_printer.py", @@ -83,8 +83,8 @@ py_test( ) py_test( - name = "compiler_test", - srcs = ["compiler_test.py"], + name = "loader_test", + srcs = ["loader_test.py"], python_version = "PY3", srcs_version = "PY2AND3", deps = [ diff --git a/tensorflow/python/autograph/pyct/anno.py b/tensorflow/python/autograph/pyct/anno.py index e1f4af46cd7..a8ae864cd88 100644 --- a/tensorflow/python/autograph/pyct/anno.py +++ b/tensorflow/python/autograph/pyct/anno.py @@ -55,6 +55,8 @@ class Basic(NoValue): ' `name_map` allows renaming symbols.') ORIGIN = ('Information about the source code that converted code originated' ' from. See origin_information.py.') + DIRECTIVES = ('User directives associated with a statement or a variable.' + ' Typically, they affect the immediately-enclosing statement.') class Static(NoValue): diff --git a/tensorflow/python/autograph/pyct/ast_util_test.py b/tensorflow/python/autograph/pyct/ast_util_test.py index bc7c3f93ac5..7ed0f7b6b85 100644 --- a/tensorflow/python/autograph/pyct/ast_util_test.py +++ b/tensorflow/python/autograph/pyct/ast_util_test.py @@ -26,7 +26,7 @@ import gast from tensorflow.python.autograph.pyct import anno from tensorflow.python.autograph.pyct import ast_util -from tensorflow.python.autograph.pyct import compiler +from tensorflow.python.autograph.pyct import loader from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import qual_names from tensorflow.python.platform import test @@ -39,28 +39,28 @@ class AstUtilTest(test.TestCase): self._invocation_counts = collections.defaultdict(lambda: 0) def test_rename_symbols_basic(self): - node = parser.parse_str('a + b') + node = parser.parse('a + b') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.QN('a'): qual_names.QN('renamed_a')}) self.assertIsInstance(node.value.left.id, str) - source = compiler.ast_to_source(node, include_encoding_marker=False) + source = parser.unparse(node, include_encoding_marker=False) self.assertEqual(source.strip(), 'renamed_a + b') def test_rename_symbols_attributes(self): - node = parser.parse_str('b.c = b.c.d') + node = parser.parse('b.c = b.c.d') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')}) - source = compiler.ast_to_source(node, include_encoding_marker=False) + source = parser.unparse(node, include_encoding_marker=False) self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d') def test_rename_symbols_annotations(self): - node = parser.parse_str('a[i]') + node = parser.parse('a[i]') node = qual_names.resolve(node) anno.setanno(node, 'foo', 'bar') orig_anno = anno.getanno(node, 'foo') @@ -71,7 +71,7 @@ class AstUtilTest(test.TestCase): self.assertIs(anno.getanno(node, 'foo'), orig_anno) def test_copy_clean(self): - node = parser.parse_str( + node = parser.parse( textwrap.dedent(""" def f(a): return a + 1 @@ -82,7 +82,7 @@ class AstUtilTest(test.TestCase): self.assertFalse(hasattr(new_node, '__foo')) def test_copy_clean_preserves_annotations(self): - node = parser.parse_str( + node = parser.parse( textwrap.dedent(""" def f(a): return a + 1 @@ -98,9 +98,9 @@ class AstUtilTest(test.TestCase): d = ast_util.keywords_to_dict(keywords) # Make sure we generate a usable dict node by attaching it to a variable and # compiling everything. - node = parser.parse_str('def f(b): pass') + node = parser.parse('def f(b): pass') node.body.append(ast.Return(d)) - result, _, _ = compiler.ast_to_object(node) + result, _, _ = loader.load_ast(node) self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'}) def assertMatch(self, target_str, pattern_str): @@ -131,12 +131,12 @@ class AstUtilTest(test.TestCase): 'super(Bar, _).__init__(_)') def _mock_apply_fn(self, target, source): - target = compiler.ast_to_source(target, include_encoding_marker=False) - source = compiler.ast_to_source(source, include_encoding_marker=False) + target = parser.unparse(target, include_encoding_marker=False) + source = parser.unparse(source, include_encoding_marker=False) self._invocation_counts[(target.strip(), source.strip())] += 1 def test_apply_to_single_assignments_dynamic_unpack(self): - node = parser.parse_str('a, b, c = d') + node = parser.parse('a, b, c = d') ast_util.apply_to_single_assignments(node.targets, node.value, self._mock_apply_fn) self.assertDictEqual(self._invocation_counts, { @@ -146,7 +146,7 @@ class AstUtilTest(test.TestCase): }) def test_apply_to_single_assignments_static_unpack(self): - node = parser.parse_str('a, b, c = d, e, f') + node = parser.parse('a, b, c = d, e, f') ast_util.apply_to_single_assignments(node.targets, node.value, self._mock_apply_fn) self.assertDictEqual(self._invocation_counts, { @@ -160,7 +160,7 @@ class AstUtilTest(test.TestCase): def f(a): return a + 1 """ - node = parser.parse_str(textwrap.dedent(src)) + node = parser.parse(textwrap.dedent(src)) for child_a, child_b in ast_util.parallel_walk(node, node): self.assertEqual(child_a, child_b) @@ -169,22 +169,22 @@ class AstUtilTest(test.TestCase): def f(a): global g """ - node = parser.parse_str(textwrap.dedent(src)) + node = parser.parse(textwrap.dedent(src)) for child_a, child_b in ast_util.parallel_walk(node, node): self.assertEqual(child_a, child_b) def test_parallel_walk_inconsistent_trees(self): - node_1 = parser.parse_str( + node_1 = parser.parse( textwrap.dedent(""" def f(a): return a + 1 """)) - node_2 = parser.parse_str( + node_2 = parser.parse( textwrap.dedent(""" def f(a): return a + (a * 2) """)) - node_3 = parser.parse_str( + node_3 = parser.parse( textwrap.dedent(""" def f(a): return a + 2 @@ -204,12 +204,11 @@ class AstUtilTest(test.TestCase): for node in matching_nodes: self.assertIsInstance(node, gast.Lambda) self.assertIn( - compiler.ast_to_source(node.body, - include_encoding_marker=False).strip(), + parser.unparse(node.body, include_encoding_marker=False).strip(), expected_bodies) def test_find_matching_definitions_lambda(self): - node = parser.parse_str( + node = parser.parse( textwrap.dedent(""" f = lambda x: 1 """)) @@ -218,7 +217,7 @@ class AstUtilTest(test.TestCase): self.assertLambdaNodes(nodes, ('(1)',)) def test_find_matching_definitions_lambda_multiple_matches(self): - node = parser.parse_str( + node = parser.parse( textwrap.dedent(""" f = lambda x: 1, lambda x: 2 """)) @@ -227,7 +226,7 @@ class AstUtilTest(test.TestCase): self.assertLambdaNodes(nodes, ('(1)', '(2)')) def test_find_matching_definitions_lambda_uses_arg_names(self): - node = parser.parse_str( + node = parser.parse( textwrap.dedent(""" f = lambda x: 1, lambda y: 2 """)) diff --git a/tensorflow/python/autograph/pyct/cfg.py b/tensorflow/python/autograph/pyct/cfg.py index ca3a8e55cf4..c2da09ef72b 100644 --- a/tensorflow/python/autograph/pyct/cfg.py +++ b/tensorflow/python/autograph/pyct/cfg.py @@ -38,7 +38,7 @@ from enum import Enum import gast # pylint:enable=g-bad-import-order -from tensorflow.python.autograph.pyct import compiler +from tensorflow.python.autograph.pyct import parser class Node(object): @@ -77,10 +77,9 @@ class Node(object): elif isinstance(self.ast_node, gast.ClassDef): return 'class %s' % self.ast_node.name elif isinstance(self.ast_node, gast.withitem): - return compiler.ast_to_source( + return parser.unparse( self.ast_node.context_expr, include_encoding_marker=False).strip() - return compiler.ast_to_source( - self.ast_node, include_encoding_marker=False).strip() + return parser.unparse(self.ast_node, include_encoding_marker=False).strip() class Graph( diff --git a/tensorflow/python/autograph/pyct/common_transformers/anf_test.py b/tensorflow/python/autograph/pyct/common_transformers/anf_test.py index f3bbba20925..e4a5a0accd5 100644 --- a/tensorflow/python/autograph/pyct/common_transformers/anf_test.py +++ b/tensorflow/python/autograph/pyct/common_transformers/anf_test.py @@ -22,7 +22,7 @@ import textwrap import gast -from tensorflow.python.autograph.pyct import compiler +from tensorflow.python.autograph.pyct import loader from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import transformer from tensorflow.python.autograph.pyct.common_transformers import anf @@ -76,9 +76,9 @@ class AnfTestBase(test.TestCase): return transformer.Context(entity_info) def assert_same_ast(self, expected_node, node, msg=None): - expected_source = compiler.ast_to_source(expected_node, indentation=' ') + expected_source = parser.unparse(expected_node, indentation=' ') expected_str = textwrap.dedent(expected_source).strip() - got_source = compiler.ast_to_source(node, indentation=' ') + got_source = parser.unparse(node, indentation=' ') got_str = textwrap.dedent(got_source).strip() self.assertEqual(expected_str, got_str, msg=msg) @@ -112,7 +112,7 @@ class AnfTransformerTest(AnfTestBase): node, _ = parser.parse_entity(test_function, future_features=()) node = anf.transform(node, self._simple_context()) - result, _, _ = compiler.ast_to_object(node) + result, _, _ = loader.load_ast(node) self.assertEqual(test_function(), result.test_function()) def test_binop_basic(self): @@ -463,13 +463,13 @@ class AnfNonTransformationTest(AnfTransformerTest): # syntax highlights nicely, but Python doesn't try to execute the # statements. node, _ = parser.parse_entity(test_fn, future_features=()) - orig_source = compiler.ast_to_source(node, indentation=' ') + orig_source = parser.unparse(node, indentation=' ') orig_str = textwrap.dedent(orig_source).strip() config = [(anf.ANY, anf.LEAVE)] # Configuration to trasform nothing node = anf.transform( node, self._simple_context(), config=config, gensym_source=DummyGensym) - new_source = compiler.ast_to_source(node, indentation=' ') + new_source = parser.unparse(node, indentation=' ') new_str = textwrap.dedent(new_source).strip() self.assertEqual(orig_str, new_str) diff --git a/tensorflow/python/autograph/pyct/inspect_utils.py b/tensorflow/python/autograph/pyct/inspect_utils.py index 47c52d2e8bb..ca9a0c9ea5d 100644 --- a/tensorflow/python/autograph/pyct/inspect_utils.py +++ b/tensorflow/python/autograph/pyct/inspect_utils.py @@ -93,6 +93,26 @@ def isbuiltin(f): return False +def isconstructor(cls): + """Returns True if the argument is an object constructor. + + In general, any object of type class is a constructor, with the exception + of classes created using a callable metaclass. + See below for why a callable metaclass is not a trivial combination: + https://docs.python.org/2.7/reference/datamodel.html#customizing-class-creation + + Args: + cls: Any + Returns: + Bool + """ + return ( + inspect.isclass(cls) + and not (issubclass(cls.__class__, type) + and hasattr(cls.__class__, '__call__') + and cls.__class__.__call__ is not type.__call__)) + + def _fix_linecache_record(obj): """Fixes potential corruption of linecache in the presence of functools.wraps. @@ -351,4 +371,3 @@ def getfutureimports(entity): return tuple() return tuple(sorted(name for name, value in entity.__globals__.items() if getattr(value, '__module__', None) == '__future__')) - diff --git a/tensorflow/python/autograph/pyct/inspect_utils_test.py b/tensorflow/python/autograph/pyct/inspect_utils_test.py index f8bd427becc..93b7d8237c5 100644 --- a/tensorflow/python/autograph/pyct/inspect_utils_test.py +++ b/tensorflow/python/autograph/pyct/inspect_utils_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import collections import functools import imp @@ -517,14 +518,14 @@ class InspectUtilsTest(test.TestCase): def baz(self): pass - self.assertTrue( - inspect_utils.getdefiningclass(Subclass.foo, Subclass) is Subclass) - self.assertTrue( - inspect_utils.getdefiningclass(Subclass.bar, Subclass) is Superclass) - self.assertTrue( - inspect_utils.getdefiningclass(Subclass.baz, Subclass) is Subclass) - self.assertTrue( - inspect_utils.getdefiningclass(Subclass.class_method, Subclass) is + self.assertIs( + inspect_utils.getdefiningclass(Subclass.foo, Subclass), Subclass) + self.assertIs( + inspect_utils.getdefiningclass(Subclass.bar, Subclass), Superclass) + self.assertIs( + inspect_utils.getdefiningclass(Subclass.baz, Subclass), Subclass) + self.assertIs( + inspect_utils.getdefiningclass(Subclass.class_method, Subclass), Superclass) def test_isbuiltin(self): @@ -537,6 +538,53 @@ class InspectUtilsTest(test.TestCase): self.assertTrue(inspect_utils.isbuiltin(zip)) self.assertFalse(inspect_utils.isbuiltin(function_decorator)) + def test_isconstructor(self): + + class OrdinaryClass(object): + pass + + class OrdinaryCallableClass(object): + + def __call__(self): + pass + + class Metaclass(type): + pass + + class CallableMetaclass(type): + + def __call__(cls): + pass + + self.assertTrue(inspect_utils.isconstructor(OrdinaryClass)) + self.assertTrue(inspect_utils.isconstructor(OrdinaryCallableClass)) + self.assertTrue(inspect_utils.isconstructor(Metaclass)) + self.assertTrue(inspect_utils.isconstructor(Metaclass('TestClass', (), {}))) + self.assertTrue(inspect_utils.isconstructor(CallableMetaclass)) + + self.assertFalse(inspect_utils.isconstructor( + CallableMetaclass('TestClass', (), {}))) + + def test_isconstructor_abc_callable(self): + + @six.add_metaclass(abc.ABCMeta) + class AbcBase(object): + + @abc.abstractmethod + def __call__(self): + pass + + class AbcSubclass(AbcBase): + + def __init__(self): + pass + + def __call__(self): + pass + + self.assertTrue(inspect_utils.isconstructor(AbcBase)) + self.assertTrue(inspect_utils.isconstructor(AbcSubclass)) + def test_getfutureimports_functions(self): self.assertEqual( inspect_utils.getfutureimports(basic_definitions.function_with_print), diff --git a/tensorflow/python/autograph/pyct/compiler.py b/tensorflow/python/autograph/pyct/loader.py similarity index 55% rename from tensorflow/python/autograph/pyct/compiler.py rename to tensorflow/python/autograph/pyct/loader.py index 297f28cfeaf..3690833b793 100644 --- a/tensorflow/python/autograph/pyct/compiler.py +++ b/tensorflow/python/autograph/pyct/loader.py @@ -29,64 +29,14 @@ import imp import os import tempfile -import astor -import gast import six from tensorflow.python.autograph.pyct import origin_info +from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.utils import ag_logging -def ast_to_source(node, indentation=' ', include_encoding_marker=True): - """Return the source code of given AST. - - Args: - node: The code to compile, as an AST object. - indentation: The string to use for indentation. - include_encoding_marker: Bool, thether to include a comment on the first - line to explicitly specify UTF-8 encoding. - - Returns: - code: The source code generated from the AST object - source_mapping: A mapping between the user and AutoGraph generated code. - """ - if not isinstance(node, (list, tuple)): - node = (node,) - generator = astor.code_gen.SourceGenerator(indentation, False, - astor.string_repr.pretty_string) - - for n in node: - if isinstance(n, gast.AST): - n = gast.gast_to_ast(n) - generator.visit(n) - generator.result.append('\n') - - # In some versions of Python, literals may appear as actual values. This - # ensures everything is string. - code = ''.join(map(str, generator.result)) - - # Strip leading blank lines. - code_lines = code.split('\n') - trimmed_code_lines = [] - for l in code_lines: - if l.rstrip() or trimmed_code_lines: - trimmed_code_lines.append(l) - code = '\n'.join(trimmed_code_lines) - - # Work around the reference cycle generated by astor. - # See https://github.com/berkerpeksag/astor/blob/55dd323f7d8d696610c703c0296763c567685c31/astor/code_gen.py#L162 # pylint:disable=line-too-long - # Reference cycles are quite disliked by TensorFlow's tests. - if hasattr(generator, 'write'): - generator.write = None - del generator - - if include_encoding_marker: - code = '# coding=utf-8\n' + code - - return code - - -def source_to_entity(source, delete_on_exit): +def load_source(source, delete_on_exit): """Loads the given source code as a Python module.""" if six.PY2: source = source.encode('utf-8') @@ -104,23 +54,22 @@ def source_to_entity(source, delete_on_exit): return imp.load_source(module_name, f.name), f.name -# TODO(mdan): Rename: ast_to_entity -def ast_to_object(nodes, - indentation=' ', - include_source_map=False, - delete_on_exit=True): - """Return the Python objects represented by given AST. +def load_ast(nodes, + indentation=' ', + include_source_map=False, + delete_on_exit=True): + """Loads the given AST as a Python module. Compiling the AST code this way ensures that the source code is readable by e.g. `pdb` or `inspect`. Args: nodes: Union[ast.AST, Iterable[ast.AST]], the code to compile, as an AST - object. + object. indentation: Text, the string to use for indentation. include_source_map: bool, whether return a source map. delete_on_exit: bool, whether to delete the temporary file used for - compilation on exit. + compilation on exit. Returns: Tuple[module, Text, Dict[LineLocation, OriginInfo]], containing: @@ -131,8 +80,8 @@ def ast_to_object(nodes, if not isinstance(nodes, (list, tuple)): nodes = (nodes,) - source = ast_to_source(nodes, indentation=indentation) - module, _ = source_to_entity(source, delete_on_exit) + source = parser.unparse(nodes, indentation=indentation) + module, _ = load_source(source, delete_on_exit) if include_source_map: source_map = origin_info.create_source_map(nodes, source, module.__file__) diff --git a/tensorflow/python/autograph/pyct/compiler_test.py b/tensorflow/python/autograph/pyct/loader_test.py similarity index 71% rename from tensorflow/python/autograph/pyct/compiler_test.py rename to tensorflow/python/autograph/pyct/loader_test.py index 3be0060612a..da7e336c5bc 100644 --- a/tensorflow/python/autograph/pyct/compiler_test.py +++ b/tensorflow/python/autograph/pyct/loader_test.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for compiler module.""" +"""Tests for loader module.""" from __future__ import absolute_import from __future__ import division @@ -23,15 +23,15 @@ import textwrap import gast -from tensorflow.python.autograph.pyct import compiler +from tensorflow.python.autograph.pyct import loader from tensorflow.python.autograph.pyct import parser from tensorflow.python.platform import test from tensorflow.python.util import tf_inspect -class CompilerTest(test.TestCase): +class LoaderTest(test.TestCase): - def test_parser_compile_identity(self): + def test_parse_load_identity(self): def test_fn(x): a = True @@ -41,37 +41,13 @@ class CompilerTest(test.TestCase): return b node, _ = parser.parse_entity(test_fn, future_features=()) - module, _, _ = compiler.ast_to_object(node) + module, _, _ = loader.load_ast(node) self.assertEqual( textwrap.dedent(tf_inspect.getsource(test_fn)), tf_inspect.getsource(module.test_fn)) - def test_ast_to_source(self): - node = gast.If( - test=gast.Num(1), - body=[ - gast.Assign( - targets=[gast.Name('a', gast.Store(), None)], - value=gast.Name('b', gast.Load(), None)) - ], - orelse=[ - gast.Assign( - targets=[gast.Name('a', gast.Store(), None)], - value=gast.Str('c')) - ]) - - source = compiler.ast_to_source(node, indentation=' ') - self.assertEqual( - textwrap.dedent(""" - # coding=utf-8 - if 1: - a = b - else: - a = 'c' - """).strip(), source.strip()) - - def test_ast_to_object(self): + def test_load_ast(self): node = gast.FunctionDef( name='f', args=gast.arguments( @@ -91,7 +67,7 @@ class CompilerTest(test.TestCase): decorator_list=[], returns=None) - module, source, _ = compiler.ast_to_object(node) + module, source, _ = loader.load_ast(node) expected_source = """ # coding=utf-8 @@ -107,14 +83,14 @@ class CompilerTest(test.TestCase): textwrap.dedent(expected_source).strip(), temp_output.read().strip()) - def test_source_to_entity(self): + def test_load_source(self): test_source = textwrap.dedent(u""" # coding=utf-8 def f(a): '日本語 Δθₜ ← Δθₜ₋₁ + ∇Q(sₜ, aₜ)(rₜ + γₜ₊₁ max Q(⋅))' return a + 1 """) - module, _ = compiler.source_to_entity(test_source, delete_on_exit=True) + module, _ = loader.load_source(test_source, delete_on_exit=True) self.assertEqual(module.f(1), 2) self.assertEqual( module.f.__doc__, '日本語 Δθₜ ← Δθₜ₋₁ + ∇Q(sₜ, aₜ)(rₜ + γₜ₊₁ max Q(⋅))') diff --git a/tensorflow/python/autograph/pyct/origin_info.py b/tensorflow/python/autograph/pyct/origin_info.py index 5479fefbb22..ae1d5e18334 100644 --- a/tensorflow/python/autograph/pyct/origin_info.py +++ b/tensorflow/python/autograph/pyct/origin_info.py @@ -102,7 +102,7 @@ def create_source_map(nodes, code, filepath): Dict[LineLocation, OriginInfo], mapping locations in code to locations indicated by origin annotations in node. """ - reparsed_nodes = parser.parse_str(code, preamble_len=0, single_node=False) + reparsed_nodes = parser.parse(code, preamble_len=0, single_node=False) for node in reparsed_nodes: resolve(node, code, filepath, node.lineno, node.col_offset) diff --git a/tensorflow/python/autograph/pyct/origin_info_test.py b/tensorflow/python/autograph/pyct/origin_info_test.py index 91c6ee5778f..01ded4cc559 100644 --- a/tensorflow/python/autograph/pyct/origin_info_test.py +++ b/tensorflow/python/autograph/pyct/origin_info_test.py @@ -39,7 +39,7 @@ class OriginInfoTest(test.TestCase): """ source = textwrap.dedent(source) - node = parser.parse_str(source) + node = parser.parse(source) fake_origin = origin_info.OriginInfo( loc=origin_info.Location('fake_filename', 3, 7), function_name='fake_function_name', @@ -118,7 +118,7 @@ class OriginInfoTest(test.TestCase): return x # comment """ source = textwrap.dedent(source) - node = parser.parse_str(source) + node = parser.parse(source) origin_info.resolve(node, source, 'test_file', 10, 10) def_origin = anno.getanno(node, anno.Basic.ORIGIN) diff --git a/tensorflow/python/autograph/pyct/parser.py b/tensorflow/python/autograph/pyct/parser.py index c5b2fe5832a..1b745fa4219 100644 --- a/tensorflow/python/autograph/pyct/parser.py +++ b/tensorflow/python/autograph/pyct/parser.py @@ -25,6 +25,7 @@ import re import textwrap import tokenize +import astor import gast import six @@ -109,7 +110,7 @@ def dedent_block(code_string): def _attempt_to_parse_normal_source(source, future_features): - return parse_str(source, preamble_len=len(future_features)), source + return parse(source, preamble_len=len(future_features)), source def _attempt_to_parse_lambda_source(source, original_source, @@ -131,17 +132,17 @@ def _attempt_to_parse_lambda_source(source, original_source, source: the processed source code of `entity`. original_source: the source code of `entity`, as it was reported by `inspect.getsource`. - future_features: see `parse_str`. + future_features: see `parse`. try_fallback: whether to attempt to remove extra code from `source` before one more attempt to parse it. Returns: - Same as `parse_str`. + Same as `parse`. """ try: - return parse_str(source, preamble_len=len(future_features)), source + return parse(source, preamble_len=len(future_features)), source - # Note: the ValueError may be raised by parse_str. + # Note: the ValueError may be raised by parse. except (SyntaxError, ValueError) as e: def fail(): raise errors.UnsupportedLanguageElementError( @@ -209,7 +210,7 @@ def parse_entity(entity, future_features): # TODO(mdan): This should take futures as input instead. -def parse_str(src, preamble_len=0, single_node=True): +def parse(src, preamble_len=0, single_node=True): """Returns the AST of given piece of code. Args: @@ -244,9 +245,58 @@ def parse_expression(src): ValueError: if src does not consist of a single Expression. """ src = STANDARD_PREAMBLE + src.strip() - node = parse_str(src, preamble_len=STANDARD_PREAMBLE_LEN, single_node=True) + node = parse(src, preamble_len=STANDARD_PREAMBLE_LEN, single_node=True) if __debug__: if not isinstance(node, gast.Expr): raise ValueError( 'expected a single expression, found instead {}'.format(node)) return node.value + + +def unparse(node, indentation=' ', include_encoding_marker=True): + """Returns the source code of given AST. + + Args: + node: The code to compile, as an AST object. + indentation: The string to use for indentation. + include_encoding_marker: Bool, thether to include a comment on the first + line to explicitly specify UTF-8 encoding. + + Returns: + code: The source code generated from the AST object + source_mapping: A mapping between the user and AutoGraph generated code. + """ + if not isinstance(node, (list, tuple)): + node = (node,) + generator = astor.code_gen.SourceGenerator(indentation, False, + astor.string_repr.pretty_string) + + for n in node: + if isinstance(n, gast.AST): + n = gast.gast_to_ast(n) + generator.visit(n) + generator.result.append('\n') + + # In some versions of Python, literals may appear as actual values. This + # ensures everything is string. + code = ''.join(map(str, generator.result)) + + # Strip leading blank lines. + code_lines = code.split('\n') + trimmed_code_lines = [] + for l in code_lines: + if l.rstrip() or trimmed_code_lines: + trimmed_code_lines.append(l) + code = '\n'.join(trimmed_code_lines) + + # Work around the reference cycle generated by astor. + # See https://github.com/berkerpeksag/astor/blob/55dd323f7d8d696610c703c0296763c567685c31/astor/code_gen.py#L162 # pylint:disable=line-too-long + # Reference cycles are quite disliked by TensorFlow's tests. + if hasattr(generator, 'write'): + generator.write = None + del generator + + if include_encoding_marker: + code = '# coding=utf-8\n' + code + + return code diff --git a/tensorflow/python/autograph/pyct/parser_test.py b/tensorflow/python/autograph/pyct/parser_test.py index ef62d140525..f5c1dcb7021 100644 --- a/tensorflow/python/autograph/pyct/parser_test.py +++ b/tensorflow/python/autograph/pyct/parser_test.py @@ -18,6 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import textwrap + +import gast + from tensorflow.python.autograph.pyct import parser from tensorflow.python.platform import test @@ -130,6 +134,30 @@ string""") self.assertEqual('a', node.value.id) self.assertEqual('b', node.attr) + def test_unparse(self): + node = gast.If( + test=gast.Num(1), + body=[ + gast.Assign( + targets=[gast.Name('a', gast.Store(), None)], + value=gast.Name('b', gast.Load(), None)) + ], + orelse=[ + gast.Assign( + targets=[gast.Name('a', gast.Store(), None)], + value=gast.Str('c')) + ]) + + source = parser.unparse(node, indentation=' ') + self.assertEqual( + textwrap.dedent(""" + # coding=utf-8 + if 1: + a = b + else: + a = 'c' + """).strip(), source.strip()) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/autograph/pyct/qual_names_test.py b/tensorflow/python/autograph/pyct/qual_names_test.py index 48db7bd7fe0..f32bf19e946 100644 --- a/tensorflow/python/autograph/pyct/qual_names_test.py +++ b/tensorflow/python/autograph/pyct/qual_names_test.py @@ -192,7 +192,7 @@ class QNResolverTest(test.TestCase): [f, (g.h.i)] j(k, l) """ - nodes = parser.parse_str(textwrap.dedent(samples), single_node=False) + nodes = parser.parse(textwrap.dedent(samples), single_node=False) nodes = tuple(resolve(node).value for node in nodes) self.assertQNStringIs(nodes[0], 'a') @@ -218,7 +218,7 @@ class QNResolverTest(test.TestCase): a.b[c[d]].e.f a.b[c[d.e.f].g].h """ - nodes = parser.parse_str(textwrap.dedent(samples), single_node=False) + nodes = parser.parse(textwrap.dedent(samples), single_node=False) nodes = tuple(resolve(node).value for node in nodes) self.assertQNStringIs(nodes[0], 'x[i]') @@ -241,7 +241,7 @@ class QNResolverTest(test.TestCase): z[i]() z()[i] """ - nodes = parser.parse_str(textwrap.dedent(samples), single_node=False) + nodes = parser.parse(textwrap.dedent(samples), single_node=False) nodes = tuple(resolve(node).value for node in nodes) self.assertQNStringIs(nodes[0], 'a.b') self.assertQNStringIs(nodes[1].func, 'a.b') diff --git a/tensorflow/python/autograph/pyct/templates.py b/tensorflow/python/autograph/pyct/templates.py index 253e2943a12..165319ef02b 100644 --- a/tensorflow/python/autograph/pyct/templates.py +++ b/tensorflow/python/autograph/pyct/templates.py @@ -120,6 +120,7 @@ class ReplaceTransformer(gast.NodeTransformer): self.preserved_annos = { anno.Basic.ORIGIN, anno.Basic.SKIP_PROCESSING, + anno.Basic.DIRECTIVES, anno.Static.ORIG_DEFINITIONS, 'extra_test', 'function_context_name', @@ -260,7 +261,7 @@ def replace(template, **replacements): for k in replacements: replacements[k] = _convert_to_ast(replacements[k]) template_str = parser.STANDARD_PREAMBLE + textwrap.dedent(template) - nodes = parser.parse_str( + nodes = parser.parse( template_str, preamble_len=parser.STANDARD_PREAMBLE_LEN, single_node=False) diff --git a/tensorflow/python/autograph/pyct/templates_test.py b/tensorflow/python/autograph/pyct/templates_test.py index 5ed10d9c937..2085e555ff4 100644 --- a/tensorflow/python/autograph/pyct/templates_test.py +++ b/tensorflow/python/autograph/pyct/templates_test.py @@ -23,7 +23,7 @@ import imp from absl.testing import parameterized import gast -from tensorflow.python.autograph.pyct import compiler +from tensorflow.python.autograph.pyct import loader from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import qual_names as qn from tensorflow.python.autograph.pyct import templates @@ -75,7 +75,7 @@ class TemplatesTest(test.TestCase, parameterized.TestCase): """ node = templates.replace(template, b=('a', 'c'))[0] - result, _, _ = compiler.ast_to_object(node) + result, _, _ = loader.load_ast(node) self.assertEqual((2, 3), result.test_fn(2, 3)) @@ -88,7 +88,7 @@ class TemplatesTest(test.TestCase, parameterized.TestCase): """ node = templates.replace(template, a='b')[0] - result, _, _ = compiler.ast_to_object(node) + result, _, _ = loader.load_ast(node) self.assertEqual(7, result.test_fn(2)) def test_replace_function_name(self): @@ -100,7 +100,7 @@ class TemplatesTest(test.TestCase, parameterized.TestCase): """ node = templates.replace(template, fname='test_fn')[0] - result, _, _ = compiler.ast_to_object(node) + result, _, _ = loader.load_ast(node) self.assertEqual(7, result.test_fn(2)) def test_replace_code_block(self): @@ -117,7 +117,7 @@ class TemplatesTest(test.TestCase, parameterized.TestCase): gast.Name('a', None, None) ], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))), ] * 2)[0] - result, _, _ = compiler.ast_to_object(node) + result, _, _ = loader.load_ast(node) self.assertEqual(3, result.test_fn(1)) def test_replace_attribute(self): @@ -127,7 +127,7 @@ class TemplatesTest(test.TestCase, parameterized.TestCase): """ node = templates.replace(template, foo='b')[0] - result, _, _ = compiler.ast_to_object(node) + result, _, _ = loader.load_ast(node) mod = imp.new_module('test') mod.b = 3 self.assertEqual(3, result.test_fn(mod)) @@ -217,7 +217,7 @@ class TemplatesTest(test.TestCase, parameterized.TestCase): source = parser.parse_expression('f(d=3, f=5)') node = templates.replace(template, kws=source.keywords)[0] - result, _, _ = compiler.ast_to_object(node) + result, _, _ = loader.load_ast(node) self.assertEqual(9, result.test_fn()) with self.assertRaises(ValueError): @@ -237,7 +237,7 @@ class TemplatesTest(test.TestCase, parameterized.TestCase): source = parser.parse_expression('f()(b)') node = templates.replace(template, foo=source)[0] - result, _, _ = compiler.ast_to_object(node) + result, _, _ = loader.load_ast(node) self.assertEqual(15, result.test_fn()) def test_replace_name_with_dict(self): @@ -248,7 +248,7 @@ class TemplatesTest(test.TestCase, parameterized.TestCase): source = parser.parse_expression('{\'bar\': 3}') node = templates.replace(template, foo=source)[0] - result, _, _ = compiler.ast_to_object(node) + result, _, _ = loader.load_ast(node) self.assertEqual(3, result.test_fn()) def test_replace_as_expression(self): diff --git a/tensorflow/python/autograph/pyct/transformer.py b/tensorflow/python/autograph/pyct/transformer.py index 592ff0c45e6..ddc31737155 100644 --- a/tensorflow/python/autograph/pyct/transformer.py +++ b/tensorflow/python/autograph/pyct/transformer.py @@ -23,7 +23,7 @@ import collections import gast from tensorflow.python.autograph.pyct import anno -from tensorflow.python.autograph.pyct import compiler +from tensorflow.python.autograph.pyct import loader from tensorflow.python.autograph.pyct import pretty_printer from tensorflow.python.autograph.pyct import templates @@ -301,7 +301,7 @@ class Base(gast.NodeTransformer): def debug_print_src(self, node): """Helper method useful for debugging. Prints the AST as code.""" if __debug__: - print(compiler.ast_to_source(node)) + print(loader.load_ast(node)) return node def create_assignment(self, target, expression): @@ -436,7 +436,7 @@ class Base(gast.NodeTransformer): def _get_source(self, node): try: - source, _ = compiler.ast_to_source(node) + source, _ = loader.load_ast(node) return source # pylint: disable=broad-except # This function is used for error reporting. If an exception occurs here, diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 7c1d9f9f238..71427c9c237 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -31,7 +31,7 @@ from tensorflow.python.util.tf_export import tf_export # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 11, 27) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 12, 4) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None diff --git a/tensorflow/python/data/experimental/benchmarks/BUILD b/tensorflow/python/data/experimental/benchmarks/BUILD index 0540fa069d3..0759a345b29 100644 --- a/tensorflow/python/data/experimental/benchmarks/BUILD +++ b/tensorflow/python/data/experimental/benchmarks/BUILD @@ -7,6 +7,11 @@ package( exports_files(["LICENSE"]) +exports_files( + ["autotune_benchmark.py"], + visibility = ["//tensorflow:internal"], +) + tf_py_test( name = "autotune_benchmark", srcs = ["autotune_benchmark.py"], diff --git a/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py index 73e68ebcf42..5f13bdae849 100644 --- a/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py @@ -23,8 +23,8 @@ from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_ from tensorflow.python.data.experimental.ops import distribute from tensorflow.python.data.experimental.ops import distribute_options from tensorflow.python.data.experimental.ops import interleave_ops -from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.experimental.ops import readers +from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.experimental.ops import unique from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops diff --git a/tensorflow/python/data/experimental/kernel_tests/matching_files_test.py b/tensorflow/python/data/experimental/kernel_tests/matching_files_test.py index 1bf07a98f28..1240b704119 100644 --- a/tensorflow/python/data/experimental/kernel_tests/matching_files_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/matching_files_test.py @@ -34,10 +34,12 @@ class MatchingFilesDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def setUp(self): + super(MatchingFilesDatasetTest, self).setUp() self.tmp_dir = tempfile.mkdtemp() def tearDown(self): shutil.rmtree(self.tmp_dir, ignore_errors=True) + super(MatchingFilesDatasetTest, self).tearDown() def _touchTempFiles(self, filenames): for filename in filenames: diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD index 1e9d1ca1d00..4cd2a3d1fcd 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD @@ -100,6 +100,24 @@ tf_py_test( ], ) +tf_py_test( + name = "inject_prefetch_test", + size = "small", + srcs = ["inject_prefetch_test.py"], + additional_deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/experimental/ops:testing", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + ], + tags = [ + "no_oss", + "no_pip", + "no_windows", + ], +) + tf_py_test( name = "latency_all_edges_test", size = "small", diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/choose_fastest_branch_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/choose_fastest_branch_dataset_test.py index ee05ae0603d..bb7849fb213 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/choose_fastest_branch_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/choose_fastest_branch_dataset_test.py @@ -23,18 +23,18 @@ from tensorflow.python.data.experimental.ops import optimization from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context +from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import test_util from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): + @combinations.generate(test_base.default_test_combinations()) def testSimple(self): dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3, 4]) @@ -49,6 +49,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase, expected_output=[0, 1, 2, 3, 4], expected_shapes=dataset_ops.get_legacy_output_shapes(dataset)) + @combinations.generate(test_base.default_test_combinations()) def testCaptureSimple(self): dataset = dataset_ops.Dataset.range(10) @@ -67,6 +68,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase, self.assertDatasetProduces( choose_fastest, expected_output=list(range(1, 11))) + @combinations.generate(test_base.default_test_combinations()) def testDifferentFunctions(self): dataset = dataset_ops.Dataset.range(100) @@ -83,6 +85,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase, choose_fastest, expected_output=[list(range(10 * x, 10 * x + 10)) for x in range(10)]) + @combinations.generate(test_base.default_test_combinations()) def testWithRepeatBeforeAndAfter(self): dataset = dataset_ops.Dataset.from_tensors(0).repeat(10) @@ -99,6 +102,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase, self.assertDatasetProduces( choose_fastest, expected_output=[[0] * 10 for _ in range(10)]) + @combinations.generate(test_base.default_test_combinations()) def testWithPrefetch(self): """Should maintain ordering even if the branches do prefetching.""" dataset = dataset_ops.Dataset.range(100) @@ -114,6 +118,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase, self.assertDatasetProduces(choose_fastest, expected_output=list(range(100))) + @combinations.generate(test_base.default_test_combinations()) def testWithMoreOutputThanInput(self): dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100) @@ -128,6 +133,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase, self.assertDatasetProduces(choose_fastest, expected_output=[0] * 1000) + @combinations.generate(test_base.default_test_combinations()) def testWithBadNumElements(self): dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100) @@ -153,6 +159,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase, choose_fastest, expected_error=(errors.InvalidArgumentError, expected_error_msg)) + @combinations.generate(test_base.default_test_combinations()) def testErrorWithRepeat(self): dataset = dataset_ops.Dataset.from_tensors(0) diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/choose_fastest_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/choose_fastest_dataset_test.py index 3e51de9f1ee..6e0d9842c48 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/choose_fastest_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/choose_fastest_dataset_test.py @@ -23,15 +23,15 @@ from tensorflow.python.data.experimental.ops import optimization from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context +from tensorflow.python.framework import combinations from tensorflow.python.framework import errors -from tensorflow.python.framework import test_util from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes class ChooseFastestDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): + @combinations.generate(test_base.default_test_combinations()) def testChooseFastestSimple(self): dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3, 4]) merge = optimization._ChooseFastestDataset([dataset, dataset]) @@ -40,6 +40,7 @@ class ChooseFastestDatasetTest(test_base.DatasetTestBase, expected_output=[0, 1, 2, 3, 4], expected_shapes=dataset_ops.get_legacy_output_shapes(dataset)) + @combinations.generate(test_base.default_test_combinations()) def testChooseFastestManyInputs(self): dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3, 4]) merge = optimization._ChooseFastestDataset([dataset for _ in range(5)]) @@ -48,6 +49,7 @@ class ChooseFastestDatasetTest(test_base.DatasetTestBase, expected_output=[0, 1, 2, 3, 4], expected_shapes=dataset_ops.get_legacy_output_shapes(dataset)) + @combinations.generate(test_base.default_test_combinations()) def testChooseFastest(self): dataset = dataset_ops.Dataset.range(600) f = lambda x: 2 * x @@ -61,11 +63,25 @@ class ChooseFastestDatasetTest(test_base.DatasetTestBase, ], expected_shapes=dataset_ops.get_legacy_output_shapes(dataset_a)) - @parameterized.named_parameters( - ("Shapes", [0], [[1, 2, 3]], "must have compatible output shapes."), - ("Types", [0], [0.0], "must have the same output types."), - ("NumComponents", [0], ([0], [1]), "must have the same output types."), - ("Cardinality", [1, 2, 3], [1], "must have compatible cardinalities.")) + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine( + slices_a=[[0]], + slices_b=[[[1, 2, 3]]], + error_msg="must have compatible output shapes.") + + combinations.combine( + slices_a=[[0]], + slices_b=[[0.0]], + error_msg="must have the same output types.") + + combinations.combine( + slices_a=[[0]], + slices_b=[([0], [1])], + error_msg="must have the same output types.") + + combinations.combine( + slices_a=[[1, 2, 3]], + slices_b=[[0]], + error_msg="must have compatible cardinalities."))) def testChooseFastestErrorWithIncompatibleInput(self, slices_a, slices_b, error_msg): dataset_a = dataset_ops.Dataset.from_tensor_slices(slices_a) diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/filter_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/filter_fusion_test.py index 1aa3d636f02..949f9e2e25c 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/filter_fusion_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/filter_fusion_test.py @@ -22,47 +22,16 @@ from absl.testing import parameterized from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import test_util from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -def _filter_fusion_test_cases(): - """Generates test cases for the FilterFusion optimization.""" - - take_all = lambda x: constant_op.constant(True) - is_zero = lambda x: math_ops.equal(x, 0) - greater = lambda x: math_ops.greater(x + 5, 0) - - tests = [] - filters = [take_all, is_zero, greater] - identity = lambda x: x - for x, predicate_1 in enumerate(filters): - for y, predicate_2 in enumerate(filters): - tests.append(("Mixed{}{}".format(x, y), identity, - [predicate_1, predicate_2])) - for z, predicate_3 in enumerate(filters): - tests.append(("Mixed{}{}{}".format(x, y, z), identity, - [predicate_1, predicate_2, predicate_3])) - - take_all_multiple = lambda x, y: constant_op.constant(True) - # Multi output - tests.append(("Multi1", lambda x: (x, x), - [take_all_multiple, take_all_multiple])) - tests.append(("Multi2", lambda x: (x, 2), [ - take_all_multiple, - lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0) - ])) - return tuple(tests) - - -@test_util.run_all_in_graph_and_eager_modes class FilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase): - @parameterized.named_parameters(*_filter_fusion_test_cases()) - def testFilterFusion(self, map_function, predicates): + def _testFilterFusion(self, map_function, predicates): dataset = dataset_ops.Dataset.range(5).apply( testing.assert_next(["Map", "Filter", "MemoryCacheImpl"])).map(map_function) @@ -91,6 +60,26 @@ class FilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase): expected_output.append(r) self.assertDatasetProduces(dataset, expected_output=expected_output) + @combinations.generate(test_base.default_test_combinations()) + def testFilterFusionScalar(self): + take_all = lambda x: constant_op.constant(True) + is_zero = lambda x: math_ops.equal(x, 0) + greater = lambda x: math_ops.greater(x + 5, 0) + predicates = [take_all, is_zero, greater] + for x in predicates: + for y in predicates: + self._testFilterFusion(lambda x: x, [x, y]) + for z in predicates: + self._testFilterFusion(lambda x: x, [x, y, z]) + + @combinations.generate(test_base.default_test_combinations()) + def testFilterFusionTuple(self): + take_all = lambda x, y: constant_op.constant(True) + is_zero = lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0) + + self._testFilterFusion(lambda x: (x, x), [take_all, take_all]) + self._testFilterFusion(lambda x: (x, 2), [take_all, is_zero]) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/filter_with_random_uniform_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/filter_with_random_uniform_fusion_test.py index 2b130f40fc9..76006252367 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/filter_with_random_uniform_fusion_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/filter_with_random_uniform_fusion_test.py @@ -17,17 +17,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import test_util +from tensorflow.python.framework import combinations from tensorflow.python.ops import random_ops from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class FilterWithRandomUniformFusionTest(test_base.DatasetTestBase): +class FilterWithRandomUniformFusionTest(test_base.DatasetTestBase, + parameterized.TestCase): + @combinations.generate(test_base.default_test_combinations()) def testFilterWithRandomUniformFusion(self): dataset = dataset_ops.Dataset.range(10000000).apply( testing.assert_next(["Sampling"])) diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py index 928b435fe5c..59f50fa1752 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py @@ -22,44 +22,17 @@ from absl.testing import parameterized from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test -def _hoist_random_uniform_test_cases(): - """Generates test cases for the HoistRandomUniform optimization.""" - - plus_one = lambda x: x + 1 - - def random(_): - return random_ops.random_uniform([], - minval=1, - maxval=10, - dtype=dtypes.float32, - seed=42) - - def random_with_assert(x): - y = random(x) - assert_op = control_flow_ops.Assert(math_ops.greater_equal(y, 1), [y]) - with ops.control_dependencies([assert_op]): - return y - - twice_random = lambda x: (random(x) + random(x)) / 2. - - tests = [("PlusOne", plus_one, False), ("RandomUniform", random, True), - ("RandomWithAssert", random_with_assert, True), - ("TwiceRandom", twice_random, False)] - return tuple(tests) - - -@test_util.run_all_in_graph_and_eager_modes class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase): def _testDataset(self, dataset): @@ -78,11 +51,10 @@ class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) - @parameterized.named_parameters(*_hoist_random_uniform_test_cases()) - def testHoisting(self, function, will_optimize): + def _testHoistFunction(self, function, should_optimize): dataset = dataset_ops.Dataset.range(5).apply( testing.assert_next( - ["Zip[0]", "Map"] if will_optimize else ["Map"])).map(function) + ["Zip[0]", "Map"] if should_optimize else ["Map"])).map(function) options = dataset_ops.Options() options.experimental_optimization.apply_default_optimizations = False @@ -90,6 +62,32 @@ class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase): dataset = dataset.with_options(options) self._testDataset(dataset) + @combinations.generate(test_base.default_test_combinations()) + def testNoRandom(self): + self._testHoistFunction(lambda x: x + 1, should_optimize=False) + + @combinations.generate(test_base.default_test_combinations()) + def testRandom(self): + + def random(_): + return random_ops.random_uniform([], + minval=1, + maxval=10, + dtype=dtypes.float32, + seed=42) + + def random_with_assert(x): + y = random(x) + assert_op = control_flow_ops.Assert(math_ops.greater_equal(y, 1), [y]) + with ops.control_dependencies([assert_op]): + return y + + self._testHoistFunction(random, should_optimize=True) + self._testHoistFunction(random_with_assert, should_optimize=True) + self._testHoistFunction( + lambda x: (random(x) + random(x)) / 2, should_optimize=False) + + @combinations.generate(test_base.default_test_combinations()) def testCapturedInputs(self): a = constant_op.constant(1, dtype=dtypes.float32) b = constant_op.constant(0, dtype=dtypes.float32) diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/inject_prefetch_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/inject_prefetch_test.py index 89f61f141b0..d1a45d7328e 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/inject_prefetch_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/inject_prefetch_test.py @@ -17,35 +17,38 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.data.experimental.ops import optimization +from absl.testing import parameterized + +from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import test_util +from tensorflow.python.framework import combinations from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class InjectPrefetchTest(test_base.DatasetTestBase): +class InjectPrefetchTest(test_base.DatasetTestBase, parameterized.TestCase): def _enable_autotune_buffers(self, dataset): options = dataset_ops.Options() options.experimental_optimization.autotune_buffers = True return dataset.with_options(options) + @combinations.generate(test_base.default_test_combinations()) def testParallelMap(self): dataset = dataset_ops.Dataset.range(100) dataset = dataset.apply( - optimization.assert_next(["ParallelMap", "Prefetch", "FiniteTake"])) + testing.assert_next(["ParallelMap", "Prefetch", "FiniteTake"])) dataset = dataset.map( lambda x: x + 1, num_parallel_calls=dataset_ops.AUTOTUNE) dataset = dataset.take(50) dataset = self._enable_autotune_buffers(dataset) self.assertDatasetProduces(dataset, range(1, 51)) + @combinations.generate(test_base.default_test_combinations()) def testMapAndBatch(self): dataset = dataset_ops.Dataset.range(100) dataset = dataset.apply( - optimization.assert_next(["MapAndBatch", "Prefetch", "FiniteTake"])) + testing.assert_next(["MapAndBatch", "Prefetch", "FiniteTake"])) dataset = dataset.map( lambda x: x + 1, num_parallel_calls=dataset_ops.AUTOTUNE) dataset = dataset.batch(10) @@ -54,10 +57,11 @@ class InjectPrefetchTest(test_base.DatasetTestBase): self.assertDatasetProduces( dataset, [list(range(i + 1, i + 11)) for i in range(0, 50, 10)]) + @combinations.generate(test_base.default_test_combinations()) def testParallelInterleaveV2(self): dataset = dataset_ops.Dataset.range(100) dataset = dataset.apply( - optimization.assert_next( + testing.assert_next( ["ParallelInterleaveV2", "Prefetch", "FiniteTake"])) dataset = dataset.interleave( lambda x: dataset_ops.Dataset.from_tensors(x + 1), @@ -66,10 +70,11 @@ class InjectPrefetchTest(test_base.DatasetTestBase): dataset = self._enable_autotune_buffers(dataset) self.assertDatasetProduces(dataset, range(1, 51)) + @combinations.generate(test_base.default_test_combinations()) def testChainedParallelDatasets(self): dataset = dataset_ops.Dataset.range(100) dataset = dataset.apply( - optimization.assert_next([ + testing.assert_next([ "ParallelMap", "Prefetch", "ParallelInterleaveV2", "Prefetch", "MapAndBatch", "Prefetch", "FiniteTake" ])) @@ -85,9 +90,10 @@ class InjectPrefetchTest(test_base.DatasetTestBase): dataset = self._enable_autotune_buffers(dataset) self.assertDatasetProduces(dataset, [[i] for i in range(3, 53)]) + @combinations.generate(test_base.default_test_combinations()) def testNoRegularMap(self): dataset = dataset_ops.Dataset.range(100) - dataset = dataset.apply(optimization.assert_next(["Map", "FiniteTake"])) + dataset = dataset.apply(testing.assert_next(["Map", "FiniteTake"])) dataset = dataset.map(lambda x: x + 1).take(50) dataset = self._enable_autotune_buffers(dataset) self.assertDatasetProduces(dataset, range(1, 51)) diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py index f6e5111cf32..d9ebc1cc719 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py @@ -17,15 +17,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_base -from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.experimental.ops import stats_aggregator +from tensorflow.python.data.experimental.ops import testing +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.platform import test -class LatencyAllEdgesTest(stats_dataset_test_base.StatsDatasetTestBase): +class LatencyAllEdgesTest(stats_dataset_test_base.StatsDatasetTestBase, + parameterized.TestCase): + # TODO(jsimsa): Investigate why are graph-mode tests failing. + @combinations.generate(test_base.eager_only_combinations()) def testLatencyStatsOptimization(self): aggregator = stats_aggregator.StatsAggregator() dataset = dataset_ops.Dataset.from_tensors(1).apply( diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_batch_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_batch_fusion_test.py index c7e6fbbf377..622b6ca5671 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_batch_fusion_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_batch_fusion_test.py @@ -17,16 +17,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import test_util +from tensorflow.python.framework import combinations from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class MapAndBatchFusionTest(test_base.DatasetTestBase): +class MapAndBatchFusionTest(test_base.DatasetTestBase, parameterized.TestCase): + @combinations.generate(test_base.default_test_combinations()) def testMapAndBatchFusion(self): dataset = dataset_ops.Dataset.range(10).apply( testing.assert_next( diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py index 1e53b4394ae..a0257f76e93 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py @@ -22,50 +22,16 @@ from absl.testing import parameterized from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import test_util from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -def _map_and_filter_fusion_test_cases(): - """Generates test cases for the MapAndFilterFusion optimization.""" - - identity = lambda x: x - increment = lambda x: x + 1 - minus_five = lambda x: x - 5 - - def increment_and_square(x): - y = x + 1 - return y * y - - take_all = lambda x: constant_op.constant(True) - is_zero = lambda x: math_ops.equal(x, 0) - is_odd = lambda x: math_ops.equal(x % 2, 0) - greater = lambda x: math_ops.greater(x + 5, 0) - - functions = [identity, increment, minus_five, increment_and_square] - filters = [take_all, is_zero, is_odd, greater] - tests = [] - - for x, fun in enumerate(functions): - for y, predicate in enumerate(filters): - tests.append(("Mixed{}{}".format(x, y), fun, predicate)) - - # Multi output - tests.append(("Multi1", lambda x: (x, x), - lambda x, y: constant_op.constant(True))) - tests.append( - ("Multi2", lambda x: (x, 2), - lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0))) - return tuple(tests) - - -@test_util.run_all_in_graph_and_eager_modes class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase): - def _testMapAndFilter(self, dataset, function, predicate): + def _testDataset(self, dataset, function, predicate): expected_output = [] for x in range(10): r = function(x) @@ -77,8 +43,7 @@ class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase): expected_output.append(r) self.assertDatasetProduces(dataset, expected_output=expected_output) - @parameterized.named_parameters(*_map_and_filter_fusion_test_cases()) - def testMapFilterFusion(self, function, predicate): + def _testMapAndFilterFusion(self, function, predicate): dataset = dataset_ops.Dataset.range(10).apply( testing.assert_next(["Map", "Filter", "Map"])).map(function).filter(predicate) @@ -86,8 +51,44 @@ class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase): options.experimental_optimization.apply_default_optimizations = False options.experimental_optimization.map_and_filter_fusion = True dataset = dataset.with_options(options) - self._testMapAndFilter(dataset, function, predicate) + self._testDataset(dataset, function, predicate) + @combinations.generate(test_base.default_test_combinations()) + def testMapAndFilterFusionScalar(self): + identity = lambda x: x + increment = lambda x: x + 1 + minus_five = lambda x: x - 5 + + def increment_and_square(x): + y = x + 1 + return y * y + + functions = [identity, increment, minus_five, increment_and_square] + + take_all = lambda x: constant_op.constant(True) + is_zero = lambda x: math_ops.equal(x, 0) + is_odd = lambda x: math_ops.equal(x % 2, 0) + greater = lambda x: math_ops.greater(x + 5, 0) + predicates = [take_all, is_zero, is_odd, greater] + + for function in functions: + for predicate in predicates: + self._testMapAndFilterFusion(function, predicate) + + @combinations.generate(test_base.default_test_combinations()) + def testMapAndFilterFusionTuple(self): + replicate = lambda x: (x, x) + with_two = lambda x: (x, 2) + functions = [replicate, with_two] + take_all = lambda x, y: constant_op.constant(True) + is_zero = lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0) + predicates = [take_all, is_zero] + + for function in functions: + for predicate in predicates: + self._testMapAndFilterFusion(function, predicate) + + @combinations.generate(test_base.default_test_combinations()) def testCapturedInputs(self): a = constant_op.constant(3, dtype=dtypes.int64) b = constant_op.constant(4, dtype=dtypes.int64) @@ -104,7 +105,7 @@ class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase): options.experimental_optimization.apply_default_optimizations = False options.experimental_optimization.map_and_filter_fusion = True dataset = dataset.with_options(options) - self._testMapAndFilter(dataset, function, predicate) + self._testDataset(dataset, function, predicate) if __name__ == "__main__": diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_fusion_test.py index 10f27dc277f..28da0474bc9 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_fusion_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_fusion_test.py @@ -22,51 +22,13 @@ from absl.testing import parameterized from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import test_util +from tensorflow.python.framework import combinations from tensorflow.python.platform import test -def _map_fusion_test_cases(): - """Generates test cases for the MapFusion optimization.""" - - identity = lambda x: x - increment = lambda x: x + 1 - - def increment_and_square(x): - y = x + 1 - return y * y - - functions = [identity, increment, increment_and_square] - tests = [] - for i, fun1 in enumerate(functions): - for j, fun2 in enumerate(functions): - tests.append(( - "Test{}{}".format(i, j), - [fun1, fun2], - )) - for k, fun3 in enumerate(functions): - tests.append(( - "Test{}{}{}".format(i, j, k), - [fun1, fun2, fun3], - )) - - swap = lambda x, n: (n, x) - tests.append(( - "Swap1", - [lambda x: (x, 42), swap], - )) - tests.append(( - "Swap2", - [lambda x: (x, 42), swap, swap], - )) - return tuple(tests) - - -@test_util.run_all_in_graph_and_eager_modes class MapFusionTest(test_base.DatasetTestBase, parameterized.TestCase): - @parameterized.named_parameters(*_map_fusion_test_cases()) - def testMapFusion(self, functions): + def _testMapFusion(self, functions): dataset = dataset_ops.Dataset.range(5).apply( testing.assert_next(["Map", "MemoryCacheImpl"])) for function in functions: @@ -88,6 +50,31 @@ class MapFusionTest(test_base.DatasetTestBase, parameterized.TestCase): expected_output.append(r) self.assertDatasetProduces(dataset, expected_output=expected_output) + @combinations.generate(test_base.default_test_combinations()) + def testMapFusionScalar(self): + identity = lambda x: x + increment = lambda x: x + 1 + + def increment_and_square(x): + y = x + 1 + return y * y + + functions = [identity, increment, increment_and_square] + + for x in functions: + for y in functions: + self._testMapFusion([x, y]) + for z in functions: + self._testMapFusion([x, y, z]) + + @combinations.generate(test_base.default_test_combinations()) + def testMapAndFilterFusionTuple(self): + with_42 = lambda x: (x, 42) + swap = lambda x, y: (y, x) + + self._testMapFusion([with_42, swap]) + self._testMapFusion([with_42, swap, swap]) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py index 668ab28c64c..a28a3052abc 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py @@ -22,38 +22,20 @@ from absl.testing import parameterized from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test -def _map_parallelization_test_cases(): - """Generates test cases for the MapParallelization optimization.""" - - identity = lambda x: x - increment = lambda x: x + 1 - - def assert_greater(x): - assert_op = control_flow_ops.Assert(math_ops.greater(x, -1), [x]) - with ops.control_dependencies([assert_op]): - return x - - return (("Identity", identity, True), - ("Increment", increment, True), - ("AssertGreater", assert_greater, True)) - - -@test_util.run_all_in_graph_and_eager_modes class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase): - @parameterized.named_parameters(*_map_parallelization_test_cases()) - def testMapParallelization(self, function, should_be_parallel): - next_nodes = ["ParallelMap"] if should_be_parallel else ["Map"] + def _testMapParallelization(self, function, should_optimize): + next_nodes = ["ParallelMap"] if should_optimize else ["Map"] dataset = dataset_ops.Dataset.range(5).apply( testing.assert_next(next_nodes)).map(function) options = dataset_ops.Options() @@ -63,9 +45,26 @@ class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces( dataset, expected_output=[function(x) for x in range(5)]) - def testMapParallelizationWithCapturedConstant(self): - """Tests that functions with captured constants are parallelized.""" + @combinations.generate(test_base.default_test_combinations()) + def testIdentity(self): + self._testMapParallelization(lambda x: x, should_optimize=True) + @combinations.generate(test_base.default_test_combinations()) + def testIncrement(self): + self._testMapParallelization(lambda x: x + 1, should_optimize=True) + + @combinations.generate(test_base.default_test_combinations()) + def testAssert(self): + + def assert_greater(x): + assert_op = control_flow_ops.Assert(math_ops.greater(x, -1), [x]) + with ops.control_dependencies([assert_op]): + return x + + self._testMapParallelization(assert_greater, should_optimize=True) + + @combinations.generate(test_base.default_test_combinations()) + def testCapturedConstant(self): captured_t = constant_op.constant(42, dtype=dtypes.int64) def fn(x): return x + captured_t @@ -78,9 +77,8 @@ class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces( dataset, expected_output=[x + 42 for x in range(5)]) - def testMapParallelizationWithCapturedVariable(self): - """Tests that functions with captured variables are not parallelized.""" - + @combinations.generate(test_base.default_test_combinations()) + def testCapturedVariable(self): captured_t = variables.Variable(42, dtype=dtypes.int64) def fn(x): return x + captured_t diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py index f17d863e555..4569f171f75 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + from absl.testing import parameterized import numpy as np @@ -26,12 +28,12 @@ from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import bitwise_ops from tensorflow.python.ops import check_ops @@ -43,21 +45,45 @@ from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test -def _generate_unary_cwise_math_cases(): - # TODO(rachelim): Consolidate tests with pfor when APIs are somewhat shared. - bitwise_cases = [("Invert", bitwise_ops.invert)] - logical_cases = [("LogicalNot", math_ops.logical_not)] - complex_cases = [ +def _generate_test_combinations(cases): + + def reduce_fn(x, y): + name, fn = y + return x + combinations.combine(map_fn=combinations.NamedObject(name, fn)) + + return functools.reduce(reduce_fn, cases, []) + + +def _unary_bitwise_test_combinations(): + cases = [("Invert", bitwise_ops.invert)] + return _generate_test_combinations(cases) + + +def _unary_logical_test_combinations(): + cases = [("LogicalNot", math_ops.logical_not)] + return _generate_test_combinations(cases) + + +def _unary_complex_test_combinations(): + cases = [ ("Angle", math_ops.angle), ("ComplexAbs", math_ops.abs), ("Conj", math_ops.conj), ("Imag", math_ops.imag), ("Real", math_ops.real), ] - real_cases = [ + return _generate_test_combinations(cases) + + +def _unary_real_test_combinations(): + # acosh requires values x >= 1 + def safe_acosh(x): + return math_ops.acosh(1 + math_ops.square(x)) + + cases = [ ("Abs", math_ops.abs), ("Acos", math_ops.acos), - ("Acosh", lambda x: math_ops.acosh(1 + math_ops.square(x))), + ("Acosh", safe_acosh), ("Asin", math_ops.asin), ("Asinh", math_ops.asinh), ("Atan", math_ops.atan), @@ -99,45 +125,26 @@ def _generate_unary_cwise_math_cases(): ("Tan", math_ops.tan), ("Tanh", math_ops.tanh), ] - random_input = np.random.rand(3, 5) - complex_component = np.random.rand(3, 5) - random_int = np.random.randint(0, 10, (7, 3, 5)) - - def bitwise_dataset_factory(): - return dataset_ops.Dataset.from_tensor_slices(random_int) - - def logical_dataset_factory(): - return dataset_ops.Dataset.from_tensor_slices(random_input > 0) - - def random_dataset_factory(): - return dataset_ops.Dataset.from_tensor_slices(random_input) - - def complex_dataset_factory(): - return dataset_ops.Dataset.from_tensor_slices( - math_ops.complex(random_input, complex_component)) - - case_factory_pairs = [ - (bitwise_cases, bitwise_dataset_factory), - (logical_cases, logical_dataset_factory), - (complex_cases, complex_dataset_factory), - (real_cases, random_dataset_factory), - ] - return [(case[0], case[1], factory) - for cases, factory in case_factory_pairs - for case in cases] + return _generate_test_combinations(cases) -def _generate_binary_cwise_math_cases(): - bitwise_cases = [("BitwiseAnd", bitwise_ops.bitwise_and), - ("BitwiseOr", bitwise_ops.bitwise_or), - ("BitwiseXor", bitwise_ops.bitwise_xor), - ("LeftShift", bitwise_ops.left_shift), - ("RightShift", bitwise_ops.right_shift)] +def _binary_bitwise_test_combinations(): + cases = [("BitwiseAnd", bitwise_ops.bitwise_and), + ("BitwiseOr", bitwise_ops.bitwise_or), + ("BitwiseXor", bitwise_ops.bitwise_xor), + ("LeftShift", bitwise_ops.left_shift), + ("RightShift", bitwise_ops.right_shift)] + return _generate_test_combinations(cases) - logical_cases = [("LogicalAnd", math_ops.logical_and), - ("LogicalOr", math_ops.logical_or)] - # Wrapper functions restricting the range of inputs of zeta and polygamma. +def _binary_logical_test_combinations(): + cases = [("LogicalAnd", math_ops.logical_and), + ("LogicalOr", math_ops.logical_or)] + return _generate_test_combinations(cases) + + +def _binary_real_test_combinations(): + def safe_polygamma(x, y): return math_ops.polygamma( math_ops.round(clip_ops.clip_by_value(y, 1, 10)), x * x + 1) @@ -145,7 +152,7 @@ def _generate_binary_cwise_math_cases(): def safe_zeta(x, y): return math_ops.zeta(x * x + 1, y * y) - real_cases = [ + cases = [ ("Add", math_ops.add), ("AddV2", math_ops.add_v2), ("Atan2", math_ops.atan2), @@ -174,150 +181,10 @@ def _generate_binary_cwise_math_cases(): ("TruncateMod", math_ops.truncate_mod), ("Zeta", safe_zeta), ] - - # Exercises broadcasting capabilities - x = np.random.rand(7, 3, 5) - y = np.random.rand(3, 5) - - x_int = np.random.randint(0, 10, (7, 3, 5)) - y_int = np.random.randint(0, 10, (3, 5)) - - def bitwise_dataset_factory(): - return dataset_ops.Dataset.from_tensors((x_int, y_int)) - - def logical_dataset_factory(): - return dataset_ops.Dataset.from_tensors((x > 0, y > 0)) - - def random_dataset_factory(): - return dataset_ops.Dataset.from_tensors((x, y)) - - case_factory_pairs = [ - (bitwise_cases, bitwise_dataset_factory), - (logical_cases, logical_dataset_factory), - (real_cases, random_dataset_factory), - ] - return [(case[0], case[1], factory) - for cases, factory in case_factory_pairs - for case in cases] + return _generate_test_combinations(cases) -def _generate_cwise_test_cases(): - return _generate_unary_cwise_math_cases() + _generate_binary_cwise_math_cases( - ) - - -def _generate_csv_test_case(): - - def csv_factory(): - return dataset_ops.Dataset.from_tensor_slices(["1.0:2:a", - "2.4:5:c"]).repeat(5) - - def decode_csv_fn(x): - return parsing_ops.decode_csv( - x, - record_defaults=[ - constant_op.constant([], dtypes.float32), - constant_op.constant([], dtypes.int32), - constant_op.constant([], dtypes.string) - ], - field_delim=":") - - return decode_csv_fn, csv_factory - - -def _generate_parse_single_example_test_case(): - # When sparse tensors are used, map_vectorization is not - # attempted because the output_shapes of the map dataset are not defined. - # TODO(rachelim): Consider being more lax with checking the output_shapes of - # the map node. - - def parse_example_factory(): - - def _int64_feature(*values): - return feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=values)) - - def _bytes_feature(*values): - return feature_pb2.Feature( - bytes_list=feature_pb2.BytesList( - value=[v.encode("utf-8") for v in values])) - - return dataset_ops.Dataset.from_tensor_slices( - constant_op.constant([ - example_pb2.Example( - features=feature_pb2.Features( - feature={ - "dense_int": _int64_feature(i), - "dense_str": _bytes_feature(str(i)), - })).SerializeToString() for i in range(10) - ])) - - def parse_single_example_fn(x): - features = { - "dense_int": parsing_ops.FixedLenFeature((), dtypes.int64, 0), - "dense_str": parsing_ops.FixedLenFeature((), dtypes.string, ""), - } - return parsing_ops.parse_single_example(x, features) - - return parse_single_example_fn, parse_example_factory - - -def _generate_optimization_test_cases(): - - def base_dataset_factory(): - return dataset_ops.Dataset.from_tensors(np.random.rand(10, 3)).repeat(5) - - rand_val = np.random.rand(1, 1, 1, 1, 1, 1) - - csv_test_case = _generate_csv_test_case() - parse_fn, parse_base = _generate_parse_single_example_test_case() - - def dense_output_only_parse_fn(x): - # Since we haven't implemented a vectorizer for SerializeSparse, any - # function with sparse outputs will only be naively vectorized. - parse_result = parse_fn(x) - return [ - y for y in parse_result if not isinstance(y, sparse_tensor.SparseTensor) - ] - - def map_fn_with_cycle(x): - c = lambda i: math_ops.less(i, 10) - b = lambda i: math_ops.add(i, 1) - return control_flow_ops.while_loop(c, b, [x]) - - # Misc test cases - test_cases = [ - ("Basic", lambda x: (x, x + 1), base_dataset_factory), - ("Broadcast", lambda x: x + rand_val, base_dataset_factory), - ("Cycle", map_fn_with_cycle, lambda: dataset_ops.Dataset.from_tensors(1)), - ("Const", lambda x: 2, base_dataset_factory), - ("Cast", lambda x: math_ops.cast(x, dtypes.float64), - base_dataset_factory), - ("Reshape", lambda x: array_ops.reshape(x, (-1, 30)), - base_dataset_factory), - ("Transpose", array_ops.transpose, base_dataset_factory), - ("Unpack", array_ops.unstack, base_dataset_factory), - ("UnpackNegativeAxis", lambda x: array_ops.unstack(x, axis=-1), - base_dataset_factory), - # Parsing ops - ("DecodeCSV", csv_test_case[0], csv_test_case[1]), - ("ParseSingleExample", parse_fn, parse_base), - ("ParseSingleExampleDenseOutputOnly", dense_output_only_parse_fn, - parse_base), - ] + _generate_cwise_test_cases() - - return [{ - "testcase_name": - x[0] + "Parallel" if num_parallel_calls is not None else x[0], - "map_fn": - x[1], - "base_dataset_factory": - x[2], - "num_parallel_calls": - num_parallel_calls - } for x in test_cases for num_parallel_calls in (None, 12)] - - -@test_util.run_all_in_graph_and_eager_modes +# TODO(rachelim): Consolidate tests with pfor when APIs are somewhat shared. class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): def _enable_map_vectorization(self, dataset, use_choose=True): @@ -370,13 +237,223 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): optimized = self._enable_map_vectorization(optimized) return unoptimized, optimized - @parameterized.named_parameters(_generate_optimization_test_cases()) - def testOptimization(self, map_fn, base_dataset_factory, num_parallel_calls): - base_dataset = base_dataset_factory() - unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn, + def _testOptimization(self, map_fn, dataset_factory, num_parallel_calls): + dataset = dataset_factory() + unoptimized, optimized = self._get_test_datasets(dataset, map_fn, num_parallel_calls) self.assertDatasetsEqual(unoptimized, optimized) + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(num_parallel_calls=[None, 12]))) + def testBasic(self, num_parallel_calls): + data = np.random.rand(10, 3) + dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5) + map_fn = lambda x: (x, x + 1) + self._testOptimization(map_fn, dataset_factory, num_parallel_calls) + + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(num_parallel_calls=[None, 12]))) + def testBroadcast(self, num_parallel_calls): + data = np.random.rand(10, 3) + dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5) + value = np.random.rand(1, 1, 1, 1, 1, 1) + map_fn = lambda x: x + value + self._testOptimization(map_fn, dataset_factory, num_parallel_calls) + + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(num_parallel_calls=[None, 12]))) + def testCast(self, num_parallel_calls): + data = np.random.rand(10, 3) + dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5) + map_fn = lambda x: math_ops.cast(x, dtypes.float64) + self._testOptimization(map_fn, dataset_factory, num_parallel_calls) + + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(num_parallel_calls=[None, 12]))) + def testConst(self, num_parallel_calls): + data = np.random.rand(10, 3) + dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5) + map_fn = lambda x: 2 + self._testOptimization(map_fn, dataset_factory, num_parallel_calls) + + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(num_parallel_calls=[None, 12]))) + def testCycle(self, num_parallel_calls): + dataset_factory = lambda: dataset_ops.Dataset.from_tensors(1) + + def map_fn(x): + c = lambda i: math_ops.less(i, 10) + b = lambda i: math_ops.add(i, 1) + return control_flow_ops.while_loop(c, b, [x]) + + self._testOptimization(map_fn, dataset_factory, num_parallel_calls) + + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(num_parallel_calls=[None, 12]))) + def testReshape(self, num_parallel_calls): + data = np.random.rand(10, 3) + dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5) + map_fn = lambda x: array_ops.reshape(x, (-1, 30)) + self._testOptimization(map_fn, dataset_factory, num_parallel_calls) + + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(num_parallel_calls=[None, 12]))) + def testTranspose(self, num_parallel_calls): + data = np.random.rand(10, 3) + dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5) + map_fn = array_ops.transpose + self._testOptimization(map_fn, dataset_factory, num_parallel_calls) + + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(num_parallel_calls=[None, 12]))) + def testUnstack(self, num_parallel_calls): + data = np.random.rand(10, 3) + dataset_factory = lambda: dataset_ops.Dataset.from_tensors(data).repeat(5) + map_fns = [array_ops.unstack, lambda x: array_ops.unstack(x, axis=-1)] + for map_fn in map_fns: + self._testOptimization(map_fn, dataset_factory, num_parallel_calls) + + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + _unary_bitwise_test_combinations(), + combinations.combine(num_parallel_calls=[None, 12]))) + def testUnaryBitwiseOperations(self, map_fn, num_parallel_calls): + x = np.random.randint(0, 10, (7, 3, 5)) + dataset_factory = lambda: dataset_ops.Dataset.from_tensor_slices(x) + self._testOptimization(map_fn, dataset_factory, num_parallel_calls) + + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + _unary_logical_test_combinations(), + combinations.combine(num_parallel_calls=[None, 12]))) + def testUnaryLogicalOperations(self, map_fn, num_parallel_calls): + x = np.random.rand(3, 5) + dataset_factory = lambda: dataset_ops.Dataset.from_tensor_slices(x > 0) + self._testOptimization(map_fn, dataset_factory, num_parallel_calls) + + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + _unary_complex_test_combinations(), + combinations.combine(num_parallel_calls=[None, 12]))) + def testUnaryComplexOperations(self, map_fn, num_parallel_calls): + x = math_ops.complex(np.random.rand(3, 5), np.random.rand(3, 5)) + dataset_factory = lambda: dataset_ops.Dataset.from_tensor_slices(x) + self._testOptimization(map_fn, dataset_factory, num_parallel_calls) + + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + _unary_real_test_combinations(), + combinations.combine(num_parallel_calls=[None, 12]))) + def testUnaryRealOperations(self, map_fn, num_parallel_calls): + x = np.random.rand(3, 5) + dataset_factory = lambda: dataset_ops.Dataset.from_tensor_slices(x) + self._testOptimization(map_fn, dataset_factory, num_parallel_calls) + + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + _binary_bitwise_test_combinations(), + combinations.combine(num_parallel_calls=[None, 12]))) + def testBinaryBitwiseOperations(self, map_fn, num_parallel_calls): + x = np.random.randint(0, 10, (7, 3, 5)) + y = np.random.randint(0, 10, (3, 5)) + dataset_factory = lambda: dataset_ops.Dataset.from_tensors((x, y)) + self._testOptimization(map_fn, dataset_factory, num_parallel_calls) + + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + _binary_logical_test_combinations(), + combinations.combine(num_parallel_calls=[None, 12]))) + def testBinaryLogicalOperations(self, map_fn, num_parallel_calls): + x = np.random.rand(7, 3, 5) + y = np.random.rand(3, 5) + dataset_factory = lambda: dataset_ops.Dataset.from_tensors((x > 0, y > 0)) + self._testOptimization(map_fn, dataset_factory, num_parallel_calls) + + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + _binary_real_test_combinations(), + combinations.combine(num_parallel_calls=[None, 12]))) + def testBinaryRealOperations(self, map_fn, num_parallel_calls): + x = np.random.rand(7, 3, 5) + y = np.random.rand(3, 5) + dataset_factory = lambda: dataset_ops.Dataset.from_tensors((x, y)) + self._testOptimization(map_fn, dataset_factory, num_parallel_calls) + + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(num_parallel_calls=[None, 12]))) + def testDecodeCsv(self, num_parallel_calls): + + def dataset_factory(): + return dataset_ops.Dataset.from_tensor_slices(["1.0:2:a", + "2.4:5:c"]).repeat(5) + + def decode_csv_fn(x): + return parsing_ops.decode_csv( + x, + record_defaults=[ + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.string) + ], + field_delim=":") + + self._testOptimization(decode_csv_fn, dataset_factory, num_parallel_calls) + + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(num_parallel_calls=[None, 12]))) + def testParseSingleExample(self, num_parallel_calls): + + def dataset_factory(): + + def _int64_feature(*values): + return feature_pb2.Feature( + int64_list=feature_pb2.Int64List(value=values)) + + def _bytes_feature(*values): + return feature_pb2.Feature( + bytes_list=feature_pb2.BytesList( + value=[v.encode("utf-8") for v in values])) + + # pylint:disable=g-complex-comprehension + return dataset_ops.Dataset.from_tensor_slices( + constant_op.constant([ + example_pb2.Example( + features=feature_pb2.Features( + feature={ + "dense_int": _int64_feature(i), + "dense_str": _bytes_feature(str(i)), + })).SerializeToString() for i in range(10) + ])) + + def parse_fn(x): + features = { + "dense_int": parsing_ops.FixedLenFeature((), dtypes.int64, 0), + "dense_str": parsing_ops.FixedLenFeature((), dtypes.string, ""), + } + return parsing_ops.parse_single_example(x, features) + + def dense_only_parse_fn(x): + return [ + y for y in parse_fn(x) + if not isinstance(y, sparse_tensor.SparseTensor) + ] + + map_fns = [parse_fn, dense_only_parse_fn] + + for map_fn in map_fns: + self._testOptimization(map_fn, dataset_factory, num_parallel_calls) + + @combinations.generate(test_base.default_test_combinations()) def testOptimizationBadMapFn(self): # Test map functions that give an error def map_fn(x): @@ -391,6 +468,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): nxt = dataset_ops.make_one_shot_iterator(optimized).get_next() self.evaluate(nxt) + @combinations.generate(test_base.default_test_combinations()) def testOptimizationWithCapturedInputs(self): # Tests that vectorization works with captured inputs. y = constant_op.constant(1, shape=(2,)) @@ -405,6 +483,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): base_dataset, map_fn, expect_optimized=True) self.assertDatasetsEqual(optimized, unoptimized) + @combinations.generate(test_base.default_test_combinations()) def testOptimizationWithMapAndBatchFusion(self): # Tests that vectorization works on fused map and batch. def map_fn(x): @@ -425,12 +504,11 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): optimized = self._enable_map_vectorization(optimized) self.assertDatasetsEqual(optimized, unoptimized) - @parameterized.named_parameters( - ("1", True, True), - ("2", True, False), - ("3", False, True), - ("4", False, False), - ) + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine( + fuse_first=[True, False], fuse_second=[True, False]))) def testOptimizationWithChainedMapAndBatch(self, fuse_first, fuse_second): # Tests that vectorization works on chained map and batch functions. def map_fn(x): @@ -474,6 +552,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): optimized = self._enable_map_vectorization(optimized) self.assertDatasetsEqual(optimized, unoptimized) + @combinations.generate(test_base.default_test_combinations()) def testOptimizationIgnoreStateful(self): def map_fn(x): @@ -488,6 +567,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): get_next = self.getNext(dataset) self.evaluate(get_next()) + @combinations.generate(test_base.default_test_combinations()) def testOptimizationIgnoreRagged(self): # Make sure we ignore inputs that might not be uniformly sized def map_fn(x): @@ -499,6 +579,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): base_dataset, map_fn, expect_optimized=False) self.assertDatasetsEqual(unoptimized, optimized) + @combinations.generate(test_base.default_test_combinations()) def testOptimizationIgnoreRaggedMap(self): # Don't optimize when the output of the map fn shapes are unknown. def map_fn(x): @@ -512,6 +593,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): get_next = self.getNext(dataset) self.evaluate(get_next()) + @combinations.generate(test_base.default_test_combinations()) def testOptimizationWithUnknownBatchShape(self): tensor = sparse_tensor.SparseTensor( indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) @@ -526,6 +608,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): optimized = self._enable_map_vectorization(unoptimized) self.assertDatasetsEqual(unoptimized, optimized) + @combinations.generate(test_base.default_test_combinations()) def testOptimizationWithSparseTensor(self): base_dataset = dataset_ops.Dataset.from_tensors(0) @@ -542,6 +625,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): optimized = self._enable_map_vectorization(unoptimized) self.assertDatasetsEqual(unoptimized, optimized) + @combinations.generate(test_base.default_test_combinations()) def testOptimizationWithPrefetch(self): dataset = dataset_ops.Dataset.range(10) dataset = dataset.map(lambda x: x) @@ -550,6 +634,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): dataset = self._enable_map_vectorization(dataset) self.assertDatasetProduces(dataset, [list(range(10))]) + @combinations.generate(test_base.default_test_combinations()) def testOptimizationWithoutChooseFastest(self): dataset = dataset_ops.Dataset.range(10) dataset = dataset.map(lambda x: x**2) diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py index a401a5c8baf..84ef45d9593 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py @@ -17,19 +17,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import test_util from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class NoopEliminationTest(test_base.DatasetTestBase): +class NoopEliminationTest(test_base.DatasetTestBase, parameterized.TestCase): + @combinations.generate(test_base.default_test_combinations()) def testNoopElimination(self): a = constant_op.constant(1, dtype=dtypes.int64) b = constant_op.constant(2, dtype=dtypes.int64) diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/shuffle_and_repeat_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/shuffle_and_repeat_fusion_test.py index 4da7fa27d58..ad1a98134b8 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/shuffle_and_repeat_fusion_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/shuffle_and_repeat_fusion_test.py @@ -17,19 +17,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.python import tf2 from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context +from tensorflow.python.framework import combinations from tensorflow.python.framework import errors -from tensorflow.python.framework import test_util from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class ShuffleAndRepeatFusionTest(test_base.DatasetTestBase): +class ShuffleAndRepeatFusionTest(test_base.DatasetTestBase, + parameterized.TestCase): + @combinations.generate(test_base.default_test_combinations()) def testShuffleAndRepeatFusion(self): if tf2.enabled() and context.executing_eagerly(): expected = "Shuffle" diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py index f6ab5a1cde2..aea4934260e 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py @@ -57,6 +57,7 @@ class DatasetSerializationTestBase(test.TestCase): def tearDown(self): self._delete_ckpt() + super(DatasetSerializationTestBase, self).tearDown() # TODO(b/72657739): Remove sparse_tensor argument, which is to test the # (deprecated) saveable `SparseTensorSliceDataset`, once the API diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index 2b1bda4138a..db749da77f8 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -287,6 +287,7 @@ tf_py_test( size = "small", srcs = ["iterator_cluster_test.py"], additional_deps = [ + ":test_base", "//tensorflow/core:protos_all_py", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", @@ -400,6 +401,7 @@ tf_py_test( "//tensorflow/python:variable_scope", "//tensorflow/python/ops/ragged", ], + shard_count = 4, ) cuda_py_test( diff --git a/tensorflow/python/data/kernel_tests/as_numpy_iterator_test.py b/tensorflow/python/data/kernel_tests/as_numpy_iterator_test.py index 23a928a817c..b704906a3ae 100644 --- a/tensorflow/python/data/kernel_tests/as_numpy_iterator_test.py +++ b/tensorflow/python/data/kernel_tests/as_numpy_iterator_test.py @@ -61,30 +61,30 @@ class AsNumpyIteratorTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(RuntimeError): ds.as_numpy_iterator() - def checkInvalidElement(self, element): + def _testInvalidElement(self, element): ds = dataset_ops.Dataset.from_tensors(element) with self.assertRaisesRegex(TypeError, '.*does not support datasets containing.*'): ds.as_numpy_iterator() @combinations.generate(test_base.eager_only_combinations()) - def testInvalidElements(self): - self.checkInvalidElement(sparse_tensor.SparseTensorValue([[0]], [0], [1])) + def testSparseElement(self): + self._testInvalidElement(sparse_tensor.SparseTensorValue([[0]], [0], [1])) @combinations.generate(test_base.eager_only_combinations()) def testRaggedElement(self): - self.checkInvalidElement( + self._testInvalidElement( ragged_tensor_value.RaggedTensorValue( np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64))) @combinations.generate(test_base.eager_only_combinations()) def testDatasetElement(self): - self.checkInvalidElement(dataset_ops.Dataset.range(3)) + self._testInvalidElement(dataset_ops.Dataset.range(3)) @combinations.generate(test_base.eager_only_combinations()) def testNestedNonTensorElement(self): tuple_elem = (constant_op.constant([1, 2, 3]), dataset_ops.Dataset.range(3)) - self.checkInvalidElement(tuple_elem) + self._testInvalidElement(tuple_elem) if __name__ == '__main__': diff --git a/tensorflow/python/data/kernel_tests/cache_test.py b/tensorflow/python/data/kernel_tests/cache_test.py index 6d3dc04a3e0..1a923645a04 100644 --- a/tensorflow/python/data/kernel_tests/cache_test.py +++ b/tensorflow/python/data/kernel_tests/cache_test.py @@ -45,9 +45,9 @@ class FileCacheTest(test_base.DatasetTestBase, parameterized.TestCase): self.cache_prefix = path.join(self.tmp_dir, "cache") def tearDown(self): - super(FileCacheTest, self).tearDown() if self.tmp_dir: shutil.rmtree(self.tmp_dir, ignore_errors=True) + super(FileCacheTest, self).tearDown() @combinations.generate(test_base.default_test_combinations()) def testCacheDatasetPassthrough(self): diff --git a/tensorflow/python/data/kernel_tests/checkpoint_test.py b/tensorflow/python/data/kernel_tests/checkpoint_test.py index 738d09b97fe..4441d5642d5 100644 --- a/tensorflow/python/data/kernel_tests/checkpoint_test.py +++ b/tensorflow/python/data/kernel_tests/checkpoint_test.py @@ -42,11 +42,11 @@ from tensorflow.python.training.tracking import util as trackable_utils class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase): def tearDown(self): - super(CheckpointTest, self).tearDown() prefix = self._iterator_checkpoint_prefix() pattern = prefix + "*" files = gfile.Glob(pattern) map(gfile.Remove, files) + super(CheckpointTest, self).tearDown() def _iterator_checkpoint_prefix(self): return os.path.join(self.get_temp_dir(), "iterator") @@ -66,8 +66,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase): iterator_state_variant) return restore_op - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode="graph")) + @combinations.generate(test_base.graph_only_combinations()) def testSaveRestore(self): def _build_graph(start, stop): @@ -118,8 +117,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode="graph")) + @combinations.generate(test_base.graph_only_combinations()) def testInitThenRestore(self): # Note: Calling init_op before restore_op is redundant. This test just makes # sure we do not fail if restore is called on an already initialized @@ -157,8 +155,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode="graph")) + @combinations.generate(test_base.graph_only_combinations()) def testMultipleSaves(self): def _build_graph(start, stop): @@ -204,8 +201,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode="graph")) + @combinations.generate(test_base.graph_only_combinations()) def testSaveRestoreWithRepeat(self): def _build_graph(start, stop, num_epochs): @@ -253,8 +249,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode="graph")) + @combinations.generate(test_base.graph_only_combinations()) def testSaveRestoreExhaustedIterator(self): def _build_graph(start, stop, num_epochs): @@ -295,8 +290,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode="eager")) + @combinations.generate(test_base.eager_only_combinations()) def testSaveRestoreOneShotIterator(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") @@ -319,8 +313,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): get_next() - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode="eager")) + @combinations.generate(test_base.eager_only_combinations()) def testSaveRestoreMultipleIterator(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") @@ -353,8 +346,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertAllEqual([1, 4], get_next_2()) self.assertAllEqual(3, get_next_3()) - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode="eager")) + @combinations.generate(test_base.eager_only_combinations()) def testRestoreExhaustedIterator(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") @@ -373,8 +365,7 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): get_next() - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode="eager")) + @combinations.generate(test_base.eager_only_combinations()) def testRestoreInReconstructedIteratorInitializable(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") diff --git a/tensorflow/python/data/kernel_tests/dataset_test.py b/tensorflow/python/data/kernel_tests/dataset_test.py index b35c4ff1b29..df151a85db0 100644 --- a/tensorflow/python/data/kernel_tests/dataset_test.py +++ b/tensorflow/python/data/kernel_tests/dataset_test.py @@ -43,7 +43,6 @@ from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test -from tensorflow.python.platform import tf_logging class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @@ -89,13 +88,13 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): variant, original_dataset.element_spec) self.assertDatasetProduces(revived_dataset, list(original_dataset)) - def checkNumInputs(self, dataset, num_inputs): + def _testNumInputs(self, dataset, num_inputs): self.assertLen(dataset._inputs(), num_inputs) @combinations.generate(test_base.default_test_combinations()) def testFixedLengthRecordInputs(self): dataset = readers.FixedLengthRecordDataset("", 42) - self.checkNumInputs(dataset, 0) + self._testNumInputs(dataset, 0) @combinations.generate(test_base.default_test_combinations()) def testFromGeneratorInputs(self): @@ -103,27 +102,27 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): yield 42 dataset = dataset_ops.Dataset.from_generator(gen, dtypes.int32) - self.checkNumInputs(dataset, 1) + self._testNumInputs(dataset, 1) @combinations.generate(test_base.default_test_combinations()) def testFromTensorsInputs(self): dataset = dataset_ops.Dataset.from_tensors([42]) - self.checkNumInputs(dataset, 0) + self._testNumInputs(dataset, 0) @combinations.generate(test_base.default_test_combinations()) def testRangeInputs(self): dataset = dataset_ops.Dataset.range(10) - self.checkNumInputs(dataset, 0) + self._testNumInputs(dataset, 0) @combinations.generate(test_base.default_test_combinations()) def testTextLineInputs(self): dataset = readers.TextLineDataset("") - self.checkNumInputs(dataset, 0) + self._testNumInputs(dataset, 0) @combinations.generate(test_base.default_test_combinations()) def testTFRecordInputs(self): dataset = readers.TFRecordDataset("") - self.checkNumInputs(dataset, 1) + self._testNumInputs(dataset, 1) @combinations.generate( combinations.combine(tf_api_version=1, mode=["eager", "graph"])) @@ -135,58 +134,58 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): dense_shape=np.array([3, 1]))) self.assertEmpty(dataset_fn._inputs()) - def checkUnaryInputs(self, dataset_fn): + def _testUnaryInputs(self, dataset_fn): input_dataset = dataset_ops.Dataset.range(0) self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) @combinations.generate(test_base.default_test_combinations()) def testBatchInputs(self): - self.checkUnaryInputs(lambda x: x.batch(10)) + self._testUnaryInputs(lambda x: x.batch(10)) @combinations.generate(test_base.default_test_combinations()) def testCacheInputs(self): - self.checkUnaryInputs(lambda x: x.cache()) + self._testUnaryInputs(lambda x: x.cache()) @combinations.generate(test_base.default_test_combinations()) def testFilterInputs(self): - self.checkUnaryInputs(lambda x: x.filter(lambda x: True)) + self._testUnaryInputs(lambda x: x.filter(lambda x: True)) @combinations.generate(test_base.default_test_combinations()) def testFlatMapInputs(self): - self.checkUnaryInputs( + self._testUnaryInputs( lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0))) @combinations.generate(test_base.default_test_combinations()) def testMapInputs(self): - self.checkUnaryInputs(lambda x: x.map(lambda x: x)) + self._testUnaryInputs(lambda x: x.map(lambda x: x)) @combinations.generate(test_base.default_test_combinations()) def testPaddedBatchInputs(self): - self.checkUnaryInputs(lambda x: x.padded_batch(10, [])) + self._testUnaryInputs(lambda x: x.padded_batch(10, [])) @combinations.generate(test_base.default_test_combinations()) def testParallelMapInputs(self): - self.checkUnaryInputs(lambda x: x.map(lambda x: x, num_parallel_calls=2)) + self._testUnaryInputs(lambda x: x.map(lambda x: x, num_parallel_calls=2)) @combinations.generate(test_base.default_test_combinations()) def testRepeatInputs(self): - self.checkUnaryInputs(lambda x: x.repeat()) + self._testUnaryInputs(lambda x: x.repeat()) @combinations.generate(test_base.default_test_combinations()) def testShuffleInputs(self): - self.checkUnaryInputs(lambda x: x.shuffle(10)) + self._testUnaryInputs(lambda x: x.shuffle(10)) @combinations.generate(test_base.default_test_combinations()) def testSkipInputs(self): - self.checkUnaryInputs(lambda x: x.skip(1)) + self._testUnaryInputs(lambda x: x.skip(1)) @combinations.generate(test_base.default_test_combinations()) def testTakeInputs(self): - self.checkUnaryInputs(lambda x: x.take(1)) + self._testUnaryInputs(lambda x: x.take(1)) @combinations.generate(test_base.default_test_combinations()) def testWindowInputs(self): - self.checkUnaryInputs(lambda x: x.window(10)) + self._testUnaryInputs(lambda x: x.window(10)) @combinations.generate(test_base.default_test_combinations()) def testUnaryTransformationInputsApply(self): @@ -195,7 +194,7 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual([input_dataset], dataset._inputs()) - def checkInputsWithInterleaveFn(self, dataset_fn, interleave_parallelism): + def _testInputsWithInterleaveFn(self, dataset_fn, interleave_parallelism): input_dataset = dataset_ops.Dataset.range(0) dataset = input_dataset.interleave( lambda x: dataset_ops.Dataset.range(0), @@ -205,11 +204,11 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testParallelInterleaveInputs(self): - self.checkInputsWithInterleaveFn(lambda: dataset_ops.range(0), 2) + self._testInputsWithInterleaveFn(lambda: dataset_ops.range(0), 2) @combinations.generate(test_base.default_test_combinations()) def testInterleaveInputs(self): - self.checkInputsWithInterleaveFn(lambda: dataset_ops.range(0), None) + self._testInputsWithInterleaveFn(lambda: dataset_ops.range(0), None) @combinations.generate(test_base.default_test_combinations()) def testNoWarnings(self): @@ -218,16 +217,16 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): lambda x: dataset_ops.Dataset.range(0), cycle_length=2) self.assertEmpty(mock_log.call_args_list) - def checkBinaryInputs(self, dataset_fn): + def _testBinaryInputs(self, dataset_fn): input1 = dataset_ops.Dataset.range(0) input2 = dataset_ops.Dataset.range(1) self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs()) @combinations.generate(test_base.default_test_combinations()) def testConcatenateInputs(self): - self.checkBinaryInputs(lambda x, y: x.concatenate(y)) + self._testBinaryInputs(lambda x, y: x.concatenate(y)) - def checkVariadicInputs(self, dataset_fn, input_datasets): + def _testVariadicInputs(self, dataset_fn, input_datasets): self.assertEqual( nest.flatten(input_datasets), dataset_fn(input_datasets)._inputs()) @@ -235,20 +234,20 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testZipOneInputs(self): input_datasets = dataset_ops.Dataset.range(0) - self.checkVariadicInputs(dataset_ops.Dataset.zip, input_datasets) + self._testVariadicInputs(dataset_ops.Dataset.zip, input_datasets) @combinations.generate(test_base.default_test_combinations()) def testZipNestInputs(self): input_datasets = (dataset_ops.Dataset.range(0), (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2))) - self.checkVariadicInputs(dataset_ops.Dataset.zip, input_datasets) + self._testVariadicInputs(dataset_ops.Dataset.zip, input_datasets) @combinations.generate(test_base.default_test_combinations()) def testZipTupleInputs(self): input_datasets = (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1)) - self.checkVariadicInputs(dataset_ops.Dataset.zip, input_datasets) + self._testVariadicInputs(dataset_ops.Dataset.zip, input_datasets) @combinations.generate(test_base.default_test_combinations()) def testFunctions(self): @@ -273,7 +272,7 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual(2, inputs.count(ds2)) self.assertEqual(1, inputs.count(ds3)) - def checkDatasetSpec(self, tf_value, expected_element_structure): + def _testDatasetSpec(self, tf_value, expected_element_structure): dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value) dataset_structure = structure.type_spec_from_value(dataset) self.assertIsInstance(dataset_structure, dataset_ops.DatasetSpec) @@ -307,12 +306,12 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testTensorDatasetSpec(self): - self.checkDatasetSpec( + self._testDatasetSpec( constant_op.constant(37.0), tensor_spec.TensorSpec([], dtypes.float32)) @combinations.generate(test_base.default_test_combinations()) def testSparseTensorDatasetSpec(self): - self.checkDatasetSpec( + self._testDatasetSpec( sparse_tensor.SparseTensor( indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32), @@ -320,7 +319,7 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testNestDatasetSpec(self): - self.checkDatasetSpec( + self._testDatasetSpec( { "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) @@ -335,20 +334,19 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testDatasetDatasetSpec(self): - self.checkDatasetSpec( + self._testDatasetSpec( dataset_ops.Dataset.from_tensor_slices( constant_op.constant([1, 2, 3])), dataset_ops.DatasetSpec(tensor_spec.TensorSpec([], dtypes.int32))) @combinations.generate(test_base.default_test_combinations()) def testOptionalDatasetSpec(self): - self.checkDatasetSpec( + self._testDatasetSpec( optional_ops.Optional.from_value(37.0), optional_ops.OptionalSpec(tensor_spec.TensorSpec([], dtypes.float32))) - @combinations.generate( - combinations.combine(tf_api_version=[1], mode=["graph"])) - def testSkipEagerSameGraphErrorOneShot(self): + @combinations.generate(test_base.graph_only_combinations()) + def testSameGraphError(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): with self.assertRaisesRegexp(ValueError, "must be from the same graph"): @@ -356,26 +354,27 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.combine(tf_api_version=[1], mode=["graph"])) - def testSkipEagerSameGraphErrorOneShotSimple(self): + def testSameGraphErrorOneShot(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): - with test.mock.patch.object(tf_logging, "warning") as mock_log: + with self.assertRaisesRegexp( + ValueError, "Please ensure that all datasets in the pipeline are " + "created in the same graph as the iterator."): _ = dataset_ops.make_one_shot_iterator(dataset) - self.assertRegexpMatches( - str(mock_log.call_args), "Please ensure that all datasets in the " - "pipeline are created in the same graph as the iterator.") @combinations.generate( combinations.combine(tf_api_version=[1], mode=["graph"])) - def testSkipEagerSameGraphErrorInitializable(self): + def testSameGraphErrorInitializable(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): - with self.assertRaisesRegexp(ValueError, "must be from the same graph"): - dataset = dataset.batch(2) + with self.assertRaisesRegexp( + ValueError, "Please ensure that all datasets in the pipeline are " + "created in the same graph as the iterator."): + _ = dataset_ops.make_initializable_iterator(dataset) @combinations.generate( combinations.times( - combinations.combine(tf_api_version=[1, 2], mode="eager"), + test_base.eager_only_combinations(), combinations.combine(execution_mode=[context.ASYNC, context.SYNC]))) def testEagerIteration(self, execution_mode): with context.execution_mode(execution_mode): diff --git a/tensorflow/python/data/kernel_tests/filter_test.py b/tensorflow/python/data/kernel_tests/filter_test.py index 05b538a46ce..f6bdcb12020 100644 --- a/tensorflow/python/data/kernel_tests/filter_test.py +++ b/tensorflow/python/data/kernel_tests/filter_test.py @@ -30,28 +30,31 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -def new_and_legacy_filter_fn_combinations(): +def _test_combinations(): - def new_filter_fn(dataset, predicate): + def filter_fn(dataset, predicate): return dataset.filter(predicate) def legacy_filter_fn(dataset, predicate): return dataset.filter_with_legacy_function(predicate) - return (combinations.combine( + filter_combinations = combinations.combine( tf_api_version=[1, 2], mode=["eager", "graph"], - apply_filter=combinations.NamedObject("new_filter_fn", new_filter_fn)) + - combinations.combine( - tf_api_version=1, - mode=["eager", "graph"], - apply_filter=combinations.NamedObject("legacy_filter_fn", - legacy_filter_fn))) + apply_filter=combinations.NamedObject("filter_fn", filter_fn)) + + legacy_filter_combinations = combinations.combine( + tf_api_version=1, + mode=["eager", "graph"], + apply_filter=combinations.NamedObject("legacy_filter_fn", + legacy_filter_fn)) + + return filter_combinations + legacy_filter_combinations class FilterTest(test_base.DatasetTestBase, parameterized.TestCase): - @combinations.generate(new_and_legacy_filter_fn_combinations()) + @combinations.generate(_test_combinations()) def testFilterDataset(self, apply_filter): components = (np.arange(7, dtype=np.int64), np.array([[1, 2, 3]], dtype=np.int64) * @@ -87,14 +90,14 @@ class FilterTest(test_base.DatasetTestBase, parameterized.TestCase): # Test an empty dataset. do_test(0, 1) - @combinations.generate(new_and_legacy_filter_fn_combinations()) + @combinations.generate(_test_combinations()) def testFilterRange(self, apply_filter): dataset = dataset_ops.Dataset.range(4) dataset = apply_filter(dataset, lambda x: math_ops.not_equal(math_ops.mod(x, 3), 2)) self.assertDatasetProduces(dataset, expected_output=[0, 1, 3]) - @combinations.generate(new_and_legacy_filter_fn_combinations()) + @combinations.generate(_test_combinations()) def testFilterDict(self, apply_filter): dataset = dataset_ops.Dataset.range(10).map( lambda x: {"foo": x * 2, "bar": x**2}) @@ -104,7 +107,7 @@ class FilterTest(test_base.DatasetTestBase, parameterized.TestCase): dataset, expected_output=[(i * 2 + i**2) for i in range(10) if not (i**2) % 2]) - @combinations.generate(new_and_legacy_filter_fn_combinations()) + @combinations.generate(_test_combinations()) def testUseStepContainerInFilter(self, apply_filter): input_data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) @@ -119,7 +122,7 @@ class FilterTest(test_base.DatasetTestBase, parameterized.TestCase): dataset = apply_filter(dataset, _predicate) self.assertDatasetProduces(dataset, expected_output=[input_data[0]]) - @combinations.generate(new_and_legacy_filter_fn_combinations()) + @combinations.generate(_test_combinations()) def testSparse(self, apply_filter): def _map_fn(i): @@ -137,7 +140,7 @@ class FilterTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces( dataset, expected_output=[_map_fn(i * 2)[0] for i in range(5)]) - @combinations.generate(new_and_legacy_filter_fn_combinations()) + @combinations.generate(_test_combinations()) def testShortCircuit(self, apply_filter): dataset = dataset_ops.Dataset.zip( (dataset_ops.Dataset.range(10), @@ -146,7 +149,7 @@ class FilterTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces( dataset, expected_output=[(i, True) for i in range(10)]) - @combinations.generate(new_and_legacy_filter_fn_combinations()) + @combinations.generate(_test_combinations()) def testParallelFilters(self, apply_filter): dataset = dataset_ops.Dataset.range(10) dataset = apply_filter(dataset, lambda x: math_ops.equal(x % 2, 0)) diff --git a/tensorflow/python/data/kernel_tests/flat_map_test.py b/tensorflow/python/data/kernel_tests/flat_map_test.py index 3afed61fc7f..00b6e400ea7 100644 --- a/tensorflow/python/data/kernel_tests/flat_map_test.py +++ b/tensorflow/python/data/kernel_tests/flat_map_test.py @@ -66,10 +66,8 @@ class FlatMapTest(test_base.DatasetTestBase, parameterized.TestCase): expected_output.extend([i] * i) self.assertDatasetProduces(dataset, expected_output=expected_output) - # Note: no eager mode coverage, session specific test. - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode=["graph"])) - def testSkipEagerSharedResourceNestedFlatMapDataset(self): + @combinations.generate(test_base.graph_only_combinations()) + def testSharedResourceNestedFlatMapDataset(self): repeats = [[1, 2], [3, 4], [5, 0], [1, 7]] components = np.array(repeats, dtype=np.int64) iterator = ( diff --git a/tensorflow/python/data/kernel_tests/from_generator_test.py b/tensorflow/python/data/kernel_tests/from_generator_test.py index dfa467add62..49753babacb 100644 --- a/tensorflow/python/data/kernel_tests/from_generator_test.py +++ b/tensorflow/python/data/kernel_tests/from_generator_test.py @@ -32,62 +32,83 @@ from tensorflow.python.ops import script_ops from tensorflow.python.platform import test -class DatasetConstructorTest(test_base.DatasetTestBase, parameterized.TestCase): +class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): def _testFromGenerator(self, generator, elem_sequence, num_repeats, - output_types=None): - if output_types is None: - output_types = dtypes.int64 - dataset = dataset_ops.Dataset.from_generator( - generator, output_types=output_types).repeat(num_repeats).prefetch(5) - self.assertDatasetProduces( - dataset, - elem_sequence * num_repeats, - requires_initialization=True, - num_test_iterations=2) - - def _testFromGeneratorOneShot(self, generator, elem_sequence, num_repeats): + requires_initialization): dataset = dataset_ops.Dataset.from_generator( generator, output_types=dtypes.int64).repeat(num_repeats).prefetch(5) self.assertDatasetProduces( - dataset, elem_sequence * num_repeats, num_test_iterations=2) + dataset, + elem_sequence * num_repeats, + requires_initialization=requires_initialization, + num_test_iterations=2) + + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine( + num_repeats=[1, 5], requires_initialization=[True, False]))) + def testFromGeneratorUsingFn(self, num_repeats, requires_initialization): - @combinations.generate(test_base.default_test_combinations()) - def testFromGeneratorUsingFunction(self): def generator(): for i in range(1, 100): yield [i] * i - elem_sequence = list(generator()) - self._testFromGenerator(generator, elem_sequence, 1) - self._testFromGenerator(generator, elem_sequence, 5) - self._testFromGeneratorOneShot(generator, elem_sequence, 1) - self._testFromGeneratorOneShot(generator, elem_sequence, 5) - @combinations.generate(test_base.default_test_combinations()) - def testFromGeneratorUsingList(self): + elem_sequence = list(generator()) + self._testFromGenerator( + generator, + elem_sequence, + num_repeats=num_repeats, + requires_initialization=requires_initialization) + + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine( + num_repeats=[1, 5], requires_initialization=[True, False]))) + def testFromGeneratorUsingList(self, num_repeats, requires_initialization): generator = lambda: [[i] * i for i in range(1, 100)] elem_sequence = list(generator()) - self._testFromGenerator(generator, elem_sequence, 1) - self._testFromGenerator(generator, elem_sequence, 5) + self._testFromGenerator( + generator, + elem_sequence, + num_repeats=num_repeats, + requires_initialization=requires_initialization) - @combinations.generate(test_base.default_test_combinations()) - def testFromGeneratorUsingNdarray(self): + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine( + num_repeats=[1, 5], requires_initialization=[True, False]))) + def testFromGeneratorUsingNdarray(self, num_repeats, requires_initialization): generator = lambda: np.arange(100, dtype=np.int64) elem_sequence = list(generator()) - self._testFromGenerator(generator, elem_sequence, 1, output_types=np.int64) - self._testFromGenerator(generator, elem_sequence, 5, output_types=np.int64) + self._testFromGenerator( + generator, + elem_sequence, + num_repeats=num_repeats, + requires_initialization=requires_initialization) - @combinations.generate(test_base.default_test_combinations()) - def testFromGeneratorUsingGeneratorExpression(self): - # NOTE(mrry): Generator *expressions* are not repeatable (or in - # general reusable), because they eagerly evaluate the `for` - # expression as `iter(range(1, 100))` and discard the means of - # reconstructing `range(1, 100)`. Wrapping the generator - # expression in a `lambda` makes it repeatable. + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine( + num_repeats=[1, 5], requires_initialization=[True, False]))) + def testFromGeneratorUsingGeneratorExpression(self, num_repeats, + requires_initialization): + # NOTE(mrry): Generator *expressions* are not repeatable (or in general + # reusable), because they eagerly evaluate the `for` expression as + # `iter(range(1, 100))` and discard the means of reconstructing + # `range(1, 100)`. Wrapping the generator expression in a `lambda` makes + # it repeatable. generator = lambda: ([i] * i for i in range(1, 100)) elem_sequence = list(generator()) - self._testFromGenerator(generator, elem_sequence, 1) - self._testFromGenerator(generator, elem_sequence, 5) + self._testFromGenerator( + generator, + elem_sequence, + num_repeats=num_repeats, + requires_initialization=requires_initialization) @combinations.generate(test_base.default_test_combinations()) def testFromMultipleConcurrentGenerators(self): @@ -392,7 +413,6 @@ class DatasetConstructorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertAllEqual(37, self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) - self.assertTrue(event.is_set()) @combinations.generate(test_base.default_test_combinations()) def testSharedName(self): diff --git a/tensorflow/python/data/kernel_tests/from_tensors_test.py b/tensorflow/python/data/kernel_tests/from_tensors_test.py index e293383403f..c899c156739 100644 --- a/tensorflow/python/data/kernel_tests/from_tensors_test.py +++ b/tensorflow/python/data/kernel_tests/from_tensors_test.py @@ -237,8 +237,8 @@ class FromTensorsTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual([3], get_next().shape) # TODO(b/121264236): needs mechanism for multiple device in eager mode. - @combinations.generate(test_base.default_test_combinations()) - def testSkipEagerSplitPipeline(self): + @combinations.generate(test_base.graph_only_combinations()) + def testSplitPipeline(self): with session.Session( target="", config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: diff --git a/tensorflow/python/data/kernel_tests/iterator_cluster_test.py b/tensorflow/python/data/kernel_tests/iterator_cluster_test.py index 0384f9fc18a..0a40c212006 100644 --- a/tensorflow/python/data/kernel_tests/iterator_cluster_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_cluster_test.py @@ -17,12 +17,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -37,9 +40,9 @@ from tensorflow.python.ops import string_ops from tensorflow.python.platform import test -class IteratorClusterTest(test.TestCase): +class IteratorClusterTest(test.TestCase, parameterized.TestCase): - @test_util.run_v1_only("b/120545219") + @combinations.generate(test_base.graph_only_combinations()) def testRemoteIteratorWithoutRemoteCallFail(self): worker_config = config_pb2.ConfigProto() worker_config.device_count["CPU"] = 2 @@ -95,7 +98,7 @@ class IteratorClusterTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(remote_op, feed_dict={target_placeholder: device1}) - @test_util.run_v1_only("b/120545219") + @combinations.generate(test_base.graph_only_combinations()) def testRemoteIteratorUsingRemoteCallOp(self): worker_config = config_pb2.ConfigProto() worker_config.device_count["CPU"] = 2 @@ -106,7 +109,7 @@ class IteratorClusterTest(test.TestCase): "/job:worker/replica:0/task:0/cpu:1", worker[0].target) - @test_util.run_v1_only("b/120545219") + @combinations.generate(test_base.graph_only_combinations()) def testRemoteIteratorUsingRemoteCallOpCrossProcess(self): workers, _ = test_util.create_local_cluster(2, 1) @@ -114,7 +117,7 @@ class IteratorClusterTest(test.TestCase): "/job:worker/replica:0/task:1/cpu:0", workers[0].target) - @test_util.run_v1_only("b/120545219") + @combinations.generate(test_base.graph_only_combinations()) def testCaptureHashTableInSharedIterator(self): worker, _ = test_util.create_local_cluster(1, 1) @@ -131,10 +134,10 @@ class IteratorClusterTest(test.TestCase): input_sentences = dataset_ops.Dataset.from_tensor_slices( ["brain brain tank salad surgery", "surgery brain"]) - iterator = ( - input_sentences.map(lambda x: string_ops.string_split([x]).values).map( - table.lookup) - .make_initializable_iterator(shared_name="shared_iterator")) + dataset = input_sentences.map( + lambda x: string_ops.string_split([x]).values).map(table.lookup) + iterator = dataset_ops.make_initializable_iterator( + dataset, shared_name="shared_iterator") init_op = iterator.initializer get_next = iterator.get_next() @@ -148,7 +151,7 @@ class IteratorClusterTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - @test_util.run_v1_only("b/120545219") + @combinations.generate(test_base.graph_only_combinations()) def testImplicitDisposeParallelMapDataset(self): # Tests whether a parallel map dataset will be cleaned up correctly when # the pipeline does not run it until exhaustion. diff --git a/tensorflow/python/data/kernel_tests/iterator_test.py b/tensorflow/python/data/kernel_tests/iterator_test.py index 70f03f3e4d2..fcb2e4c0b1f 100644 --- a/tensorflow/python/data/kernel_tests/iterator_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_test.py @@ -56,8 +56,7 @@ from tensorflow.python.util import compat class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode=["graph"])) + @combinations.generate(test_base.graph_only_combinations()) def testNoGradients(self): component = constant_op.constant([1.]) side = constant_op.constant(0.) @@ -68,8 +67,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertIsNone(gradients_impl.gradients(value, side)[0]) self.assertIsNone(gradients_impl.gradients(value, [component, side])[0]) - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode=["graph"])) + @combinations.generate(test_base.graph_only_combinations()) def testCapturingStateInOneShotRaisesException(self): var = variables.Variable(37.0, name="myvar") dataset = ( @@ -80,8 +78,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): "datasets that capture stateful objects.+myvar"): dataset_ops.make_one_shot_iterator(dataset) - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode=["graph"])) + @combinations.generate(test_base.graph_only_combinations()) def testOneShotIterator(self): components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], @@ -107,8 +104,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode=["graph"])) + @combinations.generate(test_base.graph_only_combinations()) def testOneShotIteratorCaptureByValue(self): components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], @@ -172,8 +168,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode=["graph"])) + @combinations.generate(test_base.graph_only_combinations()) def testOneShotIteratorNonBlocking(self): dataset = dataset_ops.Dataset.from_tensors([1, 2, 3]).map(lambda x: x * x) iterator = dataset_ops.make_one_shot_iterator(dataset) @@ -207,13 +202,11 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): for t in threads: t.join() - self.assertEqual(num_threads, len(results)) - self.assertEqual(num_threads - 1, - len([None for r in results if r is None])) + self.assertLen(results, num_threads) + self.assertLen([None for r in results if r is None], num_threads - 1) self.assertAllEqual([[1, 4, 9]], [r for r in results if r is not None]) - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode=["graph"])) + @combinations.generate(test_base.graph_only_combinations()) def testOneShotIteratorInitializerFails(self): # Define a dataset whose initialization will always fail. dataset = dataset_ops.Dataset.from_tensors(array_ops.gather([0], [4])) @@ -243,8 +236,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): for t in threads: t.join() - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode=["graph"])) + @combinations.generate(test_base.graph_only_combinations()) def testSimpleSharedResource(self): components = (np.array(1, dtype=np.int64), np.array([1, 2, 3], dtype=np.int64), @@ -294,8 +286,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode=["graph"])) + @combinations.generate(test_base.graph_only_combinations()) def testNotInitializedError(self): components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) iterator = dataset_ops.make_initializable_iterator( @@ -307,8 +298,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): "iterator has not been initialized"): sess.run(get_next) - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode=["graph"])) + @combinations.generate(test_base.graph_only_combinations()) def testReinitializableIterator(self): dataset_3 = dataset_ops.Dataset.from_tensors( constant_op.constant([1, 2, 3])) @@ -353,8 +343,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode=["graph"])) + @combinations.generate(test_base.graph_only_combinations()) def testReinitializableIteratorWithFunctions(self): def g(): @@ -415,8 +404,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): (constant_op.constant([1, 2, 3], dtype=dtypes.int64), constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float64)))) - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode=["graph"])) + @combinations.generate(test_base.graph_only_combinations()) def testIteratorStringHandle(self): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40]) @@ -474,8 +462,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle}) - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode=["graph"])) + @combinations.generate(test_base.graph_only_combinations()) def testIteratorStringHandleFuture(self): with forward_compat.forward_compatibility_horizon(2018, 8, 4): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) @@ -541,8 +528,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle}) - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode=["graph"])) + @combinations.generate(test_base.graph_only_combinations()) def testIteratorStringHandleReuseTensorObject(self): dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) one_shot_iterator = dataset_ops.make_one_shot_iterator(dataset) @@ -571,8 +557,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual("foo_1", handle_with_same_name.op.name) self.assertIsNot(handle_with_name, handle_with_same_name) - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode=["graph"])) + @combinations.generate(test_base.graph_only_combinations()) def testIteratorStringHandleError(self): dataset_int_scalar = ( dataset_ops.Dataset.from_tensor_slices([1, 2, 3]).repeat()) @@ -613,8 +598,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): feedable_int_vector.get_next(), feed_dict={handle_placeholder: handle_float_vector})) - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode=["graph"])) + @combinations.generate(test_base.graph_only_combinations()) def testRemoteIteratorUsingRemoteCallOpDirectSession(self): worker_config = config_pb2.ConfigProto() worker_config.device_count["CPU"] = 3 @@ -672,8 +656,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode=["graph"])) + @combinations.generate(test_base.graph_only_combinations()) def testRemoteIteratorUsingRemoteCallOpMultiWorkers(self): s1 = server_lib.Server.create_local_server() s2 = server_lib.Server.create_local_server() @@ -727,8 +710,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(n) - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode=["graph"])) + @combinations.generate(test_base.graph_only_combinations()) def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") @@ -785,8 +767,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" }) - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode=["graph"])) + @combinations.generate(test_base.graph_only_combinations()) def testRepeatedGetNextWarning(self): iterator = dataset_ops.make_one_shot_iterator(dataset_ops.Dataset.range(10)) warnings.simplefilter("always") @@ -929,7 +910,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual(val, foo.numpy()) val += 1 - @combinations.generate(combinations.combine(tf_api_version=2, mode="eager")) + @combinations.generate(test_base.eager_only_combinations()) def testOwnedIteratorFunction(self): queue = data_flow_ops.FIFOQueue(10, dtypes.int64) @@ -946,7 +927,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): for i in range(10): self.assertEqual(queue.dequeue().numpy(), i) - @combinations.generate(combinations.combine(tf_api_version=2, mode="eager")) + @combinations.generate(test_base.eager_only_combinations()) def testOwnedIteratorFunctionError(self): # In this test we verify that a function that raises an error ends up # properly deallocating the iterator resource. @@ -976,7 +957,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual(queue.size().numpy(), 2) - @combinations.generate(combinations.combine(tf_api_version=2, mode="eager")) + @combinations.generate(test_base.eager_only_combinations()) def testLimitedRetracing(self): trace_count = [0] @@ -996,7 +977,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual(self.evaluate(f(iter(dataset2))), 45) self.assertEqual(trace_count[0], 1) - @combinations.generate(combinations.combine(tf_api_version=2, mode="eager")) + @combinations.generate(test_base.eager_only_combinations()) def testNestedFunctionsIteratorResource(self): @def_function.function diff --git a/tensorflow/python/data/kernel_tests/list_files_test.py b/tensorflow/python/data/kernel_tests/list_files_test.py index 52ce300f537..40b4b77116c 100644 --- a/tensorflow/python/data/kernel_tests/list_files_test.py +++ b/tensorflow/python/data/kernel_tests/list_files_test.py @@ -35,10 +35,12 @@ from tensorflow.python.util import compat class ListFilesTest(test_base.DatasetTestBase, parameterized.TestCase): def setUp(self): + super(ListFilesTest, self).setUp() self.tmp_dir = tempfile.mkdtemp() def tearDown(self): shutil.rmtree(self.tmp_dir, ignore_errors=True) + super(ListFilesTest, self).tearDown() def _touchTempFiles(self, filenames): for filename in filenames: diff --git a/tensorflow/python/data/kernel_tests/map_test.py b/tensorflow/python/data/kernel_tests/map_test.py index 0847cdd7a0d..c8b23edbc7f 100644 --- a/tensorflow/python/data/kernel_tests/map_test.py +++ b/tensorflow/python/data/kernel_tests/map_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools from collections import namedtuple import threading import time @@ -31,13 +32,13 @@ from tensorflow.python.data.experimental.ops import threading_options from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context +from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_util -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops @@ -57,10 +58,70 @@ from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import test -def _make_coordinated_sloppy_dataset(num_elements, num_parallel_calls): +def _test_combinations_with_mode_v1(mode): + + def new_map_fn(dataset, *args, **kwargs): + return dataset.map(*args, **kwargs) + + def legacy_map_fn(dataset, *args, **kwargs): + return dataset.map_with_legacy_function(*args, **kwargs) + + new_map_combinations = combinations.combine( + tf_api_version=1, + mode=mode, + apply_map=combinations.NamedObject("map_fn", new_map_fn)) + + legacy_map_combinations = combinations.combine( + tf_api_version=1, + mode=mode, + apply_map=combinations.NamedObject("legacy_map_fn", legacy_map_fn)) + + return new_map_combinations + legacy_map_combinations + + +def _test_combinations_with_mode_v2(mode): + + def new_map_fn(dataset, *args, **kwargs): + return dataset.map(*args, **kwargs) + + return combinations.combine( + tf_api_version=2, + mode=mode, + apply_map=combinations.NamedObject("map_fn", new_map_fn)) + + +def _test_combinations_with_mode(mode): + return _test_combinations_with_mode_v1( + mode) + _test_combinations_with_mode_v2(mode) + + +def _test_combinations(): + return _test_combinations_with_mode("eager") + _test_combinations_with_mode( + "graph") + + +def _short_circuit_test_cases(): + cases = [ + ("Identity", None, lambda x: x), + ("Replicate", None, lambda x: (x, x)), + ("Swap", (None, None), lambda x, y: (y, x)), + ("Project", (None, None), lambda x, y: x) + ] + + def reduce_fn(x, y): + name, structure, fn = y + return x + combinations.combine( + structure=structure, fn=combinations.NamedObject(name, fn)) + + return functools.reduce(reduce_fn, cases, []) + + +def _make_coordinated_sloppy_dataset(apply_map, num_elements, + num_parallel_calls): """Produces a dataset iterator and events to control the order of elements. Args: + apply_map: method that applies the `map` transformation num_elements: the number of input elements num_parallel_calls: the degree of map parallelism @@ -84,28 +145,27 @@ def _make_coordinated_sloppy_dataset(num_elements, num_parallel_calls): options = dataset_ops.Options() options.experimental_deterministic = False - dataset = dataset_ops.Dataset.range(num_elements).map( - map_fn, num_parallel_calls).with_options(options) + dataset = dataset_ops.Dataset.range(num_elements) + dataset = apply_map(dataset, map_fn, num_parallel_calls).with_options(options) return dataset, coordination_events -# TODO(jsimsa): Add tests for `map_with_legacy_function`. -@test_util.run_all_in_graph_and_eager_modes class MapTest(test_base.DatasetTestBase, parameterized.TestCase): - def _buildMapDataset(self, components, count): + def _map_dataset_factory(self, components, apply_map, count): def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) - dataset = dataset_ops.Dataset.from_tensor_slices(components).map( - _map_fn).repeat(count) + dataset = dataset_ops.Dataset.from_tensor_slices(components) + dataset = apply_map(dataset, _map_fn).repeat(count) self.assertEqual( [c.shape[1:] for c in components], [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)]) return dataset - def testMapDataset(self): + @combinations.generate(_test_combinations()) + def testMapDataset(self, apply_map): """Test an dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> # RepeatDataset(count). @@ -114,7 +174,8 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): np.array(37.0) * np.arange(7)) # Test single-threaded access to the iterator. - get_next = self.getNext(self._buildMapDataset(components, 14)) + get_next = self.getNext( + self._map_dataset_factory(components, apply_map, count=14)) for _ in range(14): for i in range(7): result = self.evaluate(get_next()) @@ -123,15 +184,15 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) - # TODO(b/117581999): add eager coverage, different threads run in graph - # context. - @test_util.run_v1_only("b/120545219") - def testSkipEagerMapDatasetMultithreaded(self): + # TODO(b/117581999): add eager coverage + @combinations.generate(_test_combinations_with_mode("graph")) + def testMapDatasetMultiThreaded(self, apply_map): # Test multi-threaded access to the same iterator. components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) - get_next = self.getNext(self._buildMapDataset(components, 18)) + get_next = self.getNext( + self._map_dataset_factory(components, apply_map, count=18)) results = [] with self.cached_session() as sess: def iterator_thread(): @@ -157,94 +218,99 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): results[i * 18 + j]): self.assertAllEqual(component[i]**2, result_component) - def _buildParallelMapDataset(self, components, count, num_parallel_calls, - output_buffer_size): + def _parallel_map_dataset_factory(self, components, apply_map, count, + num_parallel_calls, buffer_size): def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) - dataset = dataset_ops.Dataset.from_tensor_slices(components).map( - _map_fn, num_parallel_calls=num_parallel_calls).prefetch( - output_buffer_size).repeat(count) + dataset = dataset_ops.Dataset.from_tensor_slices(components) + dataset = apply_map(dataset, _map_fn, num_parallel_calls=num_parallel_calls) + dataset = dataset.prefetch(buffer_size).repeat(count) self.assertEqual( [c.shape[1:] for c in components], [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)]) return dataset - def testParallelMapDataset(self): + @combinations.generate( + combinations.times( + _test_combinations(), + combinations.combine(num_parallel_calls=1, buffer_size=1) + + combinations.combine(num_parallel_calls=1, buffer_size=2) + + combinations.combine(num_parallel_calls=2, buffer_size=2) + + combinations.combine(num_parallel_calls=2, buffer_size=4) + + combinations.combine(num_parallel_calls=8, buffer_size=8) + + combinations.combine(num_parallel_calls=8, buffer_size=16))) + def testParallelMapDataset(self, apply_map, num_parallel_calls, buffer_size): """Test an dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> ParallelMapDataset(square_3) -> # RepeatDataset(count). - def do_test(num_parallel_calls, output_buffer_size): + components = (np.arange(7), + np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], + np.array(37.0) * np.arange(7)) + # Test single-threaded access to the iterator. + get_next = self.getNext( + self._parallel_map_dataset_factory(components, apply_map, 14, + num_parallel_calls, buffer_size)) + for _ in range(14): + for i in range(7): + result = self.evaluate(get_next()) + for component, result_component in zip(components, result): + self.assertAllEqual(component[i]**2, result_component) + with self.assertRaises(errors.OutOfRangeError): + self.evaluate(get_next()) - components = (np.arange(7), - np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], - np.array(37.0) * np.arange(7)) - # Test single-threaded access to the iterator. - get_next = self.getNext( - self._buildParallelMapDataset(components, 14, num_parallel_calls, - output_buffer_size)) - for _ in range(14): - for i in range(7): - result = self.evaluate(get_next()) - for component, result_component in zip(components, result): + # TODO(b/117581999): add eager coverage + @combinations.generate( + combinations.times( + _test_combinations_with_mode("graph"), + combinations.combine(num_parallel_calls=1, buffer_size=1) + + combinations.combine(num_parallel_calls=1, buffer_size=2) + + combinations.combine(num_parallel_calls=2, buffer_size=2) + + combinations.combine(num_parallel_calls=2, buffer_size=4) + + combinations.combine(num_parallel_calls=8, buffer_size=8) + + combinations.combine(num_parallel_calls=8, buffer_size=16))) + def testParallelMapDatasetMultiThreaded(self, apply_map, num_parallel_calls, + buffer_size): + + # Test multi-threaded access to the same iterator. + components = (np.arange(7), + np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], + np.array(37.0) * np.arange(7)) + get_next = self.getNext( + self._parallel_map_dataset_factory(components, apply_map, 18, + num_parallel_calls, buffer_size)) + results = [] + with self.cached_session() as sess: + + def iterator_thread(): + while True: + try: + results.append(sess.run(get_next())) + except errors.OutOfRangeError: + return + + threads = [self.checkedThread(target=iterator_thread) for _ in range(64)] + for t in threads: + t.start() + for t in threads: + t.join() + + # `results` will contain the same elements components**2 + # repeated 18 times, but in a non-deterministic order. Sort the + # results, and assert that each element of components**2 is + # produced 18 times. + results.sort(key=lambda x: x[0]) + for i in range(7): + for j in range(18): + for component, result_component in zip(components, + results[i * 18 + j]): self.assertAllEqual(component[i]**2, result_component) - with self.assertRaises(errors.OutOfRangeError): - self.evaluate(get_next()) - for num_parallel_calls_val, output_buffer_size_val in [(1, 1), (1, 2), (2, - 2), - (2, 4), (8, 8), - (8, 16)]: - do_test(num_parallel_calls_val, output_buffer_size_val) - - # TODO(b/117581999): add eager coverage, different threads run in graph - # context. - @test_util.run_v1_only("b/120545219") - def testSkipEagerParallelMapDatasetMultithreaded(self): - - def do_test(num_parallel_calls, output_buffer_size): - # Test multi-threaded access to the same iterator. - components = (np.arange(7), - np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], - np.array(37.0) * np.arange(7)) - get_next = self.getNext( - self._buildParallelMapDataset(components, 18, num_parallel_calls, - output_buffer_size)) - results = [] - with self.cached_session() as sess: - - def iterator_thread(): - while True: - try: - results.append(sess.run(get_next())) - except errors.OutOfRangeError: - return - threads = [self.checkedThread(target=iterator_thread) - for _ in range(64)] - for t in threads: - t.start() - for t in threads: - t.join() - - # `results` will contain the same elements components**2 - # repeated 18 times, but in a non-deterministic order. Sort the - # results, and assert that each element of components**2 is - # produced 18 times. - results.sort(key=lambda x: x[0]) - for i in range(7): - for j in range(18): - for component, result_component in zip(components, - results[i * 18 + j]): - self.assertAllEqual(component[i]**2, result_component) - - for num_parallel_calls_val, output_buffer_size_val in [ - (1, 1), (1, 2), (2, 2), (2, 4), (8, 8), (8, 16)]: - do_test(num_parallel_calls_val, output_buffer_size_val) - - def testImplicitDisposeParallelMapDataset(self): + @combinations.generate(_test_combinations()) + def testImplicitDisposeParallelMapDataset(self, apply_map): # Tests whether a parallel map dataset will be cleaned up correctly when # the pipeline does not run it until exhaustion. # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> @@ -253,7 +319,8 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis], np.array(37.0) * np.arange(1000)) - dataset = self._buildParallelMapDataset(components, 1000, 100, 100) + dataset = self._parallel_map_dataset_factory(components, apply_map, 1000, + 100, 100) # NOTE(mrry): Also test that the prefetching thread is cancelled correctly. dataset = dataset.prefetch(100) get_next = self.getNext(dataset) @@ -261,23 +328,29 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): for _ in range(3): self.evaluate(get_next()) - def testParallelMapUnspecifiedOutputSize(self): + @combinations.generate(_test_combinations()) + def testParallelMapUnspecifiedOutputSize(self, apply_map): components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) - dataset = (dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: array_ops.check_numerics(x, "message"), - num_parallel_calls=2)) + dataset = dataset_ops.Dataset.from_tensor_slices(components) + dataset = apply_map( + dataset, + lambda x: array_ops.check_numerics(x, "message"), + num_parallel_calls=2) get_next = self.getNext(dataset) for _ in range(3): self.evaluate(get_next()) - def testParallelMapError(self): + @combinations.generate(_test_combinations()) + def testParallelMapError(self, apply_map): components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) - dataset = (dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: array_ops.check_numerics(x, "message"), - num_parallel_calls=2)) + dataset = dataset_ops.Dataset.from_tensor_slices(components) + dataset = apply_map( + dataset, + lambda x: array_ops.check_numerics(x, "message"), + num_parallel_calls=2) get_next = self.getNext(dataset) for _ in range(3): @@ -289,13 +362,13 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) - def testPrefetchError(self): + @combinations.generate(_test_combinations()) + def testPrefetchError(self, apply_map): components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) - dataset = (dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: array_ops.check_numerics(x, "message")) - .prefetch(2)) - + dataset = dataset_ops.Dataset.from_tensor_slices(components) + dataset = apply_map( + dataset, lambda x: array_ops.check_numerics(x, "message")).prefetch(2) get_next = self.getNext(dataset) for _ in range(3): @@ -307,7 +380,8 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) - def testCaptureIterator(self): + @combinations.generate(_test_combinations()) + def testCaptureIterator(self, apply_map): def _build_ds(iterator): @@ -315,7 +389,7 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): get_next = iterator.get_next() return x * get_next - return dataset_ops.Dataset.range(10).map(_map_fn) + return apply_map(dataset_ops.Dataset.range(10), _map_fn) def _build_graph(): if context.executing_eagerly(): @@ -335,7 +409,8 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) - def testCaptureHashTable(self): + @combinations.generate(_test_combinations()) + def testCaptureHashTable(self, apply_map): # NOTE(mrry): We must use the V2 variants of `HashTable` # etc. because these produce a `tf.resource`-typed output that is # compatible with the in-graph function implementation. @@ -348,8 +423,9 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): input_sentences = dataset_ops.Dataset.from_tensor_slices( ["brain brain tank salad surgery", "surgery brain"]) - dataset = input_sentences.map(lambda x: string_ops.string_split([x]).values - ).map(table.lookup) + dataset = apply_map(input_sentences, + lambda x: string_ops.string_split([x]).values) + dataset = apply_map(dataset, table.lookup) get_next = self.getNext(dataset, requires_initialization=True) @@ -359,14 +435,15 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) - @test_util.run_v1_only("b/123904513") - def testCaptureQueue(self): + # TODO(b/123904513) + @combinations.generate(_test_combinations_with_mode_v1("graph")) + def testCaptureQueue(self, apply_map): elements = np.random.randint(100, size=[200]) queue = data_flow_ops.FIFOQueue(200, dtypes.int64, shapes=[]) enqueue_op = queue.enqueue_many(elements) close_op = queue.close() - dataset = dataset_ops.Dataset.from_tensors(0).repeat( - -1).map(lambda _: queue.dequeue()) + dataset = dataset_ops.Dataset.from_tensors(0).repeat(-1) + dataset = apply_map(dataset, lambda _: queue.dequeue()) get_next = self.getNext(dataset, requires_initialization=True) self.evaluate(enqueue_op) @@ -378,8 +455,8 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): self.evaluate(get_next()) # TODO(b/117581999): Possible deadlock in eager mode, debug. - @test_util.run_v1_only("b/120545219") - def testSkipEagerCaptureSameResourceMultipleTimes(self): + @combinations.generate(_test_combinations_with_mode_v1("graph")) + def testCaptureSameResourceMultipleTimes(self, apply_map): elements = np.random.randint(100, size=[200]) queue = data_flow_ops.FIFOQueue( 200, dtypes.int64, shapes=[], shared_name="shared_queue") @@ -389,8 +466,8 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): enqueue_op = queue.enqueue_many(elements) close_op = queue.close() - dataset = dataset_ops.Dataset.from_tensors(0).repeat( - -1).map(lambda _: (queue.dequeue(), queue_2.dequeue())) + dataset = dataset_ops.Dataset.from_tensors(0).repeat(-1) + dataset = apply_map(dataset, lambda _: (queue.dequeue(), queue_2.dequeue())) self.evaluate(enqueue_op) self.evaluate(close_op) @@ -401,9 +478,11 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) - def testSeededStatefulOperatorIsProperlyStateful(self): - dataset = dataset_ops.Dataset.from_tensors(0).repeat( - 10).map(lambda _: random_ops.random_uniform((), seed=11)).batch(2) + @combinations.generate(_test_combinations()) + def testSeededStatefulOperatorIsProperlyStateful(self, apply_map): + dataset = dataset_ops.Dataset.from_tensors(0).repeat(10) + fn = lambda _: random_ops.random_uniform((), seed=11) + dataset = apply_map(dataset, fn).batch(2) get_next = self.getNext(dataset, requires_initialization=True) random_values = [] @@ -422,9 +501,11 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): # Randomness is repeatable given same seed self.assertAllClose(random_values, random_values_2) - def testStatefulMapKeepsStateAcrossIterators(self): - dataset = dataset_ops.Dataset.from_tensors(0).repeat(10).map( - lambda _: random_ops.random_uniform((), seed=11)).repeat(1000).batch(10) + @combinations.generate(_test_combinations()) + def testStatefulMapKeepsStateAcrossIterators(self, apply_map): + dataset = dataset_ops.Dataset.from_tensors(0).repeat(10) + fn = lambda _: random_ops.random_uniform((), seed=11) + dataset = apply_map(dataset, fn).repeat(1000).batch(10) get_next = self.getNext(dataset) random_values = self.evaluate(get_next()) @@ -438,7 +519,8 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): i += 1 self.assertLess(i, 99) - def testStatefulOperationInShortCircuit(self): + @combinations.generate(_test_combinations()) + def testStatefulOperationInShortCircuit(self, apply_map): counter_var = variable_scope.get_variable( "counter", (), dtypes.int32, use_resource=True) @@ -446,7 +528,8 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): counter_var.assign_add(1) return x - dataset = dataset_ops.Dataset.range(10).map(increment_fn) + dataset = dataset_ops.Dataset.range(10) + dataset = apply_map(dataset, increment_fn) get_next = self.getNext(dataset, requires_initialization=True) @@ -459,22 +542,24 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): self.evaluate(get_next()) self.assertEqual(10, self.evaluate(counter_var)) - def testMapDict(self): - dataset = dataset_ops.Dataset.range(10).map( - lambda x: {"foo": x * 2, "bar": x**2}).map( - lambda d: d["foo"] + d["bar"]) + @combinations.generate(_test_combinations()) + def testMapDict(self, apply_map): + dataset = dataset_ops.Dataset.range(10) + dataset = apply_map(dataset, lambda x: {"foo": x * 2, "bar": x**2}) + dataset = apply_map(dataset, lambda d: d["foo"] + d["bar"]) self.assertDatasetProduces( dataset, expected_output=[i * 2 + i**2 for i in range(10)]) - def testMapNamedtuple(self, count=10): + @combinations.generate(_test_combinations()) + def testMapNamedtuple(self, apply_map): # construct dataset of tuples - labels = dataset_ops.Dataset.range(count) - images = labels.map(lambda l: -l) + labels = dataset_ops.Dataset.range(10) + images = apply_map(labels, lambda l: -l) dataset_tuple = dataset_ops.Dataset.zip((labels, images)) # convert dataset of tuples to dataset of namedtuples example = namedtuple("Example", ["label", "image"]) - dataset_namedtuple = dataset_tuple.map(example) + dataset_namedtuple = apply_map(dataset_tuple, example) def preprocess_tuple(label, image): image = 2 * image @@ -484,14 +569,14 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): return example._replace(image=2 * example.image) # preprocess both datasets - dataset_tuple = dataset_tuple.map(preprocess_tuple) - dataset_namedtuple = dataset_namedtuple.map(preprocess_namedtuple) + dataset_tuple = apply_map(dataset_tuple, preprocess_tuple) + dataset_namedtuple = apply_map(dataset_namedtuple, preprocess_namedtuple) next_tuple = self.getNext(dataset_tuple) next_namedtuple = self.getNext(dataset_namedtuple) # make sure both datasets contain the same data - for i in range(count): + for i in range(10): tuple_, namedtuple_ = self.evaluate([next_tuple(), next_namedtuple()]) self.assertEqual(tuple_, namedtuple_) self.assertEqual(tuple_, (i, -2 * i)) @@ -499,13 +584,16 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_namedtuple()) - def testUseStepContainerInMap(self): + @combinations.generate(_test_combinations()) + def testUseStepContainerInMap(self, apply_map): row = np.arange(6) - dataset = dataset_ops.Dataset.from_tensors( - row).map(lambda elems: map_fn.map_fn(lambda x: x * x, elems)) + dataset = dataset_ops.Dataset.from_tensors(row) + dataset = apply_map(dataset, + lambda elems: map_fn.map_fn(lambda x: x * x, elems)) self.assertDatasetProduces(dataset, expected_output=[row**2]) - def testCaseAndCondInMap(self): + @combinations.generate(_test_combinations()) + def testCaseAndCondInMap(self, apply_map): def control_map_fn(x, y): @@ -531,13 +619,12 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): pred_fn_pairs, default=multiply, exclusive=True) def build_dataset(row, num): - dataset = dataset_ops.Dataset.from_tensor_slices( - row).map(lambda x: control_map_fn(x, num)) - return self.getNext(dataset) + dataset = dataset_ops.Dataset.from_tensor_slices(row) + return apply_map(dataset, lambda x: control_map_fn(x, num)) row = np.arange(6) for num in [2, 3, 4]: - get_next = build_dataset(row, num) + get_next = self.getNext(build_dataset(row, num)) for i in range(6): self.assertEqual( (i // 2 if i % 2 else i * 2) if (num == 2 or num == 3) else i * 2, @@ -545,7 +632,8 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) - def testCaseInWhileInMap(self): + @combinations.generate(_test_combinations()) + def testCaseInWhileInMap(self, apply_map): def control_map_fn(x, y): @@ -564,22 +652,22 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): pred_fn_pairs, default=multiply, exclusive=True) def build_dataset(row, num): - # pylint: disable=g-long-lambda - dataset = dataset_ops.Dataset.from_tensors( - row).map(lambda elems: map_fn.map_fn( - lambda x: control_map_fn(x, num), elems)) - return self.getNext(dataset) + dataset = dataset_ops.Dataset.from_tensors(row) + return apply_map( + dataset, + lambda elems: map_fn.map_fn(lambda x: control_map_fn(x, num), elems)) row = np.arange(6) for num in [2, 3, 4]: - get_next = build_dataset(row, num) + get_next = self.getNext(build_dataset(row, num)) self.assertAllEqual( [x // 2 if (num == 2 or num == 3) else x * 2 for x in row], self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) - def testCaseAndCondInWhileInMap(self): + @combinations.generate(_test_combinations()) + def testCaseAndCondInWhileInMap(self, apply_map): def control_map_fn(x, y): @@ -606,11 +694,10 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): row = np.arange(6) num = 2 - # pylint: disable=g-long-lambda - dataset = dataset_ops.Dataset.from_tensors( - row).map(lambda elems: map_fn.map_fn( - lambda x: control_map_fn(x, num), elems)) - # pylint: enable=g-long-lambda + dataset = dataset_ops.Dataset.from_tensors(row) + dataset = apply_map( + dataset, + lambda elems: map_fn.map_fn(lambda x: control_map_fn(x, num), elems)) get_next = self.getNext(dataset) self.assertAllEqual([(x // 2 if x % 2 else x * 2) if @@ -619,17 +706,20 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) - def testNestedListMapDataset(self): - dataset = dataset_ops.Dataset.from_tensors( - [0, 1, 2]).repeat(10).map(lambda a: ([a[1], a[0] + a[2]], a[1])) - + @combinations.generate(_test_combinations()) + def testNestedListMapDataset(self, apply_map): + dataset = dataset_ops.Dataset.from_tensors([0, 1, 2]).repeat(10) + dataset = apply_map(dataset, lambda a: ([a[1], a[0] + a[2]], a[1])) expected_output = [(np.array([1, 2]), 1)] * 10 self.assertDatasetProduces(dataset, expected_output=expected_output) - def testPrefetch(self): - # We will use this event to test that `_map_py_func()` has been - # invoked a certain number of times (6 times, to be exact) after - # consuming fewer elements from the iterator. + @combinations.generate( + combinations.times(_test_combinations(), + combinations.combine(buffer_size=[1, 2, 3, 4]))) + def testPrefetch(self, apply_map, buffer_size): + # We will use this event to test that `_map_py_func()` has been invoked a + # certain number of times (6 times, to be exact) after consuming fewer + # elements from the iterator. ev = threading.Event() set_event_during_invocation = 5 @@ -642,56 +732,38 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): def _map_fn(x): return script_ops.py_func(_map_py_func, [x], x.dtype) - def do_test(buffer_size): - dataset = dataset_ops.Dataset.range(100).map(_map_fn).prefetch( - buffer_size) - - get_next = self.getNext(dataset) - # Simple test that prefetch yields the expected values in the - # expected order. - for i in range(100): - self.assertEqual(i * i, self.evaluate(get_next())) - with self.assertRaises(errors.OutOfRangeError): - self.evaluate(get_next()) - - for buffer_size in [1, 10, 100, 1000]: - do_test(buffer_size) - - # We can indirectly observe that varying the buffer size has the - # intended effect by observing when `ev` is set (on the 6th - # invocation of `_map_py_func()`). + # We can indirectly observe that varying the buffer size has the intended + # effect by observing when `ev` is set (on the 6th invocation of + # `_map_py_func()`). # NOTE(mrry): We do not test with `buffer_size == - # set_event_during_invocation`, because we must consume at least - # one element to start the prefetching. - def do_test_ev(buffer_size): - dataset = dataset_ops.Dataset.range(100).map(_map_fn).prefetch( - buffer_size) + # set_event_during_invocation`, because we must consume at least one element + # to start the prefetching. + dataset = dataset_ops.Dataset.range(100) + dataset = apply_map(dataset, _map_fn).prefetch(buffer_size) + get_next = self.getNext(dataset) - get_next = self.getNext(dataset) + event_will_be_set_after_consuming = ( + set_event_during_invocation - buffer_size + 1) - event_will_be_set_after_consuming = ( - set_event_during_invocation - buffer_size + 1) + ev.clear() + for i in range(event_will_be_set_after_consuming): + self.assertFalse(ev.is_set()) + self.assertEqual(i * i, self.evaluate(get_next())) + ev.wait() + for i in range(event_will_be_set_after_consuming, 100): + self.assertEqual(i * i, self.evaluate(get_next())) + with self.assertRaises(errors.OutOfRangeError): + self.evaluate(get_next()) - ev.clear() - for i in range(event_will_be_set_after_consuming): - self.assertFalse(ev.is_set()) - self.assertEqual(i * i, self.evaluate(get_next())) - ev.wait() - for i in range(event_will_be_set_after_consuming, 100): - self.assertEqual(i * i, self.evaluate(get_next())) - with self.assertRaises(errors.OutOfRangeError): - self.evaluate(get_next()) - - for buffer_size in range(1, set_event_during_invocation): - do_test_ev(buffer_size) - - def testReturnList(self): - dataset = dataset_ops.Dataset.range( - 10).map(lambda x: [x, constant_op.constant(37.0)]) + @combinations.generate(_test_combinations()) + def testReturnList(self, apply_map): + dataset = dataset_ops.Dataset.range(10) + dataset = apply_map(dataset, lambda x: [x, constant_op.constant(37.0)]) self.assertDatasetProduces( dataset, expected_output=[(i, 37.0) for i in range(10)]) - def testMultiOutputPyFunc(self): + @combinations.generate(_test_combinations()) + def testMultiOutputPyFunc(self, apply_map): # The `tf.py_func()` op returns a list of tensors for its outputs. def _map_fn(x_tensor): def _map_py_func(x): @@ -699,11 +771,13 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): return script_ops.py_func( _map_py_func, [x_tensor], [dtypes.int64, dtypes.float64]) - dataset = dataset_ops.Dataset.range(10).map(_map_fn) + dataset = dataset_ops.Dataset.range(10) + dataset = apply_map(dataset, _map_fn) self.assertDatasetProduces( dataset, expected_output=[(i, 37.0) for i in range(10)]) - def testSparse(self): + @combinations.generate(_test_combinations()) + def testSparse(self, apply_map): def _sparse(i): return sparse_tensor.SparseTensorValue( @@ -711,11 +785,13 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): values=(i * np.array([1])), dense_shape=np.array([1, 1])) - dataset = dataset_ops.Dataset.range(10).map(_sparse) + dataset = dataset_ops.Dataset.range(10) + dataset = apply_map(dataset, _sparse) self.assertDatasetProduces( dataset, expected_output=[_sparse(i) for i in range(10)]) - def testSparseChain(self): + @combinations.generate(_test_combinations()) + def testSparseChain(self, apply_map): def _sparse(i): return sparse_tensor.SparseTensorValue( @@ -727,37 +803,38 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertTrue(sparse_tensor.is_sparse(i)) return sparse_ops.sparse_concat(0, [i, i]) - dataset = dataset_ops.Dataset.range(10).map(_sparse).map(_check) + dataset = dataset_ops.Dataset.range(10) + dataset = apply_map(dataset, _sparse) + dataset = apply_map(dataset, _check) self.assertDatasetProduces( dataset, expected_output=[self.evaluate(_check(_sparse(i))) for i in range(10)]) - def testSparseMapShapeInference(self): - if not context.executing_eagerly(): - self.skipTest("SparseTensor shape inference requires eager mode") + @combinations.generate(_test_combinations_with_mode("eager")) + def testSparseMapShapeInference(self, apply_map): row_lengths = np.random.randint(0, 4, size=128) values = np.ones(np.sum(row_lengths)) sparse = ragged_tensor.RaggedTensor.from_row_lengths( values, row_lengths).to_sparse() dataset = dataset_ops.Dataset.from_tensor_slices(sparse) dataset = dataset.batch(32, drop_remainder=True) - dataset = dataset.map(lambda x: x) + dataset = apply_map(dataset, lambda x: x) self.assertEqual((32, 3), dataset.element_spec.shape) - def testSparseMapShapeInferencePartial(self): - if not context.executing_eagerly(): - self.skipTest("SparseTensor shape inference requires eager mode") + @combinations.generate(_test_combinations_with_mode("eager")) + def testSparseMapShapeInferencePartial(self, apply_map): row_lengths = np.random.randint(0, 4, size=128) values = np.ones(np.sum(row_lengths)) sparse = ragged_tensor.RaggedTensor.from_row_lengths( values, row_lengths).to_sparse() dataset = dataset_ops.Dataset.from_tensor_slices(sparse) dataset = dataset.batch(32, drop_remainder=False) - dataset = dataset.map(lambda x: x) + dataset = apply_map(dataset, lambda x: x) self.assertEqual([None, 3], dataset.element_spec.shape.as_list()) - def testTensorArray(self): + @combinations.generate(_test_combinations()) + def testTensorArray(self, apply_map): def _tensor_array(i): i = math_ops.cast(i, dtypes.int32) @@ -765,11 +842,13 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): tensor_array_ops.TensorArray(dtypes.int32, element_shape=(), size=i) .unstack(math_ops.range(i, dtype=dtypes.int32))) - dataset = dataset_ops.Dataset.range(10).map(_tensor_array) + dataset = dataset_ops.Dataset.range(10) + dataset = apply_map(dataset, _tensor_array) self.assertDatasetProduces( dataset, expected_output=[list(range(i)) for i in range(10)]) - def testTensorArrayChain(self): + @combinations.generate(_test_combinations()) + def testTensorArrayChain(self, apply_map): def _tensor_array(i): i = math_ops.cast(i, dtypes.int32) @@ -781,23 +860,28 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertIsInstance(x, tensor_array_ops.TensorArray) return x.identity() - dataset = dataset_ops.Dataset.range(10).map(_tensor_array).map(_check) + dataset = dataset_ops.Dataset.range(10) + dataset = apply_map(dataset, _tensor_array) + dataset = apply_map(dataset, _check) self.assertDatasetProduces( dataset, expected_output=[list(range(i)) for i in range(10)]) - def testRagged(self): + @combinations.generate(_test_combinations()) + def testRagged(self, apply_map): def _ragged(i): return ragged_tensor.RaggedTensor.from_tensor(i * [[1]]) - dataset = dataset_ops.Dataset.range(5).map(_ragged) + dataset = dataset_ops.Dataset.range(5) + dataset = apply_map(dataset, _ragged) self.assertDatasetProduces( dataset, expected_output=[ragged_factory_ops.constant([[i]]) for i in range(5)]) - def testRaggedChain(self): + @combinations.generate(_test_combinations()) + def testRaggedChain(self, apply_map): def _ragged(i): return ragged_tensor.RaggedTensor.from_tensor(i * [[1]]) @@ -806,7 +890,9 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertTrue(ragged_tensor.is_ragged(i)) return ragged_concat_ops.concat([i, i], 0) - dataset = dataset_ops.Dataset.range(10).map(_ragged).map(_concat) + dataset = dataset_ops.Dataset.range(10) + dataset = apply_map(dataset, _ragged) + dataset = apply_map(dataset, _concat) self.assertDatasetProduces( dataset, @@ -815,15 +901,19 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): for i in range(10) ]) - @test_util.run_v1_only("b/123904513") - def testParallelMapOutOfRangeError(self): + # TODO(b/123904513) + @combinations.generate(_test_combinations_with_mode_v1("graph")) + def testParallelMapOutOfRangeError(self, apply_map): + def raising_py_func(i): if i == 100: raise StopIteration() else: return i - dataset = dataset_ops.Dataset.range(105).map( + dataset = dataset_ops.Dataset.range(105) + dataset = apply_map( + dataset, lambda x: script_ops.py_func(raising_py_func, [x], dtypes.int64), num_parallel_calls=2) get_next = self.getNext(dataset) @@ -832,11 +922,15 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) - def testConstantOutput(self): - dataset = dataset_ops.Dataset.range(10).map(lambda x: [x, "hello", 10]) + @combinations.generate(_test_combinations()) + def testConstantOutput(self, apply_map): + dataset = dataset_ops.Dataset.range(10) + dataset = apply_map(dataset, lambda x: [x, "hello", 10]) self.assertDatasetProduces(dataset, [(i, b"hello", 10) for i in range(10)]) - def testWarnOnLookupTable(self): + @combinations.generate(_test_combinations()) + def testWarnOnLookupTable(self, apply_map): + def collecting_function(x): _ = lookup_ops.HashTable( lookup_ops.KeyValueTensorInitializer(["a"], [1.]), 0.0, name="t1") @@ -844,30 +938,8 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): warnings.simplefilter("always") with warnings.catch_warnings(record=True) as w: - _ = dataset_ops.Dataset.range(10).map(collecting_function) - # NOTE(mrry): Python 3 prints other warnings in addition to the one we are - # testing, so we search for the expected warning. - self.assertGreaterEqual(len(w), 1) - found_warning = False - for warning in w: - if ("Creating resources inside a function passed to Dataset.map() is " - "not supported." in str(warning)): - found_warning = True - break - self.assertTrue(found_warning) - - @test_util.run_v1_only("map_with_legacy_function v1 only") - def testWarnOnLookupTableLegacyFunction(self): - - def collecting_function(x): - _ = lookup_ops.HashTable( - lookup_ops.KeyValueTensorInitializer(["a"], [1.]), 0.0, name="t1") - return x - - warnings.simplefilter("always") - with warnings.catch_warnings(record=True) as w: - _ = dataset_ops.Dataset.range(10).map_with_legacy_function( - collecting_function) + dataset = dataset_ops.Dataset.range(10) + _ = apply_map(dataset, collecting_function) # NOTE(mrry): Python 3 prints other warnings in addition to the one we are # testing, so we search for the expected warning. self.assertGreaterEqual(len(w), 1) @@ -879,21 +951,25 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): break self.assertTrue(found_warning) + @combinations.generate(test_base.default_test_combinations()) def testWarnOnSeedFromOuterGraph(self): with ops.Graph().as_default() as g: g.seed = 10 warnings.simplefilter("always") + def _check_warning(caught_warnings, expected_result): + found_warning = False + for warning in caught_warnings: + if ("Explicitly set the seed in the function if this is not the " + "intended behavior" in str(warning)): + found_warning = True + break + self.assertEqual(found_warning, expected_result) + # map_fun doesn't use seed, so no warning is generated. with warnings.catch_warnings(record=True) as w: _ = dataset_ops.Dataset.range(10).map(math_ops.square) - found_warning = False - for warning in w: - if ("Explicitly set the seed in the function if this is not the " - "intended behavior" in str(warning)): - found_warning = True - break - self.assertFalse(found_warning) + _check_warning(w, False) def random_func(x): x = math_ops.add(x, 1) @@ -902,14 +978,7 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): with warnings.catch_warnings(record=True) as w: _ = dataset_ops.Dataset.range(10).map(random_func) - self.assertGreaterEqual(len(w), 1) - found_warning = False - for warning in w: - if ("Explicitly set the seed in the function if this is not the " - "intended behavior" in str(warning)): - found_warning = True - break - self.assertTrue(found_warning) + _check_warning(w, True) def random_func_seeded(x): ops.get_default_graph().seed = None @@ -918,41 +987,30 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): with warnings.catch_warnings(record=True) as w: _ = dataset_ops.Dataset.range(10).batch(2).map(random_func_seeded) - found_warning = False - for warning in w: - if ("Explicitly set the seed in the function if this is not the " - "intended behavior" in str(warning)): - found_warning = True - break - self.assertFalse(found_warning) + _check_warning(w, False) with warnings.catch_warnings(record=True) as w: - _ = dataset_ops.Dataset.range(10).batch( - 2).map(lambda x: random_ops.random_shuffle(x, seed=37)) - found_warning = False - for warning in w: - if ("Explicitly set the seed in the function if this is not the " - "intended behavior" in str(warning)): - found_warning = True - break - self.assertFalse(found_warning) + _ = dataset_ops.Dataset.range(10).batch(2).map( + lambda x: random_ops.random_shuffle(x, seed=37)) + _check_warning(w, False) - def testNestedDatasetMap(self): - # TODO(b/110122868): When iterators can yield a `tf.data.Dataset`, remove - # the `get_single_element()` call. - dataset = dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0]).map( - dataset_ops.Dataset.from_tensor_slices).map( - lambda ds: ds.batch(3)).flat_map(lambda x: x) + @combinations.generate(_test_combinations()) + def testNestedDatasetMap(self, apply_map): + dataset = dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0]) + dataset = apply_map(dataset, dataset_ops.Dataset.from_tensor_slices) + dataset = apply_map(dataset, lambda ds: ds.batch(3)).flat_map(lambda x: x) self.assertDatasetProduces(dataset, expected_output=[[1.0, 2.0, 3.0]]) - def testReturnValueError(self): + @combinations.generate(_test_combinations()) + def testReturnValueError(self, apply_map): dataset = dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0]) with self.assertRaisesRegexp( TypeError, r"Unsupported return value from function passed to " r"Dataset.map\(\): None."): - _ = dataset.map(lambda x: None) + _ = apply_map(dataset, lambda x: None) + @combinations.generate(test_base.default_test_combinations()) def testBrokenFunctionErrorOnInitialization(self): dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0, 3.0]) @@ -965,8 +1023,7 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): value, dtype=dtypes.float32, shape=[0], verify_shape=False)) dtype_value = attr_value_pb2.AttrValue(type=dtypes.int32.as_datatype_enum) - # Create a "Const" op with a `tf.float32` value and a `tf.int32` type - # attr. + # Create a "Const" op with a `tf.float32` value and a `tf.int32` type. const_tensor = ops.get_default_graph().create_op( "Const", [], [dtypes.int32], attrs={ @@ -980,15 +1037,11 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces( dataset, expected_error=(errors.InvalidArgumentError, "BrokenConst")) -# pylint: disable=g-long-lambda - @parameterized.named_parameters( - ("Map", lambda dataset, func: - dataset_ops.MapDataset(dataset, func, use_inter_op_parallelism=False)), - ("ParallelMap", lambda dataset, func: - dataset_ops.ParallelMapDataset(dataset, func, num_parallel_calls=1, - use_inter_op_parallelism=False)), - ) - def testNoInterOpParallelism(self, make_dataset_fn): + @combinations.generate( + combinations.times( + _test_combinations_with_mode("graph"), + combinations.combine(num_parallel_calls=[None, 12]))) + def testNoInterOpParallelism(self, apply_map, num_parallel_calls): dataset = dataset_ops.Dataset.from_tensors(0) def _get_tid(): @@ -1000,58 +1053,54 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): tids.append(script_ops.py_func(_get_tid, [], dtypes.int64)) return tids - dataset = make_dataset_fn(dataset, _map_fn) + dataset = apply_map(dataset, _map_fn) + dataset._variant_tensor.op._set_attr("use_inter_op_parallelism", + attr_value_pb2.AttrValue(b=False)) get_next = self.getNext(dataset) tids = self.evaluate(get_next()) self.assertTrue(all(tids[0] == tid for tid in tids)) -# pylint: enable=g-long-lambda - @parameterized.named_parameters( - ("SequentialIdentity", None, lambda x: x, None), - ("SequentialReplicate", None, lambda x: (x, x), None), - ("SequentialSwap", (None, None), lambda x, y: (y, x), None), - ("SequentialProject", (None, None), lambda x, y: x, None), - ("ParallelIdentity", None, lambda x: x, 10), - ("ParallelReplicate", None, lambda x: (x, x), 10), - ("ParallelSwap", (None, None), lambda x, y: (y, x), 10), - ("ParallelProject", (None, None), lambda x, y: x, 10), - ) - def testShortCircuit(self, structure, map_fn, num_parallel_calls): - dataset = self.structuredDataset(structure).repeat().map( - map_fn, num_parallel_calls=num_parallel_calls) + @combinations.generate( + combinations.times(_test_combinations(), _short_circuit_test_cases(), + combinations.combine(num_parallel_calls=[None, 12]))) + def testShortCircuit(self, apply_map, structure, fn, num_parallel_calls): + dataset = self.structuredDataset(structure).repeat() + dataset = apply_map(dataset, fn, num_parallel_calls=num_parallel_calls) get_next = self.getNext(dataset) if isinstance(structure, tuple): - expected = map_fn(*self.evaluate(self.structuredElement(structure))) + expected = fn(*self.evaluate(self.structuredElement(structure))) else: - expected = map_fn(self.evaluate(self.structuredElement(structure))) + expected = fn(self.evaluate(self.structuredElement(structure))) self.assertEqual(expected, self.evaluate(get_next())) - @parameterized.named_parameters( - ("Sequential", None), - ("Parallel", 10), - ) - def testShortCircuitCapturedInput(self, num_parallel_calls): + @combinations.generate( + combinations.times(_test_combinations(), + combinations.combine(num_parallel_calls=[None, 12]))) + def testShortCircuitCapturedInput(self, apply_map, num_parallel_calls): captured_t = variables.Variable(42) - dataset = self.structuredDataset(None).repeat().map( - lambda x: captured_t, num_parallel_calls=num_parallel_calls) + dataset = self.structuredDataset(None).repeat() + dataset = apply_map( + dataset, lambda x: captured_t, num_parallel_calls=num_parallel_calls) self.evaluate(variables.global_variables_initializer()) get_next = self.getNext(dataset, requires_initialization=True) self.assertEqual(42, self.evaluate(get_next())) - @parameterized.named_parameters( - ("1", 1, 1), - ("2", 10, 1), - ("3", 10, 10), - ("4", 100, 1), - ("5", 100, 10), - ("6", 100, 100), - ) - def testSloppyInterleaveInOrder(self, num_elements, num_parallel_calls): + @combinations.generate( + combinations.times( + _test_combinations(), + combinations.combine(num_elements=1, num_parallel_calls=1) + + combinations.combine(num_elements=10, num_parallel_calls=1) + + combinations.combine(num_elements=10, num_parallel_calls=10) + + combinations.combine(num_elements=100, num_parallel_calls=1) + + combinations.combine(num_elements=100, num_parallel_calls=10) + + combinations.combine(num_elements=100, num_parallel_calls=100))) + def testSloppyInterleaveInOrder(self, apply_map, num_elements, + num_parallel_calls): dataset, coordination_events = _make_coordinated_sloppy_dataset( - num_elements, num_parallel_calls) + apply_map, num_elements, num_parallel_calls) options = dataset_ops.Options() options.experimental_threading = threading_options.ThreadingOptions() options.experimental_threading.private_threadpool_size = ( @@ -1064,14 +1113,16 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) - @parameterized.named_parameters( - ("1", 10, 10), - ("2", 100, 10), - ("3", 100, 100), - ) - def testSloppyInterleaveOutOfOrder(self, num_elements, num_parallel_calls): + @combinations.generate( + combinations.times( + _test_combinations(), + combinations.combine(num_elements=10, num_parallel_calls=10) + + combinations.combine(num_elements=100, num_parallel_calls=10) + + combinations.combine(num_elements=100, num_parallel_calls=100))) + def testSloppyInterleaveOutOfOrder(self, apply_map, num_elements, + num_parallel_calls): dataset, coordination_events = _make_coordinated_sloppy_dataset( - num_elements, num_parallel_calls) + apply_map, num_elements, num_parallel_calls) options = dataset_ops.Options() options.experimental_threading = threading_options.ThreadingOptions() options.experimental_threading.private_threadpool_size = ( @@ -1090,25 +1141,25 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) - @parameterized.named_parameters( - ("Map", None), - ("ParallelMap", 12), - ) + @combinations.generate( + combinations.combine( + tf_api_version=2, + mode=["eager", "graph"], + num_parallel_calls=[None, 12])) def testPreserveCardinality(self, num_parallel_calls): def py_fn(_): raise StopIteration() - dataset = dataset_ops.DatasetV2.from_tensors(0).map( + dataset = dataset_ops.Dataset.from_tensors(0).map( lambda x: script_ops.py_func(py_fn, [x], dtypes.int64), num_parallel_calls=num_parallel_calls) get_next = self.getNext(dataset) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(get_next()) - # NOTE: collection test is specific to graph mode only, no eager coverage. - @test_util.run_v1_only("graph specific test") - def testSkipEagerCollectionCopy(self): + @combinations.generate(_test_combinations_with_mode("graph")) + def testCollectionCopy(self, apply_map): w = variable_scope.get_variable("w", []) self.assertIn(w, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) @@ -1117,22 +1168,21 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): return x dataset = dataset_ops.Dataset.from_tensors(constant_op.constant(1.0)) - dataset.map(func) + _ = apply_map(dataset, func) - @parameterized.named_parameters( - ("Sequential", None), - ("Parallel", 12), - ) - @test_util.run_v1_only("graph-mode specific test") - def testSkipEagerMapCancellation(self, num_parallel_calls): + @combinations.generate( + combinations.times( + _test_combinations_with_mode_v1("graph"), + combinations.combine(num_parallel_calls=[None, 12]))) + def testMapCancellation(self, apply_map, num_parallel_calls): # Checks that a cancellation of is threaded through to map transformation. queue = data_flow_ops.FIFOQueue(10, dtypes.int32, ()) def fn(_): return queue.dequeue() - dataset = dataset_ops.Dataset.range(1).map( - fn, num_parallel_calls=num_parallel_calls) + dataset = dataset_ops.Dataset.range(1) + dataset = apply_map(dataset, fn, num_parallel_calls=num_parallel_calls) get_next = self.getNext(dataset, requires_initialization=True) with self.cached_session() as sess: @@ -1143,17 +1193,11 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): thread.join() -# TODO(shivaniagarwal): separate out `map` and `map_with_legacy_function` tests -# as later would not work in v2. -@test_util.run_all_in_graph_and_eager_modes -class MapWithCapturedVariableTests(test_base.DatasetTestBase, - parameterized.TestCase): - # TODO(b/126553094): map doesnt work with variable defined inside function in # eager mode, possible Graph tensors leak out of the function building context # from function graph in eager mode as variables are created in init_scope. - @test_util.run_v1_only("b/126553094") - def testSkipEagerCreateVariableInsideFunctionWithGetter(self): + @combinations.generate(test_base.graph_only_combinations()) + def testCreateVariableInsideFunctionWithGetter(self): def func(_): with variable_scope.variable_scope( @@ -1162,12 +1206,13 @@ class MapWithCapturedVariableTests(test_base.DatasetTestBase, "counter", (), dtypes.int32, use_resource=True) return counter_var.assign_add(1) - # NOTE: In the legacy function, resource is captured by value for variable - # getter. dataset = dataset_ops.Dataset.from_tensors(0).repeat(10) - with self.assertRaisesWithPredicateMatch( - AttributeError, "'Tensor' object has no attribute 'assign_add'"): - dataset.map_with_legacy_function(func) + + if hasattr(dataset, "map_with_legacy_function"): + # NOTE: In the legacy function, resource is captured by value. + with self.assertRaisesWithPredicateMatch( + AttributeError, "'Tensor' object has no attribute 'assign_add'"): + dataset.map_with_legacy_function(func) dataset = dataset.map(func) self.evaluate(variables.global_variables_initializer()) @@ -1179,18 +1224,12 @@ class MapWithCapturedVariableTests(test_base.DatasetTestBase, with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) - @parameterized.named_parameters( - ("MapLegacyFunction", - lambda dataset, func: dataset.map_with_legacy_function(func)), - ("Map", lambda dataset, func: dataset.map(func)), - ) - @test_util.run_v1_only("map_with_legacy_function is only available in v1.") - def testCaptureVariable(self, transformation_function): + @combinations.generate(_test_combinations()) + def testCaptureVariable(self, apply_map): counter_var = variable_scope.get_variable( "counter", (), dtypes.int32, use_resource=True) dataset = dataset_ops.Dataset.from_tensors(0).repeat(10) - dataset = transformation_function( - dataset, lambda _: counter_var.assign_add(1)) + dataset = apply_map(dataset, lambda _: counter_var.assign_add(1)) get_next = self.getNext(dataset, requires_initialization=True) self.evaluate(counter_var.initializer) @@ -1203,34 +1242,20 @@ class MapWithCapturedVariableTests(test_base.DatasetTestBase, self.evaluate(get_next()) self.assertEqual(10, self.evaluate(counter_var)) - # NOTE: no need to explicitly initialize variables in eager mode. - @parameterized.named_parameters( - ("MapLegacyFunction", - lambda dataset, func: dataset.map_with_legacy_function(func)), - ("Map", lambda dataset, func: dataset.map(func)), - ) - @test_util.run_v1_only("this test is meant to run in graph mode only.") - def testSkipEagerCaptureUninitializedVariableError(self, - transformation_function): + @combinations.generate(_test_combinations_with_mode_v1("graph")) + def testCaptureUninitializedVariableError(self, apply_map): counter_var = variable_scope.get_variable( "counter", (), dtypes.int32, use_resource=True) dataset = dataset_ops.Dataset.from_tensors(0).repeat(10) - dataset = transformation_function( - dataset, lambda _: counter_var.assign_add(1)) + dataset = apply_map(dataset, lambda _: counter_var.assign_add(1)) get_next = self.getNext(dataset, requires_initialization=True) with self.assertRaises(errors.NotFoundError): self.evaluate(get_next()) # TODO(b/121264236): add eager mode coverage when we have multi-device setup. - @parameterized.named_parameters( - ("MapLegacyFunction", - lambda dataset, func: dataset.map_with_legacy_function(func)), - ("Map", lambda dataset, func: dataset.map(func)), - ) - @test_util.run_v1_only("b/121264236") - def testSkipEagerCaptureConstantsWithConflictingDevices( - self, transformation_function): + @combinations.generate(_test_combinations_with_mode_v1("graph")) + def testCaptureConstantsWithConflictingDevices(self, apply_map): config = config_pb2.ConfigProto(device_count={"CPU": 3}) with self.cached_session(config=config): with ops.device("/device:CPU:0"): @@ -1242,13 +1267,13 @@ class MapWithCapturedVariableTests(test_base.DatasetTestBase, return math_ops.add(a, b) dataset = dataset_ops.Dataset.from_tensors(0).repeat(10) - dataset = transformation_function(dataset, func) + dataset = apply_map(dataset, func) expected_output = [8.0] * 10 self.assertDatasetProduces(dataset, expected_output=expected_output) # TODO(b/121264236): add eager mode coverage when we have multi-device setup. - @test_util.run_v1_only("b/121264236") - def testSkipEagerRefVariablesWithMultipleDevices(self): + @combinations.generate(_test_combinations_with_mode_v1("graph")) + def testReferenceVariablesWithMultipleDevices(self, apply_map): config = config_pb2.ConfigProto(device_count={"CPU": 3}) with self.cached_session(config=config): @@ -1262,7 +1287,7 @@ class MapWithCapturedVariableTests(test_base.DatasetTestBase, # NOTE: Use the legacy function implementation as eager function will # convert RefVariables to ResourceVariables. dataset = dataset_ops.Dataset.from_tensors(0).repeat(10) - dataset = dataset.map_with_legacy_function(func) + dataset = apply_map(dataset, func) self.evaluate(variables.global_variables_initializer()) expected_output = [8.0] * 10 self.assertDatasetProduces( @@ -1271,8 +1296,8 @@ class MapWithCapturedVariableTests(test_base.DatasetTestBase, requires_initialization=True) # TODO(b/121264236): add eager mode coverage when we have multi-device setup. - @test_util.run_v1_only("b/121264236") - def testSkipEagerResourceVariablesWithMultipleDevices(self): + @combinations.generate(_test_combinations_with_mode_v1("graph")) + def testResourceVariablesWithMultipleDevices(self, apply_map): config = config_pb2.ConfigProto(device_count={"CPU": 3}) def func(_): @@ -1287,25 +1312,10 @@ class MapWithCapturedVariableTests(test_base.DatasetTestBase, "b", (), dtypes.int32, use_resource=True) return math_ops.add(a_var, b_var) - g_1 = ops.Graph() - with self.session(config=config, graph=g_1): - # The MapDataset node ends up with two ResourceVariable inputs, one on - # device CPU:0 and the other on device CPU:1. + g = ops.Graph() + with self.session(config=config, graph=g): dataset = dataset_ops.Dataset.from_tensors(0).repeat(10) - dataset = dataset.map(func) - self.evaluate(variables.global_variables_initializer()) - expected_output = [1] * 10 - self.assertDatasetProduces( - dataset, - expected_output=expected_output, - requires_initialization=True) - - g_2 = ops.Graph() - with self.session(config=config, graph=g_2): - # In old-Defun variable is captured as value, hence there is no colocation - # error. - dataset = dataset_ops.Dataset.from_tensors(0).repeat(10) - dataset = dataset.map_with_legacy_function(func) + dataset = apply_map(dataset, func) self.evaluate(variables.global_variables_initializer()) expected_output = [1] * 10 self.assertDatasetProduces( diff --git a/tensorflow/python/data/kernel_tests/memory_cleanup_test.py b/tensorflow/python/data/kernel_tests/memory_cleanup_test.py index a2015ef47d1..5b0ea02a054 100644 --- a/tensorflow/python/data/kernel_tests/memory_cleanup_test.py +++ b/tensorflow/python/data/kernel_tests/memory_cleanup_test.py @@ -119,8 +119,7 @@ class MemoryCleanupTest(test_base.DatasetTestBase, parameterized.TestCase): ] self.assertEmpty(tensors, "%d Tensors are still alive." % len(tensors)) - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode="eager")) + @combinations.generate(test_base.eager_only_combinations()) def testFilter(self): def get_dataset(): @@ -144,8 +143,7 @@ class MemoryCleanupTest(test_base.DatasetTestBase, parameterized.TestCase): self._testIteratorMemoryLeak(get_dataset) - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode="eager")) + @combinations.generate(test_base.eager_only_combinations()) def testFlatMap(self): def get_dataset(): @@ -157,8 +155,7 @@ class MemoryCleanupTest(test_base.DatasetTestBase, parameterized.TestCase): self._testIteratorMemoryLeak(get_dataset) - @combinations.generate( - combinations.combine(tf_api_version=[1, 2], mode="eager")) + @combinations.generate(test_base.eager_only_combinations()) def testFromGenerator(self): def get_dataset(): @@ -171,8 +168,8 @@ class MemoryCleanupTest(test_base.DatasetTestBase, parameterized.TestCase): self._testIteratorMemoryLeak(get_dataset) @combinations.generate( - combinations.combine( - tf_api_version=[1, 2], mode="eager", num_parallel_calls=[None, 10])) + combinations.times(test_base.eager_only_combinations(), + combinations.combine(num_parallel_calls=[None, 10]))) def testMap(self, num_parallel_calls): def get_dataset(): @@ -201,8 +198,8 @@ class MemoryCleanupTest(test_base.DatasetTestBase, parameterized.TestCase): self._testIteratorMemoryLeak(get_dataset) @combinations.generate( - combinations.combine( - tf_api_version=[1, 2], mode="eager", num_parallel_calls=[None, 10])) + combinations.times(test_base.eager_only_combinations(), + combinations.combine(num_parallel_calls=[None, 10]))) def testInterleave(self, num_parallel_calls): def get_dataset(): diff --git a/tensorflow/python/data/kernel_tests/optional_test.py b/tensorflow/python/data/kernel_tests/optional_test.py index 3ab6717b9c3..f0795563d09 100644 --- a/tensorflow/python/data/kernel_tests/optional_test.py +++ b/tensorflow/python/data/kernel_tests/optional_test.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + from absl.testing import parameterized import numpy as np @@ -27,6 +29,7 @@ from tensorflow.python.data.ops import optional_ops from tensorflow.python.data.util import structure from tensorflow.python.eager import context from tensorflow.python.eager import def_function +from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -40,14 +43,90 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes +def _optional_spec_test_combinations(): + # pylint: disable=g-long-lambda + cases = [ + ("Dense", lambda: constant_op.constant(37.0), + tensor_spec.TensorSpec([], dtypes.float32)), + ("Sparse", lambda: sparse_tensor.SparseTensor( + indices=[[0, 1]], + values=constant_op.constant([0], dtype=dtypes.int32), + dense_shape=[10, 10]), + sparse_tensor.SparseTensorSpec([10, 10], dtypes.int32)), + ("Nest", lambda: { + "a": constant_op.constant(37.0), + "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) + }, { + "a": + tensor_spec.TensorSpec([], dtypes.float32), + "b": ( + tensor_spec.TensorSpec([1], dtypes.string), + tensor_spec.TensorSpec([], dtypes.string), + ) + }), + ("Optional", lambda: optional_ops.Optional.from_value(37.0), + optional_ops.OptionalSpec(tensor_spec.TensorSpec([], dtypes.float32))), + ] + + def reduce_fn(x, y): + name, value_fn, expected_structure = y + return x + combinations.combine( + tf_value_fn=combinations.NamedObject(name, value_fn), + expected_value_structure=expected_structure) + + return functools.reduce(reduce_fn, cases, []) + + +def _get_next_as_optional_test_combinations(): + # pylint: disable=g-long-lambda + cases = [ + ("Dense", np.array([1, 2, 3], dtype=np.int32), + lambda: constant_op.constant([4, 5, 6], dtype=dtypes.int32), True), + ("Sparse", + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], + values=np.array([-1., 1.], dtype=np.float32), + dense_shape=[2, 2]), + lambda: sparse_tensor.SparseTensor( + indices=[[0, 1], [1, 0]], values=[37.0, 42.0], dense_shape=[2, 2]), + False), + ("Nest", { + "a": + np.array([1, 2, 3], dtype=np.int32), + "b": + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], + values=np.array([-1., 1.], dtype=np.float32), + dense_shape=[2, 2]) + }, lambda: { + "a": + constant_op.constant([4, 5, 6], dtype=dtypes.int32), + "b": + sparse_tensor.SparseTensor( + indices=[[0, 1], [1, 0]], + values=[37.0, 42.0], + dense_shape=[2, 2]) + }, False), + ] + + def reduce_fn(x, y): + name, value, value_fn, gpu_compatible = y + return x + combinations.combine( + np_value=value, tf_value_fn=combinations.NamedObject(name, value_fn), + gpu_compatible=gpu_compatible) + + return functools.reduce(reduce_fn, cases, []) + + class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase): + @combinations.generate(test_base.default_test_combinations()) def testFromValue(self): opt = optional_ops.Optional.from_value(constant_op.constant(37.0)) self.assertTrue(self.evaluate(opt.has_value())) self.assertEqual(37.0, self.evaluate(opt.get_value())) + @combinations.generate(test_base.default_test_combinations()) def testFromStructuredValue(self): opt = optional_ops.Optional.from_value({ "a": constant_op.constant(37.0), @@ -59,6 +138,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase): "b": ([b"Foo"], b"Bar") }, self.evaluate(opt.get_value())) + @combinations.generate(test_base.default_test_combinations()) def testFromSparseTensor(self): st_0 = sparse_tensor.SparseTensorValue( indices=np.array([[0]]), @@ -77,6 +157,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertAllEqual(expected.dense_shape, self.evaluate(actual.dense_shape)) + @combinations.generate(test_base.default_test_combinations()) def testFromNone(self): value_structure = tensor_spec.TensorSpec([], dtypes.float32) opt = optional_ops.Optional.none_from_structure(value_structure) @@ -91,6 +172,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.InvalidArgumentError): self.evaluate(opt.get_value()) + @combinations.generate(test_base.default_test_combinations()) def testAddN(self): devices = ["/cpu:0"] if test_util.is_gpu_available(): @@ -117,6 +199,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase): opt_none1.value_structure) self.assertFalse(self.evaluate(add_opt.has_value())) + @combinations.generate(test_base.default_test_combinations()) def testNestedAddN(self): devices = ["/cpu:0"] if test_util.is_gpu_available(): @@ -137,6 +220,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase): opt1.value_structure) self.assertAllEqual(inner_add_opt.get_value(), [4, 6.0]) + @combinations.generate(test_base.default_test_combinations()) def testZerosLike(self): devices = ["/cpu:0"] if test_util.is_gpu_available(): @@ -159,6 +243,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase): opt_none.value_structure) self.assertFalse(self.evaluate(zeros_opt.has_value())) + @combinations.generate(test_base.default_test_combinations()) def testNestedZerosLike(self): devices = ["/cpu:0"] if test_util.is_gpu_available(): @@ -175,6 +260,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase): opt1.value_structure) self.assertEqual(self.evaluate(inner_zeros_opt.get_value()), 0.0) + @combinations.generate(test_base.default_test_combinations()) def testCopyToGPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") @@ -204,6 +290,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase): self.evaluate(gpu_optional_with_value_values)) self.assertFalse(self.evaluate(gpu_optional_none_has_value)) + @combinations.generate(test_base.default_test_combinations()) def testNestedCopyToGPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") @@ -239,42 +326,10 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertFalse(self.evaluate(inner_none.has_value())) self.assertEqual(1.0, self.evaluate(gpu_nested_optional_values[2])) - def _assertElementValueEqual(self, expected, actual): - if isinstance(expected, dict): - self.assertItemsEqual(list(expected.keys()), list(actual.keys())) - for k in expected.keys(): - self._assertElementValueEqual(expected[k], actual[k]) - elif isinstance(expected, sparse_tensor.SparseTensorValue): - self.assertAllEqual(expected.indices, actual.indices) - self.assertAllEqual(expected.values, actual.values) - self.assertAllEqual(expected.dense_shape, actual.dense_shape) - else: - self.assertAllEqual(expected, actual) - - # pylint: disable=g-long-lambda - @parameterized.named_parameters( - ("Tensor", lambda: constant_op.constant(37.0), - tensor_spec.TensorSpec([], dtypes.float32)), - ("SparseTensor", lambda: sparse_tensor.SparseTensor( - indices=[[0, 1]], - values=constant_op.constant([0], dtype=dtypes.int32), - dense_shape=[10, 10]), - sparse_tensor.SparseTensorSpec([10, 10], dtypes.int32)), - ("Nest", lambda: { - "a": constant_op.constant(37.0), - "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) - }, { - "a": - tensor_spec.TensorSpec([], dtypes.float32), - "b": ( - tensor_spec.TensorSpec([1], dtypes.string), - tensor_spec.TensorSpec([], dtypes.string), - ) - }), - ("Optional", lambda: optional_ops.Optional.from_value(37.0), - optional_ops.OptionalSpec( - tensor_spec.TensorSpec([], dtypes.float32))), - ) + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + _optional_spec_test_combinations())) def testOptionalSpec(self, tf_value_fn, expected_value_structure): tf_value = tf_value_fn() opt = optional_ops.Optional.from_value(tf_value) @@ -304,36 +359,21 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase): round_trip_opt = opt_structure._from_tensor_list( opt_structure._to_tensor_list(opt)) if isinstance(tf_value, optional_ops.Optional): - self._assertElementValueEqual( + self.assertValuesEqual( self.evaluate(tf_value.get_value()), self.evaluate(round_trip_opt.get_value().get_value())) else: - self._assertElementValueEqual( + self.assertValuesEqual( self.evaluate(tf_value), self.evaluate(round_trip_opt.get_value())) - @parameterized.named_parameters( - ("Tensor", np.array([1, 2, 3], dtype=np.int32), - lambda: constant_op.constant([4, 5, 6], dtype=dtypes.int32), True), - ("SparseTensor", sparse_tensor.SparseTensorValue( - indices=[[0, 0], [1, 1]], - values=np.array([-1., 1.], dtype=np.float32), dense_shape=[2, 2]), - lambda: sparse_tensor.SparseTensor( - indices=[[0, 1], [1, 0]], values=[37.0, 42.0], dense_shape=[2, 2]), - False), - ("Nest", {"a": np.array([1, 2, 3], dtype=np.int32), - "b": sparse_tensor.SparseTensorValue( - indices=[[0, 0], [1, 1]], - values=np.array([-1., 1.], dtype=np.float32), - dense_shape=[2, 2])}, - lambda: {"a": constant_op.constant([4, 5, 6], dtype=dtypes.int32), - "b": sparse_tensor.SparseTensor( - indices=[[0, 1], [1, 0]], values=[37.0, 42.0], - dense_shape=[2, 2])}, False), - ) + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + _get_next_as_optional_test_combinations())) def testIteratorGetNextAsOptional(self, np_value, tf_value_fn, - works_on_gpu): - if not works_on_gpu and test.is_gpu_available(): + gpu_compatible): + if not gpu_compatible and test.is_gpu_available(): self.skipTest("Test case not yet supported on GPU.") ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3) @@ -348,7 +388,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase): next_elem.value_structure, structure.type_spec_from_value(tf_value_fn()))) self.assertTrue(next_elem.has_value()) - self._assertElementValueEqual(np_value, next_elem.get_value()) + self.assertValuesEqual(np_value, next_elem.get_value()) # After exhausting the iterator, `next_elem.has_value()` will evaluate to # false, and attempting to get the value will fail. for _ in range(2): @@ -379,7 +419,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase): elem_has_value, elem_value = self.evaluate( [elem_has_value_t, elem_value_t]) self.assertTrue(elem_has_value) - self._assertElementValueEqual(np_value, elem_value) + self.assertValuesEqual(np_value, elem_value) # After exhausting the iterator, `next_elem.has_value()` will evaluate to # false, and attempting to get the value will fail. @@ -388,6 +428,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.InvalidArgumentError): self.evaluate(elem_value_t) + @combinations.generate(test_base.default_test_combinations()) def testFunctionBoundaries(self): @def_function.function def get_optional(): @@ -407,6 +448,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase): val = consume_optional(opt_tensor) self.assertEqual(self.evaluate(val), 1.0) + @combinations.generate(test_base.default_test_combinations()) def testLimitedRetracing(self): trace_count = [0] diff --git a/tensorflow/python/data/kernel_tests/options_test.py b/tensorflow/python/data/kernel_tests/options_test.py index 222d8c6f1a6..b38d008b833 100644 --- a/tensorflow/python/data/kernel_tests/options_test.py +++ b/tensorflow/python/data/kernel_tests/options_test.py @@ -18,25 +18,31 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.python.data.experimental.ops import optimization_options from tensorflow.python.data.experimental.ops import stats_options from tensorflow.python.data.experimental.ops import threading_options from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.platform import test -class OptionsTest(test_base.DatasetTestBase): +class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase): + @combinations.generate(test_base.default_test_combinations()) def testOptionsDefault(self): ds = dataset_ops.Dataset.range(0) self.assertEqual(dataset_ops.Options(), ds.options()) + @combinations.generate(test_base.default_test_combinations()) def testOptionsOnce(self): options = dataset_ops.Options() ds = dataset_ops.Dataset.range(0).with_options(options).cache() self.assertEqual(options, ds.options()) + @combinations.generate(test_base.default_test_combinations()) def testOptionsTwiceSame(self): options = dataset_ops.Options() options.experimental_optimization.autotune = True @@ -44,6 +50,7 @@ class OptionsTest(test_base.DatasetTestBase): options) self.assertEqual(options, ds.options()) + @combinations.generate(test_base.default_test_combinations()) def testOptionsTwiceDifferent(self): options1 = dataset_ops.Options() options1.experimental_optimization.autotune = True @@ -55,6 +62,7 @@ class OptionsTest(test_base.DatasetTestBase): # Explicitly check that flag is False since assertFalse allows None self.assertIs(ds.options().experimental_deterministic, False) + @combinations.generate(test_base.default_test_combinations()) def testOptionsTwiceDifferentError(self): options1 = dataset_ops.Options() options1.experimental_optimization.autotune = True @@ -64,6 +72,7 @@ class OptionsTest(test_base.DatasetTestBase): "Cannot merge incompatible values"): dataset_ops.Dataset.range(0).with_options(options1).with_options(options2) + @combinations.generate(test_base.default_test_combinations()) def testOptionsMergeOptionsFromMultipleInputs(self): options1 = dataset_ops.Options() options1.experimental_optimization.autotune = True @@ -75,6 +84,7 @@ class OptionsTest(test_base.DatasetTestBase): self.assertTrue(ds.options().experimental_optimization.autotune) self.assertTrue(ds.options().experimental_deterministic) + @combinations.generate(test_base.default_test_combinations()) def testOptionsHaveDefaults(self): options1 = dataset_ops.Options() options2 = dataset_ops.Options() @@ -84,12 +94,11 @@ class OptionsTest(test_base.DatasetTestBase): options2.experimental_stats) self.assertIsNot(options1.experimental_threading, options2.experimental_threading) - self.assertEquals(options1.experimental_optimization, - optimization_options.OptimizationOptions()) - self.assertEquals(options1.experimental_stats, - stats_options.StatsOptions()) - self.assertEquals(options1.experimental_threading, - threading_options.ThreadingOptions()) + self.assertEqual(options1.experimental_optimization, + optimization_options.OptimizationOptions()) + self.assertEqual(options1.experimental_stats, stats_options.StatsOptions()) + self.assertEqual(options1.experimental_threading, + threading_options.ThreadingOptions()) if __name__ == "__main__": diff --git a/tensorflow/python/data/kernel_tests/padded_batch_test.py b/tensorflow/python/data/kernel_tests/padded_batch_test.py index 39339c0063a..a3b8f3945f3 100644 --- a/tensorflow/python/data/kernel_tests/padded_batch_test.py +++ b/tensorflow/python/data/kernel_tests/padded_batch_test.py @@ -23,43 +23,30 @@ import numpy as np from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import string_ops from tensorflow.python.platform import test from tensorflow.python.util import compat -def _random_seq_lens(count): - return np.random.randint(20, size=(count,)).astype(np.int32) - - -@test_util.run_all_in_graph_and_eager_modes class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase): - @parameterized.named_parameters( - ('default_padding', _random_seq_lens(32), 4, [-1], False), - ('constant_padding', _random_seq_lens(32), 4, [25], False), - ('uneven_with_remainder', _random_seq_lens(34), 4, [-1], False), - ('uneven_without_remainder', _random_seq_lens(34), 4, [-1], True), - ) - def testPaddedBatchDataset(self, seq_lens, batch_size, padded_shapes, - drop_remainder): - """Tests the padded batch dataset logic for various input configurations. - - Args: - seq_lens: the input sequence lengths - batch_size: the batch size - padded_shapes: the padded shapes to use - drop_remainder: whether a smaller batch size should be produced if batch - size does not divide number of inputs evenly - """ - + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine( + count=[32, 34], + padded_shapes=[[None], [25]], + drop_remainder=[True, False]))) + def testPaddedBatchDataset(self, count, padded_shapes, drop_remainder): + seq_lens = np.random.randint(20, size=(count,)).astype(np.int32) + batch_size = 4 dataset = dataset_ops.Dataset.from_tensor_slices(seq_lens).map( lambda x: array_ops.fill([x], x)).padded_batch( batch_size=batch_size, @@ -81,7 +68,9 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase): if not drop_remainder and len(seq_lens) % batch_size > 0: result = self.evaluate(get_next()) - padded_len = np.max(result) if result.size > 0 else 0 + padded_len = padded_shapes[0] + if padded_len is None or padded_len == -1: + padded_len = np.max(result) if result.size > 0 else 0 self.assertEqual((len(seq_lens) % batch_size, padded_len), result.shape) for j in range(len(seq_lens) % batch_size): seq_len = seq_lens[num_full_batches * batch_size + j] @@ -93,7 +82,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) - @test_util.run_deprecated_v1 + @combinations.generate(test_base.default_test_combinations()) def testPaddedBatchShortPadding(self): dataset = ( dataset_ops.Dataset.from_tensor_slices( @@ -102,6 +91,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces( dataset, expected_error=(errors.DataLossError, '')) + @combinations.generate(test_base.default_test_combinations()) def testPaddedBatchEmptyTensors(self): dataset = ( dataset_ops.Dataset.from_tensor_slices( @@ -109,6 +99,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase): batch_size=4, padded_shapes=[-1])) self.assertDatasetProduces(dataset, expected_output=[[[], [], [], []]]) + @combinations.generate(test_base.default_test_combinations()) def testPaddedBatchDatasetNonDefaultPadding(self): def fill_tuple(x): @@ -139,6 +130,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) + @combinations.generate(test_base.default_test_combinations()) def testPaddedBatchDatasetUnicode(self): # See GitHub issue 16149 def generator(): @@ -156,9 +148,8 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase): next_element = self.getNext(padded_dataset) self.evaluate(next_element()) - # NOTE: This test is specific to graph mode and is skipped in eager mode. - @test_util.run_deprecated_v1 - def testSkipEagerPaddedBatchDatasetShapeSpecifications(self): + @combinations.generate(test_base.graph_only_combinations()) + def testPaddedBatchDatasetShapeSpecifications(self): int_placeholder = array_ops.placeholder(dtypes.int32) float_placeholder = array_ops.placeholder(dtypes.float32) string_placeholder = array_ops.placeholder(dtypes.string) @@ -190,6 +181,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual([None, None, None], dataset_output_shapes[1].as_list()) self.assertEqual([None, 37], dataset_output_shapes[2].as_list()) + @combinations.generate(test_base.default_test_combinations()) def testPaddedBatchSparseError(self): def _map_fn(i): @@ -199,6 +191,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(TypeError): _ = dataset_ops.Dataset.range(10).map(_map_fn).padded_batch(10) + @combinations.generate(test_base.default_test_combinations()) def testPaddedBatchShapeError(self): with self.assertRaisesRegexp( ValueError, r'The padded shape \(1,\) is not compatible with the ' @@ -230,9 +223,8 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase): _ = dataset_ops.Dataset.range(10).padded_batch( 5, padded_shapes=shape_as_tensor) - # NOTE: This test is specific to graph mode and is skipped in eager mode. - @test_util.run_deprecated_v1 - def testSkipEagerPaddedBatchShapeError(self): + @combinations.generate(test_base.graph_only_combinations()) + def testPaddedBatchShapeErrorPlaceholder(self): with self.assertRaisesRegexp( ValueError, r'The padded shape \((\?|None), (\?|None)\) is not compatible with the ' diff --git a/tensorflow/python/data/kernel_tests/prefetch_test.py b/tensorflow/python/data/kernel_tests/prefetch_test.py index 427fbf1d29f..c6d2877ee7c 100644 --- a/tensorflow/python/data/kernel_tests/prefetch_test.py +++ b/tensorflow/python/data/kernel_tests/prefetch_test.py @@ -23,36 +23,41 @@ from absl.testing import parameterized from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import errors -from tensorflow.python.framework import test_util from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes class PrefetchTest(test_base.DatasetTestBase, parameterized.TestCase): - @parameterized.parameters((-1), (0), (5)) + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(buffer_size=[-1, None, 0, 42]))) def testBufferSize(self, buffer_size): dataset = dataset_ops.Dataset.range(10).prefetch(buffer_size=buffer_size) self.assertDatasetProduces(dataset, expected_output=range(10)) - @parameterized.parameters((-2), (-42)) + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(buffer_size=[-2, -42]))) def testInvalidBufferSize(self, buffer_size): with self.assertRaises(errors.InvalidArgumentError): dataset = dataset_ops.Dataset.range(10).prefetch(buffer_size=buffer_size) self.evaluate(dataset._variant_tensor) - @parameterized.parameters(*[(buffer_size, slack_period) - for buffer_size in (-1, None, 0, 5) - for slack_period in (1, 8)]) + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine( + buffer_size=[-1, None, 0, 42], slack_period=[1, 8]))) def testPrefetchWithSlack(self, buffer_size, slack_period): dataset = dataset_ops.Dataset.range(100) dataset = dataset_ops.PrefetchDataset( dataset, buffer_size, slack_period=slack_period) self.assertDatasetProduces(dataset, expected_output=range(100)) - @test_util.run_v1_only("graph-mode specific test") - def testSkipEagerPrefetchCancellation(self): + @combinations.generate(combinations.combine(tf_api_version=1, mode="graph")) + def testPrefetchCancellation(self): def map_py_fn(x): while x > -1: diff --git a/tensorflow/python/data/kernel_tests/range_test.py b/tensorflow/python/data/kernel_tests/range_test.py index b7ac60c3fff..d136565ce42 100644 --- a/tensorflow/python/data/kernel_tests/range_test.py +++ b/tensorflow/python/data/kernel_tests/range_test.py @@ -17,51 +17,60 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import errors -from tensorflow.python.framework import test_util from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class RangeTest(test_base.DatasetTestBase): +class RangeTest(test_base.DatasetTestBase, parameterized.TestCase): + @combinations.generate(test_base.default_test_combinations()) def testStop(self): dataset = dataset_ops.Dataset.range(5) self.assertDatasetProduces(dataset, expected_output=range(5)) + @combinations.generate(test_base.default_test_combinations()) def testStartStop(self): start, stop = 2, 5 dataset = dataset_ops.Dataset.range(start, stop) self.assertDatasetProduces(dataset, expected_output=range(2, 5)) + @combinations.generate(test_base.default_test_combinations()) def testStartStopStep(self): start, stop, step = 2, 10, 2 dataset = dataset_ops.Dataset.range(start, stop, step) self.assertDatasetProduces(dataset, expected_output=range(2, 10, 2)) + @combinations.generate(test_base.default_test_combinations()) def testZeroStep(self): start, stop, step = 2, 10, 0 with self.assertRaises(errors.InvalidArgumentError): dataset = dataset_ops.Dataset.range(start, stop, step) self.evaluate(dataset._variant_tensor) + @combinations.generate(test_base.default_test_combinations()) def testNegativeStep(self): start, stop, step = 2, 10, -1 dataset = dataset_ops.Dataset.range(start, stop, step) self.assertDatasetProduces(dataset, expected_output=range(2, 10, -1)) + @combinations.generate(test_base.default_test_combinations()) def testStopLessThanStart(self): start, stop = 10, 2 dataset = dataset_ops.Dataset.range(start, stop) self.assertDatasetProduces(dataset, expected_output=range(10, 2)) + @combinations.generate(test_base.default_test_combinations()) def testStopLessThanStartWithPositiveStep(self): start, stop, step = 10, 2, 2 dataset = dataset_ops.Dataset.range(start, stop, step) self.assertDatasetProduces(dataset, expected_output=range(10, 2, 2)) + @combinations.generate(test_base.default_test_combinations()) def testStopLessThanStartWithNegativeStep(self): start, stop, step = 10, 2, -1 dataset = dataset_ops.Dataset.range(start, stop, step) diff --git a/tensorflow/python/data/kernel_tests/repeat_test.py b/tensorflow/python/data/kernel_tests/repeat_test.py index 8a8537b30cf..c4262fcc08c 100644 --- a/tensorflow/python/data/kernel_tests/repeat_test.py +++ b/tensorflow/python/data/kernel_tests/repeat_test.py @@ -17,43 +17,33 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import test_util +from tensorflow.python.framework import combinations from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class RepeatTest(test_base.DatasetTestBase): +class RepeatTest(test_base.DatasetTestBase, parameterized.TestCase): - def testRepeatTensorDataset(self): + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(count=[0, 3, 7]))) + def testFiniteRepeat(self, count): """Test a dataset that repeats its input multiple times.""" components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) - # This placeholder can be fed when dataset-definition subgraph - # runs (i.e. `init_op` below) to configure the number of - # repetitions used in a particular iterator. + dataset = dataset_ops.Dataset.from_tensors(components).repeat(count) + self.assertEqual( + [c.shape for c in components], + [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)]) + self.assertDatasetProduces(dataset, [components] * count) - def do_test(count): - dataset = dataset_ops.Dataset.from_tensors(components).repeat(count) - self.assertEqual( - [c.shape for c in components], - [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)]) - self.assertDatasetProduces(dataset, [components] * count) - - # Test a finite repetition. - do_test(3) - - # test a different finite repetition. - do_test(7) - - # Test an empty repetition. - do_test(0) - - # Test an infinite repetition. - # NOTE(mrry): There's not a good way to test that the sequence - # actually is infinite. + @combinations.generate(test_base.default_test_combinations()) + def testInfiniteRepeat(self): + # NOTE(mrry): There's not a good way to test that the sequence is infinite. + components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) dataset = dataset_ops.Dataset.from_tensors(components).repeat(-1) self.assertEqual( [c.shape for c in components], @@ -64,7 +54,8 @@ class RepeatTest(test_base.DatasetTestBase): for component, result_component in zip(components, results): self.assertAllEqual(component, result_component) - def testRepeatRepeatTensorDataset(self): + @combinations.generate(test_base.default_test_combinations()) + def testRepeatRepeat(self): """Test the composition of repeat datasets.""" components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) inner_count, outer_count = 7, 14 @@ -77,11 +68,6 @@ class RepeatTest(test_base.DatasetTestBase): self.assertDatasetProduces(dataset, [components] * (inner_count * outer_count)) - def testRepeatEmptyDataset(self): - """Test that repeating an empty dataset does not hang.""" - dataset = dataset_ops.Dataset.from_tensors(0).repeat(10).skip(10).repeat(-1) - self.assertDatasetProduces(dataset, []) - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/shard_test.py b/tensorflow/python/data/kernel_tests/shard_test.py index 9fc70ff6075..5830b66d61c 100644 --- a/tensorflow/python/data/kernel_tests/shard_test.py +++ b/tensorflow/python/data/kernel_tests/shard_test.py @@ -17,66 +17,79 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import errors -from tensorflow.python.framework import test_util from tensorflow.python.platform import test -@test_util.run_v1_only("deprecated API, no eager or V2 test coverage") -class ShardTest(test_base.DatasetTestBase): +class ShardTest(test_base.DatasetTestBase, parameterized.TestCase): + @combinations.generate(test_base.default_test_combinations()) def testSimpleCase(self): dataset = dataset_ops.Dataset.range(10).shard(5, 2) self.assertDatasetProduces(dataset, expected_output=[2, 7]) + @combinations.generate(test_base.default_test_combinations()) def testNestedData(self): dataset_a = dataset_ops.Dataset.range(10) dataset_b = dataset_ops.Dataset.range(10, 0, -1) dataset = dataset_ops.Dataset.zip((dataset_a, dataset_b)).shard(5, 2) self.assertDatasetProduces(dataset, expected_output=[(2, 8), (7, 3)]) + @combinations.generate(test_base.default_test_combinations()) def testOffsetZero(self): dataset = dataset_ops.Dataset.range(10).shard(5, 0) self.assertDatasetProduces(dataset, expected_output=[0, 5]) + @combinations.generate(test_base.default_test_combinations()) def testOffsetGreaterNumShards(self): with self.assertRaises(errors.InvalidArgumentError): dataset = dataset_ops.Dataset.range(10).shard(5, 7) self.evaluate(self.getNext(dataset)()) + @combinations.generate(test_base.default_test_combinations()) def testNegativeOffset(self): with self.assertRaises(errors.InvalidArgumentError): dataset = dataset_ops.Dataset.range(10).shard(5, -3) self.evaluate(self.getNext(dataset)()) + @combinations.generate(test_base.default_test_combinations()) def testNegativeNumShards(self): with self.assertRaises(errors.InvalidArgumentError): dataset = dataset_ops.Dataset.range(10).shard(-3, 1) self.evaluate(self.getNext(dataset)()) + @combinations.generate(test_base.default_test_combinations()) def testZeroNumShards(self): with self.assertRaises(errors.InvalidArgumentError): dataset = dataset_ops.Dataset.range(10).shard(0, 1) self.evaluate(self.getNext(dataset)()) + @combinations.generate(test_base.default_test_combinations()) def testIteratorEndsBeforeFirstElem(self): dataset = dataset_ops.Dataset.range(1).shard(5, 2) self.assertDatasetProduces(dataset, expected_output=[]) + @combinations.generate(test_base.default_test_combinations()) def testLargerWorkerPool(self): dataset = dataset_ops.Dataset.range(10).shard(7, 5) self.assertDatasetProduces(dataset, expected_output=[5]) + @combinations.generate(test_base.default_test_combinations()) def testIndexEqualsNumShards(self): dataset = dataset_ops.Dataset.range(10).shard(5, 4) self.assertDatasetProduces(dataset, expected_output=[4, 9]) + @combinations.generate(test_base.default_test_combinations()) def testIndexEqualsNumShards2(self): dataset = dataset_ops.Dataset.range(10).shard(4, 3) self.assertDatasetProduces(dataset, expected_output=[3, 7]) + @combinations.generate(test_base.default_test_combinations()) def testNumShardsLargerThanDataset(self): dataset = dataset_ops.Dataset.range(10).shard(20, 5) self.assertDatasetProduces(dataset, expected_output=[5]) diff --git a/tensorflow/python/data/kernel_tests/shuffle_test.py b/tensorflow/python/data/kernel_tests/shuffle_test.py index 7f801e1b5f4..c9d17b79016 100644 --- a/tensorflow/python/data/kernel_tests/shuffle_test.py +++ b/tensorflow/python/data/kernel_tests/shuffle_test.py @@ -40,7 +40,7 @@ from tensorflow.python.platform import test class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) - def testShuffleDataset(self): + def testBasic(self): components = ( np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), np.array([9.0, 10.0, 11.0, 12.0]) @@ -160,7 +160,7 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times( - combinations.combine(tf_api_version=[1, 2], mode="graph"), + test_base.graph_only_combinations(), combinations.combine(reshuffle=[True, False]), combinations.combine(graph_seed=38, op_seed=None) + combinations.combine(graph_seed=None, op_seed=42) + @@ -188,7 +188,7 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase): # TODO(b/117581999): enable this test for eager-mode. @combinations.generate( combinations.times( - combinations.combine(tf_api_version=[1, 2], mode="graph"), + test_base.graph_only_combinations(), combinations.combine( reshuffle=[True, False], initializable=[True, False]))) def testMultipleIterators(self, reshuffle, initializable): @@ -278,7 +278,7 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times( - combinations.combine(tf_api_version=[1, 2], mode="eager"), + test_base.eager_only_combinations(), combinations.combine(reshuffle=[True, False], seed=[None, 42]))) def testReshuffleSeparateTransformations(self, reshuffle, seed): dataset = dataset_ops.Dataset.range(10) diff --git a/tensorflow/python/data/kernel_tests/skip_test.py b/tensorflow/python/data/kernel_tests/skip_test.py index 74dc8b7f55c..176893d90d2 100644 --- a/tensorflow/python/data/kernel_tests/skip_test.py +++ b/tensorflow/python/data/kernel_tests/skip_test.py @@ -17,46 +17,30 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import test_util +from tensorflow.python.framework import combinations from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class SkipTest(test_base.DatasetTestBase): +class SkipTest(test_base.DatasetTestBase, parameterized.TestCase): - def testSkipTensorDataset(self): + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(count=[-1, 0, 4, 10, 25]))) + def testBasic(self, count): components = (np.arange(10),) - - def do_test(count): - dataset = dataset_ops.Dataset.from_tensor_slices(components).skip(count) - self.assertEqual( - [c.shape[1:] for c in components], - [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)]) - start_range = min(count, 10) if count != -1 else 10 - self.assertDatasetProduces( - dataset, - [tuple(components[0][i:i + 1]) for i in range(start_range, 10)]) - - # Skip fewer than input size, we should skip - # the first 4 elements and then read the rest. - do_test(4) - - # Skip more than input size: get nothing. - do_test(25) - - # Skip exactly input size. - do_test(10) - - # Set -1 for 'count': skip the entire dataset. - do_test(-1) - - # Skip nothing - do_test(0) - + dataset = dataset_ops.Dataset.from_tensor_slices(components).skip(count) + self.assertEqual( + [c.shape[1:] for c in components], + [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)]) + start_range = min(count, 10) if count != -1 else 10 + self.assertDatasetProduces( + dataset, + [tuple(components[0][i:i + 1]) for i in range(start_range, 10)]) if __name__ == "__main__": diff --git a/tensorflow/python/data/kernel_tests/take_test.py b/tensorflow/python/data/kernel_tests/take_test.py index 665ed59a7bc..14796551e16 100644 --- a/tensorflow/python/data/kernel_tests/take_test.py +++ b/tensorflow/python/data/kernel_tests/take_test.py @@ -17,40 +17,30 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import test_util +from tensorflow.python.framework import combinations from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class TakeTest(test_base.DatasetTestBase): +class TakeTest(test_base.DatasetTestBase, parameterized.TestCase): - def testTakeTensorDataset(self): + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(count=[-1, 0, 4, 10, 25]))) + def testBasic(self, count): components = (np.arange(10),) + dataset = dataset_ops.Dataset.from_tensor_slices(components).take(count) + self.assertEqual( + [c.shape[1:] for c in components], + [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)]) + num_output = min(count, 10) if count != -1 else 10 + self.assertDatasetProduces( + dataset, [tuple(components[0][i:i + 1]) for i in range(num_output)]) - def do_test(count): - dataset = dataset_ops.Dataset.from_tensor_slices(components).take(count) - self.assertEqual( - [c.shape[1:] for c in components], - [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)]) - num_output = min(count, 10) if count != -1 else 10 - self.assertDatasetProduces( - dataset, [tuple(components[0][i:i + 1]) for i in range(num_output)]) - - # Take fewer than input size - do_test(4) - - # Take more than input size - do_test(25) - - # Take all of input - do_test(-1) - - # Take nothing - do_test(0) if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py index 6dfee4cc0f7..60796b178bf 100644 --- a/tensorflow/python/data/kernel_tests/test_base.py +++ b/tensorflow/python/data/kernel_tests/test_base.py @@ -58,7 +58,11 @@ class DatasetTestBase(test.TestCase): def assertValuesEqual(self, expected, actual): """Asserts that two values are equal.""" - if sparse_tensor.is_sparse(expected): + if isinstance(expected, dict): + self.assertItemsEqual(list(expected.keys()), list(actual.keys())) + for k in expected.keys(): + self.assertValuesEqual(expected[k], actual[k]) + elif sparse_tensor.is_sparse(expected): self.assertAllEqual(expected.indices, actual.indices) self.assertAllEqual(expected.values, actual.values) self.assertAllEqual(expected.dense_shape, actual.dense_shape) diff --git a/tensorflow/python/data/kernel_tests/text_line_dataset_test.py b/tensorflow/python/data/kernel_tests/text_line_dataset_test.py index c62d4ec8270..35b479faa21 100644 --- a/tensorflow/python/data/kernel_tests/text_line_dataset_test.py +++ b/tensorflow/python/data/kernel_tests/text_line_dataset_test.py @@ -21,11 +21,12 @@ import gzip import os import zlib +from absl.testing import parameterized + from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers -from tensorflow.python.eager import context -from tensorflow.python.framework import test_util +from tensorflow.python.framework import combinations from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -37,8 +38,7 @@ except ImportError: psutil_import_succeeded = False -@test_util.run_all_in_graph_and_eager_modes -class TextLineDatasetTest(test_base.DatasetTestBase): +class TextLineDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def _lineText(self, f, l): return compat.as_bytes("%d: %d" % (f, l)) @@ -76,7 +76,11 @@ class TextLineDatasetTest(test_base.DatasetTestBase): return filenames - def _testTextLineDataset(self, compression_type=None): + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(compression_type=[None, "GZIP", "ZLIB"]))) + def testTextLineDataset(self, compression_type): test_filenames = self._createFiles( 2, 5, crlf=True, compression_type=compression_type) @@ -115,6 +119,7 @@ class TextLineDatasetTest(test_base.DatasetTestBase): expected_output=[[self._lineText(0, i) for i in range(5)], [self._lineText(1, i) for i in range(5)]] * 10) + @combinations.generate(test_base.default_test_combinations()) def testTextLineDatasetParallelRead(self): test_filenames = self._createFiles(10, 10) files = dataset_ops.Dataset.from_tensor_slices(test_filenames).repeat(10) @@ -125,15 +130,7 @@ class TextLineDatasetTest(test_base.DatasetTestBase): self.assertDatasetProduces( dataset, expected_output=expected_output * 10, assert_items_equal=True) - def testTextLineDatasetNoCompression(self): - self._testTextLineDataset() - - def testTextLineDatasetGzipCompression(self): - self._testTextLineDataset(compression_type="GZIP") - - def testTextLineDatasetZlibCompression(self): - self._testTextLineDataset(compression_type="ZLIB") - + @combinations.generate(test_base.default_test_combinations()) def testTextLineDatasetBuffering(self): test_filenames = self._createFiles(2, 5, crlf=True) @@ -143,33 +140,33 @@ class TextLineDatasetTest(test_base.DatasetTestBase): expected_output.extend([self._lineText(j, i) for i in range(5)]) self.assertDatasetProduces(repeat_dataset, expected_output=expected_output) + @combinations.generate(test_base.eager_only_combinations()) def testIteratorResourceCleanup(self): filename = os.path.join(self.get_temp_dir(), "text.txt") with open(filename, "wt") as f: for i in range(3): f.write("%d\n" % (i,)) - with context.eager_mode(): - first_iterator = iter(readers.TextLineDataset(filename)) - self.assertEqual(b"0", next(first_iterator).numpy()) - second_iterator = iter(readers.TextLineDataset(filename)) - self.assertEqual(b"0", next(second_iterator).numpy()) - # Eager kernel caching is based on op attributes, which includes the - # Dataset's output shape. Create a different kernel to test that they - # don't create resources with the same names. - different_kernel_iterator = iter( - readers.TextLineDataset(filename).repeat().batch(16)) - self.assertEqual([16], next(different_kernel_iterator).shape) - # Remove our references to the Python Iterator objects, which (assuming no - # reference cycles) is enough to trigger DestroyResourceOp and close the - # partially-read files. - del first_iterator - del second_iterator - del different_kernel_iterator - if not psutil_import_succeeded: - self.skipTest( - "psutil is required to check that we've closed our files.") - open_files = psutil.Process().open_files() - self.assertNotIn(filename, [open_file.path for open_file in open_files]) + first_iterator = iter(readers.TextLineDataset(filename)) + self.assertEqual(b"0", next(first_iterator).numpy()) + second_iterator = iter(readers.TextLineDataset(filename)) + self.assertEqual(b"0", next(second_iterator).numpy()) + # Eager kernel caching is based on op attributes, which includes the + # Dataset's output shape. Create a different kernel to test that they + # don't create resources with the same names. + different_kernel_iterator = iter( + readers.TextLineDataset(filename).repeat().batch(16)) + self.assertEqual([16], next(different_kernel_iterator).shape) + # Remove our references to the Python Iterator objects, which (assuming no + # reference cycles) is enough to trigger DestroyResourceOp and close the + # partially-read files. + del first_iterator + del second_iterator + del different_kernel_iterator + if not psutil_import_succeeded: + self.skipTest( + "psutil is required to check that we've closed our files.") + open_files = psutil.Process().open_files() + self.assertNotIn(filename, [open_file.path for open_file in open_files]) if __name__ == "__main__": diff --git a/tensorflow/python/data/kernel_tests/tf_record_dataset_test.py b/tensorflow/python/data/kernel_tests/tf_record_dataset_test.py index 5cf8308a55f..792c4926640 100644 --- a/tensorflow/python/data/kernel_tests/tf_record_dataset_test.py +++ b/tensorflow/python/data/kernel_tests/tf_record_dataset_test.py @@ -21,31 +21,31 @@ import gzip import os import zlib +from absl.testing import parameterized + from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers +from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op -from tensorflow.python.framework import test_util from tensorflow.python.lib.io import python_io from tensorflow.python.platform import test from tensorflow.python.util import compat -@test_util.run_all_in_graph_and_eager_modes -class TFRecordDatasetTest(test_base.DatasetTestBase): +class TFRecordDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def setUp(self): super(TFRecordDatasetTest, self).setUp() self._num_files = 2 self._num_records = 7 - self.test_filenames = self._createFiles() - def dataset_fn(self, - filenames, - compression_type="", - num_epochs=1, - batch_size=None): + def _dataset_factory(self, + filenames, + compression_type="", + num_epochs=1, + batch_size=None): repeat_dataset = readers.TFRecordDataset( filenames, compression_type).repeat(num_epochs) @@ -67,6 +67,7 @@ class TFRecordDatasetTest(test_base.DatasetTestBase): writer.close() return filenames + @combinations.generate(test_base.default_test_combinations()) def testTFRecordDatasetConstructorErrorsTensorInput(self): with self.assertRaisesRegex(TypeError, "filenames.*must be.*Tensor.*string"): @@ -78,37 +79,40 @@ class TFRecordDatasetTest(test_base.DatasetTestBase): with self.assertRaises(Exception): readers.TFRecordDataset(object()) + @combinations.generate(test_base.default_test_combinations()) def testReadOneEpoch(self): # Basic test: read from file 0. - dataset = self.dataset_fn(self.test_filenames[0]) + dataset = self._dataset_factory(self.test_filenames[0]) self.assertDatasetProduces( dataset, expected_output=[self._record(0, i) for i in range(self._num_records)]) # Basic test: read from file 1. - dataset = self.dataset_fn(self.test_filenames[1]) + dataset = self._dataset_factory(self.test_filenames[1]) self.assertDatasetProduces( dataset, expected_output=[self._record(1, i) for i in range(self._num_records)]) # Basic test: read from both files. - dataset = self.dataset_fn(self.test_filenames) + dataset = self._dataset_factory(self.test_filenames) expected_output = [] for j in range(self._num_files): expected_output.extend( [self._record(j, i) for i in range(self._num_records)]) self.assertDatasetProduces(dataset, expected_output=expected_output) + @combinations.generate(test_base.default_test_combinations()) def testReadTenEpochs(self): - dataset = self.dataset_fn(self.test_filenames, num_epochs=10) + dataset = self._dataset_factory(self.test_filenames, num_epochs=10) expected_output = [] for j in range(self._num_files): expected_output.extend( [self._record(j, i) for i in range(self._num_records)]) self.assertDatasetProduces(dataset, expected_output=expected_output * 10) + @combinations.generate(test_base.default_test_combinations()) def testReadTenEpochsOfBatches(self): - dataset = self.dataset_fn( + dataset = self._dataset_factory( self.test_filenames, num_epochs=10, batch_size=self._num_records) expected_output = [] for j in range(self._num_files): @@ -116,6 +120,7 @@ class TFRecordDatasetTest(test_base.DatasetTestBase): [self._record(j, i) for i in range(self._num_records)]) self.assertDatasetProduces(dataset, expected_output=expected_output * 10) + @combinations.generate(test_base.default_test_combinations()) def testReadZlibFiles(self): zlib_files = [] for i, fn in enumerate(self.test_filenames): @@ -130,9 +135,10 @@ class TFRecordDatasetTest(test_base.DatasetTestBase): for j in range(self._num_files): expected_output.extend( [self._record(j, i) for i in range(self._num_records)]) - dataset = self.dataset_fn(zlib_files, compression_type="ZLIB") + dataset = self._dataset_factory(zlib_files, compression_type="ZLIB") self.assertDatasetProduces(dataset, expected_output=expected_output) + @combinations.generate(test_base.default_test_combinations()) def testReadGzipFiles(self): gzip_files = [] for i, fn in enumerate(self.test_filenames): @@ -145,9 +151,10 @@ class TFRecordDatasetTest(test_base.DatasetTestBase): for j in range(self._num_files): expected_output.extend( [self._record(j, i) for i in range(self._num_records)]) - dataset = self.dataset_fn(gzip_files, compression_type="GZIP") + dataset = self._dataset_factory(gzip_files, compression_type="GZIP") self.assertDatasetProduces(dataset, expected_output=expected_output) + @combinations.generate(test_base.default_test_combinations()) def testReadWithBuffer(self): one_mebibyte = 2**20 dataset = readers.TFRecordDataset( @@ -158,6 +165,7 @@ class TFRecordDatasetTest(test_base.DatasetTestBase): [self._record(j, i) for i in range(self._num_records)]) self.assertDatasetProduces(dataset, expected_output=expected_output) + @combinations.generate(test_base.default_test_combinations()) def testReadFromDatasetOfFiles(self): files = dataset_ops.Dataset.from_tensor_slices(self.test_filenames) expected_output = [] @@ -167,6 +175,7 @@ class TFRecordDatasetTest(test_base.DatasetTestBase): dataset = readers.TFRecordDataset(files) self.assertDatasetProduces(dataset, expected_output=expected_output) + @combinations.generate(test_base.default_test_combinations()) def testReadTenEpochsFromDatasetOfFilesInParallel(self): files = dataset_ops.Dataset.from_tensor_slices( self.test_filenames).repeat(10) diff --git a/tensorflow/python/data/kernel_tests/unbatch_test.py b/tensorflow/python/data/kernel_tests/unbatch_test.py index 5bb4852d534..44d949385b0 100644 --- a/tensorflow/python/data/kernel_tests/unbatch_test.py +++ b/tensorflow/python/data/kernel_tests/unbatch_test.py @@ -23,11 +23,11 @@ import numpy as np from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import string_ops @@ -36,13 +36,14 @@ from tensorflow.python.platform import test from tensorflow.python.util import compat -@test_util.run_all_in_graph_and_eager_modes class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase): + @combinations.generate(test_base.default_test_combinations()) def testUnbatchWithUnknownRankInput(self): dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]).unbatch() self.assertDatasetProduces(dataset, range(4)) + @combinations.generate(test_base.default_test_combinations()) def testUnbatchScalarDataset(self): data = tuple([math_ops.range(10) for _ in range(3)]) data = dataset_ops.Dataset.from_tensor_slices(data) @@ -54,12 +55,14 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces(data, [(i,) * 3 for i in range(10)]) + @combinations.generate(test_base.default_test_combinations()) def testUnbatchNestedDataset(self): data = dataset_ops.Dataset.from_tensors( [dataset_ops.Dataset.range(10) for _ in range(10)]) data = data.unbatch().flat_map(lambda x: x) self.assertDatasetProduces(data, list(range(10)) * 10) + @combinations.generate(test_base.default_test_combinations()) def testUnbatchDatasetWithStrings(self): data = tuple([math_ops.range(10) for _ in range(3)]) data = dataset_ops.Dataset.from_tensor_slices(data) @@ -73,6 +76,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces( data, [(i, compat.as_bytes(str(i)), i) for i in range(10)]) + @combinations.generate(test_base.default_test_combinations()) def testUnbatchDatasetWithSparseTensor(self): st = sparse_tensor.SparseTensorValue( indices=[[i, i] for i in range(10)], @@ -87,6 +91,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase): ] self.assertDatasetProduces(data, expected_output=expected_output) + @combinations.generate(test_base.default_test_combinations()) def testUnbatchDatasetWithDenseSparseAndRaggedTensor(self): st = sparse_tensor.SparseTensorValue( indices=[[i, i] for i in range(10)], @@ -104,6 +109,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces( data, expected_output=expected_output) + @combinations.generate(test_base.default_test_combinations()) def testUnbatchDatasetWithRaggedTensor(self): rt = ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]], [[4]], [[5]], [[6]], [[7]], [[8]], [[9]]]) @@ -119,6 +125,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces( data, expected_output=expected_output) + @combinations.generate(test_base.default_test_combinations()) def testUnbatchSingleElementTupleDataset(self): data = tuple([(math_ops.range(10),) for _ in range(3)]) data = dataset_ops.Dataset.from_tensor_slices(data) @@ -130,6 +137,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces(data, [((i,),) * 3 for i in range(10)]) + @combinations.generate(test_base.default_test_combinations()) def testUnbatchMultiElementTupleDataset(self): data = tuple([(math_ops.range(10 * i, 10 * i + 10), array_ops.fill([10], "hi")) for i in range(3)]) @@ -146,6 +154,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase): data, [((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")) for i in range(10)]) + @combinations.generate(test_base.default_test_combinations()) def testUnbatchEmpty(self): data = dataset_ops.Dataset.from_tensors( (constant_op.constant([]), constant_op.constant([], shape=[0, 4]), @@ -153,15 +162,15 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase): data = data.unbatch() self.assertDatasetProduces(data, []) + @combinations.generate(test_base.default_test_combinations()) def testUnbatchStaticShapeMismatch(self): data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8), np.arange(9))) with self.assertRaises(ValueError): data.unbatch() - # Note: dynamic shape mismatch is graph specific test. - @test_util.run_deprecated_v1 - def testSkipEagerUnbatchDynamicShapeMismatch(self): + @combinations.generate(test_base.graph_only_combinations()) + def testUnbatchDynamicShapeMismatch(self): ph1 = array_ops.placeholder(dtypes.int32, shape=[None]) ph2 = array_ops.placeholder(dtypes.int32, shape=None) data = dataset_ops.Dataset.from_tensors((ph1, ph2)) @@ -190,6 +199,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.InvalidArgumentError): self.evaluate(next_element) + @combinations.generate(test_base.default_test_combinations()) def testUnbatchDatasetWithUintDtypes(self): components = ( np.tile(np.array([[0], [1], [2], [3]], dtype=np.uint8), 2), diff --git a/tensorflow/python/data/kernel_tests/window_test.py b/tensorflow/python/data/kernel_tests/window_test.py index 122e874f0a0..98b453a5900 100644 --- a/tensorflow/python/data/kernel_tests/window_test.py +++ b/tensorflow/python/data/kernel_tests/window_test.py @@ -24,43 +24,32 @@ from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.eager import context +from tensorflow.python.framework import combinations from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes class WindowTest(test_base.DatasetTestBase, parameterized.TestCase): - @parameterized.named_parameters( - ("1", 20, 14, 7, 1), - ("2", 20, 17, 9, 1), - ("3", 20, 14, 14, 1), - ("4", 20, 10, 14, 1), - ("5", 20, 14, 19, 1), - ("6", 20, 4, 1, 2), - ("7", 20, 2, 1, 6), - ("8", 20, 4, 7, 2), - ("9", 20, 2, 7, 6), - ("10", 1, 10, 4, 1), - ("11", 0, 10, 4, 1), - ("12", 20, 14, 7, 1, False), - ("13", 20, 17, 9, 1, False), - ("14", 20, 14, 14, 1, False), - ("15", 20, 10, 14, 1, False), - ("16", 20, 14, 19, 1, False), - ("17", 20, 4, 1, 2, False), - ("18", 20, 2, 1, 6, False), - ("19", 20, 4, 7, 2, False), - ("20", 20, 2, 7, 6, False), - ("21", 1, 10, 4, 1, False), - ("22", 0, 10, 4, 1, False), - ) - def testWindowDataset(self, count, size, shift, stride, drop_remainder=True): + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine( + count=20, + size=[10, 14, 17], + shift=[7, 14], + stride=[1, 2, 6], + drop_remainder=[True, False]) + combinations.combine( + count=[0, 1], + size=10, + shift=4, + stride=1, + drop_remainder=[True, False]))) + def testWindowDataset(self, count, size, shift, stride, drop_remainder): """Tests a dataset that slides a window its input elements.""" components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], @@ -111,11 +100,12 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) - @parameterized.named_parameters( - ("1", 14, 0, 3, 1), - ("2", 14, 3, 0, 1), - ("3", 14, 3, 3, 0), - ) + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(count=20, size=0, shift=3, stride=1) + + combinations.combine(count=20, size=3, shift=0, stride=1) + + combinations.combine(count=20, size=3, shift=3, stride=0))) def testWindowDatasetInvalid(self, count, size, shift, stride): with self.assertRaises(errors.InvalidArgumentError): ds = dataset_ops.Dataset.range(10).map(lambda x: x).repeat(count).window( @@ -123,12 +113,14 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase): stride=stride).flat_map(lambda x: x.batch(batch_size=size)) self.evaluate(ds._variant_tensor) + @combinations.generate(test_base.default_test_combinations()) def testWindowDifferentNestedStructures(self): ds = dataset_ops.Dataset.from_tensor_slices(([1, 2], [3, 4])).window(2) self.getNext(ds) ds = dataset_ops.Dataset.from_tensor_slices({"a": [1, 2]}).window(2) self.getNext(ds) + @combinations.generate(test_base.default_test_combinations()) def testWindowSparse(self): def _sparse(i): @@ -148,6 +140,7 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase): ] self.assertDatasetProduces(dataset, expected_output=expected_output) + @combinations.generate(test_base.default_test_combinations()) def testWindowSparseWithDifferentDenseShapes(self): def _sparse(i): @@ -177,6 +170,7 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase): dense_shape=[5, i * 3 + 5 - 1])) self.assertDatasetProduces(dataset, expected_output=expected_output) + @combinations.generate(test_base.default_test_combinations()) def testNestedWindowSparse(self): def _sparse(i): @@ -205,6 +199,7 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase): ] self.assertDatasetProduces(dataset, expected_output=expected_output) + @combinations.generate(test_base.default_test_combinations()) def testWindowShapeError(self): def generator(): @@ -222,6 +217,7 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase): r"Cannot batch tensors with different shapes in component 0. " r"First element had shape \[3\] and element 2 had shape \[4\].")) + @combinations.generate(test_base.default_test_combinations()) def testWindowIgnoreErrors(self): input_values = np.float32([1., np.nan, 2., np.nan, 3.]) dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map( @@ -232,6 +228,7 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase): dataset, expected_output=[np.float32([1., 2.]), np.float32([2., 3.])]) + @combinations.generate(test_base.default_test_combinations()) def testNestedOutput(self): if not context.executing_eagerly(): self.skipTest("self.evaluate() does not work with a dataset") diff --git a/tensorflow/python/data/kernel_tests/zip_test.py b/tensorflow/python/data/kernel_tests/zip_test.py index 72f739e4e4e..c63091754c3 100644 --- a/tensorflow/python/data/kernel_tests/zip_test.py +++ b/tensorflow/python/data/kernel_tests/zip_test.py @@ -17,66 +17,68 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import errors from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import test_util from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class ZipTest(test_base.DatasetTestBase): +def _dataset_factory(components): + datasets = tuple([ + dataset_ops.Dataset.from_tensor_slices(component) + for component in components + ]) + return dataset_ops.Dataset.zip(datasets) - def testZipDataset(self): - def dataset_fn(components): - datasets = tuple([ - dataset_ops.Dataset.from_tensor_slices(component) - for component in components - ]) - return dataset_ops.Dataset.zip(datasets) +class ZipTest(test_base.DatasetTestBase, parameterized.TestCase): - equal_length_components = [ + @combinations.generate(test_base.default_test_combinations()) + def testZipEqual(self): + components = [ np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile(np.array([[12], [13], [14], [15]]), 22), np.array([37.0, 38.0, 39.0, 40.0]) ] - - get_next = self.getNext(dataset_fn(equal_length_components)) + get_next = self.getNext(_dataset_factory(components)) for i in range(4): results = self.evaluate(get_next()) - for component, result_component in zip(equal_length_components, results): + for component, result_component in zip(components, results): self.assertAllEqual(component[i], result_component) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) - variable_length_components = [[1, 2, 3, 4], [1, 2, 3, 4, 5], [1.0, 2.0]] - get_next = self.getNext(dataset_fn(variable_length_components)) + @combinations.generate(test_base.default_test_combinations()) + def testZipUnequal(self): + components = [[1, 2, 3, 4], [1, 2, 3, 4, 5], [1.0, 2.0]] + get_next = self.getNext(_dataset_factory(components)) for i in range(2): results = self.evaluate(get_next()) - for component, result_component in zip(variable_length_components, - results): + for component, result_component in zip(components, results): self.assertAllEqual(component[i], result_component) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) - def testNestedZipDataset(self): + @combinations.generate(test_base.default_test_combinations()) + def testNested(self): - equal_length_components = [ + components = [ np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile(np.array([[12], [13], [14], [15]]), 22), np.array([37.0, 38.0, 39.0, 40.0]) ] datasets = [ dataset_ops.Dataset.from_tensor_slices(component) - for component in equal_length_components + for component in components ] dataset = dataset_ops.Dataset.zip((datasets[0], (datasets[1], datasets[2]))) @@ -88,9 +90,9 @@ class ZipTest(test_base.DatasetTestBase): get_next = self.getNext(dataset) for i in range(4): result1, (result2, result3) = self.evaluate(get_next()) - self.assertAllEqual(equal_length_components[0][i], result1) - self.assertAllEqual(equal_length_components[1][i], result2) - self.assertAllEqual(equal_length_components[2][i], result3) + self.assertAllEqual(components[0][i], result1) + self.assertAllEqual(components[1][i], result2) + self.assertAllEqual(components[2][i], result3) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) with self.assertRaises(errors.OutOfRangeError): diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index f67dec9d720..f3367023a7b 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -66,7 +66,6 @@ from tensorflow.python.ops import gen_io_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops from tensorflow.python.ops import string_ops -from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.tracking import base as tracking_base from tensorflow.python.training.tracking import tracking from tensorflow.python.util import deprecation @@ -2437,26 +2436,25 @@ class DatasetV1Adapter(DatasetV1): def _ensure_same_dataset_graph(dataset): """Walks the dataset graph to ensure all datasets come from the same graph.""" + # pylint: disable=protected-access current_graph = ops.get_default_graph() bfs_q = Queue.Queue() - bfs_q.put(dataset) # pylint: disable=protected-access + bfs_q.put(dataset) visited = [] while not bfs_q.empty(): ds = bfs_q.get() visited.append(ds) - ds_graph = ds._graph # pylint: disable=protected-access + ds_graph = ds._graph if current_graph != ds_graph: - logging.warning("The graph (" + str(current_graph) + ") of the iterator " - "is different from the graph (" + str(ds_graph) + ") " - "the dataset: " + str(ds._variant_tensor) + " was " # pylint: disable=protected-access - "created in. If you are using the Estimator API, " - "make sure that no part of the dataset returned by the " - "`input_fn` function is defined outside the `input_fn` " - "function. Please ensure that all datasets in the " - "pipeline are created in the same graph as the iterator. " - "NOTE: This warning will become an error in future " - "versions of TensorFlow.") - for input_ds in ds._inputs(): # pylint: disable=protected-access + raise ValueError( + "The graph (" + str(current_graph) + ") of the iterator is different " + "from the graph (" + str(ds_graph) + ") the dataset: " + + str(ds._variant_tensor) + " was created in. If you are using the " + "Estimator API, make sure that no part of the dataset returned by " + "the `input_fn` function is defined outside the `input_fn` function. " + "Please ensure that all datasets in the pipeline are created in the " + "same graph as the iterator.") + for input_ds in ds._inputs(): if input_ds not in visited: bfs_q.put(input_ds) diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 1c30328e7dd..2bc35ef52af 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -654,6 +654,7 @@ py_test( deps = [ ":debug_events_reader", ":debug_events_writer", + ":dumping_callback_test_lib", "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", @@ -725,6 +726,7 @@ cuda_py_test( "//tensorflow/python/keras", ], python_version = "PY3", + shard_count = 8, tags = [ "guitar", "multi_and_single_gpu", @@ -766,6 +768,7 @@ cuda_py_test( additional_deps = [ ":debug_events_reader", ":debug_events_writer", + ":dumping_callback_test_lib", "//third_party/py/numpy", "//tensorflow/python:debug_ops_gen", "//tensorflow/python:framework_test_lib", diff --git a/tensorflow/python/debug/lib/debug_events_reader.py b/tensorflow/python/debug/lib/debug_events_reader.py index 2a9f331439b..a20cc175ebb 100644 --- a/tensorflow/python/debug/lib/debug_events_reader.py +++ b/tensorflow/python/debug/lib/debug_events_reader.py @@ -55,6 +55,13 @@ class DebugEventsReader(object): self._readers = dict() # A map from file path to reader. self._readers_lock = threading.Lock() + def __enter__(self): + return self + + def __exit__(self, exception_type, exception_value, traceback): + del exception_type, exception_value, traceback # Unused + self.close() + def _generic_iterator(self, file_path): """A helper method that makes an iterator given a debug-events file path.""" # The following code uses the double-checked locking pattern to optimize @@ -93,3 +100,7 @@ class DebugEventsReader(object): def graph_execution_traces_iterator(self): return self._generic_iterator(self._graph_execution_traces_path) + + def close(self): + for reader in self._readers.values(): + reader.Close() diff --git a/tensorflow/python/debug/lib/debug_events_writer_test.py b/tensorflow/python/debug/lib/debug_events_writer_test.py index 5c85ec6dcdc..86e7fd26e1a 100644 --- a/tensorflow/python/debug/lib/debug_events_writer_test.py +++ b/tensorflow/python/debug/lib/debug_events_writer_test.py @@ -20,31 +20,19 @@ from __future__ import print_function import glob import os -import tempfile import threading from tensorflow.core.protobuf import debug_event_pb2 from tensorflow.python.debug.lib import debug_events_reader from tensorflow.python.debug.lib import debug_events_writer +from tensorflow.python.debug.lib import dumping_callback_test_lib from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util -from tensorflow.python.lib.io import file_io from tensorflow.python.platform import googletest -class PywrapeventsWriterTest(test_util.TensorFlowTestCase): - - def setUp(self): - super(PywrapeventsWriterTest, self).setUp() - self.dump_root = tempfile.mkdtemp() - - def tearDown(self): - if os.path.isdir(self.dump_root): - file_io.delete_recursively(self.dump_root) - super(PywrapeventsWriterTest, self).tearDown() +class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase): def testMultiThreadedConstructorCallWorks(self): - def InitWriter(): debug_events_writer.DebugEventsWriter(self.dump_root) @@ -68,14 +56,7 @@ class PywrapeventsWriterTest(test_util.TensorFlowTestCase): self.assertEqual(len(stack_frames_paths), 1) graphs_paths = glob.glob(os.path.join(self.dump_root, "*.graphs")) self.assertEqual(len(graphs_paths), 1) - - # Verify the content of the metadata file. - reader = debug_events_reader.DebugEventsReader(self.dump_root) - metadata_iter = reader.metadata_iterator() - debug_event = next(metadata_iter) - self.assertTrue(debug_event.debug_metadata.tensorflow_version) - self.assertTrue( - debug_event.debug_metadata.file_version.startswith("debug.Event:")) + self._readAndCheckMetadataFile() def testWriteSourceFilesAndStackFrames(self): writer = debug_events_writer.DebugEventsWriter(self.dump_root) @@ -94,21 +75,21 @@ class PywrapeventsWriterTest(test_util.TensorFlowTestCase): writer.FlushNonExecutionFiles() - reader = debug_events_reader.DebugEventsReader(self.dump_root) - actuals = list(reader.source_files_iterator()) - self.assertLen(actuals, num_protos) - for i in range(num_protos): - self.assertEqual(actuals[i].source_file.file_path, - "/home/tf2user/main.py") - self.assertEqual(actuals[i].source_file.host_name, "machine.cluster") - self.assertEqual(actuals[i].source_file.lines, ["print(%d)" % i]) + with debug_events_reader.DebugEventsReader(self.dump_root) as reader: + actuals = list(reader.source_files_iterator()) + self.assertLen(actuals, num_protos) + for i in range(num_protos): + self.assertEqual(actuals[i].source_file.file_path, + "/home/tf2user/main.py") + self.assertEqual(actuals[i].source_file.host_name, "machine.cluster") + self.assertEqual(actuals[i].source_file.lines, ["print(%d)" % i]) - actuals = list(reader.stack_frames_iterator()) - self.assertLen(actuals, num_protos) - for i in range(num_protos): - self.assertEqual(actuals[i].stack_frame_with_id.id, "stack_%d" % i) - self.assertEqual(actuals[i].stack_frame_with_id.file_line_col.file_index, - i * 10) + actuals = list(reader.stack_frames_iterator()) + self.assertLen(actuals, num_protos) + for i in range(num_protos): + self.assertEqual(actuals[i].stack_frame_with_id.id, "stack_%d" % i) + self.assertEqual( + actuals[i].stack_frame_with_id.file_line_col.file_index, i * 10) def testWriteGraphOpCreationAndDebuggedGraphs(self): writer = debug_events_writer.DebugEventsWriter(self.dump_root) @@ -188,15 +169,15 @@ class PywrapeventsWriterTest(test_util.TensorFlowTestCase): for thread in threads: thread.join() - reader = debug_events_reader.DebugEventsReader(self.dump_root) # Verify the content of the .source_files file. - source_files_iter = reader.source_files_iterator() - actuals = list(source_files_iter) - file_paths = sorted([actual.source_file.file_path for actual in actuals]) - self.assertEqual(file_paths, [ - "/home/tf2user/file_0.py", "/home/tf2user/file_1.py", - "/home/tf2user/file_2.py" - ]) + with debug_events_reader.DebugEventsReader(self.dump_root) as reader: + source_files_iter = reader.source_files_iterator() + actuals = list(source_files_iter) + file_paths = sorted([actual.source_file.file_path for actual in actuals]) + self.assertEqual(file_paths, [ + "/home/tf2user/file_0.py", "/home/tf2user/file_1.py", + "/home/tf2user/file_2.py" + ]) # Verify the content of the .stack_frames file. actuals = list(reader.stack_frames_iterator()) @@ -219,18 +200,16 @@ class PywrapeventsWriterTest(test_util.TensorFlowTestCase): execution.op_type = "OpType%d" % i writer.WriteExecution(execution) - reader = debug_events_reader.DebugEventsReader(self.dump_root) - actuals = list(reader.execution_iterator()) # Before FlushExecutionFiles() is called. No data should have been written # to the file. - self.assertEqual(len(actuals), 0) + executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile() + self.assertFalse(executed_op_types) writer.FlushExecutionFiles() - actuals = list(reader.execution_iterator()) - self.assertLen(actuals, debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE) - for i in range(debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE): + executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile() + for i, executed_op_type in enumerate(executed_op_types): self.assertEqual( - actuals[i].execution.op_type, + executed_op_type, "OpType%d" % (i + debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE)) def testWriteExecutionEventsWithoutCircularBufferBehavior(self): @@ -243,11 +222,10 @@ class PywrapeventsWriterTest(test_util.TensorFlowTestCase): writer.WriteExecution(execution) writer.FlushExecutionFiles() - reader = debug_events_reader.DebugEventsReader(self.dump_root) - actuals = list(reader.execution_iterator()) - self.assertLen(actuals, num_execution_events) - for i in range(num_execution_events): - self.assertEqual(actuals[i].execution.op_type, "OpType%d" % i) + executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile() + self.assertLen(executed_op_types, num_execution_events) + for i, executed_op_type in enumerate(executed_op_types): + self.assertEqual(executed_op_type, "OpType%d" % i) def testWriteGraphExecutionTraceEventsWithCircularBuffer(self): writer = debug_events_writer.DebugEventsWriter(self.dump_root) @@ -257,19 +235,19 @@ class PywrapeventsWriterTest(test_util.TensorFlowTestCase): trace.op_name = "Op%d" % i writer.WriteGraphExecutionTrace(trace) - reader = debug_events_reader.DebugEventsReader(self.dump_root) - actuals = list(reader.graph_execution_traces_iterator()) - # Before FlushExecutionFiles() is called. No data should have been written - # to the file. - self.assertEqual(len(actuals), 0) + with debug_events_reader.DebugEventsReader(self.dump_root) as reader: + actuals = list(reader.graph_execution_traces_iterator()) + # Before FlushExecutionFiles() is called. No data should have been written + # to the file. + self.assertEqual(len(actuals), 0) - writer.FlushExecutionFiles() - actuals = list(reader.graph_execution_traces_iterator()) - self.assertLen(actuals, debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE) - for i in range(debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE): - self.assertEqual( - actuals[i].graph_execution_trace.op_name, - "Op%d" % (i + debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE)) + writer.FlushExecutionFiles() + actuals = list(reader.graph_execution_traces_iterator()) + self.assertLen(actuals, debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE) + for i in range(debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE): + self.assertEqual( + actuals[i].graph_execution_trace.op_name, + "Op%d" % (i + debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE)) def testWriteGraphExecutionTraceEventsWithoutCircularBufferBehavior(self): # A circular buffer size of 0 abolishes the circular buffer behavior. @@ -281,8 +259,8 @@ class PywrapeventsWriterTest(test_util.TensorFlowTestCase): writer.WriteGraphExecutionTrace(trace) writer.FlushExecutionFiles() - reader = debug_events_reader.DebugEventsReader(self.dump_root) - actuals = list(reader.graph_execution_traces_iterator()) + with debug_events_reader.DebugEventsReader(self.dump_root) as reader: + actuals = list(reader.graph_execution_traces_iterator()) self.assertLen(actuals, num_execution_events) for i in range(num_execution_events): self.assertEqual(actuals[i].graph_execution_trace.op_name, "Op%d" % i) @@ -324,18 +302,17 @@ class PywrapeventsWriterTest(test_util.TensorFlowTestCase): writer.FlushExecutionFiles() # Verify the content of the .execution file. - reader = debug_events_reader.DebugEventsReader(self.dump_root) - actuals = list(reader.execution_iterator()) - op_types = sorted([actual.execution.op_type for actual in actuals]) - self.assertLen(op_types, circular_buffer_size) - self.assertLen(op_types, len(set(op_types))) + executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile() + self.assertLen(executed_op_types, circular_buffer_size) + self.assertLen(executed_op_types, len(set(executed_op_types))) # Verify the content of the .execution file. - actuals = list(reader.graph_execution_traces_iterator()) - op_names = sorted( - [actual.graph_execution_trace.op_name for actual in actuals]) - self.assertLen(op_names, circular_buffer_size) - self.assertLen(op_names, len(set(op_names))) + with debug_events_reader.DebugEventsReader(self.dump_root) as reader: + actuals = list(reader.graph_execution_traces_iterator()) + op_names = sorted( + [actual.graph_execution_trace.op_name for actual in actuals]) + self.assertLen(op_names, circular_buffer_size) + self.assertLen(op_names, len(set(op_names))) if __name__ == "__main__": diff --git a/tensorflow/python/debug/lib/debug_v2_ops_test.py b/tensorflow/python/debug/lib/debug_v2_ops_test.py index f4a8b46352c..08b0ec17316 100644 --- a/tensorflow/python/debug/lib/debug_v2_ops_test.py +++ b/tensorflow/python/debug/lib/debug_v2_ops_test.py @@ -19,30 +19,28 @@ from __future__ import division from __future__ import print_function import os -import tempfile import numpy as np from tensorflow.core.protobuf import debug_event_pb2 from tensorflow.python.debug.lib import debug_events_reader from tensorflow.python.debug.lib import debug_events_writer +from tensorflow.python.debug.lib import dumping_callback_test_lib from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util -from tensorflow.python.lib.io import file_io from tensorflow.python.ops import gen_debug_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class DebugIdentityV2OpTest(test_util.TensorFlowTestCase): +class DebugIdentityV2OpTest(dumping_callback_test_lib.DumpingCallbackTestBase): def setUp(self): super(DebugIdentityV2OpTest, self).setUp() - self.dump_root = tempfile.mkdtemp() # Testing using a small circular-buffer size. self.circular_buffer_size = 4 self.writer = debug_events_writer.DebugEventsWriter( @@ -50,8 +48,6 @@ class DebugIdentityV2OpTest(test_util.TensorFlowTestCase): def tearDown(self): self.writer.Close() - if os.path.isdir(self.dump_root): - file_io.delete_recursively(self.dump_root) super(DebugIdentityV2OpTest, self).tearDown() @test_util.run_in_graph_and_eager_modes @@ -87,55 +83,55 @@ class DebugIdentityV2OpTest(test_util.TensorFlowTestCase): self.assertAllClose( write_debug_trace(x), [9.0 + np.sqrt(3.0), 16.0 + 2.0]) - reader = debug_events_reader.DebugEventsReader(self.dump_root) - metadata_iter = reader.metadata_iterator() - # Check that the .metadata DebugEvents data file has been created, even - # before FlushExecutionFiles() is called. - debug_event = next(metadata_iter) - self.assertGreater(debug_event.wall_time, 0) - self.assertTrue(debug_event.debug_metadata.tensorflow_version) - self.assertTrue( - debug_event.debug_metadata.file_version.startswith("debug.Event:")) - - graph_trace_iter = reader.graph_execution_traces_iterator() - # Before FlushExecutionFiles() is called, the .graph_execution_traces file - # ought to be empty. - with self.assertRaises(StopIteration): - next(graph_trace_iter) - - # Flush the circular buffer. - self.writer.FlushExecutionFiles() - graph_trace_iter = reader.graph_execution_traces_iterator() - - # The circular buffer has a size of 4. So only the data from the - # last two iterations should have been written to self.dump_root. - for _ in range(2): - debug_event = next(graph_trace_iter) + with debug_events_reader.DebugEventsReader(self.dump_root) as reader: + metadata_iter = reader.metadata_iterator() + # Check that the .metadata DebugEvents data file has been created, even + # before FlushExecutionFiles() is called. + debug_event = next(metadata_iter) self.assertGreater(debug_event.wall_time, 0) - trace = debug_event.graph_execution_trace - self.assertEqual(trace.tfdbg_context_id, "deadbeaf") - self.assertEqual(trace.op_name, "Square") - self.assertEqual(trace.output_slot, 0) - self.assertEqual(trace.tensor_debug_mode, - debug_event_pb2.TensorDebugMode.FULL_TENSOR) - tensor_value = tensor_util.MakeNdarray(trace.tensor_proto) - self.assertAllClose(tensor_value, [9.0, 16.0]) + self.assertTrue(debug_event.debug_metadata.tensorflow_version) + self.assertTrue( + debug_event.debug_metadata.file_version.startswith("debug.Event:")) - debug_event = next(graph_trace_iter) - self.assertGreater(debug_event.wall_time, 0) - trace = debug_event.graph_execution_trace - self.assertEqual(trace.tfdbg_context_id, "beafdead") - self.assertEqual(trace.op_name, "Sqrt") - self.assertEqual(trace.output_slot, 0) - self.assertEqual(trace.tensor_debug_mode, - debug_event_pb2.TensorDebugMode.FULL_TENSOR) - tensor_value = tensor_util.MakeNdarray(trace.tensor_proto) - self.assertAllClose(tensor_value, [np.sqrt(3.0), 2.0]) + graph_trace_iter = reader.graph_execution_traces_iterator() + # Before FlushExecutionFiles() is called, the .graph_execution_traces file + # ought to be empty. + with self.assertRaises(StopIteration): + next(graph_trace_iter) - # Only the graph-execution trace of the last iteration should be written - # to self.dump_root. - with self.assertRaises(StopIteration): - next(graph_trace_iter) + # Flush the circular buffer. + self.writer.FlushExecutionFiles() + graph_trace_iter = reader.graph_execution_traces_iterator() + + # The circular buffer has a size of 4. So only the data from the + # last two iterations should have been written to self.dump_root. + for _ in range(2): + debug_event = next(graph_trace_iter) + self.assertGreater(debug_event.wall_time, 0) + trace = debug_event.graph_execution_trace + self.assertEqual(trace.tfdbg_context_id, "deadbeaf") + self.assertEqual(trace.op_name, "Square") + self.assertEqual(trace.output_slot, 0) + self.assertEqual(trace.tensor_debug_mode, + debug_event_pb2.TensorDebugMode.FULL_TENSOR) + tensor_value = tensor_util.MakeNdarray(trace.tensor_proto) + self.assertAllClose(tensor_value, [9.0, 16.0]) + + debug_event = next(graph_trace_iter) + self.assertGreater(debug_event.wall_time, 0) + trace = debug_event.graph_execution_trace + self.assertEqual(trace.tfdbg_context_id, "beafdead") + self.assertEqual(trace.op_name, "Sqrt") + self.assertEqual(trace.output_slot, 0) + self.assertEqual(trace.tensor_debug_mode, + debug_event_pb2.TensorDebugMode.FULL_TENSOR) + tensor_value = tensor_util.MakeNdarray(trace.tensor_proto) + self.assertAllClose(tensor_value, [np.sqrt(3.0), 2.0]) + + # Only the graph-execution trace of the last iteration should be written + # to self.dump_root. + with self.assertRaises(StopIteration): + next(graph_trace_iter) @test_util.run_in_graph_and_eager_modes def testControlFlow(self): @@ -162,28 +158,28 @@ class DebugIdentityV2OpTest(test_util.TensorFlowTestCase): self.evaluate(collatz(x)) self.writer.FlushExecutionFiles() - reader = debug_events_reader.DebugEventsReader(self.dump_root) - graph_trace_iter = reader.graph_execution_traces_iterator() - try: - x_values = [] - timestamp = 0 - while True: - debug_event = next(graph_trace_iter) - self.assertGreater(debug_event.wall_time, timestamp) - timestamp = debug_event.wall_time - trace = debug_event.graph_execution_trace - self.assertEqual(trace.tfdbg_context_id, "deadbeaf") - self.assertEqual(trace.op_name, "x") - self.assertEqual(trace.output_slot, 0) - self.assertEqual(trace.tensor_debug_mode, - debug_event_pb2.TensorDebugMode.FULL_TENSOR) - x_values.append(int(tensor_util.MakeNdarray(trace.tensor_proto))) - except StopIteration: - pass + with debug_events_reader.DebugEventsReader(self.dump_root) as reader: + graph_trace_iter = reader.graph_execution_traces_iterator() + try: + x_values = [] + timestamp = 0 + while True: + debug_event = next(graph_trace_iter) + self.assertGreater(debug_event.wall_time, timestamp) + timestamp = debug_event.wall_time + trace = debug_event.graph_execution_trace + self.assertEqual(trace.tfdbg_context_id, "deadbeaf") + self.assertEqual(trace.op_name, "x") + self.assertEqual(trace.output_slot, 0) + self.assertEqual(trace.tensor_debug_mode, + debug_event_pb2.TensorDebugMode.FULL_TENSOR) + x_values.append(int(tensor_util.MakeNdarray(trace.tensor_proto))) + except StopIteration: + pass - # Due to the circular buffer, only the last 4 iterations of - # [10, 5, 16, 8, 4, 2] should have been written. - self.assertAllEqual(x_values, [16, 8, 4, 2]) + # Due to the circular buffer, only the last 4 iterations of + # [10, 5, 16, 8, 4, 2] should have been written. + self.assertAllEqual(x_values, [16, 8, 4, 2]) @test_util.run_in_graph_and_eager_modes def testTwoDumpRoots(self): @@ -210,20 +206,20 @@ class DebugIdentityV2OpTest(test_util.TensorFlowTestCase): another_writer.Close() for debug_root in (self.dump_root, another_dump_root): - reader = debug_events_reader.DebugEventsReader(debug_root) - graph_trace_iter = reader.graph_execution_traces_iterator() + with debug_events_reader.DebugEventsReader(debug_root) as reader: + graph_trace_iter = reader.graph_execution_traces_iterator() - debug_event = next(graph_trace_iter) - trace = debug_event.graph_execution_trace - self.assertEqual(trace.tfdbg_context_id, "deadbeaf") - self.assertEqual(trace.op_name, "") - self.assertEqual(trace.tensor_debug_mode, - debug_event_pb2.TensorDebugMode.FULL_TENSOR) - tensor_value = tensor_util.MakeNdarray(trace.tensor_proto) - self.assertAllClose(tensor_value, [9.0, 16.0]) + debug_event = next(graph_trace_iter) + trace = debug_event.graph_execution_trace + self.assertEqual(trace.tfdbg_context_id, "deadbeaf") + self.assertEqual(trace.op_name, "") + self.assertEqual(trace.tensor_debug_mode, + debug_event_pb2.TensorDebugMode.FULL_TENSOR) + tensor_value = tensor_util.MakeNdarray(trace.tensor_proto) + self.assertAllClose(tensor_value, [9.0, 16.0]) - with self.assertRaises(StopIteration): - next(graph_trace_iter) + with self.assertRaises(StopIteration): + next(graph_trace_iter) @test_util.run_in_graph_and_eager_modes def testDebugNumericSummaryV2OpReduceInfNanTwoSlots(self): diff --git a/tensorflow/python/debug/lib/distributed_callbacks_test.py b/tensorflow/python/debug/lib/distributed_callbacks_test.py index bd9d908fd36..e1ff0f823c3 100644 --- a/tensorflow/python/debug/lib/distributed_callbacks_test.py +++ b/tensorflow/python/debug/lib/distributed_callbacks_test.py @@ -178,7 +178,7 @@ class DistributedDumpingCallbackTest( stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() (context_ids, _, - op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id) + op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id) (op_names, device_names, _, tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids) executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names] @@ -261,7 +261,7 @@ class DistributedDumpingCallbackTest( stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() (context_ids, _, - op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id) + op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id) (op_names, device_names, _, tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids) diff --git a/tensorflow/python/debug/lib/dumping_callback.py b/tensorflow/python/debug/lib/dumping_callback.py index 96427536e1d..adb924aefaa 100644 --- a/tensorflow/python/debug/lib/dumping_callback.py +++ b/tensorflow/python/debug/lib/dumping_callback.py @@ -119,6 +119,8 @@ class _DumpingCallback(object): """Get a unique ID for an op-construction context (e.g., a graph). If the graph has been encountered before, reuse the same unique ID. + When encountering a new context (graph), this methods writes a DebugEvent + proto with the debugged_graph field to the proper DebugEvent file. Args: context: A context to get the unique ID for. Must be hashable. E.g., a @@ -130,10 +132,34 @@ class _DumpingCallback(object): # Use the double-checked lock pattern to optimize the common case. if context in self._context_to_id: # 1st check, without lock. return self._context_to_id[context] + graph_is_new = False with self._context_to_id_lock: if context not in self._context_to_id: # 2nd check, with lock. - self._context_to_id[context] = _get_id() - return self._context_to_id[context] + graph_is_new = True + context_id = _get_id() + self._context_to_id[context] = context_id + if graph_is_new: + self.get_writer().WriteDebuggedGraph(debug_event_pb2.DebuggedGraph( + graph_id=context_id, + graph_name=getattr(context, "name", None), + outer_context_id=self._get_outer_context_id(context))) + return self._context_to_id[context] + + def _get_outer_context_id(self, graph): + """Get the ID of the immediate outer context of the input graph. + + Args: + graph: The graph (context) in question. + + Returns: + If an outer context exists, the immediate outer context name as a string. + If such as outer context does not exist (i.e., `graph` is itself + outermost), `None`. + """ + if hasattr(graph, "outer_graph") and graph.outer_graph: + return self._get_context_id(graph.outer_graph) + else: + return None def _write_source_file_content(self, file_path): """Send the content of a source file via debug-events writer. @@ -352,7 +378,7 @@ class _DumpingCallback(object): writer = self.get_writer() if graph: - context_id = self._get_context_id(graph) + context_id = self._get_context_id(graph) # Innermost context ID. assert op_name is not None output_tensor_ids = self._get_symbolic_tensor_ids(len(outputs)) graph_op_creation = debug_event_pb2.GraphOpCreation( diff --git a/tensorflow/python/debug/lib/dumping_callback_test.py b/tensorflow/python/debug/lib/dumping_callback_test.py index 8cc0242c062..d32d543b382 100644 --- a/tensorflow/python/debug/lib/dumping_callback_test.py +++ b/tensorflow/python/debug/lib/dumping_callback_test.py @@ -112,84 +112,84 @@ class TracingCallbackTest( # Before FlushExecutionFiles() is called, the .execution file should be # empty. - reader = debug_events_reader.DebugEventsReader(self.dump_root) - execution_iter = reader.execution_iterator() - with self.assertRaises(StopIteration): - next(execution_iter) + with debug_events_reader.DebugEventsReader(self.dump_root) as reader: + execution_iter = reader.execution_iterator() + with self.assertRaises(StopIteration): + next(execution_iter) - # After the flushing, the .execution file should hold the appropriate - # contents. - writer.FlushExecutionFiles() - execution_iter = reader.execution_iterator() - prev_wall_time = 1 - executed_op_types = [] - tensor_values = collections.defaultdict(lambda: []) - for debug_event in execution_iter: - self.assertGreaterEqual(debug_event.wall_time, prev_wall_time) - prev_wall_time = debug_event.wall_time - execution = debug_event.execution - executed_op_types.append(execution.op_type) - self.assertTrue(execution.input_tensor_ids) - self.assertTrue(execution.output_tensor_ids) - if tensor_debug_mode == "NO_TENSOR": - # Due to the NO_TENSOR tensor debug mode, tensor_protos ought to - # be empty. - self.assertFalse(execution.tensor_protos) - elif tensor_debug_mode == "FULL_TENSOR": - # Under the FULL_TENSOR mode, the value of the tensor should be - # available through `tensor_protos`. - tensor_value = float( - tensor_util.MakeNdarray(execution.tensor_protos[0])) - tensor_values[execution.op_type].append(tensor_value) - # Verify the code_location field. - self.assertTrue(execution.code_location.stack_frame_ids) - for stack_frame_id in execution.code_location.stack_frame_ids: - self.assertIn(stack_frame_id, stack_frame_by_id) - if tensor_debug_mode == "FULL_TENSOR": - self.assertAllClose(tensor_values["Greater"], [1, 1, 1, 1, 1, 1, 0]) - self.assertAllClose(tensor_values["RealDiv"], [5, 8, 4, 2, 1]) - self.assertAllClose(tensor_values["Mul"], [15]) - self.assertAllClose(tensor_values["AddV2"], [16]) + # After the flushing, the .execution file should hold the appropriate + # contents. + writer.FlushExecutionFiles() + execution_iter = reader.execution_iterator() + prev_wall_time = 1 + executed_op_types = [] + tensor_values = collections.defaultdict(lambda: []) + for debug_event in execution_iter: + self.assertGreaterEqual(debug_event.wall_time, prev_wall_time) + prev_wall_time = debug_event.wall_time + execution = debug_event.execution + executed_op_types.append(execution.op_type) + self.assertTrue(execution.input_tensor_ids) + self.assertTrue(execution.output_tensor_ids) + if tensor_debug_mode == "NO_TENSOR": + # Due to the NO_TENSOR tensor debug mode, tensor_protos ought to + # be empty. + self.assertFalse(execution.tensor_protos) + elif tensor_debug_mode == "FULL_TENSOR": + # Under the FULL_TENSOR mode, the value of the tensor should be + # available through `tensor_protos`. + tensor_value = float( + tensor_util.MakeNdarray(execution.tensor_protos[0])) + tensor_values[execution.op_type].append(tensor_value) + # Verify the code_location field. + self.assertTrue(execution.code_location.stack_frame_ids) + for stack_frame_id in execution.code_location.stack_frame_ids: + self.assertIn(stack_frame_id, stack_frame_by_id) + if tensor_debug_mode == "FULL_TENSOR": + self.assertAllClose(tensor_values["Greater"], [1, 1, 1, 1, 1, 1, 0]) + self.assertAllClose(tensor_values["RealDiv"], [5, 8, 4, 2, 1]) + self.assertAllClose(tensor_values["Mul"], [15]) + self.assertAllClose(tensor_values["AddV2"], [16]) - self.assertEqual( - executed_op_types, - [ - "Greater", - "FloorMod", - "Equal", - "RealDiv", # 10 --> 5 - "Greater", - "FloorMod", - "Equal", - "Mul", - "AddV2", # 5 --> 16 - "Greater", - "FloorMod", - "Equal", - "RealDiv", # 16 --> 8 - "Greater", - "FloorMod", - "Equal", - "RealDiv", # 8 --> 4 - "Greater", - "FloorMod", - "Equal", - "RealDiv", # 4 --> 2 - "Greater", - "FloorMod", - "Equal", - "RealDiv", # 2 --> 1 - "Greater" - ]) + self.assertEqual( + executed_op_types, + [ + "Greater", + "FloorMod", + "Equal", + "RealDiv", # 10 --> 5 + "Greater", + "FloorMod", + "Equal", + "Mul", + "AddV2", # 5 --> 16 + "Greater", + "FloorMod", + "Equal", + "RealDiv", # 16 --> 8 + "Greater", + "FloorMod", + "Equal", + "RealDiv", # 8 --> 4 + "Greater", + "FloorMod", + "Equal", + "RealDiv", # 4 --> 2 + "Greater", + "FloorMod", + "Equal", + "RealDiv", # 2 --> 1 + "Greater" + ]) - # Due to the pure eager op execution, the .graph file and the - # .graph_execution_traces file ought to be empty. - graphs_iterator = reader.graphs_iterator() - with self.assertRaises(StopIteration): - next(graphs_iterator) - graph_trace_iter = reader.graph_execution_traces_iterator() - with self.assertRaises(StopIteration): - next(graph_trace_iter) + # Due to the pure eager op execution, the .graph file and the + # .graph_execution_traces file ought to be empty. + graphs_iterator = reader.graphs_iterator() + with self.assertRaises(StopIteration): + next(graphs_iterator) + graph_trace_iter = reader.graph_execution_traces_iterator() + with self.assertRaises(StopIteration): + next(graph_trace_iter) @parameterized.named_parameters( ("NoTensor", "NO_TENSOR"), @@ -225,7 +225,7 @@ class TracingCallbackTest( stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() (context_ids, op_types, - op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id) + op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id) self.assertIn("AddV2", op_types) self.assertIn("Log", op_types) self.assertIn("Sin", op_types) @@ -276,7 +276,7 @@ class TracingCallbackTest( stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() (context_ids, op_types, - op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id) + op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id) self.assertIn("AddV2", op_types) self.assertIn("Log", op_types) self.assertIn("Sin", op_types) @@ -354,7 +354,7 @@ class TracingCallbackTest( writer.FlushExecutionFiles() stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() (context_ids, _, - op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id) + op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id) (op_names, _, _, tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids) executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names] @@ -417,7 +417,7 @@ class TracingCallbackTest( stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() # Verify the content of the .graphs file. - context_ids, op_types, op_name_to_op_type = ( + context_ids, op_types, op_name_to_op_type, _ = ( self._readAndCheckGraphsFile(stack_frame_by_id)) self.assertIn("Less", op_types) self.assertIn("Mul", op_types) @@ -425,66 +425,67 @@ class TracingCallbackTest( # Before FlushExecutionFiles() is called, the .execution and # .graph_execution_traces files should be both empty. - reader = debug_events_reader.DebugEventsReader(self.dump_root) - execution_iter = reader.execution_iterator() - graph_execution_traces_iter = reader.graph_execution_traces_iterator() - with self.assertRaises(StopIteration): - next(execution_iter) - with self.assertRaises(StopIteration): - next(graph_execution_traces_iter) + with debug_events_reader.DebugEventsReader(self.dump_root) as reader: + execution_iter = reader.execution_iterator() + graph_execution_traces_iter = reader.graph_execution_traces_iterator() + with self.assertRaises(StopIteration): + next(execution_iter) + with self.assertRaises(StopIteration): + next(graph_execution_traces_iter) - # TODO(cais): Backport execution instrumentation to tf.Session. - writer.FlushExecutionFiles() - # After the flushing, the .execution file should hold the appropriate - # contents. - if context.executing_eagerly(): - (executed_op_types, input_tensor_ids, output_tensor_ids, - tensor_debug_modes, tensor_values) = self._readAndCheckExecutionFile() - # NOTE(b/142486213): Execution of the TF function happens with - # Session.run() in v1 graph mode, hence it doesn't get logged to the - # .execution file. - self.assertLen(executed_op_types, 1) - self.assertIn("iterative_doubling", executed_op_types[0]) - self.assertLen(input_tensor_ids[0], 2) - self.assertLen(output_tensor_ids[0], 1) - self.assertEqual(tensor_debug_modes[0], - debug_event_pb2.TensorDebugMode.Value(tensor_debug_mode)) - if tensor_debug_mode == "FULL_TENSOR": - self.assertAllClose(tensor_values, [[8.0]]) + # TODO(cais): Backport execution instrumentation to tf.Session. + writer.FlushExecutionFiles() + # After the flushing, the .execution file should hold the appropriate + # contents. + if context.executing_eagerly(): + (executed_op_types, input_tensor_ids, output_tensor_ids, + tensor_debug_modes, tensor_values) = self._readAndCheckExecutionFile() + # NOTE(b/142486213): Execution of the TF function happens with + # Session.run() in v1 graph mode, hence it doesn't get logged to the + # .execution file. + self.assertLen(executed_op_types, 1) + self.assertIn("iterative_doubling", executed_op_types[0]) + self.assertLen(input_tensor_ids[0], 2) + self.assertLen(output_tensor_ids[0], 1) + self.assertEqual( + tensor_debug_modes[0], + debug_event_pb2.TensorDebugMode.Value(tensor_debug_mode)) + if tensor_debug_mode == "FULL_TENSOR": + self.assertAllClose(tensor_values, [[8.0]]) - (op_names, _, output_slots, - tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids) - executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names] - # The Less op should have been executed 5 times. - self.assertEqual(executed_op_types.count("Less"), 5) - # The last executed op should be Less. - self.assertEqual(executed_op_types[-1], "Less") - # The Mul op should have been executed 4 times. - self.assertEqual(executed_op_types.count("Mul"), 4) - # The AddV2 op should have been run, but we refrain from asserting on how - # many times it's executed. - self.assertIn("AddV2", executed_op_types) - for output_slot in output_slots: - self.assertEqual(output_slot, 0) - if tensor_debug_mode == "NO_TENSOR": - # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought to - # be an empty float32 tensor. - for tensor_value in tensor_values: - self.assertEqual(tensor_value.dtype, np.float32) - self.assertEqual(tensor_value.shape, (0,)) - elif tensor_debug_mode == "FULL_TENSOR": - less_values = [ - tensor_values[i] - for i, op_type in enumerate(executed_op_types) - if op_type == "Less" - ] - self.assertAllClose(less_values, [True, True, True, True, False]) - mul_values = [ - tensor_values[i] - for i, op_type in enumerate(executed_op_types) - if op_type == "Mul" - ] - self.assertAllClose(mul_values, [1.0, 2.0, 4.0, 8.0]) + (op_names, _, output_slots, + tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids) + executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names] + # The Less op should have been executed 5 times. + self.assertEqual(executed_op_types.count("Less"), 5) + # The last executed op should be Less. + self.assertEqual(executed_op_types[-1], "Less") + # The Mul op should have been executed 4 times. + self.assertEqual(executed_op_types.count("Mul"), 4) + # The AddV2 op should have been run, but we refrain from asserting on how + # many times it's executed. + self.assertIn("AddV2", executed_op_types) + for output_slot in output_slots: + self.assertEqual(output_slot, 0) + if tensor_debug_mode == "NO_TENSOR": + # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought + # to be an empty float32 tensor. + for tensor_value in tensor_values: + self.assertEqual(tensor_value.dtype, np.float32) + self.assertEqual(tensor_value.shape, (0,)) + elif tensor_debug_mode == "FULL_TENSOR": + less_values = [ + tensor_values[i] + for i, op_type in enumerate(executed_op_types) + if op_type == "Less" + ] + self.assertAllClose(less_values, [True, True, True, True, False]) + mul_values = [ + tensor_values[i] + for i, op_type in enumerate(executed_op_types) + if op_type == "Mul" + ] + self.assertAllClose(mul_values, [1.0, 2.0, 4.0, 8.0]) def testCallingEnableTracingTwiceWithTheSameDumpRootIsIdempotent(self): dumping_callback.enable_dump_debug_info(self.dump_root) @@ -497,17 +498,17 @@ class TracingCallbackTest( writer.FlushNonExecutionFiles() writer.FlushExecutionFiles() - reader = debug_events_reader.DebugEventsReader(self.dump_root) - execution_iter = reader.execution_iterator() - for _ in range(2): - debug_event = next(execution_iter) - self.assertGreater(debug_event.wall_time, 0) - execution = debug_event.execution - self.assertEqual(execution.op_type, "Unique") - self.assertEqual(execution.num_outputs, 2) - self.assertTrue(execution.code_location) - with self.assertRaises(StopIteration): - next(execution_iter) + with debug_events_reader.DebugEventsReader(self.dump_root) as reader: + execution_iter = reader.execution_iterator() + for _ in range(2): + debug_event = next(execution_iter) + self.assertGreater(debug_event.wall_time, 0) + execution = debug_event.execution + self.assertEqual(execution.op_type, "Unique") + self.assertEqual(execution.num_outputs, 2) + self.assertTrue(execution.code_location) + with self.assertRaises(StopIteration): + next(execution_iter) def testCallingEnableTracingTwiceWithDifferentDumpRootsOverwrites(self): dumping_callback.enable_dump_debug_info(self.dump_root) @@ -521,23 +522,24 @@ class TracingCallbackTest( writer.FlushNonExecutionFiles() writer.FlushExecutionFiles() - reader = debug_events_reader.DebugEventsReader(new_dump_root) - execution_iter = reader.execution_iterator() - for _ in range(2): - debug_event = next(execution_iter) - self.assertGreater(debug_event.wall_time, 0) - execution = debug_event.execution - self.assertEqual(execution.op_type, "Unique") - self.assertEqual(execution.num_outputs, 2) - self.assertTrue(execution.code_location) - with self.assertRaises(StopIteration): - next(execution_iter) + with debug_events_reader.DebugEventsReader(new_dump_root) as reader: + execution_iter = reader.execution_iterator() + for _ in range(2): + debug_event = next(execution_iter) + self.assertGreater(debug_event.wall_time, 0) + execution = debug_event.execution + self.assertEqual(execution.op_type, "Unique") + self.assertEqual(execution.num_outputs, 2) + self.assertTrue(execution.code_location) + with self.assertRaises(StopIteration): + next(execution_iter) - old_dump_root_reader = debug_events_reader.DebugEventsReader(self.dump_root) - execution_iter = old_dump_root_reader.execution_iterator() - # The old dump root shouldn't have been written to. - with self.assertRaises(StopIteration): - next(execution_iter) + with debug_events_reader.DebugEventsReader( + self.dump_root) as old_dump_root_reader: + execution_iter = old_dump_root_reader.execution_iterator() + # The old dump root shouldn't have been written to. + with self.assertRaises(StopIteration): + next(execution_iter) def testCallingEnableRepeatedlyWithDifferentTensorDebugMode(self): """Assert that calling enable_dump_debug_info() with different tensor-debug modes. @@ -555,7 +557,7 @@ class TracingCallbackTest( writer.FlushNonExecutionFiles() writer.FlushExecutionFiles() stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() - context_ids, _, _ = self._readAndCheckGraphsFile(stack_frame_by_id) + context_ids, _, _, _ = self._readAndCheckGraphsFile(stack_frame_by_id) _, _, _, _, tensor_values = self._readAndCheckExecutionFile() self.assertEqual(tensor_values, [[]]) (_, _, _, @@ -586,17 +588,17 @@ class TracingCallbackTest( writer.FlushNonExecutionFiles() writer.FlushExecutionFiles() - reader = debug_events_reader.DebugEventsReader(self.dump_root) - source_files_iter = reader.source_files_iterator() - stack_frames_iter = reader.stack_frames_iterator() - execution_iter = reader.execution_iterator() - # No source-file, stack-frame or execution data should have been dumped. - with self.assertRaises(StopIteration): - next(source_files_iter) - with self.assertRaises(StopIteration): - next(stack_frames_iter) - with self.assertRaises(StopIteration): - next(execution_iter) + with debug_events_reader.DebugEventsReader(self.dump_root) as reader: + source_files_iter = reader.source_files_iterator() + stack_frames_iter = reader.stack_frames_iterator() + execution_iter = reader.execution_iterator() + # No source-file, stack-frame or execution data should have been dumped. + with self.assertRaises(StopIteration): + next(source_files_iter) + with self.assertRaises(StopIteration): + next(stack_frames_iter) + with self.assertRaises(StopIteration): + next(execution_iter) @parameterized.named_parameters( ("NoTensor", "NO_TENSOR"), @@ -630,15 +632,15 @@ class TracingCallbackTest( writer.FlushExecutionFiles() stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() - reader = debug_events_reader.DebugEventsReader(self.dump_root) - execution_iter = reader.execution_iterator() - prev_wall_time = 1 - for debug_event in execution_iter: - self.assertGreaterEqual(debug_event.wall_time, prev_wall_time) - prev_wall_time = debug_event.wall_time + with debug_events_reader.DebugEventsReader(self.dump_root) as reader: + execution_iter = reader.execution_iterator() + prev_wall_time = 1 + for debug_event in execution_iter: + self.assertGreaterEqual(debug_event.wall_time, prev_wall_time) + prev_wall_time = debug_event.wall_time (context_ids, _, - op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id) + op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id) (op_names, _, output_slots, tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids) @@ -718,6 +720,57 @@ class TracingCallbackTest( v2_squared_values = tensor_values[executed_op_types.index("Pow")] self.assertAllClose(v2_squared_values, [9.0]) + @test_util.run_in_graph_and_eager_modes + def testNestedContextIsCapturedByGraphOpCreationHistory(self): + writer = dumping_callback.enable_dump_debug_info( + self.dump_root, tensor_debug_mode="NO_TENSOR") + + @def_function.function + def iterative_doubling(x, times): + i = constant_op.constant(0, dtype=dtypes.int32) + while i < times: + x = x * 2.0 - 1.0 + i += 1 + return x + + x = constant_op.constant(2.0, dtype=dtypes.float32) + times = constant_op.constant(4, dtype=dtypes.int32) + # 2 * 2 - 1 = 3; 3 * 2 - 1 = 5; 5 * 2 - 1 = 9; 9 * 2 - 1 = 17. + self.assertAllClose(self.evaluate(iterative_doubling(x, times)), 17.0) + + writer.FlushNonExecutionFiles() + writer.FlushExecutionFiles() + + stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() + (_, _, op_name_to_op_type, + op_name_to_context_id) = self._readAndCheckGraphsFile(stack_frame_by_id) + + less_op_names = [op_name for op_name in op_name_to_op_type + if op_name_to_op_type[op_name] == "Less"] + less_context_ids = [op_name_to_context_id[op_name] + for op_name in less_op_names] + mul_op_names = [op_name for op_name in op_name_to_op_type + if op_name_to_op_type[op_name] == "Mul"] + mul_context_ids = [op_name_to_context_id[op_name] + for op_name in mul_op_names] + sub_op_names = [op_name for op_name in op_name_to_op_type + if op_name_to_op_type[op_name] == "Sub"] + sub_context_ids = [op_name_to_context_id[op_name] + for op_name in sub_op_names] + self.assertLen(less_context_ids, 1) + self.assertLen(mul_context_ids, 1) + self.assertLen(sub_context_ids, 1) + self.assertTrue(less_context_ids[0]) + self.assertTrue(mul_context_ids[0]) + self.assertTrue(sub_context_ids[0]) + # The Less op is from the while-loop cond context and hence should have + # a different innermost context ID from the mul and sub ops, which are both + # from the while-loop body context. + self.assertNotEqual(less_context_ids[0], mul_context_ids[0]) + self.assertNotEqual(less_context_ids[0], sub_context_ids[0]) + # The Mul and Sub ops are from the same innermost context. + self.assertEqual(mul_context_ids[0], sub_context_ids[0]) + @parameterized.named_parameters( ("NoTensor", "NO_TENSOR"), ("FullTensor", "FULL_TENSOR"), @@ -736,7 +789,7 @@ class TracingCallbackTest( stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() (context_ids, op_types, - op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id) + op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id) # Simply assert that graph are recorded and refrain from asserting on the # internal details of the Keras model. self.assertTrue(context_ids) @@ -803,7 +856,7 @@ class TracingCallbackTest( stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() (context_ids, op_types, - op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id) + op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id) # Simply assert that graph are recorded and refrain from asserting on the # internal details of the Keras model. self.assertTrue(context_ids) @@ -876,7 +929,7 @@ class TracingCallbackTest( stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() (context_ids, op_types, - op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id) + op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id) # Simply assert that graph are recorded and refrain from asserting on the # internal details of the Keras model. self.assertTrue(context_ids) diff --git a/tensorflow/python/debug/lib/dumping_callback_test_lib.py b/tensorflow/python/debug/lib/dumping_callback_test_lib.py index 2169ab9ce2b..74261f918ce 100644 --- a/tensorflow/python/debug/lib/dumping_callback_test_lib.py +++ b/tensorflow/python/debug/lib/dumping_callback_test_lib.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import os import shutil import socket @@ -49,59 +50,60 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase): def _readAndCheckMetadataFile(self): """Read and check the .metadata debug-events file.""" - reader = debug_events_reader.DebugEventsReader(self.dump_root) - metadata_iter = reader.metadata_iterator() - metadata = next(metadata_iter).debug_metadata - self.assertEqual(metadata.tensorflow_version, versions.__version__) - self.assertTrue(metadata.file_version.startswith("debug.Event")) + with debug_events_reader.DebugEventsReader(self.dump_root) as reader: + metadata_iter = reader.metadata_iterator() + metadata = next(metadata_iter).debug_metadata + self.assertEqual(metadata.tensorflow_version, versions.__version__) + self.assertTrue(metadata.file_version.startswith("debug.Event")) def _readAndCheckSourceFilesAndStackFrames(self): """Read and verify the .source_files & .stack_frames debug-event files. Returns: - A dict mapping stack frame IDs to stack frames (FileLineCol). + An OrderedDict mapping stack frame IDs to stack frames (FileLineCol). """ - reader = debug_events_reader.DebugEventsReader(self.dump_root) - # Check the content of the .source_files file. - source_files_iter = reader.source_files_iterator() - source_file_paths = [] - prev_wall_time = 1 - for debug_event in source_files_iter: - self.assertGreaterEqual(debug_event.wall_time, prev_wall_time) - prev_wall_time = debug_event.wall_time - source_file = debug_event.source_file - self.assertEqual(source_file.host_name, socket.gethostname()) - self.assertTrue(source_file.file_path) - if source_file.lines: - self.assertTrue(os.path.isfile(source_file.file_path)) - source_file_paths.append(source_file.file_path) - # Assert the file paths are unique. - self.assertEqual(len(source_file_paths), len(set(source_file_paths))) + with debug_events_reader.DebugEventsReader(self.dump_root) as reader: + # Check the content of the .source_files file. + source_files_iter = reader.source_files_iterator() + source_file_paths = [] + prev_wall_time = 1 + for debug_event in source_files_iter: + self.assertGreaterEqual(debug_event.wall_time, prev_wall_time) + prev_wall_time = debug_event.wall_time + source_file = debug_event.source_file + self.assertEqual(source_file.host_name, socket.gethostname()) + self.assertTrue(source_file.file_path) + if source_file.lines: + self.assertTrue(os.path.isfile(source_file.file_path)) + source_file_paths.append(source_file.file_path) + # Assert the file paths are unique. + self.assertEqual(len(source_file_paths), len(set(source_file_paths))) - # Check the content of the .stack_frames file. - stack_frame_by_id = dict() # A map from ID to stack frame. - stack_frames_iter = reader.stack_frames_iterator() - prev_wall_time = 0 - for debug_event in stack_frames_iter: - self.assertGreaterEqual(debug_event.wall_time, prev_wall_time) - prev_wall_time = debug_event.wall_time - stack_frame_with_id = debug_event.stack_frame_with_id - stack_frame_id = stack_frame_with_id.id - file_line_col = stack_frame_with_id.file_line_col - self.assertTrue(stack_frame_id) - self.assertNotIn(stack_frame_id, stack_frame_by_id, - "Duplicate stack frame ID: %s" % id) - stack_frame_by_id[stack_frame_id] = (file_line_col.file_index, - file_line_col.line, - file_line_col.func) - self.assertGreaterEqual(file_line_col.file_index, 0) - self.assertLess(file_line_col.file_index, len(source_file_paths)) - self.assertTrue(file_line_col.line) # Line numbers are 1-based. - self.assertTrue(file_line_col.func) - # Assert the stack frames are unique. - self.assertEqual( - len(stack_frame_by_id.values()), len(set(stack_frame_by_id.values()))) - return stack_frame_by_id + # Check the content of the .stack_frames file. + # A map from ID to stack frame. + stack_frame_by_id = collections.OrderedDict() + stack_frames_iter = reader.stack_frames_iterator() + prev_wall_time = 0 + for debug_event in stack_frames_iter: + self.assertGreaterEqual(debug_event.wall_time, prev_wall_time) + prev_wall_time = debug_event.wall_time + stack_frame_with_id = debug_event.stack_frame_with_id + stack_frame_id = stack_frame_with_id.id + file_line_col = stack_frame_with_id.file_line_col + self.assertTrue(stack_frame_id) + self.assertNotIn(stack_frame_id, stack_frame_by_id, + "Duplicate stack frame ID: %s" % id) + stack_frame_by_id[stack_frame_id] = (file_line_col.file_index, + file_line_col.line, + file_line_col.func) + self.assertGreaterEqual(file_line_col.file_index, 0) + self.assertLess(file_line_col.file_index, len(source_file_paths)) + self.assertTrue(file_line_col.line) # Line numbers are 1-based. + self.assertTrue(file_line_col.func) + # Assert the stack frames are unique. + self.assertEqual( + len(stack_frame_by_id.values()), len(set(stack_frame_by_id.values()))) + return stack_frame_by_id def _readAndCheckGraphsFile(self, stack_frame_by_id): """Read and verify the content of the .graphs debug-event file. @@ -115,36 +117,72 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase): `list` of `str`s. op_types: Types of the ops that are created, as a `list` of `str`s with the same length as `context_ids`. - op_name_to_op_type: A `dict` mapping op name to op type. + op_name_to_op_type: An `OrderedDict` mapping op name to op type. + op_name_to_context_id: A `dict` mapping op name to the ID of the innermost + containing graph (context). """ - reader = debug_events_reader.DebugEventsReader(self.dump_root) - graphs_iter = reader.graphs_iterator() - prev_wall_time = 0 - op_types = [] - op_name_to_op_type = dict() - context_ids = set() - symbolic_tensor_ids = set() - for debug_event in graphs_iter: - self.assertGreaterEqual(debug_event.wall_time, prev_wall_time) - prev_wall_time = debug_event.wall_time - graph_op_creation = debug_event.graph_op_creation - self.assertTrue(graph_op_creation.op_type) - op_types.append(graph_op_creation.op_type) - self.assertTrue(graph_op_creation.op_name) - op_name_to_op_type[graph_op_creation.op_name] = graph_op_creation.op_type - self.assertTrue(graph_op_creation.graph_id) - context_ids.add(graph_op_creation.graph_id) - self.assertTrue(graph_op_creation.code_location) - if graph_op_creation.num_outputs: - self.assertLen(graph_op_creation.output_tensor_ids, - graph_op_creation.num_outputs) - # Check that all symblic tensor IDs are unique. - for tensor_id in graph_op_creation.output_tensor_ids: - self.assertNotIn(tensor_id, symbolic_tensor_ids) - symbolic_tensor_ids.add(tensor_id) - for stack_frame_id in graph_op_creation.code_location.stack_frame_ids: - self.assertIn(stack_frame_id, stack_frame_by_id) - return context_ids, op_types, op_name_to_op_type + with debug_events_reader.DebugEventsReader(self.dump_root) as reader: + graphs_iter = reader.graphs_iterator() + prev_wall_time = 0 + op_types = [] + op_name_to_op_type = collections.OrderedDict() + op_name_to_context_id = dict() # Maps op name to ID of innermost context. + context_ids = set() + symbolic_tensor_ids = set() + # Maps context ID to ID of directly enclosing context (`None` for + # outermost contexts). + context_id_to_outer_id = dict() + + for debug_event in graphs_iter: + self.assertGreaterEqual(debug_event.wall_time, prev_wall_time) + prev_wall_time = debug_event.wall_time + # A DebugEvent in the .graphs file contains either of the two fields: + # - graph_op_creation for creation of a symbolic op in a graph context. + # - debugged_graph for information regarding the graph (context). + if debug_event.graph_op_creation.ByteSize(): + graph_op_creation = debug_event.graph_op_creation + self.assertTrue(graph_op_creation.op_type) + op_types.append(graph_op_creation.op_type) + self.assertTrue(graph_op_creation.op_name) + op_name_to_op_type[ + graph_op_creation.op_name] = graph_op_creation.op_type + op_name_to_context_id[ + graph_op_creation.op_name] = graph_op_creation.graph_id + self.assertTrue(graph_op_creation.graph_id) + context_ids.add(graph_op_creation.graph_id) + self.assertTrue(graph_op_creation.code_location) + if graph_op_creation.num_outputs: + self.assertLen(graph_op_creation.output_tensor_ids, + graph_op_creation.num_outputs) + # Check that all symblic tensor IDs are unique. + for tensor_id in graph_op_creation.output_tensor_ids: + self.assertNotIn(tensor_id, symbolic_tensor_ids) + symbolic_tensor_ids.add(tensor_id) + for stack_frame_id in graph_op_creation.code_location.stack_frame_ids: + self.assertIn(stack_frame_id, stack_frame_by_id) + else: + debugged_graph = debug_event.debugged_graph + if debugged_graph.outer_context_id: + inner_id = debugged_graph.graph_id + outer_id = debugged_graph.outer_context_id + if inner_id in context_id_to_outer_id: + # The outer context of a context must be always the same. + self.assertEqual(context_id_to_outer_id[inner_id], outer_id) + else: + context_id_to_outer_id[inner_id] = outer_id + else: + # This is an outermost context. + if debugged_graph.graph_id in context_id_to_outer_id: + self.assertIsNone(context_id_to_outer_id[debugged_graph.graph_id]) + else: + context_id_to_outer_id[debugged_graph.graph_id] = None + + # If any graph is created, the graph context hierarchy must be populated. + # In addition, the context of each graph op must be locatable within the + # graph context hierarchy. + for context_id in op_name_to_context_id.values(): + self.assertIn(context_id, context_id_to_outer_id) + return context_ids, op_types, op_name_to_op_type, op_name_to_context_id def _readAndCheckExecutionFile(self, dump_root=None): """Read and verify the content of the .execution debug-event file. @@ -167,31 +205,30 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase): output tensor slot of the executed op or Function. """ dump_root = self.dump_root if dump_root is None else dump_root - reader = debug_events_reader.DebugEventsReader(dump_root) - execution_iter = reader.execution_iterator() - prev_wall_time = 1 - executed_op_types = [] - input_tensor_ids = [] - output_tensor_ids = [] - tensor_debug_modes = [] - tensor_values = [] - for debug_event in execution_iter: - self.assertGreaterEqual(debug_event.wall_time, prev_wall_time) - prev_wall_time = debug_event.wall_time - execution = debug_event.execution - executed_op_types.append(execution.op_type) - input_tensor_ids.append(execution.input_tensor_ids) - output_tensor_ids.append(execution.output_tensor_ids) - tensor_debug_modes.append(execution.tensor_debug_mode) - tensor_values.append([ - tensor_util.MakeNdarray(tensor_proto) - for tensor_proto in execution.tensor_protos - ]) - - # TODO(cais): When tensor debug modes other than NO_TENSOR is supported, - # return tensor_values as well. - return (executed_op_types, input_tensor_ids, output_tensor_ids, - tensor_debug_modes, tensor_values) + with debug_events_reader.DebugEventsReader(dump_root) as reader: + execution_iter = reader.execution_iterator() + prev_wall_time = 1 + executed_op_types = [] + input_tensor_ids = [] + output_tensor_ids = [] + tensor_debug_modes = [] + tensor_values = [] + for debug_event in execution_iter: + self.assertGreaterEqual(debug_event.wall_time, prev_wall_time) + prev_wall_time = debug_event.wall_time + execution = debug_event.execution + executed_op_types.append(execution.op_type) + input_tensor_ids.append(execution.input_tensor_ids) + output_tensor_ids.append(execution.output_tensor_ids) + tensor_debug_modes.append(execution.tensor_debug_mode) + tensor_values.append([ + tensor_util.MakeNdarray(tensor_proto) + for tensor_proto in execution.tensor_protos + ]) + # TODO(cais): When tensor debug modes other than NO_TENSOR is supported, + # return tensor_values as well. + return (executed_op_types, input_tensor_ids, output_tensor_ids, + tensor_debug_modes, tensor_values) def _readAndCheckGraphExecutionTracesFile(self, context_ids): """Read & verify the content of the .graph_execution_trace debug-event file. @@ -210,29 +247,29 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase): tensor_values: Tensor values or their concise summaries, depending on TensorDebugMode. """ - reader = debug_events_reader.DebugEventsReader(self.dump_root) - graph_execution_traces_iter = reader.graph_execution_traces_iterator() - op_names = [] - device_names = [] - output_slots = [] - tensor_values = [] - for debug_event in graph_execution_traces_iter: - self.assertGreaterEqual(debug_event.wall_time, 0) - graph_execution_trace = debug_event.graph_execution_trace - op_names.append(graph_execution_trace.op_name) - self.assertTrue(graph_execution_trace.device_name) - device_names.append(graph_execution_trace.device_name) - # All the ops in the graph have only one output. - self.assertTrue(graph_execution_trace.tfdbg_context_id) - self.assertIn(graph_execution_trace.tfdbg_context_id, context_ids) - output_slots.append(graph_execution_trace.output_slot) - dtype = dtypes.DType(graph_execution_trace.tensor_proto.dtype) - if (dtype.is_numpy_compatible and - dtype._type_enum != types_pb2.DT_STRING): # pylint:disable=protected-access - # TODO(cais): Figure out how to properly convert string tensor proto to - # numpy representation. - tensor_values.append( - tensor_util.MakeNdarray(graph_execution_trace.tensor_proto)) - else: - tensor_values.append(None) - return op_names, device_names, output_slots, tensor_values + with debug_events_reader.DebugEventsReader(self.dump_root) as reader: + graph_execution_traces_iter = reader.graph_execution_traces_iterator() + op_names = [] + device_names = [] + output_slots = [] + tensor_values = [] + for debug_event in graph_execution_traces_iter: + self.assertGreaterEqual(debug_event.wall_time, 0) + graph_execution_trace = debug_event.graph_execution_trace + op_names.append(graph_execution_trace.op_name) + self.assertTrue(graph_execution_trace.device_name) + device_names.append(graph_execution_trace.device_name) + # All the ops in the graph have only one output. + self.assertTrue(graph_execution_trace.tfdbg_context_id) + self.assertIn(graph_execution_trace.tfdbg_context_id, context_ids) + output_slots.append(graph_execution_trace.output_slot) + dtype = dtypes.DType(graph_execution_trace.tensor_proto.dtype) + if (dtype.is_numpy_compatible and + dtype._type_enum != types_pb2.DT_STRING): # pylint:disable=protected-access + # TODO(cais): Figure out how to properly convert string tensor proto + # to numpy representation. + tensor_values.append( + tensor_util.MakeNdarray(graph_execution_trace.tensor_proto)) + else: + tensor_values.append(None) + return op_names, device_names, output_slots, tensor_values diff --git a/tensorflow/python/debug/lib/source_utils.py b/tensorflow/python/debug/lib/source_utils.py index 9c8f7733ef9..2a59ab97fc7 100644 --- a/tensorflow/python/debug/lib/source_utils.py +++ b/tensorflow/python/debug/lib/source_utils.py @@ -88,7 +88,7 @@ def guess_is_tensorflow_py_library(py_file_path): def load_source(source_file_path): - with open(source_file_path, "rU") as f: + with open(source_file_path, "r") as f: source_text = f.read() source_lines = source_text.split("\n") line_num_width = int(np.ceil(np.log10(len(source_lines)))) + 3 diff --git a/tensorflow/python/distribute/custom_training_loop_test.py b/tensorflow/python/distribute/custom_training_loop_test.py index 1db9bff21f0..55c2ae6a1ca 100644 --- a/tensorflow/python/distribute/custom_training_loop_test.py +++ b/tensorflow/python/distribute/custom_training_loop_test.py @@ -43,7 +43,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase): dataset = self._get_dataset() def train_step(data): - return data + return math_ops.square(data) dist_dataset = distribution.experimental_distribute_dataset(dataset) results = [] @@ -63,7 +63,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase): @def_function.function def train_step(data): - return data + return math_ops.square(data) dist_dataset = distribution.experimental_distribute_dataset(dataset) results = [] @@ -82,7 +82,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase): dataset = self._get_dataset() def train_step(data): - return data + return math_ops.square(data) @def_function.function def f_train_step(input_data): @@ -105,9 +105,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase): dataset = self._get_dataset() def train_step(data): - if math_ops.reduce_sum(data) < 0: - return -data - return data + return math_ops.square(data) @def_function.function def f_train_step(input_data): @@ -171,7 +169,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase): def testIterationInsideFunction(self, distribution): def step_fn(data): - return data + return math_ops.square(data) @def_function.function def train(dataset): @@ -199,7 +197,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase): def testIterationOutsideFunction(self, distribution): def train_step(data): - return data + return math_ops.square(data) @def_function.function def f_train_step(input_data): @@ -226,7 +224,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase): map(lambda x: math_ops.cast(x, dtypes.int32)).batch(2) def _validate_outputs(self, actual_results): - expected_results = [[i, i+1] for i in range(0, 10, 2)] + expected_results = [[i**2, (i+1)**2] for i in range(0, 10, 2)] self.assertEqual(len(expected_results), len(actual_results)) for i, expected_result in enumerate(expected_results): diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index 94ba10783f7..4a2a8af1840 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -1164,8 +1164,8 @@ class StrategyExtendedV2(object): *Replica context vs. Cross-replica context* - _replica context_ is when we are in some function that is being called once - for each replica. Otherwise we are in cross-replica context, which is + A _replica context_ applies when we are in some function that is being called + once for each replica. Otherwise we are in cross-replica context, which is useful for calling `tf.distribute.Strategy` methods which operate across the replicas (like `reduce_to()`). By default you start in a replica context (the "default single replica context") and then some methods can switch you diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index f1f9a0e872d..80d03ed438a 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -278,6 +278,9 @@ class DistributedIterator(object): except errors.OutOfRangeError: raise StopIteration + def __iter__(self): + return self + def get_next(self, name=None): """Returns the next input from the iterator for all replicas.""" if not self._enable_get_next_as_optional: diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py index 96363053219..1cca10a77a2 100644 --- a/tensorflow/python/distribute/input_lib_test.py +++ b/tensorflow/python/distribute/input_lib_test.py @@ -417,6 +417,27 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, expected_values, distribution) + @combinations.generate( + combinations.combine( + mode=["eager"], + distribution=[ + strategy_combinations.one_device_strategy, + strategy_combinations.mirrored_strategy_with_one_cpu + ])) + def testIterableIterator(self, distribution): + worker_device_pairs = [("", ["/device:CPU:0"])] + devices = nest.flatten([ds for _, ds in worker_device_pairs]) + device_map = values.ReplicaDeviceMap(devices) + input_workers = input_lib.InputWorkers(device_map, worker_device_pairs) + + dataset = dataset_ops.DatasetV2.range(10) + dist_dataset = input_lib.get_distributed_dataset(dataset, input_workers, + distribution) + + iterator = iter(dist_dataset) + for i, element in enumerate(iterator): + self.assertEqual(i, element.numpy()) + @combinations.generate( combinations.combine( mode=["graph", "eager"], diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index 0fb8ae0aafb..85958724002 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -343,19 +343,86 @@ def all_devices(): @tf_export("distribute.MirroredStrategy", v1=[]) # pylint: disable=g-classes-have-attributes class MirroredStrategy(distribute_lib.Strategy): - """Mirrors vars to distribute across multiple devices and machines. + """Synchronous training across multiple replicas on one machine. - This strategy uses one replica per device and sync replication for its - multi-GPU version. + This strategy is typically used for training on one + machine with multiple GPUs. For TPUs, use + `tf.distribute.experimental.TPUStrategy`. To use `MirroredStrategy` with + multiple workers, please refer to + `tf.distribute.experimental.MultiWorkerMirroredStrategy`. - To use `MirroredStrategy` with multiple workers, please refer to - `tf.distribute.MultiWorkerMirroredStrategy`. + For example, a variable created under a `MirroredStrategy` is a + `MirroredVariable`. If no devices are specified in the constructor argument of + the strategy then it will use all the available GPUs. If no GPUs are found, it + will use the available CPUs. Note that TensorFlow treats all CPUs on a + machine as a single device, and uses threads internally for parallelism. + + >>> strategy = tf.distribute.MirroredStrategy() + >>> with strategy.scope(): + ... x = tf.Variable(1.) + >>> x + MirroredVariable:{ + 0 /job:localhost/replica:0/task:0/device:CPU:0: + } + + While using distribution strategies, all the variable creation should be done + within the strategy's scope. This will replicate the variables across all the + replicas and keep them in sync using an all-reduce algorithm. + + Variables created inside a `MirroredStrategy` which is wrapped with a + `tf.function` are still `MirroredVariables`. + + >>> x = [] + >>> @tf.function # Wrap the function with tf.function. + ... def create_variable(): + ... if not x: + ... x.append(tf.Variable(1.)) + >>> strategy = tf.distribute.MirroredStrategy() + >>> with strategy.scope(): + ... create_variable() + ... print (x[0]) + MirroredVariable:{ + 0 /job:localhost/replica:0/task:0/device:CPU:0: + } + + `experimental_distribute_dataset` can be used to distribute the dataset across + the replicas when writing your own training loop. If you are using `.fit` and + `.compile` methods available in `tf.keras`, then `tf.keras` will handle the + distribution for you. + + For example: + + ```python + my_strategy = tf.distribute.MirroredStrategy() + with my_strategy.scope(): + @tf.function + def distribute_train_epoch(dataset): + def replica_fn(input): + # process input and return result + return result + + total_result = 0 + for x in dataset: + per_replica_result = my_strategy.experimental_run_v2(replica_fn, + args=(x,)) + total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM, + per_replica_result, axis=None) + return total_result + + dist_dataset = my_strategy.experimental_distribute_dataset(dataset) + for _ in range(EPOCHS): + train_result = distribute_train_epoch(dist_dataset) + ``` Args: - devices: a list of device strings. If `None`, all available GPUs are used. - If no GPUs are found, CPU is used. + devices: a list of device strings such as `['/gpu:0', '/gpu:1']`. If + `None`, all available GPUs are used. If no GPUs are found, CPU is used. cross_device_ops: optional, a descedant of `CrossDeviceOps`. If this is not - set, nccl will be used by default. + set, `NcclAllReduce()` will be used by default. One would customize this + if NCCL isn't available or if a special implementation that exploits + the particular hardware is available. """ def __init__(self, devices=None, cross_device_ops=None): diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 23cfbd44972..62a808f44d7 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -307,6 +307,19 @@ class BackpropTest(test.TestCase, parameterized.TestCase): y = array_ops.identity(x) self.assertEqual(t.gradient(y, x).numpy(), 1.0) + def testFunctionIndexedSlicesGradient(self): + + @def_function.function + def f(x): + return x + 1 + + with backprop.GradientTape() as t: + x = constant_op.constant([1.0]) + t.watch(x) + y = f(x) + y = array_ops.gather(y, [0]) + self.assertAllEqual(t.gradient(y, x), [1.0]) + def testTapeGradientMultiTargetOneIsSource(self): x = constant_op.constant(2.0) with backprop.GradientTape() as t: diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 07619c882e5..2d8b442e1af 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -34,7 +34,6 @@ from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import function_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python import _pywrap_utils -from tensorflow.python.compat import compat as fwd_compat from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop_util from tensorflow.python.eager import context @@ -1030,18 +1029,8 @@ class _TapeGradientFunctions(object): with ops.get_default_graph()._override_gradient_function( # pylint: disable=protected-access {"PartitionedCall": gradient_function, "StatefulPartitionedCall": gradient_function}): - # Previously, we relyed on "_gradient_op_type" attribute to restore a - # function gradient in function_deserialization.py, So add a dummy - # value "PartitionedCallUnused" for the forward compatibility. - if fwd_compat.forward_compatible(2019, 11, 16): - forward_outputs = forward_function.call(context.context(), - forward_inputs) - else: - with ops.get_default_graph().gradient_override_map( - {"PartitionedCall": "PartitionedCallUnused", - "StatefulPartitionedCall": "PartitionedCallUnused"}): - forward_outputs = forward_function.call(context.context(), - forward_inputs) + forward_outputs = forward_function.call(context.context(), + forward_inputs) py_backward, _ = self._wrap_backward_function( self._func_graph, backward_function, forward_outputs) # We will never request backward tape gradients for this operation @@ -1235,6 +1224,12 @@ class _TapeGradientFunctions(object): processed_args = [] input_index = 0 for output_index, arg in enumerate(args): + # Convert IndexedSlices to dense tensors. The IndexedSlices optimization + # is only really effective when doing tf.gather(variable) as the + # adjoint functions for most operations are unlikely to preserve the + # sparsity in IndexedSlices. + if isinstance(arg, ops.IndexedSlices): + arg = ops.convert_to_tensor(arg) if output_index in skip_positions: continue if arg is None: @@ -1703,16 +1698,7 @@ class ConcreteFunction(object): with ops.get_default_graph()._override_gradient_function( # pylint: disable=protected-access {"PartitionedCall": self._get_gradient_function(), "StatefulPartitionedCall": self._get_gradient_function()}): - # Previously, we relyed on "_gradient_op_type" attribute to restore a - # function gradient in function_deserialization.py. So add a dummy - # value "PartitionedCallUnused" for the forward compatibility. - if fwd_compat.forward_compatible(2019, 11, 16): - flat_outputs = forward_function.call(ctx, args_with_tangents) - else: - with ops.get_default_graph().gradient_override_map( - {"PartitionedCall": "PartitionedCallUnused", - "StatefulPartitionedCall": "PartitionedCallUnused"}): - flat_outputs = forward_function.call(ctx, args_with_tangents) + flat_outputs = forward_function.call(ctx, args_with_tangents) forward_backward.record(flat_outputs) return self._build_call_outputs(flat_outputs) diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py index a9a8ac0518a..828f30c40eb 100644 --- a/tensorflow/python/framework/dtypes.py +++ b/tensorflow/python/framework/dtypes.py @@ -21,11 +21,15 @@ import numpy as np from six.moves import builtins from tensorflow.core.framework import types_pb2 +# We need to import pywrap_tensorflow prior to the bfloat wrapper to avoid +# protobuf errors where a file is defined twice on MacOS. +# pylint: disable=invalid-import-order,g-bad-import-order +from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import +from tensorflow.python import _pywrap_bfloat16 from tensorflow.python import _dtypes -from tensorflow.python import pywrap_tensorflow from tensorflow.python.util.tf_export import tf_export -_np_bfloat16 = pywrap_tensorflow.TF_bfloat16_type() +_np_bfloat16 = _pywrap_bfloat16.TF_bfloat16_type() # pylint: disable=slots-on-old-class diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index c3abf49ae59..d6fb60fd724 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -1187,6 +1187,7 @@ tf_py_test( "//third_party/py/numpy", "//tensorflow/python:client_testlib", ], + python_version = "PY3", shard_count = 6, tags = [ "noasan", # times out diff --git a/tensorflow/python/keras/activations.py b/tensorflow/python/keras/activations.py index 55440dd4017..17af5d36b41 100644 --- a/tensorflow/python/keras/activations.py +++ b/tensorflow/python/keras/activations.py @@ -260,9 +260,8 @@ def sigmoid(x): >>> a = tf.constant([-20, -1.0, 0.0, 1.0, 20], dtype = tf.float32) >>> b = tf.keras.activations.sigmoid(a) - >>> b.numpy() - array([0. , 0.26894143, 0.5 , 0.7310586 , 1. ], - dtype=float32) + >>> b.numpy() > 0.0 + array([False, True, True, True, True]) Arguments: x: Input tensor. diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 40fdb0c79a3..36570e36cc8 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -676,6 +676,7 @@ class Model(network.Network, version_utils.VersionSelector): - tuple `(x_val, y_val)` of Numpy arrays or tensors - tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays - dataset + For the first two cases, `batch_size` must be provided. For the last case, `validation_steps` could be provided. shuffle: Boolean (whether to shuffle the training data diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py index 3a48d2339f3..fefbd1951e9 100644 --- a/tensorflow/python/keras/layers/convolutional.py +++ b/tensorflow/python/keras/layers/convolutional.py @@ -54,6 +54,9 @@ class Conv(Layer): a bias vector is created and added to the outputs. Finally, if `activation` is not `None`, it is applied to the outputs as well. + Note: layer attributes cannot be modified after the layer has been called + once (except the `trainable` attribute). + Arguments: rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution. filters: Integer, the dimensionality of the output space (i.e. the number diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py index aad66429b75..ee44fcbd946 100644 --- a/tensorflow/python/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -818,6 +818,8 @@ class Lambda(Layer): return nest.map_structure(_add_batch, output_shapes) def call(self, inputs, mask=None, training=None): + # Disallow two variables with the same name. + self._variables_added_in_call = set() arguments = self.arguments if self._fn_expects_mask_arg: arguments['mask'] = mask @@ -828,8 +830,18 @@ class Lambda(Layer): def _variable_creator(self, next_creator, **kwargs): name = kwargs['name'] + + # Variable named "name" already created in this invocation of `call`. + if name in self._variables_added_in_call: + raise RuntimeError('`Variable`s in a `Lambda` layer must have unique ' + 'names, found duplicate name: {}'.format(name)) + self._variables_added_in_call.add(name) + + # Reuse Variables across invocations of `call`. if name in self._variable_dict: return self._variable_dict[name] + + # Variable was never created before. var = next_creator(**kwargs) self._variable_dict[name] = var if var.trainable: @@ -964,6 +976,8 @@ class Dense(Layer): Note: If the input to the layer has a rank greater than 2, then it is flattened prior to the initial dot product with `kernel`. + Besides, layer attributes cannot be modified after the layer has been called + once (except the `trainable` attribute). Example: diff --git a/tensorflow/python/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py index aa7b42d0e95..05f89053fcb 100644 --- a/tensorflow/python/keras/layers/core_test.py +++ b/tensorflow/python/keras/layers/core_test.py @@ -236,6 +236,17 @@ class LambdaLayerTest(keras_parameterized.TestCase): self.assertLen(layer.trainable_weights, 1) self.assertEqual(layer.trainable_weights[0].name, 'lambda/multiplier:0') + def test_lambda_with_duplicate_variable_names(self): + + def fn(x): + v1 = variables.Variable(2.) + v2 = variables.Variable(1.) + return x * v1 * v2 + + layer = keras.layers.Lambda(fn) + with self.assertRaisesRegexp(RuntimeError, 'must have unique names'): + layer(np.ones((10, 10), 'float32')) + def test_lambda_with_training_arg(self): def fn(x, training=True): diff --git a/tensorflow/python/keras/layers/dense_attention.py b/tensorflow/python/keras/layers/dense_attention.py index 054e840f48c..ff5e62d0d8b 100644 --- a/tensorflow/python/keras/layers/dense_attention.py +++ b/tensorflow/python/keras/layers/dense_attention.py @@ -243,7 +243,7 @@ class Attention(BaseDenseAttention): # Query embeddings of shape [batch_size, Tq, dimension]. query_embeddings = token_embedding(query_input) # Value embeddings of shape [batch_size, Tv, dimension]. - value_embeddings = token_embedding(query_input) + value_embeddings = token_embedding(value_input) # CNN layer. cnn_layer = tf.keras.layers.Conv1D( diff --git a/tensorflow/python/keras/layers/image_preprocessing_test.py b/tensorflow/python/keras/layers/image_preprocessing_test.py index d33acbf0de7..672cb181974 100644 --- a/tensorflow/python/keras/layers/image_preprocessing_test.py +++ b/tensorflow/python/keras/layers/image_preprocessing_test.py @@ -187,6 +187,11 @@ class RandomCropTest(keras_parameterized.TestCase): self._run_test(expected_height, expected_width) def test_training_with_mock(self): + if test.is_built_with_rocm(): + # TODO(rocm): + # re-enable this test once ROCm adds support for + # the StatefulUniformFullInt Op (on the GPU) + self.skipTest('Feature not supported on ROCm') np.random.seed(1337) height, width = 3, 4 height_offset = np.random.randint(low=0, high=3) @@ -207,6 +212,11 @@ class RandomCropTest(keras_parameterized.TestCase): ('random_crop_4_by_6', 4, 6), ('random_crop_3_by_2', 3, 2)) def test_random_crop_output_shape(self, expected_height, expected_width): + if test.is_built_with_rocm(): + # TODO(rocm): + # re-enable this test once ROCm adds support for + # the StatefulUniformFullInt Op (on the GPU) + self.skipTest('Feature not supported on ROCm') with CustomObjectScope({'RandomCrop': image_preprocessing.RandomCrop}): self._run_test(expected_height, expected_width) diff --git a/tensorflow/python/keras/layers/local.py b/tensorflow/python/keras/layers/local.py index d94092023aa..ec7392e754e 100644 --- a/tensorflow/python/keras/layers/local.py +++ b/tensorflow/python/keras/layers/local.py @@ -41,6 +41,9 @@ class LocallyConnected1D(Layer): that is, a different set of filters is applied at each different patch of the input. + Note: layer attributes cannot be modified after the layer has been called + once (except the `trainable` attribute). + Example: ```python # apply a unshared weight convolution 1d of length 3 to a sequence with @@ -340,6 +343,9 @@ class LocallyConnected2D(Layer): that is, a different set of filters is applied at each different patch of the input. + Note: layer attributes cannot be modified after the layer has been called + once (except the `trainable` attribute). + Examples: ```python # apply a 3x3 unshared weights convolution with 64 output filters on a diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index 87a99f49164..eb8f43fd993 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -66,14 +66,17 @@ class StackedRNNCells(Layer): Examples: ```python - cells = [ - keras.layers.LSTMCell(output_dim), - keras.layers.LSTMCell(output_dim), - keras.layers.LSTMCell(output_dim), - ] + batch_size = 3 + sentence_max_length = 5 + n_features = 2 + new_shape = (batch_size, sentence_max_length, n_features) + x = tf.constant(np.reshape(np.arange(30), new_shape), dtype = tf.float32) - inputs = keras.Input((timesteps, input_dim)) - x = keras.layers.RNN(cells)(inputs) + rnn_cells = [tf.keras.layers.LSTMCell(128) for _ in range(2)] + stacked_lstm = tf.keras.layers.StackedRNNCells(rnn_cells) + lstm_layer = tf.keras.layers.RNN(stacked_lstm) + + result = lstm_layer(x) ``` """ diff --git a/tensorflow/python/keras/model_subclassing_compiled_test.py b/tensorflow/python/keras/model_subclassing_compiled_test.py index 180e8c8b735..bf27b3bf8a7 100644 --- a/tensorflow/python/keras/model_subclassing_compiled_test.py +++ b/tensorflow/python/keras/model_subclassing_compiled_test.py @@ -64,7 +64,7 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase): num_samples = 1000 input_dim = 50 - model = model_util.MultiIOTestModel( + model = model_util.get_multi_io_subclass_model( num_classes=num_classes, use_dp=True, use_bn=True) model.compile( loss='mse', @@ -111,7 +111,8 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase): num_samples = 100 input_dim = 50 - model = model_util.MultiIOTestModel(num_classes=num_classes, use_bn=True) + model = model_util.get_multi_io_subclass_model( + num_classes=num_classes, use_bn=True) x1 = np.ones((num_samples, input_dim)) x2 = np.ones((num_samples, input_dim)) @@ -211,7 +212,8 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase): y1 = np.zeros((num_samples, num_classes[0])) y2 = np.zeros((num_samples, num_classes[1])) - model = model_util.MultiIOTestModel(num_classes=num_classes, use_bn=True) + model = model_util.get_multi_io_subclass_model( + num_classes=num_classes, use_bn=True) model.compile( loss='mse', optimizer='rmsprop', @@ -224,7 +226,8 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase): model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0, validation_data=([x1, x2], [y1, y2])) - model = model_util.MultiIOTestModel(num_classes=num_classes, use_bn=True) + model = model_util.get_multi_io_subclass_model( + num_classes=num_classes, use_bn=True) model.compile( loss='mse', optimizer='rmsprop', @@ -246,7 +249,8 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase): y1 = np.zeros((num_samples, num_classes[0])) y2 = np.zeros((num_samples, num_classes[1])) - model = model_util.MultiIOTestModel(num_classes=num_classes, use_bn=True) + model = model_util.get_multi_io_subclass_model( + num_classes=num_classes, use_bn=True) model.compile( loss='mse', optimizer='rmsprop', @@ -255,10 +259,12 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase): model.evaluate([x1, x2], [y1, y2]) model.test_on_batch([x1, x2], [y1, y2]) - model = model_util.MultiIOTestModel(num_classes=num_classes, use_bn=True) + model = model_util.get_multi_io_subclass_model( + num_classes=num_classes, use_bn=True) model.predict([x1, x2]) - model = model_util.MultiIOTestModel(num_classes=num_classes, use_bn=True) + model = model_util.get_multi_io_subclass_model( + num_classes=num_classes, use_bn=True) model.predict_on_batch([x1, x2]) def test_saving(self): @@ -271,7 +277,8 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase): y1 = np.zeros((num_samples, num_classes[0])) y2 = np.zeros((num_samples, num_classes[1])) - model = model_util.MultiIOTestModel(num_classes=num_classes, use_bn=True) + model = model_util.get_multi_io_subclass_model( + num_classes=num_classes, use_bn=True) model.compile( loss='mse', optimizer='rmsprop', @@ -286,7 +293,8 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase): hdf5_format_name = os.path.join(self.get_temp_dir(), 'weights.h5') model.save_weights(hdf5_format_name) - model = model_util.MultiIOTestModel(num_classes=num_classes, use_bn=True) + model = model_util.get_multi_io_subclass_model( + num_classes=num_classes, use_bn=True) if h5py is not None: with self.assertRaises(ValueError): diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py index bbab5637a89..a4b8ac92b03 100644 --- a/tensorflow/python/keras/model_subclassing_test.py +++ b/tensorflow/python/keras/model_subclassing_test.py @@ -313,7 +313,7 @@ class ModelSubclassingTest(keras_parameterized.TestCase): batch_size = None num_samples = 1000 input_dim = 50 - model = model_util.MultiIOTestModel() + model = model_util.get_multi_io_subclass_model() self.assertFalse(model.built, 'Model should not have been built') self.assertFalse(model.weights, ('Model should have no weights since it ' 'has not been built.')) @@ -345,7 +345,7 @@ class ModelSubclassingTest(keras_parameterized.TestCase): self.assertTrue('Trainable params: 356' in print_fn.contents) # Multi-io - model = model_util.MultiIOTestModel( + model = model_util.get_multi_io_subclass_model( num_classes=(5, 6), use_bn=True, use_dp=True) model._set_inputs([np.ones((3, 4)), np.ones((3, 4))]) # need to build model first @@ -508,7 +508,7 @@ class GraphSpecificModelSubclassingTests(test.TestCase): input_dim = 50 with self.cached_session(): - model = model_util.MultiIOTestModel( + model = model_util.get_multi_io_subclass_model( num_classes=num_classes, use_dp=True, use_bn=True) model.compile(loss='mse', optimizer='rmsprop') @@ -595,7 +595,7 @@ class GraphSpecificModelSubclassingTests(test.TestCase): input_dim = 50 with self.cached_session(): - model = model_util.MultiIOTestModel( + model = model_util.get_multi_io_subclass_model( num_classes=num_classes, use_dp=True, use_bn=True) model.compile(loss='mse', optimizer='rmsprop') diff --git a/tensorflow/python/keras/model_subclassing_test_util.py b/tensorflow/python/keras/model_subclassing_test_util.py index 0f07c716b80..cf627b984a1 100644 --- a/tensorflow/python/keras/model_subclassing_test_util.py +++ b/tensorflow/python/keras/model_subclassing_test_util.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python import keras +from tensorflow.python.keras import testing_utils # pylint: disable=missing-docstring,not-callable @@ -62,31 +63,23 @@ class SimpleConvTestModel(keras.Model): return self.dense1(x) -class MultiIOTestModel(keras.Model): +def get_multi_io_subclass_model(use_bn=False, use_dp=False, num_classes=(2, 3)): + """Creates MultiIOModel for the tests of subclass model.""" + shared_layer = keras.layers.Dense(32, activation='relu') + branch_a = [shared_layer] + if use_dp: + branch_a.append(keras.layers.Dropout(0.5)) + branch_a.append(keras.layers.Dense(num_classes[0], activation='softmax')) - def __init__(self, use_bn=False, use_dp=False, num_classes=(2, 3)): - super(MultiIOTestModel, self).__init__(name='test_model') - self.use_bn = use_bn - self.use_dp = use_dp - self.num_classes = num_classes + branch_b = [shared_layer] + if use_bn: + branch_b.append(keras.layers.BatchNormalization()) + branch_b.append(keras.layers.Dense(num_classes[1], activation='softmax')) - self.dense1 = keras.layers.Dense(32, activation='relu') - self.dense2 = keras.layers.Dense(num_classes[0], activation='softmax') - self.dense3 = keras.layers.Dense(num_classes[1], activation='softmax') - if use_dp: - self.dp = keras.layers.Dropout(0.5) - if use_bn: - self.bn = keras.layers.BatchNormalization() - - def call(self, inputs): - x1, x2 = inputs - x1 = self.dense1(x1) - x2 = self.dense1(x2) - if self.use_dp: - x1 = self.dp(x1) - if self.use_bn: - x2 = self.bn(x2) - return [self.dense2(x1), self.dense3(x2)] + model = ( + testing_utils._MultiIOSubclassModel( # pylint: disable=protected-access + branch_a, branch_b, name='test_model')) + return model class NestedTestModel1(keras.Model): diff --git a/tensorflow/python/keras/optimizer_v2/adadelta_test.py b/tensorflow/python/keras/optimizer_v2/adadelta_test.py index 4dad9198b85..b3703fa07ea 100644 --- a/tensorflow/python/keras/optimizer_v2/adadelta_test.py +++ b/tensorflow/python/keras/optimizer_v2/adadelta_test.py @@ -35,7 +35,8 @@ from tensorflow.python.platform import test _DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64] # TODO(b/143684500): Eigen to support complex sqrt -if not test_util.IsBuiltWithNvcc() and platform.system() != "Windows": +if (not test_util.IsBuiltWithNvcc() and platform.system() != "Windows" and + not test.is_built_with_rocm()): _DATA_TYPES += [dtypes.complex64, dtypes.complex128] diff --git a/tensorflow/python/keras/optimizer_v2/adagrad_test.py b/tensorflow/python/keras/optimizer_v2/adagrad_test.py index b0b661da8f7..9cbcd27b5d8 100644 --- a/tensorflow/python/keras/optimizer_v2/adagrad_test.py +++ b/tensorflow/python/keras/optimizer_v2/adagrad_test.py @@ -38,7 +38,8 @@ from tensorflow.python.platform import test _DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64] # TODO(b/143684500): Eigen to support complex sqrt -if not test_util.IsBuiltWithNvcc() and platform.system() != "Windows": +if (not test_util.IsBuiltWithNvcc() and platform.system() != "Windows" and + not test.is_built_with_rocm()): _DATA_TYPES += [dtypes.complex64, dtypes.complex128] diff --git a/tensorflow/python/keras/optimizer_v2/rmsprop_test.py b/tensorflow/python/keras/optimizer_v2/rmsprop_test.py index 0482b6f00b7..a3480d62f21 100644 --- a/tensorflow/python/keras/optimizer_v2/rmsprop_test.py +++ b/tensorflow/python/keras/optimizer_v2/rmsprop_test.py @@ -41,7 +41,8 @@ from tensorflow.python.platform import test _DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64] # TODO(b/143684500): Eigen to support complex sqrt -if not test_util.IsBuiltWithNvcc() and platform.system() != "Windows": +if (not test_util.IsBuiltWithNvcc() and platform.system() != "Windows" and + not test.is_built_with_rocm()): _DATA_TYPES += [dtypes.complex64, dtypes.complex128] _TEST_PARAM_VALUES = [ diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py index e4c2406399f..aa4059cb50e 100644 --- a/tensorflow/python/keras/testing_utils.py +++ b/tensorflow/python/keras/testing_utils.py @@ -377,20 +377,10 @@ def saved_model_format_scope(value): _thread_local_data.saved_model_format = previous_value -def get_saved_model_format(): - """Gets the saved model format that should be tested.""" - if _thread_local_data.saved_model_format is None: - raise ValueError( - 'Cannot call `get_saved_model_format()` outside of a ' - '`saved_model_format_scope()` or `run_with_all_saved_model_formats` ' - 'decorator.') - return _thread_local_data.saved_model_format - - def get_save_format(): if _thread_local_data.saved_model_format is None: raise ValueError( - 'Cannot call `get_saved_model_format()` outside of a ' + 'Cannot call `get_save_format()` outside of a ' '`saved_model_format_scope()` or `run_with_all_saved_model_formats` ' 'decorator.') return _thread_local_data.saved_model_format @@ -608,8 +598,8 @@ class _MultiIOSubclassModel(keras.Model): """Multi IO Keras subclass model.""" def __init__(self, branch_a, branch_b, shared_input_branch=None, - shared_output_branch=None): - super(_MultiIOSubclassModel, self).__init__() + shared_output_branch=None, name=None): + super(_MultiIOSubclassModel, self).__init__(name=name) self._shared_input_branch = shared_input_branch self._branch_a = branch_a self._branch_b = branch_b diff --git a/tensorflow/python/keras/utils/data_utils_test.py b/tensorflow/python/keras/utils/data_utils_test.py index 0d3854890c5..e10d8064401 100644 --- a/tensorflow/python/keras/utils/data_utils_test.py +++ b/tensorflow/python/keras/utils/data_utils_test.py @@ -241,6 +241,7 @@ class TestEnqueuers(test.TestCase): # One epoch is completed so enqueuer will switch the Sequence acc = [] + self.skipTest('b/145555807 flakily timing out.') for _ in range(100): acc.append(next(gen_output2)[0, 0, 0, 0]) self.assertEqual(acc[-1], 99 * 15) diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py index d5cb1c49555..60d34a0e299 100644 --- a/tensorflow/python/kernel_tests/cond_v2_test.py +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -119,6 +119,29 @@ class CondV2Test(test.TestCase): output = build_cond_with_indexed_slices() self.assertAllEqual(output, [1.]) + def testReturnsNonesAndIndexedSlices(self): + + @def_function.function + def build_cond_with_indexed_slices(): + pred = constant_op.constant(True) + + def true_fn(): + return (None, None, None, + math_ops._as_indexed_slices(constant_op.constant([1.]))) + + def false_fn(): + return (None, None, None, + math_ops._as_indexed_slices(constant_op.constant([2.]))) + + result = cond_v2.cond_v2(pred, true_fn, false_fn) + self.assertIsNone(result[0]) + self.assertIsNone(result[1]) + self.assertIsNone(result[2]) + return ops.convert_to_tensor(result[3]) + + output = build_cond_with_indexed_slices() + self.assertAllEqual(output, [1.]) + def testExternalControlDependencies(self): with ops.Graph().as_default(), self.test_session(): v = variables.Variable(1.0) diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py index acc66b7c3e6..1b5fa201d8f 100644 --- a/tensorflow/python/kernel_tests/reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/reduction_ops_test.py @@ -438,6 +438,33 @@ class MeanReductionTest(BaseReductionTest): np_arr = self._makeIncremental((2,) * rank, dtypes.int32) self._compareAllAxes(np_arr) + @test_util.run_deprecated_v1 + def testUint8(self): + for rank in range(1, _MAX_RANK + 1): + np_arr = self._makeRandom((2,) * rank, dtypes.uint8) + self._compareAllAxes(np_arr) + + # This tests the issue reported in b/145030710. + @test_util.run_deprecated_v1 + def testSizeOverflowUint8(self): + np_arr = self._makeRandom((2**8,), dtypes.uint8) + self._compareAllAxes(np_arr) + + @test_util.run_deprecated_v1 + def testSizeOverflowInt8(self): + np_arr = self._makeRandom((2**7,), dtypes.int8) + self._compareAllAxes(np_arr) + + @test_util.run_deprecated_v1 + def testSizeOverflowUint16(self): + np_arr = self._makeRandom((2**16,), dtypes.uint16) + self._compareAllAxes(np_arr) + + @test_util.run_deprecated_v1 + def testSizeOverflowInt16(self): + np_arr = self._makeRandom((2**15,), dtypes.int16) + self._compareAllAxes(np_arr) + @test_util.run_deprecated_v1 def testFloat32(self): for rank in range(1, _MAX_RANK + 1): diff --git a/tensorflow/python/kernel_tests/signal/shape_ops_test.py b/tensorflow/python/kernel_tests/signal/shape_ops_test.py index e9056accc71..6d9c77a0136 100644 --- a/tensorflow/python/kernel_tests/signal/shape_ops_test.py +++ b/tensorflow/python/kernel_tests/signal/shape_ops_test.py @@ -18,9 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import itertools - -from absl.testing import parameterized import numpy as np from tensorflow.python.eager import context @@ -28,7 +25,6 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.kernel_tests.signal import test_util from tensorflow.python.ops import array_ops @@ -38,7 +34,7 @@ from tensorflow.python.platform import test @tf_test_util.run_all_in_graph_and_eager_modes -class FrameTest(test.TestCase, parameterized.TestCase): +class FrameTest(test.TestCase): def test_mapping_of_indices_without_padding(self): tensor = constant_op.constant(np.arange(9152), dtypes.int32) @@ -356,46 +352,6 @@ class FrameTest(test.TestCase, parameterized.TestCase): rewritten_graph = test_util.grappler_optimize(g, [frames]) self.assertEqual(1, len(rewritten_graph.node)) - @parameterized.parameters( - itertools.product( - # length % step == 0 - ((32, 16), - # gcd(length, step) == 1 - (32, 15), - # gcd(length, step) == 5 - (25, 15), - # length == step - (32, 32)), - (False, True), # pad_end - (False, True), # use_mlir - (False, True))) # known_batch - def test_tflite_convert(self, length_step, pad_end, use_mlir, known_batch): - """Check for tf.lite compatibility in a variety of settings.""" - def fn(signal): - return shape_ops.frame( - signal, length_step[0], length_step[1], pad_end=pad_end) - - # TODO(b/144998258): unknown batch does not currently work with padding. - if not known_batch and pad_end: - return - - signal_length, dtype = 8001, dtypes.float32 - # If batch size is unknown, tf.lite assumes it's 1. Test batch_size > 1 - # only when batch size is known. - batch_size = 2 if known_batch else 1 - static_batch_size = batch_size if known_batch else None - tflite_model = test_util.tflite_convert( - fn, [tensor_spec.TensorSpec( - shape=[static_batch_size, signal_length], dtype=dtype)], - use_mlir) - signal = np.random.normal(size=(batch_size, signal_length)).astype( - dtype.as_numpy_dtype) - actual_output, = test_util.evaluate_tflite_model( - tflite_model, [signal]) - - expected_output = self.evaluate(fn(signal)) - self.assertAllClose(actual_output, expected_output) - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/lib/core/bfloat16.cc b/tensorflow/python/lib/core/bfloat16.cc index 54be76375c9..42b248a7ddb 100644 --- a/tensorflow/python/lib/core/bfloat16.cc +++ b/tensorflow/python/lib/core/bfloat16.cc @@ -532,7 +532,9 @@ struct Bfloat16GeFunctor { // Initializes the module. bool Initialize() { - // It's critical to import umath to avoid crash in open source build. + // It's critical to ImportNumpy and import umath + // to avoid crash in open source build. + ImportNumpy(); import_umath1(false); Safe_PyObjectPtr numpy_str = make_safe(MakePyString("numpy")); diff --git a/tensorflow/python/lib/core/bfloat16_test.py b/tensorflow/python/lib/core/bfloat16_test.py index bc928cd9e5e..32453ae2296 100644 --- a/tensorflow/python/lib/core/bfloat16_test.py +++ b/tensorflow/python/lib/core/bfloat16_test.py @@ -24,12 +24,12 @@ import math import numpy as np # pylint: disable=unused-import,g-bad-import-order -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import _pywrap_bfloat16 from tensorflow.python.framework import dtypes from tensorflow.python.platform import test -bfloat16 = pywrap_tensorflow.TF_bfloat16_type() +bfloat16 = _pywrap_bfloat16.TF_bfloat16_type() class Bfloat16Test(test.TestCase): diff --git a/tensorflow/python/lib/core/bfloat16.i b/tensorflow/python/lib/core/bfloat16_wrapper.cc similarity index 70% rename from tensorflow/python/lib/core/bfloat16.i rename to tensorflow/python/lib/core/bfloat16_wrapper.cc index 10444b676b2..4a8e180c154 100644 --- a/tensorflow/python/lib/core/bfloat16.i +++ b/tensorflow/python/lib/core/bfloat16_wrapper.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,18 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -%{ +#include "include/pybind11/pybind11.h" #include "tensorflow/python/lib/core/bfloat16.h" -%} -%init %{ -tensorflow::RegisterNumpyBfloat16(); -%} +PYBIND11_MODULE(_pywrap_bfloat16, m) { + tensorflow::RegisterNumpyBfloat16(); -%{ -PyObject* TF_bfloat16_type() { - return tensorflow::Bfloat16PyType(); + m.def("TF_bfloat16_type", + [] { return pybind11::handle(tensorflow::Bfloat16PyType()); }); } -%} - -PyObject* TF_bfloat16_type(); diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index c60d3cbc830..839da5fb90f 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -1658,12 +1658,12 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None): Example: ```python - tf.assert_shapes({ + tf.assert_shapes([ (x, ('N', 'Q')), (y, ('N', 'D')), (param, ('Q',)), (scalar, ()) - }) + ]) ``` Example of adding a dependency to an operation: diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py index c0237d1bf9f..fd0102328d1 100644 --- a/tensorflow/python/ops/cond_v2.py +++ b/tensorflow/python/ops/cond_v2.py @@ -617,15 +617,18 @@ def _make_output_composite_tensors_match(op_type, branch_graphs): def _make_indexed_slices_indices_types_match(op_type, branch_graphs): """Match dtype of IndexedSlices.indices in outputs of branch_graphs.""" assert branch_graphs + # Indices of `IndexedSlices.indices` tensors in `branch_graphs[i].outputs`. indexed_slice_indices = [] current_index = 0 + # Note that this still contains Nones. We leave those in so that error + # messages contain the correct indices. We handle the Nones later when + # updating `current_index`. branch_outputs_flat_with_composites = [ nest.flatten(branch_graph.structured_outputs, expand_composites=False) for branch_graph in branch_graphs ] outs_per_branch = [len(outs) for outs in branch_outputs_flat_with_composites] assert len(set(outs_per_branch)) == 1, outs_per_branch - num_none_outputs = 0 # Store indices of IndexedSlices.indices in `indexed_slice_indices`. for output_idx, branch_outs in enumerate( zip(*branch_outputs_flat_with_composites)): @@ -640,17 +643,17 @@ def _make_indexed_slices_indices_types_match(op_type, branch_graphs): indexed_slice_indices.append(current_index + 1) if nest.is_sequence_or_composite(branch_outs[0]): current_index += len(nest.flatten(branch_outs[0], expand_composites=True)) - else: + elif branch_outs[0] is not None: + # `FuncGraph.outputs` does not contain Nones so no need to update the + # counter in that case. current_index += 1 - if branch_outs[0] is None: - num_none_outputs += 1 if not indexed_slice_indices: return # `FuncGraph.outputs` is the flattened `FuncGraph.structured_outputs` minus # the Nones. - if current_index != len(branch_graphs[0].outputs) + num_none_outputs: + if current_index != len(branch_graphs[0].outputs): raise ValueError("Insufficient elements in branch_graphs[0].outputs.\n" "Expected: %i\n" "Actual: %i" % diff --git a/tensorflow/python/ops/linalg/linalg_impl.py b/tensorflow/python/ops/linalg/linalg_impl.py index 9cdcaee6ac2..18d22968c94 100644 --- a/tensorflow/python/ops/linalg/linalg_impl.py +++ b/tensorflow/python/ops/linalg/linalg_impl.py @@ -298,11 +298,9 @@ def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0) u3, v3 = _matrix_exp_pade3(matrix) u5, v5 = _matrix_exp_pade5(matrix) - u7, v7 = _matrix_exp_pade7(matrix / math_ops.pow( - constant_op.constant(2.0, dtype=matrix.dtype), - math_ops.cast( - squarings, - matrix.dtype))[..., array_ops.newaxis, array_ops.newaxis]) + u7, v7 = _matrix_exp_pade7(matrix / math_ops.cast( + math_ops.pow(const(2.0), squarings), + matrix.dtype)[..., array_ops.newaxis, array_ops.newaxis]) conds = (4.258730016922831e-001, 1.880152677804762e+000) u = _nest_where(conds, (u3, u5, u7)) v = _nest_where(conds, (v3, v5, v7)) @@ -315,11 +313,9 @@ def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin u5, v5 = _matrix_exp_pade5(matrix) u7, v7 = _matrix_exp_pade7(matrix) u9, v9 = _matrix_exp_pade9(matrix) - u13, v13 = _matrix_exp_pade13(matrix / math_ops.pow( - constant_op.constant(2.0, dtype=matrix.dtype), - math_ops.cast( - squarings, - matrix.dtype))[..., array_ops.newaxis, array_ops.newaxis]) + u13, v13 = _matrix_exp_pade13(matrix / math_ops.cast( + math_ops.pow(const(2.0), squarings), + matrix.dtype)[..., array_ops.newaxis, array_ops.newaxis]) conds = (1.495585217958292e-002, 2.539398330063230e-001, 9.504178996162932e-001, 2.097847961257068e+000) u = _nest_where(conds, (u3, u5, u7, u9, u13)) diff --git a/tensorflow/python/ops/linalg/linear_operator_test_util.py b/tensorflow/python/ops/linalg/linear_operator_test_util.py index ae7b11778d4..33b8003ae2e 100644 --- a/tensorflow/python/ops/linalg/linear_operator_test_util.py +++ b/tensorflow/python/ops/linalg/linear_operator_test_util.py @@ -452,6 +452,11 @@ def _test_cond(use_placeholder, shapes_info, dtype): if 0 in shapes_info.shape[-2:]: return + # ROCm platform does not yet support complex types + if test.is_built_with_rocm() and \ + ((dtype == dtypes.complex64) or (dtype == dtypes.complex128)): + return + sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED # Ensure self-adjoint and PD so we get finite condition numbers. operator, mat = self.operator_and_matrix( diff --git a/tensorflow/python/ops/linalg/sparse/conjugate_gradient.py b/tensorflow/python/ops/linalg/sparse/conjugate_gradient.py index dd8798141db..613309f856d 100644 --- a/tensorflow/python/ops/linalg/sparse/conjugate_gradient.py +++ b/tensorflow/python/ops/linalg/sparse/conjugate_gradient.py @@ -30,7 +30,7 @@ from tensorflow.python.ops.linalg import linalg_impl as linalg from tensorflow.python.util.tf_export import tf_export -@tf_export('linalg.experimental.sparse.conjugate_gradient') +@tf_export('linalg.experimental.conjugate_gradient') def conjugate_gradient(operator, rhs, preconditioner=None, diff --git a/tensorflow/python/ops/linalg/sparse/sparse.py b/tensorflow/python/ops/linalg/sparse/sparse.py index ef7abdc6b81..6f9b2522335 100644 --- a/tensorflow/python/ops/linalg/sparse/sparse.py +++ b/tensorflow/python/ops/linalg/sparse/sparse.py @@ -20,7 +20,11 @@ from __future__ import print_function # go/tf-wildcard-import # pylint: disable=wildcard-import -from tensorflow.python.ops.linalg.sparse.conjugate_gradient import * +from tensorflow.python.ops.linalg.sparse.conjugate_gradient import conjugate_gradient from tensorflow.python.ops.linalg.sparse.sparse_csr_matrix_grad import * from tensorflow.python.ops.linalg.sparse.sparse_csr_matrix_ops import * # pylint: enable=wildcard-import + +__all__ = [ + 'conjugate_gradient' +] diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index f7e01f57200..f9a75f6aecc 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -204,20 +204,14 @@ def _SumGrad(op, grad): input_shape = array_ops.shape(op.inputs[0]) - if compat.forward_compatible(2019, 10, 23): - if not op.get_attr("keep_dims"): - with ops.colocate_with(input_shape): - # TODO(apassos) remove this once device placement for eager ops makes - # more sense. - output_shape_kept_dims = math_ops.reduced_shape(input_shape, - op.inputs[1]) - grad = array_ops.reshape(grad, output_shape_kept_dims) - return [array_ops.broadcast_to(grad, input_shape), None] - with ops.colocate_with(input_shape): - output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) - tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims) - grad = array_ops.reshape(grad, output_shape_kept_dims) - return [array_ops.tile(grad, tile_scaling), None] + if not op.get_attr("keep_dims"): + with ops.colocate_with(input_shape): + # TODO(apassos) remove this once device placement for eager ops makes + # more sense. + output_shape_kept_dims = math_ops.reduced_shape(input_shape, + op.inputs[1]) + grad = array_ops.reshape(grad, output_shape_kept_dims) + return [array_ops.broadcast_to(grad, input_shape), None] def _MinOrMaxGrad(op, grad): diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 6e6fd50a419..078219e2f23 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -75,7 +75,6 @@ import six from six.moves import builtins from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.python.compat import compat as fwd_compat from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -1364,10 +1363,7 @@ def tensor_equals(self, other): g = getattr(self, "graph", None) if (ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions() and (g is None or g._building_function)): # pylint: disable=protected-access - if fwd_compat.forward_compatible(2019, 9, 25): - return gen_math_ops.equal(self, other, incompatible_shape_error=False) - else: - return gen_math_ops.equal(self, other) + return gen_math_ops.equal(self, other, incompatible_shape_error=False) else: # In legacy graph mode, tensor equality is object equality return self is other @@ -1378,10 +1374,7 @@ def tensor_not_equals(self, other): if other is None: return True if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): - if fwd_compat.forward_compatible(2019, 9, 25): - return gen_math_ops.not_equal(self, other, incompatible_shape_error=False) - else: - return gen_math_ops.not_equal(self, other) + return gen_math_ops.not_equal(self, other, incompatible_shape_error=False) else: # In legacy graph mode, tensor equality is object equality return self is not other diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py index 94ada0515cf..f9208cca551 100644 --- a/tensorflow/python/ops/random_ops.py +++ b/tensorflow/python/ops/random_ops.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import numpy as np -import six from tensorflow.python.compat import compat from tensorflow.python.eager import context @@ -266,11 +265,10 @@ def random_uniform(shape, shape = tensor_util.shape_tensor(shape) # TODO(b/143079601): Remove this once the compatible window is passed. if compat.forward_compatible(2019, 12, 3): - # In case of [0,1) floating results, minval and maxval is unused. - minval_is_zero = isinstance(minval, six.integer_types + - (float,)) and minval == 0 - maxval_is_one = isinstance(maxval, six.integer_types + - (float,)) and maxval == 1 + # In case of [0,1) floating results, minval and maxval is unused. We do an + # `is` comparison here since this is cheaper than isinstance or __eq__. + minval_is_zero = minval is 0 # pylint: disable=literal-comparison + maxval_is_one = maxval is 1 # pylint: disable=literal-comparison if not minval_is_zero or not maxval_is_one or dtype.is_integer: minval = ops.convert_to_tensor(minval, dtype=dtype, name="min") maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max") diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i index 413b5126e77..761e6f376f8 100644 --- a/tensorflow/python/tensorflow.i +++ b/tensorflow/python/tensorflow.i @@ -21,8 +21,6 @@ limitations under the License. %include "tensorflow/python/client/tf_session.i" -%include "tensorflow/python/lib/core/bfloat16.i" - %include "tensorflow/python/lib/io/file_io.i" %include "tensorflow/python/lib/io/py_record_reader.i" diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl index b2981b14209..45c1a959256 100644 --- a/tensorflow/python/tools/api/generator/api_init_files.bzl +++ b/tensorflow/python/tools/api/generator/api_init_files.bzl @@ -32,6 +32,7 @@ TENSORFLOW_API_INIT_FILES = [ "io/__init__.py", "queue/__init__.py", "linalg/__init__.py", + "linalg/experimental/__init__.py", "lite/__init__.py", "lite/experimental/__init__.py", "lite/experimental/microfrontend/__init__.py", diff --git a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl index 31e0c6ca457..a67afdcad29 100644 --- a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl +++ b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl @@ -36,6 +36,7 @@ TENSORFLOW_API_INIT_FILES_V1 = [ "layers/__init__.py", "layers/experimental/__init__.py", "linalg/__init__.py", + "linalg/experimental/__init__.py", "lite/__init__.py", "lite/constants/__init__.py", "lite/experimental/__init__.py", diff --git a/tensorflow/python/tools/api/generator/create_python_api.py b/tensorflow/python/tools/api/generator/create_python_api.py index 3af677322d6..42119aeae34 100644 --- a/tensorflow/python/tools/api/generator/create_python_api.py +++ b/tensorflow/python/tools/api/generator/create_python_api.py @@ -243,15 +243,16 @@ class _ModuleInitCodeBuilder(object): # from it using * import. Don't need this for lazy_loading because the # underscore symbols are already included in __all__ when passed in and # handled by TFModuleWrapper. + root_module_footer = '' if not self._lazy_loading: underscore_names_str = ', '.join( '\'%s\'' % name for name in self._underscore_names_in_root) - module_text_map[''] = module_text_map.get('', '') + ''' + root_module_footer = """ _names_with_underscore = [%s] __all__ = [_s for _s in dir() if not _s.startswith('_')] __all__.extend([_s for _s in _names_with_underscore]) -''' % underscore_names_str +""" % underscore_names_str # Add module wrapper if we need to print deprecation messages # or if we use lazy loading. @@ -273,7 +274,7 @@ __all__.extend([_s for _s in _names_with_underscore]) footer_text_map[dest_module] = _DEPRECATION_FOOTER % ( dest_module, public_apis_name, deprecation, has_lite) - return module_text_map, footer_text_map + return module_text_map, footer_text_map, root_module_footer def format_import(self, source_module_name, source_name, dest_name): """Formats import statement. @@ -620,9 +621,12 @@ def create_api_files(output_files, packages, root_init_template, output_dir, os.makedirs(os.path.dirname(file_path)) open(file_path, 'a').close() - module_text_map, deprecation_footer_map = get_api_init_text( - packages, output_package, api_name, - api_version, compat_api_versions, lazy_loading, use_relative_imports) + ( + module_text_map, + deprecation_footer_map, + root_module_footer, + ) = get_api_init_text(packages, output_package, api_name, api_version, + compat_api_versions, lazy_loading, use_relative_imports) # Add imports to output files. missing_output_files = [] @@ -652,6 +656,7 @@ def create_api_files(output_files, packages, root_init_template, output_dir, with open(root_init_template, 'r') as root_init_template_file: contents = root_init_template_file.read() contents = contents.replace('# API IMPORTS PLACEHOLDER', text) + contents = contents.replace('# __all__ PLACEHOLDER', root_module_footer) elif module in compat_module_to_template: # Read base init file for compat module with open(compat_module_to_template[module], 'r') as init_template_file: diff --git a/tensorflow/python/tools/api/generator/create_python_api_test.py b/tensorflow/python/tools/api/generator/create_python_api_test.py index 010f189dcb2..76404d6c82b 100644 --- a/tensorflow/python/tools/api/generator/create_python_api_test.py +++ b/tensorflow/python/tools/api/generator/create_python_api_test.py @@ -62,7 +62,7 @@ class CreatePythonApiTest(test.TestCase): del sys.modules[_MODULE_NAME] def testFunctionImportIsAdded(self): - imports, _ = create_python_api.get_api_init_text( + imports, _, _ = create_python_api.get_api_init_text( packages=[create_python_api._DEFAULT_PACKAGE], output_package='tensorflow', api_name='tensorflow', @@ -97,7 +97,7 @@ class CreatePythonApiTest(test.TestCase): msg='compat.v1 in %s' % str(imports.keys())) def testClassImportIsAdded(self): - imports, _ = create_python_api.get_api_init_text( + imports, _, _ = create_python_api.get_api_init_text( packages=[create_python_api._DEFAULT_PACKAGE], output_package='tensorflow', api_name='tensorflow', @@ -116,7 +116,7 @@ class CreatePythonApiTest(test.TestCase): msg='%s not in %s' % (expected_import, str(imports))) def testConstantIsAdded(self): - imports, _ = create_python_api.get_api_init_text( + imports, _, _ = create_python_api.get_api_init_text( packages=[create_python_api._DEFAULT_PACKAGE], output_package='tensorflow', api_name='tensorflow', @@ -132,7 +132,7 @@ class CreatePythonApiTest(test.TestCase): msg='%s not in %s' % (expected, str(imports))) def testCompatModuleIsAdded(self): - imports, _ = create_python_api.get_api_init_text( + imports, _, _ = create_python_api.get_api_init_text( packages=[create_python_api._DEFAULT_PACKAGE], output_package='tensorflow', api_name='tensorflow', @@ -144,7 +144,7 @@ class CreatePythonApiTest(test.TestCase): msg='compat.v1.test not in %s' % str(imports.keys())) def testNestedCompatModulesAreAdded(self): - imports, _ = create_python_api.get_api_init_text( + imports, _, _ = create_python_api.get_api_init_text( packages=[create_python_api._DEFAULT_PACKAGE], output_package='tensorflow', api_name='tensorflow', diff --git a/tensorflow/python/tpu/tpu_embedding.py b/tensorflow/python/tpu/tpu_embedding.py index 76483693dfe..d6e77815041 100644 --- a/tensorflow/python/tpu/tpu_embedding.py +++ b/tensorflow/python/tpu/tpu_embedding.py @@ -1044,7 +1044,7 @@ class TPUEmbedding(object): sample_indices = ( enqueue_data.sample_indices if enqueue_data.sample_indices is not None else array_ops.zeros( - (0,), dtype=dtypes.int32)) + (0,), dtype=dtypes.int64)) sample_indices_list.append(sample_indices) aggregation_weights = ( diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py index 3a52d7653f4..1aa8947fb1f 100644 --- a/tensorflow/python/training/moving_averages_test.py +++ b/tensorflow/python/training/moving_averages_test.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -131,7 +130,6 @@ class MovingAveragesTest(test.TestCase): @test_util.deprecated_graph_mode_only def testWeightedMovingAverageBfloat16(self): - bfloat16 = pywrap_tensorflow.TF_bfloat16_type() with self.cached_session() as sess: decay = 0.5 weight = array_ops.placeholder(dtypes.bfloat16, []) @@ -154,7 +152,8 @@ class MovingAveragesTest(test.TestCase): wma_array = sess.run(wma, feed_dict={val: val_2, weight: weight_2}) numerator_2 = numerator_1 * decay + val_2 * weight_2 * (1.0 - decay) denominator_2 = denominator_1 * decay + weight_2 * (1.0 - decay) - self.assertAllClose(bfloat16(numerator_2 / denominator_2), wma_array) + self.assertAllClose( + dtypes._np_bfloat16(numerator_2 / denominator_2), wma_array) def _Repeat(value, dim): diff --git a/tensorflow/stream_executor/cuda/cudnn_stub.cc b/tensorflow/stream_executor/cuda/cudnn_stub.cc index 5a05437480e..f683cecdb52 100644 --- a/tensorflow/stream_executor/cuda/cudnn_stub.cc +++ b/tensorflow/stream_executor/cuda/cudnn_stub.cc @@ -53,7 +53,8 @@ cudnnStatus_t GetSymbolNotFoundError() { return CUDNN_STATUS_INTERNAL_ERROR; } #include "tensorflow/stream_executor/cuda/cudnn_6_0.inc" #elif CUDNN_MINOR < 1 #include "tensorflow/stream_executor/cuda/cudnn_7_0.inc" -#elif CUDNN_MINOR < 3 +// 2 instead of 3: see https://github.com/tensorflow/tensorflow/issues/32350 +#elif CUDNN_MINOR < 2 #include "tensorflow/stream_executor/cuda/cudnn_7_1.inc" #elif CUDNN_MINOR < 4 #include "tensorflow/stream_executor/cuda/cudnn_7_3.inc" diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 11598d9885e..5f9f2296c3c 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -2576,3 +2576,7 @@ def if_mlir(if_true, if_false = []): def tfcompile_extra_flags(): return "" + +def tf_external_workspace_visible(visibility): + # External workspaces can see this target. + return ["//visibility:public"] diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.experimental.pbtxt new file mode 100644 index 00000000000..9c9a67ba712 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.experimental.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.linalg.experimental" +tf_module { + member_method { + name: "conjugate_gradient" + argspec: "args=[\'operator\', \'rhs\', \'preconditioner\', \'x\', \'tol\', \'max_iter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1e-05\', \'20\', \'conjugate_gradient\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt index f645db2f310..632400c6570 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt @@ -76,6 +76,10 @@ tf_module { name: "LinearOperatorZeros" mtype: "" } + member { + name: "experimental" + mtype: "" + } member_method { name: "adjoint" argspec: "args=[\'matrix\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index f09b683f0f9..604f676bf34 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -604,6 +604,18 @@ tf_module { name: "BytesProducedStatsDataset" argspec: "args=[\'input_dataset\', \'tag\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "CSRSparseMatrixComponents" + argspec: "args=[\'csr_sparse_matrix\', \'index\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "CSRSparseMatrixToDense" + argspec: "args=[\'sparse_input\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "CSRSparseMatrixToSparseTensor" + argspec: "args=[\'sparse_matrix\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "CSVDataset" argspec: "args=[\'filenames\', \'compression_type\', \'buffer_size\', \'header\', \'field_delim\', \'use_quote_delim\', \'na_value\', \'select_cols\', \'record_defaults\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -1032,6 +1044,10 @@ tf_module { name: "DeleteSessionTensor" argspec: "args=[\'handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "DenseToCSRSparseMatrix" + argspec: "args=[\'dense_input\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "DenseToDenseSetOperation" argspec: "args=[\'set1\', \'set2\', \'set_operation\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " @@ -3952,6 +3968,50 @@ tf_module { name: "SparseMatMul" argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'a_is_sparse\', \'b_is_sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " } + member_method { + name: "SparseMatrixAdd" + argspec: "args=[\'a\', \'b\', \'alpha\', \'beta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixMatMul" + argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'transpose_output\', \'conjugate_output\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "SparseMatrixMul" + argspec: "args=[\'a\', \'b\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixNNZ" + argspec: "args=[\'sparse_matrix\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixOrderingAMD" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixSoftmax" + argspec: "args=[\'logits\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixSoftmaxGrad" + argspec: "args=[\'softmax\', \'grad_softmax\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixSparseCholesky" + argspec: "args=[\'input\', \'permutation\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixSparseMatMul" + argspec: "args=[\'a\', \'b\', \'type\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "SparseMatrixTranspose" + argspec: "args=[\'input\', \'type\', \'conjugate\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "SparseMatrixZeros" + argspec: "args=[\'dense_shape\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "SparseReduceMax" argspec: "args=[\'input_indices\', \'input_values\', \'input_shape\', \'reduction_axes\', \'keep_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " @@ -4048,6 +4108,10 @@ tf_module { name: "SparseTensorSliceDataset" argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "SparseTensorToCSRSparseMatrix" + argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "SparseToDense" argspec: "args=[\'sparse_indices\', \'output_shape\', \'sparse_values\', \'default_value\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.experimental.pbtxt new file mode 100644 index 00000000000..9c9a67ba712 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.experimental.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.linalg.experimental" +tf_module { + member_method { + name: "conjugate_gradient" + argspec: "args=[\'operator\', \'rhs\', \'preconditioner\', \'x\', \'tol\', \'max_iter\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1e-05\', \'20\', \'conjugate_gradient\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt index a58c988577a..041041f60ed 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt @@ -76,6 +76,10 @@ tf_module { name: "LinearOperatorZeros" mtype: "" } + member { + name: "experimental" + mtype: "" + } member_method { name: "adjoint" argspec: "args=[\'matrix\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index f09b683f0f9..604f676bf34 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -604,6 +604,18 @@ tf_module { name: "BytesProducedStatsDataset" argspec: "args=[\'input_dataset\', \'tag\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "CSRSparseMatrixComponents" + argspec: "args=[\'csr_sparse_matrix\', \'index\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "CSRSparseMatrixToDense" + argspec: "args=[\'sparse_input\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "CSRSparseMatrixToSparseTensor" + argspec: "args=[\'sparse_matrix\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "CSVDataset" argspec: "args=[\'filenames\', \'compression_type\', \'buffer_size\', \'header\', \'field_delim\', \'use_quote_delim\', \'na_value\', \'select_cols\', \'record_defaults\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -1032,6 +1044,10 @@ tf_module { name: "DeleteSessionTensor" argspec: "args=[\'handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "DenseToCSRSparseMatrix" + argspec: "args=[\'dense_input\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "DenseToDenseSetOperation" argspec: "args=[\'set1\', \'set2\', \'set_operation\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " @@ -3952,6 +3968,50 @@ tf_module { name: "SparseMatMul" argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'a_is_sparse\', \'b_is_sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " } + member_method { + name: "SparseMatrixAdd" + argspec: "args=[\'a\', \'b\', \'alpha\', \'beta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixMatMul" + argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'transpose_output\', \'conjugate_output\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "SparseMatrixMul" + argspec: "args=[\'a\', \'b\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixNNZ" + argspec: "args=[\'sparse_matrix\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixOrderingAMD" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixSoftmax" + argspec: "args=[\'logits\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixSoftmaxGrad" + argspec: "args=[\'softmax\', \'grad_softmax\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixSparseCholesky" + argspec: "args=[\'input\', \'permutation\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseMatrixSparseMatMul" + argspec: "args=[\'a\', \'b\', \'type\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "SparseMatrixTranspose" + argspec: "args=[\'input\', \'type\', \'conjugate\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "SparseMatrixZeros" + argspec: "args=[\'dense_shape\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "SparseReduceMax" argspec: "args=[\'input_indices\', \'input_values\', \'input_shape\', \'reduction_axes\', \'keep_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " @@ -4048,6 +4108,10 @@ tf_module { name: "SparseTensorSliceDataset" argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "SparseTensorToCSRSparseMatrix" + argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "SparseToDense" argspec: "args=[\'sparse_indices\', \'output_shape\', \'sparse_values\', \'default_value\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " diff --git a/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod-jupyter.Dockerfile new file mode 100644 index 00000000000..92118f0ade8 --- /dev/null +++ b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod-jupyter.Dockerfile @@ -0,0 +1,181 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# +# THIS IS A GENERATED DOCKERFILE. +# +# This file was assembled from multiple pieces, whose use is documented +# throughout. Please refer to the TensorFlow dockerfiles documentation +# for more information. + +ARG UBUNTU_VERSION=18.04 + +FROM ubuntu:${UBUNTU_VERSION} AS base + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + git \ + libcurl3-dev \ + libfreetype6-dev \ + libhdf5-serial-dev \ + libzmq3-dev \ + pkg-config \ + rsync \ + software-properties-common \ + sudo \ + unzip \ + zip \ + zlib1g-dev \ + openjdk-8-jdk \ + openjdk-8-jre-headless \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +ENV CI_BUILD_PYTHON python + +# CACHE_STOP is used to rerun future commands, otherwise cloning tensorflow will be cached and will not pull the most recent version +ARG CACHE_STOP=1 +# Check out TensorFlow source code if --build-arg CHECKOUT_TF_SRC=1 +ARG CHECKOUT_TF_SRC=0 +RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true + +ARG USE_PYTHON_3_NOT_2 +ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} +ARG PYTHON=python${_PY_SUFFIX} +ARG PIP=pip${_PY_SUFFIX} + +# See http://bugs.python.org/issue19846 +ENV LANG C.UTF-8 + +RUN apt-get update && apt-get install -y \ + ${PYTHON} \ + ${PYTHON}-pip + +RUN ${PIP} --no-cache-dir install --upgrade \ + pip \ + setuptools + +# Some TF tools expect a "python" binary +RUN ln -s $(which ${PYTHON}) /usr/local/bin/python + +RUN apt-get update && apt-get install -y \ + build-essential \ + curl \ + git \ + wget \ + openjdk-8-jdk \ + ${PYTHON}-dev \ + virtualenv \ + swig + +RUN ${PIP} --no-cache-dir install \ + Pillow \ + h5py \ + keras_applications \ + keras_preprocessing \ + matplotlib \ + mock \ + numpy \ + scipy \ + sklearn \ + pandas \ + future \ + portpicker \ + && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \ + enum34 + +# Install bazel +ARG BAZEL_VERSION=0.24.1 +RUN mkdir /bazel && \ + wget -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \ + wget -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \ + chmod +x /bazel/installer.sh && \ + /bazel/installer.sh && \ + rm -f /bazel/installer.sh + +# install libnuma, openssh, wget +RUN ( apt-get update && apt-get install -y --no-install-recommends --fix-missing \ + libnuma-dev \ + openssh-server \ + openssh-client \ + wget && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* ) || \ + ( yum -y update && yum -y install \ + numactl-devel \ + openssh-server \ + openssh-clients \ + wget && \ + yum clean all ) || \ + ( echo "Unsupported Linux distribution. Aborting!" && exit 1 ) + +# Install Open MPI +# download realese version from official website as openmpi github master is not always stable +ARG OPENMPI_VERSION=openmpi-4.0.0 +ARG OPENMPI_DOWNLOAD_URL=https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz +RUN mkdir /tmp/openmpi && \ + cd /tmp/openmpi && \ + wget ${OPENMPI_DOWNLOAD_URL} && \ + tar zxf ${OPENMPI_VERSION}.tar.gz && \ + cd ${OPENMPI_VERSION} && \ + ./configure --enable-orterun-prefix-by-default && \ + make -j $(nproc) all && \ + make install && \ + ldconfig && \ + rm -rf /tmp/openmpi + +# Create a wrapper for OpenMPI to allow running as root by default +RUN mv /usr/local/bin/mpirun /usr/local/bin/mpirun.real && \ + echo '#!/bin/bash' > /usr/local/bin/mpirun && \ + echo 'mpirun.real --allow-run-as-root "$@"' >> /usr/local/bin/mpirun && \ + chmod a+x /usr/local/bin/mpirun + +# Configure OpenMPI to run good defaults: +RUN echo "btl_tcp_if_exclude = lo,docker0" >> /usr/local/etc/openmpi-mca-params.conf + +# Install OpenSSH for MPI to communicate between containers +RUN mkdir -p /var/run/sshd + +# Allow OpenSSH to talk to containers without asking for confirmation +RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \ + echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \ + mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config + +# Check out horovod source code if --build-arg CHECKOUT_HOROVOD_SRC=1 +ARG CHECKOUT_HOROVOD_SRC=0 +RUN test "${CHECKOUT_HOROVOD_SRC}" -eq 1 && git clone --recursive https://github.com/uber/horovod.git /horovod_src || true + +COPY bashrc /etc/bash.bashrc +RUN chmod a+rwx /etc/bash.bashrc + +RUN ${PIP} install jupyter matplotlib +RUN ${PIP} install jupyter_http_over_ws +RUN jupyter serverextension enable --py jupyter_http_over_ws + +RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/ +RUN mkdir /.local && chmod a+rwx /.local +RUN apt-get install -y --no-install-recommends wget +WORKDIR /tf/tensorflow-tutorials +RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/basic_classification.ipynb +RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/basic_text_classification.ipynb +COPY readme-for-jupyter.md README.md +RUN apt-get autoremove -y && apt-get remove -y wget +WORKDIR /tf +EXPOSE 8888 + +RUN ${PYTHON} -m ipykernel.kernelspec + +CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/tf --ip 0.0.0.0 --no-browser --allow-root"] diff --git a/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod.Dockerfile new file mode 100644 index 00000000000..338474678d2 --- /dev/null +++ b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod.Dockerfile @@ -0,0 +1,162 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# +# THIS IS A GENERATED DOCKERFILE. +# +# This file was assembled from multiple pieces, whose use is documented +# throughout. Please refer to the TensorFlow dockerfiles documentation +# for more information. + +ARG UBUNTU_VERSION=18.04 + +FROM ubuntu:${UBUNTU_VERSION} AS base + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + git \ + libcurl3-dev \ + libfreetype6-dev \ + libhdf5-serial-dev \ + libzmq3-dev \ + pkg-config \ + rsync \ + software-properties-common \ + sudo \ + unzip \ + zip \ + zlib1g-dev \ + openjdk-8-jdk \ + openjdk-8-jre-headless \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +ENV CI_BUILD_PYTHON python + +# CACHE_STOP is used to rerun future commands, otherwise cloning tensorflow will be cached and will not pull the most recent version +ARG CACHE_STOP=1 +# Check out TensorFlow source code if --build-arg CHECKOUT_TF_SRC=1 +ARG CHECKOUT_TF_SRC=0 +RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true + +ARG USE_PYTHON_3_NOT_2 +ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} +ARG PYTHON=python${_PY_SUFFIX} +ARG PIP=pip${_PY_SUFFIX} + +# See http://bugs.python.org/issue19846 +ENV LANG C.UTF-8 + +RUN apt-get update && apt-get install -y \ + ${PYTHON} \ + ${PYTHON}-pip + +RUN ${PIP} --no-cache-dir install --upgrade \ + pip \ + setuptools + +# Some TF tools expect a "python" binary +RUN ln -s $(which ${PYTHON}) /usr/local/bin/python + +RUN apt-get update && apt-get install -y \ + build-essential \ + curl \ + git \ + wget \ + openjdk-8-jdk \ + ${PYTHON}-dev \ + virtualenv \ + swig + +RUN ${PIP} --no-cache-dir install \ + Pillow \ + h5py \ + keras_applications \ + keras_preprocessing \ + matplotlib \ + mock \ + numpy \ + scipy \ + sklearn \ + pandas \ + future \ + portpicker \ + && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \ + enum34 + +# Install bazel +ARG BAZEL_VERSION=0.24.1 +RUN mkdir /bazel && \ + wget -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \ + wget -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \ + chmod +x /bazel/installer.sh && \ + /bazel/installer.sh && \ + rm -f /bazel/installer.sh + +# install libnuma, openssh, wget +RUN ( apt-get update && apt-get install -y --no-install-recommends --fix-missing \ + libnuma-dev \ + openssh-server \ + openssh-client \ + wget && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* ) || \ + ( yum -y update && yum -y install \ + numactl-devel \ + openssh-server \ + openssh-clients \ + wget && \ + yum clean all ) || \ + ( echo "Unsupported Linux distribution. Aborting!" && exit 1 ) + +# Install Open MPI +# download realese version from official website as openmpi github master is not always stable +ARG OPENMPI_VERSION=openmpi-4.0.0 +ARG OPENMPI_DOWNLOAD_URL=https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz +RUN mkdir /tmp/openmpi && \ + cd /tmp/openmpi && \ + wget ${OPENMPI_DOWNLOAD_URL} && \ + tar zxf ${OPENMPI_VERSION}.tar.gz && \ + cd ${OPENMPI_VERSION} && \ + ./configure --enable-orterun-prefix-by-default && \ + make -j $(nproc) all && \ + make install && \ + ldconfig && \ + rm -rf /tmp/openmpi + +# Create a wrapper for OpenMPI to allow running as root by default +RUN mv /usr/local/bin/mpirun /usr/local/bin/mpirun.real && \ + echo '#!/bin/bash' > /usr/local/bin/mpirun && \ + echo 'mpirun.real --allow-run-as-root "$@"' >> /usr/local/bin/mpirun && \ + chmod a+x /usr/local/bin/mpirun + +# Configure OpenMPI to run good defaults: +RUN echo "btl_tcp_if_exclude = lo,docker0" >> /usr/local/etc/openmpi-mca-params.conf + +# Install OpenSSH for MPI to communicate between containers +RUN mkdir -p /var/run/sshd + +# Allow OpenSSH to talk to containers without asking for confirmation +RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \ + echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \ + mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config + +# Check out horovod source code if --build-arg CHECKOUT_HOROVOD_SRC=1 +ARG CHECKOUT_HOROVOD_SRC=0 +RUN test "${CHECKOUT_HOROVOD_SRC}" -eq 1 && git clone --recursive https://github.com/uber/horovod.git /horovod_src || true + +COPY bashrc /etc/bash.bashrc +RUN chmod a+rwx /etc/bash.bashrc diff --git a/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod-jupyter.Dockerfile new file mode 100644 index 00000000000..5ba0fe65500 --- /dev/null +++ b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod-jupyter.Dockerfile @@ -0,0 +1,128 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# +# THIS IS A GENERATED DOCKERFILE. +# +# This file was assembled from multiple pieces, whose use is documented +# throughout. Please refer to the TensorFlow dockerfiles documentation +# for more information. + +ARG UBUNTU_VERSION=18.04 + +FROM ubuntu:${UBUNTU_VERSION} as base + +ARG USE_PYTHON_3_NOT_2 +ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} +ARG PYTHON=python${_PY_SUFFIX} +ARG PIP=pip${_PY_SUFFIX} + +# See http://bugs.python.org/issue19846 +ENV LANG C.UTF-8 + +RUN apt-get update && apt-get install -y \ + ${PYTHON} \ + ${PYTHON}-pip + +RUN ${PIP} --no-cache-dir install --upgrade \ + pip \ + setuptools + +# Some TF tools expect a "python" binary +RUN ln -s $(which ${PYTHON}) /usr/local/bin/python + +# Options: +# tensorflow +# tensorflow-gpu +# tf-nightly +# tf-nightly-gpu +# Set --build-arg TF_PACKAGE_VERSION=1.11.0rc0 to install a specific version. +# Installs the latest version by default. +ARG TF_PACKAGE=tensorflow +ARG TF_PACKAGE_VERSION= +RUN ${PIP} install ${TF_PACKAGE}${TF_PACKAGE_VERSION:+==${TF_PACKAGE_VERSION}} + +# install libnuma, openssh, wget +RUN ( apt-get update && apt-get install -y --no-install-recommends --fix-missing \ + libnuma-dev \ + openssh-server \ + openssh-client \ + wget && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* ) || \ + ( yum -y update && yum -y install \ + numactl-devel \ + openssh-server \ + openssh-clients \ + wget && \ + yum clean all ) || \ + ( echo "Unsupported Linux distribution. Aborting!" && exit 1 ) + +# Install Open MPI +# download realese version from official website as openmpi github master is not always stable +ARG OPENMPI_VERSION=openmpi-4.0.0 +ARG OPENMPI_DOWNLOAD_URL=https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz +RUN mkdir /tmp/openmpi && \ + cd /tmp/openmpi && \ + wget ${OPENMPI_DOWNLOAD_URL} && \ + tar zxf ${OPENMPI_VERSION}.tar.gz && \ + cd ${OPENMPI_VERSION} && \ + ./configure --enable-orterun-prefix-by-default && \ + make -j $(nproc) all && \ + make install && \ + ldconfig && \ + rm -rf /tmp/openmpi + +# Create a wrapper for OpenMPI to allow running as root by default +RUN mv /usr/local/bin/mpirun /usr/local/bin/mpirun.real && \ + echo '#!/bin/bash' > /usr/local/bin/mpirun && \ + echo 'mpirun.real --allow-run-as-root "$@"' >> /usr/local/bin/mpirun && \ + chmod a+x /usr/local/bin/mpirun + +# Configure OpenMPI to run good defaults: +RUN echo "btl_tcp_if_exclude = lo,docker0" >> /usr/local/etc/openmpi-mca-params.conf + +# Install OpenSSH for MPI to communicate between containers +RUN mkdir -p /var/run/sshd + +# Allow OpenSSH to talk to containers without asking for confirmation +RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \ + echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \ + mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config + +# Install Horovod +ARG HOROVOD_VERSION=0.16.4 +RUN ${PIP} install --no-cache-dir horovod==${HOROVOD_VERSION} + +COPY bashrc /etc/bash.bashrc +RUN chmod a+rwx /etc/bash.bashrc + +RUN ${PIP} install jupyter matplotlib +RUN ${PIP} install jupyter_http_over_ws +RUN jupyter serverextension enable --py jupyter_http_over_ws + +RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/ +RUN mkdir /.local && chmod a+rwx /.local +RUN apt-get install -y --no-install-recommends wget +WORKDIR /tf/tensorflow-tutorials +RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/basic_classification.ipynb +RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/basic_text_classification.ipynb +COPY readme-for-jupyter.md README.md +RUN apt-get autoremove -y && apt-get remove -y wget +WORKDIR /tf +EXPOSE 8888 + +RUN ${PYTHON} -m ipykernel.kernelspec + +CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/tf --ip 0.0.0.0 --no-browser --allow-root"] diff --git a/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod.Dockerfile new file mode 100644 index 00000000000..e08b910a1bb --- /dev/null +++ b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod.Dockerfile @@ -0,0 +1,109 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# +# THIS IS A GENERATED DOCKERFILE. +# +# This file was assembled from multiple pieces, whose use is documented +# throughout. Please refer to the TensorFlow dockerfiles documentation +# for more information. + +ARG UBUNTU_VERSION=18.04 + +FROM ubuntu:${UBUNTU_VERSION} as base + +ARG USE_PYTHON_3_NOT_2 +ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} +ARG PYTHON=python${_PY_SUFFIX} +ARG PIP=pip${_PY_SUFFIX} + +# See http://bugs.python.org/issue19846 +ENV LANG C.UTF-8 + +RUN apt-get update && apt-get install -y \ + ${PYTHON} \ + ${PYTHON}-pip + +RUN ${PIP} --no-cache-dir install --upgrade \ + pip \ + setuptools + +# Some TF tools expect a "python" binary +RUN ln -s $(which ${PYTHON}) /usr/local/bin/python + +# Options: +# tensorflow +# tensorflow-gpu +# tf-nightly +# tf-nightly-gpu +# Set --build-arg TF_PACKAGE_VERSION=1.11.0rc0 to install a specific version. +# Installs the latest version by default. +ARG TF_PACKAGE=tensorflow +ARG TF_PACKAGE_VERSION= +RUN ${PIP} install ${TF_PACKAGE}${TF_PACKAGE_VERSION:+==${TF_PACKAGE_VERSION}} + +# install libnuma, openssh, wget +RUN ( apt-get update && apt-get install -y --no-install-recommends --fix-missing \ + libnuma-dev \ + openssh-server \ + openssh-client \ + wget && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* ) || \ + ( yum -y update && yum -y install \ + numactl-devel \ + openssh-server \ + openssh-clients \ + wget && \ + yum clean all ) || \ + ( echo "Unsupported Linux distribution. Aborting!" && exit 1 ) + +# Install Open MPI +# download realese version from official website as openmpi github master is not always stable +ARG OPENMPI_VERSION=openmpi-4.0.0 +ARG OPENMPI_DOWNLOAD_URL=https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz +RUN mkdir /tmp/openmpi && \ + cd /tmp/openmpi && \ + wget ${OPENMPI_DOWNLOAD_URL} && \ + tar zxf ${OPENMPI_VERSION}.tar.gz && \ + cd ${OPENMPI_VERSION} && \ + ./configure --enable-orterun-prefix-by-default && \ + make -j $(nproc) all && \ + make install && \ + ldconfig && \ + rm -rf /tmp/openmpi + +# Create a wrapper for OpenMPI to allow running as root by default +RUN mv /usr/local/bin/mpirun /usr/local/bin/mpirun.real && \ + echo '#!/bin/bash' > /usr/local/bin/mpirun && \ + echo 'mpirun.real --allow-run-as-root "$@"' >> /usr/local/bin/mpirun && \ + chmod a+x /usr/local/bin/mpirun + +# Configure OpenMPI to run good defaults: +RUN echo "btl_tcp_if_exclude = lo,docker0" >> /usr/local/etc/openmpi-mca-params.conf + +# Install OpenSSH for MPI to communicate between containers +RUN mkdir -p /var/run/sshd + +# Allow OpenSSH to talk to containers without asking for confirmation +RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \ + echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \ + mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config + +# Install Horovod +ARG HOROVOD_VERSION=0.16.4 +RUN ${PIP} install --no-cache-dir horovod==${HOROVOD_VERSION} + +COPY bashrc /etc/bash.bashrc +RUN chmod a+rwx /etc/bash.bashrc diff --git a/tensorflow/tools/dockerfiles/partials/mkl_horovod/devel-horovod.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/mkl_horovod/devel-horovod.partial.Dockerfile new file mode 100644 index 00000000000..dab42914df3 --- /dev/null +++ b/tensorflow/tools/dockerfiles/partials/mkl_horovod/devel-horovod.partial.Dockerfile @@ -0,0 +1,3 @@ +# Check out horovod source code if --build-arg CHECKOUT_HOROVOD_SRC=1 +ARG CHECKOUT_HOROVOD_SRC=0 +RUN test "${CHECKOUT_HOROVOD_SRC}" -eq 1 && git clone --recursive https://github.com/uber/horovod.git /horovod_src || true diff --git a/tensorflow/tools/dockerfiles/partials/mkl_horovod/horovod.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/mkl_horovod/horovod.partial.Dockerfile new file mode 100644 index 00000000000..b2bb20f713d --- /dev/null +++ b/tensorflow/tools/dockerfiles/partials/mkl_horovod/horovod.partial.Dockerfile @@ -0,0 +1,3 @@ +# Install Horovod +ARG HOROVOD_VERSION=0.16.4 +RUN ${PIP} install --no-cache-dir horovod==${HOROVOD_VERSION} diff --git a/tensorflow/tools/dockerfiles/partials/mkl_horovod/mpi.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/mkl_horovod/mpi.partial.Dockerfile new file mode 100644 index 00000000000..67055ab244a --- /dev/null +++ b/tensorflow/tools/dockerfiles/partials/mkl_horovod/mpi.partial.Dockerfile @@ -0,0 +1,47 @@ +# install libnuma, openssh, wget +RUN ( apt-get update && apt-get install -y --no-install-recommends --fix-missing \ + libnuma-dev \ + openssh-server \ + openssh-client \ + wget && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* ) || \ + ( yum -y update && yum -y install \ + numactl-devel \ + openssh-server \ + openssh-clients \ + wget && \ + yum clean all ) || \ + ( echo "Unsupported Linux distribution. Aborting!" && exit 1 ) + +# Install Open MPI +# download realese version from official website as openmpi github master is not always stable +ARG OPENMPI_VERSION=openmpi-4.0.0 +ARG OPENMPI_DOWNLOAD_URL=https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz +RUN mkdir /tmp/openmpi && \ + cd /tmp/openmpi && \ + wget ${OPENMPI_DOWNLOAD_URL} && \ + tar zxf ${OPENMPI_VERSION}.tar.gz && \ + cd ${OPENMPI_VERSION} && \ + ./configure --enable-orterun-prefix-by-default && \ + make -j $(nproc) all && \ + make install && \ + ldconfig && \ + rm -rf /tmp/openmpi + +# Create a wrapper for OpenMPI to allow running as root by default +RUN mv /usr/local/bin/mpirun /usr/local/bin/mpirun.real && \ + echo '#!/bin/bash' > /usr/local/bin/mpirun && \ + echo 'mpirun.real --allow-run-as-root "$@"' >> /usr/local/bin/mpirun && \ + chmod a+x /usr/local/bin/mpirun + +# Configure OpenMPI to run good defaults: +RUN echo "btl_tcp_if_exclude = lo,docker0" >> /usr/local/etc/openmpi-mca-params.conf + +# Install OpenSSH for MPI to communicate between containers +RUN mkdir -p /var/run/sshd + +# Allow OpenSSH to talk to containers without asking for confirmation +RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \ + echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \ + mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu/python.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu/python.partial.Dockerfile index 6af47319538..602bdbf5606 100644 --- a/tensorflow/tools/dockerfiles/partials/ubuntu/python.partial.Dockerfile +++ b/tensorflow/tools/dockerfiles/partials/ubuntu/python.partial.Dockerfile @@ -15,4 +15,4 @@ RUN ${PIP} --no-cache-dir install --upgrade \ setuptools # Some TF tools expect a "python" binary -RUN ln -s $(which ${PYTHON}) /usr/local/bin/python +RUN ln -s $(which ${PYTHON}) /usr/local/bin/python diff --git a/tensorflow/tools/dockerfiles/spec.yml b/tensorflow/tools/dockerfiles/spec.yml index 79fb7785d8f..5a64b70bacb 100644 --- a/tensorflow/tools/dockerfiles/spec.yml +++ b/tensorflow/tools/dockerfiles/spec.yml @@ -1,5 +1,5 @@ header: | - # Copyright 2018 The TensorFlow Authors. All Rights Reserved. + # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -83,6 +83,21 @@ slice_sets: - ubuntu/python - tensorflow - shell + - add_to_name: "-horovod" + dockerfile_exclusive_name: "horovod" + dockerfile_subdirectory: "mkl_horovod" + partials: + - ubuntu/version + - ubuntu/cpu + - ubuntu/python + - tensorflow + - mkl_horovod/mpi + - mkl_horovod/horovod + - shell + tests: + - import-mkl-horovod.sh + args: + - TF_PACKAGE=intel-tensorflow - add_to_name: "-gpu" dockerfile_exclusive_name: "gpu" args: @@ -110,6 +125,22 @@ slice_sets: - build-cpu.sh args: - CHECKOUT_TF_SRC=1 + - add_to_name: "devel-horovod" + dockerfile_exclusive_name: "devel-horovod" + dockerfile_subdirectory: "mkl_horovod" + partials: + - ubuntu/version + - ubuntu/devel-cpu + - ubuntu/python + - ubuntu/bazel + - mkl_horovod/mpi + - mkl_horovod/devel-horovod + - shell + tests: + - build-mkl-horovod.sh + args: + - CHECKOUT_TF_SRC=1 + - CHECKOUT_HOROVOD_SRC=1 - add_to_name: "devel-gpu" dockerfile_exclusive_name: "devel-gpu" partials: diff --git a/tensorflow/tools/dockerfiles/tests/build-mkl-horovod.sh b/tensorflow/tools/dockerfiles/tests/build-mkl-horovod.sh new file mode 100755 index 00000000000..62c2ffbc471 --- /dev/null +++ b/tensorflow/tools/dockerfiles/tests/build-mkl-horovod.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash + +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + + + +# Download and build TensorFlow. +set -euxo pipefail +git clone --branch=master --depth=1 https://github.com/tensorflow/tensorflow.git /tensorflow +cd /tensorflow + +ln -s $(which ${PYTHON}) /usr/local/bin/python + +# Build TensorFlow with support for Intel(R) MKL-DNN +yes "" | ${PYTHON} configure.py && \ + bazel build -c opt --config=mkl --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" \ + tensorflow/tools/pip_package:build_pip_package && \ + bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/pip && \ + pip --no-cache-dir install --upgrade /tmp/pip/tensorflow-*.whl && \ + rm -rf /tmp/pip && \ + rm -rf /root/.cache + + +# download and build Horovod +git clone --recursive https://github.com/uber/horovod.git +cd horovod +# export environment +export HOROVOD_WITHOUT_PYTORCH=1 +export HOROVOD_WITH_TENSORFLOW=1 +python setup.py sdist +pip --no-cache-dir install --upgrade sdist/horovod*.tar.gz && \ + rm -rf sdist && \ + rm -rf /root/.cache diff --git a/tensorflow/tools/dockerfiles/tests/import-mkl-horovod.sh b/tensorflow/tools/dockerfiles/tests/import-mkl-horovod.sh new file mode 100755 index 00000000000..b1cae48c6ee --- /dev/null +++ b/tensorflow/tools/dockerfiles/tests/import-mkl-horovod.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash + +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +python -c 'from tensorflow.python import pywrap_tensorflow; pywrap_tensorflow.IsMklEnabled() or exit(1); import horovod.tensorflow as hvd' diff --git a/tensorflow/virtual_root_template_v1.__init__.py b/tensorflow/virtual_root_template_v1.__init__.py index 236e9f52258..9a45bc0355d 100644 --- a/tensorflow/virtual_root_template_v1.__init__.py +++ b/tensorflow/virtual_root_template_v1.__init__.py @@ -132,7 +132,4 @@ try: except NameError: pass -# Manually patch keras and estimator so tf.keras and tf.estimator work -keras = _sys.modules["tensorflow.keras"] -if not _root_estimator: estimator = _sys.modules["tensorflow.estimator"] # LINT.ThenChange(//tensorflow/virtual_root_template_v2.__init__.py.oss) diff --git a/tensorflow/virtual_root_template_v2.__init__.py b/tensorflow/virtual_root_template_v2.__init__.py index 83c020182a8..bd8c903e455 100644 --- a/tensorflow/virtual_root_template_v2.__init__.py +++ b/tensorflow/virtual_root_template_v2.__init__.py @@ -126,14 +126,4 @@ try: except NameError: pass -# TODO(mihaimaruseac): Revisit all of this once we release 2.1 -# Manually patch keras and estimator so tf.keras and tf.estimator work -keras = _sys.modules["tensorflow.keras"] -if not _root_estimator: estimator = _sys.modules["tensorflow.estimator"] -# Also import module aliases -try: - from tensorflow_core import losses, metrics, initializers, optimizers -except ImportError: - pass - # LINT.ThenChange(//tensorflow/virtual_root_template_v1.__init__.py.oss) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 9f1d4cce63a..a79444c221c 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -172,11 +172,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): name = "eigen_archive", build_file = clean_dep("//third_party:eigen.BUILD"), patch_file = clean_dep("//third_party/eigen3:gpu_packet_math.patch"), - sha256 = "6d8ed482addd14892d7b0bd98fec2c02f18fdab97775bda68c3f2a99ffb190fb", - strip_prefix = "eigen-eigen-66be6c76fc01", + sha256 = "9edd4860b52813eaf8c023f0de1767ec58e2d67a290b718e6702469208ac5be1", + strip_prefix = "eigen-eigen-54bca9936424", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/bitbucket.org/eigen/eigen/get/66be6c76fc01.tar.gz", - "https://bitbucket.org/eigen/eigen/get/66be6c76fc01.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/bitbucket.org/eigen/eigen/get/54bca9936424.tar.gz", + "https://bitbucket.org/eigen/eigen/get/54bca9936424.tar.gz", ], ) diff --git a/third_party/llvm/llvm.bzl b/third_party/llvm/llvm.bzl index 5a478827980..0d06b7e8df7 100644 --- a/third_party/llvm/llvm.bzl +++ b/third_party/llvm/llvm.bzl @@ -294,7 +294,7 @@ win32_cmake_vars = { # ThreadPoolExecutor global destructor and thread handshaking do not work # on this platform when used as a DLL. - # See: https://github.com/google/iree/issues/114 + # See: https://bugs.llvm.org/show_bug.cgi?id=44211 "LLVM_ENABLE_THREADS": 0, } diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD index 3b09b3cb470..26e03c46df9 100644 --- a/third_party/mlir/BUILD +++ b/third_party/mlir/BUILD @@ -79,7 +79,7 @@ cc_library( "lib/IR/Diagnostics.cpp", "lib/IR/Dialect.cpp", "lib/IR/Function.cpp", - "lib/IR/FunctionSupport.cpp", + "lib/IR/FunctionImplementation.cpp", "lib/IR/IntegerSet.cpp", "lib/IR/IntegerSetDetail.h", "lib/IR/Location.cpp", @@ -114,6 +114,7 @@ cc_library( "include/mlir/IR/DialectImplementation.h", "include/mlir/IR/DialectInterface.h", "include/mlir/IR/Function.h", + "include/mlir/IR/FunctionImplementation.h", "include/mlir/IR/FunctionSupport.h", "include/mlir/IR/Identifier.h", "include/mlir/IR/IntegerSet.h", @@ -195,13 +196,11 @@ cc_library( includes = ["include"], deps = [ ":AffineOps", - ":Analysis", ":IR", ":LoopOps", ":StandardOps", ":Support", ":TransformUtils", - ":VectorOps", "@llvm//:support", ], ) @@ -474,15 +473,20 @@ cc_library( name = "VectorOps", srcs = [ "lib/Dialect/VectorOps/VectorOps.cpp", + "lib/Dialect/VectorOps/VectorToVector.cpp", ], hdrs = [ + "include/mlir/Dialect/VectorOps/Utils.h", "include/mlir/Dialect/VectorOps/VectorOps.h", + "include/mlir/Dialect/VectorOps/VectorTransforms.h", ], includes = ["include"], deps = [ + ":EDSC", ":IR", ":Support", ":VectorOpsIncGen", + ":VectorTransformPatterns", "@llvm//:support", ], ) @@ -950,7 +954,9 @@ filegroup( "include/mlir/Dialect/SPIRV/SPIRVCastOps.td", "include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td", "include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td", + "include/mlir/Dialect/SPIRV/SPIRVGroupOps.td", "include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td", + "include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td", "include/mlir/Dialect/SPIRV/SPIRVOps.td", "include/mlir/Dialect/SPIRV/SPIRVStructureOps.td", ":OpBaseTdFiles", @@ -1135,6 +1141,7 @@ cc_library( srcs = [ "lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp", "lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp", + "lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp", ], hdrs = [ "include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h", @@ -1298,6 +1305,7 @@ cc_library( ":StandardOps", ":Support", ":TransformUtils", + ":VectorAnalysis", ":VectorOps", "@llvm//:support", ], @@ -1452,7 +1460,6 @@ cc_library( "lib/Analysis/TestMemRefDependenceCheck.cpp", "lib/Analysis/TestParallelismDetection.cpp", "lib/Analysis/Utils.cpp", - "lib/Analysis/VectorAnalysis.cpp", "lib/Analysis/Verifier.cpp", ], hdrs = [ @@ -1467,7 +1474,6 @@ cc_library( "include/mlir/Analysis/Passes.h", "include/mlir/Analysis/SliceAnalysis.h", "include/mlir/Analysis/Utils.h", - "include/mlir/Analysis/VectorAnalysis.h", "include/mlir/Analysis/Verifier.h", ], includes = ["include"], @@ -1480,6 +1486,23 @@ cc_library( ":Pass", ":StandardOps", ":Support", + "@llvm//:support", + ], + alwayslink = 1, +) + +cc_library( + name = "VectorAnalysis", + srcs = [ + "lib/Analysis/VectorAnalysis.cpp", + ], + includes = ["include"], + deps = [ + ":AffineOps", + ":Analysis", + ":IR", + ":StandardOps", + ":Support", ":VectorOps", "@llvm//:support", ], @@ -1665,7 +1688,8 @@ cc_library( ":StandardToSPIRVConversions", ":Support", ":Transforms", - ":VectorConversions", + ":VectorToLLVM", + ":VectorToLoops", ":ViewOpGraph", ":ViewRegionGraph", "@llvm//:support", @@ -2198,7 +2222,7 @@ cc_library( ":StandardOps", ":Support", ":Transforms", - ":VectorConversions", + ":VectorToLLVM", "@llvm//:core", "@llvm//:support", ], @@ -2212,8 +2236,8 @@ cc_library( "lib/Dialect/Linalg/IR/LinalgOps.cpp", "lib/Dialect/Linalg/IR/LinalgTypes.cpp", "lib/Dialect/Linalg/Transforms/Fusion.cpp", + "lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp", "lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp", - "lib/Dialect/Linalg/Transforms/LowerToLoops.cpp", "lib/Dialect/Linalg/Transforms/Promotion.cpp", "lib/Dialect/Linalg/Transforms/Tiling.cpp", "lib/Dialect/Linalg/Utils/Utils.cpp", @@ -2368,18 +2392,40 @@ gentbl( ) cc_library( - name = "VectorConversions", + name = "VectorToLLVM", srcs = [ - "lib/Conversion/VectorConversions/VectorToLLVM.cpp", - "lib/Conversion/VectorConversions/VectorToLoops.cpp", - "lib/Conversion/VectorConversions/VectorToVector.cpp", # TODO(transforms?) + "lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp", ], hdrs = [ - "include/mlir/Conversion/VectorConversions/VectorConversions.h", + "include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h", + ], + includes = ["include"], + deps = [ + ":EDSC", + ":IR", + ":LLVMDialect", + ":LLVMTransforms", + ":Pass", + ":StandardOps", + ":Support", + ":Transforms", + ":VectorOps", + "@llvm//:core", + "@llvm//:support", + ], + alwayslink = 1, +) + +cc_library( + name = "VectorToLoops", + srcs = [ + "lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp", + ], + hdrs = [ + "include/mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h", ], includes = ["include"], deps = [ - ":Analysis", ":EDSC", ":IR", ":LLVMDialect", @@ -2389,7 +2435,6 @@ cc_library( ":Support", ":Transforms", ":VectorOps", - ":VectorTransformPatterns", "@llvm//:core", "@llvm//:support", ], diff --git a/third_party/mlir/CMakeLists.txt b/third_party/mlir/CMakeLists.txt index c8ffa759376..d6767fa75a8 100644 --- a/third_party/mlir/CMakeLists.txt +++ b/third_party/mlir/CMakeLists.txt @@ -12,6 +12,27 @@ function(mlir_tablegen ofn) PARENT_SCOPE) endfunction() +function(add_mlir_dialect dialect) + set(LLVM_TARGET_DEFINITIONS ${dialect}.td) + mlir_tablegen(${dialect}.h.inc -gen-op-decls) + mlir_tablegen(${dialect}.cpp.inc -gen-op-defs) + add_public_tablegen_target(MLIR${dialect}IncGen) + + # Generate Dialect Documentation + tablegen(MLIR ${dialect}.md -gen-op-doc "-I${MLIR_MAIN_SRC_DIR}" "-I${MLIR_INCLUDE_DIR}") + set(GEN_DOC_FILE ${MLIR_BINARY_DIR}/docs/Dialects/${dialect}.md) + add_custom_command( + OUTPUT ${GEN_DOC_FILE} + COMMAND ${CMAKE_COMMAND} -E copy + ${CMAKE_CURRENT_BINARY_DIR}/${dialect}.md + ${GEN_DOC_FILE} + DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/${dialect}.md) + add_custom_target(${dialect}DocGen DEPENDS ${GEN_DOC_FILE}) + add_dependencies(mlir-doc ${dialect}DocGen) +endfunction() + +add_custom_target(mlir-doc) + # TODO: This is to handle the current static registration, but should be # factored out a bit. function(whole_archive_link target) diff --git a/third_party/mlir/bindings/python/BUILD b/third_party/mlir/bindings/python/BUILD index f9941ca1336..64ade7f43e2 100644 --- a/third_party/mlir/bindings/python/BUILD +++ b/third_party/mlir/bindings/python/BUILD @@ -7,14 +7,7 @@ licenses(["notice"]) # Apache 2.0 exports_files(["BUILD"]) package( - default_visibility = [":friends"], -) - -package_group( - name = "friends", - packages = [ - "@local_config_mlir//bindings/...", - ], + default_visibility = ["@local_config_mlir//:friends"], ) # diff --git a/third_party/mlir/bindings/python/pybind.cpp b/third_party/mlir/bindings/python/pybind.cpp index a458837f77a..b1be0d21336 100644 --- a/third_party/mlir/bindings/python/pybind.cpp +++ b/third_party/mlir/bindings/python/pybind.cpp @@ -31,6 +31,8 @@ #include "mlir/EDSC/Intrinsics.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" @@ -62,6 +64,8 @@ struct PythonExpr; struct PythonFunctionContext; struct PythonStmt; struct PythonBlock; +struct PythonAffineExpr; +struct PythonAffineMap; struct PythonType { PythonType() : type{nullptr} {} @@ -191,6 +195,28 @@ struct PythonMLIRModule { // Create a boolean attribute. PythonAttribute boolAttr(bool value); + // Creates an Array attribute. + PythonAttribute arrayAttr(const std::vector &values); + + // Creates an AffineMap attribute. + PythonAttribute affineMapAttr(PythonAffineMap value); + + // Creates an affine constant expression. + PythonAffineExpr affineConstantExpr(int64_t value); + + // Creates an affine symbol expression. + PythonAffineExpr affineSymbolExpr(unsigned position); + + // Creates an affine dimension expression. + PythonAffineExpr affineDimExpr(unsigned position); + + // Creates a single constant result affine map. + PythonAffineMap affineConstantMap(int64_t value); + + // Creates an affine map. + PythonAffineMap affineMap(unsigned dimCount, unsigned symbolCount, + const std::vector &results); + // Compile the module save the execution engine. "optLevel" and // "codegenOptLevel" contain the levels of optimization to run (0 to 3) for // transformations and codegen. -1 means ExecutionEngine default. @@ -467,14 +493,15 @@ struct PythonAttribute { PythonAttribute(const PythonAttribute &other) = default; operator mlir_attr_t() { return attr; } + operator Attribute() const { return Attribute::getFromOpaquePointer(attr); } + std::string str() const { if (!attr) return "##null attr##"; std::string res; llvm::raw_string_ostream os(res); - Attribute::getFromOpaquePointer(reinterpret_cast(attr)) - .print(os); + Attribute().print(os); return res; } @@ -532,6 +559,48 @@ private: std::unordered_map attrs; }; +// Wraps mlir::AffineExpr. +struct PythonAffineExpr { + PythonAffineExpr() : affine_expr() {} + PythonAffineExpr(const AffineExpr &a) : affine_expr(a) {} + PythonAffineExpr(const PythonAffineExpr &other) = default; + + operator AffineExpr() const { return affine_expr; } + operator AffineExpr &() { return affine_expr; } + + AffineExpr get() const { return affine_expr; } + + std::string str() const { + std::string res; + llvm::raw_string_ostream os(res); + affine_expr.print(os); + return res; + } + +private: + AffineExpr affine_expr; +}; + +// Wraps mlir::AffineMap. +struct PythonAffineMap { + PythonAffineMap() : affine_map() {} + PythonAffineMap(const AffineMap &a) : affine_map(a) {} + PythonAffineMap(const PythonAffineMap &other) = default; + + operator AffineMap() const { return affine_map; } + operator AffineMap &() { return affine_map; } + + std::string str() const { + std::string res; + llvm::raw_string_ostream os(res); + affine_map.print(os); + return res; + } + +private: + AffineMap affine_map; +}; + struct PythonIndexedValue { explicit PythonIndexedValue(PythonType type) : indexed(Type::getFromOpaquePointer(type.type)) {} @@ -640,6 +709,42 @@ PythonAttribute PythonMLIRModule::boolAttr(bool value) { return PythonAttribute(::makeBoolAttr(&mlirContext, value)); } +PythonAttribute +PythonMLIRModule::arrayAttr(const std::vector &values) { + std::vector mlir_attributes(values.begin(), values.end()); + auto array_attr = ArrayAttr::get( + llvm::ArrayRef(mlir_attributes), &mlirContext); + return PythonAttribute(array_attr.getAsOpaquePointer()); +} + +PythonAttribute PythonMLIRModule::affineMapAttr(PythonAffineMap value) { + return PythonAttribute(AffineMapAttr::get(value).getAsOpaquePointer()); +} + +PythonAffineExpr PythonMLIRModule::affineConstantExpr(int64_t value) { + return PythonAffineExpr(getAffineConstantExpr(value, &mlirContext)); +} + +PythonAffineExpr PythonMLIRModule::affineSymbolExpr(unsigned position) { + return PythonAffineExpr(getAffineSymbolExpr(position, &mlirContext)); +} + +PythonAffineExpr PythonMLIRModule::affineDimExpr(unsigned position) { + return PythonAffineExpr(getAffineDimExpr(position, &mlirContext)); +} + +PythonAffineMap PythonMLIRModule::affineConstantMap(int64_t value) { + return PythonAffineMap(AffineMap::getConstantMap(value, &mlirContext)); +} + +PythonAffineMap +PythonMLIRModule::affineMap(unsigned dimCount, unsigned SymbolCount, + const std::vector &results) { + std::vector mlir_results(results.begin(), results.end()); + return PythonAffineMap(AffineMap::get( + dimCount, SymbolCount, llvm::ArrayRef(mlir_results))); +} + PYBIND11_MODULE(pybind, m) { m.doc() = "Python bindings for MLIR Embedded Domain-Specific Components (EDSCs)"; @@ -801,6 +906,12 @@ PYBIND11_MODULE(pybind, m) { "integerAttr", &PythonMLIRModule::integerAttr, "Creates an mlir::IntegerAttr of the given type with the given value " "in the context associated with this MLIR module.") + .def("arrayAttr", &PythonMLIRModule::arrayAttr, + "Creates an mlir::ArrayAttr of the given type with the given values " + "in the context associated with this MLIR module.") + .def("affineMapAttr", &PythonMLIRModule::affineMapAttr, + "Creates an mlir::AffineMapAttr of the given type with the given " + "value in the context associated with this MLIR module.") .def("declare_function", &PythonMLIRModule::declareFunction, "Declares a new mlir::FuncOp in the current mlir::ModuleOp. The " "function arguments can have attributes. The function has no " @@ -831,6 +942,16 @@ PYBIND11_MODULE(pybind, m) { .def("get_engine_address", &PythonMLIRModule::getEngineAddress, "Returns the address of the compiled ExecutionEngine. This is used " "for in-process execution.") + .def("affine_constant_expr", &PythonMLIRModule::affineConstantExpr, + "Returns an affine constant expression.") + .def("affine_symbol_expr", &PythonMLIRModule::affineSymbolExpr, + "Returns an affine symbol expression.") + .def("affine_dim_expr", &PythonMLIRModule::affineDimExpr, + "Returns an affine dim expression.") + .def("affine_constant_map", &PythonMLIRModule::affineConstantMap, + "Returns an affine map with single constant result.") + .def("affine_map", &PythonMLIRModule::affineMap, "Returns an affine map.", + py::arg("dimCount"), py::arg("symbolCount"), py::arg("resuls")) .def("__str__", &PythonMLIRModule::getIR, "Get the string representation of the module"); @@ -940,6 +1061,68 @@ PYBIND11_MODULE(pybind, m) { .def(py::init()) .def("load", &PythonIndexedValue::load) .def("store", &PythonIndexedValue::store); + + py::class_(m, "AffineExpr", + "A wrapper around mlir::AffineExpr") + .def(py::init()) + .def("__add__", + [](PythonAffineExpr lhs, int64_t rhs) -> PythonAffineExpr { + return PythonAffineExpr(lhs.get() + rhs); + }) + .def("__add__", + [](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr { + return PythonAffineExpr(lhs.get() + rhs.get()); + }) + .def("__neg__", + [](PythonAffineExpr lhs) -> PythonAffineExpr { + return PythonAffineExpr(-lhs.get()); + }) + .def("__sub__", + [](PythonAffineExpr lhs, int64_t rhs) -> PythonAffineExpr { + return PythonAffineExpr(lhs.get() - rhs); + }) + .def("__sub__", + [](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr { + return PythonAffineExpr(lhs.get() - rhs.get()); + }) + .def("__mul__", + [](PythonAffineExpr lhs, int64_t rhs) -> PythonAffineExpr { + return PythonAffineExpr(lhs.get() * rhs); + }) + .def("__mul__", + [](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr { + return PythonAffineExpr(lhs.get() * rhs.get()); + }) + .def("__floordiv__", + [](PythonAffineExpr lhs, uint64_t rhs) -> PythonAffineExpr { + return PythonAffineExpr(lhs.get().floorDiv(rhs)); + }) + .def("__floordiv__", + [](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr { + return PythonAffineExpr(lhs.get().floorDiv(rhs.get())); + }) + .def("ceildiv", + [](PythonAffineExpr lhs, uint64_t rhs) -> PythonAffineExpr { + return PythonAffineExpr(lhs.get().ceilDiv(rhs)); + }) + .def("ceildiv", + [](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr { + return PythonAffineExpr(lhs.get().ceilDiv(rhs.get())); + }) + .def("__mod__", + [](PythonAffineExpr lhs, uint64_t rhs) -> PythonAffineExpr { + return PythonAffineExpr(lhs.get() % rhs); + }) + .def("__mod__", + [](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr { + return PythonAffineExpr(lhs.get() % rhs.get()); + }) + .def("__str__", &PythonAffineExpr::str); + + py::class_(m, "AffineMap", + "A wrapper around mlir::AffineMap") + .def(py::init()) + .def("__str__", &PythonAffineMap::str); } } // namespace python diff --git a/third_party/mlir/bindings/python/test/test_py2and3.py b/third_party/mlir/bindings/python/test/test_py2and3.py index 2f4281ee59a..678e5023173 100644 --- a/third_party/mlir/bindings/python/test/test_py2and3.py +++ b/third_party/mlir/bindings/python/test/test_py2and3.py @@ -285,6 +285,49 @@ class EdscTest: # CHECK-LABEL: testFunctionDeclaration # CHECK: func @foo(memref<10xf32>, memref<10xf32> {llvm.noalias = true}, memref<10xf32> {readonly = true}) + def testFunctionDeclarationWithAffineAttr(self): + self.setUp() + a1 = self.module.affine_constant_expr(23) + a2 = self.module.affine_constant_expr(44) + a3 = self.module.affine_dim_expr(1) + s0 = self.module.affine_symbol_expr(0) + aMap1 = self.module.affine_map(2, 0, [a1, a2, s0]) + aMap2 = self.module.affine_constant_map(42) + aMap3 = self.module.affine_map( + 2, 0, + [a1 + a2 * a3, a1 // a3 % a2, + a1.ceildiv(a2), a1 - 2, a2 * 2, -a3]) + + affineAttr1 = self.module.affineMapAttr(aMap1) + affineAttr2 = self.module.affineMapAttr(aMap2) + affineAttr3 = self.module.affineMapAttr(aMap3) + + t = self.module.make_memref_type(self.f32Type, [10]) + t_with_attr = t({ + "affine_attr_1": affineAttr1, + "affine_attr_2": affineAttr2, + "affine_attr_3": affineAttr3, + }) + + f = self.module.declare_function("foo", [t, t_with_attr], []) + printWithCurrentFunctionName(str(self.module)) + # CHECK-LABEL: testFunctionDeclarationWithAffineAttr + # CHECK: func @foo(memref<10xf32>, memref<10xf32> {affine_attr_1 = (d0, d1) -> (23, 44, s0), affine_attr_2 = () -> (42), affine_attr_3 = (d0, d1) -> (d1 * 44 + 23, (23 floordiv d1) mod 44, 1, 21, 88, -d1)}) + + def testFunctionDeclarationWithArrayAttr(self): + self.setUp() + arrayAttr = self.module.arrayAttr([ + self.module.integerAttr(self.i32Type, 43), + self.module.integerAttr(self.i32Type, 33), + ]) + t = self.module.make_memref_type(self.f32Type, [10]) + t_with_attr = t({"array_attr": arrayAttr}) + + f = self.module.declare_function("foo", [t, t_with_attr], []) + printWithCurrentFunctionName(str(self.module)) + # CHECK-LABEL: testFunctionDeclarationWithArrayAttr + # CHECK: func @foo(memref<10xf32>, memref<10xf32> {array_attr = [43 : i32, 33 : i32]}) + def testFunctionMultiple(self): self.setUp() with self.module.function_context("foo", [], []): diff --git a/third_party/mlir/g3doc/DeclarativeRewrites.md b/third_party/mlir/g3doc/DeclarativeRewrites.md index 2d9fb5b5219..c7276daccd8 100644 --- a/third_party/mlir/g3doc/DeclarativeRewrites.md +++ b/third_party/mlir/g3doc/DeclarativeRewrites.md @@ -144,6 +144,13 @@ Also note that we only need to add `TypeConstraint` or `AttributeConstraint` when we need to further limit the match criteria. If all valid cases to the op are acceptable, then we can leave the constraint unspecified. +`$_` is a special symbol to mean ignore capturing an argument. For example, +`def : Pat<(AOp $_, $b), ...>` means only `$b` is interesting to capture and +will be referenced later in result patterns. It's still possible to place +additional constraints even if the symbol is not to be captured; for such case, +you can simply use just the `TypeConstraint` or `AttributeConstraint` without a +bound symbol, for example, `def : Pat<(AOp $a, F32Attr), ...>`. + #### Matching DAG of operations To match an DAG of ops, use nested `dag` objects: @@ -479,7 +486,7 @@ on **naming convention**: a `__N` suffix is added to a symbol to indicate the #### `__N` suffix -The `__N` sufix is specifying the `N`-th result as a whole (which can be +The `__N` suffix is specifying the `N`-th result as a whole (which can be [variadic](#supporting-variadic-ops)). For example, we can bind a symbol to some multi-result op and reference a specific result later: @@ -674,7 +681,7 @@ mlir-tblgen --gen-rewriters -I /path/to/mlir/include /path/to/input/td/file ### Compilation error: no matching member function for call to 'build' -This is because DRR is failing to call a `build()` mehtod with result type +This is because DRR is failing to call a `build()` method with result type deduction ability. See [building operations](#building-operations) for more details. diff --git a/third_party/mlir/g3doc/Dialects/LLVM.md b/third_party/mlir/g3doc/Dialects/LLVM.md index ed0cad2df1f..9791352aa56 100644 --- a/third_party/mlir/g3doc/Dialects/LLVM.md +++ b/third_party/mlir/g3doc/Dialects/LLVM.md @@ -72,6 +72,11 @@ llvm.func @foo(%arg0: !llvm.i64) { llvm.return } +// A function with `internal` linkage. +llvm.func internal @internal_func() { + llvm.return +} + ``` ### LLVM IR operations diff --git a/third_party/mlir/g3doc/OpDefinitions.md b/third_party/mlir/g3doc/OpDefinitions.md index ea794964033..25865593800 100644 --- a/third_party/mlir/g3doc/OpDefinitions.md +++ b/third_party/mlir/g3doc/OpDefinitions.md @@ -382,27 +382,86 @@ def OpWithInferTypeInterfaceOp : Op<... [DeclareOpInterfaceMethods]> { ... } ``` -### Custom builder methods +### Builder methods -For each operation, there are two builders automatically generated based on the -arguments and returns types: +For each operation, there are a few builders automatically generated based on +the arguments and returns types. For example, given the following op definition: + +```tablegen +def MyOp : ... { + let arguments = (ins + I32:$i32_operand, + F32:$f32_operand, + ..., + + I32Attr:$i32_attr, + F32Attr:$f32_attr, + ... + ); + + let results = (outs + I32:$i32_result, + F32:$f32_result, + ... + ); +} +``` + +The following builders are generated: ```c++ -static void build(Builder *, OperationState &tblgen_state, - Type , Type , ..., - Value , Value , ..., - Attribute , Attribute , ...); - -static void build(Builder *, OperationState &tblgen_state, +// All result-types/operands/attributes have one aggregate parameter. +static void build(Builder *tblgen_builder, OperationState &tblgen_state, ArrayRef resultTypes, ArrayRef operands, ArrayRef attributes); + +// Each result-type/operand/attribute has a separate parameter. The parameters +// for attributes are of mlir::Attribute types. +static void build(Builder *tblgen_builder, OperationState &tblgen_state, + Type i32_result, Type f32_result, ..., + Value *i32_operand, Value *f32_operand, ..., + IntegerAttr i32_attr, FloatAttr f32_attr, ...); + +// Each result-type/operand/attribute has a separate parameter. The parameters +// for attributes are raw values unwrapped with mlir::Attribute instances. +// (Note that this builder will not always be generated. See the following +// explanation for more details.) +static void build(Builder *tblgen_builder, OperationState &tblgen_state, + Type i32_result, Type f32_result, ..., + Value *i32_operand, Value *f32_operand, ..., + APInt i32_attr, StringRef f32_attr, ...); + +// (And potentially others depending on the specific op.) ``` -The above cases make sure basic uniformity so that we can create ops using the +The first form provides basic uniformity so that we can create ops using the same form regardless of the exact op. This is particularly useful for implementing declarative pattern rewrites. +The second and third forms are good for use in manually written code given that +they provide better guarantee via signatures. + +The third form will be generated if any of the op's attribute has different +`Attr.returnType` from `Attr.storageType` and we know how to build an attribute +from an unwrapped value (i.e., `Attr.constBuilderCall` is defined.) +Additionally, for the third form, if an attribute appearing later in the +`arguments` list has a default value, the default value will be supplied in the +declaration. This works for `BoolAttr`, `StrAttr`, `EnumAttr` for now and the +list can grow in the future. So if possible, default valued attribute should be +placed at the end of the `arguments` list to leverage this feature. (This +behavior is essentially due to C++ function parameter default value placement +restrictions.) Otherwise, the builder of the third form will still be generated +but default values for the attributes not at the end of the `arguments` list +will not be supplied in the builder's signature. + +And there may potentially exist other builders depending on the specific op; +please refer to the +[generated C++ file](#run-mlir-tblgen-to-see-the-generated-content) for the +complete list. + +#### Custom builder methods + However, if the above cases cannot satisfy all needs, you can define additional convenience build methods with `OpBuilder`. diff --git a/third_party/mlir/g3doc/Tutorials/Toy/Ch-1.md b/third_party/mlir/g3doc/Tutorials/Toy/Ch-1.md index 650a135619e..cb7f97cb3f6 100644 --- a/third_party/mlir/g3doc/Tutorials/Toy/Ch-1.md +++ b/third_party/mlir/g3doc/Tutorials/Toy/Ch-1.md @@ -62,7 +62,7 @@ def main() { var b<2, 3> = [1, 2, 3, 4, 5, 6]; # transpose() and print() are the only builtin, the following will transpose - # b and perform an element-wise multiplication before printing the result. + # a and b and perform an element-wise multiplication before printing the result. print(transpose(a) * transpose(b)); } ``` @@ -145,7 +145,7 @@ Module: var: b @test/ast.toy:21:30 var: c @test/ast.toy:21:33 ] - VarDecl e<> @test/ast.toy:24:3 + VarDecl f<> @test/ast.toy:24:3 Call 'multiply_transpose' [ @test/ast.toy:24:11 Call 'transpose' [ @test/ast.toy:24:30 var: a @test/ast.toy:24:40 diff --git a/third_party/mlir/g3doc/Tutorials/Toy/Ch-2.md b/third_party/mlir/g3doc/Tutorials/Toy/Ch-2.md index 056b25779cb..d797624ed72 100755 --- a/third_party/mlir/g3doc/Tutorials/Toy/Ch-2.md +++ b/third_party/mlir/g3doc/Tutorials/Toy/Ch-2.md @@ -72,7 +72,7 @@ Let's break down the anatomy of this MLIR operation: are always constant. Here we define a boolean attribute named 'inplace' that has a constant value of true. -- `(tensor<2x3xf64) -> tensor<3x2xf64>` +- `(tensor<2x3xf64>) -> tensor<3x2xf64>` * This refers to the type of the operation in a functional form, spelling the types of the arguments in parentheses and the type of the return @@ -270,16 +270,17 @@ types, etc). We can always get an instance of our toy operation by using LLVM's casting infrastructure: ```c++ -void processConstantOp(mlir::Operation *op) { - ConstantOp op = llvm::dyn_cast(op); +void processConstantOp(mlir::Operation *operation) { + ConstantOp op = llvm::dyn_cast(operation); // This operation is not an instance of `ConstantOp`. if (!op) return; // Get the internal operation instance back. - mlir::Operation *internalOp = op.getOperation(); - assert(internalOp == op && "these operation instances are the same"); + mlir::Operation *internalOperation = op.getOperation(); + assert(internalOperation == operation && + "these operation instances are the same"); } ``` @@ -395,7 +396,7 @@ documents. ```tablegen def ConstantOp : Toy_Op<"constant", [NoSideEffect]> { // Provide a summary and description for this operation. This can be used to - // auto-generate documenatation of the operations within our dialect. + // auto-generate documentation of the operations within our dialect. let summary = "constant operation"; let description = [{ Constant operation turns a literal into an SSA value. The data is attached @@ -473,7 +474,7 @@ the implementation inline. ```tablegen def ConstantOp : Toy_Op<"constant", [NoSideEffect]> { // Provide a summary and description for this operation. This can be used to - // auto-generate documenatation of the operations within our dialect. + // auto-generate documentation of the operations within our dialect. let summary = "constant operation"; let description = [{ Constant operation turns a literal into an SSA value. The data is attached diff --git a/third_party/mlir/g3doc/WritingAPass.md b/third_party/mlir/g3doc/WritingAPass.md index df0d153ad1a..1e4564aa21d 100644 --- a/third_party/mlir/g3doc/WritingAPass.md +++ b/third_party/mlir/g3doc/WritingAPass.md @@ -116,12 +116,20 @@ the following: * Provide a valid constructor taking an `Operation*`. * Must not modify the given operation. -The base `OperationPass` class provide utilities for querying and preserving -analyses for the current operation being processed. Using the example passes -defined above, let's see some examples: +An analysis may provide additional hooks to control various behavior: + +* `bool isInvalidated(const AnalysisManager::PreservedAnalyses &)` + +Given a preserved analysis set, the analysis returns true if it should truly be +invalidated. This allows for more fine-tuned invalidation in cases where an +analysis wasn't explicitly marked preserved, but may be preserved(or +invalidated) based upon other properties such as analyses sets. ### Querying Analyses +The base `OperationPass` class provide utilities for querying and preserving +analyses for the current operation being processed. + * OperationPass automatically provides the following utilities for querying analyses: * `getAnalysis<>` @@ -137,7 +145,7 @@ defined above, let's see some examples: - Get an analysis for a given child operation, constructing it if necessary. -A few example usages are shown below: +Using the example passes defined above, let's see some examples: ```c++ /// An interesting analysis. diff --git a/third_party/mlir/include/mlir/Analysis/LoopAnalysis.h b/third_party/mlir/include/mlir/Analysis/LoopAnalysis.h index 7763a2bd262..8832c1469bc 100644 --- a/third_party/mlir/include/mlir/Analysis/LoopAnalysis.h +++ b/third_party/mlir/include/mlir/Analysis/LoopAnalysis.h @@ -31,8 +31,9 @@ namespace mlir { class AffineExpr; class AffineForOp; class AffineMap; -class Operation; class MemRefType; +class NestedPattern; +class Operation; class Value; /// Returns the trip count of the loop as an affine map with its corresponding @@ -91,14 +92,16 @@ using VectorizableLoopFun = std::function; /// 1. no conditionals are nested under the loop; /// 2. all nested load/stores are to scalar MemRefs. /// TODO(ntv): relax the no-conditionals restriction -bool isVectorizableLoopBody(AffineForOp loop); +bool isVectorizableLoopBody(AffineForOp loop, + NestedPattern &vectorTransferMatcher); /// Checks whether the loop is structurally vectorizable and that all the LoadOp /// and StoreOp matched have access indexing functions that are are either: /// 1. invariant along the loop induction variable created by 'loop'; /// 2. varying along at most one memory dimension. If such a unique dimension /// is found, it is written into `memRefDim`. -bool isVectorizableLoopBody(AffineForOp loop, int *memRefDim); +bool isVectorizableLoopBody(AffineForOp loop, int *memRefDim, + NestedPattern &vectorTransferMatcher); /// Checks where SSA dominance would be violated if a for op's body /// operations are shifted by the specified shifts. This method checks if a diff --git a/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h b/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h index 69db8171ed0..4caa6d9de77 100644 --- a/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h +++ b/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h @@ -26,12 +26,19 @@ namespace mlir { class SPIRVTypeConverter; + /// Appends to a pattern list additional patterns for translating StandardOps to -/// SPIR-V ops. +/// SPIR-V ops. Also adds the patterns legalize ops not directly translated to +/// SPIR-V dialect. void populateStandardToSPIRVPatterns(MLIRContext *context, SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns); +/// Appends to a pattern list patterns to legalize ops that are not directly +/// lowered to SPIR-V. +void populateStdLegalizationPatternsForSPIRVLowering( + MLIRContext *context, OwningRewritePatternList &patterns); + } // namespace mlir #endif // MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H diff --git a/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h b/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h index 1bf497708de..e8a71feb8b2 100644 --- a/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h +++ b/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h @@ -25,8 +25,13 @@ #include "mlir/Pass/Pass.h" namespace mlir { + /// Pass to convert StandardOps to SPIR-V ops. std::unique_ptr> createConvertStandardToSPIRVPass(); + +/// Pass to legalize ops that are not directly lowered to SPIR-V. +std::unique_ptr createLegalizeStdOpsForSPIRVLoweringPass(); + } // namespace mlir #endif // MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRVPASS_H diff --git a/third_party/mlir/include/mlir/Conversion/VectorConversions/VectorConversions.h b/third_party/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h similarity index 54% rename from third_party/mlir/include/mlir/Conversion/VectorConversions/VectorConversions.h rename to third_party/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h index 56862ca0dad..a87e1c658a6 100644 --- a/third_party/mlir/include/mlir/Conversion/VectorConversions/VectorConversions.h +++ b/third_party/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -1,4 +1,4 @@ -//===- VectorConversions.h - Utils to convert from the vector dialect -----===// +//===- ConvertVectorToLLVM.h - Utils to convert from the vector dialect ---===// // // Copyright 2019 The MLIR Authors. // @@ -14,31 +14,16 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef MLIR_CONVERSION_VECTORCONVERSIONS_VECTORCONVERSIONS_H_ -#define MLIR_CONVERSION_VECTORCONVERSIONS_VECTORCONVERSIONS_H_ +#ifndef MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_ +#define MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_ #include "mlir/Transforms/DialectConversion.h" namespace mlir { class LLVMTypeConverter; -class MLIRContext; class ModuleOp; template class OpPassBase; -/// Collect a set of patterns to convert from the Vector dialect to affine loops -/// surrounding ops in different dialects (vector, std etc). -/// This is the general place where we want to implement Vector -> Vector and -/// Vector -> Std legalizations. -void populateVectorToAffineLoopsConversionPatterns( - MLIRContext *context, OwningRewritePatternList &patterns); - -/// Collect a set of patterns to convert from the Vector dialect to itself. -/// Should be merged with populateVectorToAffineLoopsConversionPatterns. -void populateVectorToVectorConversionPatterns( - MLIRContext *context, OwningRewritePatternList &patterns, - ArrayRef coarseVectorShape = {}, - ArrayRef fineVectorShape = {}); - /// Collect a set of patterns to convert from the Vector dialect to LLVM. void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); @@ -48,4 +33,4 @@ OpPassBase *createLowerVectorToLLVMPass(); } // namespace mlir -#endif // MLIR_CONVERSION_VECTORCONVERSIONS_VECTORCONVERSIONS_H_ +#endif // MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_ diff --git a/third_party/mlir/include/mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h b/third_party/mlir/include/mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h new file mode 100644 index 00000000000..198eaceda41 --- /dev/null +++ b/third_party/mlir/include/mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h @@ -0,0 +1,36 @@ +//===- ConvertVectorToLoops.h - Utils to convert from the vector dialect --===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLOOPS_H_ +#define MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLOOPS_H_ + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +class MLIRContext; +class ModuleOp; +template class OpPassBase; + +/// Collect a set of patterns to convert from the Vector dialect to loops + std. +void populateVectorToAffineLoopsConversionPatterns( + MLIRContext *context, OwningRewritePatternList &patterns); + +/// Create a pass to convert vector operations to affine loops + std dialect. +OpPassBase *createLowerVectorToLoopsPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLOOPS_H_ diff --git a/third_party/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt index 6c5a58c957b..8f812b39593 100644 --- a/third_party/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt +++ b/third_party/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt @@ -1,4 +1 @@ -set(LLVM_TARGET_DEFINITIONS AffineOps.td) -mlir_tablegen(AffineOps.h.inc -gen-op-decls) -mlir_tablegen(AffineOps.cpp.inc -gen-op-defs) -add_public_tablegen_target(MLIRAffineOpsIncGen) +add_mlir_dialect(AffineOps) diff --git a/third_party/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt index eaf72d214f8..a8fb5e08ee5 100644 --- a/third_party/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt +++ b/third_party/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt @@ -1,4 +1 @@ -set(LLVM_TARGET_DEFINITIONS FxpMathOps.td) -mlir_tablegen(FxpMathOps.h.inc -gen-op-decls) -mlir_tablegen(FxpMathOps.cpp.inc -gen-op-defs) -add_public_tablegen_target(MLIRFxpMathOpsIncGen) +add_mlir_dialect(FxpMathOps) diff --git a/third_party/mlir/include/mlir/Dialect/GPU/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/GPU/CMakeLists.txt index 5ba59a1026c..bdb5dec79b9 100644 --- a/third_party/mlir/include/mlir/Dialect/GPU/CMakeLists.txt +++ b/third_party/mlir/include/mlir/Dialect/GPU/CMakeLists.txt @@ -1,4 +1 @@ -set(LLVM_TARGET_DEFINITIONS GPUOps.td) -mlir_tablegen(GPUOps.h.inc -gen-op-decls) -mlir_tablegen(GPUOps.cpp.inc -gen-op-defs) -add_public_tablegen_target(MLIRGPUOpsIncGen) +add_mlir_dialect(GPUOps) diff --git a/third_party/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt index 3e5a0346ed6..4ecc71aef08 100644 --- a/third_party/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/third_party/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -4,14 +4,10 @@ mlir_tablegen(LLVMOps.cpp.inc -gen-op-defs) mlir_tablegen(LLVMOpsEnums.h.inc -gen-enum-decls) mlir_tablegen(LLVMOpsEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIRLLVMOpsIncGen) -set(LLVM_TARGET_DEFINITIONS NVVMOps.td) -mlir_tablegen(NVVMOps.h.inc -gen-op-decls) -mlir_tablegen(NVVMOps.cpp.inc -gen-op-defs) -add_public_tablegen_target(MLIRNVVMOpsIncGen) -set(LLVM_TARGET_DEFINITIONS ROCDLOps.td) -mlir_tablegen(ROCDLOps.h.inc -gen-op-decls) -mlir_tablegen(ROCDLOps.cpp.inc -gen-op-defs) -add_public_tablegen_target(MLIRROCDLOpsIncGen) + +add_mlir_dialect(NVVMOps) +add_mlir_dialect(ROCDLOps) + set(LLVM_TARGET_DEFINITIONS LLVMOps.td) mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions) add_public_tablegen_target(MLIRLLVMConversionsIncGen) diff --git a/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h index eb39537c03d..83c30e64b9f 100644 --- a/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -195,7 +195,8 @@ private: /// global and use it to compute the address of the first character in the /// string (operations inserted at the builder insertion point). Value *createGlobalString(Location loc, OpBuilder &builder, StringRef name, - StringRef value, LLVM::LLVMDialect *llvmDialect); + StringRef value, LLVM::Linkage linkage, + LLVM::LLVMDialect *llvmDialect); } // end namespace LLVM } // end namespace mlir diff --git a/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 3d697b78374..573542ba838 100644 --- a/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -467,8 +467,36 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> { let printer = [{ p << getOperationName(); }]; } +//////////////////////////////////////////////////////////////////////////////// // Auxiliary operations (do not appear in LLVM IR but necessary for the dialect // to work correctly). +//////////////////////////////////////////////////////////////////////////////// + +// Linkage attribute is used on functions and globals. The order follows that of +// https://llvm.org/docs/LangRef.html#linkage-types. The names are equivalent to +// visible names in the IR rather than to enum values names in llvm::GlobalValue +// since the latter is easier to change. +def LinkagePrivate : I64EnumAttrCase<"Private", 0>; +def LinkageInternal : I64EnumAttrCase<"Internal", 1>; +def LinkageAvailableExternally : I64EnumAttrCase<"AvailableExternally", 2>; +def LinkageLinkonce : I64EnumAttrCase<"Linkonce", 3>; +def LinkageWeak : I64EnumAttrCase<"Weak", 4>; +def LinkageCommon : I64EnumAttrCase<"Common", 5>; +def LinkageAppending : I64EnumAttrCase<"Appending", 6>; +def LinkageExternWeak : I64EnumAttrCase<"ExternWeak", 7>; +def LinkageLinkonceODR : I64EnumAttrCase<"LinkonceODR", 8>; +def LinkageWeakODR : I64EnumAttrCase<"WeakODR", 9>; +def LinkageExternal : I64EnumAttrCase<"External", 10>; +def Linkage : I64EnumAttr< + "Linkage", + "LLVM linkage types", + [LinkagePrivate, LinkageInternal, LinkageAvailableExternally, + LinkageLinkonce, LinkageWeak, LinkageCommon, LinkageAppending, + LinkageExternWeak, LinkageLinkonceODR, LinkageWeakODR, LinkageExternal]> { + let cppNamespace = "::mlir::LLVM"; +} + + def LLVM_AddressOfOp : LLVM_OneResultOp<"mlir.addressof">, Arguments<(ins FlatSymbolRefAttr:$global_name)> { @@ -501,6 +529,7 @@ def LLVM_GlobalOp [IsolatedFromAbove, SingleBlockImplicitTerminator<"ReturnOp">, Symbol]>, Arguments<(ins TypeAttr:$type, UnitAttr:$constant, StrAttr:$sym_name, + Linkage:$linkage, OptionalAttr:$value, DefaultValuedAttr:$addr_space)> { let summary = "LLVM dialect global."; @@ -522,8 +551,8 @@ def LLVM_GlobalOp let builders = [ OpBuilder<"Builder *builder, OperationState &result, LLVMType type, " - "bool isConstant, StringRef name, Attribute value, " - "ArrayRef attrs = {}"> + "bool isConstant, Linkage linkage, StringRef name, " + "Attribute value, ArrayRef attrs = {}"> ]; let extraClassDeclaration = [{ @@ -554,9 +583,12 @@ def LLVM_GlobalOp let verifier = "return ::verify(*this);"; } -def LLVM_LLVMFuncOp : LLVM_ZeroResultOp<"func", - [NativeOpTrait<"IsIsolatedFromAbove">, NativeOpTrait<"FunctionLike">, - Symbol]> { +def LLVM_LLVMFuncOp + : LLVM_ZeroResultOp<"func", + [NativeOpTrait<"IsIsolatedFromAbove">, + NativeOpTrait<"FunctionLike">, Symbol]>, + Arguments<(ins DefaultValuedAttr:$linkage)> { let summary = "LLVM dialect function, has wrapped LLVM IR function type"; let regions = (region AnyRegion:$body); @@ -565,7 +597,8 @@ def LLVM_LLVMFuncOp : LLVM_ZeroResultOp<"func", let builders = [ OpBuilder<"Builder *builder, OperationState &result, StringRef name, " - "LLVMType type, ArrayRef attrs = {}, " + "LLVMType type, LLVM::Linkage linkage = LLVM::Linkage::External, " + "ArrayRef attrs = {}, " "ArrayRef argAttrs = {}"> ]; @@ -598,10 +631,7 @@ def LLVM_LLVMFuncOp : LLVM_ZeroResultOp<"func", let verifier = [{ return ::verify(*this); }]; let printer = [{ printLLVMFuncOp(p, *this); }]; - let parser = [{ - return impl::parseFunctionLikeOp(parser, result, /*allowVariadic=*/true, - buildLLVMFunctionType); - }]; + let parser = [{ return parseLLVMFuncOp(parser, result); }]; } def LLVM_NullOp diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt index b175e9ad044..2a883a138a5 100644 --- a/third_party/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt +++ b/third_party/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt @@ -1,7 +1,4 @@ -set(LLVM_TARGET_DEFINITIONS LinalgOps.td) -mlir_tablegen(LinalgOps.h.inc -gen-op-decls) -mlir_tablegen(LinalgOps.cpp.inc -gen-op-defs) -add_public_tablegen_target(MLIRLinalgOpsIncGen) +add_mlir_dialect(LinalgOps) set(LLVM_TARGET_DEFINITIONS LinalgLibraryOps.td) mlir_tablegen(LinalgLibraryOps.h.inc -gen-op-decls) mlir_tablegen(LinalgLibraryOps.cpp.inc -gen-op-defs) diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td index 92b325b5943..afaf039ffd5 100644 --- a/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td +++ b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td @@ -368,7 +368,7 @@ def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> { class GenericOpBase : LinalgLibraryBase_Op { let arguments = (ins Variadic:$views, AffineMapArrayAttr:$indexing_maps, - I64ArrayAttr:$n_loop_types, + ArrayAttr:$iterator_types, I64ArrayAttr:$n_views, OptionalAttr:$doc, OptionalAttr:$fun, @@ -377,7 +377,7 @@ class GenericOpBase : LinalgLibraryBase_Op { let extraClassDeclaration = [{ SmallVector linalgTraitAttrNames() { return SmallVector{ - "doc", "fun", "indexing_maps", "library_call", "n_loop_types", "n_views" + "doc", "fun", "indexing_maps", "library_call", "iterator_types", "n_views" }; } unsigned getNumInputs() { @@ -395,26 +395,35 @@ class GenericOpBase : LinalgLibraryBase_Op { return val.getZExtValue(); } unsigned getNumParallelLoops() { - if (!getAttr("n_loop_types") || n_loop_types().getValue().size() != 3) + if (!getAttr("iterator_types") || iterator_types().getValue().size() == 0) return 0; - auto val = n_loop_types().getValue()[0].cast().getValue(); - assert(val.getSExtValue() >= 0); - return val.getZExtValue(); + unsigned nPar = 0; + for (auto ty : iterator_types()) { + if (ty.cast().getValue() == "parallel") + nPar++; + } + return nPar; } unsigned getNumReductionLoops() { - if (!getAttr("n_loop_types") || n_loop_types().getValue().size() != 3) + if (!getAttr("iterator_types") || iterator_types().getValue().size() == 0) return 0; - auto val = n_loop_types().getValue()[1].cast().getValue(); - assert(val.getSExtValue() >= 0); - return val.getZExtValue(); - } - unsigned getNumWindowLoops() { - if (!getAttr("n_loop_types") || n_loop_types().getValue().size() != 3) + unsigned nRed = 0; + for (auto ty : iterator_types()) { + if (ty.cast().getValue() == "reduction") + nRed++; + } + return nRed; + } + unsigned getNumWindowLoops() { + if (!getAttr("iterator_types") || iterator_types().getValue().size() == 0) return 0; - auto val = n_loop_types().getValue()[2].cast().getValue(); - assert(val.getSExtValue() >= 0); - return val.getZExtValue(); - } + unsigned nWin = 0; + for (auto ty : iterator_types()) { + if (ty.cast().getValue() == "window") + nWin++; + } + return nWin; + } unsigned getNumLoops() { return getNumParallelLoops() + getNumReductionLoops() + getNumWindowLoops(); @@ -474,8 +483,9 @@ def GenericOp : GenericOpBase<"generic"> { The external library is assumed to be dynamically linked and no strong compile-time guarantees are provided. In the absence of such a library call, linalg.generic will always lower to loops. - - n_loops: a triple of I64Attr representing the number of enclosing - [parallel, reduction, window] loops respectively. + - iterator_types: an ArrayAttr they type of the enclosing loops; Each element of + the list represents and iterator of one of the following types: + parallel, reduction, window - n_views: a pair of I64Attr representing the number of input (readonly) and output (readwrite) views. @@ -498,7 +508,7 @@ def GenericOp : GenericOpBase<"generic"> { indexing_maps = #matmul_accesses, library_call = "linalg_matmul", n_views = [2, 1], - n_loop_types = [2, 1, 0] + iterator_types = ["parallel", "parallel", "reduction"] } ``` @@ -557,8 +567,8 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> { To support inplace updates in a generic fashion, the signature of the function must be: ``` - fun([input views element types], [output views element types]) - -> ([output views element types]) + fun([index types for induction variables], [input views element types], + [output views element types]) -> ([output views element types]) ``` - indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input and output view. Such AffineMapAttr specifies the mapping between the @@ -568,15 +578,16 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> { maps to. The external library is assumed to be dynamically linked and no strong compile-time guarantees are provided. In the absence of such a library call, linalg.indexed_generic will always lower to loops. - - n_loops: a triple of I64Attr representing the number of enclosing - [parallel, reduction, window] loops respectively. + - iterator_types: an ArrayAttr they type of the enclosing loops; Each element of + the list represents and iterator of one of the following types: + parallel, reduction, window - n_views: a pair of I64Attr representing the number of input (readonly) and output (readwrite) views. Example: Defining a #matmul_trait attribute in MLIR can be done as follows: ```mlir - func @fma(%a: f32, %b: f32, %c: f32) -> f32 { + func @fma(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32) -> f32 { %d = mulf %a, %b: f32 %e = addf %c, %d: f32 return %e: f32 @@ -592,7 +603,7 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> { indexing_maps = #matmul_accesses, library_call = "linalg_matmul", n_views = [2, 1], - n_loop_types = [2, 1, 0] + iterator_types = ["parallel", "parallel", "reduction"] } ``` diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/Passes.h b/third_party/mlir/include/mlir/Dialect/Linalg/Passes.h index 5ecd50070da..7ae3877f01e 100644 --- a/third_party/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/third_party/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -39,9 +39,16 @@ createLinalgTilingPass(ArrayRef tileSizes = {}); std::unique_ptr> createLinalgPromotionPass(bool dynamicBuffers); -std::unique_ptr> createLowerLinalgToLoopsPass(); +/// Create a pass to convert Linalg operations to loop.for loops and +/// std.load/std.store accesses. +std::unique_ptr> createConvertLinalgToLoopsPass(); -/// Create a pass to convert vector operations to the LLVMIR dialect. +/// Create a pass to convert Linalg operations to affine.for loops and +/// affine_load/affine_store accesses. +/// Placeholder for now, this is NYI. +std::unique_ptr> createConvertLinalgToAffineLoopsPass(); + +/// Create a pass to convert Linalg operations to the LLVMIR dialect. std::unique_ptr> createConvertLinalgToLLVMPass(); } // namespace linalg diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td index d243bb23f2c..8bc0eaf2097 100644 --- a/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td +++ b/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td @@ -62,4 +62,15 @@ class TileLinalgOp sizes, string value> : NativeCodeCall< StrJoinInt.result # "}, \"" # value # "\")))" # " return matchFailure();">; +//===----------------------------------------------------------------------===// +// Linalg to loop patterns. +//===----------------------------------------------------------------------===// +class LinalgOpToLoops : NativeCodeCall< + "if (failed(linalgOpToLoops<" # OpType # ">($_builder, $0))) " # + " return matchFailure();">; + +class LinalgOpToAffineLoops : NativeCodeCall< + "if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, $0))) " # + " return matchFailure();">; + #endif // LINALG_TRANSFORMS diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h index 56ae94f32c6..966b8f93135 100644 --- a/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h +++ b/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h @@ -35,20 +35,6 @@ struct LinalgTransforms { static const StringLiteral kLinalgTransformMarker; }; -// Declarative transformation used in tablegen patterns. -// Tiles `op` by `sizes` and sets the attribute `kLinalgTransformMarker` to -// `linalgMarker`. -LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, Operation *op, - ArrayRef sizes, - StringRef linalgMarker); - -// Declarative transformation used in tablegen patterns. -// Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and sets -// the attribute `kLinalgTransformMarker` to `linalgMarker`. -LogicalResult tileAndFuseLinalgOpAndSetMarker( - PatternRewriter &rewriter, Operation *op, ArrayRef sizes, - ArrayRef operandIndicesToFuse, StringRef linalgMarker); - namespace detail { // Implementation detail of isProducedByOpOfType avoids the need for explicit // template instantiations. @@ -65,6 +51,33 @@ bool isProducedByOpOfType(Operation *consumerOp, Value *consumedView) { consumerOp, consumedView, [](Operation *op) { return isa(op); }); } +//////////////////////////////////////////////////////////////////////////////// +// The following Declarative Rewrite Rule (DRR) helpers are used in rewrite +// patterns. As such, they must not call into `rewriter.erase/replace` APIs and +// it is the responsibility of the enclosing PatternRewriter to erase on +// success. +//////////////////////////////////////////////////////////////////////////////// + +// Tiles `op` by `sizes` and sets the attribute `kLinalgTransformMarker` to +// `linalgMarker`. +LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, Operation *op, + ArrayRef sizes, + StringRef linalgMarker); + +// Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and sets +// the attribute `kLinalgTransformMarker` to `linalgMarker`. +LogicalResult tileAndFuseLinalgOpAndSetMarker( + PatternRewriter &rewriter, Operation *op, ArrayRef sizes, + ArrayRef operandIndicesToFuse, StringRef linalgMarker); + +// Emits a loop nest of `loop.for` with the proper body for `op`. +template +LogicalResult linalgOpToLoops(PatternRewriter &rewriter, Operation *op); + +// Emits a loop nest of `affine.for` with the proper body for `op`. +template +LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op); + } // namespace linalg } // namespace mlir diff --git a/third_party/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt index 2d699580c04..9f5863f2be9 100644 --- a/third_party/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt +++ b/third_party/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt @@ -1,4 +1 @@ -set(LLVM_TARGET_DEFINITIONS LoopOps.td) -mlir_tablegen(LoopOps.h.inc -gen-op-decls) -mlir_tablegen(LoopOps.cpp.inc -gen-op-defs) -add_public_tablegen_target(MLIRLoopOpsIncGen) +add_mlir_dialect(LoopOps) diff --git a/third_party/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt index 3e3b9462b88..f95532ecf6e 100644 --- a/third_party/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt +++ b/third_party/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt @@ -1,4 +1 @@ -set(LLVM_TARGET_DEFINITIONS QuantOps.td) -mlir_tablegen(QuantOps.h.inc -gen-op-decls) -mlir_tablegen(QuantOps.cpp.inc -gen-op-defs) -add_public_tablegen_target(MLIRQuantOpsIncGen) +add_mlir_dialect(QuantOps) diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt index c18d6534261..b6759a9111b 100644 --- a/third_party/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt +++ b/third_party/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt @@ -3,10 +3,7 @@ mlir_tablegen(SPIRVLowering.h.inc -gen-struct-attr-decls) mlir_tablegen(SPIRVLowering.cpp.inc -gen-struct-attr-defs) add_public_tablegen_target(MLIRSPIRVLoweringStructGen) -set(LLVM_TARGET_DEFINITIONS SPIRVOps.td) -mlir_tablegen(SPIRVOps.h.inc -gen-op-decls) -mlir_tablegen(SPIRVOps.cpp.inc -gen-op-defs) -add_public_tablegen_target(MLIRSPIRVOpsIncGen) +add_mlir_dialect(SPIRVOps) set(LLVM_TARGET_DEFINITIONS SPIRVBase.td) mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls) diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td index cbcd9303626..00ce72f5b2a 100644 --- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td +++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td @@ -292,6 +292,8 @@ def SPV_IAddOp : SPV_ArithmeticBinaryOp<"IAdd", SPV_Integer, [Commutative]> { ``` }]; + + let hasFolder = 1; } // ----- @@ -328,6 +330,8 @@ def SPV_IMulOp : SPV_ArithmeticBinaryOp<"IMul", SPV_Integer, [Commutative]> { ``` }]; + + let hasFolder = 1; } // ----- diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index 07cdd7ac790..2ee8f3bdd43 100644 --- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -54,236 +54,6 @@ def SPV_Dialect : Dialect { let cppNamespace = "spirv"; } -//===----------------------------------------------------------------------===// -// SPIR-V opcode specification -//===----------------------------------------------------------------------===// - -class SPV_OpCode { - // Name used as reference to retrieve the opcode - string opname = name; - - // Opcode associated with the name - int opcode = val; -} - -// Begin opcode section. Generated from SPIR-V spec; DO NOT MODIFY! - -def SPV_OC_OpNop : I32EnumAttrCase<"OpNop", 0>; -def SPV_OC_OpUndef : I32EnumAttrCase<"OpUndef", 1>; -def SPV_OC_OpSourceContinued : I32EnumAttrCase<"OpSourceContinued", 2>; -def SPV_OC_OpSource : I32EnumAttrCase<"OpSource", 3>; -def SPV_OC_OpSourceExtension : I32EnumAttrCase<"OpSourceExtension", 4>; -def SPV_OC_OpName : I32EnumAttrCase<"OpName", 5>; -def SPV_OC_OpMemberName : I32EnumAttrCase<"OpMemberName", 6>; -def SPV_OC_OpString : I32EnumAttrCase<"OpString", 7>; -def SPV_OC_OpExtension : I32EnumAttrCase<"OpExtension", 10>; -def SPV_OC_OpExtInstImport : I32EnumAttrCase<"OpExtInstImport", 11>; -def SPV_OC_OpExtInst : I32EnumAttrCase<"OpExtInst", 12>; -def SPV_OC_OpMemoryModel : I32EnumAttrCase<"OpMemoryModel", 14>; -def SPV_OC_OpEntryPoint : I32EnumAttrCase<"OpEntryPoint", 15>; -def SPV_OC_OpExecutionMode : I32EnumAttrCase<"OpExecutionMode", 16>; -def SPV_OC_OpCapability : I32EnumAttrCase<"OpCapability", 17>; -def SPV_OC_OpTypeVoid : I32EnumAttrCase<"OpTypeVoid", 19>; -def SPV_OC_OpTypeBool : I32EnumAttrCase<"OpTypeBool", 20>; -def SPV_OC_OpTypeInt : I32EnumAttrCase<"OpTypeInt", 21>; -def SPV_OC_OpTypeFloat : I32EnumAttrCase<"OpTypeFloat", 22>; -def SPV_OC_OpTypeVector : I32EnumAttrCase<"OpTypeVector", 23>; -def SPV_OC_OpTypeArray : I32EnumAttrCase<"OpTypeArray", 28>; -def SPV_OC_OpTypeRuntimeArray : I32EnumAttrCase<"OpTypeRuntimeArray", 29>; -def SPV_OC_OpTypeStruct : I32EnumAttrCase<"OpTypeStruct", 30>; -def SPV_OC_OpTypePointer : I32EnumAttrCase<"OpTypePointer", 32>; -def SPV_OC_OpTypeFunction : I32EnumAttrCase<"OpTypeFunction", 33>; -def SPV_OC_OpConstantTrue : I32EnumAttrCase<"OpConstantTrue", 41>; -def SPV_OC_OpConstantFalse : I32EnumAttrCase<"OpConstantFalse", 42>; -def SPV_OC_OpConstant : I32EnumAttrCase<"OpConstant", 43>; -def SPV_OC_OpConstantComposite : I32EnumAttrCase<"OpConstantComposite", 44>; -def SPV_OC_OpConstantNull : I32EnumAttrCase<"OpConstantNull", 46>; -def SPV_OC_OpSpecConstantTrue : I32EnumAttrCase<"OpSpecConstantTrue", 48>; -def SPV_OC_OpSpecConstantFalse : I32EnumAttrCase<"OpSpecConstantFalse", 49>; -def SPV_OC_OpSpecConstant : I32EnumAttrCase<"OpSpecConstant", 50>; -def SPV_OC_OpSpecConstantComposite : I32EnumAttrCase<"OpSpecConstantComposite", 51>; -def SPV_OC_OpFunction : I32EnumAttrCase<"OpFunction", 54>; -def SPV_OC_OpFunctionParameter : I32EnumAttrCase<"OpFunctionParameter", 55>; -def SPV_OC_OpFunctionEnd : I32EnumAttrCase<"OpFunctionEnd", 56>; -def SPV_OC_OpFunctionCall : I32EnumAttrCase<"OpFunctionCall", 57>; -def SPV_OC_OpVariable : I32EnumAttrCase<"OpVariable", 59>; -def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>; -def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>; -def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>; -def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>; -def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>; -def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>; -def SPV_OC_OpConvertFToU : I32EnumAttrCase<"OpConvertFToU", 109>; -def SPV_OC_OpConvertFToS : I32EnumAttrCase<"OpConvertFToS", 110>; -def SPV_OC_OpConvertSToF : I32EnumAttrCase<"OpConvertSToF", 111>; -def SPV_OC_OpConvertUToF : I32EnumAttrCase<"OpConvertUToF", 112>; -def SPV_OC_OpUConvert : I32EnumAttrCase<"OpUConvert", 113>; -def SPV_OC_OpSConvert : I32EnumAttrCase<"OpSConvert", 114>; -def SPV_OC_OpFConvert : I32EnumAttrCase<"OpFConvert", 115>; -def SPV_OC_OpBitcast : I32EnumAttrCase<"OpBitcast", 124>; -def SPV_OC_OpFNegate : I32EnumAttrCase<"OpFNegate", 127>; -def SPV_OC_OpIAdd : I32EnumAttrCase<"OpIAdd", 128>; -def SPV_OC_OpFAdd : I32EnumAttrCase<"OpFAdd", 129>; -def SPV_OC_OpISub : I32EnumAttrCase<"OpISub", 130>; -def SPV_OC_OpFSub : I32EnumAttrCase<"OpFSub", 131>; -def SPV_OC_OpIMul : I32EnumAttrCase<"OpIMul", 132>; -def SPV_OC_OpFMul : I32EnumAttrCase<"OpFMul", 133>; -def SPV_OC_OpUDiv : I32EnumAttrCase<"OpUDiv", 134>; -def SPV_OC_OpSDiv : I32EnumAttrCase<"OpSDiv", 135>; -def SPV_OC_OpFDiv : I32EnumAttrCase<"OpFDiv", 136>; -def SPV_OC_OpUMod : I32EnumAttrCase<"OpUMod", 137>; -def SPV_OC_OpSRem : I32EnumAttrCase<"OpSRem", 138>; -def SPV_OC_OpSMod : I32EnumAttrCase<"OpSMod", 139>; -def SPV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>; -def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>; -def SPV_OC_OpLogicalEqual : I32EnumAttrCase<"OpLogicalEqual", 164>; -def SPV_OC_OpLogicalNotEqual : I32EnumAttrCase<"OpLogicalNotEqual", 165>; -def SPV_OC_OpLogicalOr : I32EnumAttrCase<"OpLogicalOr", 166>; -def SPV_OC_OpLogicalAnd : I32EnumAttrCase<"OpLogicalAnd", 167>; -def SPV_OC_OpLogicalNot : I32EnumAttrCase<"OpLogicalNot", 168>; -def SPV_OC_OpSelect : I32EnumAttrCase<"OpSelect", 169>; -def SPV_OC_OpIEqual : I32EnumAttrCase<"OpIEqual", 170>; -def SPV_OC_OpINotEqual : I32EnumAttrCase<"OpINotEqual", 171>; -def SPV_OC_OpUGreaterThan : I32EnumAttrCase<"OpUGreaterThan", 172>; -def SPV_OC_OpSGreaterThan : I32EnumAttrCase<"OpSGreaterThan", 173>; -def SPV_OC_OpUGreaterThanEqual : I32EnumAttrCase<"OpUGreaterThanEqual", 174>; -def SPV_OC_OpSGreaterThanEqual : I32EnumAttrCase<"OpSGreaterThanEqual", 175>; -def SPV_OC_OpULessThan : I32EnumAttrCase<"OpULessThan", 176>; -def SPV_OC_OpSLessThan : I32EnumAttrCase<"OpSLessThan", 177>; -def SPV_OC_OpULessThanEqual : I32EnumAttrCase<"OpULessThanEqual", 178>; -def SPV_OC_OpSLessThanEqual : I32EnumAttrCase<"OpSLessThanEqual", 179>; -def SPV_OC_OpFOrdEqual : I32EnumAttrCase<"OpFOrdEqual", 180>; -def SPV_OC_OpFUnordEqual : I32EnumAttrCase<"OpFUnordEqual", 181>; -def SPV_OC_OpFOrdNotEqual : I32EnumAttrCase<"OpFOrdNotEqual", 182>; -def SPV_OC_OpFUnordNotEqual : I32EnumAttrCase<"OpFUnordNotEqual", 183>; -def SPV_OC_OpFOrdLessThan : I32EnumAttrCase<"OpFOrdLessThan", 184>; -def SPV_OC_OpFUnordLessThan : I32EnumAttrCase<"OpFUnordLessThan", 185>; -def SPV_OC_OpFOrdGreaterThan : I32EnumAttrCase<"OpFOrdGreaterThan", 186>; -def SPV_OC_OpFUnordGreaterThan : I32EnumAttrCase<"OpFUnordGreaterThan", 187>; -def SPV_OC_OpFOrdLessThanEqual : I32EnumAttrCase<"OpFOrdLessThanEqual", 188>; -def SPV_OC_OpFUnordLessThanEqual : I32EnumAttrCase<"OpFUnordLessThanEqual", 189>; -def SPV_OC_OpFOrdGreaterThanEqual : I32EnumAttrCase<"OpFOrdGreaterThanEqual", 190>; -def SPV_OC_OpFUnordGreaterThanEqual : I32EnumAttrCase<"OpFUnordGreaterThanEqual", 191>; -def SPV_OC_OpShiftRightLogical : I32EnumAttrCase<"OpShiftRightLogical", 194>; -def SPV_OC_OpShiftRightArithmetic : I32EnumAttrCase<"OpShiftRightArithmetic", 195>; -def SPV_OC_OpShiftLeftLogical : I32EnumAttrCase<"OpShiftLeftLogical", 196>; -def SPV_OC_OpBitwiseOr : I32EnumAttrCase<"OpBitwiseOr", 197>; -def SPV_OC_OpBitwiseXor : I32EnumAttrCase<"OpBitwiseXor", 198>; -def SPV_OC_OpBitwiseAnd : I32EnumAttrCase<"OpBitwiseAnd", 199>; -def SPV_OC_OpNot : I32EnumAttrCase<"OpNot", 200>; -def SPV_OC_OpBitFieldInsert : I32EnumAttrCase<"OpBitFieldInsert", 201>; -def SPV_OC_OpBitFieldSExtract : I32EnumAttrCase<"OpBitFieldSExtract", 202>; -def SPV_OC_OpBitFieldUExtract : I32EnumAttrCase<"OpBitFieldUExtract", 203>; -def SPV_OC_OpBitReverse : I32EnumAttrCase<"OpBitReverse", 204>; -def SPV_OC_OpBitCount : I32EnumAttrCase<"OpBitCount", 205>; -def SPV_OC_OpControlBarrier : I32EnumAttrCase<"OpControlBarrier", 224>; -def SPV_OC_OpMemoryBarrier : I32EnumAttrCase<"OpMemoryBarrier", 225>; -def SPV_OC_OpPhi : I32EnumAttrCase<"OpPhi", 245>; -def SPV_OC_OpLoopMerge : I32EnumAttrCase<"OpLoopMerge", 246>; -def SPV_OC_OpSelectionMerge : I32EnumAttrCase<"OpSelectionMerge", 247>; -def SPV_OC_OpLabel : I32EnumAttrCase<"OpLabel", 248>; -def SPV_OC_OpBranch : I32EnumAttrCase<"OpBranch", 249>; -def SPV_OC_OpBranchConditional : I32EnumAttrCase<"OpBranchConditional", 250>; -def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>; -def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>; -def SPV_OC_OpUnreachable : I32EnumAttrCase<"OpUnreachable", 255>; -def SPV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>; - -def SPV_OpcodeAttr : - I32EnumAttr<"Opcode", "valid SPIR-V instructions", [ - SPV_OC_OpNop, SPV_OC_OpUndef, SPV_OC_OpSourceContinued, SPV_OC_OpSource, - SPV_OC_OpSourceExtension, SPV_OC_OpName, SPV_OC_OpMemberName, SPV_OC_OpString, - SPV_OC_OpExtension, SPV_OC_OpExtInstImport, SPV_OC_OpExtInst, - SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode, - SPV_OC_OpCapability, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, - SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeArray, - SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer, - SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, - SPV_OC_OpConstant, SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull, - SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant, - SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, - SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, - SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, - SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpConvertFToU, - SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, - SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast, - SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, - SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, - SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, - SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, - SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, - SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, - SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, - SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, - SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, - SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, - SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, - SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, - SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, - SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic, - SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor, - SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert, - SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse, - SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, - SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, - SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn, - SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed - ]> { - let returnType = "::mlir::spirv::Opcode"; - let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())"; - let cppNamespace = "::mlir::spirv"; -} - -// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! - -//===----------------------------------------------------------------------===// -// SPIR-V type definitions -//===----------------------------------------------------------------------===// - -def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">; -def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">; -def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">; -def SPV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">; - -// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types -// for the definition of the following types and type categories. - -def SPV_Void : TypeAlias; -def SPV_Bool : IntOfWidths<[1]>; -def SPV_Integer : IntOfWidths<[8, 16, 32, 64]>; -def SPV_Float : FloatOfWidths<[16, 32, 64]>; -def SPV_Float16or32 : FloatOfWidths<[16, 32]>; -def SPV_Vector : VectorOfLengthAndType<[2, 3, 4], - [SPV_Bool, SPV_Integer, SPV_Float]>; -// Component type check is done in the type parser for the following SPIR-V -// dialect-specific types so we use "Any" here. -def SPV_AnyPtr : Type; -def SPV_AnyArray : Type; -def SPV_AnyRTArray : Type; -def SPV_AnyStruct : Type; - -def SPV_Numerical : AnyTypeOf<[SPV_Integer, SPV_Float]>; -def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>; -def SPV_Aggregate : AnyTypeOf<[SPV_AnyArray, SPV_AnyStruct]>; -def SPV_Composite : - AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>; -def SPV_Type : AnyTypeOf<[ - SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector, - SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct - ]>; - -class SPV_ScalarOrVectorOf : - AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4], [type]>]>; - -def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>; -def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>; - -// TODO(antiagainst): Use a more appropriate way to model optional operands -class SPV_Optional : Variadic; - -// TODO(ravishankarm): From 1.4, this should also include Composite type. -def SPV_SelectType : AnyTypeOf<[SPV_Scalar, SPV_Vector, SPV_AnyPtr]>; - //===----------------------------------------------------------------------===// // SPIR-V extension definitions //===----------------------------------------------------------------------===// @@ -316,153 +86,6 @@ def SPV_ExtensionAttr : // Begin enum section. Generated from SPIR-V spec; DO NOT MODIFY! -def SPV_AM_Logical : I32EnumAttrCase<"Logical", 0>; -def SPV_AM_Physical32 : I32EnumAttrCase<"Physical32", 1>; -def SPV_AM_Physical64 : I32EnumAttrCase<"Physical64", 2>; -def SPV_AM_PhysicalStorageBuffer64 : I32EnumAttrCase<"PhysicalStorageBuffer64", 5348>; - -def SPV_AddressingModelAttr : - I32EnumAttr<"AddressingModel", "valid SPIR-V AddressingModel", [ - SPV_AM_Logical, SPV_AM_Physical32, SPV_AM_Physical64, - SPV_AM_PhysicalStorageBuffer64 - ]> { - let cppNamespace = "::mlir::spirv"; -} - -def SPV_BI_Position : I32EnumAttrCase<"Position", 0>; -def SPV_BI_PointSize : I32EnumAttrCase<"PointSize", 1>; -def SPV_BI_ClipDistance : I32EnumAttrCase<"ClipDistance", 3>; -def SPV_BI_CullDistance : I32EnumAttrCase<"CullDistance", 4>; -def SPV_BI_VertexId : I32EnumAttrCase<"VertexId", 5>; -def SPV_BI_InstanceId : I32EnumAttrCase<"InstanceId", 6>; -def SPV_BI_PrimitiveId : I32EnumAttrCase<"PrimitiveId", 7>; -def SPV_BI_InvocationId : I32EnumAttrCase<"InvocationId", 8>; -def SPV_BI_Layer : I32EnumAttrCase<"Layer", 9>; -def SPV_BI_ViewportIndex : I32EnumAttrCase<"ViewportIndex", 10>; -def SPV_BI_TessLevelOuter : I32EnumAttrCase<"TessLevelOuter", 11>; -def SPV_BI_TessLevelInner : I32EnumAttrCase<"TessLevelInner", 12>; -def SPV_BI_TessCoord : I32EnumAttrCase<"TessCoord", 13>; -def SPV_BI_PatchVertices : I32EnumAttrCase<"PatchVertices", 14>; -def SPV_BI_FragCoord : I32EnumAttrCase<"FragCoord", 15>; -def SPV_BI_PointCoord : I32EnumAttrCase<"PointCoord", 16>; -def SPV_BI_FrontFacing : I32EnumAttrCase<"FrontFacing", 17>; -def SPV_BI_SampleId : I32EnumAttrCase<"SampleId", 18>; -def SPV_BI_SamplePosition : I32EnumAttrCase<"SamplePosition", 19>; -def SPV_BI_SampleMask : I32EnumAttrCase<"SampleMask", 20>; -def SPV_BI_FragDepth : I32EnumAttrCase<"FragDepth", 22>; -def SPV_BI_HelperInvocation : I32EnumAttrCase<"HelperInvocation", 23>; -def SPV_BI_NumWorkgroups : I32EnumAttrCase<"NumWorkgroups", 24>; -def SPV_BI_WorkgroupSize : I32EnumAttrCase<"WorkgroupSize", 25>; -def SPV_BI_WorkgroupId : I32EnumAttrCase<"WorkgroupId", 26>; -def SPV_BI_LocalInvocationId : I32EnumAttrCase<"LocalInvocationId", 27>; -def SPV_BI_GlobalInvocationId : I32EnumAttrCase<"GlobalInvocationId", 28>; -def SPV_BI_LocalInvocationIndex : I32EnumAttrCase<"LocalInvocationIndex", 29>; -def SPV_BI_WorkDim : I32EnumAttrCase<"WorkDim", 30>; -def SPV_BI_GlobalSize : I32EnumAttrCase<"GlobalSize", 31>; -def SPV_BI_EnqueuedWorkgroupSize : I32EnumAttrCase<"EnqueuedWorkgroupSize", 32>; -def SPV_BI_GlobalOffset : I32EnumAttrCase<"GlobalOffset", 33>; -def SPV_BI_GlobalLinearId : I32EnumAttrCase<"GlobalLinearId", 34>; -def SPV_BI_SubgroupSize : I32EnumAttrCase<"SubgroupSize", 36>; -def SPV_BI_SubgroupMaxSize : I32EnumAttrCase<"SubgroupMaxSize", 37>; -def SPV_BI_NumSubgroups : I32EnumAttrCase<"NumSubgroups", 38>; -def SPV_BI_NumEnqueuedSubgroups : I32EnumAttrCase<"NumEnqueuedSubgroups", 39>; -def SPV_BI_SubgroupId : I32EnumAttrCase<"SubgroupId", 40>; -def SPV_BI_SubgroupLocalInvocationId : I32EnumAttrCase<"SubgroupLocalInvocationId", 41>; -def SPV_BI_VertexIndex : I32EnumAttrCase<"VertexIndex", 42>; -def SPV_BI_InstanceIndex : I32EnumAttrCase<"InstanceIndex", 43>; -def SPV_BI_SubgroupEqMask : I32EnumAttrCase<"SubgroupEqMask", 4416>; -def SPV_BI_SubgroupGeMask : I32EnumAttrCase<"SubgroupGeMask", 4417>; -def SPV_BI_SubgroupGtMask : I32EnumAttrCase<"SubgroupGtMask", 4418>; -def SPV_BI_SubgroupLeMask : I32EnumAttrCase<"SubgroupLeMask", 4419>; -def SPV_BI_SubgroupLtMask : I32EnumAttrCase<"SubgroupLtMask", 4420>; -def SPV_BI_BaseVertex : I32EnumAttrCase<"BaseVertex", 4424>; -def SPV_BI_BaseInstance : I32EnumAttrCase<"BaseInstance", 4425>; -def SPV_BI_DrawIndex : I32EnumAttrCase<"DrawIndex", 4426>; -def SPV_BI_DeviceIndex : I32EnumAttrCase<"DeviceIndex", 4438>; -def SPV_BI_ViewIndex : I32EnumAttrCase<"ViewIndex", 4440>; -def SPV_BI_BaryCoordNoPerspAMD : I32EnumAttrCase<"BaryCoordNoPerspAMD", 4992>; -def SPV_BI_BaryCoordNoPerspCentroidAMD : I32EnumAttrCase<"BaryCoordNoPerspCentroidAMD", 4993>; -def SPV_BI_BaryCoordNoPerspSampleAMD : I32EnumAttrCase<"BaryCoordNoPerspSampleAMD", 4994>; -def SPV_BI_BaryCoordSmoothAMD : I32EnumAttrCase<"BaryCoordSmoothAMD", 4995>; -def SPV_BI_BaryCoordSmoothCentroidAMD : I32EnumAttrCase<"BaryCoordSmoothCentroidAMD", 4996>; -def SPV_BI_BaryCoordSmoothSampleAMD : I32EnumAttrCase<"BaryCoordSmoothSampleAMD", 4997>; -def SPV_BI_BaryCoordPullModelAMD : I32EnumAttrCase<"BaryCoordPullModelAMD", 4998>; -def SPV_BI_FragStencilRefEXT : I32EnumAttrCase<"FragStencilRefEXT", 5014>; -def SPV_BI_ViewportMaskNV : I32EnumAttrCase<"ViewportMaskNV", 5253>; -def SPV_BI_SecondaryPositionNV : I32EnumAttrCase<"SecondaryPositionNV", 5257>; -def SPV_BI_SecondaryViewportMaskNV : I32EnumAttrCase<"SecondaryViewportMaskNV", 5258>; -def SPV_BI_PositionPerViewNV : I32EnumAttrCase<"PositionPerViewNV", 5261>; -def SPV_BI_ViewportMaskPerViewNV : I32EnumAttrCase<"ViewportMaskPerViewNV", 5262>; -def SPV_BI_FullyCoveredEXT : I32EnumAttrCase<"FullyCoveredEXT", 5264>; -def SPV_BI_TaskCountNV : I32EnumAttrCase<"TaskCountNV", 5274>; -def SPV_BI_PrimitiveCountNV : I32EnumAttrCase<"PrimitiveCountNV", 5275>; -def SPV_BI_PrimitiveIndicesNV : I32EnumAttrCase<"PrimitiveIndicesNV", 5276>; -def SPV_BI_ClipDistancePerViewNV : I32EnumAttrCase<"ClipDistancePerViewNV", 5277>; -def SPV_BI_CullDistancePerViewNV : I32EnumAttrCase<"CullDistancePerViewNV", 5278>; -def SPV_BI_LayerPerViewNV : I32EnumAttrCase<"LayerPerViewNV", 5279>; -def SPV_BI_MeshViewCountNV : I32EnumAttrCase<"MeshViewCountNV", 5280>; -def SPV_BI_MeshViewIndicesNV : I32EnumAttrCase<"MeshViewIndicesNV", 5281>; -def SPV_BI_BaryCoordNV : I32EnumAttrCase<"BaryCoordNV", 5286>; -def SPV_BI_BaryCoordNoPerspNV : I32EnumAttrCase<"BaryCoordNoPerspNV", 5287>; -def SPV_BI_FragSizeEXT : I32EnumAttrCase<"FragSizeEXT", 5292>; -def SPV_BI_FragInvocationCountEXT : I32EnumAttrCase<"FragInvocationCountEXT", 5293>; -def SPV_BI_LaunchIdNV : I32EnumAttrCase<"LaunchIdNV", 5319>; -def SPV_BI_LaunchSizeNV : I32EnumAttrCase<"LaunchSizeNV", 5320>; -def SPV_BI_WorldRayOriginNV : I32EnumAttrCase<"WorldRayOriginNV", 5321>; -def SPV_BI_WorldRayDirectionNV : I32EnumAttrCase<"WorldRayDirectionNV", 5322>; -def SPV_BI_ObjectRayOriginNV : I32EnumAttrCase<"ObjectRayOriginNV", 5323>; -def SPV_BI_ObjectRayDirectionNV : I32EnumAttrCase<"ObjectRayDirectionNV", 5324>; -def SPV_BI_RayTminNV : I32EnumAttrCase<"RayTminNV", 5325>; -def SPV_BI_RayTmaxNV : I32EnumAttrCase<"RayTmaxNV", 5326>; -def SPV_BI_InstanceCustomIndexNV : I32EnumAttrCase<"InstanceCustomIndexNV", 5327>; -def SPV_BI_ObjectToWorldNV : I32EnumAttrCase<"ObjectToWorldNV", 5330>; -def SPV_BI_WorldToObjectNV : I32EnumAttrCase<"WorldToObjectNV", 5331>; -def SPV_BI_HitTNV : I32EnumAttrCase<"HitTNV", 5332>; -def SPV_BI_HitKindNV : I32EnumAttrCase<"HitKindNV", 5333>; -def SPV_BI_IncomingRayFlagsNV : I32EnumAttrCase<"IncomingRayFlagsNV", 5351>; -def SPV_BI_WarpsPerSMNV : I32EnumAttrCase<"WarpsPerSMNV", 5374>; -def SPV_BI_SMCountNV : I32EnumAttrCase<"SMCountNV", 5375>; -def SPV_BI_WarpIDNV : I32EnumAttrCase<"WarpIDNV", 5376>; -def SPV_BI_SMIDNV : I32EnumAttrCase<"SMIDNV", 5377>; - -def SPV_BuiltInAttr : - I32EnumAttr<"BuiltIn", "valid SPIR-V BuiltIn", [ - SPV_BI_Position, SPV_BI_PointSize, SPV_BI_ClipDistance, SPV_BI_CullDistance, - SPV_BI_VertexId, SPV_BI_InstanceId, SPV_BI_PrimitiveId, SPV_BI_InvocationId, - SPV_BI_Layer, SPV_BI_ViewportIndex, SPV_BI_TessLevelOuter, - SPV_BI_TessLevelInner, SPV_BI_TessCoord, SPV_BI_PatchVertices, - SPV_BI_FragCoord, SPV_BI_PointCoord, SPV_BI_FrontFacing, SPV_BI_SampleId, - SPV_BI_SamplePosition, SPV_BI_SampleMask, SPV_BI_FragDepth, - SPV_BI_HelperInvocation, SPV_BI_NumWorkgroups, SPV_BI_WorkgroupSize, - SPV_BI_WorkgroupId, SPV_BI_LocalInvocationId, SPV_BI_GlobalInvocationId, - SPV_BI_LocalInvocationIndex, SPV_BI_WorkDim, SPV_BI_GlobalSize, - SPV_BI_EnqueuedWorkgroupSize, SPV_BI_GlobalOffset, SPV_BI_GlobalLinearId, - SPV_BI_SubgroupSize, SPV_BI_SubgroupMaxSize, SPV_BI_NumSubgroups, - SPV_BI_NumEnqueuedSubgroups, SPV_BI_SubgroupId, - SPV_BI_SubgroupLocalInvocationId, SPV_BI_VertexIndex, SPV_BI_InstanceIndex, - SPV_BI_SubgroupEqMask, SPV_BI_SubgroupGeMask, SPV_BI_SubgroupGtMask, - SPV_BI_SubgroupLeMask, SPV_BI_SubgroupLtMask, SPV_BI_BaseVertex, - SPV_BI_BaseInstance, SPV_BI_DrawIndex, SPV_BI_DeviceIndex, SPV_BI_ViewIndex, - SPV_BI_BaryCoordNoPerspAMD, SPV_BI_BaryCoordNoPerspCentroidAMD, - SPV_BI_BaryCoordNoPerspSampleAMD, SPV_BI_BaryCoordSmoothAMD, - SPV_BI_BaryCoordSmoothCentroidAMD, SPV_BI_BaryCoordSmoothSampleAMD, - SPV_BI_BaryCoordPullModelAMD, SPV_BI_FragStencilRefEXT, SPV_BI_ViewportMaskNV, - SPV_BI_SecondaryPositionNV, SPV_BI_SecondaryViewportMaskNV, - SPV_BI_PositionPerViewNV, SPV_BI_ViewportMaskPerViewNV, SPV_BI_FullyCoveredEXT, - SPV_BI_TaskCountNV, SPV_BI_PrimitiveCountNV, SPV_BI_PrimitiveIndicesNV, - SPV_BI_ClipDistancePerViewNV, SPV_BI_CullDistancePerViewNV, - SPV_BI_LayerPerViewNV, SPV_BI_MeshViewCountNV, SPV_BI_MeshViewIndicesNV, - SPV_BI_BaryCoordNV, SPV_BI_BaryCoordNoPerspNV, SPV_BI_FragSizeEXT, - SPV_BI_FragInvocationCountEXT, SPV_BI_LaunchIdNV, SPV_BI_LaunchSizeNV, - SPV_BI_WorldRayOriginNV, SPV_BI_WorldRayDirectionNV, SPV_BI_ObjectRayOriginNV, - SPV_BI_ObjectRayDirectionNV, SPV_BI_RayTminNV, SPV_BI_RayTmaxNV, - SPV_BI_InstanceCustomIndexNV, SPV_BI_ObjectToWorldNV, SPV_BI_WorldToObjectNV, - SPV_BI_HitTNV, SPV_BI_HitKindNV, SPV_BI_IncomingRayFlagsNV, - SPV_BI_WarpsPerSMNV, SPV_BI_SMCountNV, SPV_BI_WarpIDNV, SPV_BI_SMIDNV - ]> { - let cppNamespace = "::mlir::spirv"; -} - def SPV_C_Matrix : I32EnumAttrCase<"Matrix", 0>; def SPV_C_Shader : I32EnumAttrCase<"Shader", 1>; def SPV_C_Geometry : I32EnumAttrCase<"Geometry", 2>; @@ -671,6 +294,153 @@ def SPV_CapabilityAttr : let cppNamespace = "::mlir::spirv"; } +def SPV_AM_Logical : I32EnumAttrCase<"Logical", 0>; +def SPV_AM_Physical32 : I32EnumAttrCase<"Physical32", 1>; +def SPV_AM_Physical64 : I32EnumAttrCase<"Physical64", 2>; +def SPV_AM_PhysicalStorageBuffer64 : I32EnumAttrCase<"PhysicalStorageBuffer64", 5348>; + +def SPV_AddressingModelAttr : + I32EnumAttr<"AddressingModel", "valid SPIR-V AddressingModel", [ + SPV_AM_Logical, SPV_AM_Physical32, SPV_AM_Physical64, + SPV_AM_PhysicalStorageBuffer64 + ]> { + let cppNamespace = "::mlir::spirv"; +} + +def SPV_BI_Position : I32EnumAttrCase<"Position", 0>; +def SPV_BI_PointSize : I32EnumAttrCase<"PointSize", 1>; +def SPV_BI_ClipDistance : I32EnumAttrCase<"ClipDistance", 3>; +def SPV_BI_CullDistance : I32EnumAttrCase<"CullDistance", 4>; +def SPV_BI_VertexId : I32EnumAttrCase<"VertexId", 5>; +def SPV_BI_InstanceId : I32EnumAttrCase<"InstanceId", 6>; +def SPV_BI_PrimitiveId : I32EnumAttrCase<"PrimitiveId", 7>; +def SPV_BI_InvocationId : I32EnumAttrCase<"InvocationId", 8>; +def SPV_BI_Layer : I32EnumAttrCase<"Layer", 9>; +def SPV_BI_ViewportIndex : I32EnumAttrCase<"ViewportIndex", 10>; +def SPV_BI_TessLevelOuter : I32EnumAttrCase<"TessLevelOuter", 11>; +def SPV_BI_TessLevelInner : I32EnumAttrCase<"TessLevelInner", 12>; +def SPV_BI_TessCoord : I32EnumAttrCase<"TessCoord", 13>; +def SPV_BI_PatchVertices : I32EnumAttrCase<"PatchVertices", 14>; +def SPV_BI_FragCoord : I32EnumAttrCase<"FragCoord", 15>; +def SPV_BI_PointCoord : I32EnumAttrCase<"PointCoord", 16>; +def SPV_BI_FrontFacing : I32EnumAttrCase<"FrontFacing", 17>; +def SPV_BI_SampleId : I32EnumAttrCase<"SampleId", 18>; +def SPV_BI_SamplePosition : I32EnumAttrCase<"SamplePosition", 19>; +def SPV_BI_SampleMask : I32EnumAttrCase<"SampleMask", 20>; +def SPV_BI_FragDepth : I32EnumAttrCase<"FragDepth", 22>; +def SPV_BI_HelperInvocation : I32EnumAttrCase<"HelperInvocation", 23>; +def SPV_BI_NumWorkgroups : I32EnumAttrCase<"NumWorkgroups", 24>; +def SPV_BI_WorkgroupSize : I32EnumAttrCase<"WorkgroupSize", 25>; +def SPV_BI_WorkgroupId : I32EnumAttrCase<"WorkgroupId", 26>; +def SPV_BI_LocalInvocationId : I32EnumAttrCase<"LocalInvocationId", 27>; +def SPV_BI_GlobalInvocationId : I32EnumAttrCase<"GlobalInvocationId", 28>; +def SPV_BI_LocalInvocationIndex : I32EnumAttrCase<"LocalInvocationIndex", 29>; +def SPV_BI_WorkDim : I32EnumAttrCase<"WorkDim", 30>; +def SPV_BI_GlobalSize : I32EnumAttrCase<"GlobalSize", 31>; +def SPV_BI_EnqueuedWorkgroupSize : I32EnumAttrCase<"EnqueuedWorkgroupSize", 32>; +def SPV_BI_GlobalOffset : I32EnumAttrCase<"GlobalOffset", 33>; +def SPV_BI_GlobalLinearId : I32EnumAttrCase<"GlobalLinearId", 34>; +def SPV_BI_SubgroupSize : I32EnumAttrCase<"SubgroupSize", 36>; +def SPV_BI_SubgroupMaxSize : I32EnumAttrCase<"SubgroupMaxSize", 37>; +def SPV_BI_NumSubgroups : I32EnumAttrCase<"NumSubgroups", 38>; +def SPV_BI_NumEnqueuedSubgroups : I32EnumAttrCase<"NumEnqueuedSubgroups", 39>; +def SPV_BI_SubgroupId : I32EnumAttrCase<"SubgroupId", 40>; +def SPV_BI_SubgroupLocalInvocationId : I32EnumAttrCase<"SubgroupLocalInvocationId", 41>; +def SPV_BI_VertexIndex : I32EnumAttrCase<"VertexIndex", 42>; +def SPV_BI_InstanceIndex : I32EnumAttrCase<"InstanceIndex", 43>; +def SPV_BI_SubgroupEqMask : I32EnumAttrCase<"SubgroupEqMask", 4416>; +def SPV_BI_SubgroupGeMask : I32EnumAttrCase<"SubgroupGeMask", 4417>; +def SPV_BI_SubgroupGtMask : I32EnumAttrCase<"SubgroupGtMask", 4418>; +def SPV_BI_SubgroupLeMask : I32EnumAttrCase<"SubgroupLeMask", 4419>; +def SPV_BI_SubgroupLtMask : I32EnumAttrCase<"SubgroupLtMask", 4420>; +def SPV_BI_BaseVertex : I32EnumAttrCase<"BaseVertex", 4424>; +def SPV_BI_BaseInstance : I32EnumAttrCase<"BaseInstance", 4425>; +def SPV_BI_DrawIndex : I32EnumAttrCase<"DrawIndex", 4426>; +def SPV_BI_DeviceIndex : I32EnumAttrCase<"DeviceIndex", 4438>; +def SPV_BI_ViewIndex : I32EnumAttrCase<"ViewIndex", 4440>; +def SPV_BI_BaryCoordNoPerspAMD : I32EnumAttrCase<"BaryCoordNoPerspAMD", 4992>; +def SPV_BI_BaryCoordNoPerspCentroidAMD : I32EnumAttrCase<"BaryCoordNoPerspCentroidAMD", 4993>; +def SPV_BI_BaryCoordNoPerspSampleAMD : I32EnumAttrCase<"BaryCoordNoPerspSampleAMD", 4994>; +def SPV_BI_BaryCoordSmoothAMD : I32EnumAttrCase<"BaryCoordSmoothAMD", 4995>; +def SPV_BI_BaryCoordSmoothCentroidAMD : I32EnumAttrCase<"BaryCoordSmoothCentroidAMD", 4996>; +def SPV_BI_BaryCoordSmoothSampleAMD : I32EnumAttrCase<"BaryCoordSmoothSampleAMD", 4997>; +def SPV_BI_BaryCoordPullModelAMD : I32EnumAttrCase<"BaryCoordPullModelAMD", 4998>; +def SPV_BI_FragStencilRefEXT : I32EnumAttrCase<"FragStencilRefEXT", 5014>; +def SPV_BI_ViewportMaskNV : I32EnumAttrCase<"ViewportMaskNV", 5253>; +def SPV_BI_SecondaryPositionNV : I32EnumAttrCase<"SecondaryPositionNV", 5257>; +def SPV_BI_SecondaryViewportMaskNV : I32EnumAttrCase<"SecondaryViewportMaskNV", 5258>; +def SPV_BI_PositionPerViewNV : I32EnumAttrCase<"PositionPerViewNV", 5261>; +def SPV_BI_ViewportMaskPerViewNV : I32EnumAttrCase<"ViewportMaskPerViewNV", 5262>; +def SPV_BI_FullyCoveredEXT : I32EnumAttrCase<"FullyCoveredEXT", 5264>; +def SPV_BI_TaskCountNV : I32EnumAttrCase<"TaskCountNV", 5274>; +def SPV_BI_PrimitiveCountNV : I32EnumAttrCase<"PrimitiveCountNV", 5275>; +def SPV_BI_PrimitiveIndicesNV : I32EnumAttrCase<"PrimitiveIndicesNV", 5276>; +def SPV_BI_ClipDistancePerViewNV : I32EnumAttrCase<"ClipDistancePerViewNV", 5277>; +def SPV_BI_CullDistancePerViewNV : I32EnumAttrCase<"CullDistancePerViewNV", 5278>; +def SPV_BI_LayerPerViewNV : I32EnumAttrCase<"LayerPerViewNV", 5279>; +def SPV_BI_MeshViewCountNV : I32EnumAttrCase<"MeshViewCountNV", 5280>; +def SPV_BI_MeshViewIndicesNV : I32EnumAttrCase<"MeshViewIndicesNV", 5281>; +def SPV_BI_BaryCoordNV : I32EnumAttrCase<"BaryCoordNV", 5286>; +def SPV_BI_BaryCoordNoPerspNV : I32EnumAttrCase<"BaryCoordNoPerspNV", 5287>; +def SPV_BI_FragSizeEXT : I32EnumAttrCase<"FragSizeEXT", 5292>; +def SPV_BI_FragInvocationCountEXT : I32EnumAttrCase<"FragInvocationCountEXT", 5293>; +def SPV_BI_LaunchIdNV : I32EnumAttrCase<"LaunchIdNV", 5319>; +def SPV_BI_LaunchSizeNV : I32EnumAttrCase<"LaunchSizeNV", 5320>; +def SPV_BI_WorldRayOriginNV : I32EnumAttrCase<"WorldRayOriginNV", 5321>; +def SPV_BI_WorldRayDirectionNV : I32EnumAttrCase<"WorldRayDirectionNV", 5322>; +def SPV_BI_ObjectRayOriginNV : I32EnumAttrCase<"ObjectRayOriginNV", 5323>; +def SPV_BI_ObjectRayDirectionNV : I32EnumAttrCase<"ObjectRayDirectionNV", 5324>; +def SPV_BI_RayTminNV : I32EnumAttrCase<"RayTminNV", 5325>; +def SPV_BI_RayTmaxNV : I32EnumAttrCase<"RayTmaxNV", 5326>; +def SPV_BI_InstanceCustomIndexNV : I32EnumAttrCase<"InstanceCustomIndexNV", 5327>; +def SPV_BI_ObjectToWorldNV : I32EnumAttrCase<"ObjectToWorldNV", 5330>; +def SPV_BI_WorldToObjectNV : I32EnumAttrCase<"WorldToObjectNV", 5331>; +def SPV_BI_HitTNV : I32EnumAttrCase<"HitTNV", 5332>; +def SPV_BI_HitKindNV : I32EnumAttrCase<"HitKindNV", 5333>; +def SPV_BI_IncomingRayFlagsNV : I32EnumAttrCase<"IncomingRayFlagsNV", 5351>; +def SPV_BI_WarpsPerSMNV : I32EnumAttrCase<"WarpsPerSMNV", 5374>; +def SPV_BI_SMCountNV : I32EnumAttrCase<"SMCountNV", 5375>; +def SPV_BI_WarpIDNV : I32EnumAttrCase<"WarpIDNV", 5376>; +def SPV_BI_SMIDNV : I32EnumAttrCase<"SMIDNV", 5377>; + +def SPV_BuiltInAttr : + I32EnumAttr<"BuiltIn", "valid SPIR-V BuiltIn", [ + SPV_BI_Position, SPV_BI_PointSize, SPV_BI_ClipDistance, SPV_BI_CullDistance, + SPV_BI_VertexId, SPV_BI_InstanceId, SPV_BI_PrimitiveId, SPV_BI_InvocationId, + SPV_BI_Layer, SPV_BI_ViewportIndex, SPV_BI_TessLevelOuter, + SPV_BI_TessLevelInner, SPV_BI_TessCoord, SPV_BI_PatchVertices, + SPV_BI_FragCoord, SPV_BI_PointCoord, SPV_BI_FrontFacing, SPV_BI_SampleId, + SPV_BI_SamplePosition, SPV_BI_SampleMask, SPV_BI_FragDepth, + SPV_BI_HelperInvocation, SPV_BI_NumWorkgroups, SPV_BI_WorkgroupSize, + SPV_BI_WorkgroupId, SPV_BI_LocalInvocationId, SPV_BI_GlobalInvocationId, + SPV_BI_LocalInvocationIndex, SPV_BI_WorkDim, SPV_BI_GlobalSize, + SPV_BI_EnqueuedWorkgroupSize, SPV_BI_GlobalOffset, SPV_BI_GlobalLinearId, + SPV_BI_SubgroupSize, SPV_BI_SubgroupMaxSize, SPV_BI_NumSubgroups, + SPV_BI_NumEnqueuedSubgroups, SPV_BI_SubgroupId, + SPV_BI_SubgroupLocalInvocationId, SPV_BI_VertexIndex, SPV_BI_InstanceIndex, + SPV_BI_SubgroupEqMask, SPV_BI_SubgroupGeMask, SPV_BI_SubgroupGtMask, + SPV_BI_SubgroupLeMask, SPV_BI_SubgroupLtMask, SPV_BI_BaseVertex, + SPV_BI_BaseInstance, SPV_BI_DrawIndex, SPV_BI_DeviceIndex, SPV_BI_ViewIndex, + SPV_BI_BaryCoordNoPerspAMD, SPV_BI_BaryCoordNoPerspCentroidAMD, + SPV_BI_BaryCoordNoPerspSampleAMD, SPV_BI_BaryCoordSmoothAMD, + SPV_BI_BaryCoordSmoothCentroidAMD, SPV_BI_BaryCoordSmoothSampleAMD, + SPV_BI_BaryCoordPullModelAMD, SPV_BI_FragStencilRefEXT, SPV_BI_ViewportMaskNV, + SPV_BI_SecondaryPositionNV, SPV_BI_SecondaryViewportMaskNV, + SPV_BI_PositionPerViewNV, SPV_BI_ViewportMaskPerViewNV, SPV_BI_FullyCoveredEXT, + SPV_BI_TaskCountNV, SPV_BI_PrimitiveCountNV, SPV_BI_PrimitiveIndicesNV, + SPV_BI_ClipDistancePerViewNV, SPV_BI_CullDistancePerViewNV, + SPV_BI_LayerPerViewNV, SPV_BI_MeshViewCountNV, SPV_BI_MeshViewIndicesNV, + SPV_BI_BaryCoordNV, SPV_BI_BaryCoordNoPerspNV, SPV_BI_FragSizeEXT, + SPV_BI_FragInvocationCountEXT, SPV_BI_LaunchIdNV, SPV_BI_LaunchSizeNV, + SPV_BI_WorldRayOriginNV, SPV_BI_WorldRayDirectionNV, SPV_BI_ObjectRayOriginNV, + SPV_BI_ObjectRayDirectionNV, SPV_BI_RayTminNV, SPV_BI_RayTmaxNV, + SPV_BI_InstanceCustomIndexNV, SPV_BI_ObjectToWorldNV, SPV_BI_WorldToObjectNV, + SPV_BI_HitTNV, SPV_BI_HitKindNV, SPV_BI_IncomingRayFlagsNV, + SPV_BI_WarpsPerSMNV, SPV_BI_SMCountNV, SPV_BI_WarpIDNV, SPV_BI_SMIDNV + ]> { + let cppNamespace = "::mlir::spirv"; +} + def SPV_D_RelaxedPrecision : I32EnumAttrCase<"RelaxedPrecision", 0>; def SPV_D_SpecId : I32EnumAttrCase<"SpecId", 1>; def SPV_D_Block : I32EnumAttrCase<"Block", 2>; @@ -1101,7 +871,7 @@ def SPV_StorageClassAttr : // End enum section. Generated from SPIR-V spec; DO NOT MODIFY! -// Enums added manually that are not part of SPIRV spec +// Enums added manually that are not part of SPIR-V spec def SPV_IDI_NoDepth : I32EnumAttrCase<"NoDepth", 0>; def SPV_IDI_IsDepth : I32EnumAttrCase<"IsDepth", 1>; @@ -1141,6 +911,58 @@ def SPV_SamplerUseAttr: let cppNamespace = "::mlir::spirv"; } +//===----------------------------------------------------------------------===// +// SPIR-V type definitions +//===----------------------------------------------------------------------===// + +def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">; +def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">; +def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">; +def SPV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">; + +// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types +// for the definition of the following types and type categories. + +def SPV_Void : TypeAlias; +def SPV_Bool : IntOfWidths<[1]>; +def SPV_Integer : IntOfWidths<[8, 16, 32, 64]>; +def SPV_Float : FloatOfWidths<[16, 32, 64]>; +def SPV_Float16or32 : FloatOfWidths<[16, 32]>; +def SPV_Vector : VectorOfLengthAndType<[2, 3, 4], + [SPV_Bool, SPV_Integer, SPV_Float]>; +// Component type check is done in the type parser for the following SPIR-V +// dialect-specific types so we use "Any" here. +def SPV_AnyPtr : Type; +def SPV_AnyArray : Type; +def SPV_AnyRTArray : Type; +def SPV_AnyStruct : Type; + +def SPV_Numerical : AnyTypeOf<[SPV_Integer, SPV_Float]>; +def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>; +def SPV_Aggregate : AnyTypeOf<[SPV_AnyArray, SPV_AnyStruct]>; +def SPV_Composite : + AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>; +def SPV_Type : AnyTypeOf<[ + SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector, + SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct + ]>; + +class SPV_ScalarOrVectorOf : + AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4], [type]>]>; + +def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>; +def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>; + +class SPV_Vec4 : VectorOfLengthAndType<[4], [type]>; +def SPV_IntVec4 : SPV_Vec4; +def SPV_I32Vec4 : SPV_Vec4; + +// TODO(antiagainst): Use a more appropriate way to model optional operands +class SPV_Optional : Variadic; + +// TODO(ravishankarm): From 1.4, this should also include Composite type. +def SPV_SelectType : AnyTypeOf<[SPV_Scalar, SPV_Vector, SPV_AnyPtr]>; + //===----------------------------------------------------------------------===// // SPIR-V OpTrait definitions //===----------------------------------------------------------------------===// @@ -1155,6 +977,189 @@ def InModuleScope : PredOpTrait< "op must appear in a 'spv.module' block", CPred<"llvm::isa_and_nonnull($_op.getParentOp())">>; +//===----------------------------------------------------------------------===// +// SPIR-V opcode specification +//===----------------------------------------------------------------------===// + +class SPV_OpCode { + // Name used as reference to retrieve the opcode + string opname = name; + + // Opcode associated with the name + int opcode = val; +} + +// Begin opcode section. Generated from SPIR-V spec; DO NOT MODIFY! + +def SPV_OC_OpNop : I32EnumAttrCase<"OpNop", 0>; +def SPV_OC_OpUndef : I32EnumAttrCase<"OpUndef", 1>; +def SPV_OC_OpSourceContinued : I32EnumAttrCase<"OpSourceContinued", 2>; +def SPV_OC_OpSource : I32EnumAttrCase<"OpSource", 3>; +def SPV_OC_OpSourceExtension : I32EnumAttrCase<"OpSourceExtension", 4>; +def SPV_OC_OpName : I32EnumAttrCase<"OpName", 5>; +def SPV_OC_OpMemberName : I32EnumAttrCase<"OpMemberName", 6>; +def SPV_OC_OpString : I32EnumAttrCase<"OpString", 7>; +def SPV_OC_OpExtension : I32EnumAttrCase<"OpExtension", 10>; +def SPV_OC_OpExtInstImport : I32EnumAttrCase<"OpExtInstImport", 11>; +def SPV_OC_OpExtInst : I32EnumAttrCase<"OpExtInst", 12>; +def SPV_OC_OpMemoryModel : I32EnumAttrCase<"OpMemoryModel", 14>; +def SPV_OC_OpEntryPoint : I32EnumAttrCase<"OpEntryPoint", 15>; +def SPV_OC_OpExecutionMode : I32EnumAttrCase<"OpExecutionMode", 16>; +def SPV_OC_OpCapability : I32EnumAttrCase<"OpCapability", 17>; +def SPV_OC_OpTypeVoid : I32EnumAttrCase<"OpTypeVoid", 19>; +def SPV_OC_OpTypeBool : I32EnumAttrCase<"OpTypeBool", 20>; +def SPV_OC_OpTypeInt : I32EnumAttrCase<"OpTypeInt", 21>; +def SPV_OC_OpTypeFloat : I32EnumAttrCase<"OpTypeFloat", 22>; +def SPV_OC_OpTypeVector : I32EnumAttrCase<"OpTypeVector", 23>; +def SPV_OC_OpTypeArray : I32EnumAttrCase<"OpTypeArray", 28>; +def SPV_OC_OpTypeRuntimeArray : I32EnumAttrCase<"OpTypeRuntimeArray", 29>; +def SPV_OC_OpTypeStruct : I32EnumAttrCase<"OpTypeStruct", 30>; +def SPV_OC_OpTypePointer : I32EnumAttrCase<"OpTypePointer", 32>; +def SPV_OC_OpTypeFunction : I32EnumAttrCase<"OpTypeFunction", 33>; +def SPV_OC_OpConstantTrue : I32EnumAttrCase<"OpConstantTrue", 41>; +def SPV_OC_OpConstantFalse : I32EnumAttrCase<"OpConstantFalse", 42>; +def SPV_OC_OpConstant : I32EnumAttrCase<"OpConstant", 43>; +def SPV_OC_OpConstantComposite : I32EnumAttrCase<"OpConstantComposite", 44>; +def SPV_OC_OpConstantNull : I32EnumAttrCase<"OpConstantNull", 46>; +def SPV_OC_OpSpecConstantTrue : I32EnumAttrCase<"OpSpecConstantTrue", 48>; +def SPV_OC_OpSpecConstantFalse : I32EnumAttrCase<"OpSpecConstantFalse", 49>; +def SPV_OC_OpSpecConstant : I32EnumAttrCase<"OpSpecConstant", 50>; +def SPV_OC_OpSpecConstantComposite : I32EnumAttrCase<"OpSpecConstantComposite", 51>; +def SPV_OC_OpFunction : I32EnumAttrCase<"OpFunction", 54>; +def SPV_OC_OpFunctionParameter : I32EnumAttrCase<"OpFunctionParameter", 55>; +def SPV_OC_OpFunctionEnd : I32EnumAttrCase<"OpFunctionEnd", 56>; +def SPV_OC_OpFunctionCall : I32EnumAttrCase<"OpFunctionCall", 57>; +def SPV_OC_OpVariable : I32EnumAttrCase<"OpVariable", 59>; +def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>; +def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>; +def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>; +def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>; +def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>; +def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>; +def SPV_OC_OpConvertFToU : I32EnumAttrCase<"OpConvertFToU", 109>; +def SPV_OC_OpConvertFToS : I32EnumAttrCase<"OpConvertFToS", 110>; +def SPV_OC_OpConvertSToF : I32EnumAttrCase<"OpConvertSToF", 111>; +def SPV_OC_OpConvertUToF : I32EnumAttrCase<"OpConvertUToF", 112>; +def SPV_OC_OpUConvert : I32EnumAttrCase<"OpUConvert", 113>; +def SPV_OC_OpSConvert : I32EnumAttrCase<"OpSConvert", 114>; +def SPV_OC_OpFConvert : I32EnumAttrCase<"OpFConvert", 115>; +def SPV_OC_OpBitcast : I32EnumAttrCase<"OpBitcast", 124>; +def SPV_OC_OpFNegate : I32EnumAttrCase<"OpFNegate", 127>; +def SPV_OC_OpIAdd : I32EnumAttrCase<"OpIAdd", 128>; +def SPV_OC_OpFAdd : I32EnumAttrCase<"OpFAdd", 129>; +def SPV_OC_OpISub : I32EnumAttrCase<"OpISub", 130>; +def SPV_OC_OpFSub : I32EnumAttrCase<"OpFSub", 131>; +def SPV_OC_OpIMul : I32EnumAttrCase<"OpIMul", 132>; +def SPV_OC_OpFMul : I32EnumAttrCase<"OpFMul", 133>; +def SPV_OC_OpUDiv : I32EnumAttrCase<"OpUDiv", 134>; +def SPV_OC_OpSDiv : I32EnumAttrCase<"OpSDiv", 135>; +def SPV_OC_OpFDiv : I32EnumAttrCase<"OpFDiv", 136>; +def SPV_OC_OpUMod : I32EnumAttrCase<"OpUMod", 137>; +def SPV_OC_OpSRem : I32EnumAttrCase<"OpSRem", 138>; +def SPV_OC_OpSMod : I32EnumAttrCase<"OpSMod", 139>; +def SPV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>; +def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>; +def SPV_OC_OpLogicalEqual : I32EnumAttrCase<"OpLogicalEqual", 164>; +def SPV_OC_OpLogicalNotEqual : I32EnumAttrCase<"OpLogicalNotEqual", 165>; +def SPV_OC_OpLogicalOr : I32EnumAttrCase<"OpLogicalOr", 166>; +def SPV_OC_OpLogicalAnd : I32EnumAttrCase<"OpLogicalAnd", 167>; +def SPV_OC_OpLogicalNot : I32EnumAttrCase<"OpLogicalNot", 168>; +def SPV_OC_OpSelect : I32EnumAttrCase<"OpSelect", 169>; +def SPV_OC_OpIEqual : I32EnumAttrCase<"OpIEqual", 170>; +def SPV_OC_OpINotEqual : I32EnumAttrCase<"OpINotEqual", 171>; +def SPV_OC_OpUGreaterThan : I32EnumAttrCase<"OpUGreaterThan", 172>; +def SPV_OC_OpSGreaterThan : I32EnumAttrCase<"OpSGreaterThan", 173>; +def SPV_OC_OpUGreaterThanEqual : I32EnumAttrCase<"OpUGreaterThanEqual", 174>; +def SPV_OC_OpSGreaterThanEqual : I32EnumAttrCase<"OpSGreaterThanEqual", 175>; +def SPV_OC_OpULessThan : I32EnumAttrCase<"OpULessThan", 176>; +def SPV_OC_OpSLessThan : I32EnumAttrCase<"OpSLessThan", 177>; +def SPV_OC_OpULessThanEqual : I32EnumAttrCase<"OpULessThanEqual", 178>; +def SPV_OC_OpSLessThanEqual : I32EnumAttrCase<"OpSLessThanEqual", 179>; +def SPV_OC_OpFOrdEqual : I32EnumAttrCase<"OpFOrdEqual", 180>; +def SPV_OC_OpFUnordEqual : I32EnumAttrCase<"OpFUnordEqual", 181>; +def SPV_OC_OpFOrdNotEqual : I32EnumAttrCase<"OpFOrdNotEqual", 182>; +def SPV_OC_OpFUnordNotEqual : I32EnumAttrCase<"OpFUnordNotEqual", 183>; +def SPV_OC_OpFOrdLessThan : I32EnumAttrCase<"OpFOrdLessThan", 184>; +def SPV_OC_OpFUnordLessThan : I32EnumAttrCase<"OpFUnordLessThan", 185>; +def SPV_OC_OpFOrdGreaterThan : I32EnumAttrCase<"OpFOrdGreaterThan", 186>; +def SPV_OC_OpFUnordGreaterThan : I32EnumAttrCase<"OpFUnordGreaterThan", 187>; +def SPV_OC_OpFOrdLessThanEqual : I32EnumAttrCase<"OpFOrdLessThanEqual", 188>; +def SPV_OC_OpFUnordLessThanEqual : I32EnumAttrCase<"OpFUnordLessThanEqual", 189>; +def SPV_OC_OpFOrdGreaterThanEqual : I32EnumAttrCase<"OpFOrdGreaterThanEqual", 190>; +def SPV_OC_OpFUnordGreaterThanEqual : I32EnumAttrCase<"OpFUnordGreaterThanEqual", 191>; +def SPV_OC_OpShiftRightLogical : I32EnumAttrCase<"OpShiftRightLogical", 194>; +def SPV_OC_OpShiftRightArithmetic : I32EnumAttrCase<"OpShiftRightArithmetic", 195>; +def SPV_OC_OpShiftLeftLogical : I32EnumAttrCase<"OpShiftLeftLogical", 196>; +def SPV_OC_OpBitwiseOr : I32EnumAttrCase<"OpBitwiseOr", 197>; +def SPV_OC_OpBitwiseXor : I32EnumAttrCase<"OpBitwiseXor", 198>; +def SPV_OC_OpBitwiseAnd : I32EnumAttrCase<"OpBitwiseAnd", 199>; +def SPV_OC_OpNot : I32EnumAttrCase<"OpNot", 200>; +def SPV_OC_OpBitFieldInsert : I32EnumAttrCase<"OpBitFieldInsert", 201>; +def SPV_OC_OpBitFieldSExtract : I32EnumAttrCase<"OpBitFieldSExtract", 202>; +def SPV_OC_OpBitFieldUExtract : I32EnumAttrCase<"OpBitFieldUExtract", 203>; +def SPV_OC_OpBitReverse : I32EnumAttrCase<"OpBitReverse", 204>; +def SPV_OC_OpBitCount : I32EnumAttrCase<"OpBitCount", 205>; +def SPV_OC_OpControlBarrier : I32EnumAttrCase<"OpControlBarrier", 224>; +def SPV_OC_OpMemoryBarrier : I32EnumAttrCase<"OpMemoryBarrier", 225>; +def SPV_OC_OpPhi : I32EnumAttrCase<"OpPhi", 245>; +def SPV_OC_OpLoopMerge : I32EnumAttrCase<"OpLoopMerge", 246>; +def SPV_OC_OpSelectionMerge : I32EnumAttrCase<"OpSelectionMerge", 247>; +def SPV_OC_OpLabel : I32EnumAttrCase<"OpLabel", 248>; +def SPV_OC_OpBranch : I32EnumAttrCase<"OpBranch", 249>; +def SPV_OC_OpBranchConditional : I32EnumAttrCase<"OpBranchConditional", 250>; +def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>; +def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>; +def SPV_OC_OpUnreachable : I32EnumAttrCase<"OpUnreachable", 255>; +def SPV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>; +def SPV_OC_OpGroupNonUniformBallot : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>; +def SPV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>; + +def SPV_OpcodeAttr : + I32EnumAttr<"Opcode", "valid SPIR-V instructions", [ + SPV_OC_OpNop, SPV_OC_OpUndef, SPV_OC_OpSourceContinued, SPV_OC_OpSource, + SPV_OC_OpSourceExtension, SPV_OC_OpName, SPV_OC_OpMemberName, SPV_OC_OpString, + SPV_OC_OpExtension, SPV_OC_OpExtInstImport, SPV_OC_OpExtInst, + SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode, + SPV_OC_OpCapability, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, + SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeArray, + SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer, + SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, + SPV_OC_OpConstant, SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull, + SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant, + SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, + SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, + SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, + SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpConvertFToU, + SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, + SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast, + SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, + SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, + SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, + SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, + SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, + SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, + SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, + SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, + SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, + SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, + SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, + SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, + SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, + SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic, + SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor, + SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert, + SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse, + SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, + SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, + SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn, + SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed, + SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpSubgroupBallotKHR + ]> { + let cppNamespace = "::mlir::spirv"; +} + +// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! + //===----------------------------------------------------------------------===// // SPIR-V op definitions //===----------------------------------------------------------------------===// diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td new file mode 100644 index 00000000000..5f60e6b0135 --- /dev/null +++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td @@ -0,0 +1,74 @@ +//===-- SPIRVGroupOps.td - MLIR SPIR-V (Sub)Group Ops ------*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file contains group and subgroup ops for the SPIR-V dialect. It +// corresponds to "3.32.21. Group and Subgroup Instructions" of the SPIR-V +// specification. +// +//===----------------------------------------------------------------------===// + +#ifndef SPIRV_GROUP_OPS +#define SPIRV_GROUP_OPS + +// ----- + +def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> { + let summary = "See extension SPV_KHR_shader_ballot"; + + let description = [{ + Computes a bitfield value combining the Predicate value from all invocations + in the current Subgroup that execute the same dynamic instance of this + instruction. The bit is set to one if the corresponding invocation is active + and the predicate is evaluated to true; otherwise, it is set to zero. + + Predicate must be a Boolean type. + + Result Type must be a 4 component vector of 32 bit integer types. + + Result is a set of bitfields where the first invocation is represented in bit + 0 of the first vector component and the last (up to SubgroupSize) is the + higher bit number of the last bitmask needed to represent all bits of the + subgroup invocations. + + ### Custom assembly form + + ``` {.ebnf} + subgroup-ballot-op ::= ssa-id `=` `spv.SubgroupBallotKHR` + ssa-use `:` `vector` `<` 4 `x` `i32` `>` + ``` + + For example: + + ``` + %0 = spv.SubgroupBallotKHR %predicate : vector<4xi32> + ``` + }]; + + let arguments = (ins + SPV_Bool:$predicate + ); + + let results = (outs + SPV_I32Vec4:$result + ); + + let verifier = [{ return success(); }]; +} + +// ----- + +#endif // SPIRV_GROUP_OPS diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h index a5b3fc27413..306f2b9f309 100644 --- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h +++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h @@ -30,21 +30,23 @@ namespace mlir { -/// Converts a function type according to the requirements of a SPIR-V entry -/// function. The arguments need to be converted to spv.GlobalVariables of -/// spv.ptr types so that they could be bound by the runtime. +/// Type conversion from stdandard types to SPIR-V types for shader interface. +/// +/// For composite types, this converter additionally performs type wrapping to +/// satisfy shader interface requirements: shader interface types must be +/// pointers to structs. class SPIRVTypeConverter final : public TypeConverter { public: using TypeConverter::TypeConverter; - /// Converts types to SPIR-V types using the basic type converter. - Type convertType(Type t) override; + /// Converts the given standard `type` to SPIR-V correspondance. + Type convertType(Type type) override; - /// Gets the index type equivalent in SPIR-V. - Type getIndexType(MLIRContext *context); + /// Gets the SPIR-V correspondance for the standard index type. + static Type getIndexType(MLIRContext *context); }; -/// Base class to define a conversion pattern to translate Ops into SPIR-V. +/// Base class to define a conversion pattern to lower `SourceOp` into SPIR-V. template class SPIRVOpLowering : public OpConversionPattern { public: @@ -54,10 +56,7 @@ public: typeConverter(typeConverter) {} protected: - /// Type lowering class. SPIRVTypeConverter &typeConverter; - -private: }; #include "mlir/Dialect/SPIRV/SPIRVLowering.h.inc" diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td new file mode 100644 index 00000000000..a37f5b576fd --- /dev/null +++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td @@ -0,0 +1,78 @@ +//===-- SPIRVNonUniformOps.td - MLIR SPIR-V NonUniform Ops -*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file contains non-uniform ops for the SPIR-V dialect. It corresponds to +// "3.32.24. Non-Uniform Instructions" of the SPIR-V specification. +// +//===----------------------------------------------------------------------===// + +#ifndef SPIRV_NON_UNIFORM_OPS +#define SPIRV_NON_UNIFORM_OPS + +// ----- + +def SPV_GroupNonUniformBallotOp : SPV_Op<"GroupNonUniformBallot", []> { + let summary = [{ + Returns a bitfield value combining the Predicate value from all + invocations in the group that execute the same dynamic instance of this + instruction. The bit is set to one if the corresponding invocation is + active and the Predicate for that invocation evaluated to true; + otherwise, it is set to zero. + }]; + + let description = [{ + Result Type must be a vector of four components of integer type scalar, + whose Signedness operand is 0. + + Result is a set of bitfields where the first invocation is represented + in the lowest bit of the first vector component and the last (up to the + size of the group) is the higher bit number of the last bitmask needed + to represent all bits of the group invocations. + + Execution must be Workgroup or Subgroup Scope. + + Predicate must be a Boolean type. + + ### Custom assembly form + + ``` {.ebnf} + scope ::= `"Workgroup"` | `"Subgroup"` + non-uniform-ballot-op ::= ssa-id `=` `spv.GroupNonUniformBallot` scope + ssa-use `:` `vector` `<` 4 `x` `integer-type` `>` + ``` + + For example: + + ``` + %0 = spv.GroupNonUniformBallot "SubGroup" %predicate : vector<4xi32> + ``` + }]; + + let arguments = (ins + SPV_ScopeAttr:$execution_scope, + SPV_Bool:$predicate + ); + + let results = (outs + SPV_IntVec4:$result + ); +} + +// ----- + +#endif // SPIRV_NON_UNIFORM_OPS + diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h index 104a4798e7c..353004b6c76 100644 --- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h +++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h @@ -26,6 +26,8 @@ #include "mlir/IR/Function.h" namespace mlir { +class OpBuilder; + namespace spirv { #define GET_OP_CLASSES diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td index 41d729da777..149c2359fda 100644 --- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -20,11 +20,12 @@ // //===----------------------------------------------------------------------===// -// Note that for each op in this file, we use a tool to automatically generate -// certain sections in its definition: basic structure, summary, description. -// So modifications to these sections will not be respected. Modifications to -// op traits, arguments, results, and sections after the results are retained. -// Besides, ops in this file must be separated via the '// -----' marker. +// Note that for each op in this file and the included files for specific op +// categories, we use a tool to automatically generate certain sections in its +// definition: basic structure, summary, description. So modifications to these +// sections will not be respected. Modifications to op traits, arguments, +// results, and sections after the results are retained. Besides, ops must be +// separated via the '// -----' marker. #ifndef SPIRV_OPS #define SPIRV_OPS @@ -34,11 +35,11 @@ include "mlir/Dialect/SPIRV/SPIRVArithmeticOps.td" include "mlir/Dialect/SPIRV/SPIRVBitOps.td" include "mlir/Dialect/SPIRV/SPIRVCastOps.td" include "mlir/Dialect/SPIRV/SPIRVControlFlowOps.td" -include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td" -// Pull in ops for defining the SPIR-V module structure -include "mlir/Dialect/SPIRV/SPIRVStructureOps.td" -// Pull in ops for extended instruction set for GLSL include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td" +include "mlir/Dialect/SPIRV/SPIRVGroupOps.td" +include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td" +include "mlir/Dialect/SPIRV/SPIRVNonUniformOps.td" +include "mlir/Dialect/SPIRV/SPIRVStructureOps.td" // ----- diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td index 1ec825aab5c..34b386ebc17 100644 --- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -118,6 +118,13 @@ def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> { let extraClassDeclaration = [{ // Returns true if a constant can be built for the given `type`. static bool isBuildableWith(Type type); + + // Creates a constant zero/one of the given `type` at the current insertion + // point of `builder` and returns it. + static spirv::ConstantOp getZero(Type type, Location loc, + OpBuilder *builder); + static spirv::ConstantOp getOne(Type type, Location loc, + OpBuilder *builder); }]; let hasOpcode = 0; diff --git a/third_party/mlir/include/mlir/Dialect/StandardOps/Ops.td b/third_party/mlir/include/mlir/Dialect/StandardOps/Ops.td index e2731acf47f..51c7bfbccdc 100644 --- a/third_party/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/third_party/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -1206,7 +1206,7 @@ def ViewOp : Std_Op<"view", [NoSideEffect]> { // ViewOp with dynamic offset and one dynamic size. %2 = view %0[%offset_1024][%size0] - : memref<2048xi8> to memref (d0 * 4 + d1 + s0) + : memref<2048xi8> to memref (d0 * 4 + d1 + s0)> // ViewOp creating 3D shape where two of the dim sizes are dynamic. // *) The dynamic offset specified in the ViewOp is applied to the @@ -1219,7 +1219,7 @@ def ViewOp : Std_Op<"view", [NoSideEffect]> { // shape and dynamic sizes. %3 = view %0[%offset_1024][%size0, %size1] : memref<2048xi8> to memref (d0 * s1 + d1 * 4 + d2 + s0) + (d0, d1, d2)[s0, s1] -> (d0 * s1 + d1 * 4 + d2 + s0)> }]; let arguments = (ins MemRefRankOf<[I8], [1]>:$source, @@ -1248,7 +1248,7 @@ def ViewOp : Std_Op<"view", [NoSideEffect]> { let hasCanonicalizer = 1; } -def SubViewOp : Std_Op<"subview", [NoSideEffect]> { +def SubViewOp : Std_Op<"subview", [AttrSizedOperandSegments, NoSideEffect]> { let summary = "memref subview operation"; let description = [{ The "subview" operation converts a memref type to another memref type @@ -1356,23 +1356,25 @@ def SubViewOp : Std_Op<"subview", [NoSideEffect]> { // TODO(b/144779634, ravishankarm) : Use different arguments for // offsets, sizes and strides. - let arguments = (ins AnyMemRef:$source, I32Attr:$num_offsets, - I32Attr:$num_sizes, I32Attr:$num_strides, - Variadic:$operands); + let arguments = (ins + AnyMemRef:$source, + Variadic:$offsets, + Variadic:$sizes, + Variadic:$strides, + I32ElementsAttr:$operand_segment_sizes + ); let results = (outs AnyMemRef); - let builders = [OpBuilder< - "Builder *b, OperationState &result, Value *source, " - "ArrayRef offsets, ArrayRef sizes, " - "ArrayRef strides, Type resultType = Type(), " - "ArrayRef attrs = {}">, + let builders = [ OpBuilder< - "Builder *builder, OperationState &result, Type resultType, Value *source">, + "Builder *b, OperationState &result, Value *source, " + "ArrayRef offsets, ArrayRef sizes, " + "ArrayRef strides, Type resultType = Type(), " + "ArrayRef attrs = {}">, OpBuilder< - "Builder *builder, OperationState &result, Type resultType, Value *source, " - "unsigned num_offsets, unsigned num_sizes, unsigned num_strides, " - "ArrayRef offsets, ArrayRef sizes, " - "ArrayRef strides">]; + "Builder *builder, OperationState &result, " + "Type resultType, Value *source"> + ]; let extraClassDeclaration = [{ /// Returns the type of the base memref operand. @@ -1384,28 +1386,21 @@ def SubViewOp : Std_Op<"subview", [NoSideEffect]> { MemRefType getType() { return getResult()->getType().cast(); } /// Returns as integer value the number of offset operands. - int64_t getNumOffsets() { - return num_offsets().getSExtValue(); - } + int64_t getNumOffsets() { return llvm::size(offsets()); } /// Returns as integer value the number of size operands. - int64_t getNumSizes() { - return num_sizes().getSExtValue(); - } + int64_t getNumSizes() { return llvm::size(sizes()); } /// Returns as integer value the number of stride operands. - int64_t getNumStrides() { - return num_strides().getSExtValue(); - } - - /// Returns the dynamic offsets for this subview operation. - operand_range getDynamicOffsets(); + int64_t getNumStrides() { return llvm::size(strides()); } /// Returns the dynamic sizes for this subview operation if specified. - operand_range getDynamicSizes(); + operand_range getDynamicSizes() { return sizes(); } - /// Returns the dynamic strides for this subview operation if specified. - operand_range getDynamicStrides(); + /// Returns in `staticStrides` the static value of the stride + /// operands. Returns failure() if the static value of the stride + /// operands could not be retrieved. + LogicalResult getStaticStrides(SmallVectorImpl &staticStrides); // Auxiliary range data structure and helper function that unpacks the // offset, size and stride operands of the SubViewOp into a list of triples. diff --git a/third_party/mlir/include/mlir/Dialect/VectorOps/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/VectorOps/CMakeLists.txt index 3849dd7ffdf..c165c5e676d 100644 --- a/third_party/mlir/include/mlir/Dialect/VectorOps/CMakeLists.txt +++ b/third_party/mlir/include/mlir/Dialect/VectorOps/CMakeLists.txt @@ -1,7 +1,4 @@ -set(LLVM_TARGET_DEFINITIONS VectorOps.td) -mlir_tablegen(VectorOps.h.inc -gen-op-decls) -mlir_tablegen(VectorOps.cpp.inc -gen-op-defs) -add_public_tablegen_target(MLIRVectorOpsIncGen) +add_mlir_dialect(VectorOps) set(LLVM_TARGET_DEFINITIONS VectorTransformPatterns.td) mlir_tablegen(VectorTransformPatterns.h.inc -gen-rewriters) diff --git a/third_party/mlir/include/mlir/Analysis/VectorAnalysis.h b/third_party/mlir/include/mlir/Dialect/VectorOps/Utils.h similarity index 96% rename from third_party/mlir/include/mlir/Analysis/VectorAnalysis.h rename to third_party/mlir/include/mlir/Dialect/VectorOps/Utils.h index 350bdfd8cce..2cff8795304 100644 --- a/third_party/mlir/include/mlir/Analysis/VectorAnalysis.h +++ b/third_party/mlir/include/mlir/Dialect/VectorOps/Utils.h @@ -1,4 +1,4 @@ -//===- VectorAnalysis.h - Analysis for Vectorization -------*- C++ -*-=======// +//===- Utils.h - VectorOps Utils ----------------------------*- C++ -*-=======// // // Copyright 2019 The MLIR Authors. // @@ -15,8 +15,8 @@ // limitations under the License. // ============================================================================= -#ifndef MLIR_ANALYSIS_VECTORANALYSIS_H_ -#define MLIR_ANALYSIS_VECTORANALYSIS_H_ +#ifndef MLIR_DIALECT_VECTOROPS_UTILS_H_ +#define MLIR_DIALECT_VECTOROPS_UTILS_H_ #include "mlir/Support/LLVM.h" @@ -140,4 +140,4 @@ bool operatesOnSuperVectorsOf(Operation &op, VectorType subVectorType); } // end namespace matcher } // end namespace mlir -#endif // MLIR_ANALYSIS_VECTORANALYSIS_H_ +#endif // MLIR_DIALECT_VECTOROPS_UTILS_H_ diff --git a/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index c78334dd54a..d34fa9a245d 100644 --- a/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -171,7 +171,24 @@ def Vector_BroadcastOp : let summary = "broadcast operation"; let description = [{ Broadcasts the scalar or k-D vector value in the source operand - to a n-D result vector such that the broadcast makes sense. + to a n-D result vector such that the broadcast makes sense, i.e., + the source operand is duplicated to match the given rank and sizes + in the result vector. The legality rules are: + * the source operand must have the same element type as the result type + * a k-D vector can be broadcast to + a n-D vector if + * k <= n, and + * the sizes in the trailing dimensions n-k < i <= n with j=i+k-n + match exactly as s_j = t_i or s_j = 1: + ``` + t_1 x .. t_n-k x t_n-k+1 x .. x t_i x .. x t_n + s_1 x .. x s_j x .. x s_k + + ``` + The source operand is duplicated over all the missing leading dimensions + and streched over the trailing dimensions where the source has a non-equal + dimension of 1. These rules imply that any scalar broadcast (k=0) to any + shaped vector with the same element type is always legal. Examples: ``` @@ -610,7 +627,37 @@ def Vector_TypeCastOp : }]; } -// TODO(andydavis) Morph this operation into a Vector_MaskOp. +// TODO(andydavis) Add constant folding support. +def Vector_CreateMaskOp : + Vector_Op<"create_mask", [NoSideEffect]>, + Arguments<(ins Variadic:$operands)>, Results<(outs VectorOf<[I1]>)> { + let summary = "creates a vector mask"; + let description = [{ + Creates and returns a vector mask where elements of the result vector + are set to '0' or '1', based on whether the element indices are contained + within a hyper-rectangular region specified by the operands. Specifically, + each operand specifies a range [0, operand-value) for a unique dimension in + the vector result. The conjunction of the operand ranges define + hyper-rectangular region within which elements values are set to 1 + (otherwise element values are set to 0). + + Example: create a vector mask of size 4x3xi1 where elements in range + 0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0). + + %1 = vector.create_mask %c3, %c2 : vector<4x3xi1> + + print %1 + columns + 0 1 2 + |------------ + 0 | 1 1 0 + rows 1 | 1 1 0 + 2 | 1 1 0 + 3 | 0 0 0 + }]; +} + +// TODO(andydavis) Delete this op once ContractOp is converted to use VectorMask def Vector_IndexTupleOp : Vector_Op<"make_index_tuple", [NoSideEffect]>, Arguments<(ins Variadic:$operands)>, diff --git a/third_party/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h b/third_party/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h new file mode 100644 index 00000000000..2c2e4e7c4fa --- /dev/null +++ b/third_party/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h @@ -0,0 +1,82 @@ +//===- VectorTransforms.h - Vector transformations as patterns --*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef DIALECT_VECTOROPS_VECTORTRANSFORMS_H_ +#define DIALECT_VECTOROPS_VECTORTRANSFORMS_H_ + +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +class MLIRContext; +class OwningRewritePatternList; + +/// Collect a set of patterns to convert from the Vector dialect to itself. +/// Should be merged with populateVectorToAffineLoopsConversionPatterns. +void populateVectorToVectorConversionPatterns( + MLIRContext *context, OwningRewritePatternList &patterns, + ArrayRef coarseVectorShape = {}, + ArrayRef fineVectorShape = {}); + +//////////////////////////////////////////////////////////////////////////////// +// The following Declarative Rewrite Rule (DRR) helpers are used in rewrite +// patterns. As such, they must not call into `rewriter.erase/replace` APIs and +// it is the responsibility of the enclosing PatternRewriter to erase on +// success. +//////////////////////////////////////////////////////////////////////////////// + +namespace vector { + +// Entry point for unrolling declarative pattern rewrites. +// `op` is unrolled to the `targetShape` as follows, for each of its operands: +// 1. the unrolled type `unrolledVectorType` and number of unrolled instances +// `numUnrolledInstances` are computed from the `targetShape`. For now it is +// assumed the unrolling factors divide the vector sizes. +// 2. a fakeFork cast op is inserted that takes the operand and returns +// `numUnrolledInstances` results of type `unrolledVectorType`. +// 3. the original op is cloned `numUnrolledInstances` times, once for each +// result of the fakeFork cast op. +// 4. a fakeJoin cast op takes all these results and merges them into a single +// aggregate vector result whose size matches the original non-unrolled op +// operand types. +// +// Example: +// +// opA(operand0, operand1) // numUnrolledInstances = 3 +// +// operand0 operand1 +// | | +// fork fork +// <----------gather all fork ops ---------> +// /|\ /|\ +// f00 f01 f02 f10 f11 f12 +// <---------- clone op 3 times ---------> +// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12) +// \ | / +// <-------------------- join -------------------------> +// +// Other local patterns then kick in iteratively (including DCE) and compose +// until all the fakeFork and fakeJoin ops are removed. +// +// This will be extended in the future to support more advanced use cases than +// simple pointwise ops. +Value *unrollSingleResultOpMatchingType(PatternRewriter &builder, Operation *op, + ArrayRef targetShape); + +} // namespace vector +} // namespace mlir + +#endif // DIALECT_VECTOROPS_VECTORTRANSFORMS_H_ diff --git a/third_party/mlir/include/mlir/EDSC/Builders.h b/third_party/mlir/include/mlir/EDSC/Builders.h index 1927ce60eab..5940f1c244f 100644 --- a/third_party/mlir/include/mlir/EDSC/Builders.h +++ b/third_party/mlir/include/mlir/EDSC/Builders.h @@ -26,7 +26,6 @@ #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/IR/Builders.h" #include "mlir/Transforms/FoldUtils.h" diff --git a/third_party/mlir/include/mlir/EDSC/Intrinsics.h b/third_party/mlir/include/mlir/EDSC/Intrinsics.h index 6e1c49f66cc..468cc1c4240 100644 --- a/third_party/mlir/include/mlir/EDSC/Intrinsics.h +++ b/third_party/mlir/include/mlir/EDSC/Intrinsics.h @@ -215,7 +215,6 @@ using select = ValueBuilder; using std_load = ValueBuilder; using std_store = OperationBuilder; using subi = ValueBuilder; -using vector_type_cast = ValueBuilder; using view = ValueBuilder; /// Branches into the mlir::Block* captured by BlockHandle `b` with `operands`. diff --git a/third_party/mlir/include/mlir/IR/Builders.h b/third_party/mlir/include/mlir/IR/Builders.h index 01ad38cfc11..c5ed7b16b56 100644 --- a/third_party/mlir/include/mlir/IR/Builders.h +++ b/third_party/mlir/include/mlir/IR/Builders.h @@ -120,6 +120,8 @@ public: IntegerAttr getI32IntegerAttr(int32_t value); IntegerAttr getI64IntegerAttr(int64_t value); + DenseIntElementsAttr getI32VectorAttr(ArrayRef values); + ArrayAttr getAffineMapArrayAttr(ArrayRef values); ArrayAttr getI32ArrayAttr(ArrayRef values); ArrayAttr getI64ArrayAttr(ArrayRef values); diff --git a/third_party/mlir/include/mlir/IR/FunctionImplementation.h b/third_party/mlir/include/mlir/IR/FunctionImplementation.h new file mode 100644 index 00000000000..241d5615acf --- /dev/null +++ b/third_party/mlir/include/mlir/IR/FunctionImplementation.h @@ -0,0 +1,109 @@ +//===- FunctionImplementation.h - Function-like Op utilities ----*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file provides utility functions for implementing function-like +// operations, in particular, parsing, printing and verification components +// common to function-like operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_FUNCTIONIMPLEMENTATION_H_ +#define MLIR_IR_FUNCTIONIMPLEMENTATION_H_ + +#include "mlir/IR/FunctionSupport.h" +#include "mlir/IR/OpImplementation.h" + +namespace mlir { + +namespace impl { + +/// A named class for passing around the variadic flag. +class VariadicFlag { +public: + explicit VariadicFlag(bool variadic) : variadic(variadic) {} + bool isVariadic() const { return variadic; } + +private: + /// Underlying storage. + bool variadic; +}; + +/// Adds argument and result attributes, provided as `argAttrs` and +/// `resultAttrs` arguments, to the list of operation attributes in `result`. +/// Internally, argument and result attributes are stored as dict attributes +/// with special names given by getResultAttrName, getArgumentAttrName. +void addArgAndResultAttrs(Builder &builder, OperationState &result, + ArrayRef> argAttrs, + ArrayRef> resultAttrs); + +/// Callback type for `parseFunctionLikeOp`, the callback should produce the +/// type that will be associated with a function-like operation from lists of +/// function arguments and results, VariadicFlag indicates whether the function +/// should have variadic arguments; in case of error, it may populate the last +/// argument with a message. +using FuncTypeBuilder = llvm::function_ref, ArrayRef, VariadicFlag, std::string &)>; + +/// Parses a function signature using `parser`. The `allowVariadic` argument +/// indicates whether functions with variadic arguments are supported. The +/// trailing arguments are populated by this function with names, types and +/// attributes of the arguments and those of the results. +ParseResult parseFunctionSignature( + OpAsmParser &parser, bool allowVariadic, + SmallVectorImpl &argNames, + SmallVectorImpl &argTypes, + SmallVectorImpl> &argAttrs, bool &isVariadic, + SmallVectorImpl &resultTypes, + SmallVectorImpl> &resultAttrs); + +/// Parser implementation for function-like operations. Uses +/// `funcTypeBuilder` to construct the custom function type given lists of +/// input and output types. If `allowVariadic` is set, the parser will accept +/// trailing ellipsis in the function signature and indicate to the builder +/// whether the function is variadic. If the builder returns a null type, +/// `result` will not contain the `type` attribute. The caller can then add a +/// type, report the error or delegate the reporting to the op's verifier. +ParseResult parseFunctionLikeOp(OpAsmParser &parser, OperationState &result, + bool allowVariadic, + FuncTypeBuilder funcTypeBuilder); + +/// Printer implementation for function-like operations. Accepts lists of +/// argument and result types to use while printing. +void printFunctionLikeOp(OpAsmPrinter &p, Operation *op, + ArrayRef argTypes, bool isVariadic, + ArrayRef resultTypes); + +/// Prints the signature of the function-like operation `op`. Assumes `op` has +/// the FunctionLike trait and passed the verification. +void printFunctionSignature(OpAsmPrinter &p, Operation *op, + ArrayRef argTypes, bool isVariadic, + ArrayRef resultTypes); + +/// Prints the list of function prefixed with the "attributes" keyword. The +/// attributes with names listed in "elided" as well as those used by the +/// function-like operation internally are not printed. Nothing is printed +/// if all attributes are elided. Assumes `op` has the `FunctionLike` trait and +/// passed the verification. +void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs, + unsigned numResults, + ArrayRef elided = {}); + +} // namespace impl + +} // namespace mlir + +#endif // MLIR_IR_FUNCTIONIMPLEMENTATION_H_ diff --git a/third_party/mlir/include/mlir/IR/FunctionSupport.h b/third_party/mlir/include/mlir/IR/FunctionSupport.h index 38e406e8f08..4656c35a9c2 100644 --- a/third_party/mlir/include/mlir/IR/FunctionSupport.h +++ b/third_party/mlir/include/mlir/IR/FunctionSupport.h @@ -24,12 +24,12 @@ #define MLIR_IR_FUNCTIONSUPPORT_H #include "mlir/IR/OpDefinition.h" -#include "mlir/IR/OpImplementation.h" #include "llvm/ADT/SmallString.h" namespace mlir { namespace impl { + /// Return the name of the attribute used for function types. inline StringRef getTypeAttrName() { return "type"; } @@ -73,77 +73,6 @@ inline ArrayRef getResultAttrs(Operation *op, unsigned index) { return resultDict ? resultDict.getValue() : llvm::None; } -/// A named class for passing around the variadic flag. -class VariadicFlag { -public: - explicit VariadicFlag(bool variadic) : variadic(variadic) {} - bool isVariadic() const { return variadic; } - -private: - /// Underlying storage. - bool variadic; -}; - -/// Adds argument and result attributes, provided as `argAttrs` and -/// `resultAttrs` arguments, to the list of operation attributes in `result`. -/// Internally, argument and result attributes are stored as dict attributes -/// with special names given by getResultAttrName, getArgumentAttrName. -void addArgAndResultAttrs(Builder &builder, OperationState &result, - ArrayRef> argAttrs, - ArrayRef> resultAttrs); - -/// Callback type for `parseFunctionLikeOp`, the callback should produce the -/// type that will be associated with a function-like operation from lists of -/// function arguments and results, VariadicFlag indicates whether the function -/// should have variadic arguments; in case of error, it may populate the last -/// argument with a message. -using FuncTypeBuilder = llvm::function_ref, ArrayRef, VariadicFlag, std::string &)>; - -/// Parses a function signature using `parser`. The `allowVariadic` argument -/// indicates whether functions with variadic arguments are supported. The -/// trailing arguments are populated by this function with names, types and -/// attributes of the arguments and those of the results. -ParseResult parseFunctionSignature( - OpAsmParser &parser, bool allowVariadic, - SmallVectorImpl &argNames, - SmallVectorImpl &argTypes, - SmallVectorImpl> &argAttrs, bool &isVariadic, - SmallVectorImpl &resultTypes, - SmallVectorImpl> &resultAttrs); - -/// Parser implementation for function-like operations. Uses -/// `funcTypeBuilder` to construct the custom function type given lists of -/// input and output types. If `allowVariadic` is set, the parser will accept -/// trailing ellipsis in the function signature and indicate to the builder -/// whether the function is variadic. If the builder returns a null type, -/// `result` will not contain the `type` attribute. The caller can then add a -/// type, report the error or delegate the reporting to the op's verifier. -ParseResult parseFunctionLikeOp(OpAsmParser &parser, OperationState &result, - bool allowVariadic, - FuncTypeBuilder funcTypeBuilder); - -/// Printer implementation for function-like operations. Accepts lists of -/// argument and result types to use while printing. -void printFunctionLikeOp(OpAsmPrinter &p, Operation *op, - ArrayRef argTypes, bool isVariadic, - ArrayRef resultTypes); - -/// Prints the signature of the function-like operation `op`. Assumes `op` has -/// the FunctionLike trait and passed the verification. -void printFunctionSignature(OpAsmPrinter &p, Operation *op, - ArrayRef argTypes, bool isVariadic, - ArrayRef resultTypes); - -/// Prints the list of function prefixed with the "attributes" keyword. The -/// attributes with names listed in "elided" as well as those used by the -/// function-like operation internally are not printed. Nothing is printed -/// if all attributes are elided. Assumes `op` has the `FunctionLike` trait and -/// passed the verification. -void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs, - unsigned numResults, - ArrayRef elided = {}); - } // namespace impl namespace OpTrait { diff --git a/third_party/mlir/include/mlir/IR/Operation.h b/third_party/mlir/include/mlir/IR/Operation.h index 8a7bad13d69..27bc1b17b63 100644 --- a/third_party/mlir/include/mlir/IR/Operation.h +++ b/third_party/mlir/include/mlir/IR/Operation.h @@ -63,7 +63,7 @@ public: static Operation *create(Location location, OperationName name, ArrayRef resultTypes, ArrayRef operands, - const NamedAttributeList &attributes, + NamedAttributeList attributes, ArrayRef successors, unsigned numRegions, bool resizableOperandList); @@ -74,7 +74,7 @@ public: static Operation *create(Location location, OperationName name, ArrayRef resultTypes, ArrayRef operands, - const NamedAttributeList &attributes, + NamedAttributeList attributes, ArrayRef successors = {}, ArrayRef> regions = {}, bool resizableOperandList = false); diff --git a/third_party/mlir/include/mlir/Pass/AnalysisManager.h b/third_party/mlir/include/mlir/Pass/AnalysisManager.h index 163ecf6356f..6c37223ad91 100644 --- a/third_party/mlir/include/mlir/Pass/AnalysisManager.h +++ b/third_party/mlir/include/mlir/Pass/AnalysisManager.h @@ -76,9 +76,36 @@ private: SmallPtrSet preservedIDs; }; +namespace analysis_impl { +/// Trait to check if T provides a static 'isInvalidated' method. +template +using has_is_invalidated = decltype(std::declval().isInvalidated( + std::declval())); + +/// Implementation of 'isInvalidated' if the analysis provides a definition. +template +std::enable_if_t::value, bool> +isInvalidated(AnalysisT &analysis, const PreservedAnalyses &pa) { + return analysis.isInvalidated(pa); +} +/// Default implementation of 'isInvalidated'. +template +std::enable_if_t::value, bool> +isInvalidated(AnalysisT &analysis, const PreservedAnalyses &pa) { + return !pa.isPreserved(); +} +} // end namespace analysis_impl + /// The abstract polymorphic base class representing an analysis. struct AnalysisConcept { virtual ~AnalysisConcept() = default; + + /// A hook used to query analyses for invalidation. Given a preserved analysis + /// set, returns true if it should truly be invalidated. This allows for more + /// fine-tuned invalidation in cases where an analysis wasn't explicitly + /// marked preserved, but may be preserved(or invalidated) based upon other + /// properties such as analyses sets. + virtual bool isInvalidated(const PreservedAnalyses &pa) = 0; }; /// A derived analysis model used to hold a specific analysis object. @@ -87,6 +114,12 @@ template struct AnalysisModel : public AnalysisConcept { explicit AnalysisModel(Args &&... args) : analysis(std::forward(args)...) {} + /// A hook used to query analyses for invalidation. + bool isInvalidated(const PreservedAnalyses &pa) final { + return analysis_impl::isInvalidated(analysis, pa); + } + + /// The actual analysis object. AnalysisT analysis; }; @@ -147,11 +180,11 @@ public: /// Invalidate any cached analyses based upon the given set of preserved /// analyses. - void invalidate(const detail::PreservedAnalyses &pa) { - // Remove any analyses not marked as preserved. + void invalidate(const PreservedAnalyses &pa) { + // Remove any analyses that were invalidated. for (auto it = analyses.begin(), e = analyses.end(); it != e;) { auto curIt = it++; - if (!pa.isPreserved(curIt->first)) + if (curIt->second->isInvalidated(pa)) analyses.erase(curIt); } } @@ -170,7 +203,7 @@ struct NestedAnalysisMap { Operation *getOperation() const { return analyses.getOperation(); } /// Invalidate any non preserved analyses. - void invalidate(const detail::PreservedAnalyses &pa); + void invalidate(const PreservedAnalyses &pa); /// The cached analyses for nested operations. llvm::DenseMap> childAnalyses; @@ -195,6 +228,8 @@ class AnalysisManager { const AnalysisManager *>; public: + using PreservedAnalyses = detail::PreservedAnalyses; + // Query for a cached analysis on the given parent operation. The analysis may // not exist and if it does it may be out-of-date. template @@ -240,7 +275,7 @@ public: AnalysisManager slice(Operation *op); /// Invalidate any non preserved analyses, - void invalidate(const detail::PreservedAnalyses &pa) { impl->invalidate(pa); } + void invalidate(const PreservedAnalyses &pa) { impl->invalidate(pa); } /// Clear any held analyses. void clear() { diff --git a/third_party/mlir/include/mlir/TableGen/Attribute.h b/third_party/mlir/include/mlir/TableGen/Attribute.h index 60f95156bb5..242376e24ff 100644 --- a/third_party/mlir/include/mlir/TableGen/Attribute.h +++ b/third_party/mlir/include/mlir/TableGen/Attribute.h @@ -81,10 +81,10 @@ public: // built upon. Attribute getBaseAttr() const; - // Returns whether this attribute has a default value's initializer. - bool hasDefaultValueInitializer() const; - // Returns the default value's initializer for this attribute. - StringRef getDefaultValueInitializer() const; + // Returns whether this attribute has a default value. + bool hasDefaultValue() const; + // Returns the default value for this attribute. + StringRef getDefaultValue() const; // Returns whether this attribute is optional. bool isOptional() const; diff --git a/third_party/mlir/include/mlir/TableGen/Operator.h b/third_party/mlir/include/mlir/TableGen/Operator.h index 7b636ddb79e..89fd4ed8d2e 100644 --- a/third_party/mlir/include/mlir/TableGen/Operator.h +++ b/third_party/mlir/include/mlir/TableGen/Operator.h @@ -103,6 +103,7 @@ public: llvm::iterator_range getAttributes() const; int getNumAttributes() const { return attributes.size(); } + int getNumNativeAttributes() const { return numNativeAttributes; } // Op attribute accessors. NamedAttribute &getAttribute(int index) { return attributes[index]; } diff --git a/third_party/mlir/lib/Analysis/LoopAnalysis.cpp b/third_party/mlir/lib/Analysis/LoopAnalysis.cpp index f01e548a3df..b297a63cb62 100644 --- a/third_party/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/third_party/mlir/lib/Analysis/LoopAnalysis.cpp @@ -25,7 +25,6 @@ #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/NestedMatcher.h" #include "mlir/Dialect/AffineOps/AffineOps.h" -#include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseSet.h" @@ -273,15 +272,12 @@ static bool isVectorElement(LoadOrStoreOpPointer memoryOp) { return memRefType.getElementType().template isa(); } -static bool isVectorTransferReadOrWrite(Operation &op) { - return isa(op) || isa(op); -} - using VectorizableOpFun = std::function; static bool isVectorizableLoopBodyWithOpCond(AffineForOp loop, - VectorizableOpFun isVectorizableOp) { + VectorizableOpFun isVectorizableOp, + NestedPattern &vectorTransferMatcher) { auto *forOp = loop.getOperation(); // No vectorization across conditionals for now. @@ -303,9 +299,8 @@ isVectorizableLoopBodyWithOpCond(AffineForOp loop, return false; } - auto vectorTransfers = matcher::Op(isVectorTransferReadOrWrite); SmallVector vectorTransfersMatched; - vectorTransfers.match(forOp, &vectorTransfersMatched); + vectorTransferMatcher.match(forOp, &vectorTransfersMatched); if (!vectorTransfersMatched.empty()) { return false; } @@ -331,18 +326,20 @@ isVectorizableLoopBodyWithOpCond(AffineForOp loop, return true; } -bool mlir::isVectorizableLoopBody(AffineForOp loop, int *memRefDim) { +bool mlir::isVectorizableLoopBody(AffineForOp loop, int *memRefDim, + NestedPattern &vectorTransferMatcher) { VectorizableOpFun fun([memRefDim](AffineForOp loop, Operation &op) { auto load = dyn_cast(op); auto store = dyn_cast(op); return load ? isContiguousAccess(loop.getInductionVar(), load, memRefDim) : isContiguousAccess(loop.getInductionVar(), store, memRefDim); }); - return isVectorizableLoopBodyWithOpCond(loop, fun); + return isVectorizableLoopBodyWithOpCond(loop, fun, vectorTransferMatcher); } -bool mlir::isVectorizableLoopBody(AffineForOp loop) { - return isVectorizableLoopBodyWithOpCond(loop, nullptr); +bool mlir::isVectorizableLoopBody(AffineForOp loop, + NestedPattern &vectorTransferMatcher) { + return isVectorizableLoopBodyWithOpCond(loop, nullptr, vectorTransferMatcher); } /// Checks whether SSA dominance would be violated if a for op's body diff --git a/third_party/mlir/lib/Analysis/SliceAnalysis.cpp b/third_party/mlir/lib/Analysis/SliceAnalysis.cpp index 718db0e76d2..700321ebb40 100644 --- a/third_party/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/third_party/mlir/lib/Analysis/SliceAnalysis.cpp @@ -20,7 +20,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Analysis/VectorAnalysis.h" #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/IR/Function.h" diff --git a/third_party/mlir/lib/Analysis/VectorAnalysis.cpp b/third_party/mlir/lib/Analysis/VectorAnalysis.cpp index 666ee071c63..42d3f10b14c 100644 --- a/third_party/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/third_party/mlir/lib/Analysis/VectorAnalysis.cpp @@ -15,11 +15,11 @@ // limitations under the License. // ============================================================================= -#include "mlir/Analysis/VectorAnalysis.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Dialect/VectorOps/Utils.h" #include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/IntegerSet.h" diff --git a/third_party/mlir/lib/Conversion/CMakeLists.txt b/third_party/mlir/lib/Conversion/CMakeLists.txt index 6d370f714e2..c791d214d30 100644 --- a/third_party/mlir/lib/Conversion/CMakeLists.txt +++ b/third_party/mlir/lib/Conversion/CMakeLists.txt @@ -8,4 +8,5 @@ add_subdirectory(LoopsToGPU) add_subdirectory(LoopToStandard) add_subdirectory(StandardToLLVM) add_subdirectory(StandardToSPIRV) -add_subdirectory(VectorConversions) +add_subdirectory(VectorToLLVM) +add_subdirectory(VectorToLoops) diff --git a/third_party/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/third_party/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h index d681e4c86ea..e06e88b92f1 100644 --- a/third_party/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/third_party/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -97,7 +97,7 @@ private: return llvm::cast(*funcOp); mlir::OpBuilder b(op->getParentOfType()); - return b.create(op->getLoc(), funcName, funcType, llvm::None); + return b.create(op->getLoc(), funcName, funcType); } const std::string f32Func; diff --git a/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp index 9d8c8942051..f342083bee7 100644 --- a/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp +++ b/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp @@ -320,7 +320,7 @@ Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant( std::string globalName = llvm::formatv("{0}_kernel_name", name); return LLVM::createGlobalString( loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()), - llvmDialect); + LLVM::Linkage::Internal, llvmDialect); } // Emits LLVM IR to launch a kernel function. Expects the module that contains @@ -368,7 +368,8 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( SmallString<128> nameBuffer(*kernelModule.getName()); nameBuffer.append(kCubinStorageSuffix); Value *data = LLVM::createGlobalString( - loc, builder, nameBuffer.str(), cubinAttr.getValue(), getLLVMDialect()); + loc, builder, nameBuffer.str(), cubinAttr.getValue(), + LLVM::Linkage::Internal, getLLVMDialect()); // Emit the load module call to load the module data. Error checking is done // in the called helper function. diff --git a/third_party/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/third_party/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index f56508dfeba..54dd18e7492 100644 --- a/third_party/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/third_party/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -387,8 +387,8 @@ private: builder.getNamedAttr("addr_space", builder.getI32IntegerAttr(3)); auto globalOp = builder.create( loc, arrayType.cast(), - /*isConstant=*/false, name, /*value=*/Attribute(), - llvm::makeArrayRef(addrSpace)); + /*isConstant=*/false, LLVM::Linkage::Internal, name, + /*value=*/Attribute(), llvm::makeArrayRef(addrSpace)); return rewriter.create(loc, globalOp); } diff --git a/third_party/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/third_party/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp index ebb0fd75753..709dd3af7f0 100644 --- a/third_party/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/third_party/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -20,7 +20,7 @@ #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" -#include "mlir/Conversion/VectorConversions/VectorConversions.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" @@ -340,6 +340,28 @@ public: } }; +template +static SmallVector ExtractOperandTypes(Operation *op) { + return SmallVector{op->getOperandTypes()}; +} + +template <> +SmallVector ExtractOperandTypes(Operation *op) { + auto ctx = op->getContext(); + auto indexedGenericOp = cast(op); + auto numLoops = indexedGenericOp.getNumLoops(); + + SmallVector result; + result.reserve(numLoops + op->getNumOperands()); + for (unsigned i = 0; i < numLoops; ++i) { + result.push_back(IndexType::get(ctx)); + } + for (auto type : op->getOperandTypes()) { + result.push_back(type); + } + return result; +} + // Get a SymbolRefAttr containing the library function name for the LinalgOp. // If the library function does not exist, insert a declaration. template @@ -359,7 +381,7 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op, return fnNameAttr; } - SmallVector inputTypes(op->getOperandTypes()); + SmallVector inputTypes(ExtractOperandTypes(op)); assert(op->getNumResults() == 0 && "Library call for linalg operation can be generated only for ops that " "have void return types"); @@ -430,6 +452,40 @@ public: } }; +/// Conversion pattern specialization for IndexedGenericOp. +template <> +class LinalgOpConversion + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(IndexedGenericOp op, + PatternRewriter &rewriter) const override { + auto libraryCallName = + getLibraryCallSymbolRef(op, rewriter); + if (!libraryCallName) + return this->matchFailure(); + + // TODO(pifon, ntv): Use induction variables values instead of zeros, when + // IndexedGenericOp is tiled. + auto zero = rewriter.create( + op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); + auto indexedGenericOp = cast(op); + auto numLoops = indexedGenericOp.getNumLoops(); + SmallVector operands; + operands.reserve(numLoops + op.getNumOperands()); + for (unsigned i = 0; i < numLoops; ++i) { + operands.push_back(zero); + } + for (auto operand : op.getOperands()) { + operands.push_back(operand); + } + rewriter.replaceOpWithNewOp(op, libraryCallName.getValue(), + ArrayRef{}, operands); + return this->matchSuccess(); + } +}; + /// A non-conversion rewrite pattern kicks in to convert CopyOp with /// permutations into a sequence of TransposeOp and permutation-free CopyOp. /// This interplays together with TransposeOpConversion and diff --git a/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index ae2b7837c40..793997e9045 100644 --- a/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -443,9 +443,11 @@ struct FuncOpConversion : public LLVMLegalizationPattern { attributes.push_back(attr); } - // Create an LLVM funcion. + // Create an LLVM funcion, use external linkage by default until MLIR + // functions have linkage. auto newFuncOp = rewriter.create( - op->getLoc(), funcOp.getName(), llvmType, attributes); + op->getLoc(), funcOp.getName(), llvmType, LLVM::Linkage::External, + attributes); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); @@ -1476,7 +1478,6 @@ struct SubViewOpLowering : public LLVMLegalizationPattern { ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto viewOp = cast(op); - SubViewOpOperandAdaptor adaptor(operands); // TODO(b/144779634, ravishankarm) : After Tblgen is adapted to support // having multiple variadic operands where each operand can have different // number of entries, clean all of this up. @@ -1505,10 +1506,12 @@ struct SubViewOpLowering : public LLVMLegalizationPattern { if (!sourceElementTy || !targetDescTy) return matchFailure(); - // Early exit for 0-D and operands lesser than `rank` corner cases. + // Currently, only rank > 0 and full or no operands are supported. Fail to + // convert otherwise. unsigned rank = sourceMemRefType.getRank(); - if (viewMemRefType.getRank() == 0 || rank != dynamicOffsets.size() || - rank != dynamicSizes.size() || rank != dynamicStrides.size()) + if (viewMemRefType.getRank() == 0 || (rank != dynamicOffsets.size()) || + (!dynamicSizes.empty() && rank != dynamicSizes.size()) || + (!dynamicStrides.empty() && rank != dynamicStrides.size())) return matchFailure(); int64_t offset; @@ -1518,7 +1521,7 @@ struct SubViewOpLowering : public LLVMLegalizationPattern { return matchFailure(); // Create the descriptor. - MemRefDescriptor sourceMemRef(adaptor.source()); + MemRefDescriptor sourceMemRef(operands.front()); auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); // Copy the buffer pointer from the old descriptor to the new one. @@ -1538,6 +1541,17 @@ struct SubViewOpLowering : public LLVMLegalizationPattern { for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); + // Fill in missing dynamic sizes. + auto llvmIndexType = lowering.convertType(rewriter.getIndexType()); + if (dynamicSizes.empty()) { + dynamicSizes.reserve(viewMemRefType.getRank()); + auto shape = viewMemRefType.getShape(); + for (auto extent : shape) { + dynamicSizes.push_back(rewriter.create( + loc, llvmIndexType, rewriter.getI64IntegerAttr(extent))); + } + } + // Offset. Value *baseOffset = sourceMemRef.offset(rewriter, loc); for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) { @@ -1551,9 +1565,14 @@ struct SubViewOpLowering : public LLVMLegalizationPattern { // Update sizes and strides. for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { targetMemRef.setSize(rewriter, loc, i, dynamicSizes[i]); - targetMemRef.setStride(rewriter, loc, i, - rewriter.create( - loc, dynamicStrides[i], strideValues[i])); + Value *newStride; + if (dynamicStrides.empty()) + newStride = rewriter.create( + loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); + else + newStride = rewriter.create(loc, dynamicStrides[i], + strideValues[i]); + targetMemRef.setStride(rewriter, loc, i, newStride); } rewriter.replaceOp(op, {targetMemRef}); @@ -1644,13 +1663,14 @@ struct ViewOpLowering : public LLVMLegalizationPattern { // Field 3: Copy the offset in aligned pointer. unsigned numDynamicSizes = llvm::size(viewOp.getDynamicSizes()); (void)numDynamicSizes; + bool hasDynamicOffset = offset == MemRefType::getDynamicStrideOrOffset(); auto sizeAndOffsetOperands = adaptor.operands(); - assert(llvm::size(sizeAndOffsetOperands) == numDynamicSizes + 1 || - offset != MemRefType::getDynamicStrideOrOffset()); - Value *baseOffset = (offset != MemRefType::getDynamicStrideOrOffset()) + assert(llvm::size(sizeAndOffsetOperands) == + numDynamicSizes + (hasDynamicOffset ? 1 : 0)); + Value *baseOffset = !hasDynamicOffset ? createIndexConstant(rewriter, loc, offset) // TODO(ntv): better adaptor. - : sizeAndOffsetOperands.back(); + : sizeAndOffsetOperands.front(); targetMemRef.setOffset(rewriter, loc, baseOffset); // Early exit for 0-D corner case. @@ -1662,10 +1682,14 @@ struct ViewOpLowering : public LLVMLegalizationPattern { return op->emitWarning("cannot cast to non-contiguous shape"), matchFailure(); Value *stride = nullptr, *nextSize = nullptr; + // Drop the dynamic stride from the operand list, if present. + ArrayRef sizeOperands(sizeAndOffsetOperands); + if (hasDynamicOffset) + sizeOperands = sizeOperands.drop_front(); for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { // Update size. - Value *size = getSize(rewriter, loc, viewMemRefType.getShape(), - sizeAndOffsetOperands, i); + Value *size = + getSize(rewriter, loc, viewMemRefType.getShape(), sizeOperands, i); targetMemRef.setSize(rewriter, loc, i, size); // Update stride. stride = getStride(rewriter, loc, strides, nextSize, stride, i); diff --git a/third_party/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt b/third_party/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt index 351216216f1..fcced23a95e 100644 --- a/third_party/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt +++ b/third_party/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt @@ -5,6 +5,7 @@ add_public_tablegen_target(MLIRStandardToSPIRVIncGen) add_llvm_library(MLIRStandardToSPIRVTransforms ConvertStandardToSPIRV.cpp ConvertStandardToSPIRVPass.cpp + LegalizeStandardForSPIRV.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV diff --git a/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp index 62cabf66a0d..c2ca4c94878 100644 --- a/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -28,6 +28,48 @@ using namespace mlir; +//===----------------------------------------------------------------------===// +// Utility functions for operation conversion +//===----------------------------------------------------------------------===// + +/// Performs the index computation to get to the element pointed to by +/// `indices` using the layout map of `baseType`. + +// TODO(ravishankarm) : This method assumes that the `origBaseType` is a +// MemRefType with AffineMap that has static strides. Handle dynamic strides +spirv::AccessChainOp getElementPtr(OpBuilder &builder, + SPIRVTypeConverter &typeConverter, + Location loc, MemRefType origBaseType, + Value *basePtr, ArrayRef indices) { + // Get base and offset of the MemRefType and verify they are static. + int64_t offset; + SmallVector strides; + if (failed(getStridesAndOffset(origBaseType, strides, offset)) || + llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) { + return nullptr; + } + + auto indexType = typeConverter.getIndexType(builder.getContext()); + + Value *ptrLoc = nullptr; + assert(indices.size() == strides.size()); + for (auto index : enumerate(indices)) { + Value *strideVal = builder.create( + loc, indexType, IntegerAttr::get(indexType, strides[index.index()])); + Value *update = + builder.create(loc, strideVal, index.value()); + ptrLoc = + (ptrLoc ? builder.create(loc, ptrLoc, update).getResult() + : update); + } + SmallVector linearizedIndices; + // Add a '0' at the start to index into the struct. + linearizedIndices.push_back(builder.create( + loc, indexType, IntegerAttr::get(indexType, 0))); + linearizedIndices.push_back(ptrLoc); + return builder.create(loc, basePtr, linearizedIndices); +} + //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// @@ -38,6 +80,7 @@ namespace { /// operation. Since IndexType is not used within SPIR-V dialect, this needs /// special handling to make sure the result type and the type of the value /// attribute are consistent. +// TODO(ravishankarm) : This should be moved into DRR. class ConstantIndexOpConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; @@ -112,6 +155,7 @@ public: /// the type of the return value of the replacement operation differs from /// that of the replaced operation. This is not handled in tablegen-based /// pattern specification. +// TODO(ravishankarm) : This should be moved into DRR. template class IntegerOpConversion final : public SPIRVOpLowering { public: @@ -128,37 +172,10 @@ public: } }; -// If 'basePtr' is the result of lowering a value of MemRefType, and 'indices' -// are the indices used to index into the original value (for load/store), -// perform the equivalent address calculation in SPIR-V. -spirv::AccessChainOp getElementPtr(OpBuilder &builder, Location loc, - Value *basePtr, ArrayRef indices, - SPIRVTypeConverter &typeConverter) { - // MemRefType is converted to a - // spirv::StructType>> - auto ptrType = basePtr->getType().cast(); - (void)ptrType; - auto structType = ptrType.getPointeeType().cast(); - (void)structType; - assert(structType.getNumElements() == 1); - auto indexType = typeConverter.getIndexType(builder.getContext()); - - // Need to add a '0' at the beginning of the index list for accessing into the - // struct that wraps the nested array types. - Value *zero = builder.create( - loc, indexType, builder.getIntegerAttr(indexType, 0)); - SmallVector accessIndices; - accessIndices.reserve(1 + indices.size()); - accessIndices.push_back(zero); - accessIndices.append(indices.begin(), indices.end()); - return builder.create(loc, basePtr, accessIndices); -} - /// Convert load -> spv.LoadOp. The operands of the replaced operation are of /// IndexType while that of the replacement operation are of type i32. This is /// not supported in tablegen based pattern specification. -// TODO(ravishankarm) : These could potentially be templated on the operation -// being converted, since the same logic should work for linalg.load. +// TODO(ravishankarm) : This should be moved into DRR. class LoadOpConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; @@ -167,9 +184,9 @@ public: matchAndRewrite(LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { LoadOpOperandAdaptor loadOperands(operands); - auto basePtr = loadOperands.memref(); - auto loadPtr = getElementPtr(rewriter, loadOp.getLoc(), basePtr, - loadOperands.indices(), typeConverter); + auto loadPtr = getElementPtr(rewriter, typeConverter, loadOp.getLoc(), + loadOp.memref()->getType().cast(), + loadOperands.memref(), loadOperands.indices()); rewriter.replaceOpWithNewOp(loadOp, loadPtr, /*memory_access =*/nullptr, /*alignment =*/nullptr); @@ -178,6 +195,7 @@ public: }; /// Convert return -> spv.Return. +// TODO(ravishankarm) : This should be moved into DRR. class ReturnToSPIRVConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; @@ -194,6 +212,7 @@ public: }; /// Convert select -> spv.Select +// TODO(ravishankarm) : This should be moved into DRR. class SelectOpConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; @@ -211,8 +230,7 @@ public: /// Convert store -> spv.StoreOp. The operands of the replaced operation are /// of IndexType while that of the replacement operation are of type i32. This /// is not supported in tablegen based pattern specification. -// TODO(ravishankarm) : These could potentially be templated on the operation -// being converted, since the same logic should work for linalg.store. +// TODO(ravishankarm) : This should be moved into DRR. class StoreOpConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; @@ -221,11 +239,12 @@ public: matchAndRewrite(StoreOp storeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { StoreOpOperandAdaptor storeOperands(operands); - auto value = storeOperands.value(); - auto basePtr = storeOperands.memref(); - auto storePtr = getElementPtr(rewriter, storeOp.getLoc(), basePtr, - storeOperands.indices(), typeConverter); - rewriter.replaceOpWithNewOp(storeOp, storePtr, value, + auto storePtr = + getElementPtr(rewriter, typeConverter, storeOp.getLoc(), + storeOp.memref()->getType().cast(), + storeOperands.memref(), storeOperands.indices()); + rewriter.replaceOpWithNewOp(storeOp, storePtr, + storeOperands.value(), /*memory_access =*/nullptr, /*alignment =*/nullptr); return matchSuccess(); @@ -243,8 +262,8 @@ namespace mlir { void populateStandardToSPIRVPatterns(MLIRContext *context, SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns) { + // Add patterns that lower operations into SPIR-V dialect. populateWithGenerated(context, &patterns); - // Add the return op conversion. patterns .insert, diff --git a/third_party/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/third_party/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp new file mode 100644 index 00000000000..1e8afbf43e1 --- /dev/null +++ b/third_party/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp @@ -0,0 +1,192 @@ +//===- LegalizeStandardForSPIRV.cpp - Legalize ops for SPIR-V lowering ----===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This transformation pass legalizes operations before the conversion to SPIR-V +// dialect to handle ops that cannot be lowered directly. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h" +#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +/// Merges subview operation with load operation. +class LoadOpOfSubViewFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(LoadOp loadOp, + PatternRewriter &rewriter) const override; +}; + +/// Merges subview operation with store operation. +class StoreOpOfSubViewFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(StoreOp storeOp, + PatternRewriter &rewriter) const override; +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Utility functions for op legalization. +//===----------------------------------------------------------------------===// + +/// Given the 'indices' of an load/store operation where the memref is a result +/// of a subview op, returns the indices w.r.t to the source memref of the +/// subview op. For example +/// +/// %0 = ... : memref<12x42xf32> +/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to +/// memref<4x4xf32, offset=?, strides=[?, ?]> +/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]> +/// +/// could be folded into +/// +/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] : +/// memref<12x42xf32> +static LogicalResult +resolveSourceIndices(Location loc, PatternRewriter &rewriter, + SubViewOp subViewOp, ArrayRef indices, + SmallVectorImpl &sourceIndices) { + // TODO: Aborting when the offsets are static. There might be a way to fold + // the subview op with load even if the offsets have been canonicalized + // away. + if (subViewOp.getNumOffsets() == 0) + return failure(); + + SmallVector opOffsets = llvm::to_vector<2>(subViewOp.offsets()); + SmallVector opStrides; + if (subViewOp.getNumStrides()) { + // If the strides are dynamic, get the stride operands. + opStrides = llvm::to_vector<2>(subViewOp.strides()); + } else { + // When static, the stride operands can be retrieved by taking the strides + // of the result of the subview op, and dividing the strides of the base + // memref. + SmallVector staticStrides; + if (failed(subViewOp.getStaticStrides(staticStrides))) { + return failure(); + } + opStrides.reserve(opOffsets.size()); + for (auto stride : staticStrides) { + auto constValAttr = rewriter.getIntegerAttr( + IndexType::get(rewriter.getContext()), stride); + opStrides.emplace_back(rewriter.create(loc, constValAttr)); + } + } + assert(opOffsets.size() == opStrides.size()); + + // New indices for the load are the current indices * subview_stride + + // subview_offset. + assert(indices.size() == opStrides.size()); + sourceIndices.resize(indices.size()); + for (auto index : enumerate(indices)) { + auto offset = opOffsets[index.index()]; + auto stride = opStrides[index.index()]; + auto mul = rewriter.create(loc, index.value(), stride); + sourceIndices[index.index()] = + rewriter.create(loc, offset, mul).getResult(); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Folding SubViewOp and LoadOp. +//===----------------------------------------------------------------------===// + +PatternMatchResult +LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp, + PatternRewriter &rewriter) const { + auto subViewOp = + dyn_cast_or_null(loadOp.memref()->getDefiningOp()); + if (!subViewOp) { + return matchFailure(); + } + SmallVector sourceIndices, + indices = llvm::to_vector<4>(loadOp.indices()); + if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, indices, + sourceIndices))) + return matchFailure(); + + rewriter.replaceOpWithNewOp(loadOp, subViewOp.source(), + sourceIndices); + return matchSuccess(); +} + +//===----------------------------------------------------------------------===// +// Folding SubViewOp and StoreOp. +//===----------------------------------------------------------------------===// + +PatternMatchResult +StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp, + PatternRewriter &rewriter) const { + auto subViewOp = + dyn_cast_or_null(storeOp.memref()->getDefiningOp()); + if (!subViewOp) { + return matchFailure(); + } + SmallVector sourceIndices, + indices = llvm::to_vector<4>(storeOp.indices()); + if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp, + indices, sourceIndices))) + return matchFailure(); + + rewriter.replaceOpWithNewOp(storeOp, storeOp.value(), + subViewOp.source(), sourceIndices); + return matchSuccess(); +} + +//===----------------------------------------------------------------------===// +// Hook for adding patterns. +//===----------------------------------------------------------------------===// + +void mlir::populateStdLegalizationPatternsForSPIRVLowering( + MLIRContext *context, OwningRewritePatternList &patterns) { + patterns.insert(context); +} + +//===----------------------------------------------------------------------===// +// Pass for testing just the legalization patterns. +//===----------------------------------------------------------------------===// + +namespace { +struct SPIRVLegalization final : public OperationPass { + void runOnOperation() override; +}; +} // namespace + +void SPIRVLegalization::runOnOperation() { + OwningRewritePatternList patterns; + auto *context = &getContext(); + populateStdLegalizationPatternsForSPIRVLowering(context, patterns); + applyPatternsGreedily(getOperation()->getRegions(), patterns); +} + +std::unique_ptr mlir::createLegalizeStdOpsForSPIRVLoweringPass() { + return std::make_unique(); +} + +static PassRegistration + pass("legalize-std-for-spirv", "Legalize standard ops for SPIR-V lowering"); diff --git a/third_party/mlir/lib/Conversion/VectorConversions/CMakeLists.txt b/third_party/mlir/lib/Conversion/VectorConversions/CMakeLists.txt deleted file mode 100644 index c8d699e4462..00000000000 --- a/third_party/mlir/lib/Conversion/VectorConversions/CMakeLists.txt +++ /dev/null @@ -1,18 +0,0 @@ -add_llvm_library(MLIRVectorConversions - VectorToLLVM.cpp - VectorToLoops.cpp - VectorToVector.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorConversions -) -set(LIBS - MLIRLLVMIR - MLIRTransforms - LLVMCore - LLVMSupport - ) - -add_dependencies(MLIRVectorConversions ${LIBS}) -add_dependencies(MLIRVectorConversions MLIRVectorTransformPatternsIncGen) -target_link_libraries(MLIRVectorConversions ${LIBS}) diff --git a/third_party/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/third_party/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt new file mode 100644 index 00000000000..2aaec68f6c4 --- /dev/null +++ b/third_party/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -0,0 +1,15 @@ +add_llvm_library(MLIRVectorToLLVM + ConvertVectorToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToLLVM +) +set(LIBS + MLIRLLVMIR + MLIRTransforms + LLVMCore + LLVMSupport + ) + +add_dependencies(MLIRVectorToLLVM ${LIBS}) +target_link_libraries(MLIRVectorToLLVM ${LIBS}) diff --git a/third_party/mlir/lib/Conversion/VectorConversions/VectorToLLVM.cpp b/third_party/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp similarity index 99% rename from third_party/mlir/lib/Conversion/VectorConversions/VectorToLLVM.cpp rename to third_party/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 5420ad05ae1..7221998ce25 100644 --- a/third_party/mlir/lib/Conversion/VectorConversions/VectorToLLVM.cpp +++ b/third_party/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -17,7 +17,7 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" -#include "mlir/Conversion/VectorConversions/VectorConversions.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/IR/Attributes.h" diff --git a/third_party/mlir/lib/Conversion/VectorToLoops/CMakeLists.txt b/third_party/mlir/lib/Conversion/VectorToLoops/CMakeLists.txt new file mode 100644 index 00000000000..e213bc9bcce --- /dev/null +++ b/third_party/mlir/lib/Conversion/VectorToLoops/CMakeLists.txt @@ -0,0 +1,15 @@ +add_llvm_library(MLIRVectorToLoops + ConvertVectorToLoops.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToLoops +) +set(LIBS + MLIRLLVMIR + MLIRTransforms + LLVMCore + LLVMSupport + ) + +add_dependencies(MLIRVectorToLoops ${LIBS}) +target_link_libraries(MLIRVectorToLoops ${LIBS}) diff --git a/third_party/mlir/lib/Conversion/VectorConversions/VectorToLoops.cpp b/third_party/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp similarity index 99% rename from third_party/mlir/lib/Conversion/VectorConversions/VectorToLoops.cpp rename to third_party/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp index 74479b9922d..43ad91ce878 100644 --- a/third_party/mlir/lib/Conversion/VectorConversions/VectorToLoops.cpp +++ b/third_party/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp @@ -21,8 +21,8 @@ #include -#include "mlir/Conversion/VectorConversions/VectorConversions.h" #include "mlir/Dialect/VectorOps/VectorOps.h" +#include "mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h" #include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Helpers.h" #include "mlir/IR/AffineExpr.h" @@ -41,6 +41,8 @@ using vector::TransferWriteOp; namespace { +using vector_type_cast = edsc::intrinsics::ValueBuilder; + /// Implements lowering of TransferReadOp and TransferWriteOp to a /// proper abstraction for the hardware. /// @@ -356,7 +358,6 @@ PatternMatchResult VectorTransferRewriter::matchAndRewrite( } } // namespace -/// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToAffineLoopsConversionPatterns( MLIRContext *context, OwningRewritePatternList &patterns) { patterns.insert, diff --git a/third_party/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/third_party/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 5fc1cade760..8d84fadae8a 100644 --- a/third_party/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/third_party/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -24,6 +24,7 @@ #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" +#include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" diff --git a/third_party/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/third_party/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index ca71db5fd8d..00911012c1d 100644 --- a/third_party/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/third_party/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -22,6 +22,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" @@ -861,9 +862,13 @@ static ParseResult parseConstantOp(OpAsmParser &parser, // Builder, printer and verifier for LLVM::GlobalOp. //===----------------------------------------------------------------------===// +/// Returns the name used for the linkge attribute. This *must* correspond to +/// the name of the attribute in ODS. +static StringRef getLinkageAttrName() { return "linkage"; } + void GlobalOp::build(Builder *builder, OperationState &result, LLVMType type, - bool isConstant, StringRef name, Attribute value, - ArrayRef attrs) { + bool isConstant, Linkage linkage, StringRef name, + Attribute value, ArrayRef attrs) { result.addAttribute(SymbolTable::getSymbolAttrName(), builder->getStringAttr(name)); result.addAttribute("type", TypeAttr::get(type)); @@ -871,12 +876,50 @@ void GlobalOp::build(Builder *builder, OperationState &result, LLVMType type, result.addAttribute("constant", builder->getUnitAttr()); if (value) result.addAttribute("value", value); + result.addAttribute(getLinkageAttrName(), builder->getI64IntegerAttr( + static_cast(linkage))); result.attributes.append(attrs.begin(), attrs.end()); result.addRegion(); } +// Returns the textual representation of the given linkage. +static StringRef linkageToStr(LLVM::Linkage linkage) { + switch (linkage) { + case LLVM::Linkage::Private: + return "private"; + case LLVM::Linkage::Internal: + return "internal"; + case LLVM::Linkage::AvailableExternally: + return "available_externally"; + case LLVM::Linkage::Linkonce: + return "linkonce"; + case LLVM::Linkage::Weak: + return "weak"; + case LLVM::Linkage::Common: + return "common"; + case LLVM::Linkage::Appending: + return "appending"; + case LLVM::Linkage::ExternWeak: + return "extern_weak"; + case LLVM::Linkage::LinkonceODR: + return "linkonce_odr"; + case LLVM::Linkage::WeakODR: + return "weak_odr"; + case LLVM::Linkage::External: + return "external"; + } + llvm_unreachable("unknown linkage type"); +} + +// Prints the keyword for the linkage type using the printer. +static void printLinkage(OpAsmPrinter &p, LLVM::Linkage linkage) { + p << linkageToStr(linkage); +} + static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) { p << op.getOperationName() << ' '; + printLinkage(p, op.linkage()); + p << ' '; if (op.constant()) p << "constant "; p.printSymbolName(op.sym_name()); @@ -884,8 +927,9 @@ static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) { if (auto value = op.getValueOrNull()) p.printAttribute(value); p << ')'; - p.printOptionalAttrDict(op.getAttrs(), {SymbolTable::getSymbolAttrName(), - "type", "constant", "value"}); + p.printOptionalAttrDict(op.getAttrs(), + {SymbolTable::getSymbolAttrName(), "type", "constant", + "value", getLinkageAttrName()}); // Print the trailing type unless it's a string global. if (op.getValueOrNull().dyn_cast_or_null()) @@ -898,12 +942,46 @@ static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) { p.printRegion(initializer, /*printEntryBlockArgs=*/false); } -// ::= `llvm.mlir.global` `constant`? `@` identifier -// `(` attribute? `)` attribute-list? (`:` type)? region? +// Parses one of the keywords provided in the list `keywords` and returns the +// position of the parsed keyword in the list. If none of the keywords from the +// list is parsed, returns -1. +static int parseOptionalKeywordAlternative(OpAsmParser &parser, + ArrayRef keywords) { + for (auto en : llvm::enumerate(keywords)) { + if (succeeded(parser.parseOptionalKeyword(en.value()))) + return en.index(); + } + return -1; +} + +// Parses one of the linkage keywords and, if succeeded, appends the "linkage" +// integer attribute with the corresponding value to `result`. +// +// linkage ::= `private` | `internal` | `available_externally` | `linkonce` +// | `weak` | `common` | `appending` | `extern_weak` +// | `linkonce_odr` | `weak_odr` | `external +static ParseResult parseOptionalLinkageKeyword(OpAsmParser &parser, + OperationState &result) { + int index = parseOptionalKeywordAlternative( + parser, {"private", "internal", "available_externally", "linkonce", + "weak", "common", "appending", "extern_weak", "linkonce_odr", + "weak_odr", "external"}); + if (index == -1) + return failure(); + result.addAttribute(getLinkageAttrName(), + parser.getBuilder().getI64IntegerAttr(index)); + return success(); +} + +// operation ::= `llvm.mlir.global` linkage `constant`? `@` identifier +// `(` attribute? `)` attribute-list? (`:` type)? region? // // The type can be omitted for string attributes, in which case it will be // inferred from the value of the string as [strlen(value) x i8]. static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { + if (failed(parseOptionalLinkageKeyword(parser, result))) + return parser.emitError(parser.getCurrentLocation(), "expected linkage"); + if (succeeded(parser.parseOptionalKeyword("constant"))) result.addAttribute("constant", parser.getBuilder().getUnitAttr()); @@ -1039,12 +1117,15 @@ static ParseResult parseShuffleVectorOp(OpAsmParser &parser, //===----------------------------------------------------------------------===// void LLVMFuncOp::build(Builder *builder, OperationState &result, StringRef name, - LLVMType type, ArrayRef attrs, + LLVMType type, LLVM::Linkage linkage, + ArrayRef attrs, ArrayRef argAttrs) { result.addRegion(); result.addAttribute(SymbolTable::getSymbolAttrName(), builder->getStringAttr(name)); result.addAttribute("type", TypeAttr::get(type)); + result.addAttribute(getLinkageAttrName(), builder->getI64IntegerAttr( + static_cast(linkage))); result.attributes.append(attrs.begin(), attrs.end()); if (argAttrs.empty()) return; @@ -1058,15 +1139,16 @@ void LLVMFuncOp::build(Builder *builder, OperationState &result, StringRef name, result.addAttribute(getArgAttrName(i, argAttrName), argDict); } -// Build an LLVM function type from the given lists of input and output types. +// Builds an LLVM function type from the given lists of input and output types. // Returns a null type if any of the types provided are non-LLVM types, or if // there is more than one output type. -static Type buildLLVMFunctionType(Builder &b, ArrayRef inputs, - ArrayRef outputs, - impl::VariadicFlag variadicFlag, - std::string &errorMessage) { +static Type buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc, + ArrayRef inputs, ArrayRef outputs, + impl::VariadicFlag variadicFlag) { + Builder &b = parser.getBuilder(); if (outputs.size() > 1) { - errorMessage = "expected zero or one function result"; + parser.emitError(loc, "failed to construct function type: expected zero or " + "one function result"); return {}; } @@ -1075,7 +1157,8 @@ static Type buildLLVMFunctionType(Builder &b, ArrayRef inputs, for (auto t : inputs) { auto llvmTy = t.dyn_cast(); if (!llvmTy) { - errorMessage = "expected LLVM type for function arguments"; + parser.emitError(loc, "failed to construct function type: expected LLVM " + "type for function arguments"); return {}; } llvmInputs.push_back(llvmTy); @@ -1091,16 +1174,71 @@ static Type buildLLVMFunctionType(Builder &b, ArrayRef inputs, LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(dialect) : outputs.front().dyn_cast(); if (!llvmOutput) { - errorMessage = "expected LLVM type for function results"; + parser.emitError(loc, "failed to construct function type: expected LLVM " + "type for function results"); return {}; } return LLVMType::getFunctionTy(llvmOutput, llvmInputs, variadicFlag.isVariadic()); } -// Print the LLVMFuncOp. Collects argument and result types and passes them -// to the trait printer. Drops "void" result since it cannot be parsed back. +// Parses an LLVM function. +// +// operation ::= `llvm.func` linkage? function-signature function-attributes? +// function-body +// +static ParseResult parseLLVMFuncOp(OpAsmParser &parser, + OperationState &result) { + // Default to external linkage if no keyword is provided. + if (failed(parseOptionalLinkageKeyword(parser, result))) + result.addAttribute(getLinkageAttrName(), + parser.getBuilder().getI64IntegerAttr( + static_cast(LLVM::Linkage::External))); + + StringAttr nameAttr; + SmallVector entryArgs; + SmallVector, 1> argAttrs; + SmallVector, 1> resultAttrs; + SmallVector argTypes; + SmallVector resultTypes; + bool isVariadic; + + auto signatureLocation = parser.getCurrentLocation(); + if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), + result.attributes) || + impl::parseFunctionSignature(parser, /*allowVariadic=*/true, entryArgs, + argTypes, argAttrs, isVariadic, resultTypes, + resultAttrs)) + return failure(); + + auto type = + buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes, + impl::VariadicFlag(isVariadic)); + if (!type) + return failure(); + result.addAttribute(impl::getTypeAttrName(), TypeAttr::get(type)); + + if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) + return failure(); + impl::addArgAndResultAttrs(parser.getBuilder(), result, argAttrs, + resultAttrs); + + auto *body = result.addRegion(); + return parser.parseOptionalRegion( + *body, entryArgs, entryArgs.empty() ? llvm::ArrayRef() : argTypes); +} + +// Print the LLVMFuncOp. Collects argument and result types and passes them to +// helper functions. Drops "void" result since it cannot be parsed back. Skips +// the external linkage since it is the default value. static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) { + p << op.getOperationName() << ' '; + if (op.linkage() != LLVM::Linkage::External) { + printLinkage(p, op.linkage()); + p << ' '; + } + p.printSymbolName(op.getName()); + LLVMType fnType = op.getType(); SmallVector argTypes; SmallVector resTypes; @@ -1112,7 +1250,15 @@ static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) { if (!returnType.getUnderlyingType()->isVoidTy()) resTypes.push_back(returnType); - impl::printFunctionLikeOp(p, op, argTypes, op.isVarArg(), resTypes); + impl::printFunctionSignature(p, op, argTypes, op.isVarArg(), resTypes); + impl::printFunctionAttributes(p, op, argTypes.size(), resTypes.size(), + {getLinkageAttrName()}); + + // Print the body if this is not an external function. + Region &body = op.body(); + if (!body.empty()) + p.printRegion(body, /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); } // Hook for OpTrait::FunctionLike, called after verifying that the 'type' @@ -1148,9 +1294,26 @@ unsigned LLVMFuncOp::getNumFuncResults() { return 1; } +// Verifies LLVM- and implementation-specific properties of the LLVM func Op: +// - functions don't have 'common' linkage +// - external functions have 'external' or 'extern_weak' linkage; +// - vararg is (currently) only supported for external functions; +// - entry block arguments are of LLVM types and match the function signature. static LogicalResult verify(LLVMFuncOp op) { - if (op.isExternal()) + if (op.linkage() == LLVM::Linkage::Common) + return op.emitOpError() + << "functions cannot have '" << linkageToStr(LLVM::Linkage::Common) + << "' linkage"; + + if (op.isExternal()) { + if (op.linkage() != LLVM::Linkage::External && + op.linkage() != LLVM::Linkage::ExternWeak) + return op.emitOpError() + << "external functions must have '" + << linkageToStr(LLVM::Linkage::External) << "' or '" + << linkageToStr(LLVM::Linkage::ExternWeak) << "' linkage"; return success(); + } if (op.isVarArg()) return op.emitOpError("only external functions can be variadic"); @@ -1488,6 +1651,7 @@ LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) { Value *mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, StringRef name, StringRef value, + LLVM::Linkage linkage, LLVM::LLVMDialect *llvmDialect) { assert(builder.getInsertionBlock() && builder.getInsertionBlock()->getParentOp() && @@ -1501,7 +1665,8 @@ Value *mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, auto type = LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(llvmDialect), value.size()); auto global = moduleBuilder.create( - loc, type, /*isConstant=*/true, name, builder.getStringAttr(value)); + loc, type, /*isConstant=*/true, linkage, name, + builder.getStringAttr(value)); // Get the pointer to the first character in the global string. Value *globalPtr = builder.create(loc, global); diff --git a/third_party/mlir/lib/Dialect/Linalg/CMakeLists.txt b/third_party/mlir/lib/Dialect/Linalg/CMakeLists.txt index 4b7cd81be94..a4ce5038891 100644 --- a/third_party/mlir/lib/Dialect/Linalg/CMakeLists.txt +++ b/third_party/mlir/lib/Dialect/Linalg/CMakeLists.txt @@ -5,7 +5,7 @@ add_llvm_library(MLIRLinalg IR/LinalgTypes.cpp Transforms/Fusion.cpp Transforms/LinalgTransforms.cpp - Transforms/LowerToLoops.cpp + Transforms/LinalgToLoops.cpp Transforms/Promotion.cpp Transforms/Tiling.cpp Utils/Utils.cpp diff --git a/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp b/third_party/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp similarity index 64% rename from third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp rename to third_party/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp index 0bf4ceaa33b..cf0b235f57f 100644 --- a/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp +++ b/third_party/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h" #include "mlir/Dialect/Linalg/Utils/Intrinsics.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/LoopOps/LoopOps.h" @@ -41,12 +42,14 @@ using namespace mlir::linalg; using namespace mlir::linalg::intrinsics; using IndexedStdValue = TemplatedIndexedValue; +using IndexedAffineValue = TemplatedIndexedValue; + using edsc::op::operator+; using edsc::op::operator==; static SmallVector -foldedAffineApplies(OpBuilder &b, Location loc, AffineMap map, - ArrayRef vals, OperationFolder *folder) { +makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map, + ArrayRef vals) { assert(map.getNumSymbols() == 0); assert(map.getNumInputs() == vals.size()); SmallVector res; @@ -56,17 +59,16 @@ foldedAffineApplies(OpBuilder &b, Location loc, AffineMap map, auto exprMap = AffineMap::get(dims, 0, e); SmallVector operands(vals.begin(), vals.end()); canonicalizeMapAndOperands(&exprMap, &operands); - res.push_back(affine_apply(folder, exprMap, operands)); + res.push_back(affine_apply(exprMap, operands)); } return res; } static SmallVector permuteIvs(ArrayRef ivs, - Optional permutation, - OperationFolder *folder) { + Optional permutation) { return permutation ? applyMapToValues(ScopedContext::getBuilder(), ScopedContext::getLocation(), - permutation.getValue(), ivs, folder) + permutation.getValue(), ivs) : SmallVector(ivs.begin(), ivs.end()); } @@ -75,20 +77,17 @@ static SmallVector permuteIvs(ArrayRef ivs, // which new loops will be created. static SmallVector emitLoopRanges(OpBuilder &b, Location loc, AffineMap map, - ArrayRef allViewSizes, - OperationFolder *folder); + ArrayRef allViewSizes); SmallVector emitLoopRanges(OpBuilder &b, Location loc, AffineMap map, - ArrayRef allViewSizes, - OperationFolder *folder) { + ArrayRef allViewSizes) { // Apply `map` to get view sizes in loop order. - auto sizes = applyMapToValues(b, loc, map, allViewSizes, folder); + auto sizes = applyMapToValues(b, loc, map, allViewSizes); // Create a new range with the applied tile sizes. ScopedContext scope(b, loc); SmallVector res; for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) { - res.push_back(range(constant_index(folder, 0), sizes[idx], - constant_index(folder, 1))); + res.push_back(range(constant_index(0), sizes[idx], constant_index(1))); } return res; } @@ -99,14 +98,14 @@ class LinalgScopedEmitter {}; template class LinalgScopedEmitter { public: - static void emitScalarImplementation(ArrayRef allIvs, CopyOp copyOp, - OperationFolder *folder) { + static void emitScalarImplementation(ArrayRef allIvs, + CopyOp copyOp) { auto nPar = copyOp.getNumParallelLoops(); assert(nPar == allIvs.size()); auto inputIvs = - permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation(), folder); + permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation()); auto outputIvs = - permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation(), folder); + permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation()); SmallVector iivs(inputIvs.begin(), inputIvs.end()); SmallVector oivs(outputIvs.begin(), outputIvs.end()); IndexedValueType O(copyOp.getOutput(0)), I(copyOp.getInput(0)); @@ -122,8 +121,8 @@ public: template class LinalgScopedEmitter { public: - static void emitScalarImplementation(ArrayRef allIvs, FillOp fillOp, - OperationFolder *folder) { + static void emitScalarImplementation(ArrayRef allIvs, + FillOp fillOp) { auto nPar = fillOp.getNumParallelLoops(); assert(nPar == allIvs.size()); auto ivs = @@ -139,8 +138,7 @@ public: template class LinalgScopedEmitter { public: - static void emitScalarImplementation(ArrayRef allIvs, DotOp dotOp, - OperationFolder *folder) { + static void emitScalarImplementation(ArrayRef allIvs, DotOp dotOp) { assert(allIvs.size() == 1); IndexHandle r_i(allIvs[0]); IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)), @@ -154,8 +152,7 @@ template class LinalgScopedEmitter { public: static void emitScalarImplementation(ArrayRef allIvs, - MatvecOp matvecOp, - OperationFolder *folder) { + MatvecOp matvecOp) { assert(allIvs.size() == 2); IndexHandle i(allIvs[0]), r_j(allIvs[1]); IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)), @@ -169,8 +166,7 @@ template class LinalgScopedEmitter { public: static void emitScalarImplementation(ArrayRef allIvs, - MatmulOp matmulOp, - OperationFolder *folder) { + MatmulOp matmulOp) { assert(allIvs.size() == 3); IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]); IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)), @@ -183,17 +179,17 @@ public: template class LinalgScopedEmitter { public: - static void emitScalarImplementation(ArrayRef allIvs, ConvOp convOp, - OperationFolder *folder) { + static void emitScalarImplementation(ArrayRef allIvs, + ConvOp convOp) { auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); auto maps = loopToOperandRangesMaps(convOp); SmallVector fIdx( - foldedAffineApplies(b, loc, maps[0], allIvs, folder)); + makeCanonicalAffineApplies(b, loc, maps[0], allIvs)); SmallVector imIdx( - foldedAffineApplies(b, loc, maps[1], allIvs, folder)); + makeCanonicalAffineApplies(b, loc, maps[1], allIvs)); SmallVector oIdx( - foldedAffineApplies(b, loc, maps[2], allIvs, folder)); + makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); IndexedValueType F(convOp.filter()), I(convOp.input()), O(convOp.output()); // Emit scalar form. O(oIdx) += F(fIdx) * I(imIdx); @@ -234,8 +230,7 @@ template class LinalgScopedEmitter { public: static void emitScalarImplementation(ArrayRef allIvs, - GenericOp genericOp, - OperationFolder *folder) { + GenericOp genericOp) { auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); using edsc::intrinsics::detail::ValueHandleArray; @@ -245,15 +240,15 @@ public: // 1.a. Emit std_load from input views. for (unsigned i = 0; i < nInputs; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, genericOp.getInputIndexingMap(i), allIvs, folder)); + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getInputIndexingMap(i), allIvs)); indexedValues[i] = std_load(genericOp.getInput(i), indexing); } // 1.b. Emit std_load from output views. for (unsigned i = 0; i < nOutputs; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder)); + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs)); indexedValues[nInputs + i] = std_load(genericOp.getOutput(i), indexing); } @@ -265,8 +260,8 @@ public: // 3. Emit std_store. for (unsigned i = 0; i < nOutputs; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder)); + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs)); std_store(callOp->getResult(i), genericOp.getOutput(i), indexing); } return; @@ -288,8 +283,8 @@ public: auto *yieldOp = cast(block.back()).getOperation(); assert(yieldOp->getNumOperands() == nOutputs); for (unsigned i = 0; i < nOutputs; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder)); + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs)); std_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i), indexing); } @@ -330,8 +325,7 @@ template class LinalgScopedEmitter { public: static void emitScalarImplementation(ArrayRef allIvs, - IndexedGenericOp indexedGenericOp, - OperationFolder *folder) { + IndexedGenericOp indexedGenericOp) { auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); using edsc::intrinsics::detail::ValueHandleArray; @@ -346,16 +340,16 @@ public: // 1.a. Emit std_load from input views. for (unsigned i = 0; i < nInputs; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs, folder)); + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs)); indexedValues[nLoops + i] = std_load(indexedGenericOp.getInput(i), indexing); } // 1.b. Emit std_load from output views. for (unsigned i = 0; i < nOutputs; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs, folder)); + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); indexedValues[nLoops + nInputs + i] = std_load(indexedGenericOp.getOutput(i), indexing); } @@ -367,8 +361,8 @@ public: // 3. Emit std_store. for (unsigned i = 0; i < nOutputs; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs, folder)); + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); std_store(callOp->getResult(i), indexedGenericOp.getOutput(i), indexing); } @@ -391,96 +385,110 @@ public: auto *yieldOp = cast(block.back()).getOperation(); assert(yieldOp->getNumOperands() == nOutputs); for (unsigned i = 0; i < nOutputs; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs, folder)); + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); std_store(map.lookup(yieldOp->getOperand(i)), indexedGenericOp.getOutput(i), indexing); } } }; +namespace { +// This struct is for factoring out the implementation and support template +// instantiations in the following 2 cases: +// 1. Appending to a list of patterns via RewritePatternList. +// 2. Direct invocation via `linalgOpToLoops` and `linalgOpToAffineLoops`. +// The implementation must work both in DRR and inside a RewritePattern. As a +// consequence, (1) it is only allowed to emit new ops if the match is +// guaranteed to be a success, (2) it is not allowed erase/replace, and (3) an +// encompassing pattern must take care of the erasure logic. +template +class LinalgOpToLoopsImpl { +public: + static LogicalResult doit(Operation *op, PatternRewriter &rewriter); +}; +} // namespace + +template +LogicalResult LinalgOpToLoopsImpl::doit( + Operation *op, PatternRewriter &rewriter) { + OpBuilder b(op); + ScopedContext scope(b, op->getLoc()); + + // The flattened loopToOperandRangesMaps is expected to be an invertible + // permutation map (which is asserted in the inverse calculation). + auto linalgOp = cast(op); + auto invertedMap = + inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp))); + if (!invertedMap) { + LinalgScopedEmitter::emitScalarImplementation( + {}, linalgOp); + return success(); + } + + auto nPar = linalgOp.getNumParallelLoops(); + auto nRed = linalgOp.getNumReductionLoops(); + auto nWin = linalgOp.getNumWindowLoops(); + SmallVector allIvs(nPar + nRed + nWin); + SmallVector allPIvs = makeIndexHandlePointers(allIvs); + auto loopRanges = emitLoopRanges(scope.getBuilder(), scope.getLocation(), + invertedMap, getViewSizes(linalgOp)); + assert(loopRanges.size() == allIvs.size()); + + LoopNestRangeBuilder(allPIvs, loopRanges)([&] { + auto allIvValues = extractValues(allIvs); + LinalgScopedEmitter::emitScalarImplementation( + allIvValues, linalgOp); + }); + return success(); +} + template class LinalgRewritePattern : public RewritePattern { public: explicit LinalgRewritePattern(MLIRContext *context) - : RewritePattern(ConcreteOp::getOperationName(), /*benefit=*/1, context), - folder(context) {} + : RewritePattern(ConcreteOp::getOperationName(), 1, context) {} PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - OpBuilder b(op); - ScopedContext scope(b, op->getLoc()); - - // The flattened loopToOperandRangesMaps is expected to be an invertible - // permutation map (which is asserted in the inverse calculation). - auto linalgOp = cast(op); - auto invertedMap = - inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp))); - if (!invertedMap) { - LinalgScopedEmitter::emitScalarImplementation({}, linalgOp, - &folder); - rewriter.eraseOp(op); - return matchSuccess(); - } - - auto nPar = linalgOp.getNumParallelLoops(); - auto nRed = linalgOp.getNumReductionLoops(); - auto nWin = linalgOp.getNumWindowLoops(); - SmallVector allIvs(nPar + nRed + nWin); - SmallVector allPIvs = makeIndexHandlePointers(allIvs); - auto loopRanges = - emitLoopRanges(scope.getBuilder(), scope.getLocation(), invertedMap, - getViewSizes(linalgOp), &folder); - assert(loopRanges.size() == allIvs.size()); - - // clang-format off; - LoopNestRangeBuilder(allPIvs, loopRanges)([&] { - auto allIvValues = extractValues(allIvs); - LinalgScopedEmitter::emitScalarImplementation(allIvValues, - linalgOp, - &folder); - }); - // clang-format on + using Impl = LinalgOpToLoopsImpl; + if (failed(Impl::doit(op, rewriter))) + return matchFailure(); rewriter.eraseOp(op); return matchSuccess(); } - - mutable OperationFolder folder; }; // Helper classes for type list expansion. template -class ConversionList; +class RewritePatternList; template -class ConversionList { +class RewritePatternList { public: static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {} }; template -class ConversionList { +class RewritePatternList { public: static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) { patterns .insert>( ctx); - ConversionList::build(patterns, - ctx); + RewritePatternList::build( + patterns, ctx); } }; /// Populate the given list with patterns that convert from Linalg to LLVM. template -void ForOpRewritePatterns(OwningRewritePatternList &patterns, - MLIRContext *ctx) { - ConversionList::build(patterns, ctx); + >::build(patterns, ctx); } namespace { @@ -491,28 +499,114 @@ struct LowerLinalgToLoopsPass }; } // namespace +// Local folding pattern for AffineApplyOp that we can apply greedily. +// This replaces AffineApplyOp by the proper value in cases where the associated +// map is trivial. A trivial map here is defined as a map with a single result +// and either: +// 1. Zero operand + returns a single AffineConstantExpr +// 2. One operand + returns a single AffineDimExpr +// 3. One operands + returns a single AffineSymbolExpr +// +// In the first case, the AffineApplyOp is replaced by a new constant. In the +// other cases, it is replaced by its unique operand. +struct FoldAffineOp : public RewritePattern { + FoldAffineOp(MLIRContext *context) + : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + AffineApplyOp affineApplyOp = cast(op); + auto map = affineApplyOp.getAffineMap(); + if (map.getNumResults() != 1 || map.getNumInputs() > 1) + return matchFailure(); + + AffineExpr expr = map.getResult(0); + if (map.getNumInputs() == 0) { + if (auto val = expr.dyn_cast()) { + rewriter.replaceOpWithNewOp(op, val.getValue()); + return matchSuccess(); + } + return matchFailure(); + } + if (expr.dyn_cast() || expr.dyn_cast()) { + rewriter.replaceOp(op, op->getOperand(0)); + return matchSuccess(); + } + return matchFailure(); + } +}; + template void LowerLinalgToLoopsPass::runOnFunction() { + auto *context = &this->getContext(); OwningRewritePatternList patterns; - ForOpRewritePatterns(patterns, - &this->getContext()); - - ConversionTarget target(this->getContext()); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - if (failed(applyPartialConversion(this->getFunction(), target, patterns))) { - this->signalPassFailure(); - } + // Canonicalization and folding patterns applied greedily allow cleaning up + // the emitted IR on the fly. + // TODO(ntv) fold view and subview ops? + FillRewritePatterns(patterns, context); + DimOp::getCanonicalizationPatterns(patterns, context); + AffineApplyOp::getCanonicalizationPatterns(patterns, context); + patterns.insert(context); + // Just apply the patterns greedily. + applyPatternsGreedily(this->getFunction(), patterns); } +/// Create a pass to convert Linalg operations to loop.for loops and +/// std.load/std.store accesses. std::unique_ptr> -mlir::linalg::createLowerLinalgToLoopsPass() { +mlir::linalg::createConvertLinalgToLoopsPass() { return std::make_unique< LowerLinalgToLoopsPass>(); } +/// Create a pass to convert Linalg operations to affine.for loops and +/// affine_load/affine_store accesses. +/// Placeholder for now, this is NYI. +std::unique_ptr> +mlir::linalg::createConvertLinalgToAffineLoopsPass() { + return std::make_unique< + LowerLinalgToLoopsPass>(); +} + +// Emits a loop nest of `loop.for` with the proper body for `op`. +template +LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter, + Operation *op) { + return LinalgOpToLoopsImpl::doit( + op, rewriter); +} + +// Emits a loop nest of `affine.for` with the proper body for `op`. +template +LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter, + Operation *op) { + return LinalgOpToLoopsImpl::doit( + op, rewriter); +} + +// TODO(ntv) Need to make these instantiations more future-proof to avoid the +// need to update as soon as we add new ops. +#define INSTANTIATE_LINALG_OP_TO_LOOPS(OP_TYPE) \ + template LogicalResult mlir::linalg::linalgOpToLoops( \ + PatternRewriter & rewriter, Operation * op); \ + template LogicalResult mlir::linalg::linalgOpToAffineLoops( \ + PatternRewriter & rewriter, Operation * op); + +INSTANTIATE_LINALG_OP_TO_LOOPS(CopyOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(FillOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(DotOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(MatvecOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(MatmulOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(ConvOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(GenericOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(IndexedGenericOp) + static PassRegistration> structuredLoopsPass( - "linalg-lower-to-loops", + "convert-linalg-to-loops", "Lower the operations from the linalg dialect into loops"); + +static PassRegistration> + affineLoopsPass( + "convert-linalg-to-affine-loops", + "Lower the operations from the linalg dialect into affine loops"); diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp index 3c571add56a..e3b550223e5 100644 --- a/third_party/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -68,8 +68,7 @@ mlir::spirv::getEntryPointABIAttr(ArrayRef localSize, // Type Conversion //===----------------------------------------------------------------------===// -namespace { -Type convertIndexType(MLIRContext *context) { +Type SPIRVTypeConverter::getIndexType(MLIRContext *context) { // Convert to 32-bit integers for now. Might need a way to control this in // future. // TODO(ravishankarm): It is porbably better to make it 64-bit integers. To @@ -82,27 +81,54 @@ Type convertIndexType(MLIRContext *context) { // TODO(ravishankarm): This is a utility function that should probably be // exposed by the SPIR-V dialect. Keeping it local till the use case arises. -Optional getTypeNumBytes(Type t) { +static Optional getTypeNumBytes(Type t) { if (auto integerType = t.dyn_cast()) { return integerType.getWidth() / 8; } else if (auto floatType = t.dyn_cast()) { return floatType.getWidth() / 8; + } else if (auto memRefType = t.dyn_cast()) { + // TODO: Layout should also be controlled by the ABI attributes. For now + // using the layout from MemRef. + int64_t offset; + SmallVector strides; + if (!memRefType.hasStaticShape() || + failed(getStridesAndOffset(memRefType, strides, offset))) { + return llvm::None; + } + // To get the size of the memref object in memory, the total size is the + // max(stride * dimension-size) computed for all dimensions times the size + // of the element. + auto elementSize = getTypeNumBytes(memRefType.getElementType()); + if (!elementSize) { + return llvm::None; + } + auto dims = memRefType.getShape(); + if (llvm::is_contained(dims, ShapedType::kDynamicSize) || + offset == MemRefType::getDynamicStrideOrOffset() || + llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) { + return llvm::None; + } + int64_t memrefSize = -1; + for (auto shape : enumerate(dims)) { + memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]); + } + return (offset + memrefSize) * elementSize.getValue(); } // TODO: Add size computation for other types. return llvm::None; } -Type typeConversionImpl(Type t) { - // Check if the type is SPIR-V supported. If so return the type. - if (spirv::SPIRVDialect::isValidType(t)) { - return t; +static Type convertStdType(Type type) { + // If the type is already valid in SPIR-V, directly return. + if (spirv::SPIRVDialect::isValidType(type)) { + return type; } - if (auto indexType = t.dyn_cast()) { - return convertIndexType(t.getContext()); + if (auto indexType = type.dyn_cast()) { + return SPIRVTypeConverter::getIndexType(type.getContext()); } - if (auto memRefType = t.dyn_cast()) { + if (auto memRefType = type.dyn_cast()) { // TODO(ravishankarm): For now only support default memory space. The memory // space description is not set is stone within MLIR, i.e. it depends on the // context it is being used. To map this to SPIR-V storage classes, we @@ -111,60 +137,46 @@ Type typeConversionImpl(Type t) { if (memRefType.getMemorySpace()) { return Type(); } - auto elementType = typeConversionImpl(memRefType.getElementType()); + + auto elementType = convertStdType(memRefType.getElementType()); if (!elementType) { return Type(); } + auto elementSize = getTypeNumBytes(elementType); if (!elementSize) { return Type(); } // TODO(ravishankarm) : Handle dynamic shapes. if (memRefType.hasStaticShape()) { - // Get the strides and offset - int64_t offset; - SmallVector strides; - if (failed(getStridesAndOffset(memRefType, strides, offset)) || - offset == MemRefType::getDynamicStrideOrOffset() || - llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) { - // TODO(ravishankarm) : Handle dynamic strides and offsets. + auto arraySize = getTypeNumBytes(memRefType); + if (!arraySize) { return Type(); } - // Convert to a multi-dimensional spv.array if size is known. - auto shape = memRefType.getShape(); - assert(shape.size() == strides.size()); - for (int i = shape.size(); i > 0; --i) { - elementType = spirv::ArrayType::get( - elementType, shape[i - 1], strides[i - 1] * elementSize.getValue()); - } - // For the offset, need to wrap the array in a struct. - auto structType = - spirv::StructType::get(elementType, offset * elementSize.getValue()); + auto arrayType = spirv::ArrayType::get( + elementType, arraySize.getValue() / elementSize.getValue(), + elementSize.getValue()); + auto structType = spirv::StructType::get(arrayType, 0); // For now initialize the storage class to StorageBuffer. This will be // updated later based on whats passed in w.r.t to the ABI attributes. return spirv::PointerType::get(structType, spirv::StorageClass::StorageBuffer); } } + return Type(); } -} // namespace -Type SPIRVTypeConverter::convertType(Type t) { return typeConversionImpl(t); } - -Type SPIRVTypeConverter::getIndexType(MLIRContext *context) { - return convertType(IndexType::get(context)); -} +Type SPIRVTypeConverter::convertType(Type type) { return convertStdType(type); } //===----------------------------------------------------------------------===// // Builtin Variables //===----------------------------------------------------------------------===// -namespace { /// Look through all global variables in `moduleOp` and check if there is a /// spv.globalVariable that has the same `builtin` attribute. -spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp, - spirv::BuiltIn builtin) { +static spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp, + spirv::BuiltIn builtin) { for (auto varOp : moduleOp.getBlock().getOps()) { if (auto builtinAttr = varOp.getAttrOfType(convertToSnakeCase( stringifyDecoration(spirv::Decoration::BuiltIn)))) { @@ -178,15 +190,14 @@ spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp, } /// Gets name of global variable for a buitlin. -std::string getBuiltinVarName(spirv::BuiltIn builtin) { +static std::string getBuiltinVarName(spirv::BuiltIn builtin) { return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__"; } /// Gets or inserts a global variable for a builtin within a module. -spirv::GlobalVariableOp getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp, - Location loc, - spirv::BuiltIn builtin, - OpBuilder &builder) { +static spirv::GlobalVariableOp +getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp, Location loc, + spirv::BuiltIn builtin, OpBuilder &builder) { if (auto varOp = getBuiltinVariable(moduleOp, builtin)) { return varOp; } @@ -217,7 +228,6 @@ spirv::GlobalVariableOp getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp, builder.restoreInsertionPoint(ip); return newVarOp; } -} // namespace /// Gets the global variable associated with a builtin and add /// it if it doesnt exist. diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 6bb052d49d7..89abbe894e6 100644 --- a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -26,6 +26,7 @@ #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" @@ -335,6 +336,9 @@ static void printVariableDecorations(Operation *op, OpAsmPrinter &printer, // `indices`. Returns a null Attribute if error happens. static Attribute extractCompositeElement(Attribute composite, ArrayRef indices) { + // Check that given composite is a constant. + if (!composite) + return {}; // Return composite itself if we reach the end of the index chain. if (indices.empty()) return composite; @@ -381,7 +385,9 @@ static inline bool isMergeBlock(Block &block) { // TableGen'erated canonicalizers //===----------------------------------------------------------------------===// +namespace { #include "SPIRVCanonicalization.inc" +} //===----------------------------------------------------------------------===// // Common parsers and printers @@ -1168,6 +1174,35 @@ bool spirv::ConstantOp::isBuildableWith(Type type) { return true; } +spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc, + OpBuilder *builder) { + if (auto intType = type.dyn_cast()) { + unsigned width = intType.getWidth(); + Attribute val; + if (width == 1) + return builder->create(loc, type, + builder->getBoolAttr(false)); + return builder->create( + loc, type, builder->getIntegerAttr(type, APInt(width, 0))); + } + + llvm_unreachable("unimplemented types for ConstantOp::getZero()"); +} + +spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc, + OpBuilder *builder) { + if (auto intType = type.dyn_cast()) { + unsigned width = intType.getWidth(); + if (width == 1) + return builder->create(loc, type, + builder->getBoolAttr(true)); + return builder->create( + loc, type, builder->getIntegerAttr(type, APInt(width, 1))); + } + + llvm_unreachable("unimplemented types for ConstantOp::getOne()"); +} + //===----------------------------------------------------------------------===// // spv.ControlBarrier //===----------------------------------------------------------------------===// @@ -1518,6 +1553,73 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) { return success(); } +//===----------------------------------------------------------------------===// +// spv.GroupNonUniformBallotOp +//===----------------------------------------------------------------------===// + +static ParseResult parseGroupNonUniformBallotOp(OpAsmParser &parser, + OperationState &state) { + spirv::Scope executionScope; + OpAsmParser::OperandType operandInfo; + Type resultType; + IntegerType i1Type = parser.getBuilder().getI1Type(); + if (parseEnumAttribute(executionScope, parser, state, + kExecutionScopeAttrName) || + parser.parseOperand(operandInfo) || parser.parseColonType(resultType) || + parser.resolveOperand(operandInfo, i1Type, state.operands)) + return failure(); + + return parser.addTypeToList(resultType, state.types); +} + +static void print(spirv::GroupNonUniformBallotOp ballotOp, + OpAsmPrinter &printer) { + printer << spirv::GroupNonUniformBallotOp::getOperationName() << " \"" + << stringifyScope(ballotOp.execution_scope()) << "\" "; + printer.printOperand(ballotOp.predicate()); + printer << " : " << ballotOp.getType(); +} + +static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) { + // TODO(antiagainst): check the result integer type's signedness bit is 0. + + spirv::Scope scope = ballotOp.execution_scope(); + if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) + return ballotOp.emitOpError( + "execution scope must be 'Workgroup' or 'Subgroup'"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// spv.IAdd +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::IAddOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "spv.IAdd expects two operands"); + // lhs + 0 = lhs + if (matchPattern(operand2(), m_Zero())) + return operand1(); + + return nullptr; +} + +//===----------------------------------------------------------------------===// +// spv.IMul +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::IMulOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "spv.IMul expects two operands"); + // lhs * 0 == 0 + if (matchPattern(operand2(), m_Zero())) + return operand2(); + // lhs * 1 = lhs + if (matchPattern(operand2(), m_One())) + return operand1(); + + return nullptr; +} + //===----------------------------------------------------------------------===// // spv.LoadOp //===----------------------------------------------------------------------===// @@ -2464,6 +2566,28 @@ static LogicalResult verify(spirv::StoreOp storeOp) { return verifyMemoryAccessAttribute(storeOp); } +//===----------------------------------------------------------------------===// +// spv.SubgroupBallotKHROp +//===----------------------------------------------------------------------===// + +static ParseResult parseSubgroupBallotKHROp(OpAsmParser &parser, + OperationState &state) { + OpAsmParser::OperandType operandInfo; + Type resultType; + IntegerType i1Type = parser.getBuilder().getI1Type(); + if (parser.parseOperand(operandInfo) || parser.parseColonType(resultType) || + parser.resolveOperand(operandInfo, i1Type, state.operands)) + return failure(); + + return parser.addTypeToList(resultType, state.types); +} + +static void print(spirv::SubgroupBallotKHROp ballotOp, OpAsmPrinter &printer) { + printer << spirv::SubgroupBallotKHROp::getOperationName() << ' '; + printer.printOperand(ballotOp.predicate()); + printer << " : " << ballotOp.getType(); +} + //===----------------------------------------------------------------------===// // spv.Undef //===----------------------------------------------------------------------===// @@ -2533,11 +2657,10 @@ static ParseResult parseVariableOp(OpAsmParser &parser, OperationState &state) { state.addTypes(ptrType); // Resolve the initializer operand - SmallVector init; if (initInfo) { - if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(), init)) + if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(), + state.operands)) return failure(); - state.addOperands(init); } auto attr = parser.getBuilder().getI32IntegerAttr( diff --git a/third_party/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/third_party/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index e9d36f66369..d48b31fe491 100644 --- a/third_party/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/third_party/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -194,8 +194,8 @@ FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef operands, if (isScalarOrVectorType(argType.value())) { auto indexType = typeConverter.convertType(IndexType::get(funcOp.getContext())); - auto zero = rewriter.create( - funcOp.getLoc(), indexType, rewriter.getIntegerAttr(indexType, 0)); + auto zero = + spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), &rewriter); auto loadPtr = rewriter.create( funcOp.getLoc(), replacement, zero.constant()); replacement = rewriter.create(funcOp.getLoc(), loadPtr, diff --git a/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp b/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp index 0bf562337a9..9f6510d0f17 100644 --- a/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -1370,7 +1370,7 @@ OpFoldResult DimOp::fold(ArrayRef operands) { // Fold dim to the size argument of a SubViewOp. auto memref = memrefOrTensor()->getDefiningOp(); if (auto subview = dyn_cast_or_null(memref)) { - auto sizes = subview.getDynamicSizes(); + auto sizes = subview.sizes(); if (!sizes.empty()) return *(sizes.begin() + getIndex()); } @@ -2327,9 +2327,15 @@ static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { SmallVector sizesInfo; auto indexType = parser.getBuilder().getIndexType(); Type srcType, dstType; + llvm::SMLoc offsetLoc; + if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || + parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) + return failure(); + + if (offsetInfo.size() > 1) + return parser.emitError(offsetLoc) << "expects 0 or 1 offset operand"; + return failure( - parser.parseOperand(srcInfo) || - parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square) || parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(srcType) || @@ -2563,35 +2569,23 @@ static Type inferSubViewResultType(MemRefType memRefType) { memRefType.getMemorySpace()); } -void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType, - Value *source, unsigned num_offsets, - unsigned num_sizes, unsigned num_strides, - ArrayRef offsets, ArrayRef sizes, - ArrayRef strides) { - SmallVector operands; - operands.reserve(num_offsets + num_sizes + num_strides); - operands.append(offsets.begin(), offsets.end()); - operands.append(sizes.begin(), sizes.end()); - operands.append(strides.begin(), strides.end()); - build(b, result, resultType, source, b->getI32IntegerAttr(num_offsets), - b->getI32IntegerAttr(num_sizes), b->getI32IntegerAttr(num_strides), - operands); -} - void mlir::SubViewOp::build(Builder *b, OperationState &result, Value *source, ArrayRef offsets, ArrayRef sizes, ArrayRef strides, Type resultType, ArrayRef attrs) { if (!resultType) resultType = inferSubViewResultType(source->getType().cast()); - build(b, result, resultType, source, offsets.size(), sizes.size(), - strides.size(), offsets, sizes, strides); + auto segmentAttr = b->getI32VectorAttr( + {1, static_cast(offsets.size()), static_cast(sizes.size()), + static_cast(strides.size())}); + build(b, result, resultType, source, offsets, sizes, strides, segmentAttr); result.addAttributes(attrs); } void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType, Value *source) { - build(b, result, resultType, source, 0, 0, 0, {}, {}, {}); + build(b, result, source, /*offsets=*/{}, /*sizes=*/{}, /*strides=*/{}, + resultType); } static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { @@ -2607,12 +2601,13 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { parser.parseOperandList(stridesInfo, OpAsmParser::Delimiter::Square)) { return failure(); } + auto builder = parser.getBuilder(); - result.addAttribute("num_offsets", - builder.getI32IntegerAttr(offsetsInfo.size())); - result.addAttribute("num_sizes", builder.getI32IntegerAttr(sizesInfo.size())); - result.addAttribute("num_strides", - builder.getI32IntegerAttr(stridesInfo.size())); + result.addAttribute( + SubViewOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr({1, static_cast(offsetsInfo.size()), + static_cast(sizesInfo.size()), + static_cast(stridesInfo.size())})); return failure( parser.parseOptionalAttrDict(result.attributes) || @@ -2627,14 +2622,15 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { static void print(OpAsmPrinter &p, SubViewOp op) { p << op.getOperationName() << ' ' << *op.getOperand(0) << '['; - p.printOperands(op.getDynamicOffsets()); + p.printOperands(op.offsets()); p << "]["; - p.printOperands(op.getDynamicSizes()); + p.printOperands(op.sizes()); p << "]["; - p.printOperands(op.getDynamicStrides()); + p.printOperands(op.strides()); p << ']'; - SmallVector elidedAttrs = {"num_offsets", "num_sizes", - "num_strides"}; + + SmallVector elidedAttrs = { + SubViewOp::getOperandSegmentSizeAttr()}; p.printOptionalAttrDict(op.getAttrs(), elidedAttrs); p << " : " << op.getOperand(0)->getType() << " to " << op.getType(); } @@ -2689,14 +2685,16 @@ static LogicalResult verify(SubViewOp op) { } // Verify that if the shape of the subview type is static, then sizes are not - // dynamic values, and viceversa. + // dynamic values, and vice versa. if ((subViewType.hasStaticShape() && op.getNumSizes() != 0) || (op.getNumSizes() == 0 && !subViewType.hasStaticShape())) { return op.emitError("invalid to specify dynamic sizes when subview result " "type is statically shaped and viceversa"); } + + // Verify that if dynamic sizes are specified, then the result memref type + // have full dynamic dimensions. if (op.getNumSizes() > 0) { - // Verify that non if the shape values of the result type are static. if (llvm::any_of(subViewType.getShape(), [](int64_t dim) { return dim != ShapedType::kDynamicSize; })) { @@ -2758,12 +2756,47 @@ SmallVector SubViewOp::getRanges() { unsigned rank = getType().getRank(); res.reserve(rank); for (unsigned i = 0; i < rank; ++i) - res.emplace_back(Range{*(getDynamicOffsets().begin() + i), - *(getDynamicSizes().begin() + i), - *(getDynamicStrides().begin() + i)}); + res.emplace_back(Range{*(offsets().begin() + i), *(sizes().begin() + i), + *(strides().begin() + i)}); return res; } +LogicalResult +SubViewOp::getStaticStrides(SmallVectorImpl &staticStrides) { + // If the strides are dynamic return failure. + if (getNumStrides()) + return failure(); + + // When static, the stride operands can be retrieved by taking the strides of + // the result of the subview op, and dividing the strides of the base memref. + int64_t resultOffset, baseOffset; + SmallVector resultStrides, baseStrides; + if (failed( + getStridesAndOffset(getBaseMemRefType(), baseStrides, baseOffset)) || + llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) || + failed(getStridesAndOffset(getType(), resultStrides, resultOffset))) + return failure(); + + assert(static_cast(resultStrides.size()) == getType().getRank() && + baseStrides.size() == resultStrides.size() && + "base and result memrefs must have the same rank"); + assert(!llvm::is_contained(resultStrides, + MemRefType::getDynamicStrideOrOffset()) && + "strides of subview op must be static, when there are no dynamic " + "strides specified"); + staticStrides.resize(getType().getRank()); + for (auto resultStride : enumerate(resultStrides)) { + auto baseStride = baseStrides[resultStride.index()]; + // The result stride is expected to be a multiple of the base stride. Abort + // if that is not the case. + if (resultStride.value() < baseStride || + resultStride.value() % baseStride != 0) + return failure(); + staticStrides[resultStride.index()] = resultStride.value() / baseStride; + } + return success(); +} + static bool hasConstantOffsetSizesAndStrides(MemRefType memrefType) { if (memrefType.getNumDynamicDims() > 0) return false; @@ -2792,13 +2825,13 @@ public: // Follow all or nothing approach for shapes for now. If all the operands // for sizes are constants then fold it into the type of the result memref. if (subViewType.hasStaticShape() || - llvm::any_of(subViewOp.getDynamicSizes(), [](Value *operand) { + llvm::any_of(subViewOp.sizes(), [](Value *operand) { return !matchPattern(operand, m_ConstantIndex()); })) { return matchFailure(); } SmallVector staticShape(subViewOp.getNumSizes()); - for (auto size : enumerate(subViewOp.getDynamicSizes())) { + for (auto size : enumerate(subViewOp.sizes())) { auto defOp = size.value()->getDefiningOp(); assert(defOp); staticShape[size.index()] = cast(defOp).getValue(); @@ -2808,12 +2841,12 @@ public: subViewType.getMemorySpace()); auto newSubViewOp = rewriter.create( subViewOp.getLoc(), subViewOp.source(), - llvm::to_vector<4>(subViewOp.getDynamicOffsets()), ArrayRef(), - llvm::to_vector<4>(subViewOp.getDynamicStrides()), newMemRefType); + llvm::to_vector<4>(subViewOp.offsets()), ArrayRef(), + llvm::to_vector<4>(subViewOp.strides()), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp( - llvm::to_vector<4>(subViewOp.getDynamicSizes()), subViewOp, - newSubViewOp, subViewOp.getType()); + llvm::to_vector<4>(subViewOp.sizes()), subViewOp, newSubViewOp, + subViewOp.getType()); return matchSuccess(); } }; @@ -2839,14 +2872,14 @@ public: failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) || llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) || - llvm::any_of(subViewOp.getDynamicStrides(), [](Value *stride) { + llvm::any_of(subViewOp.strides(), [](Value *stride) { return !matchPattern(stride, m_ConstantIndex()); })) { return matchFailure(); } SmallVector staticStrides(subViewOp.getNumStrides()); - for (auto stride : enumerate(subViewOp.getDynamicStrides())) { + for (auto stride : enumerate(subViewOp.strides())) { auto defOp = stride.value()->getDefiningOp(); assert(defOp); assert(baseStrides[stride.index()] > 0); @@ -2858,15 +2891,15 @@ public: MemRefType newMemRefType = MemRefType::get(subViewType.getShape(), subViewType.getElementType(), layoutMap, subViewType.getMemorySpace()); - auto newSubViewOp = rewriter.create( - subViewOp.getLoc(), subViewOp.source(), - llvm::to_vector<4>(subViewOp.getDynamicOffsets()), - llvm::to_vector<4>(subViewOp.getDynamicSizes()), ArrayRef(), - newMemRefType); + auto newSubViewOp = + rewriter.create(subViewOp.getLoc(), subViewOp.source(), + llvm::to_vector<4>(subViewOp.offsets()), + llvm::to_vector<4>(subViewOp.sizes()), + ArrayRef(), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp( - llvm::to_vector<4>(subViewOp.getDynamicStrides()), subViewOp, - newSubViewOp, subViewOp.getType()); + llvm::to_vector<4>(subViewOp.strides()), subViewOp, newSubViewOp, + subViewOp.getType()); return matchSuccess(); } }; @@ -2893,14 +2926,14 @@ public: llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) || baseOffset == MemRefType::getDynamicStrideOrOffset() || - llvm::any_of(subViewOp.getDynamicOffsets(), [](Value *stride) { + llvm::any_of(subViewOp.offsets(), [](Value *stride) { return !matchPattern(stride, m_ConstantIndex()); })) { return matchFailure(); } auto staticOffset = baseOffset; - for (auto offset : enumerate(subViewOp.getDynamicOffsets())) { + for (auto offset : enumerate(subViewOp.offsets())) { auto defOp = offset.value()->getDefiningOp(); assert(defOp); assert(baseStrides[offset.index()] > 0); @@ -2915,39 +2948,17 @@ public: layoutMap, subViewType.getMemorySpace()); auto newSubViewOp = rewriter.create( subViewOp.getLoc(), subViewOp.source(), ArrayRef(), - llvm::to_vector<4>(subViewOp.getDynamicSizes()), - llvm::to_vector<4>(subViewOp.getDynamicStrides()), newMemRefType); + llvm::to_vector<4>(subViewOp.sizes()), + llvm::to_vector<4>(subViewOp.strides()), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp( - llvm::to_vector<4>(subViewOp.getDynamicOffsets()), subViewOp, - newSubViewOp, subViewOp.getType()); + llvm::to_vector<4>(subViewOp.offsets()), subViewOp, newSubViewOp, + subViewOp.getType()); return matchSuccess(); } }; } // end anonymous namespace -SubViewOp::operand_range SubViewOp::getDynamicOffsets() { - auto numOffsets = getNumOffsets(); - assert(getNumOperands() >= numOffsets + 1); - return {operand_begin() + 1, operand_begin() + 1 + numOffsets}; -} - -SubViewOp::operand_range SubViewOp::getDynamicSizes() { - auto numSizes = getNumSizes(); - auto numOffsets = getNumOffsets(); - assert(getNumOperands() >= numSizes + numOffsets + 1); - return {operand_begin() + 1 + numOffsets, - operand_begin() + 1 + numOffsets + numSizes}; -} - -SubViewOp::operand_range SubViewOp::getDynamicStrides() { - auto numSizes = getNumSizes(); - auto numOffsets = getNumOffsets(); - auto numStrides = getNumStrides(); - assert(getNumOperands() >= numSizes + numOffsets + numStrides + 1); - return {operand_begin() + (1 + numOffsets + numSizes), - operand_begin() + (1 + numOffsets + numSizes + numStrides)}; -} void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { diff --git a/third_party/mlir/lib/Dialect/VectorOps/CMakeLists.txt b/third_party/mlir/lib/Dialect/VectorOps/CMakeLists.txt index 590eeed6f41..754e62de14e 100644 --- a/third_party/mlir/lib/Dialect/VectorOps/CMakeLists.txt +++ b/third_party/mlir/lib/Dialect/VectorOps/CMakeLists.txt @@ -1,11 +1,13 @@ add_llvm_library(MLIRVectorOps DialectRegistration.cpp VectorOps.cpp + VectorToVector.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/VectorOps ) add_dependencies(MLIRVectorOps MLIRVectorOpsIncGen) +add_dependencies(MLIRVectorOps MLIRVectorTransformPatternsIncGen) target_link_libraries(MLIRVectorOps MLIRIR) diff --git a/third_party/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/third_party/mlir/lib/Dialect/VectorOps/VectorOps.cpp index fe320b91439..7f3be9d9fa9 100644 --- a/third_party/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/third_party/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -386,10 +386,17 @@ static LogicalResult verify(BroadcastOp op) { if (srcVectorType) { const int64_t srcRank = srcVectorType.getRank(); const int64_t dstRank = dstVectorType.getRank(); - // TODO(ajcbik): implement proper rank testing for broadcast; - // this is just a temporary placeholder check. - if (srcRank > dstRank) { + if (srcRank > dstRank) return op.emitOpError("source rank higher than destination rank"); + // Source has an exact match or singleton value for all trailing dimensions + // (all leading dimensions are simply duplicated). + const int64_t lead = dstRank - srcRank; + for (int64_t i = 0; i < srcRank; i++) { + const int64_t srcDim = srcVectorType.getDimSize(i); + const int64_t dstDim = dstVectorType.getDimSize(lead + i); + if (srcDim != 1 && srcDim != dstDim) + return op.emitOpError("dimension mismatch (") + << srcDim << " vs. " << dstDim << ")"; } } return success(); @@ -988,6 +995,37 @@ static LogicalResult verify(TypeCastOp &op) { return success(); } +//===----------------------------------------------------------------------===// +// CreateMaskOp +//===----------------------------------------------------------------------===// + +ParseResult parseCreateMaskOp(OpAsmParser &parser, OperationState &result) { + auto indexType = parser.getBuilder().getIndexType(); + Type resultType; + SmallVector operandInfo; + return failure( + parser.parseOperandList(operandInfo) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(resultType) || + parser.resolveOperands(operandInfo, indexType, result.operands) || + parser.addTypeToList(resultType, result.types)); +} + +static void print(OpAsmPrinter &p, CreateMaskOp &op) { + p << op.getOperationName() << ' '; + p.printOperands(op.operands()); + p << " : " << op.getResult()->getType(); +} + +static LogicalResult verify(CreateMaskOp &op) { + // Verify that an operand was specified for each result vector each dimension. + if (op.getNumOperands() != + op.getResult()->getType().cast().getRank()) + return op.emitOpError( + "must specify an operand for each result vector dimension"); + return success(); +} + //===----------------------------------------------------------------------===// // IndexTupleOp //===----------------------------------------------------------------------===// diff --git a/third_party/mlir/lib/Conversion/VectorConversions/VectorToVector.cpp b/third_party/mlir/lib/Dialect/VectorOps/VectorToVector.cpp similarity index 98% rename from third_party/mlir/lib/Conversion/VectorConversions/VectorToVector.cpp rename to third_party/mlir/lib/Dialect/VectorOps/VectorToVector.cpp index 74687d449af..1e2e651189f 100644 --- a/third_party/mlir/lib/Conversion/VectorConversions/VectorToVector.cpp +++ b/third_party/mlir/lib/Dialect/VectorOps/VectorToVector.cpp @@ -21,10 +21,9 @@ #include -#include "mlir/Analysis/VectorAnalysis.h" -#include "mlir/Conversion/VectorConversions/VectorConversions.h" -#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Dialect/VectorOps/Utils.h" #include "mlir/Dialect/VectorOps/VectorOps.h" +#include "mlir/Dialect/VectorOps/VectorTransforms.h" #include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Helpers.h" #include "mlir/IR/AffineExpr.h" @@ -198,7 +197,7 @@ static bool hasShape(Value *v, ArrayRef shape) { // // This will be extended in the future to support more advanced use cases than // simple pointwise ops. -static Value *unrollSingleResultOpMatchingType(PatternRewriter &builder, +Value * mlir::vector::unrollSingleResultOpMatchingType(PatternRewriter &builder, Operation *op, ArrayRef targetShape) { LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE diff --git a/third_party/mlir/lib/EDSC/Intrinsics.cpp b/third_party/mlir/lib/EDSC/Intrinsics.cpp index f80726866fc..1b19f9aa0bf 100644 --- a/third_party/mlir/lib/EDSC/Intrinsics.cpp +++ b/third_party/mlir/lib/EDSC/Intrinsics.cpp @@ -16,7 +16,6 @@ // ============================================================================= #include "mlir/EDSC/Intrinsics.h" -#include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/EDSC/Builders.h" #include "mlir/IR/AffineExpr.h" diff --git a/third_party/mlir/lib/IR/Builders.cpp b/third_party/mlir/lib/IR/Builders.cpp index afdeefd023c..4d6cd3550ca 100644 --- a/third_party/mlir/lib/IR/Builders.cpp +++ b/third_party/mlir/lib/IR/Builders.cpp @@ -100,6 +100,14 @@ IntegerAttr Builder::getI64IntegerAttr(int64_t value) { return IntegerAttr::get(getIntegerType(64), APInt(64, value)); } +DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef values) { + return DenseElementsAttr::get( + VectorType::get(static_cast(values.size()), + getIntegerType(32)), + values) + .cast(); +} + IntegerAttr Builder::getI32IntegerAttr(int32_t value) { return IntegerAttr::get(getIntegerType(32), APInt(32, value)); } diff --git a/third_party/mlir/lib/IR/Function.cpp b/third_party/mlir/lib/IR/Function.cpp index 4e103508af0..e5e854260f3 100644 --- a/third_party/mlir/lib/IR/Function.cpp +++ b/third_party/mlir/lib/IR/Function.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" diff --git a/third_party/mlir/lib/IR/FunctionSupport.cpp b/third_party/mlir/lib/IR/FunctionImplementation.cpp similarity index 98% rename from third_party/mlir/lib/IR/FunctionSupport.cpp rename to third_party/mlir/lib/IR/FunctionImplementation.cpp index c6f2673ef2a..66c0d8af6d3 100644 --- a/third_party/mlir/lib/IR/FunctionSupport.cpp +++ b/third_party/mlir/lib/IR/FunctionImplementation.cpp @@ -1,4 +1,4 @@ -//===- FunctionSupport.cpp - Utility types for function-like ops ----------===// +//===- FunctionImplementation.cpp - Utilities for function-like ops -------===// // // Copyright 2019 The MLIR Authors. // @@ -15,9 +15,9 @@ // limitations under the License. // ============================================================================= -#include "mlir/IR/FunctionSupport.h" +#include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/FunctionSupport.h" #include "mlir/IR/SymbolTable.h" using namespace mlir; @@ -71,7 +71,8 @@ parseArgumentList(OpAsmParser &parser, bool allowVariadic, }; // Parse the function arguments. - if (parser.parseOptionalRParen()) { + isVariadic = false; + if (failed(parser.parseOptionalRParen())) { do { unsigned numTypedArguments = argTypes.size(); if (parseArgument()) diff --git a/third_party/mlir/lib/IR/Operation.cpp b/third_party/mlir/lib/IR/Operation.cpp index e5ec43c699b..d079033e39b 100644 --- a/third_party/mlir/lib/IR/Operation.cpp +++ b/third_party/mlir/lib/IR/Operation.cpp @@ -125,17 +125,17 @@ Operation *Operation::create(Location location, OperationName name, /// Create a new Operation from operation state. Operation *Operation::create(const OperationState &state) { - return Operation::create( - state.location, state.name, state.types, state.operands, - NamedAttributeList(state.attributes).getDictionary(), state.successors, - state.regions, state.resizableOperandList); + return Operation::create(state.location, state.name, state.types, + state.operands, NamedAttributeList(state.attributes), + state.successors, state.regions, + state.resizableOperandList); } /// Create a new Operation with the specific fields. Operation *Operation::create(Location location, OperationName name, ArrayRef resultTypes, ArrayRef operands, - const NamedAttributeList &attributes, + NamedAttributeList attributes, ArrayRef successors, ArrayRef> regions, bool resizableOperandList) { @@ -153,7 +153,7 @@ Operation *Operation::create(Location location, OperationName name, Operation *Operation::create(Location location, OperationName name, ArrayRef resultTypes, ArrayRef operands, - const NamedAttributeList &attributes, + NamedAttributeList attributes, ArrayRef successors, unsigned numRegions, bool resizableOperandList) { unsigned numSuccessors = successors.size(); @@ -901,18 +901,21 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { return success(); } -static LogicalResult verifyBBArguments(Operation::operand_range operands, - Block *destBB, Operation *op) { - unsigned operandCount = std::distance(operands.begin(), operands.end()); +static LogicalResult verifySuccessor(Operation *op, unsigned succNo) { + Operation::operand_range operands = op->getSuccessorOperands(succNo); + unsigned operandCount = op->getNumSuccessorOperands(succNo); + Block *destBB = op->getSuccessor(succNo); if (operandCount != destBB->getNumArguments()) return op->emitError() << "branch has " << operandCount - << " operands, but target block has " + << " operands for successor #" << succNo + << ", but target block has " << destBB->getNumArguments(); auto operandIt = operands.begin(); for (unsigned i = 0, e = operandCount; i != e; ++i, ++operandIt) { if ((*operandIt)->getType() != destBB->getArgument(i)->getType()) - return op->emitError() << "type mismatch in bb argument #" << i; + return op->emitError() << "type mismatch for bb argument #" << i + << " of successor #" << succNo; } return success(); @@ -926,7 +929,7 @@ static LogicalResult verifyTerminatorSuccessors(Operation *op) { auto *succ = op->getSuccessor(i); if (succ->getParent() != parent) return op->emitError("reference to block defined in another region"); - if (failed(verifyBBArguments(op->getSuccessorOperands(i), succ, op))) + if (failed(verifySuccessor(op, i))) return failure(); } return success(); diff --git a/third_party/mlir/lib/Pass/Pass.cpp b/third_party/mlir/lib/Pass/Pass.cpp index a195bb0c0c8..6d8e230eeec 100644 --- a/third_party/mlir/lib/Pass/Pass.cpp +++ b/third_party/mlir/lib/Pass/Pass.cpp @@ -533,7 +533,7 @@ static LogicalResult runWithCrashRecovery(OpPassManager &pm, outputFile->keep(); return reproducerModule->emitError() - << "A crash has been detected while processing the MLIR module, a " + << "A failure has been detected while processing the MLIR module, a " "reproducer has been generated in '" << crashReproducerFileName << "'"; } diff --git a/third_party/mlir/lib/TableGen/Attribute.cpp b/third_party/mlir/lib/TableGen/Attribute.cpp index c2b673a7c93..ec946a855fc 100644 --- a/third_party/mlir/lib/TableGen/Attribute.cpp +++ b/third_party/mlir/lib/TableGen/Attribute.cpp @@ -107,12 +107,12 @@ tblgen::Attribute tblgen::Attribute::getBaseAttr() const { return *this; } -bool tblgen::Attribute::hasDefaultValueInitializer() const { +bool tblgen::Attribute::hasDefaultValue() const { const auto *init = def->getValueInit("defaultValue"); return !getValueAsString(init).empty(); } -StringRef tblgen::Attribute::getDefaultValueInitializer() const { +StringRef tblgen::Attribute::getDefaultValue() const { const auto *init = def->getValueInit("defaultValue"); return getValueAsString(init); } diff --git a/third_party/mlir/lib/TableGen/Pattern.cpp b/third_party/mlir/lib/TableGen/Pattern.cpp index d3c1dddd21e..098dba3ae6e 100644 --- a/third_party/mlir/lib/TableGen/Pattern.cpp +++ b/third_party/mlir/lib/TableGen/Pattern.cpp @@ -564,7 +564,8 @@ void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, // We can only bind symbols to op arguments in source pattern. Those // symbols are referenced in result patterns. auto treeArgName = tree.getArgName(i); - if (!treeArgName.empty()) { + // `$_` is a special symbol meaning ignore the current argument. + if (!treeArgName.empty() && treeArgName != "_") { LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: " << treeArgName << '\n'); if (!infoMap.bindOpArgument(treeArgName, op, i)) { diff --git a/third_party/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/third_party/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp index fd4e4134d8b..6cf975bcce2 100644 --- a/third_party/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/third_party/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -215,6 +215,37 @@ Attribute Importer::getConstantAsAttr(llvm::Constant *value) { return Attribute(); } +/// Converts LLVM global variable linkage type into the LLVM dialect predicate. +static LLVM::Linkage +processLinkage(llvm::GlobalVariable::LinkageTypes linkage) { + switch (linkage) { + case llvm::GlobalValue::PrivateLinkage: + return LLVM::Linkage::Private; + case llvm::GlobalValue::InternalLinkage: + return LLVM::Linkage::Internal; + case llvm::GlobalValue::AvailableExternallyLinkage: + return LLVM::Linkage::AvailableExternally; + case llvm::GlobalValue::LinkOnceAnyLinkage: + return LLVM::Linkage::Linkonce; + case llvm::GlobalValue::WeakAnyLinkage: + return LLVM::Linkage::Weak; + case llvm::GlobalValue::CommonLinkage: + return LLVM::Linkage::Common; + case llvm::GlobalValue::AppendingLinkage: + return LLVM::Linkage::Appending; + case llvm::GlobalValue::ExternalWeakLinkage: + return LLVM::Linkage::ExternWeak; + case llvm::GlobalValue::LinkOnceODRLinkage: + return LLVM::Linkage::LinkonceODR; + case llvm::GlobalValue::WeakODRLinkage: + return LLVM::Linkage::WeakODR; + case llvm::GlobalValue::ExternalLinkage: + return LLVM::Linkage::External; + } + + llvm_unreachable("unhandled linkage type"); +} + GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) { auto it = globals.find(GV); if (it != globals.end()) @@ -224,9 +255,10 @@ GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) { Attribute valueAttr; if (GV->hasInitializer()) valueAttr = getConstantAsAttr(GV->getInitializer()); - GlobalOp op = b.create(UnknownLoc::get(context), - processType(GV->getValueType()), - GV->isConstant(), GV->getName(), valueAttr); + GlobalOp op = b.create( + UnknownLoc::get(context), processType(GV->getValueType()), + GV->isConstant(), processLinkage(GV->getLinkage()), GV->getName(), + valueAttr); if (GV->hasInitializer() && !valueAttr) { Region &r = op.getInitializerRegion(); currentEntryBlock = b.createBlock(&r); diff --git a/third_party/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/third_party/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 7f3ce5a738f..f985fed3991 100644 --- a/third_party/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/third_party/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -279,6 +279,35 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) { return success(); } +// Convert the LLVM dialect linkage type to LLVM IR linkage type. +llvm::GlobalVariable::LinkageTypes convertLinkageType(LLVM::Linkage linkage) { + switch (linkage) { + case LLVM::Linkage::Private: + return llvm::GlobalValue::PrivateLinkage; + case LLVM::Linkage::Internal: + return llvm::GlobalValue::InternalLinkage; + case LLVM::Linkage::AvailableExternally: + return llvm::GlobalValue::AvailableExternallyLinkage; + case LLVM::Linkage::Linkonce: + return llvm::GlobalValue::LinkOnceAnyLinkage; + case LLVM::Linkage::Weak: + return llvm::GlobalValue::WeakAnyLinkage; + case LLVM::Linkage::Common: + return llvm::GlobalValue::CommonLinkage; + case LLVM::Linkage::Appending: + return llvm::GlobalValue::AppendingLinkage; + case LLVM::Linkage::ExternWeak: + return llvm::GlobalValue::ExternalWeakLinkage; + case LLVM::Linkage::LinkonceODR: + return llvm::GlobalValue::LinkOnceODRLinkage; + case LLVM::Linkage::WeakODR: + return llvm::GlobalValue::WeakODRLinkage; + case LLVM::Linkage::External: + return llvm::GlobalValue::ExternalLinkage; + } + llvm_unreachable("unknown linkage type"); +} + // Create named global variables that correspond to llvm.mlir.global // definitions. void ModuleTranslation::convertGlobals() { @@ -308,11 +337,15 @@ void ModuleTranslation::convertGlobals() { cst = cast(valueMapping.lookup(ret.getOperand(0))); } + auto linkage = convertLinkageType(op.linkage()); + bool anyExternalLinkage = + (linkage == llvm::GlobalVariable::ExternalLinkage || + linkage == llvm::GlobalVariable::ExternalWeakLinkage); auto addrSpace = op.addr_space().getLimitedValue(); auto *var = new llvm::GlobalVariable( - *llvmModule, type, op.constant(), llvm::GlobalValue::InternalLinkage, - cst, op.sym_name(), /*InsertBefore=*/nullptr, - llvm::GlobalValue::NotThreadLocal, addrSpace); + *llvmModule, type, op.constant(), linkage, + anyExternalLinkage ? nullptr : cst, op.sym_name(), + /*InsertBefore=*/nullptr, llvm::GlobalValue::NotThreadLocal, addrSpace); globalsMapping.try_emplace(op, var); } diff --git a/third_party/mlir/lib/Transforms/LoopFusion.cpp b/third_party/mlir/lib/Transforms/LoopFusion.cpp index 7985ca1c5ef..cda35297abc 100644 --- a/third_party/mlir/lib/Transforms/LoopFusion.cpp +++ b/third_party/mlir/lib/Transforms/LoopFusion.cpp @@ -1005,17 +1005,19 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, // Checks if node 'srcId' can be safely fused into node 'dstId'. Node 'srcId' // may write to multiple memrefs but it is required that only one of them, -// 'srcLiveOutStoreOp', have an output edge. +// 'srcLiveOutStoreOp', has output edges. // Returns true if 'dstNode's read/write region to 'memref' is a super set of -// 'srcNode's write region to 'memref'. +// 'srcNode's write region to 'memref' and 'srcId' has only one output edge. // TODO(andydavis) Generalize this to handle more live in/out cases. static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, AffineStoreOp srcLiveOutStoreOp, MemRefDependenceGraph *mdg) { assert(srcLiveOutStoreOp && "Expected a valid store op"); - assert(mdg->getOutEdgeCount(srcId) == 1 && "Expected only one output edge"); auto *dstNode = mdg->getNode(dstId); Value *memref = srcLiveOutStoreOp.getMemRef(); + // Return false if 'srcNode' has more than one output edge on 'memref'. + if (mdg->getOutEdgeCount(srcId, memref) > 1) + return false; // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOp' on 'memref'. MemRefRegion srcWriteRegion(srcLiveOutStoreOp.getLoc()); diff --git a/third_party/mlir/lib/Transforms/MaterializeVectors.cpp b/third_party/mlir/lib/Transforms/MaterializeVectors.cpp index 874eac6e4e6..33f5927d88e 100644 --- a/third_party/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/third_party/mlir/lib/Transforms/MaterializeVectors.cpp @@ -26,9 +26,9 @@ #include "mlir/Analysis/NestedMatcher.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/Utils.h" -#include "mlir/Analysis/VectorAnalysis.h" #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Dialect/VectorOps/Utils.h" #include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" diff --git a/third_party/mlir/lib/Transforms/Vectorize.cpp b/third_party/mlir/lib/Transforms/Vectorize.cpp index 2a0ce092f81..c1e0a9c0e13 100644 --- a/third_party/mlir/lib/Transforms/Vectorize.cpp +++ b/third_party/mlir/lib/Transforms/Vectorize.cpp @@ -24,9 +24,9 @@ #include "mlir/Analysis/NestedMatcher.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/Utils.h" -#include "mlir/Analysis/VectorAnalysis.h" #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Dialect/VectorOps/Utils.h" #include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" @@ -589,6 +589,13 @@ makePatterns(const llvm::DenseSet ¶llelLoops, int vectorRank, } } +static NestedPattern &vectorTransferPattern() { + static auto pattern = matcher::Op([](Operation &op) { + return isa(op) || isa(op); + }); + return pattern; +} + namespace { /// Base state for the vectorize pass. @@ -893,7 +900,8 @@ isVectorizableLoopPtrFactory(const llvm::DenseSet ¶llelLoops, if (parallelIt == parallelLoops.end()) return false; int memRefDim = -1; - auto vectorizableBody = isVectorizableLoopBody(loop, &memRefDim); + auto vectorizableBody = + isVectorizableLoopBody(loop, &memRefDim, vectorTransferPattern()); if (!vectorizableBody) return false; return memRefDim == -1 || fastestVaryingMemRefDimension == -1 || @@ -1172,7 +1180,7 @@ static LogicalResult vectorizeRootMatch(NestedMatch m, // vectorizable. If a pattern is not vectorizable anymore, we just skip it. // TODO(ntv): implement a non-greedy profitability analysis that keeps only // non-intersecting patterns. - if (!isVectorizableLoopBody(loop)) { + if (!isVectorizableLoopBody(loop, vectorTransferPattern())) { LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable"); return failure(); } diff --git a/third_party/mlir/test/BUILD b/third_party/mlir/test/BUILD index 85680008d6d..25f7b8399eb 100644 --- a/third_party/mlir/test/BUILD +++ b/third_party/mlir/test/BUILD @@ -133,9 +133,9 @@ cc_library( "lib/Transforms/TestLoopFusion.cpp", "lib/Transforms/TestLoopMapping.cpp", "lib/Transforms/TestLoopParametricTiling.cpp", - "lib/Transforms/TestLowerVectorTransfers.cpp", "lib/Transforms/TestMemRefStrideCalculation.cpp", "lib/Transforms/TestOpaqueLoc.cpp", + "lib/Transforms/TestVectorToLoopsConversion.cpp", "lib/Transforms/TestVectorToVectorConversion.cpp", "lib/Transforms/TestVectorizationUtils.cpp", ], @@ -155,8 +155,9 @@ cc_library( "@local_config_mlir//:Support", "@local_config_mlir//:TransformUtils", "@local_config_mlir//:Transforms", - "@local_config_mlir//:VectorConversions", "@local_config_mlir//:VectorOps", + "@local_config_mlir//:VectorToLLVM", + "@local_config_mlir//:VectorToLoops", ], alwayslink = 1, ) diff --git a/third_party/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/third_party/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td index 839671c866a..97e0cb21704 100644 --- a/third_party/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td +++ b/third_party/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td @@ -73,4 +73,11 @@ def : Pattern<(DotOp:$op $a, $b, $c), [(TileLinalgOp<[8], "REG"> $op)], [(Constraint> $op)]>; +//===----------------------------------------------------------------------===// +// Linalg to loops patterns. +//===----------------------------------------------------------------------===// +def : Pattern<(DotOp:$op $a, $b, $c), + [(LinalgOpToLoops<"DotOp"> $op)], + [(Constraint> $op)]>; + #endif // TEST_LINALG_TRANSFORMS_PATTERNS diff --git a/third_party/mlir/test/lib/TestDialect/TestOps.td b/third_party/mlir/test/lib/TestDialect/TestOps.td index 6bb0cbc7f0c..6952eaa7717 100644 --- a/third_party/mlir/test/lib/TestDialect/TestOps.td +++ b/third_party/mlir/test/lib/TestDialect/TestOps.td @@ -479,6 +479,18 @@ def OpJ : TEST_Op<"op_j">, Arguments<(ins)>, Results<(outs I32)>; def OpK : TEST_Op<"op_k">, Arguments<(ins)>, Results<(outs I32)>; def : Pat<(OpJ), (OpK)>; +// Test `$_` for ignoring op argument match. +def TestIgnoreArgMatchSrcOp : TEST_Op<"ignore_arg_match_src"> { + let arguments = (ins + AnyType:$a, AnyType:$b, AnyType:$c, + AnyAttr:$d, AnyAttr:$e, AnyAttr:$f); +} +def TestIgnoreArgMatchDstOp : TEST_Op<"ignore_arg_match_dst"> { + let arguments = (ins AnyType:$b, AnyAttr:$f); +} +def : Pat<(TestIgnoreArgMatchSrcOp $_, $b, I32, I64Attr:$_, $_, $f), + (TestIgnoreArgMatchDstOp $b, $f)>; + def OpInterleavedOperandAttribute1 : TEST_Op<"interleaved_operand_attr1"> { let arguments = (ins I32:$input1, diff --git a/third_party/mlir/test/lib/Transforms/CMakeLists.txt b/third_party/mlir/test/lib/Transforms/CMakeLists.txt index 2d482e5f1a5..8bc9c736187 100644 --- a/third_party/mlir/test/lib/Transforms/CMakeLists.txt +++ b/third_party/mlir/test/lib/Transforms/CMakeLists.txt @@ -6,9 +6,9 @@ add_llvm_library(MLIRTestTransforms TestLinalgTransforms.cpp TestLoopMapping.cpp TestLoopParametricTiling.cpp - TestLowerVectorTransfers.cpp TestOpaqueLoc.cpp TestMemRefStrideCalculation.cpp + TestVectorToLoopsConversion.cpp TestVectorToVectorConversion.cpp TestVectorizationUtils.cpp diff --git a/third_party/mlir/test/lib/Transforms/TestLowerVectorTransfers.cpp b/third_party/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp similarity index 71% rename from third_party/mlir/test/lib/Transforms/TestLowerVectorTransfers.cpp rename to third_party/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp index 8341777f6a4..e5f5f749bd0 100644 --- a/third_party/mlir/test/lib/Transforms/TestLowerVectorTransfers.cpp +++ b/third_party/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp @@ -1,4 +1,4 @@ -//===- TestLowerVectorTransfers.cpp - Test VectorTransfers lowering -------===// +//===- TestVectorToLoopsConversion.cpp - Test VectorTransfers lowering ----===// // // Copyright 2019 The MLIR Authors. // @@ -17,7 +17,7 @@ #include -#include "mlir/Conversion/VectorConversions/VectorConversions.h" +#include "mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/Passes.h" @@ -26,8 +26,8 @@ using namespace mlir; namespace { -struct TestLowerVectorTransfersPass - : public FunctionPass { +struct TestVectorToLoopsPass + : public FunctionPass { void runOnFunction() override { OwningRewritePatternList patterns; auto *context = &getContext(); @@ -38,7 +38,6 @@ struct TestLowerVectorTransfersPass } // end anonymous namespace -static PassRegistration - pass("test-affine-lower-vector-transfers", - "Materializes vector transfer ops to a " - "proper abstraction for the hardware"); +static PassRegistration + pass("test-convert-vector-to-loops", + "Converts vector transfer ops to loops over scalars and vector casts"); diff --git a/third_party/mlir/test/lib/Transforms/TestVectorToVectorConversion.cpp b/third_party/mlir/test/lib/Transforms/TestVectorToVectorConversion.cpp index 2550796ade2..9f9b8a554fe 100644 --- a/third_party/mlir/test/lib/Transforms/TestVectorToVectorConversion.cpp +++ b/third_party/mlir/test/lib/Transforms/TestVectorToVectorConversion.cpp @@ -18,7 +18,7 @@ #include -#include "mlir/Conversion/VectorConversions/VectorConversions.h" +#include "mlir/Dialect/VectorOps/VectorTransforms.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/Passes.h" diff --git a/third_party/mlir/test/lib/Transforms/TestVectorizationUtils.cpp b/third_party/mlir/test/lib/Transforms/TestVectorizationUtils.cpp index f0f1f6b0b23..7efc74f2304 100644 --- a/third_party/mlir/test/lib/Transforms/TestVectorizationUtils.cpp +++ b/third_party/mlir/test/lib/Transforms/TestVectorizationUtils.cpp @@ -22,8 +22,8 @@ #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/NestedMatcher.h" #include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Analysis/VectorAnalysis.h" #include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/VectorOps/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/StandardTypes.h" diff --git a/third_party/mlir/tools/mlir-opt/CMakeLists.txt b/third_party/mlir/tools/mlir-opt/CMakeLists.txt index e38b43d59b8..b30d7e39ce8 100644 --- a/third_party/mlir/tools/mlir-opt/CMakeLists.txt +++ b/third_party/mlir/tools/mlir-opt/CMakeLists.txt @@ -21,7 +21,7 @@ set(LIBS MLIRAffineToStandard MLIRLoopsToGPU MLIRLinalgToLLVM - + MLIRLoopToStandard MLIREDSC MLIRFxpMathOps @@ -51,7 +51,8 @@ set(LIBS MLIRTestTransforms MLIRSupport MLIRVectorOps - MLIRVectorConversions + MLIRVectorToLLVM + MLIRVectorToLoops ) if(MLIR_CUDA_CONVERSIONS_ENABLED) list(APPEND LIBS diff --git a/third_party/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/third_party/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 864f7734f8a..16894ad4cb3 100644 --- a/third_party/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/third_party/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -32,6 +32,8 @@ #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" +#define DEBUG_TYPE "mlir-tblgen-opdefgen" + using namespace llvm; using namespace mlir; using namespace mlir::tblgen; @@ -113,6 +115,14 @@ static std::string getArgumentName(const Operator &op, int index) { return formatv("{0}_{1}", generatedArgName, index); } +// Returns true if we can use unwrapped value for the given `attr` in builders. +static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) { + return attr.getReturnType() != attr.getStorageType() && + // We need to wrap the raw value into an attribute in the builder impl + // so we need to make sure that the attribute specifies how to do that. + !attr.getConstBuilderTemplate().empty(); +} + namespace { // Simple RAII helper for defining ifdef-undef-endif scopes. class IfDefScope { @@ -506,46 +516,66 @@ private: void genBuilder(); // Generates the build() method that takes each result-type/operand/attribute - // as a stand-alone parameter. This build() method also requires specifying - // result types for all results. - void genSeparateParamBuilder(); + // as a stand-alone parameter. Attributes will take wrapped mlir::Attribute + // values. The generated build() method also requires specifying result types + // for all results. + void genSeparateParamWrappedAttrBuilder(); + + // Generates the build() method that takes each result-type/operand/attribute + // as a stand-alone parameter. Attributes will take raw values without + // mlir::Attribute wrapper. The generated build() method also requires + // specifying result types for all results. + void genSeparateParamUnwrappedAttrBuilder(); // Generates the build() method that takes a single parameter for all the // result types and a separate parameter for each operand/attribute. void genCollectiveTypeParamBuilder(); // Generates the build() method that takes each operand/attribute as a - // stand-alone parameter. This build() method uses first operand's type - // as all results' types. + // stand-alone parameter. The generated build() method uses first operand's + // type as all results' types. void genUseOperandAsResultTypeSeparateParamBuilder(); // Generates the build() method that takes all operands/attributes - // collectively as one parameter. This build() method uses first operand's - // type as all results' types. + // collectively as one parameter. The generated build() method uses first + // operand's type as all results' types. void genUseOperandAsResultTypeCollectiveParamBuilder(); // Generates the build() method that takes each operand/attribute as a - // stand-alone parameter. This build() method uses first attribute's type - // as all result's types. + // stand-alone parameter. The generated build() method uses first attribute's + // type as all result's types. void genUseAttrAsResultTypeBuilder(); // Generates the build() method that takes all result types collectively as // one parameter. Similarly for operands and attributes. void genCollectiveParamBuilder(); - enum class TypeParamKind { None, Separate, Collective }; + // The kind of parameter to generate for result types in builders. + enum class TypeParamKind { + None, // No result type in parameter list. + Separate, // A separate parameter for each result type. + Collective, // An ArrayRef for all result types. + }; + + // The kind of parameter to generate for attributes in builders. + enum class AttrParamKind { + WrappedAttr, // A wrapped MLIR Attribute instance. + UnwrappedValue, // A raw value without MLIR Attribute wrapper. + }; // Builds the parameter list for build() method of this op. This method writes - // to `paramList` the comma-separated parameter list. If `includeResultTypes` - // is true then `paramList` will also contain the parameters for all results - // and `resultTypeNames` will be populated with the parameter name for each - // result type. + // to `paramList` the comma-separated parameter list and updates + // `resultTypeNames` with the names for parameters for specifying result + // types. The given `typeParamKind` and `attrParamKind` controls how result + // types and attributes are placed in the parameter list. void buildParamList(std::string ¶mList, SmallVectorImpl &resultTypeNames, - TypeParamKind kind); + TypeParamKind typeParamKind, + AttrParamKind attrParamKind = AttrParamKind::WrappedAttr); // Adds op arguments and regions into operation state for build() methods. - void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body); + void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, + bool isRawValueAttr = false); // Generates canonicalizer declaration for the operation. void genCanonicalizerDecls(); @@ -650,18 +680,18 @@ void OpEmitter::genAttrGetters() { // Return the queried attribute with the correct return type. auto attrVal = - (attr.hasDefaultValueInitializer() || attr.isOptional()) + (attr.hasDefaultValue() || attr.isOptional()) ? formatv("this->getAttr(\"{0}\").dyn_cast_or_null<{1}>()", name, attr.getStorageType()) : formatv("this->getAttr(\"{0}\").cast<{1}>()", name, attr.getStorageType()); body << " auto attr = " << attrVal << ";\n"; - if (attr.hasDefaultValueInitializer()) { + if (attr.hasDefaultValue()) { // Returns the default value if not set. // TODO: this is inefficient, we are recreating the attribute for every // call. This should be set instead. - std::string defaultValue = tgfmt(attr.getConstBuilderTemplate(), &fctx, - attr.getDefaultValueInitializer()); + std::string defaultValue = + tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()); body << " if (!attr)\n return " << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf(defaultValue)) @@ -847,7 +877,7 @@ void OpEmitter::genNamedRegionGetters() { } } -void OpEmitter::genSeparateParamBuilder() { +void OpEmitter::genSeparateParamWrappedAttrBuilder() { std::string paramList; llvm::SmallVector resultNames; buildParamList(paramList, resultNames, TypeParamKind::Separate); @@ -862,6 +892,42 @@ void OpEmitter::genSeparateParamBuilder() { } } +void OpEmitter::genSeparateParamUnwrappedAttrBuilder() { + // If this op does not have native attributes at all, return directly to avoid + // redefining builders. + if (op.getNumNativeAttributes() == 0) + return; + + bool canGenerate = false; + // We are generating builders that take raw values for attributes. We need to + // make sure the native attributes have a meaningful "unwrapped" value type + // different from the wrapped mlir::Attribute type to avoid redefining + // builders. This checks for the op has at least one such native attribute. + for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) { + NamedAttribute &namedAttr = op.getAttribute(i); + if (canUseUnwrappedRawValue(namedAttr.attr)) { + canGenerate = true; + break; + } + } + if (!canGenerate) + return; + + std::string paramList; + llvm::SmallVector resultNames; + buildParamList(paramList, resultNames, TypeParamKind::Separate, + AttrParamKind::UnwrappedValue); + + auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static); + genCodeForAddingArgAndRegionForBuilder(m.body(), /*isRawValueAttr=*/true); + + // Push all result types to the operation state. + for (int i = 0, e = op.getNumResults(); i < e; ++i) { + m.body() << " " << builderOpState << ".addTypes(" << resultNames[i] + << ");\n"; + } +} + void OpEmitter::genCollectiveTypeParamBuilder() { auto numResults = op.getNumResults(); @@ -1006,7 +1072,8 @@ void OpEmitter::genBuilder() { // We generate three builders here: // 1. one having a stand-alone parameter for each result type / operand / // attribute, and - genSeparateParamBuilder(); + genSeparateParamWrappedAttrBuilder(); + genSeparateParamUnwrappedAttrBuilder(); // 2. one having a stand-alone parameter for each operand / attribute and // an aggregated parameter for all result types, and genCollectiveTypeParamBuilder(); @@ -1069,15 +1136,16 @@ void OpEmitter::genCollectiveParamBuilder() { void OpEmitter::buildParamList(std::string ¶mList, SmallVectorImpl &resultTypeNames, - TypeParamKind kind) { + TypeParamKind typeParamKind, + AttrParamKind attrParamKind) { resultTypeNames.clear(); auto numResults = op.getNumResults(); resultTypeNames.reserve(numResults); - paramList = "Builder *, OperationState &"; + paramList = "Builder *tblgen_builder, OperationState &"; paramList.append(builderOpState); - switch (kind) { + switch (typeParamKind) { case TypeParamKind::None: break; case TypeParamKind::Separate: { @@ -1100,10 +1168,36 @@ void OpEmitter::buildParamList(std::string ¶mList, } break; } + // Add parameters for all arguments (operands and attributes). + int numOperands = 0; int numAttrs = 0; - // Add parameters for all arguments (operands and attributes). + int defaultValuedAttrStartIndex = op.getNumArgs(); + if (attrParamKind == AttrParamKind::UnwrappedValue) { + // Calculate the start index from which we can attach default values in the + // builder declaration. + for (int i = op.getNumArgs() - 1; i >= 0; --i) { + auto *namedAttr = op.getArg(i).dyn_cast(); + if (!namedAttr || !namedAttr->attr.hasDefaultValue()) + break; + + if (!canUseUnwrappedRawValue(namedAttr->attr)) + break; + + // Creating an APInt requires us to provide bitwidth, value, and + // signedness, which is complicated compared to others. Similarly + // for APFloat. + // TODO(b/144412160) Adjust the 'returnType' field of such attributes + // to support them. + StringRef retType = namedAttr->attr.getReturnType(); + if (retType == "APInt" || retType == "APFloat") + break; + + defaultValuedAttrStartIndex = i; + } + } + for (int i = 0, e = op.getNumArgs(); i < e; ++i) { auto argument = op.getArg(i); if (argument.is()) { @@ -1113,24 +1207,46 @@ void OpEmitter::buildParamList(std::string ¶mList, paramList.append(getArgumentName(op, numOperands)); ++numOperands; } else { - // TODO(antiagainst): Support default initializer for attributes const auto &namedAttr = op.getAttribute(numAttrs); const auto &attr = namedAttr.attr; paramList.append(", "); + if (attr.isOptional()) paramList.append("/*optional*/"); - paramList.append(attr.getStorageType()); + + switch (attrParamKind) { + case AttrParamKind::WrappedAttr: + paramList.append(attr.getStorageType()); + break; + case AttrParamKind::UnwrappedValue: + if (canUseUnwrappedRawValue(attr)) { + paramList.append(attr.getReturnType()); + } else { + paramList.append(attr.getStorageType()); + } + break; + } paramList.append(" "); paramList.append(namedAttr.name); + + // Attach default value if requested and possible. + if (attrParamKind == AttrParamKind::UnwrappedValue && + i >= defaultValuedAttrStartIndex) { + bool isString = attr.getReturnType() == "StringRef"; + paramList.append(" = "); + if (isString) + paramList.append("\""); + paramList.append(attr.getDefaultValue()); + if (isString) + paramList.append("\""); + } ++numAttrs; } } - - if (numOperands + numAttrs != op.getNumArgs()) - PrintFatalError("op arguments must be either operands or attributes"); } -void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body) { +void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, + bool isRawValueAttr) { // Push all operands to the result for (int i = 0, e = op.getNumOperands(); i < e; ++i) { body << " " << builderOpState << ".addOperands(" << getArgumentName(op, i) @@ -1139,13 +1255,25 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body) { // Push all attributes to the result for (const auto &namedAttr : op.getAttributes()) { - if (!namedAttr.attr.isDerivedAttr()) { - bool emitNotNullCheck = namedAttr.attr.isOptional(); + auto &attr = namedAttr.attr; + if (!attr.isDerivedAttr()) { + bool emitNotNullCheck = attr.isOptional(); if (emitNotNullCheck) { body << formatv(" if ({0}) ", namedAttr.name) << "{\n"; } - body << formatv(" {0}.addAttribute(\"{1}\", {1});\n", builderOpState, - namedAttr.name); + if (isRawValueAttr && canUseUnwrappedRawValue(attr)) { + // If this is a raw value, then we need to wrap it in an Attribute + // instance. + FmtContext fctx; + fctx.withBuilder("(*tblgen_builder)"); + std::string value = + tgfmt(attr.getConstBuilderTemplate(), &fctx, namedAttr.name); + body << formatv(" {0}.addAttribute(\"{1}\", {2});\n", builderOpState, + namedAttr.name, value); + } else { + body << formatv(" {0}.addAttribute(\"{1}\", {1});\n", builderOpState, + namedAttr.name); + } if (emitNotNullCheck) { body << " }\n"; } @@ -1282,8 +1410,7 @@ void OpEmitter::genVerifier() { body << formatv(" auto {0} = this->getAttr(\"{1}\");\n", varName, attrName); - bool allowMissingAttr = - attr.hasDefaultValueInitializer() || attr.isOptional(); + bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional(); if (allowMissingAttr) { // If the attribute has a default value, then only verify the predicate if // set. This does effectively assume that the default value is valid. diff --git a/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp b/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp index d2776e05805..d321b204f4e 100644 --- a/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -315,7 +315,8 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth, // Capture the value auto name = tree.getArgName(argIndex); - if (!name.empty()) { + // `$_` is a special symbol to ignore op argument matching. + if (!name.empty() && name != "_") { // We need to subtract the number of attributes before this operand to get // the index in the operand list. auto numPrevAttrs = std::count_if( @@ -329,6 +330,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth, void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth, int indent) { + Operator &op = tree.getDialectOp(opMap); auto *namedAttr = op.getArg(argIndex).get(); const auto &attr = namedAttr->attr; @@ -340,10 +342,10 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth, attr.getStorageType(), namedAttr->name); // TODO(antiagainst): This should use getter method to avoid duplication. - if (attr.hasDefaultValueInitializer()) { + if (attr.hasDefaultValue()) { os.indent(indent) << "if (!tblgen_attr) tblgen_attr = " << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, - attr.getDefaultValueInitializer()) + attr.getDefaultValue()) << ";\n"; } else if (attr.isOptional()) { // For a missing attribute that is optional according to definition, we @@ -371,7 +373,8 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth, // Capture the value auto name = tree.getArgName(argIndex); - if (!name.empty()) { + // `$_` is a special symbol to ignore op argument matching. + if (!name.empty() && name != "_") { os.indent(indent) << formatv("{0} = tblgen_attr;\n", name); } diff --git a/third_party/mlir/tools/mlir-tblgen/mlir-tblgen.cpp b/third_party/mlir/tools/mlir-tblgen/mlir-tblgen.cpp index 50b680d904d..993a05d7095 100644 --- a/third_party/mlir/tools/mlir-tblgen/mlir-tblgen.cpp +++ b/third_party/mlir/tools/mlir-tblgen/mlir-tblgen.cpp @@ -74,7 +74,10 @@ const mlir::GenInfo *generator; // TableGenMain requires a function pointer so this function is passed in which // simply wraps the call to the generator. static bool MlirTableGenMain(raw_ostream &os, RecordKeeper &records) { - assert(generator && "no generator specified"); + if (!generator) { + os << records; + return false; + } return generator->invoke(records, os); } diff --git a/third_party/mlir/utils/spirv/gen_spirv_dialect.py b/third_party/mlir/utils/spirv/gen_spirv_dialect.py index 9aed98dba70..d1530f77d5a 100755 --- a/third_party/mlir/utils/spirv/gen_spirv_dialect.py +++ b/third_party/mlir/utils/spirv/gen_spirv_dialect.py @@ -303,7 +303,10 @@ def update_td_enum_attrs(path, operand_kinds, filter_list): # Sort alphabetically according to enum name defs.sort(key=lambda enum : enum[0]) # Only keep the definitions from now on - defs = [enum[1] for enum in defs] + # Put Capability's definition at the very beginning because capability cases + # will be referenced later + defs = [enum[1] for enum in defs if enum[0] == 'Capability' + ] + [enum[1] for enum in defs if enum[0] != 'Capability'] # Substitute the old section content = content[0] + AUTOGEN_ENUM_SECTION_MARKER + '\n\n' + \ @@ -423,7 +426,11 @@ def get_op_definition(instruction, doc, existing_info): # Make sure we have ', ' to separate the category arguments from traits category_args = category_args.rstrip(', ') + ', ' - summary, text = doc.split('\n', 1) + if '\n' in doc: + summary, text = doc.split('\n', 1) + else: + summary = doc + text = '' wrapper = textwrap.TextWrapper( width=76, initial_indent=' ', subsequent_indent=' ')