diff --git a/ISSUES.md b/ISSUES.md index aabd3158b39..a6c77f76950 100644 --- a/ISSUES.md +++ b/ISSUES.md @@ -1,7 +1,9 @@ -If you open a GitHub Issue, here is our policy: 1. It must be a bug/performance -issue or a feature request or a build issue or a documentation issue (for small -doc fixes please send a PR instead). 2. Make sure the Issue Template is filled -out. 3. The issue should be related to the repo it is created in. +If you open a GitHub Issue, here is our policy: + +1. It must be a bug/performance issue or a feature request or a build issue or + a documentation issue (for small doc fixes please send a PR instead). +1. Make sure the Issue Template is filled out. +1. The issue should be related to the repo it is created in. **Here's why we have this policy:** We want to focus on the work that benefits the whole community, e.g., fixing bugs and adding features. Individual support diff --git a/RELEASE.md b/RELEASE.md index 237e03f11fa..9ccef55583a 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -54,6 +54,7 @@ * Corrected higher-order gradients of control flow constructs (`tf.cond`, `tf.while_loop`, and compositions like `tf.foldl`) computed with `tf.GradientTape` inside a `tf.function`. + * Changed the default step size in `gradient_checker_v2.compute_gradients` to be exactly representable as a binary floating point numbers. This avoids poluting gradient approximations needlessly, which is some cases leads to false negatives in op gradient tests. * `tf.summary`: * New `tf.summary.graph` allows manual write of TensorFlow graph @@ -65,6 +66,19 @@ supported MSVC version to 16.4 (current: 16.8). * See: https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion +* TensorRT + * Removed the deprecated `session_config` parameter for the TF1-TRT + converter `TrtGraphConverter`. Previously, we issued a warning when the + value of the parameter is not None. + * The TF2-TRT converter `TrtGraphConverterV2` takes an object of class + TrtConversionParams as a parameter. Removed three deprecated fields from + this class: `rewriter_config_template`, `is_dynamic_op`, and + `max_batch_size`. Previously, we issued a warning when the value of + `rewriter_config_template` is not None. We issued an error when the + value of `is_dynamic_op` is not True. We didn't use the value for + `max_batch_size` for building TensorRT engines. + * Issue a warning when function get_tensorrt_rewriter_config is used. + ## Thanks to our Contributors This release contains contributions from many people at Google, as well as: diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 659bed6e6b8..5e78db99a52 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -72,6 +72,14 @@ config_setting( visibility = ["//visibility:public"], ) +# Config setting that disables the default logger, only logging +# to registered TFLogSinks +config_setting( + name = "no_default_logger", + define_values = {"no_default_logger": "true"}, + visibility = ["//visibility:public"], +) + # Config setting for determining if we are building for Android. config_setting( name = "android", @@ -732,6 +740,7 @@ tf_cc_shared_object( "//tensorflow/c/experimental/filesystem:filesystem_interface", "//tensorflow/c/experimental/stream_executor:stream_executor_hdrs", "//tensorflow/c:kernels_hdrs", + "//tensorflow/c:logging", "//tensorflow/c:ops_hdrs", "//tensorflow/cc/saved_model:loader_lite_impl", "//tensorflow/core/common_runtime:core_cpu_impl", diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 08dd5d0820e..279e3108318 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -301,7 +301,7 @@ tf_cuda_cc_test( ], args = ["--heap_check=local"], linkstatic = tf_kernel_tests_linkstatic(), - tags = tf_cuda_tests_tags(), + tags = tf_cuda_tests_tags() + ["no_cuda_asan"], # b/173654156 deps = [ ":c_api_experimental", ":c_api_unified_internal", @@ -469,6 +469,7 @@ tf_cuda_cc_test( linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + [ "nomac", + "no_cuda_asan", # b/173825513 ], deps = [ ":abstract_tensor_handle", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 88705cf3058..00b45209edb 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -61,6 +61,7 @@ limitations under the License. // PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc. #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE) #include "tensorflow/core/tfrt/eager/c_api_tfrt.h" +#include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed.h" #endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE #if !defined(IS_MOBILE_PLATFORM) @@ -101,7 +102,14 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { if (opts->use_tfrt) { #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE) - return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async)); + tfrt::tf::ContextInterface* tfrt_context = + new tfrt::tf::ContextInterface(opts->async); +#if !defined(IS_MOBILE_PLATFORM) + tfrt_context->SetDistributedManager( + std::make_unique<tfrt::tf::DistributedManagerContextInterface>( + tfrt_context->GetCoreRuntime()->GetHostContext())); +#endif // !IS_MOBILE_PLATFORM + return tensorflow::wrap(tfrt_context); #else status->status = tensorflow::errors::Unimplemented("TFRT is not supported"); return nullptr; diff --git a/tensorflow/c/eager/gradients.cc b/tensorflow/c/eager/gradients.cc index 58ffcf247cf..8adc51328fa 100644 --- a/tensorflow/c/eager/gradients.cc +++ b/tensorflow/c/eager/gradients.cc @@ -226,7 +226,7 @@ void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const { // Helper functions which delegate to `AbstractOperation`, update // the state of the ForwardOperation and call the tape as appropriate. -// These APIs are mainly to faciliate testing and are subject to change. +// These APIs are mainly to facilitate testing and are subject to change. namespace internal { Status Reset(AbstractOperation* op_, const char* op, const char* raw_device_name, ForwardOperation* forward_op_) { diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index a0860da7b04..b54dcf942c7 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -39,7 +39,7 @@ struct XlaAutoJitFlag { int32 optimization_level_general; }; -// Sets the xla_auto_jit_flag based on the given flag sting. Supported syntax +// Sets the xla_auto_jit_flag based on the given flag string. Supported syntax // is: // <number>: sets general and single_gpu setting to the provided number. // single-gpu(<number>): sets the single_gpu setting to the provided number. diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc index a549c99d9b7..baca8b99088 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator.cc @@ -103,7 +103,9 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr, if (flr->config_proto()) { config_proto = *flr->config_proto(); } - if (!IsMlirBridgePassEnabled(*fbody->graph, config_proto)) { + MlirBridgeRolloutPolicy policy = + GetMlirBridgeRolloutPolicy(*fbody->graph, config_proto); + if (policy != MlirBridgeRolloutPolicy::kEnabledByUser) { RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map; if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) { std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo> diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 0ebf3ac6bd1..dd6d3b3fbaf 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -1046,6 +1046,7 @@ def HLO_ReshapeOp: HLO_Op<"reshape", let results = (outs HLO_StaticShapeTensor); let hasFolder = 1; + let hasCanonicalizer = 1; let hasCustomHLOConverter = 1; } diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td index b3708bf4ff1..4613627c36d 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td @@ -202,9 +202,9 @@ def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> { let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input, Arg<LHLO_Buffer, "", [MemWrite]>:$output, - Arg<UntypedBuffer, "", [MemWrite]>:$scratch, + Arg<LHLO_Buffer, "", [MemWrite]>:$scratch, Arg<I32Buffer, "", [MemWrite]>:$info, - BoolAttr:$is_upper); + BoolAttr:$is_lower); } #endif // LHLO_GPU_OPS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index 0ee44f3202f..ff052b53098 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -197,10 +197,11 @@ def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer>, BASE_HLO //===----------------------------------------------------------------------===// // TODO(b/139813999): specify required function signature in a type-safe way. -def LHLO_ReduceOp: LHLO_Op<"reduce", [ - SameVariadicOperandSize, - SingleBlockImplicitTerminator<"TerminatorOp"> - ]>, BASE_HLO_ReduceOp { +// +// The region `body` may return lmhlo.TerminatorOp or mhlo.ReturnOp. We are +// moving towards mhlo.ReturnOp, but some code that needs cleanup still assumes lmhlo.TerminatorOp. +// TODO(timshen): cleanup lmhlo.TerminatorOp. +def LHLO_ReduceOp: LHLO_Op<"reduce", [SameVariadicOperandSize]>, BASE_HLO_ReduceOp { let arguments = (ins Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands, Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$init_values, diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc index 389c5794c91..2b92afe956b 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -1939,6 +1939,12 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) { return {}; } +void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList& results, + MLIRContext* context) { + results.insert<IdentityBroadcastReshape, IdentityBroadcastInDimReshape>( + context); +} + //===----------------------------------------------------------------------===// // Case Op //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_patterns.td b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_patterns.td index 776732b178b..01564b86381 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_patterns.td +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_patterns.td @@ -31,3 +31,15 @@ def DynamicBroadcastToOwnShape_2 : Pat< def ShapeOfDynamicReshape : Pat< (Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)), (replaceWithValue $shape)>; + +def HasSameType : Constraint<CPred<"$0.getType() == $1.getType()">>; + +def IdentityBroadcastReshape : Pat< + (HLO_ReshapeOp:$op (HLO_BroadcastOp $input, $dims)), + (replaceWithValue $input), + [(HasSameType $input, $op)]>; + +def IdentityBroadcastInDimReshape : Pat< + (HLO_ReshapeOp:$op (HLO_BroadcastInDimOp $input, $dims)), + (replaceWithValue $input), + [(HasSameType $input, $op)]>; diff --git a/tensorflow/compiler/mlir/hlo/lib/utils/convert_op_folder.cc b/tensorflow/compiler/mlir/hlo/lib/utils/convert_op_folder.cc index 0751d2c626c..a29f0a628c4 100644 --- a/tensorflow/compiler/mlir/hlo/lib/utils/convert_op_folder.cc +++ b/tensorflow/compiler/mlir/hlo/lib/utils/convert_op_folder.cc @@ -61,12 +61,16 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements, // mapValues always takes a function returning APInt, even when the output // is actually float. using func_type = llvm::APInt(const llvm::APInt&); + + // TODO(hinsu): Correctly handle unsigned element types. + bool is_bool = old_type.isInteger(1); if (auto newFloatType = new_type.dyn_cast<mlir::FloatType>()) { // Int -> Float return elements.mapValues( - new_type, llvm::function_ref<func_type>([&newFloatType]( + new_type, llvm::function_ref<func_type>([&newFloatType, &is_bool]( const llvm::APInt& intVal) { - llvm::APFloat newDouble(static_cast<double>(intVal.getSExtValue())); + int64_t val = is_bool ? intVal.getZExtValue() : intVal.getSExtValue(); + llvm::APFloat newDouble(static_cast<double>(val)); bool loses_info = false; newDouble.convert(newFloatType.getFloatSemantics(), llvm::APFloat::rmNearestTiesToEven, &loses_info); @@ -76,9 +80,10 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements, // new_type is Integer // Int -> Int return elements.mapValues( - new_type, - llvm::function_ref<func_type>([&bit_width](const llvm::APInt& intVal) { - return llvm::APInt(bit_width, intVal.getSExtValue()); + new_type, llvm::function_ref<func_type>([&bit_width, &is_bool]( + const llvm::APInt& intVal) { + int64_t val = is_bool ? intVal.getZExtValue() : intVal.getSExtValue(); + return llvm::APInt(bit_width, val); })); } diff --git a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir index 8470f363fcb..41eedeeabe5 100644 --- a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir @@ -1483,3 +1483,21 @@ func @pad_fold() -> tensor<4x5xi32> { // CHECK-SAME: [1, 1, 1, 1, 1], [2, 1, 3, 1, 1], [4, 1, 5, 1, 1], [1, 1, 1, 1, 1] // CHECK-SAME: ]> : tensor<4x5xi32> } + +// CHECK-LABEL: @identity_broadcast_reshape +func @identity_broadcast_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> { + %0 = "mhlo.broadcast"(%arg0) { + broadcast_sizes = dense<[1]> : tensor<1xi64>} : (tensor<128xf32>) -> tensor<1x128xf32> + %1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32> + return %1 : tensor<128xf32> + // CHECK: return %arg0 : tensor<128xf32> +} + +// CHECK-LABEL: @identity_broadcast_in_dim_reshape +func @identity_broadcast_in_dim_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> { + %0 = "mhlo.broadcast_in_dim"(%arg0) { + broadcast_dimensions = dense<[1]> : tensor<1xi64> } : (tensor<128xf32>) -> tensor<1x128xf32> + %1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32> + return %1 : tensor<128xf32> + // CHECK: return %arg0 : tensor<128xf32> +} diff --git a/tensorflow/compiler/mlir/hlo/tests/convert.mlir b/tensorflow/compiler/mlir/hlo/tests/convert.mlir index dab395c52cd..246cf415d27 100644 --- a/tensorflow/compiler/mlir/hlo/tests/convert.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/convert.mlir @@ -123,6 +123,17 @@ func @const_int_bf16() -> tensor<bf16> { // ----- +// CHECK-LABEL: func @const_bool_f32 +func @const_bool_f32() -> tensor<2xf32> { + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32> + %cst = mhlo.constant dense<[0, 1]> : tensor<2xi1> + %0 = "mhlo.convert"(%cst) : (tensor<2xi1>) -> tensor<2xf32> + // CHECK-NEXT: return [[CST]] + return %0 : tensor<2xf32> +} + +// ----- + // CHECK-LABEL: func @const_bf16_int func @const_bf16_int() -> tensor<i16> { // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i16> @@ -145,8 +156,8 @@ func @const_int_narrowing() -> tensor<i32> { // ----- -// CHECK-LABEL: func @const_int_widening -func @const_int_widening() -> tensor<i64> { +// CHECK-LABEL: func @const_bool_widening +func @const_bool_widening() -> tensor<i64> { // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64> %cst = mhlo.constant dense<42> : tensor<i32> %0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i64> @@ -156,6 +167,17 @@ func @const_int_widening() -> tensor<i64> { // ----- +// CHECK-LABEL: func @const_int_widening +func @const_int_widening() -> tensor<2xi32> { + // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[0, 1]> : tensor<2xi32> + %cst = mhlo.constant dense<[0, 1]> : tensor<2xi1> + %0 = "mhlo.convert"(%cst) : (tensor<2xi1>) -> tensor<2xi32> + // CHECK-NEXT: return [[CST]] + return %0 : tensor<2xi32> +} + +// ----- + // CHECK-LABEL: func @const_negative_int_widening func @const_negative_int_widening() -> tensor<i64> { // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-42> : tensor<i64> diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir index 9e5ce67f39a..a939cab6d10 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir @@ -93,7 +93,7 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) { %scratch = alloc() : memref<32xi8> %info = alloc() : memref<32xi32> - "lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_upper = true } + "lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_lower = true } : (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> () return } diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 63f7a26899f..978d0bbbfa9 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -704,6 +704,7 @@ cc_library( ":convert_type", ":flatbuffer_tflite_operator_lib", ":tensorflow_lite", + "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:mangling_util", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index a3cd4a8ad51..2861c14b32b 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -64,6 +64,7 @@ limitations under the License. #include "mlir/Translation.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -149,7 +150,7 @@ StatusOr<QuantizedType> GetQuantizedType(const TensorT& tensor, Builder builder, int64_t storage_min = QuantizedType::getDefaultMinimumForInteger( is_signed, storage_type.getWidth()) + - is_weight_buffer; + static_cast<int>(is_weight_buffer); int64_t storage_max = QuantizedType::getDefaultMaximumForInteger( is_signed, storage_type.getWidth()); uint32_t flags = @@ -177,12 +178,25 @@ StatusOr<QuantizedType> GetQuantizedType(const TensorT& tensor, Builder builder, quant_params.zero_point.at(0), storage_min, storage_max); } +// import float tensor with calibration value into calibrated quantized type. +StatusOr<QuantizedType> GetCalibratedQuantizedType(const TensorT& tensor, + Builder builder) { + if (tensor.quantization == nullptr) { + return errors::InvalidArgument("The tensor is not quantized."); + } + auto raw_elem_type = ConvertElementType(tensor.type, builder); + float min = tensor.quantization->min[0]; + float max = tensor.quantization->max[0]; + return mlir::quant::CalibratedQuantizedType::get(raw_elem_type, min, max); +} + // TODO(b/138222071) Remove shapeless_are_scalars once we can reliably // make that distinction and don't have to rely on context // (input to main and constants must have static shape) StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder, bool shapeless_are_scalars = false, - bool is_constant = false) { + bool is_constant = false, + bool is_intermediate = false) { mlir::Type elem_type = ConvertElementType(tensor.type, builder); // TODO(b/139554398) Store min/max (even for non-quantized tensors) somewhere // if it's set @@ -191,6 +205,13 @@ StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder, GetQuantizedType(tensor, builder, is_constant)); } + // Intermediate tensors with calibration value (but not scale and zero points) + // should return calibrated quantized type. + if (is_intermediate && tensor.quantization != nullptr && + !IsQuantized(tensor)) { + TF_ASSIGN_OR_RETURN(elem_type, GetCalibratedQuantizedType(tensor, builder)); + } + if (IsScalar(tensor) || (shapeless_are_scalars && tensor.shape.empty())) { return RankedTensorType::get({}, elem_type); } @@ -1033,7 +1054,7 @@ StatusOr<FuncOp> ConvertSubgraph( } } - // Intermediate tensors for tfl.lstm are used to carry quantization range + // Intermediate tensors for LSTMs are used to carry quantization range // in their types, so we only need and extract their types. std::vector<mlir::TensorType> intermediate_types; intermediate_types.reserve(5); @@ -1041,7 +1062,8 @@ StatusOr<FuncOp> ConvertSubgraph( TF_ASSIGN_OR_RETURN( auto type, GetTensorType(*subgraph.tensors[intermediate], builder, /*shapeless_are_scalars=*/true, - /*is_constant=*/true)); + /*is_constant=*/false, + /*is_intermediate=*/true)); intermediate_types.emplace_back(type); } @@ -1135,7 +1157,6 @@ OwningModuleRef tflite::FlatBufferToMlir( auto builder = Builder(context); - std::vector<std::string> func_names; for (auto& subgraph : model->subgraphs) { func_names.push_back(subgraph->name); diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 215812a6d1d..1ec6a4c28a7 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -1978,6 +1978,10 @@ OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) { OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { assert(operands.size() == 1); + if (getElementTypeOrSelf(input()) == getElementTypeOrSelf(getType())) { + return input(); + } + // For now, only supports cast between integer types. auto elements_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); if (!elements_attr) { diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 51300978c39..d6ab33e0673 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -483,9 +483,9 @@ value of each element in `x`. For example, if x is an input element and y is an output element, this operation computes \\(y = |x|\\). }]; - let arguments = (ins TFL_FpTensor:$x); + let arguments = (ins TFL_TensorOf<[F32, QI8, QI16]>:$x); - let results = (outs TFL_FpTensor:$y); + let results = (outs TFL_TensorOf<[F32, QI8, QI16]>:$y); let hasFolder = 1; } @@ -587,15 +587,15 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [ let arguments = (ins TFL_I32Tensor:$output_shape, - TFL_TensorOf<[F32, QI8, QUI8]>:$weights, - TFL_TensorOf<[F32, QI8, QUI8]>:$input, - TFL_TensorOfOrNone<[F32, QI32]>:$bias, + TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$weights, + TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input, + TFL_TensorOfOrNone<[F32, QI32, I64]>:$bias, TFL_PaddingAttr:$padding, Confined<I32Attr, [IntPositive]>:$stride_h, Confined<I32Attr, [IntPositive]>:$stride_w ); - let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output); + let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output); let hasOptions = 1; @@ -624,7 +624,7 @@ def TFL_AveragePool2DOp: }]; let arguments = ( - ins TFL_TensorOf<[F32, QI8, QUI8]>:$input, + ins TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input, I32Attr:$filter_height, I32Attr:$filter_width, TFL_PaddingAttr:$padding, @@ -633,7 +633,7 @@ def TFL_AveragePool2DOp: TFL_AFAttr:$fused_activation_function ); - let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output); + let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output); let hasOptions = 1; let customOption = "Pool2DOptions"; @@ -947,7 +947,7 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [ let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$input, - TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$filter, + TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$filter, TFL_TensorOfOrNone<[F32, QI32, QUI32]>:$bias, TFL_AFAttr:$fused_activation_function, @@ -999,14 +999,14 @@ in the batch dimensions and broadcasting. }]; let arguments = (ins - TFL_TensorOf<[F32, QI8]>:$x, - TFL_TensorOf<[F32, QI8]>:$y, + TFL_TensorOf<[F32, QI8, QI16]>:$x, + TFL_TensorOf<[F32, QI8, QI16]>:$y, DefaultValuedAttr<BoolAttr, "false">:$adj_x, DefaultValuedAttr<BoolAttr, "false">:$adj_y ); let results = (outs - TFL_TensorOf<[F32, QI8]>:$output + TFL_TensorOf<[F32, QI8, QI16]>:$output ); let hasOptions = 1; @@ -1026,7 +1026,7 @@ def TFL_GatherOp : TFL_Op<"gather", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8]>:$params, + TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8, QI16]>:$params, TFL_TensorOf<[I32, I64]>:$indices, I32Attr:$axis ); @@ -1038,7 +1038,7 @@ def TFL_GatherOp : TFL_Op<"gather", [ ]; let results = (outs - TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8]>:$output + TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8, QI16]>:$output ); let hasOptions = 1; @@ -1750,12 +1750,12 @@ def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [ }]; let arguments = ( - ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$input, + ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8, QI16]>:$input, // Slope of the activation function at x < 0. F32Attr:$alpha ); - let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$output); + let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8, QI16]>:$output); let hasOptions = 0b1; } @@ -1977,12 +1977,12 @@ def TFL_MaximumOp : TFL_Op<"maximum", [ }]; let arguments = ( - ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$lhs, - TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$rhs + ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$lhs, + TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$rhs ); let results = (outs - TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$max + TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$max ); let builders = [TFL_BroadcastableBinaryBuilder]; @@ -2005,13 +2005,13 @@ def TFL_MeanOp : TFL_Op<"mean", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8]>:$input, + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8, QI16]>:$input, TFL_TensorOf<[I32, I64]>:$axis, BoolAttr:$keep_dims ); let results = (outs - TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8]>:$output); + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8, QI16]>:$output); let hasOptions = 1; let customOption = "ReducerOptions"; @@ -2090,13 +2090,13 @@ equivalent to setting: }]; let arguments = (ins - TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8]>:$input, + TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$input, TFL_I32OrI64Tensor:$begin, TFL_I32OrI64Tensor:$size ); let results = (outs - TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8]>:$output + TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$output ); let verifier = [{ return Verify(*this); }]; @@ -2116,12 +2116,12 @@ def TFL_SumOp: TFL_Op<"sum", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input, + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input, TFL_I32Tensor:$axes, BoolAttr:$keep_dims ); - let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output); + let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output); let hasOptions = 1; let customOption = "ReducerOptions"; @@ -2139,13 +2139,13 @@ def TFL_ReduceMinOp: TFL_Op<"reduce_min", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input, + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input, TFL_I32Tensor:$axes, BoolAttr:$keep_dims ); let results = (outs - TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output); + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output); let hasOptions = 1; let customOption = "ReducerOptions"; @@ -2163,13 +2163,13 @@ def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input, + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input, TFL_I32Tensor:$axes, BoolAttr:$keep_dims ); let results = (outs - TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output); + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output); let hasOptions = 1; let customOption = "ReducerOptions"; @@ -2186,13 +2186,13 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input, + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input, TFL_I32Tensor:$axes, BoolAttr:$keep_dims ); let results = (outs - TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output); + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output); let hasOptions = 1; let customOption = "ReducerOptions"; @@ -2210,12 +2210,12 @@ def TFL_MinimumOp : TFL_Op<"minimum", [ }]; let arguments = ( - ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$lhs, - TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$rhs + ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$lhs, + TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$rhs ); let results = (outs - TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$min + TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$min ); let builders = [TFL_BroadcastableBinaryBuilder]; @@ -2364,10 +2364,10 @@ def TFL_PadOp : TFL_Op<"pad", [ ``` }]; - let arguments = (ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input, + let arguments = (ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input, TFL_I32OrI64Tensor:$padding); - let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output); + let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output); let hasOptions = 1; } @@ -2500,9 +2500,9 @@ def TFL_ReluOp: TFL_Op<"relu", [ x -> max(0, x) }]; - let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$x); + let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8, QI16]>:$x); - let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$y); + let results = (outs TFL_TensorOf<[F32, QUI8, QI8, QI16]>:$y); // This builder doesn't work with quantized type, so it can only be used by // non-quantization tablegen patterns. Currently, it is used by the @@ -2828,11 +2828,11 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [ }]; let arguments = ( - ins TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$input, + ins TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8, QI16]>:$input, F32Attr:$beta ); - let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$output); + let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8, QI16]>:$output); let hasOptions = 1; @@ -3058,12 +3058,12 @@ def TFL_TransposeOp : TFL_Op<"transpose", [ }]; let arguments = (ins - TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64]>:$input, + TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64, QI16]>:$input, TFL_TensorOf<[I32]>:$perm ); let results = (outs - TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64]>:$output + TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64, QI16]>:$output ); let verifier = [{ return Verify(*this); }]; @@ -3330,14 +3330,14 @@ def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor", [ }]; let arguments = (ins - TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8]>:$input, + TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8, QI16]>:$input, TFL_I32Tensor:$size, BoolAttr:$align_corners, DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers ); let results = (outs - TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8]>:$output + TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8, QI16]>:$output ); let hasOptions = 1; @@ -3830,6 +3830,9 @@ Ba et al. 'Layer Normalization' // Types of the optional intermediate tensors, which exist for fully // quantized LSTM op and hold the ranges of the intermediate tensors. + // The type for intermediate tenssors are be quant.calibrated when imported + // to only store calibrated min, max values. The proper quantization spec is + // determined while going through quantization passes. OptionalAttr<TypeAttr>:$input_to_input_intermediate, OptionalAttr<TypeAttr>:$input_to_forget_intermediate, OptionalAttr<TypeAttr>:$input_to_cell_intermediate, @@ -3945,6 +3948,9 @@ def TFL_UnidirectionalSequenceLSTMOp : // Types of the optional intermediate tensors, which exist for fully // quantized op and hold the ranges of the intermediate tensors. + // The type for intermediate tenssors are be quant.calibrated when imported + // to only store calibrated min, max values. The proper quantization spec is + // determined while going through quantization passes. OptionalAttr<TypeAttr>:$input_to_input_intermediate, OptionalAttr<TypeAttr>:$input_to_forget_intermediate, OptionalAttr<TypeAttr>:$input_to_cell_intermediate, diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index 897e623b37b..015323df5e3 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -89,6 +89,11 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops(); pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops; pass_config.lower_tensor_list_ops = true; + // Disable the unfolding of the 16x16 TF::BatchMatMulOp to avoid the + // conversion to an unsupported 16x16 TFL::FullyConnectedOp. + if (toco_flags.inference_type() == toco::IODataType::QUANTIZED_INT16) { + pass_config.unfold_batch_matmul = false; + } return internal::ConvertMLIRToTFLiteFlatBuffer( toco_flags, std::move(module), pass_config, /*saved_model_tags=*/{}, diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index e16dbdea02a..01f8070b37a 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -174,6 +174,11 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops(); pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops; pass_config.lower_tensor_list_ops = true; + // Disable the unfolding of the 16x16 TF::BatchMatMulOp to avoid the + // conversion to an unsupported 16x16 TFL::FullyConnectedOp. + if (toco_flags.inference_type() == toco::IODataType::QUANTIZED_INT16) { + pass_config.unfold_batch_matmul = false; + } // TODO(b/153507667): Pass the session object when importing logic is removed. auto status = internal::ConvertMLIRToTFLiteFlatBuffer( diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_config.h b/tensorflow/compiler/mlir/lite/quantization/quantization_config.h index 7343853fea0..0e2f4906a7a 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_config.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_config.h @@ -53,10 +53,11 @@ struct QuantizationSpecs { bool disable_per_channel = false; // When set to true, the fixed output ranges of the activation ops (tanh, - // sigmoid, etc.) are not enforced. Then, to quantize these ops, quantization - // emulation ops should be specified after the ops in the input graph. This - // flag should be set to false for post-training quantization. - bool disable_enforced_fixed_output_range = false; + // sigmoid, etc.) and the weight constants are not inferred. Then, to quantize + // these ops, quantization emulation ops should be placed after the ops in the + // input graph. This flag should be set to false for post-training + // quantization. + bool disable_infer_tensor_range = false; // The node type when the model is exported. Currently this is limited to // DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc index 53674972f5b..63afab0cf12 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc @@ -100,13 +100,13 @@ class QuantizationDriver { explicit QuantizationDriver(FuncOp fn, bool is_signed, bool disable_per_channel, OpQuantSpecGetter op_quant_spec_getter, - bool enforce_fixed_output_range) + bool infer_tensor_range) : fn_(fn), builder_(fn.getBody()), is_signed_(is_signed), disable_per_channel_(disable_per_channel), op_quant_spec_getter_(op_quant_spec_getter), - enforce_fixed_output_range_(enforce_fixed_output_range) {} + infer_tensor_range_(infer_tensor_range) {} // The entry point of the quantization parameters propagation. void Run(); @@ -384,7 +384,9 @@ class QuantizationDriver { OpQuantSpecGetter op_quant_spec_getter_; - bool enforce_fixed_output_range_; + // Infer output ranges for activation ops and constants. This is usually + // required for post-training quantization. + bool infer_tensor_range_; }; } // namespace @@ -670,33 +672,43 @@ void QuantizationDriver::PreprocessConstantOps() { Value value = cst.getResult(); builder_.setInsertionPoint(cst); - for (auto indexed_use : llvm::enumerate(value.getUses())) { - auto &use = indexed_use.value(); - auto spec = GetQuantSpec(use.getOwner()); - auto biases = spec->biases_params; - Operation *user = use.getOwner(); - int operand_num = use.getOperandNumber(); + // The following loop will change the value uses, thus we cache all the uses + // needs to be changed. + llvm::SmallVector<std::pair<Operation *, int>, 4> uses; + for (auto &use : value.getUses()) { + uses.push_back({use.getOwner(), use.getOperandNumber()}); + } + for (auto indexed_use : llvm::enumerate(uses)) { + Operation *user = indexed_use.value().first; + int operand_num = indexed_use.value().second; + + auto spec = GetQuantSpec(user); + auto biases = spec->biases_params; + + // The quantization parameters of a `weight` shouldn't be determined by + // other values. So any constants which are not bias, an operand of an + // op with same scale requirements, and haven't been quantized are + // weights. if (biases.find(operand_num) == biases.end() && !llvm::dyn_cast<mlir::SameScalesOpInterface>(user) && !llvm::dyn_cast<quant::QuantizeCastOp>(user)) { - // Needs to scan the content to get the quantization parameters if there - // are no quantization parameters (FakeQuant ops). - // For this case, the weight isn't duplicated. + // Needs to scan the content of weights to get the quantization + // parameters if there are no quantization parameters (FakeQuant ops). + // For this case, the weight will not be duplicated. weights_.insert(cst); auto affine_user = llvm::dyn_cast<mlir::AffineQuantizedOpInterface>(user); - if (affine_user && - affine_user.GetAffineOperandIndex() == use.getOperandNumber() && + if (affine_user && affine_user.GetAffineOperandIndex() == operand_num && affine_user.RequiredNarrowRangeAffineOperand()) { optimized_weights_.insert( {cst, affine_user.GetQuantizationDimIndex()}); } } else { - // This is a bias, so the quantization parameter isn't determined by the - // local content. Same if the user can have quantization parameter - // propagated from other places. - // Duplicate this constant in case it is shared by different users. + // This is a bias or an operand of an op with same scale requirements, + // so the quantization parameter are propagated from or determined by + // other values. Duplicate this constant in case it is shared by + // different users. if (indexed_use.index() > 0) { cst = builder_.create<ConstantOp>(cst.getLoc(), cst.getValue()); } @@ -786,12 +798,14 @@ bool QuantizationDriver::PropagateParams() { quantized_.insert(op); if (auto cst = llvm::dyn_cast<ConstantOp>(op)) { - // If it isn't a weight or has been quantized, skip. - if (!IsWeight(cst) || IsQuantized(op)) continue; - - // The quantization parameters are determined by the content of the - // constant. - changed |= SetConstantResultParams(op); + // If the workflow requires inferring ranges from the content + // (post-training quantization) and it is weight (filter) and hasn't + // been quantized, we infer the quantization parameters from the content. + if (infer_tensor_range_ && IsWeight(cst) && !IsQuantized(op)) { + // The quantization parameters are determined by the content of the + // constant. + changed |= SetConstantResultParams(op); + } continue; } @@ -826,7 +840,9 @@ bool QuantizationDriver::PropagateParams() { // TODO(fengliuai): make the bit width configurable. auto restricted = llvm::dyn_cast<FixedOutputRangeInterface>(op); - if (restricted && enforce_fixed_output_range_) { + if (restricted && infer_tensor_range_) { + // Infer ranges from the activation ops. This is usually required for + // the post-training quantization workflow. // TODO(fengliuai): different result can have different fixed range. auto params = restricted.GetFixedOutputRange(is_signed_, /*bit_width=*/8); for (auto i = 0; i < op->getNumResults(); ++i) { @@ -903,9 +919,9 @@ void QuantizationDriver::Run() { void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed, bool disable_per_channel, OpQuantSpecGetter op_quant_spec_getter, - bool post_training_quantization) { + bool infer_tensor_ranges) { QuantizationDriver(func, is_signed, disable_per_channel, op_quant_spec_getter, - post_training_quantization) + infer_tensor_ranges) .Run(); } diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index 5463ba15e18..cb131453a7b 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -490,13 +490,13 @@ quant::QuantizedType GetUniformQuantizedTypeForBias( // and the propagation results are materialized by inserting pairs of quantize // and dequantize ops to this function. Set `disable_per_channel` to true to not // use per channel quantization even the op supports it. -// Setting `enforce_fixed_output_range` to true, to infer quantization -// parameters from the fixed output range ops. This is only used for -// post-training quantization. +// Setting `infer_tensor_range` to true, to infer quantization parameters from +// the activation ops and weight constants. This is only used for post-training +// quantization. void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed, bool disable_per_channel, OpQuantSpecGetter op_quant_spec_getter, - bool enforce_fixed_output_range); + bool infer_tensor_ranges); // The function might contain more stats ops than required, and it will // introduce requantize if the calibration stats have conflicts. This method diff --git a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir index 27a7068cda6..88b084220e8 100644 --- a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir +++ b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir @@ -638,4 +638,10 @@ func @cast_ui8_to_i32() -> tensor<4xi32> { // CHECK: return %[[CST]] } +// CHECK-LABEL: @cast_identity +func @cast_identity(%arg0 : tensor<7xf32>) -> tensor<7xf32> { + %0 = "tfl.cast"(%arg0) : (tensor<7xf32>) -> tensor<7xf32> + return %0 : tensor<7xf32> + // CHECK: return %arg0 : tensor<7xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.json b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.json new file mode 100644 index 00000000000..7052d289796 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.json @@ -0,0 +1,335 @@ +// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck %s + +// CHECK: effective_hidden_scale_intermediate = tensor<!quant.calibrated<f32<-5.000000e-01, 5.000000e-01>>> +// CHECK: input_to_cell_intermediate = tensor<!quant.calibrated<f32<-4.000000e+00, 4.000000e+00>>> +// CHECK: input_to_forget_intermediate = tensor<!quant.calibrated<f32<-1.600000e+01, 1.600000e+01>>> +// CHECK: input_to_input_intermediate = tensor<!quant.calibrated<f32<-3.200000e+01, 3.200000e+01>>> +// CHECK: input_to_output_intermediate = tensor<!quant.calibrated<f32<-1.000000e+00, 1.000000e+00>>> + +{ + "version": 3, + "operator_codes": [ + { + "builtin_code": "UNIDIRECTIONAL_SEQUENCE_LSTM" + } + ], + "subgraphs": [ + { + "tensors": [ + { + "shape": [1, 5], + "name": "input0" + }, + { + "shape": [2, 5], + "buffer": 1, + "name": "input2input_weights1" + }, + { + "shape": [2, 5], + "buffer": 2, + "name": "input2forget_weights2" + }, + { + "shape": [2, 5], + "buffer": 3, + "name": "input2cell_weights3" + }, + { + "shape": [2, 5], + "buffer": 4, + "name": "input2output_weights4" + }, + { + "shape": [2, 4], + "buffer": 5, + "name": "rec2input_weights5" + }, + { + "shape": [2, 4], + "buffer": 6, + "name": "rec2forget_weights6" + }, + { + "shape": [2, 4], + "buffer": 7, + "name": "rec2cell_weights7" + }, + { + "shape": [2, 4], + "buffer": 8, + "name": "rec2output_weights8" + }, + { + "shape": [2], + "buffer": 9, + "name": "cell2input_weights9" + }, + { + "shape": [2], + "buffer": 10, + "name": "cell2forget_weights10" + }, + { + "shape": [2], + "buffer": 11, + "name": "cell2output_weights11" + }, + { + "shape": [2], + "buffer": 12, + "name": "input_gate_bias12" + }, + { + "shape": [2], + "buffer": 13, + "name": "forget_gate_bias13" + }, + { + "shape": [2], + "buffer": 14, + "name": "cell_gate_bias14" + }, + { + "shape": [2], + "buffer": 15, + "name": "output_gate_bias15" + }, + { + "shape": [4, 2], + "buffer": 16, + "name": "proj_weights16" + }, + { + "shape": [4], + "buffer": 17, + "name": "proj_bias17" + }, + { + "shape": [1, 4], + "name": "input_activation_state18", + "is_variable": true, + "quantization": { + "min": [-0.9], + "max": [0.9] + } + }, + { + "shape": [1, 2], + "name": "input_cell_state19", + "is_variable": true, + "quantization": { + "min": [-0.8], + "max": [0.8] + } + }, + { + "shape": [2], + "buffer": 18, + "name": "input_norm20" + }, + { + "shape": [2], + "buffer": 19, + "name": "forget_norm21" + }, + { + "shape": [2], + "buffer": 20, + "name": "cell_norm22" + }, + { + "shape": [2], + "buffer": 21, + "name": "output_norm23" + }, + { + "shape": [], + "name": "output24" + }, + { + "shape": [], + "name": "intermediate_0", + "is_variable": true, + "quantization": { + "min": [-32], + "max": [32] + } + }, + { + "shape": [], + "name": "intermediate_1", + "is_variable": true, + "quantization": { + "min": [-16], + "max": [16] + } + }, + { + "shape": [], + "name": "intermediate_2", + "is_variable": true, + "quantization": { + "min": [-4], + "max": [4] + } + }, + { + "shape": [], + "name": "intermediate_3", + "is_variable": true, + "quantization": { + "min": [-1.0], + "max": [1.0] + } + }, + { + "shape": [], + "name": "intermediate_4", + "is_variable": true, + "quantization": { + "min": [-0.5], + "max": [0.5] + } + } + ], + "inputs": [0], + "outputs": [24], + "operators": [ + { + "inputs": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 + ], + "outputs": [24], + "intermediates": [ + 25, 26, 27, 28, 29 + ], + "builtin_options_type": "UnidirectionalSequenceLSTMOptions", + "builtin_options": { + "fused_activation_function": "TANH", + "cell_clip": 50.0 + }, + "mutating_variable_inputs": [ + false, + false, false, false, false, + false, false, false, false, + false, false, false, + false, false, false, false, + true, true, + false, false, false, false + ] + } + ] + } + ], + "buffers": [ + { + "data": [] + }, + { + "data": [ + 36, 167, 168, 63, 0, 140, 72, 191, 120, 20, 147, 62, 20, 152, 196, 190, 121, 98, 82, 187, 95, 128, 213, 61, 189, 3, 138, 63, 54, 103, 13, 62, 46, 224, 66, 63, 157, 204, 180, 191 + ] + }, + { + "data": [ + 223, 20, 21, 64, 246, 166, 31, 191, 6, 51, 157, 188, 114, 90, 167, 62, 118, 240, 59, 63, 49, 162, 255, 62, 17, 91, 160, 63, 32, 47, 26, 63, 40, 136, 178, 191, 243, 154, 236, 61 + ] + }, + { + "data": [ + 137, 231, 86, 63, 41, 154, 16, 63, 239, 37, 77, 191, 55, 189, 24, 189, 86, 63, 18, 63, 42, 55, 13, 191, 110, 139, 138, 191, 219, 148, 181, 63, 71, 232, 108, 191, 66, 226, 145, 191 + ] + }, + { + "data": [ + 245, 179, 225, 190, 51, 202, 176, 189, 132, 47, 53, 191, 155, 25, 50, 191, 197, 130, 240, 191, 98, 125, 45, 62, 243, 70, 83, 62, 85, 155, 139, 63, 113, 239, 11, 192, 35, 251, 139, 62 + ] + }, + { + "data": [ + 248, 188, 211, 191, 142, 11, 73, 62, 36, 8, 84, 63, 186, 208, 11, 191, 76, 208, 190, 191, 223, 200, 210, 63, 183, 170, 103, 63, 116, 129, 145, 63 + ] + }, + { + "data": [ + 235, 202, 222, 190, 159, 201, 112, 191, 217, 248, 166, 63, 165, 199, 131, 191, 130, 59, 47, 63, 179, 11, 186, 62, 55, 168, 18, 192, 152, 213, 26, 64 + ] + }, + { + "data": [ + 245, 123, 138, 62, 213, 106, 231, 59, 211, 218, 250, 62, 25, 157, 134, 63, 147, 22, 164, 63, 25, 221, 139, 62, 1, 230, 247, 62, 210, 185, 142, 63 + ] + }, + { + "data": [ + 197, 123, 23, 192, 45, 96, 178, 190, 174, 87, 165, 62, 213, 225, 200, 191, 119, 248, 15, 191, 128, 125, 171, 189, 90, 125, 222, 63, 4, 76, 95, 62 + ] + }, + { + "data": [ + 210, 73, 183, 63, 248, 177, 13, 191 + ] + }, + { + "data": [ + 78, 251, 212, 191, 169, 29, 147, 63 + ] + }, + { + "data": [ + 178, 227, 203, 191, 247, 155, 103, 63 + ] + }, + { + "data": [ + 206, 111, 165, 190, 153, 77, 227, 63 + ] + }, + { + "data": [ + 255, 114, 132, 191, 253, 202, 140, 191 + ] + }, + { + "data": [ + 90, 247, 1, 192, 125, 120, 209, 191 + ] + }, + { + "data": [ + 65, 75, 243, 191, 58, 122, 146, 190 + ] + }, + { + "data": [ + 40, 135, 20, 63, 109, 50, 220, 191, 56, 241, 189, 63, 65, 12, 92, 63, 61, 14, 162, 62, 157, 138, 81, 63, 125, 61, 191, 61, 102, 231, 20, 63 + ] + }, + { + "data": [ + 145, 79, 49, 189, 175, 235, 220, 190, 182, 111, 157, 190, 144, 236, 97, 191 + ] + }, + { + "data": [ + 76, 188, 109, 63, 228, 150, 201, 190 + ] + }, + { + "data": [ + 6, 146, 66, 191, 122, 127, 100, 191 + ] + }, + { + "data": [ + 216, 59, 169, 190, 161, 178, 215, 191 + ] + }, + { + "data": [ + 208, 144, 101, 191, 127, 233, 195, 190 + ] + } + ] +} diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir index 07962e15c33..b2a721c947d 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir @@ -20,20 +20,20 @@ func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32 func @testFullyQuantizedLSTM(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, %arg1: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg2: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, %arg3: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, %arg4: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, %arg5: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, %arg6: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, %arg7: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, %arg8: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, %arg9: tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, %arg10: tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, %arg11: tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, %arg12: tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, %arg13: tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, %arg14: tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, %arg15: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg16: tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, %arg17: tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, %arg18: tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>, %arg19: tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, %arg20: tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> { %cst = constant unit - %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0075630000792443752:2>>, kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> + %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8:f32, 0.0075630000792443752:2>>, kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> return %0 : tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> // CHECK-LABEL: testFullyQuantizedLSTM // CHECK: %cst = constant unit // CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) -// CHECK: }) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0075630000792443752:2>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, kernel_type = "FULL", proj_clip = 0.00999999977 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.6013896674849093E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> +// CHECK: }) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8:f32, 0.0075630000792443752:2>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, kernel_type = "FULL", proj_clip = 0.00999999977 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.6013896674849093E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> } // ----- // CHECK-LABEL: testUnidirectionalSequenceLstmWithIntermediates func @testUnidirectionalSequenceLstmWithIntermediates(%arg0: tensor<? x ? x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> { - // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> - %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> + // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> + %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> return %0 : tensor<?xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 2542c6c9af7..d991452b187 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -318,6 +318,15 @@ func @any(%arg0: tensor<2x2xi1>, %arg1: tensor<i32>) -> tensor<i1> { // CHECK: "tfl.reduce_any"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor<i32>) -> tensor<i1> } +func @any_i64axes(%arg0: tensor<8x16x16xi1>, %arg1: tensor<2xi64>) -> tensor<?xi1> { + %0 = "tf.Any"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xi1>, tensor<2xi64>) -> tensor<?xi1> + return %0 : tensor<?xi1> + + // CHECK-LABEL: any_i64axes + // CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32> + // CHECK: "tfl.reduce_any"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xi1>, tensor<2xi32>) -> tensor<?xi1> +} + func @ceil(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { %0 = "tf.Ceil"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> return %0 : tensor<8x16xf32> @@ -972,6 +981,15 @@ func @sum_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32 // CHECK: "tfl.sum"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32> } +func @sum_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor<?xf32> { + %0 = "tf.Sum"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi64>) -> tensor<?xf32> + return %0 : tensor<?xf32> + + // CHECK-LABEL: sum_i64axes + // CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32> + // CHECK: "tfl.sum"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32> +} + func @reduce_min(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> { %0 = "tf.Min"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32> return %0 : tensor<?xf32> @@ -988,6 +1006,15 @@ func @reduce_min_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tenso // CHECK: "tfl.reduce_min"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32> } +func @reduce_min_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor<?xf32> { + %0 = "tf.Min"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi64>) -> tensor<?xf32> + return %0 : tensor<?xf32> + + // CHECK-LABEL: reduce_min_i64axes + // CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32> + // CHECK: "tfl.reduce_min"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32> +} + func @reduce_max(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> { %0 = "tf.Max"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32> return %0 : tensor<?xf32> @@ -1004,6 +1031,15 @@ func @reduce_max_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tenso // CHECK: "tfl.reduce_max"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32> } +func @reduce_max_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor<?xf32> { + %0 = "tf.Max"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi64>) -> tensor<?xf32> + return %0 : tensor<?xf32> + + // CHECK-LABEL: reduce_max_i64axes + // CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32> + // CHECK: "tfl.reduce_max"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32> +} + func @reduce_prod(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> { %0 = "tf.Prod"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32> return %0 : tensor<?xf32> @@ -1020,6 +1056,15 @@ func @reduce_prod_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tens // CHECK: "tfl.reduce_prod"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32> } +func @reduce_prod_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor<?xf32> { + %0 = "tf.Prod"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi64>) -> tensor<?xf32> + return %0 : tensor<?xf32> + + // CHECK-LABEL: reduce_prod_i64axes + // CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32> + // CHECK: "tfl.reduce_prod"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32> +} + func @batch_to_space_nd(%arg0: tensor<4x2x2x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor<?xf32> { %0 = "tf.BatchToSpaceND"(%arg0, %arg1, %arg2) : (tensor<4x2x2x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32> return %0 : tensor<?xf32> @@ -1172,10 +1217,10 @@ func @strided_slice_with_constant_attributes(%arg0: tensor<10x10x10xf32>, %arg1: %0 = "tf.StridedSlice"(%arg0, %cst, %cst_1, %cst_2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32> return %0 : tensor<10x10xf32> // CHECK-LABEL: strided_slice_with_constant_attributes - // CHECK-DAG: [[BEGIN:%cst.*]] = constant dense<[-1, 0, 0]> : tensor<3xi32> - // CHECK-DAG: [[END:%cst.*]] = constant dense<[0, 10, 10]> : tensor<3xi32> - // CHECK-DAG: [[STRIDES:%cst.*]] = constant dense<1> : tensor<3xi32> - // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 6 : i32, ellipsis_mask = 0 : i32, end_mask = 6 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<10x10xf32> + // CHECK-DAG: [[BEGIN:%cst.*]] = constant dense<-1> : tensor<1xi32> + // CHECK-DAG: [[END:%cst.*]] = constant dense<0> : tensor<1xi32> + // CHECK-DAG: [[STRIDES:%cst.*]] = constant dense<1> : tensor<1xi32> + // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32> } func @strided_slice_with_string(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> { @@ -1185,6 +1230,39 @@ func @strided_slice_with_string(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tenso // CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> } +func @strided_slice_with_unranked_input_and_i64_parameters(%arg0: tensor<*xf32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>, %arg3: tensor<1xi64>) -> tensor<*xf32> { + %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<*xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<*xf32> + return %0 : tensor<*xf32> + // CHECK-LABEL: strided_slice_with_unranked_input_and_i64_parameters + // CHECK-DAG: [[BEGIN:%.*]] = "tfl.cast"(%arg1) : (tensor<1xi64>) -> tensor<1xi32> + // CHECK-DAG: [[END:%.*]] = "tfl.cast"(%arg2) : (tensor<1xi64>) -> tensor<1xi32> + // CHECK-DAG: [[STRIDES:%.*]] = "tfl.cast"(%arg3) : (tensor<1xi64>) -> tensor<1xi32> + // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<*xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xf32> +} + +func @strided_slice_with_i64_parameters(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>, %arg3: tensor<1xi64>) -> tensor<1x2x2x5xf32> { + %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<12x2x2x5xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1x2x2x5xf32> + return %0 : tensor<1x2x2x5xf32> + // CHECK-LABEL: strided_slice_with_i64_parameters + // CHECK-DAG: [[BEGIN:%.*]] = "tfl.cast"(%arg1) : (tensor<1xi64>) -> tensor<1xi32> + // CHECK-DAG: [[END:%.*]] = "tfl.cast"(%arg2) : (tensor<1xi64>) -> tensor<1xi32> + // CHECK-DAG: [[STRIDES:%.*]] = "tfl.cast"(%arg3) : (tensor<1xi64>) -> tensor<1xi32> + // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> +} + +func @strided_slice_with_i64_constant_attributes(%arg0: tensor<10x10x10xf32>) -> tensor<10x10xf32> { + %cst = constant dense<-1> : tensor<1xi64> + %cst_1 = constant dense<0> : tensor<1xi64> + %cst_2 = constant dense<1> : tensor<1xi64> + %0 = "tf.StridedSlice"(%arg0, %cst, %cst_1, %cst_2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<10x10xf32> + return %0 : tensor<10x10xf32> + // CHECK-LABEL: strided_slice_with_i64_constant_attributes + // CHECK-DAG: [[BEGIN:%cst.*]] = constant dense<-1> : tensor<1xi32> + // CHECK-DAG: [[END:%cst.*]] = constant dense<0> : tensor<1xi32> + // CHECK-DAG: [[STRIDES:%cst.*]] = constant dense<1> : tensor<1xi32> + // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32> +} + func @slice1Tensor(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<?x3x5xf32> { %0 = "tf.Slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32> return %0 : tensor<?x3x5xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir index 53caf15bc8f..5407e8dfdae 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir @@ -625,6 +625,35 @@ func @QuantizeSharedBiases2( // CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]]) } +// Make sure constants are duplicataed for all users. +// CHECK-LABEL: QuantizeSharedConstantsMultipleUsers +func @QuantizeSharedConstantsMultipleUsers( + %arg0: tensor<32x!quant.uniform<u8:f32, 1.0>>, + %arg1: tensor<32x!quant.uniform<u8:f32, 2.0>>, + %arg2: tensor<32x!quant.uniform<u8:f32, 3.0>>, + %arg3: tensor<32x!quant.uniform<u8:f32, 4.0>>) -> (tensor<32xf32>, tensor<32xf32>, tensor<32xf32>, tensor<32xf32>) { + %cst = constant dense<0.0> : tensor<32xf32> + %0 = "tfl.dequantize"(%arg0) : (tensor<32x!quant.uniform<u8:f32, 1.0>>) -> tensor<32xf32> + %1 = "tfl.dequantize"(%arg1) : (tensor<32x!quant.uniform<u8:f32, 2.0>>) -> tensor<32xf32> + %2 = "tfl.dequantize"(%arg2) : (tensor<32x!quant.uniform<u8:f32, 3.0>>) -> tensor<32xf32> + %3 = "tfl.dequantize"(%arg3) : (tensor<32x!quant.uniform<u8:f32, 4.0>>) -> tensor<32xf32> + + %4 = "tfl.minimum"(%0, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> + %5 = "tfl.minimum"(%1, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> + %6 = "tfl.minimum"(%2, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> + %7 = "tfl.minimum"(%3, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> + return %4, %5, %6, %7 : tensor<32xf32>, tensor<32xf32>, tensor<32xf32>, tensor<32xf32> + +// CHECK: %[[cst1:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<32xf32> +// CHECK: %[[cst2:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 3.000000e+00>>) -> tensor<32xf32> +// CHECK: %[[cst3:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 4.000000e+00>>) -> tensor<32xf32> +// CHECK: %[[cst4:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 1.000000e+00>>) -> tensor<32xf32> +// CHECK: "tfl.minimum"(%{{.*}}, %[[cst4]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> +// CHECK: "tfl.minimum"(%{{.*}}, %[[cst1]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> +// CHECK: "tfl.minimum"(%{{.*}}, %[[cst2]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> +// CHECK: "tfl.minimum"(%{{.*}}, %[[cst3]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> +} + // Make sure quantization parameters are scanned from weight, but not from bias. // CHECK-LABEL: QuantizeWeight func @QuantizeWeight(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> { diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index 53fc1bf88be..50a93ecf4b5 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -551,6 +551,19 @@ func @StridedSliceRewriteMasks(%arg0: tensor<8x4x16x2xf32>) -> tensor<8x4x16x1xf return %0 : tensor<8x4x16x1xf32> } +func @strided_slice_with_constant_attributes(%arg0: tensor<10x10x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<10x10xf32> { + %cst = constant dense<-1> : tensor<1xi32> + %cst_1 = constant dense<0> : tensor<1xi32> + %cst_2 = constant dense<1> : tensor<1xi32> + %0 = "tf.StridedSlice"(%arg0, %cst, %cst_1, %cst_2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32> + return %0 : tensor<10x10xf32> + // CHECK-LABEL: strided_slice_with_constant_attributes + // CHECK-DAG: [[BEGIN:%cst.*]] = constant dense<[-1, 0, 0]> : tensor<3xi32> + // CHECK-DAG: [[END:%cst.*]] = constant dense<[0, 10, 10]> : tensor<3xi32> + // CHECK-DAG: [[STRIDES:%cst.*]] = constant dense<1> : tensor<3xi32> + // CHECK-NEXT: "tf.StridedSlice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<10x10xf32> +} + func @broadcast_to_f32_low_dim(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> return %0: tensor<3x3xf32> diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index 5b35c71b0bb..f4bc89b9801 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -328,23 +328,28 @@ def LegalizeMean : Pat<(TF_MeanOp $arg0, $arg1, BoolAttr:$arg2), (TFL_MeanOp $arg0, $arg1, $arg2)>; def LegalizeSum : Pat<(TF_SumOp $arg, $axes, BoolAttr:$arg2), - (TFL_SumOp $arg, $axes, $arg2)>; + (TFL_SumOp $arg, (CreateTFCastToInt32Op $axes), $arg2)>; // TopK in TFL is always sorted so we ignore that attribute here. def LegalizeTopKV2 : Pat<(TF_TopKV2Op $input, $k, $ignored_sorted), (TFL_TopKV2Op $input, $k)>; -def LegalizeMin : Pat<(TF_MinOp $arg0, $arg1, BoolAttr:$arg2), - (TFL_ReduceMinOp $arg0, $arg1, $arg2)>; +def LegalizeMin : Pat< + (TF_MinOp $arg0, $axes, BoolAttr:$arg2), + (TFL_ReduceMinOp $arg0, (CreateTFCastToInt32Op $axes), $arg2)>; -def LegalizeMax : Pat<(TF_MaxOp $arg0, $arg1, BoolAttr:$arg2), - (TFL_ReduceMaxOp $arg0, $arg1, $arg2)>; +def LegalizeMax : Pat< + (TF_MaxOp $arg0, $axes, BoolAttr:$arg2), + (TFL_ReduceMaxOp $arg0, (CreateTFCastToInt32Op $axes), $arg2)>; -def LegalizeProd : Pat<(TF_ProdOp $arg0, $arg1, BoolAttr:$arg2), - (TFL_ReduceProdOp $arg0, $arg1, $arg2)>; +def LegalizeProd : Pat< + (TF_ProdOp $arg0, $axes, BoolAttr:$arg2), + (TFL_ReduceProdOp $arg0, (CreateTFCastToInt32Op $axes), $arg2)>; -def LegalizeAny : Pat<(TF_AnyOp $input, $reduction_indices, $keep_dims), - (TFL_ReduceAnyOp $input, $reduction_indices, $keep_dims)>; +def LegalizeAny : Pat< + (TF_AnyOp $input, $reduction_indices, $keep_dims), + (TFL_ReduceAnyOp $input, (CreateTFCastToInt32Op $reduction_indices), + $keep_dims)>; def LegalizeCast : Pat<(TF_CastOp $arg0, BoolAttr:$arg1), (TFL_CastOp $arg0)>; @@ -471,3 +476,14 @@ def LegalizeCumsum : Pat< def LegalizeReshape : Pat< (TF_ReshapeOp $input, $shape), (TFL_ReshapeOp $input, (CreateTFCastToInt32Op $shape))>; + +def LegalizeStridedSlice : Pat< + (TF_StridedSliceOp + $input, $begin, $end, $strides, $begin_mask, $end_mask, $ellipsis_mask, + $new_axis_mask, $shrink_axis_mask), + (TFL_StridedSliceOp $input, + (CreateTFCastToInt32Op $begin), (CreateTFCastToInt32Op $end), + (CreateTFCastToInt32Op $strides), (convertIntAttrTo32Bit $begin_mask), + (convertIntAttrTo32Bit $end_mask), (convertIntAttrTo32Bit $ellipsis_mask), + (convertIntAttrTo32Bit $new_axis_mask), + (convertIntAttrTo32Bit $shrink_axis_mask))>; diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index bf623976423..febbd7d2c83 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -148,7 +148,6 @@ DECL_CONVERT_OP(MatrixDiagV3); DECL_CONVERT_OP(Pack); DECL_CONVERT_OP(Split); DECL_CONVERT_OP(SplitV); -DECL_CONVERT_OP(StridedSlice); DECL_CONVERT_OP(Unpack); DECL_CONVERT_OP(RandomUniform); @@ -325,81 +324,6 @@ LogicalResult ConvertTFSplitVOp::matchAndRewrite( return success(); } -Value PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter, - Value attribute, - ArrayRef<int32_t> padding_val, int* mask) { - DenseIntElementsAttr dense_elem_attr; - SmallVector<int32_t, 8> padded_val; - - auto ranked_attr_type = attribute.getType().dyn_cast<RankedTensorType>(); - if (!ranked_attr_type || - !matchPattern(attribute, m_Constant(&dense_elem_attr))) { - // If the input attribute is neither ranked type nor constant, we - // can't do any padding. Instead we just return it. - return attribute; - } - for (const auto& idx : dense_elem_attr.getIntValues()) { - padded_val.push_back(idx.getSExtValue()); - } - auto attr_dim_count = ranked_attr_type.getShape()[0]; - int full_dim_count = padding_val.size(); - for (int i = attr_dim_count; i < full_dim_count; ++i) { - padded_val.push_back(padding_val[i]); - if (mask) *mask |= 1 << i; - } - auto type = - RankedTensorType::get({full_dim_count}, rewriter.getIntegerType(32)); - auto attr = DenseElementsAttr::get<int32_t>(type, padded_val); - return rewriter.create<ConstantOp>(op->getLoc(), type, attr); -} - -LogicalResult ConvertTFStridedSliceOp::matchAndRewrite( - Operation* op, PatternRewriter& rewriter) const { - auto tf_strided_slice_op = cast<TF::StridedSliceOp>(op); - auto ranked_input_type = - tf_strided_slice_op.input().getType().dyn_cast<RankedTensorType>(); - if (!ranked_input_type) { - // If input is not a ranked tensor, we can't deduce the padding dimensions - // from it, so we just do a plain conversion here. - rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>( - op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(), - tf_strided_slice_op.begin(), tf_strided_slice_op.end(), - tf_strided_slice_op.strides(), - rewriter.getI32IntegerAttr(tf_strided_slice_op.begin_mask()), - rewriter.getI32IntegerAttr(tf_strided_slice_op.end_mask()), - rewriter.getI32IntegerAttr(tf_strided_slice_op.ellipsis_mask()), - rewriter.getI32IntegerAttr(tf_strided_slice_op.new_axis_mask()), - rewriter.getI32IntegerAttr(tf_strided_slice_op.shrink_axis_mask())); - return success(); - } - - int num_input_dims = ranked_input_type.getRank(); - // Pad `begin` array with zero values and update the `begin_mask`. - SmallVector<int32_t, 8> begin_pad_val(num_input_dims, 0); - int begin_mask = tf_strided_slice_op.begin_mask(); - Value padded_begin = PadStridedSliceAttributeArray( - op, rewriter, tf_strided_slice_op.begin(), begin_pad_val, &begin_mask); - // Pad `end` array with `input_shape` and update the `end_mask`. - int end_mask = tf_strided_slice_op.end_mask(); - auto input_shape = ranked_input_type.getShape(); - SmallVector<int32_t, 8> end_pad_val(input_shape.begin(), input_shape.end()); - Value padded_end = PadStridedSliceAttributeArray( - op, rewriter, tf_strided_slice_op.end(), end_pad_val, &end_mask); - // Pad `strides` array with ones. - SmallVector<int32_t, 8> strides_pad_val(num_input_dims, 1); - Value padded_strides = PadStridedSliceAttributeArray( - op, rewriter, tf_strided_slice_op.strides(), strides_pad_val, nullptr); - rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>( - op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(), - padded_begin, padded_end, padded_strides, - rewriter.getI32IntegerAttr(begin_mask), - rewriter.getI32IntegerAttr(end_mask), - rewriter.getI32IntegerAttr(tf_strided_slice_op.ellipsis_mask()), - rewriter.getI32IntegerAttr(tf_strided_slice_op.new_axis_mask()), - rewriter.getI32IntegerAttr(tf_strided_slice_op.shrink_axis_mask())); - return success(); -} - LogicalResult ConvertTFUnpackOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_unpack_op = cast<TF::UnpackOp>(op); @@ -769,8 +693,8 @@ void addPatterns(MLIRContext* context, OwningRewritePatternList& patterns) { patterns .insert<ConvertTFConcatV2Op, ConvertTFMatMulOp, ConvertTFMatrixDiagV2Op, ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFSplitOp, - ConvertTFSplitVOp, ConvertTFStridedSliceOp, ConvertTFUnpackOp, - ConvertTFAssertOp, ConvertTFRandomUniformOp>(context); + ConvertTFSplitVOp, ConvertTFUnpackOp, ConvertTFAssertOp, + ConvertTFRandomUniformOp>(context); // Ophint python converter converted tf node pattern. patterns.insert<LegalizeUnidirectionalSequenceLstm, diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index a2af9b6e9d7..235117fcb92 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -376,14 +376,17 @@ void PrepareQuantizePass::runOnFunction() { OwningRewritePatternList patterns; bool is_signed = quant_specs_.IsSignedInferenceType(); int bit_width = quant_specs_.GetQuantizationTypeWidth(); - bool quantization_aware_training_mode = ContainsQuantizeOps(func); - // Enforce fixed output range for post-training quantization and - // when the model has quantization emulation ops, unless it was disabled - // explicitly by the flag. - bool enforced_output_range = - (quant_specs_.post_training_quantization || - quantization_aware_training_mode) && - !quant_specs_.disable_enforced_fixed_output_range; + // When this is true, the quantizer will try its best to extract the + // quantization parameters from the op quantization property and constant + // content. This is also set to true when the `quantize_allowlist` and + // `quantize_signed` test flags are enabled. + bool eager_quantize = ContainsQuantizeOps(func) || + (!quantize_allowlist.empty() || quantize_signed); + // Infer the tensor range for the activation ops and weight constants unless + // it is disabled explicitly. + bool infer_tensor_range = + (quant_specs_.post_training_quantization || eager_quantize) && + !quant_specs_.disable_infer_tensor_range; if (is_signed) { patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx); // Convert quant stats to int8 quantization parameters. @@ -403,7 +406,7 @@ void PrepareQuantizePass::runOnFunction() { // values (tensors). ApplyQuantizationParamsPropagation( func, is_signed, disable_per_channel || quant_specs_.disable_per_channel, - GetOpQuantSpec, enforced_output_range); + GetOpQuantSpec, infer_tensor_range); ConvertMlirQuantOpsToTFLQuantOps(func); } diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index bb44788e961..f56c8bc0d06 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -712,6 +712,23 @@ struct ConvertTFStridedSlice : public RewritePattern { return success(); } + void PadStridedSliceAttributeArray(DenseIntElementsAttr dense_elem_attr, + SmallVectorImpl<int32_t> &val, + SmallVectorImpl<int32_t> &padded_val, + ArrayRef<int32_t> padding_val, + int *mask) const { + for (const auto &idx : dense_elem_attr.getIntValues()) { + val.push_back(idx.getSExtValue()); + padded_val.push_back(idx.getSExtValue()); + } + int attr_dim_count = val.size(); + int full_dim_count = padding_val.size(); + for (int i = attr_dim_count; i < full_dim_count; ++i) { + padded_val.push_back(padding_val[i]); + if (mask) *mask |= 1 << i; + } + } + LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op); @@ -719,17 +736,102 @@ struct ConvertTFStridedSlice : public RewritePattern { // Handle new axis mask. if (strided_slice_op.new_axis_mask() != 0) { // We currently don't handle simultaneous shrink_ and new_axis masks. - if (strided_slice_op.shrink_axis_mask()) { - return failure(); + if (!strided_slice_op.shrink_axis_mask()) { + return RewriteNewAxisMask(strided_slice_op, rewriter); } - return RewriteNewAxisMask(strided_slice_op, rewriter); } // Handle ellipsis mask. if (strided_slice_op.ellipsis_mask() != 0) { return RewriteEllipsisMask(strided_slice_op, rewriter); } - return failure(); + + auto ranked_input_type = + strided_slice_op.input().getType().dyn_cast<RankedTensorType>(); + if (!ranked_input_type) { + return failure(); + } + + auto begin_attr = strided_slice_op.begin(); + auto end_attr = strided_slice_op.end(); + auto strides_attr = strided_slice_op.strides(); + + auto begin_attr_type = begin_attr.getType().dyn_cast<RankedTensorType>(); + auto end_attr_type = end_attr.getType().dyn_cast<RankedTensorType>(); + auto strides_attr_type = + strides_attr.getType().dyn_cast<RankedTensorType>(); + + DenseIntElementsAttr begin_elem_attr; + DenseIntElementsAttr end_elem_attr; + DenseIntElementsAttr strides_elem_attr; + + if (!begin_attr_type || + !matchPattern(begin_attr, m_Constant(&begin_elem_attr))) { + return failure(); + } + if (!end_attr_type || !matchPattern(end_attr, m_Constant(&end_elem_attr))) { + return failure(); + } + if (!strides_attr_type || + !matchPattern(strides_attr, m_Constant(&strides_elem_attr))) { + return failure(); + } + + SmallVector<int32_t, 4> begin, end, strides; + SmallVector<int32_t, 4> padded_begin, padded_end, padded_strides; + + int num_input_dims = ranked_input_type.getRank(); + SmallVector<int32_t, 4> padding_begin(num_input_dims, 0); + auto input_shape = ranked_input_type.getShape(); + SmallVector<int32_t, 4> padding_end(input_shape.begin(), input_shape.end()); + SmallVector<int32_t, 4> padding_strides(num_input_dims, 1); + + int begin_mask = strided_slice_op.begin_mask(); + int end_mask = strided_slice_op.end_mask(); + + PadStridedSliceAttributeArray(begin_elem_attr, begin, padded_begin, + padding_begin, &begin_mask); + PadStridedSliceAttributeArray(end_elem_attr, end, padded_end, padding_end, + &end_mask); + PadStridedSliceAttributeArray(strides_elem_attr, strides, padded_strides, + padding_strides, nullptr); + + if (begin == padded_begin && end == padded_end && + strides == padded_strides && + begin_mask == strided_slice_op.begin_mask() && + end_mask == strided_slice_op.end_mask()) { + return failure(); + } + + auto begin_end_type = + RankedTensorType::get({num_input_dims}, rewriter.getIntegerType(32)); + auto new_begin_attr = rewriter.create<ConstantOp>( + op->getLoc(), begin_end_type, + DenseElementsAttr::get<int32_t>(begin_end_type, padded_begin)); + auto new_end_attr = rewriter.create<ConstantOp>( + op->getLoc(), begin_end_type, + DenseElementsAttr::get<int32_t>(begin_end_type, padded_end)); + auto strides_type = + RankedTensorType::get({static_cast<long>(padded_strides.size())}, + rewriter.getIntegerType(32)); + auto new_strides_attr = rewriter.create<ConstantOp>( + op->getLoc(), strides_type, + DenseElementsAttr::get<int32_t>(strides_type, padded_strides)); + + auto attribute_type = rewriter.getIntegerType(64); + rewriter.replaceOpWithNewOp<TF::StridedSliceOp>( + op, strided_slice_op.output().getType(), strided_slice_op.input(), + new_begin_attr, new_end_attr, new_strides_attr, + rewriter.getIntegerAttr(attribute_type, begin_mask), + rewriter.getIntegerAttr(attribute_type, end_mask), + rewriter.getIntegerAttr(attribute_type, + strided_slice_op.ellipsis_mask()), + rewriter.getIntegerAttr(attribute_type, + strided_slice_op.new_axis_mask()), + rewriter.getIntegerAttr(attribute_type, + strided_slice_op.shrink_axis_mask())); + + return success(); } }; diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index bedb8d0c5ea..d26f09c5c91 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -337,6 +337,7 @@ cc_library( ":tensorflow", "//tensorflow/compiler/mlir/hlo", "//tensorflow/core:framework", + "//tensorflow/core:lib", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:Dialect", @@ -859,6 +860,7 @@ cc_library( "transforms/mark_ops_for_outside_compilation.cc", "transforms/materialize_mlir_passthrough_op.cc", "transforms/optimize.cc", + "transforms/outside_compiled_to_host_launch.cc", "transforms/parallel_execute_to_islands.cc", "transforms/parallelize_embedding_params_ops_pass.cc", "transforms/promote_resources_to_args.cc", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index c8d1a83058f..76e8836fef1 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -31,7 +31,7 @@ limitations under the License. include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td" include "mlir/Interfaces/InferTypeOpInterface.td" -def TF_AbsOp : TF_Op<"Abs", [NoSideEffect, SameOperandsAndResultType]> { +def TF_AbsOp : TF_Op<"Abs", [Idempotent, NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes the absolute value of a tensor."; let description = [{ @@ -1002,13 +1002,13 @@ reverse of SpaceToBatch. See below for a precise description. TF_Tensor:$output ); - let verifier = [{ - return Verify(*this); - }]; - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr Tcrops = TF_DerivedOperandTypeAttr<2>; TF_DerivedOperandTypeAttr Tblock_shape = TF_DerivedOperandTypeAttr<1>; + + let verifier = [{ + return Verify(*this); + }]; } def TF_BetaincOp : TF_Op<"Betainc", [NoSideEffect]> { @@ -1486,7 +1486,7 @@ def TF_CastOp : TF_Op<"Cast", [NoSideEffect, SameOperandsAndResultShape]> { let hasFolder = 1; } -def TF_CeilOp : TF_Op<"Ceil", [NoSideEffect, SameOperandsAndResultType]> { +def TF_CeilOp : TF_Op<"Ceil", [Idempotent, NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns element-wise smallest integer not less than x."; let arguments = (ins @@ -3502,8 +3502,8 @@ tf.math.equal(x, y) ==> array([True, True]) }]; let arguments = (ins - TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, - TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y, + TF_Tensor:$x, + TF_Tensor:$y, DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error ); @@ -3843,8 +3843,8 @@ def TF_FakeQuantWithMinMaxArgsGradientOp : TF_Op<"FakeQuantWithMinMaxArgsGradien let summary = "Compute gradients for a FakeQuantWithMinMaxArgs operation."; let arguments = (ins - F32Tensor:$gradients, - F32Tensor:$inputs, + TF_Float32Tensor:$gradients, + TF_Float32Tensor:$inputs, DefaultValuedAttr<F32Attr, "-6.0f">:$min, DefaultValuedAttr<F32Attr, "6.0f">:$max, @@ -3853,7 +3853,7 @@ def TF_FakeQuantWithMinMaxArgsGradientOp : TF_Op<"FakeQuantWithMinMaxArgsGradien ); let results = (outs - F32Tensor:$backprops + TF_Float32Tensor:$backprops ); } @@ -3911,19 +3911,19 @@ def TF_FakeQuantWithMinMaxVarsGradientOp : TF_Op<"FakeQuantWithMinMaxVarsGradien let summary = "Compute gradients for a FakeQuantWithMinMaxVars operation."; let arguments = (ins - F32Tensor:$gradients, - F32Tensor:$inputs, - F32Tensor:$min, - F32Tensor:$max, + TF_Float32Tensor:$gradients, + TF_Float32Tensor:$inputs, + TF_Float32Tensor:$min, + TF_Float32Tensor:$max, DefaultValuedAttr<I64Attr, "8">:$num_bits, DefaultValuedAttr<BoolAttr, "false">:$narrow_range ); let results = (outs - F32Tensor:$backprops_wrt_input, - F32Tensor:$backprop_wrt_min, - F32Tensor:$backprop_wrt_max + TF_Float32Tensor:$backprops_wrt_input, + TF_Float32Tensor:$backprop_wrt_min, + TF_Float32Tensor:$backprop_wrt_max ); } @@ -4026,7 +4026,7 @@ fill([2, 3], 9) ==> [[9, 9, 9] ]; } -def TF_FloorOp : TF_Op<"Floor", [NoSideEffect, SameOperandsAndResultType]> { +def TF_FloorOp : TF_Op<"Floor", [Idempotent, NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns element-wise largest integer not greater than x."; let arguments = (ins @@ -4977,13 +4977,13 @@ $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ }]; let arguments = (ins - F32Tensor:$predictions, + TF_Float32Tensor:$predictions, TF_I32OrI64Tensor:$targets, TF_I32OrI64Tensor:$k ); let results = (outs - I1Tensor:$precision + TF_BoolTensor:$precision ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; @@ -6912,7 +6912,7 @@ retained with length 1. ]; } -def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInterface]> { +def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> { let summary = "Performs max pooling on the input."; let arguments = (ins @@ -6936,6 +6936,9 @@ def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInter SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; } SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; } LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation); + // TF_LayoutSensitiveInterface: + StringRef GetOptimalLayout(const RuntimeDevices& devices); + LogicalResult UpdateDataFormat(StringRef data_format); }]; } @@ -7852,8 +7855,8 @@ def TF_NotEqualOp : TF_Op<"NotEqual", [Commutative, NoSideEffect]> { }]; let arguments = (ins - TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, - TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y, + TF_Tensor:$x, + TF_Tensor:$y, DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error ); @@ -8031,7 +8034,7 @@ times by rerunning "MakeIterator". ); } -def TF_OnesLikeOp : TF_Op<"OnesLike", [NoSideEffect, SameOperandsAndResultType]> { +def TF_OnesLikeOp : TF_Op<"OnesLike", [Idempotent, NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns a tensor of ones with the same shape and type as x."; let arguments = (ins @@ -8429,6 +8432,10 @@ def TF_QrOp : TF_Op<"Qr", [NoSideEffect]> { Computes the QR decomposition of each inner matrix in `tensor` such that `tensor[..., :, :] = q[..., :, :] * r[..., :,:])` +Currently, the gradient for the QR decomposition is well-defined only when +the first `P` columns of the inner matrix are linearly independent, where +`P` is the minimum of `M` and `N`, the 2 inner-most dimmensions of `tensor`. + ```python # a is a tensor. # q is a tensor of orthonormal matrices. @@ -9114,7 +9121,7 @@ most one RecvTPUEmbeddingActivations op in the TPU graph. TF_DerivedResultSizeAttr num_outputs = TF_DerivedResultSizeAttr<0>; } -def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface, TF_LayoutAgnostic]> { +def TF_ReluOp : TF_Op<"Relu", [Idempotent, NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface, TF_LayoutAgnostic]> { let summary = "Computes rectified linear: `max(features, 0)`."; let description = [{ @@ -9140,7 +9147,7 @@ array([ 0., 0., -0., 3.], dtype=float32) }]; } -def TF_Relu6Op : TF_Op<"Relu6", [NoSideEffect, SameOperandsAndResultType]> { +def TF_Relu6Op : TF_Op<"Relu6", [Idempotent, NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes rectified linear 6: `min(max(features, 0), 6)`."; let arguments = (ins @@ -10538,7 +10545,7 @@ bitwise_ops.right_shift(lhs, rhs) TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_RintOp : TF_Op<"Rint", [NoSideEffect, SameOperandsAndResultType]> { +def TF_RintOp : TF_Op<"Rint", [Idempotent, NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns element-wise integer closest to x."; let description = [{ @@ -10605,7 +10612,7 @@ roll(t, shift=[2, -3], axis=[1, 1]) ==> [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]] TF_DerivedOperandTypeAttr Taxis = TF_DerivedOperandTypeAttr<2>; } -def TF_RoundOp : TF_Op<"Round", [NoSideEffect, SameOperandsAndResultType]> { +def TF_RoundOp : TF_Op<"Round", [Idempotent, NoSideEffect, SameOperandsAndResultType]> { let summary = [{ Rounds the values of a tensor to the nearest integer, element-wise. }]; @@ -11335,7 +11342,7 @@ Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_SignOp : TF_Op<"Sign", [NoSideEffect, SameOperandsAndResultType]> { +def TF_SignOp : TF_Op<"Sign", [Idempotent, NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns an element-wise indication of the sign of a number."; let description = [{ @@ -12345,14 +12352,14 @@ The outputs are a deterministic function of `shape`, `seed`, and `alpha`. } def TF_StatelessRandomGetAlgOp : TF_Op<"StatelessRandomGetAlg", []> { - let summary = [{ -Picks the best counter-based RNG algorithm based on device. - }]; + let summary = "Picks the best counter-based RNG algorithm based on device."; let description = [{ This op picks the best counter-based RNG algorithm based on device. }]; + let arguments = (ins); + let results = (outs TF_Int32Tensor:$alg ); @@ -14088,73 +14095,35 @@ This operation is very similar to `tf.scatter_nd`, except that the updates are scattered onto an existing tensor (as opposed to a zero-tensor). If the memory for the existing tensor cannot be re-used, a copy is made and updated. -If `indices` contains duplicates, then their updates are accumulated (summed). +If `indices` contains duplicates, then we pick the last update for the index. -**WARNING**: The order in which updates are applied is nondeterministic, so the -output will be nondeterministic if `indices` contains duplicates -- because -of some numerical approximation issues, numbers summed in different order -may yield different results. +If an out of bound index is found on CPU, an error is returned. + +**WARNING**: There are some GPU specific semantics for this operation. +- If an out of bound index is found, the index is ignored. +- The order in which updates are applied is nondeterministic, so the output +will be nondeterministic if `indices` contains duplicates. `indices` is an integer tensor containing indices into a new tensor of shape -`shape`. The last dimension of `indices` can be at most the rank of `shape`: +`shape`. - indices.shape[-1] <= shape.rank +* `indices` must have at least 2 axes: `(num_updates, index_depth)`. +* The last axis of `indices` is how deep to index into `tensor` so this index + depth must be less than the rank of `tensor`: `indices.shape[-1] <= tensor.ndim` -The last dimension of `indices` corresponds to indices into elements -(if `indices.shape[-1] = shape.rank`) or slices -(if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of -`shape`. `updates` is a tensor with shape +if `indices.shape[-1] = tensor.rank` this Op indexes and updates scalar elements. +if `indices.shape[-1] < tensor.rank` it indexes and updates slices of the input +`tensor`. - indices.shape[:-1] + shape[indices.shape[-1]:] +Each `update` has a rank of `tensor.rank - indices.shape[-1]`. +The overall shape of `updates` is: -The simplest form of scatter is to insert individual elements in a tensor by -index. For example, say we want to insert 4 scattered elements in a rank-1 -tensor with 8 elements. +``` +indices.shape[:-1] + tensor.shape[indices.shape[-1]:] +``` -<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> -<img style="width:100%" src="https://www.tensorflow.org/images/ScatterNd1.png" alt> -</div> - -In Python, this scatter operation would look like this: - - >>> indices = tf.constant([[4], [3], [1], [7]]) - >>> updates = tf.constant([9, 10, 11, 12]) - >>> tensor = tf.ones([8], dtype=tf.int32) - >>> print(tf.tensor_scatter_nd_update(tensor, indices, updates)) - tf.Tensor([ 1 11 1 10 9 1 1 12], shape=(8,), dtype=int32) - -We can also, insert entire slices of a higher rank tensor all at once. For -example, if we wanted to insert two slices in the first dimension of a -rank-3 tensor with two matrices of new values. - -In Python, this scatter operation would look like this: - - >>> indices = tf.constant([[0], [2]]) - >>> updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6], - ... [7, 7, 7, 7], [8, 8, 8, 8]], - ... [[5, 5, 5, 5], [6, 6, 6, 6], - ... [7, 7, 7, 7], [8, 8, 8, 8]]]) - >>> tensor = tf.ones([4, 4, 4], dtype=tf.int32) - >>> print(tf.tensor_scatter_nd_update(tensor, indices, updates).numpy()) - [[[5 5 5 5] - [6 6 6 6] - [7 7 7 7] - [8 8 8 8]] - [[1 1 1 1] - [1 1 1 1] - [1 1 1 1] - [1 1 1 1]] - [[5 5 5 5] - [6 6 6 6] - [7 7 7 7] - [8 8 8 8]] - [[1 1 1 1] - [1 1 1 1] - [1 1 1 1] - [1 1 1 1]]] - -Note that on CPU, if an out of bound index is found, an error is returned. -On GPU, if an out of bound index is found, the index is ignored. +For usage examples see the python [tf.tensor_scatter_nd_update]( +https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update) function }]; let arguments = (ins @@ -15080,7 +15049,7 @@ https://www.tensorflow.org/xla/operation_semantics#gather }]; let arguments = (ins - Arg<TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{The array we're gathering from.}]>:$operand, + Arg<TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{The array we're gathering from.}]>:$operand, Arg<TF_I32OrI64Tensor, [{Array containing the starting indices of the slices we gather.}]>:$start_indices, Arg<TF_I32OrI64Tensor, [{slice_sizes[i] is the bounds for the slice on dimension i.}]>:$slice_sizes, @@ -15089,7 +15058,7 @@ https://www.tensorflow.org/xla/operation_semantics#gather ); let results = (outs - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; @@ -15457,7 +15426,7 @@ def TF_XlogyOp : TF_Op<"Xlogy", [NoSideEffect, ResultsBroadcastableShape, TF_Sam TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType]> { +def TF_ZerosLikeOp : TF_Op<"ZerosLike", [Idempotent, NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns a tensor of zeros with the same shape and type as x."; let arguments = (ins diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 817dfc4cfe2..20a3a9d8a7e 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -684,12 +684,23 @@ body: A function that takes a list of tensors and returns another FlatSymbolRefAttr:$cond, FlatSymbolRefAttr:$body, - DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes, DefaultValuedAttr<I64Attr, "10">:$parallel_iterations, // Used to map StatelessWhile and While op defined in TensorFlow to a common // op. - BoolAttr:$is_stateless + BoolAttr:$is_stateless, + + // In TensorFlow, While has a special behavior where if `output_shapes` + // attribute is not empty, those shapes are used in its shape function + // as result shapes instead of propagating operand shapes as result shapes. + // This allows for different result shapes from operand shapes. While these + // shapes are imported and set as a part of the result type, there is no + // indicator differentiating between having no output shapes compared to + // having all unranked shapes. Thus this attribute is set to determine + // which shape function behavior to use for this op, specifically + // propagating operand shapes as result shapes when this attribute is not + // set, or preserving result shapes as is when this attribute is set. + UnitAttr:$shape_invariant ); let results = (outs @@ -697,6 +708,7 @@ body: A function that takes a list of tensors and returns another ); TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>; + TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>; let verifier = [{ return Verify(*this); @@ -752,12 +764,23 @@ def TF_WhileRegionOp : TF_Op<"WhileRegion", let arguments = (ins Variadic<AnyTensor>:$input, - DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes, DefaultValuedAttr<I64Attr, "10">:$parallel_iterations, // Used to map StatelessWhile and While op defined in TensorFlow to a common // op. - BoolAttr:$is_stateless + BoolAttr:$is_stateless, + + // In TensorFlow, While has a special behavior where if `output_shapes` + // attribute is not empty, those shapes are used in its shape function + // as result shapes instead of propagating operand shapes as result shapes. + // This allows for different result shapes from operand shapes. While these + // shapes are imported and set as a part of the result type, there is no + // indicator differentiating between having no output shapes compared to + // having all unranked shapes. Thus this attribute is set to determine + // which shape function behavior to use for this op, specifically + // propagating operand shapes as result shapes when this attribute is not + // set, or preserving result shapes as is when this attribute is set. + UnitAttr:$shape_invariant ); let results = (outs Variadic<AnyTensor>:$output); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index 9dde6c5769e..b48338e3ced 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -2467,6 +2467,33 @@ LogicalResult MaxPoolOp::FoldOperandsPermutation( permutation, this, {{"strides", strides()}, {"ksize", ksize()}}); } +LogicalResult MaxPoolOp::UpdateDataFormat(StringRef new_data_format) { + StringRef src_data_format = data_format(); + + auto perm = GetDataFormatPermutation(src_data_format, new_data_format); + if (perm.empty()) return failure(); + + // Update data_format attribute and result types. + if (failed(::mlir::TF::UpdateDataFormat(new_data_format, this))) + return failure(); + + stridesAttr(ShuffleArrayAttr(strides(), perm)); + explicit_paddingsAttr(ShuffleArrayAttr(explicit_paddings(), perm, 2)); + ksizeAttr(ShuffleArrayAttr(ksize(), perm)); + + return success(); +} + +StringRef MaxPoolOp::GetOptimalLayout(const RuntimeDevices &devices) { + // Keep current data format if no GPUs are available or if explicit placement + // does not allow to use GPU for this operation. + if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) + return data_format(); + + // Defaults to NCHW. + return "NCHW"; +} + //===----------------------------------------------------------------------===// // MaxPoolGradOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir b/tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir index afc9e1e51ed..b4738ed5605 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir @@ -41,3 +41,34 @@ func @broadcast_mul_implicit_no_fold(%arg0: tensor<5x7xf32>, %arg1: tensor<5xf32 // CHECK: %[[V1:.*]] = "tf.Mul"(%arg0, %[[V0]]) : (tensor<5x7xf32>, tensor<3x5x7xf32>) -> tensor<3x5x7xf32> // CHECK: %[[V1]] : tensor<3x5x7xf32> } + +// CHECK-LABEL: @broadcast_eq +func @broadcast_eq(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xi1> { + %cst = constant dense<[5, 7]> : tensor<2xi32> + %0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32> + %1 = "tf.Equal"(%arg0, %0) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xi1> + return %1 : tensor<5x7xi1> + // CHECK: %[[V0:.*]] = "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xi1> + // CHECK: %[[V0]] : tensor<5x7xi1> +} + +// CHECK-LABEL: @broadcast_neq +func @broadcast_neq(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xi1> { + %cst = constant dense<[5, 7]> : tensor<2xi32> + %0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32> + %1 = "tf.NotEqual"(%arg0, %0) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xi1> + return %1 : tensor<5x7xi1> + // CHECK: %[[V0:.*]] = "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xi1> + // CHECK: %[[V0]] : tensor<5x7xi1> +} + +// CHECK-LABEL: @broadcast_both_operand +func @broadcast_both_operand(%arg0: tensor<7xf32>, %arg1: tensor<5x1xf32>) -> tensor<5x7xf32> { + %cst = constant dense<[5, 7]> : tensor<2xi64> + %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<7xf32>, tensor<2xi64>) -> tensor<5x7xf32> + %1 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<5x1xf32>, tensor<2xi64>) -> tensor<5x7xf32> + %2 = "tf.Add"(%0, %1) : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xf32> + return %2 : tensor<5x7xf32> + // CHECK: %[[V0:.*]] = "tf.Add"(%arg0, %arg1) : (tensor<7xf32>, tensor<5x1xf32>) -> tensor<5x7xf32> + // CHECK: %[[V0]] : tensor<5x7xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt index dbfc42a11ae..890ca03b4a7 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1 -o - -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1,WhileWithOutputShapes:1 -o - -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s # Verify that TensorFlow While and StatelessWhile ops are mapped to the # composite While op in MLIR with is_stateless attribute set accordingly to @@ -6,6 +6,7 @@ # CHECK-DAG: "tf.While"{{.*}} is_stateless = false{{.*}} loc("StatefulWhile") # CHECK-DAG: "tf.While"{{.*}} is_stateless = true{{.*}} loc("StatelessWhile") +# CHECK-DAG: "tf.While"{{.*}} is_stateless = false{{.*}} shape_invariant{{.*}} -> (tensor<i32>, tensor<*xf32>) loc("WhileWithOutputShapes") node { name: "StatefulWhile" @@ -73,6 +74,51 @@ node { experimental_debug_info { } } +node { + name: "WhileWithOutputShapes" + op: "While" + input: "iter" + input: "val" + attr { + key: "T" + value { + list { + type: DT_INT32 + type: DT_FLOAT + } + } + } + attr { + key: "body" + value { + func { + name: "body" + } + } + } + attr { + key: "cond" + value { + func { + name: "cond" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + shape { + unknown_rank: true + } + } + } + } + experimental_debug_info { + } +} node { name: "main" op: "_Retval" @@ -107,6 +153,23 @@ node { } } } +node { + name: "main2" + op: "_Retval" + input: "WhileWithOutputShapes:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "index" + value { + i: 2 + } + } +} node { name: "iter" op: "Placeholder" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir index 0034d3f4308..342804f1230 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir @@ -82,3 +82,20 @@ func @bias_add_nchw(%arg0: tensor<1x256x150x150xf32>, %arg1: tensor<256xf32>) -> %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW", device = ""} : (tensor<1x256x150x150xf32>, tensor<256xf32>) -> tensor<1x256x150x150xf32> return %0 : tensor<1x256x150x150xf32> } + +// CHECK-LABEL: maxpool_nchw +func @maxpool_nchw(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32> { + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} + // CHECK: %[[R0:.*]] = "tf.Transpose"(%arg0, %[[CST]]) + // CHECK: %[[R1:.*]] = "tf.MaxPool"(%[[R0]]) {data_format = "NHWC", explicit_paddings = [], ksize = [1, 3, 3, 1], padding = "SAME", strides = [1, 2, 2, 1]} + // CHECK: %[[CST_0:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} + // CHECK: "tf.Transpose"(%[[R1]], %[[CST_0]]) + %0 = "tf.MaxPool"(%arg0) + { + data_format = "NCHW", + ksize = [1, 1, 3, 3], + padding = "SAME", + strides = [1, 1, 2, 2] + } : (tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32> + return %0 : tensor<1x64x56x56xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index 4be65bd0a3b..311b84f0374 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -1703,3 +1703,110 @@ func @convert_iota_3d() -> tensor<5x7x9xi32> { return %0 : tensor<5x7x9xi32> } +// CHECK-LABEL: func @convert_avgpool_valid( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> +// CHECK: return %[[VAL_1]] : tensor<4x7x7x8xf32> +// CHECK: } +func @convert_avgpool_valid(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> { + %0 = mhlo.constant dense<0.0> : tensor<f32> + %1 = mhlo.constant dense<9.0> : tensor<4x7x7x8xf32> + %2 = "mhlo.reduce_window"(%arg0, %0) ( { + ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): + %5 = mhlo.add %arg1, %arg2 : tensor<f32> + "mhlo.return"(%5) : (tensor<f32>) -> () + }) { + base_dilations = dense<1> : tensor<4xi64>, + padding = dense<0> : tensor<4x2xi64>, + window_dilations = dense<1> : tensor<4xi64>, + window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32> + %3 = mhlo.divide %2, %1 : tensor<4x7x7x8xf32> + return %3 : tensor<4x7x7x8xf32> +} + +// CHECK-LABEL: func @convert_avgpool_valid_rw( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> +// CHECK: return %[[VAL_1]] : tensor<4x7x7x8xf32> +// CHECK: } +func @convert_avgpool_valid_rw(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> { + %0 = mhlo.constant dense<1.0> : tensor<4x16x16x8xf32> + %1 = mhlo.constant dense<0.0> : tensor<f32> + %2 = "mhlo.reduce_window"(%arg0, %1) ( { + ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): + %6 = mhlo.add %arg1, %arg2 : tensor<f32> + "mhlo.return"(%6) : (tensor<f32>) -> () + }) { + base_dilations = dense<1> : tensor<4xi64>, + padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>, + window_dilations = dense<1> : tensor<4xi64>, + window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32> + %3 = "mhlo.reduce_window"(%0, %1) ( { + ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): + %6 = mhlo.add %arg1, %arg2 : tensor<f32> + "mhlo.return"(%6) : (tensor<f32>) -> () + }) { + base_dilations = dense<1> : tensor<4xi64>, + padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>, + window_dilations = dense<1> : tensor<4xi64>, + window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32> + %4 = mhlo.divide %2, %3 : tensor<4x7x7x8xf32> + return %4 : tensor<4x7x7x8xf32> +} + +// CHECK-LABEL: func @convert_avgpool_valid_3d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.AvgPool3D"(%[[VAL_0]]) {data_format = "NDHWC", ksize = [1, 3, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 2, 1]} : (tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32> +// CHECK: return %[[VAL_1]] : tensor<4x7x7x7x8xf32> +// CHECK: } +func @convert_avgpool_valid_3d(%arg0: tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32> { + %0 = mhlo.constant dense<0.0> : tensor<f32> + %1 = mhlo.constant dense<27.0> : tensor<4x7x7x7x8xf32> + %2 = "mhlo.reduce_window"(%arg0, %0) ( { + ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): + %5 = mhlo.add %arg1, %arg2 : tensor<f32> + "mhlo.return"(%5) : (tensor<f32>) -> () + }) { + base_dilations = dense<1> : tensor<5xi64>, + padding = dense<0> : tensor<5x2xi64>, + window_dilations = dense<1> : tensor<5xi64>, + window_dimensions = dense<[1, 3, 3, 3, 1]> : tensor<5xi64>, + window_strides = dense<[1, 2, 2, 2, 1]> : tensor<5xi64>} : (tensor<4x16x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x7x8xf32> + %3 = mhlo.divide %2, %1 : tensor<4x7x7x7x8xf32> + return %3 : tensor<4x7x7x7x8xf32> +} + +// CHECK-LABEL: func @convert_avgpool_same( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "SAME", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> +// CHECK: return %[[VAL_1]] : tensor<4x8x8x8xf32> +// CHECK: } +func @convert_avgpool_same(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> { + %0 = mhlo.constant dense<1.0> : tensor<4x16x16x8xf32> + %1 = mhlo.constant dense<0.0> : tensor<f32> + %2 = "mhlo.reduce_window"(%arg0, %1) ( { + ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): + %6 = mhlo.add %arg1, %arg2 : tensor<f32> + "mhlo.return"(%6) : (tensor<f32>) -> () + }) { + base_dilations = dense<1> : tensor<4xi64>, + padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, + window_dilations = dense<1> : tensor<4xi64>, + window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x8x8x8xf32> + %3 = "mhlo.reduce_window"(%0, %1) ( { + ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): + %6 = mhlo.add %arg1, %arg2 : tensor<f32> + "mhlo.return"(%6) : (tensor<f32>) -> () + }) { + base_dilations = dense<1> : tensor<4xi64>, + padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, + window_dilations = dense<1> : tensor<4xi64>, + window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x8x8x8xf32> + %4 = mhlo.divide %2, %3 : tensor<4x8x8x8xf32> + return %4 : tensor<4x8x8x8xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir index dc2a6e8bd05..52937fc0a5b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir @@ -244,9 +244,9 @@ func @fourdim_space_to_batch_nd(%input: tensor<3x5x7x10xf32>, %block_shape: tens // CHECK-DAG: [[PAD_DEFAULT:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} // CHECK-DAG: [[PADDED:%.+]] = "tf.PadV2"(%arg0, [[FULL_PADDINGS]], [[PAD_DEFAULT]]) // CHECK-DAG: [[PADDINGS:%.+]]:2 = "tf.Unpack"([[FULL_PADDINGS]]) {axis = 1 : i64} - // CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.Add"([[PADDINGS]]#0, [[PADDINGS]]#1) + // CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.AddV2"([[PADDINGS]]#0, [[PADDINGS]]#1) // CHECK-DAG: [[INPUT_SHAPE:%.+]] = "tf.Const"() {value = dense<[3, 5, 7, 10]> : tensor<4xi64>} - // CHECK-DAG: [[PADDED_SHAPE:%.+]] = "tf.Add"([[PADDINGS_SUM]], [[INPUT_SHAPE]]) + // CHECK-DAG: [[PADDED_SHAPE:%.+]] = "tf.AddV2"([[PADDINGS_SUM]], [[INPUT_SHAPE]]) // CHECK-DAG: [[PADDED_SHAPE_SPLITS:%.+]]:4 = "tf.Split"([[ZERO_I32]], [[PADDED_SHAPE]]) // CHECK-DAG: [[BLOCK_SHAPE_SPLITS:%.+]]:2 = "tf.Split"([[ZERO_I32]], %arg1) // CHECK-DAG: [[OUTER_SHAPE_0:%.+]] = "tf.Div"([[PADDED_SHAPE_SPLITS]]#1, [[BLOCK_SHAPE_SPLITS]]#0) @@ -338,10 +338,10 @@ func @fake_quant_with_min_max_args(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> { // CHECK-DAG: [[VAL5:%.+]] = "tf.ClipByValue"(%arg0, [[VAL2]], [[VAL1]]) // CHECK-DAG: [[VAL6:%.+]] = "tf.Sub"([[VAL5]], [[VAL2]]) // CHECK-DAG: [[VAL7:%.+]] = "tf.Mul"([[VAL6]], [[VAL0]]) - // CHECK-DAG: [[VAL8:%.+]] = "tf.Add"([[VAL7]], [[VAL4]]) + // CHECK-DAG: [[VAL8:%.+]] = "tf.AddV2"([[VAL7]], [[VAL4]]) // CHECK-DAG: [[VAL9:%.+]] = "tf.Floor"([[VAL8]]) // CHECK-DAG: [[VAL10:%.+]] = "tf.Mul"([[VAL9]], [[VAL3]]) - // CHECK-DAG: [[VAL11:%.+]] = "tf.Add"([[VAL10]], [[VAL2]]) + // CHECK-DAG: [[VAL11:%.+]] = "tf.AddV2"([[VAL10]], [[VAL2]]) %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {max = 1.0 : f32, min = -1.0 : f32, narrow_range = false, num_bits = 8 : i64} : (tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: return [[VAL11]] @@ -361,7 +361,7 @@ func @fake_quant_with_min_max_vars(%arg0 : tensor<?x?xf32>, %arg1 : tensor<f32>, // CHECK-DAG: [[VAL9:%.+]] = "tf.Floor"([[VAL8]]) // CHECK-DAG: [[VAL10:%.+]] = "tf.Sub"([[VAL8]], [[VAL9]]) // CHECK-DAG: [[VAL11:%.+]] = "tf.Less"([[VAL10]], [[VAL3]]) - // CHECK-DAG: [[VAL12:%.+]] = "tf.Add"([[VAL2]], [[VAL9]]) + // CHECK-DAG: [[VAL12:%.+]] = "tf.AddV2"([[VAL9]], [[VAL2]]) // CHECK-DAG: [[VAL13:%.+]] = "tf.Select"([[VAL11]], [[VAL9]], [[VAL12]]) // CHECK-DAG: [[VAL14:%.+]] = "tf.ClipByValue"([[VAL13]], [[VAL0]], [[VAL1]]) : // CHECK-DAG: [[VAL15:%.+]] = "tf.Sub"([[VAL0]], [[VAL14]]) @@ -371,10 +371,10 @@ func @fake_quant_with_min_max_vars(%arg0 : tensor<?x?xf32>, %arg1 : tensor<f32>, // CHECK-DAG: [[VAL19:%.+]] = "tf.ClipByValue"(%arg0, [[VAL17]], [[VAL18]]) // CHECK-DAG: [[VAL20:%.+]] = "tf.Sub"([[VAL19]], [[VAL17]]) // CHECK-DAG: [[VAL21:%.+]] = "tf.Mul"([[VAL20]], [[VAL6]]) - // CHECK-DAG: [[VAL22:%.+]] = "tf.Add"([[VAL21]], [[VAL3]]) + // CHECK-DAG: [[VAL22:%.+]] = "tf.AddV2"([[VAL21]], [[VAL3]]) // CHECK-DAG: [[VAL23:%.+]] = "tf.Floor"([[VAL22]]) // CHECK-DAG: [[VAL24:%.+]] = "tf.Mul"([[VAL23]], [[VAL5]]) - // CHECK-DAG: [[VAL25:%.+]] = "tf.Add"([[VAL24]], [[VAL17]]) + // CHECK-DAG: [[VAL25:%.+]] = "tf.AddV2"([[VAL24]], [[VAL17]]) %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {narrow_range = false, num_bits = 8 : i64} : (tensor<?x?xf32>, tensor<f32>, tensor<f32>) -> tensor<?x?xf32> // CHECK: return [[VAL25]] @@ -746,7 +746,7 @@ func @round(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-DAG: [[HALF:%.+]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<f32>} // CHECK-DAG: [[CMP:%.+]] = "tf.Less"([[SUB]], [[HALF]]) // CHECK-DAG: [[ONE:%.+]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} - // CHECK-DAG: [[ADD:%.+]] = "tf.Add"([[ONE]], [[FLOOR]]) + // CHECK-DAG: [[ADD:%.+]] = "tf.AddV2"([[FLOOR]], [[ONE]]) // CHECK-DAG: [[SELECT:%.+]] = "tf.Select"([[CMP]], [[FLOOR]], [[ADD]]) %0 = "tf.Round"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> @@ -761,7 +761,7 @@ func @round_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { // CHECK-DAG: [[HALF:%.+]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<f32>} // CHECK-DAG: [[CMP:%.+]] = "tf.Less"([[SUB]], [[HALF]]) // CHECK-DAG: [[ONE:%.+]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} - // CHECK-DAG: [[ADD:%.+]] = "tf.Add"([[ONE]], [[FLOOR]]) + // CHECK-DAG: [[ADD:%.+]] = "tf.AddV2"([[FLOOR]], [[ONE]]) // CHECK-DAG: [[SELECT:%.+]] = "tf.Select"([[CMP]], [[FLOOR]], [[ADD]]) %0 = "tf.Round"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-while-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-while-ops.mlir index 4bb324d0d85..8f5066d4162 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-while-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-while-ops.mlir @@ -1,12 +1,13 @@ // RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s -func @main(%arg0: tensor<i32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) { - %0:2 = tf_executor.graph { - %outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false, output_shapes = [#tf.shape<>, #tf.shape<5>]} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatefulWhile") - %outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true, output_shapes = [#tf.shape<>, #tf.shape<5>]} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatelessWhile") - tf_executor.fetch %outputs_2#1, %outputs_4#1 : tensor<5xf32>, tensor<5xf32> +func @main(%arg0: tensor<i32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) { + %0:3 = tf_executor.graph { + %outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatefulWhile") + %outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatelessWhile") + %outputs_6:2, %control_7 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false, shape_invariant} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("WhileWithOutputShapes") + tf_executor.fetch %outputs_2#1, %outputs_4#1, %outputs_6#1 : tensor<5xf32>, tensor<5xf32>, tensor<5xf32> } - return %0#0, %0#1 : tensor<5xf32>, tensor<5xf32> + return %0#0, %0#1, %0#2 : tensor<5xf32>, tensor<5xf32>, tensor<5xf32> } func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> { @@ -36,6 +37,7 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor // CHECK-NOT: name: // CHECK: op: "While" // CHECK-NOT: is_stateless +// CHECK-NOT: shape_invariant // CHECK: attr { // CHECK: key: "output_shapes" // CHECK: value { @@ -54,6 +56,7 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor // CHECK-NOT: name: // CHECK: op: "StatelessWhile" // CHECK-NOT: is_stateless +// CHECK-NOT: shape_invariant // CHECK: attr { // CHECK: key: "output_shapes" // CHECK: value { @@ -67,3 +70,20 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor // CHECK: } // CHECK: } +// CHECK: name: "WhileWithOutputShapes" +// CHECK-NOT: name: +// CHECK: op: "While" +// CHECK-NOT: is_stateless +// CHECK-NOT: shape_invariant +// CHECK: attr { +// CHECK: key: "output_shapes" +// CHECK: value { +// CHECK: list { +// CHECK: shape { +// CHECK: dim { +// CHECK: size: 5 +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/outside_compiled_to_host_launch.mlir b/tensorflow/compiler/mlir/tensorflow/tests/outside_compiled_to_host_launch.mlir new file mode 100644 index 00000000000..44aee930fc3 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/outside_compiled_to_host_launch.mlir @@ -0,0 +1,153 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-outside-compiled-to-host-launch | FILECHECK_OPTS="" FileCheck %s + +// expected-error@+1 {{'module' op bad 'tf.devices' attribute at index 0, not a string}} +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = [1]} { + // Tests that missing `_xla_outside_compilation` attribute value results in an error. + func @invalid_device_attribute() -> tensor<?xi32> { + %0 = "tf_device.cluster"() ( { + %1 = "tf.A"() : () -> tensor<?xi32> + %2 = "tf.B"(%1) {_xla_outside_compilation = ""}: (tensor<?xi32>) -> tensor<?xi32> + tf_device.return %2 : tensor<?xi32> + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32> + return %0 : tensor<?xi32> + } +} + +// ----- + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + // Tests that missing `_xla_outside_compilation` attribute value results in an error. + func @empty_outside_compilation_attribute() -> tensor<?xi32> { + %0 = "tf_device.cluster"() ( { + %1 = "tf.A"() : () -> tensor<?xi32> + // expected-error@+1 {{'tf.B' op requires non empty '_xla_outside_compilation' string attribute}} + %2 = "tf.B"(%1) {_xla_outside_compilation = ""}: (tensor<?xi32>) -> tensor<?xi32> + tf_device.return %2 : tensor<?xi32> + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32> + return %0 : tensor<?xi32> + } +} + +// ----- + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + + // Tests that TPU cluster with no outside compilation does not generate launch op. + + // CHECK-LABEL: func @no_outside_compilation + // CHECK-NOT: "tf_device.launch" + func @no_outside_compilation() -> tensor<?xi32> { + %0 = "tf_device.cluster"() ( { + %1 = "tf.A"() : () -> tensor<?xi32> + %2 = "tf.B"(%1) : (tensor<?xi32>) -> tensor<?xi32> + tf_device.return %2 : tensor<?xi32> + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32> + return %0 : tensor<?xi32> + } + + + // Tests the launch wrap of a single outside compiled cluster with no input or output dependencies. + + // CHECK-LABEL: func @nodep_single_outside_compilation + func @nodep_single_outside_compilation() -> () { + // CHECK: "tf.A" + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.B" + // CHECK-NOT: _xla_outside_compilation + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + // CHECK: device_assignment = [], num_cores_per_replica = 1 : i64, topology = "" + "tf_device.cluster"() ( { + "tf.A"() : () -> () + "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () + "tf.C"() : () -> () + tf_device.return + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> () + return + } + + // Tests the launch wrap of a single outside compiled cluster with data parallelism. + + // CHECK-LABEL: func @single_outside_compilation_with_replicate + func @single_outside_compilation_with_replicate(%arg0: tensor<?xi32>) -> () { + // CHECK: "tf.A" + // CHECK: tf_device.replicate + // CHECK-NEXT: "tf_device.cluster" + // CHECK-NEXT: "tf.B" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: "tf.C" + // CHECK-NOT: _xla_outside_compilation + // CHECK: tf_device.return + // CHECK-NEXT: device = "TPU_REPLICATED_HOST" + // CHECK: device_assignment = [], num_cores_per_replica = 1 : i64, topology = "" + %0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32> + tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} { + "tf_device.cluster"() ( { + "tf.B"() : () -> () + "tf.C"(%ri_0) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> () + "tf.D"() : () -> () + tf_device.return + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> () + tf_device.return + } + return + } + + // Tests launch wrap of a single outside compiled cluster with input/output. + + // CHECK-LABEL: func @single_outside_compilation_input_output + func @single_outside_compilation_input_output(%arg0: tensor<?xi32>) -> tensor<?xi32> { + %0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32> + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK-NEXT: %[[LAUNCH_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]]) + // CHECK: tf_device.return %[[B_OUTPUT]] + // CHECK: "tf.C"(%[[LAUNCH_OUTPUT]]) + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor<?xi32>) + %4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> tensor<?xi32> + %5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32> + tf_device.return %5 : tensor<?xi32> + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32> + tf_device.return %2 : tensor<?xi32> + } + + return %1 : tensor<?xi32> + } + + // Tests launch wrap of multiple outside compiled cluster with input/output. + + // CHECK-LABEL: func @multiple_outside_compilation_input_output + func @multiple_outside_compilation_input_output(%arg0: tensor<?xi32>) -> tensor<?xi32> { + %0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32> + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK-NEXT: %[[LAUNCH_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]]) + // CHECK: tf_device.return %[[B_OUTPUT]] + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[LAUNCH_OUTPUT]]) + // CHECK-NEXT: %[[LAUNCH_OUTPUT2:[0-9]*]] = "tf_device.launch" + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[C_OUTPUT]]) + // CHECK: %[[E_OUTPUT:[0-9]*]] = "tf.E"(%[[D_OUTPUT]]) + // CHECK: tf_device.return %[[E_OUTPUT]] + // CHECK: "tf.F"(%[[LAUNCH_OUTPUT2]]) + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor<?xi32>) + %4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> tensor<?xi32> + %5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32> + %6 = "tf.D"(%5) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> tensor<?xi32> + %7 = "tf.E"(%6) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> tensor<?xi32> + %8 = "tf.F"(%7) : (tensor<?xi32>) -> tensor<?xi32> + tf_device.return %8 : tensor<?xi32> + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32> + tf_device.return %2 : tensor<?xi32> + } + + return %1 : tensor<?xi32> + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc index c2760000e82..3f58514f6e9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/Dialect/Traits.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project @@ -38,6 +39,12 @@ class ConvertResultsBroadcastableShapeOp : public RewritePattern { LogicalResult matchAndRewrite(Operation* op, PatternRewriter& rewriter) const override; + + private: + template <typename Op> + LogicalResult RewriteEqOp(Operation* op, PatternRewriter& rewriter) const; + + LogicalResult RewriteOp(Operation* op, PatternRewriter& rewriter) const; }; class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> { @@ -47,7 +54,27 @@ class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> { LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { - if (!op->hasTrait<OpTrait::ResultsBroadcastableShape>()) return failure(); + if (op->hasTrait<OpTrait::ResultsBroadcastableShape>()) + return RewriteOp(op, rewriter); + + // tf.Equal and tf.NotEqual ops only satisfy ResultsBroadcastableShape when + // incompatible_shape_error is `true` (what is also checked by the verifier). + if (succeeded(RewriteEqOp<TF::EqualOp>(op, rewriter))) return success(); + if (succeeded(RewriteEqOp<TF::NotEqualOp>(op, rewriter))) return success(); + + return failure(); +} + +template <typename Op> +LogicalResult ConvertResultsBroadcastableShapeOp::RewriteEqOp( + Operation* op, PatternRewriter& rewriter) const { + auto eq_op = llvm::dyn_cast_or_null<Op>(op); + if (eq_op && eq_op.incompatible_shape_error()) return RewriteOp(op, rewriter); + return failure(); +} + +LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp( + Operation* op, PatternRewriter& rewriter) const { if (op->getNumOperands() != 2 || op->getResultTypes().size() != 1) return failure(); @@ -56,6 +83,7 @@ LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite( op->getResultTypes().front().dyn_cast_or_null<RankedTensorType>(); if (!result_type || !result_type.hasStaticShape()) return failure(); + bool changed = false; for (uint64_t i = 0, e = op->getNumOperands(); i < e; ++i) { // Check that the i'th operand is a broadcast. auto broadcast = llvm::dyn_cast_or_null<TF::BroadcastToOp>( @@ -89,10 +117,9 @@ LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite( // Update the operand of the op to be the operand of the broadcast. rewriter.updateRootInPlace( op, [&]() { op->getOpOperand(i).set(broadcast.input()); }); - return success(); + changed = true; } - - return failure(); + return success(changed); } void BroadcastFoldPass::runOnFunction() { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc index a92d3f367cf..e865597c92e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc @@ -112,8 +112,8 @@ LogicalResult ConvertIfOp(IfOp if_op) { LogicalResult ConvertWhileOp(WhileOp while_op) { auto while_region = OpBuilder(while_op).create<TF::WhileRegionOp>( while_op.getLoc(), while_op.getResultTypes(), while_op.input(), - while_op.output_shapes(), while_op.parallel_iterations(), - while_op.is_stateless()); + while_op.parallel_iterations(), while_op.is_stateless(), + while_op.shape_invariant()); CopyDeviceAndUnderscoredAttributes(while_op, while_region); YieldOp cond_yield = diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index 4e9f9871964..418f6c0f1c6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -21,15 +21,18 @@ limitations under the License. #include <numeric> #include <vector> +#include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project @@ -44,6 +47,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/core/framework/kernel_shape_util.h" +#include "tensorflow/core/lib/math/math_util.h" namespace mlir { namespace TF { @@ -479,18 +483,16 @@ template <typename ReductionOp> LogicalResult MatchBinaryReduceFunction(mlir::Region &function) { Block &body = function.front(); if (body.getNumArguments() != 2) return failure(); - if (body.getOperations().size() != 2) return failure(); - - ReductionOp reduce_op = dyn_cast<ReductionOp>(body.front()); - if (!reduce_op) return failure(); - if (reduce_op.lhs() != body.getArgument(0) || - reduce_op.rhs() != body.getArgument(1)) - return failure(); mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back()); if (!return_op) return failure(); - if (return_op.getNumOperands() != 1 || - return_op.results().front() != reduce_op) + if (return_op.getNumOperands() != 1) return failure(); + + ReductionOp reduce_op = dyn_cast_or_null<ReductionOp>( + return_op.getOperands().front().getDefiningOp()); + if (!reduce_op) return failure(); + if (reduce_op.lhs() != body.getArgument(0) || + reduce_op.rhs() != body.getArgument(1)) return failure(); return success(); @@ -654,6 +656,190 @@ class ConvertIotaOpToTfRange : public OpConversionPattern<mhlo::IotaOp> { } }; +// Maps the following represenattions of AvgPool in MHLO into a tf.AvgPool{3D} +// operation when they cleanly map to 2D or 3D average pool with VALID or SAME +// padding: +// * div(reduce_sum_window(x), constant(sizeof(window))) +// * div(reduce_sum_window(x), reduce_sum_window(constant(1))) +class ConvertAvgPoolOp : public OpConversionPattern<mhlo::DivOp> { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::DivOp div_op, ArrayRef<Value> args, + ConversionPatternRewriter &rewriter) const final { + auto rw = + dyn_cast_or_null<mhlo::ReduceWindowOp>(div_op.lhs().getDefiningOp()); + if (!rw) return failure(); + + // Check that the reduce-window is a sum-reduce-window. + if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw.body()))) + return failure(); + + // Check that this is a floating point reduce window with a rank of 4 or 5. + RankedTensorType rw_type = rw.getType().dyn_cast<RankedTensorType>(); + if (!rw_type || !rw_type.getElementType().isa<FloatType>() || + rw_type.getRank() <= 3 || rw_type.getRank() > 5) + return failure(); + + // Check that the Div op doesn't do broadcasting on the output of the reduce + // window. + if (div_op.getType() != rw.getType()) return failure(); + + // tf.avg_pool need at least 3 dimensions (batch, spatial, channel) + const uint64_t rank = rw.window_dimensions().size(); + if (rank <= 2) return failure(); + + // If the init value isn't zero then it can't be an average pool. + if (!isFloatZero(rw.init_value())) return failure(); + + llvm::SmallVector<int64_t, 5> window_strides; + if (rw.window_strides().hasValue()) { + window_strides.insert(window_strides.end(), + rw.window_strides()->getValues<int64_t>().begin(), + rw.window_strides()->getValues<int64_t>().end()); + } else { + window_strides.resize(rank, 1); + } + + llvm::SmallVector<int64_t, 10> padding; + if (rw.padding().hasValue()) { + padding.insert(padding.begin(), + rw.padding()->getValues<int64_t>().begin(), + rw.padding()->getValues<int64_t>().end()); + } else { + padding.resize(2 * rank, 0); + } + + // Check that we don't do any reduction along the batch (first) and channel + // (last) dimensions. + const uint64_t batch_dim = 0; + const uint64_t channel_dim = rank - 1; + if (rw.window_dimensions().getValue<int64_t>({batch_dim}) != 1 || + rw.window_dimensions().getValue<int64_t>({channel_dim}) != 1 || + window_strides[batch_dim] != 1 || window_strides[channel_dim] != 1 || + padding[2 * batch_dim] != 0 || padding[2 * batch_dim + 1] != 0 || + padding[2 * channel_dim] != 0 || padding[2 * channel_dim + 1] != 0) + return failure(); + + if (rw.window_dilations().hasValue() && + !(rw.window_dilations()->isSplat() && + rw.window_dilations()->getSplatValue<APInt>() == 1)) + return failure(); + + if (rw.base_dilations().hasValue() && + !(rw.base_dilations()->isSplat() && + rw.base_dilations()->getSplatValue<APInt>() == 1)) + return failure(); + + DenseFPElementsAttr divisor; + if (matchPattern(div_op.rhs(), m_Constant(&divisor))) { + // If the divisor is a constant then check that it matches with the number + // of elements inside the window what is required for a VALID AvgPool. + if (!divisor.isSplat()) return failure(); + int64_t window_size = 1; + for (int64_t w : rw.window_dimensions().getValues<int64_t>()) { + window_size *= w; + } + if (!divisor.getSplatValue<APFloat>().isExactlyValue(window_size)) + return failure(); + + // Check that we have no padding. + if (!llvm::all_of(padding, [](int64_t i) { return i == 0; })) + return failure(); + + return replaceWithAvgPool( + div_op, rw.operand(), + llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()), + window_strides, "VALID", rewriter); + } + + auto rw_rhs = + dyn_cast_or_null<mhlo::ReduceWindowOp>(div_op.rhs().getDefiningOp()); + if (rw_rhs) { + // Check that RHS is a sum-reduce-window. + if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw_rhs.body()))) + return failure(); + + // Check that the RHS is a reduce_window over a constant 1 input with 0 as + // the init value. + DenseFPElementsAttr rhs_input; + if (!isFloatZero(rw_rhs.init_value()) || + !matchPattern(rw_rhs.operand(), m_Constant(&rhs_input)) || + !rhs_input.isSplat() || + !rhs_input.getSplatValue<APFloat>().isExactlyValue(1.0)) + return failure(); + + // Check that the two reduce window have the same window configuration. + if (rw.window_dimensions() != rw_rhs.window_dimensions() || + rw.window_strides() != rw_rhs.window_strides() || + rw.window_dilations() != rw_rhs.window_dilations() || + rw.base_dilations() != rw_rhs.base_dilations() || + rw.padding() != rw_rhs.padding()) + return failure(); + + if (llvm::all_of(padding, [](int64_t i) { return i == 0; })) + return replaceWithAvgPool( + div_op, rw.operand(), + llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()), + window_strides, "VALID", rewriter); + + RankedTensorType input_type = + rw.operand().getType().dyn_cast<RankedTensorType>(); + RankedTensorType output_type = rw.getType().dyn_cast<RankedTensorType>(); + if (!input_type || !output_type) return failure(); + + // Check that the individual padding values are corresponding to SAME + // padding from TensorFlow. + for (uint64_t i = 1; i < rank - 1; ++i) { + int64_t padding_size = + (output_type.getShape()[i] - 1) * window_strides[i] + + rw.window_dimensions().getValue<int64_t>({i}) - + input_type.getShape()[i]; + if (padding[2 * i] != + tensorflow::MathUtil::FloorOfRatio(padding_size, int64_t(2)) || + padding[2 * i + 1] != + tensorflow::MathUtil::CeilOfRatio(padding_size, int64_t(2))) + return failure(); + } + return replaceWithAvgPool( + div_op, rw.operand(), + llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()), + window_strides, "SAME", rewriter); + } + return failure(); + } + + private: + bool isFloatZero(Value value) const { + DenseFPElementsAttr initial_value; + return matchPattern(value, m_Constant(&initial_value)) && + initial_value.getNumElements() == 1 && + initial_value.getValue<APFloat>({}).isZero(); + } + + LogicalResult replaceWithAvgPool(mhlo::DivOp op, Value input, + llvm::ArrayRef<int64_t> ksizes, + llvm::ArrayRef<int64_t> kstrides, + llvm::StringRef padding, + ConversionPatternRewriter &rewriter) const { + if (ksizes.size() == 4) { + rewriter.replaceOpWithNewOp<AvgPoolOp>( + op, op.getType(), input, rewriter.getI64ArrayAttr(ksizes), + rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding), + rewriter.getStringAttr("NHWC")); + return success(); + } else if (ksizes.size() == 5) { + rewriter.replaceOpWithNewOp<AvgPool3DOp>( + op, op.getType(), input, rewriter.getI64ArrayAttr(ksizes), + rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding), + rewriter.getStringAttr("NDHWC")); + return success(); + } + return failure(); + } +}; + class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert<TF::TensorFlowDialect>(); @@ -794,10 +980,10 @@ static PassRegistration<LegalizeHloToTf> pass( void PopulateLegalizeHloToTfPatterns(OwningRewritePatternList *patterns, MLIRContext *context) { + patterns->insert<ConvertAvgPoolOp, ConvertConvOp, ConvertSliceOp, + ConvertReduceOpToTfMax, ConvertReduceOpToTfMin, + ConvertReduceOpToTfSum, ConvertIotaOpToTfRange>(context); populateWithGenerated(context, *patterns); - patterns->insert<ConvertConvOp, ConvertSliceOp, ConvertReduceOpToTfMax, - ConvertReduceOpToTfMin, ConvertReduceOpToTfSum, - ConvertIotaOpToTfRange>(context); } std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass() { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc index 9d43eb17624..2d9c9023ac1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc @@ -332,12 +332,13 @@ class LowerDynamicStitchOp : public RewritePattern { class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern { public: explicit ConvertFakeQuantWithMinMaxVarsOp(MLIRContext *context) - : RewritePattern(FakeQuantWithMinMaxVarsOp::getOperationName(), - {SubOp::getOperationName(), ConstOp::getOperationName(), - MulOp::getOperationName(), FloorOp::getOperationName(), - ClipByValueOp::getOperationName(), - DivOp::getOperationName(), RoundOp::getOperationName()}, - 1, context) {} + : RewritePattern( + FakeQuantWithMinMaxVarsOp::getOperationName(), + {AddV2Op::getOperationName(), SubOp::getOperationName(), + ConstOp::getOperationName(), MulOp::getOperationName(), + FloorOp::getOperationName(), ClipByValueOp::getOperationName(), + DivOp::getOperationName(), RoundOp::getOperationName()}, + 1, context) {} LogicalResult matchAndRewrite(Operation *src_op, PatternRewriter &rewriter) const override { @@ -419,8 +420,8 @@ class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern { op.getLoc(), DenseElementsAttr::get(scalar_ty, ConvertToAPFloat(0.5, element_ty))); - quantized_input = rewriter.create<AddOp>(op.getLoc(), input_ty, - quantized_input, half_val); + quantized_input = rewriter.create<AddV2Op>(op.getLoc(), input_ty, + quantized_input, half_val); quantized_input = rewriter.create<FloorOp>(op.getLoc(), quantized_input); @@ -428,8 +429,8 @@ class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern { Value output = rewriter.create<MulOp>(op.getLoc(), input_ty, quantized_input, quant_to_float); - output = - rewriter.create<AddOp>(op.getLoc(), input_ty, output, nudged_float_min); + output = rewriter.create<AddV2Op>(op.getLoc(), input_ty, output, + nudged_float_min); rewriter.replaceOp(op, {output}); return success(); @@ -811,7 +812,7 @@ class LowerSpaceToBatchNDOp : public RewritePattern { CastOp::getOperationName(), ConstOp::getOperationName(), ConcatV2Op::getOperationName(), - AddOp::getOperationName(), + AddV2Op::getOperationName(), PadOp::getOperationName(), SplitOp::getOperationName(), UnpackOp::getOperationName(), @@ -907,8 +908,8 @@ class LowerSpaceToBatchNDOp : public RewritePattern { auto paddings_split = rewriter.create<UnpackOp>( loc, TypeRange({paddings_sum_type, paddings_sum_type}), full_paddings, rewriter.getI64IntegerAttr(1)); - auto paddings_sum = rewriter.create<AddOp>(loc, paddings_split.getResult(0), - paddings_split.getResult(1)); + auto paddings_sum = rewriter.create<AddV2Op>( + loc, paddings_split.getResult(0), paddings_split.getResult(1)); auto input_shape_tensor = rewriter.create<ConstOp>( loc, @@ -918,7 +919,7 @@ class LowerSpaceToBatchNDOp : public RewritePattern { // padded_shape_tensor is the shape of padded. auto padded_shape_tensor = - rewriter.create<AddOp>(loc, paddings_sum, input_shape_tensor); + rewriter.create<AddV2Op>(loc, paddings_sum, input_shape_tensor); auto zero_i32 = rewriter.create<ConstOp>( loc, GetScalarOfType(rewriter.getIntegerType(32), 0)); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td index d0bd97501e1..f667080f0c2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td @@ -237,7 +237,7 @@ def : Pat<(TF_RoundOp:$res TF_FloatTensor:$input), (TF_SubOp $input, (TF_FloorOp:$floor $input)), (TF_ConstOp (GetScalarOfFloatType<"0.5"> $input))), $floor, - (TF_AddOp + (TF_AddV2Op (TF_ConstOp (GetScalarOfType<1> $input)), $floor))>; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/outside_compiled_to_host_launch.cc b/tensorflow/compiler/mlir/tensorflow/transforms/outside_compiled_to_host_launch.cc new file mode 100644 index 00000000000..225f6c0cc15 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/outside_compiled_to_host_launch.cc @@ -0,0 +1,184 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" + +namespace mlir { +namespace TFTPU { + +namespace { + +constexpr char kDeviceAttr[] = "device"; +constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; + +// Mapping for `_xla_outside_compilation` attribute to ops of a cluster. +using OutsideClusterMap = + llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<Operation*, 8>, 8>; + +// This pass wraps ops with the same `_xla_outside_compilation` +// attribute value in a tf_device.launch op with host device assignment. +// +// A simple example: +// "tf_device.cluster"() ( { +// "tf.A"() +// "tf.B"() {_xla_outside_compilation = "cluster1"} +// "tf.C"() +// tf_device.return +// }) {num_cores_per_replica = 1, topology = "", device_assignment = []} +// +// Would become the following ops (unimportant attribute, type are omitted): +// "tf_device.cluster"() ( { +// "tf.A"() +// "tf_device.launch"() { +// "tf.B"() {_xla_outside_compilation = "cluster1"} +// tf_device.return +// } {device = "TPU_REPLICATED_HOST"} : () -> () +// "tf.C"() +// tf_device.return +// }) {num_cores_per_replica = 1, topology = "", device_assignment = []} +// + +struct OutsideCompiledToHostLaunch + : public PassWrapper<OutsideCompiledToHostLaunch, OperationPass<ModuleOp>> { + void runOnOperation() override; +}; + +// Collects and clusters ops in `block` with the same `_xla_outside_compilation` +// attribute into `clusters` This returns an error if a +// `_xla_outside_compilation` attribute of an op is empty. +LogicalResult CollectAndGroupOutsideClusterOps(Block* block, + OutsideClusterMap* clusters) { + auto walk_result = block->walk([&](Operation* op) { + if (auto attr = op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) { + if (attr.getValue().empty()) { + op->emitOpError() << "requires non empty '" + << kXlaOutsideCompilationAttr << "' string attribute"; + return WalkResult::interrupt(); + } + + auto it = clusters->try_emplace(attr.getValue()); + it.first->getSecond().push_back(op); + } + return WalkResult::advance(); + }); + + return failure(walk_result.wasInterrupted()); +} + +// Extracts all externally used outputs of `cluster_ops`. +llvm::SmallSetVector<Value, 4> GetExternalOutputs( + llvm::ArrayRef<Operation*> cluster_ops) { + llvm::SmallSetVector<Value, 4> external_outputs; + llvm::SmallPtrSet<Operation*, 4> host_cluster_ops_set; + for (auto op : cluster_ops) { + op->walk([&](Operation* host_cluster_op) { + host_cluster_ops_set.insert(host_cluster_op); + }); + } + + for (Operation* op : cluster_ops) { + for (Operation* user : op->getUsers()) { + bool is_external = llvm::none_of( + host_cluster_ops_set, + [&](Operation* cluster_op) { return user == cluster_op; }); + if (!is_external) continue; + for (Value v : user->getOperands()) { + if (v.getDefiningOp() == op) external_outputs.insert(v); + } + } + } + + return external_outputs; +} + +void WrapClusterInLaunch(llvm::ArrayRef<Operation*> cluster_ops, + llvm::StringRef host_device) { + auto* last_cluster_op = cluster_ops.back(); + OpBuilder builder(last_cluster_op); + llvm::SmallVector<Type, 4> launch_output_types; + auto external_outputs = GetExternalOutputs(cluster_ops); + for (const auto& external_output : external_outputs) + launch_output_types.push_back(external_output.getType()); + + auto launch_op = builder.create<tf_device::LaunchOp>( + last_cluster_op->getLoc(), builder.getStringAttr(host_device), + /*result_types=*/launch_output_types); + for (auto result : llvm::zip(external_outputs, launch_op.getResults())) { + std::get<0>(result).replaceAllUsesWith(std::get<1>(result)); + } + + launch_op.body().push_back(new Block); + builder.setInsertionPointToEnd(&launch_op.GetBody()); + auto* return_op = + builder + .create<tf_device::ReturnOp>(last_cluster_op->getLoc(), + external_outputs.getArrayRef()) + .getOperation(); + MLIRContext* context = launch_op.getContext(); + for (Operation* cluster_op : cluster_ops) { + cluster_op->removeAttr( + Identifier::get(kXlaOutsideCompilationAttr, context)); + cluster_op->removeAttr(Identifier::get(kDeviceAttr, context)); + cluster_op->moveBefore(return_op); + } +} + +void OutsideCompiledToHostLaunch::runOnOperation() { + // Get runtime devices information from the closest parent module. + auto module = getOperation(); + mlir::TF::RuntimeDevices devices; + if (failed(tensorflow::GetDevicesFromOp(module, &devices))) + return signalPassFailure(); + + auto result = module.walk([&](tf_device::ClusterOp tpu_cluster) { + OutsideClusterMap clusters; + if (failed(CollectAndGroupOutsideClusterOps(&tpu_cluster.GetBody(), + &clusters))) + return WalkResult::interrupt(); + + if (clusters.empty()) return WalkResult::advance(); + + std::string host_device; + tensorflow::GetHostDeviceOutsideComputation(devices, tpu_cluster, + &host_device); + for (const auto& cluster : clusters) { + WrapClusterInLaunch(cluster.getSecond(), host_device); + } + return WalkResult::advance(); + }); + if (result.wasInterrupted()) return signalPassFailure(); +} + +} // anonymous namespace + +std::unique_ptr<OperationPass<ModuleOp>> +CreateOutsideCompiledToHostLaunchPass() { + return std::make_unique<OutsideCompiledToHostLaunch>(); +} + +static PassRegistration<OutsideCompiledToHostLaunch> pass( + "tf-outside-compiled-to-host-launch", + "Wraps ops with ithe same _xla_outside_compiled attribute in " + "tf_device.launch on replicated host device."); + +} // namespace TFTPU +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 3293ef7810e..25945182c20 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -345,6 +345,11 @@ std::unique_ptr<OperationPass<FuncOp>> CreateTPUColocateCompositeResourceOps(); // run-time according to compilation result. std::unique_ptr<OperationPass<ModuleOp>> CreateTPUVariableReformattingPass(); +// Creates a pass that wraps ops with the same `_xla_outside_compilation` +// attribute value in a tf_device.launch op with host device assignment. +std::unique_ptr<OperationPass<ModuleOp>> +CreateOutsideCompiledToHostLaunchPass(); + // Creates a pass that groups outside compiled operations (CPU ops inside TPU // cluster) into clusters that can be extracted and run on the CPU. std::unique_ptr<OperationPass<ModuleOp>> diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc index 2363663cb5a..a61fa650973 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc @@ -398,8 +398,8 @@ LogicalResult RegionControlFlowToFunctional::ConvertWhileOp( OpBuilder builder(while_region); auto while_op = builder.create<WhileOp>( while_region.getLoc(), new_result_types, new_inputs, cond_name, body_name, - while_region.output_shapes(), while_region.parallel_iterations(), - while_region.is_stateless()); + while_region.parallel_iterations(), while_region.is_stateless(), + while_region.shape_invariant()); CopyDeviceAndUnderscoredAttributes(while_region, while_op); // Redirect old results to new results. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc index 7953dfe1832..caf4e6b9197 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc @@ -255,8 +255,7 @@ TF::WhileRegionOp CloneEmptyWhile(bool is_stateless, OpBuilder& builder) { auto host_side_while = builder.create<TF::WhileRegionOp>( loc, /*output=*/ArrayRef<Type>{}, /*input=*/ArrayRef<Value>{}, - /*output_shapes=*/builder.getArrayAttr({}), parallel_iterations, - is_stateless); + parallel_iterations, is_stateless, /*shape_invariant=*/false); // Create empty else branch region. auto& body = host_side_while.body(); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc index 0057e498cea..125994b928d 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc @@ -155,6 +155,9 @@ StatusOr<absl::flat_hash_set<absl::string_view>> GetAttributesToIgnore( if (llvm::isa<mlir::TF::CaseOp, mlir::TF::IfOp, mlir::TF::WhileOp>(inst)) attrs_to_ignore.insert("is_stateless"); + if (llvm::isa<mlir::TF::WhileOp>(inst)) + attrs_to_ignore.insert("shape_invariant"); + return attrs_to_ignore; } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 4242cdd349e..986b570f4fc 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -971,6 +971,16 @@ StatusOr<mlir::Type> ImporterBase::InferOutputType(const Node& node, int idx, etype.getContext())); } + if (node.IsWhileNode()) { + auto* output_shapes = node.attrs().Find("output_shapes"); + auto* element_types = node.attrs().Find("T"); + if (output_shapes && !output_shapes->list().shape().empty()) { + const auto& output_shape = output_shapes->list().shape(idx); + const auto& element_type = element_types->list().type(idx); + return ConvertToMlirTensorType(output_shape, element_type, &builder); + } + } + // Returns a simple, more conservative unranked tensor type. auto default_type = [&]() -> StatusOr<mlir::Type> { mlir::Type element_type; @@ -1907,7 +1917,13 @@ Status ImporterBase::ConvertNode(const Node& node) { // Case/If/While op in MLIR and add the differentiating attribute. if (node.IsCaseNode()) composite_control_flow_op("Case"); if (node.IsIfNode()) composite_control_flow_op("If"); - if (node.IsWhileNode()) composite_control_flow_op("While"); + if (node.IsWhileNode()) { + composite_control_flow_op("While"); + auto* output_shapes = node.attrs().Find("output_shapes"); + if (output_shapes && !output_shapes->list().shape().empty()) + result.attributes.push_back( + builder_.getNamedAttr("shape_invariant", builder_.getUnitAttr())); + } // Register the mapping between the TF node and the newly created operation. node_values_[node.id()] = @@ -3420,6 +3436,8 @@ SavedModelSignatureDefImporterLite::ConvertGraph( const std::vector<std::pair<std::string, TensorInfo>>& inputs, const std::vector<std::pair<std::string, TensorInfo>>& outputs, const std::vector<std::string> control_outputs) { + VLOG(1) << "Importing Signature: " << name; + GraphImportConfig specs; specs.prune_unused_nodes = true; specs.inputs = ParseInputArrays(inputs); @@ -3491,6 +3509,9 @@ SavedModelSignatureDefImporterLite::ParseInputArrays( // Only dense tensor is supported. DCHECK_EQ(tensor_info.encoding_case(), tensorflow::TensorInfo::kName); + VLOG(1) << "Importing Signature Input: input_name = " << iter.first + << ", tensor_info = " << tensor_info.DebugString(); + ArrayInfo array_info; array_info.imported_dtype = tensor_info.dtype(); array_info.shape = tensor_info.tensor_shape(); diff --git a/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD b/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD index 735ee64596e..d088adc3f9b 100644 --- a/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD +++ b/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD @@ -68,6 +68,7 @@ distribute_py_test( tags = [ "no_cuda_asan", # b/173431253 "no_oss", + "notap", # b/173661843 "notsan", # b/173246447 ], xla_enable_strict_auto_jit = False, # b/173254861 diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc index e7e9025c4aa..a8a99f1cd10 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc @@ -90,8 +90,17 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only, pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass()); } - // Legalize only hlo operations to lhlo, keep the rest as tensors. - pm.addPass(mlir::kernel_gen::transforms::CreateHloBufferizePass()); + // Partial bufferization: Transforms inparticular HLO operation to their + // corresponding LHLO operations and converts the function signature. Leaves + // shape operations untouched. + pm.addPass(mlir::kernel_gen::transforms::CreateBufferizePass( + /*allow_partial_bufferization=*/true)); + // Run CSE to ensure that loads and stores to the same location get recognized + // as such. + pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); + // Forward stores to buffers to loads. + pm.addNestedPass<mlir::FuncOp>(xla::mlir_gpu::createStoreForwardingPass()); + // Clean up the IR for further processing. pm.addPass(mlir::createCanonicalizerPass()); pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass()); @@ -100,18 +109,22 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only, // needed. llvm::SmallVector<unsigned, 4> tiling_for_unrolling; llvm::SmallVector<int64_t, 4> as_int64; - if (!unroll_factors.empty()) { - tiling_for_unrolling.reserve(tile_sizes.size()); - for (auto pair : llvm::zip(tile_sizes, unroll_factors)) { - tiling_for_unrolling.push_back(std::get<0>(pair) * std::get<1>(pair)); - as_int64.push_back(std::get<1>(pair)); - } - } else { - tiling_for_unrolling.append(tile_sizes.begin(), tile_sizes.end()); + tiling_for_unrolling.reserve(tile_sizes.size()); + for (auto pair : llvm::zip(tile_sizes, unroll_factors)) { + tiling_for_unrolling.push_back(std::get<0>(pair) * std::get<1>(pair)); + as_int64.push_back(std::get<1>(pair)); } + tiling_for_unrolling.append( + tile_sizes.drop_front(unroll_factors.size()).begin(), tile_sizes.end()); // Transform LHLO operations to LinAlg. pm.addNestedPass<mlir::FuncOp>( ::mlir::lmhlo::createLegalizeLhloToLinalgPass()); + if (!gpu_binary_only) { + // Find candidates for buffer reuse. This is only successful if buffer size + // equality can be determined based on `linalg.generic` operations. + pm.addNestedPass<mlir::FuncOp>( + mlir::kernel_gen::transforms::CreateBufferReusePass()); + } // Fuse linalg operations. pm.addNestedPass<mlir::FuncOp>(::mlir::lmhlo::createLhloFuseLinalgPass( /*use_parallel_loops=*/true, tiling_for_unrolling)); @@ -141,11 +154,6 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only, // Some basic cleanup. pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); - if (!gpu_binary_only) { - // Find candidates for buffer reuse. - pm.addNestedPass<mlir::FuncOp>( - mlir::kernel_gen::transforms::CreateBufferReusePass()); - } // Greedily map the remaining loop to GPU hardware dimensions. pm.addNestedPass<::mlir::FuncOp>(xla::mlir_gpu::createMapParallelLoopsPass()); @@ -157,7 +165,7 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only, pm.addPass(mlir::kernel_gen::transforms::CreateShapeToDescriptorsPass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass()); - pm.addPass(mlir::kernel_gen::transforms::CreateFinalBufferizePass()); + pm.addPass(mlir::kernel_gen::transforms::CreateBufferizePass()); pm.addNestedPass<mlir::FuncOp>(mlir::createPromoteBuffersToStackPass(64)); // TODO(herhut): Enabled this to avoid leaks once fixed. // pm.addNestedPass<mlir::FuncOp>(::mlir::createBufferDeallocationPass()); @@ -189,10 +197,6 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only, pm.addPass(::mlir::createLowerAffinePass()); // Constraints are removed as late as possible and before lowering to CFG. pm.addNestedPass<::mlir::FuncOp>(::mlir::createConvertShapeConstraintsPass()); - if (embed_memref_prints) { - pm.addNestedPass<::mlir::FuncOp>( - mlir::kernel_gen::transforms::CreateEmbedMemRefPrintsPass()); - } pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); // TODO(herhut): Remove this pass once the LowerToCFG pass can handle it. pm.addNestedPass<mlir::FuncOp>( @@ -200,6 +204,10 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only, pm.addPass(::mlir::createLowerToCFGPass()); // Map allocs, asserts, etc. to the tensorflow framework. pm.addPass(mlir::kernel_gen::tf_framework::CreateEmbedTFFrameworkPass()); + if (embed_memref_prints) { + pm.addNestedPass<::mlir::FuncOp>( + mlir::kernel_gen::transforms::CreateEmbedMemRefPrintsPass()); + } if (failed(pm.run(module))) { return InternalError("Lowering to GPU kernels failed."); } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/buffer_reuse.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/buffer_reuse.mlir index b297c72af5d..bb86e25414d 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/buffer_reuse.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/buffer_reuse.mlir @@ -49,12 +49,58 @@ func @local_reuse_with_memref_maps( iterator_types = ["parallel"] } ins(%arg : memref<?xi64, offset: 2, strides: [3]>) outs(%result : memref<?xi64, offset: 2, strides: [3]>) { - ^bb0(%a: i64, %b: i64): + ^bb0(%a : i64, %b : i64): linalg.yield %a : i64 } return %result : memref<?xi64, offset: 2, strides: [3]> } +// CHECK-LABEL: @memref_reinterpret_cast_alias +func @memref_reinterpret_cast_alias(%arg : memref<f32>, %n : index) + -> memref<?xf32> attributes {tf_entry} { + %c0 = constant 0 : index + %reinterpreted = memref_reinterpret_cast %arg to + offset: [0], + sizes: [%n], + strides: [%c0]: memref<f32> to memref<?xf32> + + // CHECK: alloc + // CHECK-SAME: reuse_input_candidates = [0 : index] + %result = alloc(%n) : memref<?xf32> + + // reinterpreted (arg) and result are of same size. + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%reinterpreted : memref<?xf32>) outs(%result : memref<?xf32>) { + ^bb0(%a : f32, %b : f32): + linalg.yield %a : f32 + } + + return %result : memref<?xf32> +} + +// CHECK-LABEL: @memref_cast_alias +func @memref_cast_alias(%arg : memref<*xf32>, %n : index) + -> memref<?xf32> attributes {tf_entry} { + %casted = memref_cast %arg : memref<*xf32> to memref<?xf32> + + // CHECK: alloc + // CHECK-SAME: reuse_input_candidates = [0 : index] + %result = alloc(%n) : memref<?xf32> + + // reinterpreted (arg) and result are of same size. + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%casted : memref<?xf32>) outs(%result : memref<?xf32>) { + ^bb0(%a : f32, %b : f32): + linalg.yield %a : f32 + } + + return %result : memref<?xf32> +} + // CHECK-LABEL: @indirect_size_equality func @indirect_size_equality(%arg0 : memref<?xi64>, %arg1 : memref<?xi64>, @@ -65,7 +111,7 @@ func @indirect_size_equality(%arg0 : memref<?xi64>, indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"] } ins(%arg0 : memref<?xi64>) outs(%arg1 : memref<?xi64>) { - ^bb0(%a: i64, %b: i64): + ^bb0(%a : i64, %b : i64): linalg.yield %a : i64 } @@ -78,7 +124,7 @@ func @indirect_size_equality(%arg0 : memref<?xi64>, indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"] } ins(%arg0 : memref<?xi64>) outs(%result : memref<?xi64>) { - ^bb0(%a: i64, %b: i64): + ^bb0(%a : i64, %b : i64): linalg.yield %a : i64 } @@ -317,7 +363,7 @@ func @abs_unranked_i64(%arg : memref<*xi64>, indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"] } ins(%flat_arg : memref<?xi64>) outs(%flat_result : memref<?xi64>) { - ^bb0(%a: i64, %b: i64): + ^bb0(%a : i64, %b : i64): %c0 = constant 0 : i64 %a_pos = cmpi "sge", %a, %c0 : i64 %a_neg = subi %c0, %a : i64 @@ -360,3 +406,41 @@ func @index_element_type(%arg : memref<2x3xindex>) -> memref<2x3xindex> %result = alloc() : memref<2x3xindex> return %result : memref<2x3xindex> } + +// Example as it occurs in the `tf.Abs` kernel for `f32`. +// CHECK-LABEL: @abs_f32 +func @abs_f32(%arg0: memref<*xf32>) -> memref<*xf32> + attributes {llvm.emit_c_interface, tf_entry} { + %c0 = constant 0 : index + %0 = shape.shape_of %arg0 : memref<*xf32> -> tensor<?xindex> + %1 = shape.num_elements %0 : tensor<?xindex> -> index + // CHECK-LABEL: alloc + // CHECK-SAME: reuse_input_candidates = [] + %2 = alloc() : memref<1xindex> + store %1, %2[%c0] : memref<1xindex> + %3 = memref_reshape %arg0(%2) + : (memref<*xf32>, memref<1xindex>) -> memref<?xf32> + %4 = dim %3, %c0 : memref<?xf32> + %5 = index_cast %4 : index to i64 + // CHECK-LABEL: alloc + // CHECK-SAME: reuse_input_candidates = [] + %6 = alloc() : memref<1xi64> + store %5, %6[%c0] : memref<1xi64> + %7 = load %6[%c0] : memref<1xi64> + %8 = index_cast %7 : i64 to index + // CHECK-LABEL: alloc + // CHECK-SAME: reuse_input_candidates = [0 : index] + %9 = alloc(%8) : memref<?xf32> + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%3 : memref<?xf32>) outs(%9 : memref<?xf32>) { + ^bb0(%arg1: f32, %arg2: f32): // no predecessors + %12 = absf %arg1 : f32 + linalg.yield %12 : f32 + } + %10 = tensor_to_memref %0 : memref<?xindex> + %11 = memref_reshape %9(%10) + : (memref<?xf32>, memref<?xindex>) -> memref<*xf32> + return %11 : memref<*xf32> +} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir index 244fa24dc6b..eca3d586294 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir @@ -1,4 +1,4 @@ -// RUN: kernel-gen-opt %s --final-bufferize | FileCheck %s +// RUN: kernel-gen-opt %s --bufferize | FileCheck %s // CHECK-LABEL: @extract_element // CHECK-SAME: (%[[ARG:.*]]: memref<?xf32>) -> f32 diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/isinf.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/isinf.mlir index d251c40981a..27912cb7b57 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/isinf.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/isinf.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s --test-tf-lower-tf --xla-legalize-tf | mlir-hlo-opt --transform-unranked-hlo | kernel-gen-opt -allow-unregistered-dialect --canonicalize --shape-to-descriptors --canonicalize --hlo-bufferize --std-bufferize --func-bufferize | FileCheck %s +// RUN: tf-opt %s --test-tf-lower-tf --xla-legalize-tf | mlir-hlo-opt --transform-unranked-hlo | kernel-gen-opt -allow-unregistered-dialect --canonicalize --shape-to-descriptors --canonicalize --bufferize | FileCheck %s // Test whether all shape computations required for isinf can be lowered to // the standard dialect, scf and descriptors. diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/print_memrefs.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/print_memrefs.mlir index 1b119f56377..c89e1ca7362 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/print_memrefs.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/print_memrefs.mlir @@ -24,7 +24,7 @@ func @print_memrefs( return %output : memref<*xf16> } -// CHECK: func private @print_memref_index(memref<*xindex>) +// CHECK: func private @print_memref_i64(memref<*xi64>) // CHECK-LABEL: func @print_memrefs @@ -33,12 +33,16 @@ func @print_memrefs( // CHECK: [[NUM_ELEM:%.*]] = alloca() : memref<1xindex> // CHECK: store {{%.*}}, [[NUM_ELEM]] -// CHECK: [[UNRANKED_NUM_ELEM:%.*]] = memref_cast [[NUM_ELEM]] -// CHECK-NEXT: call @print_memref_index([[UNRANKED_NUM_ELEM]]) +// CHECK: [[NUM_ELEM_I64:%.*]] = index_cast [[NUM_ELEM]] +// CHECK-SAME: : memref<1xindex> to memref<1xi64> +// CHECK-NEXT: [[UNRANKED_NUM_ELEM:%.*]] = memref_cast [[NUM_ELEM_I64]] +// CHECK-NEXT: call @print_memref_i64([[UNRANKED_NUM_ELEM]]) // CHECK: memref_reshape // CHECK: tf_framework.alloc -// CHECK: [[UNRANKED_SHAPE:%.*]] = memref_cast [[SHAPE]] -// CHECK-NEXT: call @print_memref_index([[UNRANKED_SHAPE]]) +// CHECK: [[SHAPE_I64:%.*]] = index_cast [[SHAPE]] +// CHECK-SAME: : memref<?xindex> to memref<?xi64> +// CHECK-NEXT: [[UNRANKED_SHAPE:%.*]] = memref_cast [[SHAPE_I64]] +// CHECK-NEXT: call @print_memref_i64([[UNRANKED_SHAPE]]) // CHECK: memref_reshape diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tanh.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tanh.mlir index 6a8f4107b65..6b9f011a109 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tanh.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tanh.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s --xla-legalize-tf | mlir-hlo-opt --transform-unranked-hlo | kernel-gen-opt -allow-unregistered-dialect --shape-to-descriptors --canonicalize --hlo-bufferize --std-bufferize --scf-bufferize --func-bufferize | FileCheck %s +// RUN: tf-opt %s --xla-legalize-tf | mlir-hlo-opt --transform-unranked-hlo | kernel-gen-opt -allow-unregistered-dialect --shape-to-descriptors --canonicalize --bufferize | FileCheck %s // Test whether all shape computations required for tanh can be lowered to // the standard dialect, scf and descriptors. We check for a sparse pattern here, diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-lmhlo.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-lmhlo.mlir index 9b3baf69488..2fc585d9e9d 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-lmhlo.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-lmhlo.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s --xla-legalize-tf='legalize-chlo=false' | mlir-hlo-opt --transform-unranked-hlo --chlo-legalize-to-hlo | kernel-gen-opt --shape-to-descriptors --canonicalize --hlo-bufferize +// RUN: tf-opt %s --xla-legalize-tf='legalize-chlo=false' | mlir-hlo-opt --transform-unranked-hlo --chlo-legalize-to-hlo | kernel-gen-opt --shape-to-descriptors --canonicalize --bufferize func @acos(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "tf.Acos"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32> diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc index 7d263169d60..62dd39b8897 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc @@ -75,7 +75,7 @@ extern "C" void _mlir_ciface_tf_dealloc(void* op_kernel_ctx, void* ptr) { extern "C" void _mlir_ciface_tf_report_error(void* op_kernel_ctx, int32_t error_code, char* msg) { Optional<ErrorCode> symbol = symbolizeErrorCode(error_code); - if (symbol.hasValue()) { + if (!symbol.hasValue()) { LOG(ERROR) << "No valid conversion from integer value = " << error_code << "to ErrorCode attribute"; return; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc index a184b80ea44..ac72f6b157c 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc @@ -55,12 +55,14 @@ namespace { /// A temporary buffer size analysis that is correct but may be incomplete. class BufferSizeAnalysis { public: - explicit BufferSizeAnalysis(FuncOp f) { build(f); } + BufferSizeAnalysis(FuncOp f, const BufferAliasAnalysis &aliases) { + build(f, aliases); + } bool is_same_size(Value a, Value b) { return ecs_.isEquivalent(a, b); } private: - void build(FuncOp &f) { + void build(FuncOp &f, const BufferAliasAnalysis &aliases) { auto buffers = find_buffer_values(f); // Memrefs with statically known same shape and same symbol-free affine maps @@ -102,10 +104,16 @@ class BufferSizeAnalysis { } }); - // Operand and result of `reshape_memref_cast` must be of same size. - f.walk([&](MemRefReshapeOp reshapeOp) { - ecs_.unionSets(reshapeOp.result(), reshapeOp.source()); - }); + // All aliases of a memref must be of the same underlying buffer size. + for (auto e : aliases) { + Value value = e.getFirst(); + if (!value.getType().isa<BaseMemRefType>()) continue; + for (Value alias : e.getSecond()) { + assert(alias.getType().isa<BaseMemRefType>() && + "Expected aliases of memref to be memrefs."); + ecs_.unionSets(value, alias); + } + } } bool affine_maps_symbol_free_and_equal(ArrayRef<AffineMap> as, @@ -181,7 +189,7 @@ class BufferReuseAnalysis { void find_reuse_candiates(FuncOp &f, BufferAliasAnalysis &aliases) { Liveness liveness(f); - BufferSizeAnalysis size_equivalences(f); + BufferSizeAnalysis size_equivalences(f, aliases); f.walk([&](Block *block) { find_reuse_candiates(block, aliases, liveness.getLiveness(block), size_equivalences, f.getArguments()); @@ -204,50 +212,53 @@ class BufferReuseAnalysis { // Find reuse candidates for the regarded allocation. SmallVector<int64_t, 2> local_reuse_candidates; - for (auto it : llvm::enumerate(arguments)) { - int64_t old_buffer_index = it.index(); - Value old_buffer = it.value(); + for (BlockArgument old_buffer : arguments) { if (!old_buffer.getType().isa<BaseMemRefType>()) continue; - // Will not reuse buffers of different size as they may be too small. + // Size criterion: Do not reuse buffers of different size as they may be + // too small. if (!size_equivalences.is_same_size(new_buffer, old_buffer)) continue; - // Only reuse buffers that are no longer used on first reuse, i.e. they - // are no longer alive. - bool livetimes_compatible = true; + // Lifetime criterion: Only reuse buffers that are no longer used on + // first reuse, i.e. they are no longer alive. + bool lifetimes_compatible = true; for (Value old_buffer_alias : aliases.resolve(old_buffer)) { if (first_reuse == nullptr) { // If the first use is beyond the end of this block we look at the // block end. An argument buffer that is already reusable there is - // certainly reusable at any later actual use. + // certainly reusable at any later actual use. Otherwise, lifetimes + // are incompatible. if (liveness->isLiveOut(old_buffer_alias)) { - livetimes_compatible = false; + lifetimes_compatible = false; break; } } else { - // A buffer is *not* reusable if - // i) its last use is after the point of reuse, or - // ii) its last use is also its first reuse but the operation - // does not allow for local reuse. + // A buffer is reusable if + // i) its last use is before the point of reuse, or + // ii) its last use is also its first reuse and the operation + // allows for local reuse. + // Otherwise, lifetimes are incompatible. Operation *last_use = liveness->getEndOperation(old_buffer_alias, &block->front()); assert(last_use != nullptr && last_use->getBlock() == block && "Expected last use in same block."); if (first_reuse->isBeforeInBlock(last_use)) { - livetimes_compatible = false; + lifetimes_compatible = false; break; } if (first_reuse == last_use && !can_reuse_locally(first_reuse, old_buffer_alias, new_buffer)) { - livetimes_compatible = false; + lifetimes_compatible = false; break; } } } - // All criteria are fulfilled 🙂. - if (livetimes_compatible) + if (lifetimes_compatible) { + // All criteria are fulfilled 🙂. + int64_t old_buffer_index = old_buffer.getArgNumber(); local_reuse_candidates.push_back(old_buffer_index); + } } reuse_candidates_[&op] = local_reuse_candidates; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc index 46dbfbb8405..6199ea14b4b 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc @@ -48,40 +48,6 @@ namespace kernel_gen { namespace transforms { namespace { -#define GEN_PASS_CLASSES -#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" - -struct HloBufferizePass : public HloBufferizePassBase<HloBufferizePass> { - // TODO(b/173201243): Move to tablegen. - void getDependentDialects(DialectRegistry& registry) const override { - registry.insert<lmhlo::LmhloDialect>(); - } - - public: - void runOnOperation() override { - OwningRewritePatternList patterns; - auto& context = getContext(); - ConversionTarget target(context); - target.addLegalDialect<lmhlo::LmhloDialect>(); - target.addLegalDialect<StandardOpsDialect>(); - target.addIllegalDialect<mhlo::MhloDialect>(); - - BufferizeTypeConverter converter; - // Configure bufferize pattern for functions and lhlo. - mhlo::populateHLOToLHLOConversionPattern(&context, &converter, &patterns); - - // Configure legality and structural patterns. - populateBufferizeMaterializationLegality(target); - populateShapeStructuralTypeConversionsAndLegality(&context, converter, - patterns, target); - scf::populateSCFStructuralTypeConversionsAndLegality(&context, converter, - patterns, target); - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) - signalPassFailure(); - } -}; - // TODO(herhut) : This could become a real pattern in bufferize pass. What we // would need to do is insert a copy to model the semantics correctly. The same // is true for the TensorLoad pattern that is already in there. Then buffer @@ -104,11 +70,34 @@ class UnrankedTensorStoreTestOnlyPattern } }; -struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> { +// TODO(frgossen): Move this upstream to `populateFuncOpTypeConversionPattern` +// This pattern is merely needed to materialize type casts for return values so +// that they match the function signature conversion. +class ReturnOpTypeConversionPattern + : public OpConversionPattern<mlir::ReturnOp> { + public: + using OpConversionPattern<mlir::ReturnOp>::OpConversionPattern; + + LogicalResult matchAndRewrite( + ReturnOp op, ArrayRef<Value> operands, + ConversionPatternRewriter& rewriter) const final { + rewriter.replaceOpWithNewOp<ReturnOp>(op, operands); + return success(); + } +}; + +#define GEN_PASS_CLASSES +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" + +struct BufferizePass : public BufferizePassBase<BufferizePass> { + explicit BufferizePass(bool allow_partial_bufferization) { + allow_partial_bufferization_ = allow_partial_bufferization; + } + // TODO(b/173201243): Move to tablegen. void getDependentDialects(DialectRegistry& registry) const override { registry.insert<AffineDialect, scf::SCFDialect, shape::ShapeDialect, - tf_framework::TFFrameworkDialect>(); + tf_framework::TFFrameworkDialect, lmhlo::LmhloDialect>(); } public: @@ -117,12 +106,17 @@ struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> { ConversionTarget target(context); target.addLegalDialect<scf::SCFDialect, StandardOpsDialect, tf_framework::TFFrameworkDialect, AffineDialect, - shape::ShapeDialect>(); + shape::ShapeDialect, lmhlo::LmhloDialect>(); target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); + target.addIllegalDialect<mhlo::MhloDialect>(); target.addIllegalOp<DynamicTensorFromElementsOp, ExtractElementOp, - TensorFromElementsOp, TensorCastOp, TensorLoadOp, - TensorToMemrefOp>(); + TensorFromElementsOp, TensorCastOp>(); + + if (!allow_partial_bufferization_) { + target.addIllegalOp<TensorLoadOp, TensorToMemrefOp>(); + } + // Certain operations are no longer legal on tensors but otherwise are. target.addDynamicallyLegalOp<ConstantOp, SelectOp>([&](Operation* op) { return llvm::none_of(op->getResultTypes(), @@ -144,8 +138,8 @@ struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> { return converter.isLegal(inputs) && converter.isLegal(results) && converter.isLegal(&op.getBody()); }); - target.addDynamicallyLegalOp<CallOp, ConstantOp, DimOp, RankOp, SelectOp>( - typesAreLegal); + target.addDynamicallyLegalOp<CallOp, ConstantOp, DimOp, RankOp, SelectOp, + ReturnOp>(typesAreLegal); OwningRewritePatternList patterns; mhlo::populateHLOToLHLOConversionPattern(&context, &converter, &patterns); @@ -160,22 +154,27 @@ struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> { scf::populateSCFStructuralTypeConversionsAndLegality(&context, converter, patterns, target); patterns.insert<UnrankedTensorStoreTestOnlyPattern>(&context); + patterns.insert<ReturnOpTypeConversionPattern>(converter, &context); auto module = getOperation(); - if (failed(applyFullConversion(module, target, std::move(patterns)))) { - signalPassFailure(); + if (allow_partial_bufferization_) { + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + signalPassFailure(); + } + } else { + if (failed( + mlir::applyFullConversion(module, target, std::move(patterns)))) { + signalPassFailure(); + } } } }; } // namespace -std::unique_ptr<OperationPass<ModuleOp> > CreateHloBufferizePass() { - return std::make_unique<HloBufferizePass>(); -} - -std::unique_ptr<OperationPass<ModuleOp> > CreateFinalBufferizePass() { - return std::make_unique<FinalBufferizePass>(); +std::unique_ptr<OperationPass<ModuleOp> > CreateBufferizePass( + bool allow_partial_bufferization) { + return std::make_unique<BufferizePass>(allow_partial_bufferization); } } // namespace transforms diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_memref_prints.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_memref_prints.cc index 1d845f5ae5d..50d514eeb0a 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_memref_prints.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_memref_prints.cc @@ -37,9 +37,8 @@ namespace { using tf_framework::TFFrameworkDialect; -Operation* emitCallToFunc(Location loc, StringRef func_name, - ArrayRef<Type> return_types, ValueRange args, - OpBuilder* b) { +Operation* emitCallToPrint(Location loc, StringRef func_name, Value arg, + OpBuilder* b) { auto caller_func = b->getInsertionBlock()->getParent()->getParentOfType<FuncOp>(); auto callee_func = @@ -49,12 +48,12 @@ Operation* emitCallToFunc(Location loc, StringRef func_name, auto module = caller_func.getParentOfType<ModuleOp>(); b->setInsertionPointToStart(module.getBody()); - auto func_type = - FunctionType::get(args.getTypes(), return_types, b->getContext()); + auto func_type = FunctionType::get(arg.getType(), /*results=*/llvm::None, + b->getContext()); callee_func = b->create<FuncOp>(module.getLoc(), func_name, func_type); callee_func.setPrivate(); } - return b->create<CallOp>(loc, callee_func, args); + return b->create<CallOp>(loc, callee_func, arg); } void EmitPrint(Operation* op, Liveness& liveness, OpBuilder* b) { @@ -70,28 +69,32 @@ void EmitPrint(Operation* op, Liveness& liveness, OpBuilder* b) { liveness.getLiveness(op->getBlock())->getEndOperation(memref, op); b->setInsertionPoint(end_op); + if (element_type.isIndex()) { + element_type = b->getI64Type(); + memref_type = MemRefType::get(memref_type.getShape(), element_type, + memref_type.getAffineMaps(), + memref_type.getMemorySpace()); + memref = b->create<IndexCastOp>(loc, memref, memref_type); + } + auto unranked_type = UnrankedMemRefType::get(element_type, memref_type.getMemorySpace()); Value unranked_memref = b->create<MemRefCastOp>(loc, memref, unranked_type); if (element_type.isF32()) { - emitCallToFunc(loc, "print_memref_f32", {}, {unranked_memref}, b); + emitCallToPrint(loc, "print_memref_f32", unranked_memref, b); return; } if (element_type.isF64()) { - emitCallToFunc(loc, "print_memref_f64", {}, {unranked_memref}, b); - return; - } - if (element_type.isIndex()) { - emitCallToFunc(loc, "print_memref_index", {}, {unranked_memref}, b); + emitCallToPrint(loc, "print_memref_f64", unranked_memref, b); return; } if (element_type.isInteger(32)) { - emitCallToFunc(loc, "print_memref_i32", {}, {unranked_memref}, b); + emitCallToPrint(loc, "print_memref_i32", unranked_memref, b); return; } - if (element_type.isInteger(64)) { - emitCallToFunc(loc, "print_memref_i64", {}, {unranked_memref}, b); + if (element_type.isInteger(64) || element_type.isIndex()) { + emitCallToPrint(loc, "print_memref_i64", unranked_memref, b); return; } } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h index e2d92c3c8c3..c067d2b779d 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h @@ -47,13 +47,10 @@ std::unique_ptr<OperationPass<ModuleOp> > CreateTFKernelToLLVMPass(); // using memref descriptors. std::unique_ptr<OperationPass<ModuleOp> > CreateShapeToDescriptorsPass(); -// Pass to tranform hlo-level computations on values to their corresponding -// parts on buffers. -std::unique_ptr<OperationPass<ModuleOp>> CreateHloBufferizePass(); - -// Pass to tranform late-dialect level computations (essentially all non-hlo -// dialects) on values to their corresponding parts on buffers. -std::unique_ptr<OperationPass<ModuleOp>> CreateFinalBufferizePass(); +// Pass to tranform operations on values to their corresponding parts on +// buffers. +std::unique_ptr<OperationPass<ModuleOp>> CreateBufferizePass( + bool allow_partial_bufferization = false); // Pass to materialize broadcasts. std::unique_ptr<FunctionPass> CreateMaterializeBroadcastsPass(); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td index b6f8202217a..96e66297771 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td @@ -38,16 +38,14 @@ def ShapeToDescriptorsPass : Pass<"shape-to-descriptors", "ModuleOp"> { let constructor = "transforms::CreateShapeToDescriptorsPass()"; } -def HloBufferizePass : Pass<"hlo-bufferize", "ModuleOp"> { - let summary = "Pass to transform hlo operations on values to buffer based " - "ones."; - let constructor = "transforms::CreateHloBufferizePass()"; -} - -def FinalBufferizePass : Pass<"final-bufferize", "ModuleOp"> { - let summary = "Pass to transform operations from all non-hlo dialects on " - "values to buffer-based ones."; - let constructor = "transforms::CreateFinalBufferizePass()"; +def BufferizePass : Pass<"bufferize", "ModuleOp"> { + let summary = "Pass to transform operations on values to buffer-based ones."; + let options = [ + Option<"allow_partial_bufferization_", "allow-partial-bufferization", + "bool", /*default=*/"false", "Allow partial bufferization. " + "Value-based operations may remain, e.g. for shape operations.">, + ]; + let constructor = "transforms::CreateBufferizePass()"; } def MaterializeBroadcastsPass : FunctionPass<"materialize-broadcast"> { diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index c1e08284d21..5dc89adffc4 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -140,6 +140,7 @@ cc_library( ":translate_cl_options", "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:lhlo", + "//tensorflow/compiler/mlir/hlo:lhlo_gpu", "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -148,6 +149,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", + "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt index 95595d850d2..395d4bb8f9f 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt @@ -63,8 +63,11 @@ HloModule SelectAndScatter ROOT %add.11 = f32[] add(f32[] %lhs.9, f32[] %rhs.10) } +// CHECK-LABEL: module +// CHECK: global_memref "private" constant @[[$GLOBAL:.*]] : memref<f32> = dense<0.000000e+00> // CHECK-LABEL: func @main -// CHECK: "lmhlo.select_and_scatter" +// CHECK: %[[GLOBAL_MEMREF:.*]] = get_global_memref @[[$GLOBAL]] : memref<f32> +// CHECK: "lmhlo.select_and_scatter"(%{{.*}}, %{{.*}}, %[[GLOBAL_MEMREF]], %{{.*}}) // CHECK: ^bb0(%[[ARG0:.*]]: tensor<f32>, %[[ARG1:.*]]: tensor<f32>): // CHECK: %[[COMPARE:.*]] = "mhlo.compare"(%[[ARG0]], %[[ARG1]]) {comparison_direction = "GE"} // CHECK: "mhlo.return"(%[[COMPARE]]) : (tensor<i1>) -> () @@ -78,7 +81,7 @@ HloModule SelectAndScatter ENTRY main () -> f32[6] { %operand = f32[6]{0} parameter(0) %source = f32[2]{0} parameter(1) - %init = f32[] parameter(2) + %init = f32[] constant(0) ROOT %select-and-scatter.12 = f32[6]{0} select-and-scatter(f32[6]{0} %operand, f32[2]{0} %source, f32[] %init), window={size=3 stride=3}, select=%ge_F32, scatter=%add_F32 } @@ -101,4 +104,20 @@ ENTRY main { s32[] %static), custom_call_target="SliceToDynamic", backend_config="" -} \ No newline at end of file +} + +// ----- + + +HloModule Cholesky + +// CHECK-LABEL: func @main +// CHECK: "lmhlo_gpu.cholesky" +// CHECK-SAME: is_lower = true +ENTRY main { + %param = f32[3,3] parameter(0) + ROOT %custom-call = (f32[3,3], f32[3], s32[]) custom-call(f32[3,3] %param), + custom_call_target="__cusolver$cholesky", + operand_layout_constraints={f32[3,3]}, + backend_config="{\"lower\":true}" +} diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir index e7312e2114c..cd72707123c 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir @@ -374,3 +374,23 @@ func @main(%arg0: tuple<tuple<tensor<f32>>, tensor<f32>>, %arg1: tuple<tensor<f3 return %result : tuple<tensor<f32>, tensor<f32>, tensor<f32>> } + +// ----- + +// CHECK-LABEL: func @main +// CHECK: "lmhlo.reduce"({{.*}}) ( { +// CHECK: ^bb0(%[[VAL1:.*]]: tensor<f32>, %[[VAL2:.*]]: tensor<i32>, %[[VAL3:.*]]: tensor<f32>, %[[VAL4:.*]]: tensor<i32>): // no predecessors +// CHECK: %[[VAL5:.*]] = mhlo.maximum %[[VAL1]], %[[VAL3]] : tensor<f32> +// CHECK: %[[VAL6:.*]] = mhlo.maximum %[[VAL2]], %[[VAL4:.*]] : tensor<i32> +// CHECK: %[[VAL7:.*]] = "mhlo.tuple"(%[[VAL5]], %[[VAL6:.*]]) : (tensor<f32>, tensor<i32>) -> tuple<tensor<f32>, tensor<i32>> +// CHECK: "mhlo.return"(%[[VAL7:.*]]) : (tuple<tensor<f32>, tensor<i32>>) -> () +// CHECK: }) +func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor<f32>, %arg3 : tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>) { + %result0, %result1 = "mhlo.reduce"(%arg0, %arg1, %arg2, %arg3) ( { + ^bb0(%fa: tensor<f32>, %ia : tensor<i32>, %fb: tensor<f32>, %ib: tensor<i32>): // no predecessors + %fmax = "mhlo.maximum"(%fa, %fb) {} : (tensor<f32>, tensor<f32>) -> tensor<f32> + %imax = "mhlo.maximum"(%ia, %ib) {} : (tensor<i32>, tensor<i32>) -> tensor<i32> + "mhlo.return"(%fmax, %imax) : (tensor<f32>, tensor<i32>) -> () + }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<1x10xi32>, tensor<f32>, tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>) + return %result0, %result1 : tensor<1xf32>, tensor<1xi32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir index 7f37dbb0479..f24b1b6e6bc 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir @@ -13,10 +13,8 @@ // CHECK-LABEL: func @add func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK-NEXT: %[[SUM0:.*]] = mhlo.add %arg0, %arg0 : tensor<2xi32> - // CHECK-NEXT: %[[SUM1:.*]] = mhlo.add %[[SUM0]], %arg0 : tensor<2xi32> - // CHECK-NEXT: return %[[SUM1]] : tensor<2xi32> - %0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - %1 = "tf.AddV2"(%0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + // CHECK-NEXT: return %[[SUM0]] : tensor<2xi32> + %1 = "tf.AddV2"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %1: tensor<2xi32> } @@ -27,7 +25,7 @@ func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { // CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-NEXT: mhlo.add %[[LHS_BCAST]], %arg1 - %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0: tensor<1x2xi32> } @@ -37,7 +35,7 @@ func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { // CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} // CHECK-NEXT: mhlo.add %[[LHS_BCAST]], %arg1 - %0 = "tf.Add"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> + %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> return %0: tensor<4x4x4x4xi32> } @@ -55,7 +53,7 @@ func @add_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi3 // CHECK-NEXT: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} // CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] : tensor<?x?xi32> // CHECK-NEXT: shape.assuming_yield %[[RESULT]] - %0 = "tf.Add"(%arg0, %arg1) : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32> + %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32> return %0: tensor<?x?xi32> } @@ -64,7 +62,7 @@ func @add_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi3 func @broadcast_add_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { // CHECK: tf.Add // CHLO: chlo.broadcast_add %arg0, %arg1 - %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<*xi32>) -> tensor<*xi32> + %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1xi32>, tensor<*xi32>) -> tensor<*xi32> return %0: tensor<*xi32> } @@ -264,6 +262,13 @@ func @equal_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1> return %0: tensor<*xi1> } +// CHECK-LABEL: func @equal_unsupported_type +func @equal_unsupported_type(%arg0: tensor<*x!tf.string>, %arg1: tensor<*x!tf.string>) -> tensor<*xi1> { + // CHECK: "tf.Equal" + %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<*x!tf.string>, tensor<*x!tf.string>) -> tensor<*xi1> + return %0: tensor<*xi1> +} + // CHECK-LABEL: func @notequal func @notequal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-full-conversion.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-full-conversion.mlir index 0660af4ed1c..0646a391eca 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-full-conversion.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-full-conversion.mlir @@ -26,7 +26,7 @@ func @tf_unknown_op(%arg0: tensor<2xi32>) -> tensor<2xi32> { // ----- func @tf_known_op(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %0 = "tf.AddV2"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0: tensor<2xi32> } @@ -38,7 +38,7 @@ func @tf_unknown_known_mix(%arg0: tensor<2xi32>) -> tensor<2xi32> { // expected-error@+1 {{'tf.OpA' op is not legalizable}} %0 = "tf.OpA"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> %1 = "tf.OpB"(%0, %0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - %2 = "tf.Add"(%1, %1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %2 = "tf.AddV2"(%1, %1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> %3 = "tf.OpB"(%2, %2) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %2: tensor<2xi32> } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index fd6ed5a0219..5061de06c6f 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -3984,31 +3984,35 @@ func @assert(%arg0: tensor<i1>, %arg1: tensor<*xf32>) { // tf.Unpack legalization //===----------------------------------------------------------------------===// -// TODO(b/156340000): Re-enable when fixed. -// // C-HECK-LABEL: @unpack -// func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>) { -// // C-HECK: %[[SLICE1:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> -// // C-HECK: %[[RES1:.*]] = "mhlo.reshape"(%[[SLICE1]]) : (tensor<4x1x6xf32>) -> tensor<4x?xf32> -// // C-HECK: %[[SLICE2:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> -// // C-HECK: %[[RES2:.*]] = "mhlo.reshape"(%[[SLICE2]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> -// // C-HECK: %[[SLICE3:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> -// // C-HECK: %[[RES3:.*]] = "mhlo.reshape"(%[[SLICE3]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> +// CHECK-LABEL: @unpack +func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) { + // CHECK: %[[SLICE1:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> + // CHECK: %[[RES1:.*]] = "mhlo.reshape"(%[[SLICE1]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[SLICE2:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> + // CHECK: %[[RES2:.*]] = "mhlo.reshape"(%[[SLICE2]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[SLICE3:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> + // CHECK: %[[RES3:.*]] = "mhlo.reshape"(%[[SLICE3]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> -// %0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -// // return %[[RES1]], %[[RES2]], %[[RES3]] -// return %0#0, %0#1, %0#2 : tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32> -// } + %0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) + // return %[[RES1]], %[[RES2]], %[[RES3]] + return %0#0, %0#1, %0#2 : tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32> +} -// // C-HECK-LABEL: @unpack_dynamic -// func @unpack_dynamic(%input: tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) { -// // C-HECK: %[[SLICE1:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 1]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<?x?x2xf32>) -> tensor<?x?x1xf32> -// // C-HECK: "mhlo.reshape"(%[[SLICE1]]) : (tensor<?x?x1xf32>) -> tensor<?x?xf32> -// // C-HECK: %[[SLICE2:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 2]> : tensor<3xi64>, start_indices = dense<[0, 0, 1]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<?x?x2xf32>) -> tensor<?x?x1xf32> -// // C-HECK: "mhlo.reshape"(%[[SLICE2]]) : (tensor<?x?x1xf32>) -> tensor<?x?xf32> +// CHECK-LABEL: @unpack_dynamic +func @unpack_dynamic(%input: tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) { -// %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) -// return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32> -// } + // CHECK: tf.Unpack + %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) + return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32> +} + +// CHECK-LABEL: @unpack_unranked +func @unpack_unranked(%input: tensor<*xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) { + + // CHECK: tf.Unpack + %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<*xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) + return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32> +} //===----------------------------------------------------------------------===// // tf.UnsortedSegment{Max|Min|Prod|Sum} legalization diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index fe75a2a8dff..f30fdeb0991 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -3311,7 +3311,7 @@ class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> { auto input_val = GetScalarConstOfType(begin_element_ty, loc, input_shape[d], &rewriter); auto wrapped_index = - rewriter.create<TF::AddOp>(loc, input_val, reshaped_index); + rewriter.create<TF::AddV2Op>(loc, input_val, reshaped_index); auto final_index = rewriter.create<SelectOp>( loc, type, index_negative, wrapped_index, reshaped_index); slice_begin_indices.push_back(final_index); @@ -4808,7 +4808,7 @@ class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> { LogicalResult matchAndRewrite(TF::UnpackOp op, PatternRewriter &rewriter) const override { - auto value_type = op.value().getType().cast<RankedTensorType>(); + auto value_type = op.value().getType().dyn_cast<RankedTensorType>(); if (!value_type) return failure(); int64_t value_rank = value_type.getRank(); @@ -4820,7 +4820,7 @@ class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> { auto end_indices = llvm::to_vector<4>(value_type.getShape()); SmallVector<int64_t, 4> strides(value_rank, 1); - // All HLO slice+reshape results used to replace the original tf.Unpack op. + // All HLO slice+squeeze results used to replace the original tf.Unpack op. SmallVector<Value, 4> results; results.reserve(op.getNumResults()); @@ -4833,9 +4833,10 @@ class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> { GetI64ElementsAttr(end_indices, &rewriter), GetI64ElementsAttr(strides, &rewriter)); // Reshape to drop the axis dimension. - auto reshape_op = rewriter.create<mhlo::ReshapeOp>( - op.getLoc(), op.getType(i), slice_op); - results.push_back(reshape_op); + auto result = + rewriter.create<TF::SqueezeOp>(op.getLoc(), op.getType(i), slice_op, + rewriter.getI64ArrayAttr(op.axis())); + results.push_back(result); } rewriter.replaceOp(op, results); diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 11af809ffb7..113d88158d2 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -89,8 +89,7 @@ class DirectBinaryPat<Op FromOp, Op ToOp> : Pat<(FromOp AnyTensor:$l, AnyTensor:$r), (ToOp $l, $r, (BinBroadcastDimensions $l, $r))>; -foreach fromToBinPair = [[TF_AddOp, HLOClient_BroadcastAddOp], - [TF_AddV2Op, HLOClient_BroadcastAddOp], +foreach fromToBinPair = [[TF_AddV2Op, HLOClient_BroadcastAddOp], [TF_DivOp, HLOClient_BroadcastDivOp], [TF_LeftShiftOp, HLOClient_BroadcastShiftLeftOp], [TF_MaximumOp, HLOClient_BroadcastMaxOp], @@ -225,7 +224,7 @@ class EqualityPat<Op FromOp, StrEnumAttrCase direction> (HLOClient_BroadcastCompareOp $l, $r, (BinBroadcastDimensions $l, $r), direction, (HLO_DEFAULT_COMPARISON_TYPE)), - [(AreBroadcastCompatible $l, $r)]>; + [(AreBroadcastCompatible $l, $r), (HLO_Tensor $l)]>; def : EqualityPat<TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ>; def : EqualityPat<TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE>; diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc index edc7edf8068..505bdcdcf48 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h" +#include <climits> #include <memory> #include <tuple> @@ -38,6 +39,7 @@ limitations under the License. #include "mlir/Pass/PassOptions.h" // from @llvm-project #include "mlir/Translation.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h" #include "tensorflow/compiler/mlir/xla/hlo_utils.h" @@ -46,18 +48,21 @@ limitations under the License. #include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" using xla::BufferAllocation; using xla::BufferAssignment; using xla::HloComputation; +using xla::HloCustomCallInstruction; using xla::HloInstruction; using xla::HloModule; using xla::HloModuleProto; @@ -140,8 +145,9 @@ Status ConvertModule(std::unique_ptr<HloModule> hlo_module, ModuleOp module, class XlaHloToLhloPass : public PassWrapper<XlaHloToLhloPass, OperationPass<ModuleOp>> { void getDependentDialects(DialectRegistry& registry) const override { - registry.insert<mlir::StandardOpsDialect, mlir::mhlo::MhloDialect, - mlir::lmhlo::LmhloDialect>(); + registry + .insert<mlir::StandardOpsDialect, mlir::mhlo::MhloDialect, + mlir::lmhlo::LmhloDialect, mlir::lmhlo_gpu::LmhloGpuDialect>(); } public: @@ -274,6 +280,10 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(HloInstruction* instr) { return EmitSelectAndScatterOp(instr); case HloOpcode::kCustomCall: return EmitCustomCallOp(instr); + case HloOpcode::kConstant: + return EmitConstant(instr); + case HloOpcode::kReduce: + return EmitReduceOp(instr); default: llvm::errs() << instr->ToString(); return tensorflow::errors::Internal( @@ -485,13 +495,18 @@ StatusOr<lmhlo::SelectAndScatterOp> LhloDialectEmitter::EmitSelectAndScatterOp( return select_and_scatter; } -StatusOr<lmhlo::CustomCallOp> LhloDialectEmitter::EmitCustomCallOp( +StatusOr<mlir::Operation*> LhloDialectEmitter::EmitCustomCallOp( HloInstruction* instr) { + auto* custom_call_instr = ::xla::Cast<::xla::HloCustomCallInstruction>(instr); + + if (xla::gpu::IsCustomCallToCusolver(*instr)) { + return EmitCholesky(custom_call_instr); + } + size_t num_arguments, num_results; TF_ASSIGN_OR_RETURN(auto custom_call, CreateOpWithoutAttrs<lmhlo::CustomCallOp>( instr, num_arguments, num_results)); - auto* custom_call_instr = ::xla::Cast<::xla::HloCustomCallInstruction>(instr); custom_call.call_target_nameAttr( builder_.getStringAttr(custom_call_instr->custom_call_target())); custom_call.backend_configAttr( @@ -500,12 +515,102 @@ StatusOr<lmhlo::CustomCallOp> LhloDialectEmitter::EmitCustomCallOp( static_cast<int32_t>(num_results)}; custom_call.setAttr(lmhlo::CustomCallOp::getOperandSegmentSizeAttr(), builder_.getI32VectorAttr(segments)); - return custom_call; + return custom_call.getOperation(); +} + +StatusOr<lmhlo_gpu::CholeskyOp> LhloDialectEmitter::EmitCholesky( + HloCustomCallInstruction* custom_call) { + TF_ASSIGN_OR_RETURN(auto cholesky_op, + CreateOpWithoutAttrs<lmhlo_gpu::CholeskyOp>(custom_call)); + TF_ASSIGN_OR_RETURN(xla::CholeskyOptions options, + custom_call->backend_config<xla::CholeskyOptions>()); + cholesky_op.is_lowerAttr(builder_.getBoolAttr(options.lower())); + return cholesky_op; +} + +// Convert an XLA HLO constant to a global_memref + get_global_memref pair. +StatusOr<mlir::GetGlobalMemrefOp> LhloDialectEmitter::EmitConstant( + const HloInstruction* instr) { + // Insert a global_memref in the module. + Location loc = getLocation(instr); + + auto const_instr = ::xla::Cast<::xla::HloConstantInstruction>(instr); + TF_RET_CHECK(const_instr->shape().IsArray() && + const_instr->shape().is_static()); + TF_ASSIGN_OR_RETURN(Type type, ::xla::ConvertShapeToType<MemRefType>( + const_instr->shape(), builder_)); + auto memref_type = type.dyn_cast<MemRefType>(); + TF_RET_CHECK(memref_type != nullptr); + + TF_ASSIGN_OR_RETURN( + DenseElementsAttr initial_value, + CreateDenseElementsAttrFromLiteral(const_instr->literal(), builder_)); + + std::string constant_name = ::xla::llvm_ir::ConstantHloToGlobalName(*instr); + + // Insert the global memref at the top level. + { + OpBuilder::InsertionGuard guard(builder_); + builder_.clearInsertionPoint(); + auto global_var = builder_.create<GlobalMemrefOp>( + loc, constant_name, builder_.getStringAttr("private"), + TypeAttr::get(memref_type), initial_value, true); + SymbolTable(module_).insert(global_var); + global_var.getOperation()->moveBefore(&module_.front()); + } + + auto get_global_memref = + builder_.create<GetGlobalMemrefOp>(loc, memref_type, constant_name); + + // For operations that do not fold this constant value in their codegen, we + // still need to materialize it into a buffer. Since buffer allocation is + // already done, annotate the get_global_memref with the information to get to + // the allocated buffer slice for this constant if need be. + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, + assignment_.GetUniqueTopLevelSlice(instr)); + get_global_memref.setAttr("lmhlo.alloc", + builder_.getIndexAttr(slice.index())); + get_global_memref.setAttr("lmhlo.slice_offset", + builder_.getI64IntegerAttr(slice.offset())); + get_global_memref.setAttr("lmhlo.slice_size", + builder_.getI64IntegerAttr(slice.size())); + + // Update the cache to remember this value. + auto& cached_value = slices_[std::make_pair(instr, ::xla::ShapeIndex())]; + TF_RET_CHECK(cached_value == nullptr); + cached_value = get_global_memref; + return get_global_memref; +} + +StatusOr<lmhlo::ReduceOp> LhloDialectEmitter::EmitReduceOp( + HloInstruction* instr) { + TF_ASSIGN_OR_RETURN(auto reduce_op, + CreateOpWithoutAttrs<lmhlo::ReduceOp>(instr)); + auto* reduce = ::xla::Cast<::xla::HloReduceInstruction>(instr); + std::vector<int64_t> dimensions(reduce->dimensions().begin(), + reduce->dimensions().end()); + reduce_op.dimensionsAttr(GetI64DenseElementsAttr(dimensions)); + TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion( + *instr->called_computations()[0], &reduce_op.body(), &builder_)); + return reduce_op; } StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView( const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape, const ::xla::ShapeIndex& shape_index) { + // Cache generated ViewOp and StaticMemRefCastOp by (instruction, + // shape_index). + auto& cached_value = slices_[std::make_pair(instr, shape_index)]; + if (cached_value) { + return cached_value; + } + + if (instr->IsConstant() && shape_index.empty()) { + TF_ASSIGN_OR_RETURN(Value constant_memref, EmitConstant(instr)); + return cached_value = constant_memref; + } + // If the shape happens to have dynamic dimensions, create the memref using // the underlying static shape. // TODO(jurahul): Revisit this when we can model memrefs with dynamic shape @@ -518,7 +623,7 @@ StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView( assignment_.GetUniqueSlice(instr, shape_index)); Value alloc = allocations_[slice.allocation()]; if (alloc.getType() == out_type && slice.offset() == 0) { - return alloc; + return cached_value = alloc; } auto out_memref_type = out_type.dyn_cast<MemRefType>(); @@ -527,13 +632,6 @@ StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView( "Expected memref type when creating a view for leaf type of a " "tuple."); - // Cache generated ViewOp and StaticMemRefCastOp by (instruction, - // shape_index). - auto& cached_value = slices_[std::make_pair(instr, shape_index)]; - if (cached_value) { - return cached_value; - } - Value byte_shift = builder_.create<ConstantIndexOp>(alloc.getLoc(), slice.offset()); @@ -695,8 +793,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createXlaHloToLhloWithXlaPass() { Status HloToLhloModule(const BufferAssignment& assignment, const HloModule& hlo_module, ModuleOp module) { module.getContext() - ->loadDialect<StandardOpsDialect, mhlo::MhloDialect, - lmhlo::LmhloDialect>(); + ->loadDialect<StandardOpsDialect, mhlo::MhloDialect, lmhlo::LmhloDialect, + lmhlo_gpu::LmhloGpuDialect>(); HloComputation* computation = hlo_module.entry_computation(); LhloDialectEmitter emitter(assignment, *computation, module); diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h index 29a4c67d366..169cbdc93c5 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h @@ -16,13 +16,15 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_ #define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_ +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" namespace mlir { @@ -55,8 +57,18 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { ::xla::StatusOr<lmhlo::ScatterOp> EmitScatterOp(::xla::HloInstruction* instr); ::xla::StatusOr<lmhlo::SelectAndScatterOp> EmitSelectAndScatterOp( ::xla::HloInstruction* instr); - ::xla::StatusOr<lmhlo::CustomCallOp> EmitCustomCallOp( - ::xla::HloInstruction* instr); + + ::xla::StatusOr<Operation*> EmitCustomCallOp(::xla::HloInstruction* instr); + ::xla::StatusOr<lmhlo_gpu::CholeskyOp> EmitCholesky( + ::xla::HloCustomCallInstruction* custom_call); + + ::xla::StatusOr<lmhlo::ReduceOp> EmitReduceOp(::xla::HloInstruction* instr); + ::xla::StatusOr<GetGlobalMemrefOp> EmitConstant( + ::xla::HloInstruction* instr) { + return EmitConstant(static_cast<const ::xla::HloInstruction*>(instr)); + } + ::xla::StatusOr<GetGlobalMemrefOp> EmitConstant( + const ::xla::HloInstruction* instr); ::xla::Status CreateOperands(::xla::HloInstruction* instr, SmallVectorImpl<Value>& operands, @@ -122,7 +134,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { OpBuilder* b, Location loc); // Return an MLIR location for an HLO instruction. - Location getLocation(::xla::HloInstruction* inst) { + Location getLocation(const ::xla::HloInstruction* inst) { return NameLoc::get(builder_.getIdentifier(inst->name()), builder_.getContext()); } @@ -180,7 +192,7 @@ tensorflow::Status HloToLhloModule(const ::xla::BufferAssignment& assignment, ModuleOp module); OwningModuleRef HloTextToLhloTranslateFunction(llvm::StringRef input, - mlir::MLIRContext* context); + MLIRContext* context); } // namespace mlir diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 99d7730bfe5..db631b9c5bb 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -85,6 +85,9 @@ namespace tensorflow { namespace tensorrt { namespace convert { +using absl::StrAppend; +using absl::StrCat; + bool IsEngineInput(absl::string_view name) { return absl::StartsWith(name, IONamePrefixes::kInputPHName); } @@ -92,47 +95,6 @@ bool IsEngineOutput(absl::string_view name) { return absl::StartsWith(name, IONamePrefixes::kOutputPHName); } -using absl::StrAppend; -using absl::StrCat; - -inline Status TfDataTypeToTrt(DataType tf_dtype, - nvinfer1::DataType* trt_dtype) { - switch (tf_dtype) { - case DataType::DT_FLOAT: - *trt_dtype = nvinfer1::DataType::kFLOAT; - break; - case DataType::DT_HALF: - *trt_dtype = nvinfer1::DataType::kHALF; - break; - case DataType::DT_INT32: - *trt_dtype = nvinfer1::DataType::kINT32; - break; - default: - return errors::InvalidArgument("Unsupported data type ", - DataTypeString(tf_dtype)); - } - return Status::OK(); -} - -inline Status TrtDataTypeToTf(nvinfer1::DataType trt_dtype, - DataType* tf_dtype) { - switch (trt_dtype) { - case nvinfer1::DataType::kFLOAT: - *tf_dtype = DataType::DT_FLOAT; - break; - case nvinfer1::DataType::kHALF: - *tf_dtype = DataType::DT_HALF; - break; - case nvinfer1::DataType::kINT32: - *tf_dtype = DataType::DT_INT32; - break; - default: - return errors::InvalidArgument("Unsupported data type ", - DebugString(trt_dtype)); - } - return Status::OK(); -} - class TFAttrs { public: explicit TFAttrs(const NodeDef& tf_node) { @@ -182,7 +144,7 @@ std::vector<float> TFAttrs::get<std::vector<float>>(const string& key) const { template <> nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(const string& key) const { nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT); - TF_CHECK_OK(TfDataTypeToTrt(this->at(key)->type(), &trt_dtype)); + TF_CHECK_OK(TfTypeToTrtType(this->at(key)->type(), &trt_dtype)); return trt_dtype; } @@ -271,7 +233,7 @@ Status ValidateTensorProperties(const string& producer_node_type, nvinfer1::DataType* trt_dtype, nvinfer1::Dims* trt_dims, int* batch_size) { // Convert data type. - TF_RETURN_IF_ERROR(TfDataTypeToTrt(dtype, trt_dtype)); + TF_RETURN_IF_ERROR(TfTypeToTrtType(dtype, trt_dtype)); // Convert shape. if (shape.dims() < 0) { @@ -512,7 +474,7 @@ Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value, TFAttrs attrs(params->node_def); if (attrs.count(dtype_attr_name)) { DataType dtype = attrs.get<DataType>(dtype_attr_name); - TF_RETURN_IF_ERROR(TfDataTypeToTrt(dtype, &trt_type)); + TF_RETURN_IF_ERROR(TfTypeToTrtType(dtype, &trt_type)); } // In order to be broadcastable, the number of dims has to match. @@ -1091,7 +1053,7 @@ TRT_ShapedWeights TrtWeightStore::GetTempWeights(nvinfer1::DataType trt_dtype, DataType tf_dtype; // TODO(laigd): make it return a status. TF_CHECK_OK(TensorShapeUtils::MakeShape(dims.d, dims.nbDims, &shape)); - TF_CHECK_OK(TrtDataTypeToTf(trt_dtype, &tf_dtype)); + TF_CHECK_OK(TrtTypeToTfType(trt_dtype, &tf_dtype)); // TODO(jie): check weights size_bytes. 0 means type error Tensor tensor(tf_dtype, shape); TRT_ShapedWeights weights(trt_dtype, dims, tensor); @@ -2621,6 +2583,7 @@ Status Converter::DynamicReshape(nvinfer1::ITensor* input, nvinfer1::IConcatenationLayer* concat_layer = network()->addConcatenation( const_cast<nvinfer1::ITensor* const*>(concat_inputs.data()), concat_inputs.size()); + SetLayerName(concat_layer, params->node_def, "concat", op_instance); concat_layer->setAxis(0); nvinfer1::ITensor* new_shape = concat_layer->getOutput(0); // Reshape input using new shape @@ -4219,7 +4182,7 @@ Status TfTensorToTrtWeights(const Tensor& tensor, TrtWeightStore* weight_store, // Verify that the dtype is supported by TensorRT. Otherwise, return an error. nvinfer1::DataType trt_dtype; - TF_RETURN_IF_ERROR(TfDataTypeToTrt(converted_dtype, &trt_dtype)); + TF_RETURN_IF_ERROR(TfTypeToTrtType(converted_dtype, &trt_dtype)); if (tensor.NumElements() == 0) { // Return empty weights. @@ -6244,7 +6207,7 @@ Status ConvertGraphDefToEngine( TFAttrs attrs(node_def); DataType tf_dtype = attrs.get<DataType>("T"); nvinfer1::DataType trt_dtype; - TF_RETURN_IF_ERROR(TfDataTypeToTrt(tf_dtype, &trt_dtype)); + TF_RETURN_IF_ERROR(TfTypeToTrtType(tf_dtype, &trt_dtype)); if (output_tensors.size() <= slot_number) { output_tensors.resize(slot_number + 1); } diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index a33d5c28cb2..1d60ebbd047 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -135,20 +135,6 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v) { return os; } -nvinfer1::DataType TfDataTypeToTrt(DataType tf_type) { - nvinfer1::DataType trt_type; - Status status = TfTypeToTrtType(tf_type, &trt_type); - EXPECT_EQ(status, Status::OK()); - return trt_type; -} - -DataType TrtDataTypeToTf(nvinfer1::DataType trt_type) { - DataType tf_type; - Status status = TrtTypeToTfType(trt_type, &tf_type); - EXPECT_EQ(status, Status::OK()); - return tf_type; -} - NodeDef MakeNodeDef(const string& name, const string& op, const std::vector<string>& inputs, const std::map<string, AttrValue> attrs = {}) { @@ -1048,8 +1034,10 @@ TEST_F(ConverterTest, AddAndGetTensorOrWeights) { template <typename T> void TestGetWeightRange(ConverterTest* test, TrtWeightStore* weight_store) { - TRT_ShapedWeights weights = weight_store->GetTempWeights( - TfDataTypeToTrt(DataTypeToEnum<T>::v()), GetTestDims({2, 3})); + nvinfer1::DataType trt_type; + TF_ASSERT_OK(TfTypeToTrtType(DataTypeToEnum<T>::v(), &trt_type)); + TRT_ShapedWeights weights = + weight_store->GetTempWeights(trt_type, GetTestDims({2, 3})); const std::vector<T> values = {T(3), T(1), T(2), T(6), T(5), T(4)}; memcpy(weights.GetValues(), values.data(), weights.size_bytes()); @@ -1445,7 +1433,8 @@ class OpConverterTest : public ::testing::Test { ASSERT_NE(-1, input_index); const nvinfer1::DataType trt_dtype = engine_->getBindingDataType(input_index); - const DataType tf_type = TrtDataTypeToTf(trt_dtype); + DataType tf_type; + TF_ASSERT_OK(TrtTypeToTfType(trt_dtype, &tf_type)); ASSERT_EQ(data.tensor.dtype(), tf_type) << DataTypeString(data.tensor.dtype()) << " vs. " << DataTypeString(tf_type); @@ -1457,8 +1446,9 @@ class OpConverterTest : public ::testing::Test { // Mark the output tensor as TRT engine output. std::vector<Converter::EngineOutputInfo> output_info; for (const auto& data : *output_data) { - output_info.push_back( - {data.name, data.name, TfDataTypeToTrt(data.tensor.dtype())}); + nvinfer1::DataType trt_type; + TF_RETURN_IF_ERROR(TfTypeToTrtType(data.tensor.dtype(), &trt_type)); + output_info.push_back({data.name, data.name, trt_type}); } TF_RETURN_IF_ERROR(converter_->RenameAndMarkOutputTensors(output_info)); @@ -1519,7 +1509,8 @@ class OpConverterTest : public ::testing::Test { const string& name, const std::vector<int32>& dims, nvinfer1::DataType trt_type = nvinfer1::DataType::kFLOAT, Status add_input_status = Status::OK()) { - DataType tf_type = TrtDataTypeToTf(trt_type); + DataType tf_type; + TF_ASSERT_OK(TrtTypeToTfType(trt_type, &tf_type)); ops::Placeholder::Attrs attrs; TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &attrs.shape_)); @@ -1566,7 +1557,8 @@ class OpConverterTest : public ::testing::Test { node_inputs_[name] = ops::Const(scope_.WithOpName(name), t); // Add weights for conversion. - const nvinfer1::DataType dtype = TfDataTypeToTrt(DataTypeToEnum<T>::v()); + nvinfer1::DataType dtype; + TF_ASSERT_OK(TfTypeToTrtType(DataTypeToEnum<T>::v(), &dtype)); const nvinfer1::Dims trt_dims = GetTestDims(dims); const int64_t num_elements = TrtWeightDimsNumElements(trt_dims); QCHECK_EQ(num_elements, values.size()) @@ -1800,8 +1792,9 @@ class ParameterizedOpConverterTestBase partial_shape = dims; } } - AddTestTensorWithTFDims(name, partial_shape, TfDataTypeToTrt(tf_type), - add_input_status); + nvinfer1::DataType trt_type; + TF_ASSERT_OK(TfTypeToTrtType(tf_type, &trt_type)); + AddTestTensorWithTFDims(name, partial_shape, trt_type, add_input_status); if (!values.empty()) { VLOG(2) << "Adding test tensor: " << name << " " << DataTypeString(tf_type); @@ -2032,7 +2025,7 @@ TEST_F(OpConverterTest, ConvertConst) { Reset(); NodeDef node_def = MakeConstNodeDef<double>("my_const", {}); RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, - "Unsupported data type double"); + "Unsupported tensorflow data type double"); } { Reset(); @@ -2805,8 +2798,9 @@ void TestAddN(OpConverterTest* test) { test->Reset(); DataVec input_data; for (const auto name : {"inp1", "inp2", "inp3"}) { - test->AddTestTensor(name, /*dims=*/{1, 2}, /*batch_size=*/2, - TfDataTypeToTrt(dtype)); + nvinfer1::DataType trt_type; + TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type)); + test->AddTestTensor(name, /*dims=*/{1, 2}, /*batch_size=*/2, trt_type); input_data.push_back({name, test->AsTensor<CType>({CType(1), CType(2), CType(3), CType(4)})}); } @@ -2828,8 +2822,9 @@ void TestAddN(OpConverterTest* test) { test->Reset(); DataVec input_data; for (const auto name : {"inp1", "inp2"}) { - test->AddTestTensor(name, /*dims=*/{1, 2}, /*batch_size=*/1, - TfDataTypeToTrt(dtype)); + nvinfer1::DataType trt_type; + TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type)); + test->AddTestTensor(name, /*dims=*/{1, 2}, /*batch_size=*/1, trt_type); input_data.push_back({name, test->AsTensor<CType>({CType(1), CType(2)})}); } test->AddTestWeights("inp3", /*dims=*/{1, 1, 2}, @@ -4252,8 +4247,9 @@ TEST_P(OpConverterTest1, ConvertConv2D) { Reset(); NodeDef node_def = get_conv2d_nodedef(); // Channel dim unknown, should fail. - AddTestTensorWithTFDims("input", {-1, -1, -1, -1}, - TfDataTypeToTrt(tf_type_)); + nvinfer1::DataType trt_type; + TF_ASSERT_OK(TfTypeToTrtType(tf_type_, &trt_type)); + AddTestTensorWithTFDims("input", {-1, -1, -1, -1}, trt_type); AddTestWeights<float>("weights", {1, 2, 1, 1}, {-1, 1}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, @@ -5018,8 +5014,9 @@ TEST_F(OpConverterTest, ConvertTopK) { { // K is a tensor, should fail. Reset(); - AddTestTensor("input", {1, 2, 3}, /*batch_size=*/1, - /*trt_dtype=*/TfDataTypeToTrt(dtype)); + nvinfer1::DataType trt_type; + TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type)); + AddTestTensor("input", {1, 2, 3}, /*batch_size=*/1, trt_type); AddTestTensor("weights", {2}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, @@ -5590,8 +5587,10 @@ void TestConvertConcat(OpConverterTest* test) { NodeDef node_def = get_concat_nodedef(dtype, num_inputs); // Create inputs. for (int j = 0; j < num_inputs; ++j) { + nvinfer1::DataType trt_type; + TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type)); test->AddTestTensor(StrCat("values_", j), ok_params[i].input_shapes[j], 1, - TfDataTypeToTrt(dtype)); + trt_type); } test->AddTestWeights<int32>("axis", {1}, {ok_params[i].axis}); test->RunValidationAndConversion(node_def); @@ -5752,8 +5751,9 @@ void TestConvertSplit(OpConverterTest* test) { NodeDef node_def = get_split_nodedef(dtype, ok_params[i].num_split); // Create inputs. test->AddTestWeights<int32>("axis", {1}, {ok_params[i].axis}); - test->AddTestTensor("value", ok_params[i].input_shape, 1, - TfDataTypeToTrt(dtype)); + nvinfer1::DataType trt_type; + TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type)); + test->AddTestTensor("value", ok_params[i].input_shape, 1, trt_type); // Convert. test->RunValidationAndConversion(node_def); @@ -5929,8 +5929,9 @@ void TestConvertUnpack(OpConverterTest* test) { NodeDef node_def = get_unpack_nodedef(dtype, ok_params[i].num, ok_params[i].axis); // Create inputs. - test->AddTestTensor("value", ok_params[i].input_shape, 1, - TfDataTypeToTrt(dtype)); + nvinfer1::DataType trt_type; + TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type)); + test->AddTestTensor("value", ok_params[i].input_shape, 1, trt_type); // Convert. test->RunValidationAndConversion(node_def); @@ -6272,8 +6273,10 @@ void TestConvertArgMinMax(OpConverterTest* test) { NodeDef node_def = GetArgMinMaxNodeDef<OpType>(dtype, DT_INT32); // Create inputs. + nvinfer1::DataType trt_type; + TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type)); test->AddTestTensor("input", params[i].input_shape, /*batch_size=*/1, - /*trt_dtype=*/TfDataTypeToTrt(dtype)); + /*trt_dtype=*/trt_type); test->AddTestWeights<int32>("dimension", {1}, {params[i].axis}); test->RunValidationAndConversion(node_def); @@ -6374,8 +6377,9 @@ void TestConvertDepthSpaceShuffle( NodeDef node_def = GetDepthSpaceShuffleNodeDef<OpType>( dtype, params[i].block_size, params[i].data_format); - test->AddTestTensor("input", params[i].input_dims, 1, - TfDataTypeToTrt(dtype)); + nvinfer1::DataType trt_type; + TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type)); + test->AddTestTensor("input", params[i].input_dims, 1, trt_type); test->RunValidationAndConversion(node_def); TRT_TensorOrWeights output; @@ -6648,7 +6652,9 @@ void TestConvertClipByValue(OpConverterTest* test) { test->Reset(); NodeDef node_def = GetClipByValueNodeDef(dtype); - test->AddTestTensor("t", params[i].dims, 1, TfDataTypeToTrt(dtype)); + nvinfer1::DataType trt_type; + TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type)); + test->AddTestTensor("t", params[i].dims, 1, trt_type); test->AddTestWeights<CType>("clip_value_min", {1}, {params[i].clip_value_min}); test->AddTestWeights<CType>("clip_value_max", {1}, @@ -6848,9 +6854,11 @@ void TestConvertResize(OpConverterTest* test) { // Create resize node. NodeDef node_def = MakeResizeNodeDef<OpType>("my_resize", dtype, params[i].align_corners); + nvinfer1::DataType trt_type; + TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type)); // Create input tensor test->AddTestTensor("input", params[i].input_dims, /*batch_size=*/1, - /*trt_dtype=*/TfDataTypeToTrt(dtype)); + /*trt_dtype=*/trt_type); // Create output size. test->AddTestWeights<int32>("size", {2}, params[i].output_resize_dims); @@ -6949,8 +6957,10 @@ void TestConvertPad(OpConverterTest* test) { // Create pad node. NodeDef node_def = MakePadNodeDef("my_pad", dtype); // Create input tensor + nvinfer1::DataType trt_type; + TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type)); test->AddTestTensor("input", params[i].input_dims, /*batch_size=*/1, - /*trt_dtype=*/TfDataTypeToTrt(dtype)); + /*trt_dtype=*/trt_type); // Create output size. test->AddTestWeights<int32>("padding", params[i].pad_dims, {0, 0, 1, 0, 0, 1, 0, 0}); diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.cc b/tensorflow/compiler/tf2tensorrt/convert/utils.cc index 1fc0d13c993..650aff5836f 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.cc @@ -198,7 +198,8 @@ Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type) { *trt_type = nvinfer1::DataType::kINT32; break; default: - return errors::Internal("Unsupported tensorflow type"); + return errors::InvalidArgument("Unsupported tensorflow data type ", + DataTypeString(tf_type)); } return Status::OK(); } @@ -215,7 +216,7 @@ Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type) { *tf_type = DT_INT32; break; default: - return errors::Internal("Invalid TRT type"); + return errors::InvalidArgument("Invalid TRT data type"); } return Status::OK(); } diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index e0a731e502e..2d56209a068 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -103,13 +103,16 @@ class TRTEngineOp : public AsyncOpKernel { TRTEngineCacheResource* cache_res, AsyncHelper* helper); - // Construct a function handle for executing native funcdef graph - // These are the exact same function. + // Constructs a function handle for the segment of the TRTEngineOp. + StatusOr<FunctionLibraryRuntime::Handle> ConstructFunctionHandle( + FunctionLibraryRuntime* lib, const string& device_name, + bool allow_soft_placement = false, size_t num_inputs = 0, + size_t num_outputs = 0); - Status ConstructFunctionHandle(FunctionLibraryRuntime* lib, - const string& device_name, - bool allow_soft_placement = false, - size_t num_inputs = 0, size_t num_outputs = 0); + // Imports the GraphDef for the segment of the TRTEngineOp to + // segment_graph_def_. + Status ImportSegmentGraphDef(FunctionLibraryRuntime* lib, + const string& device_name); // Executes replaced native segment as function Op. void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper); @@ -175,12 +178,16 @@ class TRTEngineOp : public AsyncOpKernel { // Whether to build TensorRT engines at runtime. bool allow_build_at_runtime_; + // Whether to allow soft placement when the graph is executed with native + // TensorFlow. + bool allow_soft_placement_; + // Maximum number of cached engines. int max_cached_engines_; int64 workspace_size_; mutex engine_mutex_; - FunctionLibraryRuntime::Handle func_handle_; + FunctionLibraryRuntime::Handle native_execution_func_handle_; // The finalized calibrator for inference. std::unique_ptr<TRTInt8Calibrator> calibrator_; @@ -260,11 +267,9 @@ static Status FunctionDefToGraphDef(FunctionLibraryRuntime::Handle handle, return Status::OK(); } -Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib, - const string& device_name, - bool allow_soft_placement, - size_t num_inputs, - size_t num_outputs) { +StatusOr<FunctionLibraryRuntime::Handle> TRTEngineOp::ConstructFunctionHandle( + FunctionLibraryRuntime* lib, const string& device_name, + bool allow_soft_placement, size_t num_inputs, size_t num_outputs) { VLOG(1) << "Constructing function handle"; if (lib == nullptr) { return errors::Internal("Context function library is null"); @@ -298,8 +303,20 @@ Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib, inst_ops.config_proto.set_allow_soft_placement(true); } } - return lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), inst_ops, - &func_handle_); + FunctionLibraryRuntime::Handle func_handle; + Status status = lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), + inst_ops, &func_handle); + if (status.ok()) { + return func_handle; + } + return status; +} + +Status TRTEngineOp::ImportSegmentGraphDef(FunctionLibraryRuntime* lib, + const string& device_name) { + TF_ASSIGN_OR_RETURN(FunctionLibraryRuntime::Handle func_handle, + ConstructFunctionHandle(lib, device_name)); + return FunctionDefToGraphDef(func_handle, lib, &segment_graph_def_); } TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) @@ -335,14 +352,21 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) << context->device()->name() << ", thus setting _allow_build_at_runtime=true"; allow_build_at_runtime_ = true; + } else { + OP_REQUIRES_OK(context, status); } - func_handle_ = kInvalidHandle; + + status = context->GetAttr("_allow_soft_placement", &allow_soft_placement_); + if (status.code() == tensorflow::error::NOT_FOUND) { + allow_soft_placement_ = true; + } else { + OP_REQUIRES_OK(context, status); + } + + native_execution_func_handle_ = kInvalidHandle; if (!static_engine_) { - FunctionLibraryRuntime* lib = context->function_library(); - OP_REQUIRES_OK(context, - ConstructFunctionHandle(lib, context->device()->name())); - OP_REQUIRES_OK( - context, FunctionDefToGraphDef(func_handle_, lib, &segment_graph_def_)); + OP_REQUIRES_OK(context, ImportSegmentGraphDef(context->function_library(), + context->device()->name())); } // TODO(laigd): calibration_data is used in TF v1.x and we keep it only for // backward compatibility reasons. Remove it once all known users switch to @@ -411,13 +435,13 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper) { std::vector<Tensor> inputs; std::vector<Tensor>* outputs = new std::vector<Tensor>(); - if (func_handle_ == kInvalidHandle) { - OP_REQUIRES_OK_ASYNC( - ctx, + if (native_execution_func_handle_ == kInvalidHandle) { + StatusOr<FunctionLibraryRuntime::Handle> status_or_handle = ConstructFunctionHandle(ctx->function_library(), ctx->device()->name(), - /*allow_soft_placement=*/true, - ctx->num_inputs(), ctx->num_outputs()), - *helper); + allow_soft_placement_, ctx->num_inputs(), + ctx->num_outputs()); + OP_REQUIRES_OK_ASYNC(ctx, status_or_handle.status(), *helper); + native_execution_func_handle_ = status_or_handle.ValueOrDie(); } auto lib = ctx->function_library(); FunctionLibraryRuntime::Options opts; @@ -430,7 +454,7 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, } helper->Ref(); // Increment count for calculating native graph VLOG(1) << "Executing native segment: " << name(); - lib->Run(opts, func_handle_, inputs, outputs, + lib->Run(opts, native_execution_func_handle_, inputs, outputs, [this, ctx, outputs, helper](const Status& s) { core::ScopedUnref sc(helper); OP_REQUIRES_OK_ASYNC(ctx, s, *helper); @@ -854,12 +878,8 @@ StatusOr<std::pair<EngineContext*, int>> TRTEngineOp::GetEngine( return std::pair<EngineContext*, int>(&empty_context, 0); } if (segment_graph_def_.node().empty()) { - FunctionLibraryRuntime* lib = ctx->function_library(); - auto status = ConstructFunctionHandle(lib, ctx->device()->name()); - if (status.ok()) { - status = - FunctionDefToGraphDef(func_handle_, lib, &segment_graph_def_); - } + Status status = ImportSegmentGraphDef(ctx->function_library(), + ctx->device()->name()); if (!status.ok()) { LOG_FIRST_FEW_WARNING_WITH_PREFIX << "Getting segment graph for " << name() << " failed. " diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc index 71193dc24cf..e7beb589f7b 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc @@ -91,6 +91,13 @@ class TRTEngineOpTestBase : public OpsTestBase { OpsTestBase::SetDevice(DEVICE_GPU, std::move(device)); NameAttrList function; function.set_name(StrCat(op_name, "_native_segment")); + // We disable allow_soft_placement when executing the native segment of the + // TRTEngineOp for the following reasons: + // OpsTestBase only allow one device in the device manager. + // We need to define the GPU device to test TRTEngineOp. + // When allow_soft_placement is true, the TensorFlow runtime produces an + // error if a CPU device is not defined + // (see ProcessFunctionLibraryRuntime::InstantiateMultiDevice). TF_ASSERT_OK(NodeDefBuilder(op_name, "TRTEngineOp") .Input(FakeInput(1, dtype)) .Attr("input_shapes", {shape}) @@ -105,6 +112,7 @@ class TRTEngineOpTestBase : public OpsTestBase { .Attr("use_calibration", false) .Attr("_use_implicit_batch", use_implicit_batch) .Attr("_allow_build_at_runtime", allow_build_at_runtime) + .Attr("_allow_soft_placement", false) .Attr("OutT", {dtype}) .Finalize(OpsTestBase::node_def())); TF_ASSERT_OK(InitOpWithFunctionLibrary()); diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h index 350c198ee70..ba3dea924de 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h @@ -52,7 +52,11 @@ class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass { bool IsEnabled(const ConfigProto& config_proto, const Graph& graph) const override { - return IsMlirBridgePassEnabled(graph, config_proto); + // Do not run the bridge if it's enabled by the graph analysis, + // only run if it's enabled by the user explicitly. + MlirBridgeRolloutPolicy policy = + GetMlirBridgeRolloutPolicy(graph, config_proto); + return policy == MlirBridgeRolloutPolicy::kEnabledByUser; } // This should be used as a thin mapper around mlir::ModulePass::runOnModule diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 9883a3f47c5..d44f9991936 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -361,6 +361,7 @@ tf_cc_test( ":util", ":xla_data_proto_cc", "//tensorflow/core:lib", + "//tensorflow/core:test", "//tensorflow/core:test_main", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 28f76829f14..171afa42351 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -234,6 +234,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_proto_cc", "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/core:lib", + "//tensorflow/stream_executor/lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 92d222f32b2..c797f58274c 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -59,7 +59,6 @@ cc_library( srcs = ["comparators.cc"], hdrs = [ "comparators.h", - "//tensorflow/compiler/xla:literal_util", ], deps = [ ":constants", diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 1043528e752..33506535ddf 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -46,6 +47,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace xla { @@ -1266,7 +1268,8 @@ StatusOr<XlaOp> XlaBuilder::GetTupleElementInternal(const Shape& shape, } XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs, - const PrecisionConfig* precision_config) { + const PrecisionConfig* precision_config, + absl::optional<PrimitiveType> preferred_element_type) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs)); @@ -1278,15 +1281,17 @@ XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs, }); } -XlaOp XlaBuilder::DotGeneral(XlaOp lhs, XlaOp rhs, - const DotDimensionNumbers& dimension_numbers, - const PrecisionConfig* precision_config) { +XlaOp XlaBuilder::DotGeneral( + XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config, + absl::optional<PrimitiveType> preferred_element_type) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs)); TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs)); - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferDotOpShape(*lhs_shape, *rhs_shape, - dimension_numbers)); + TF_ASSIGN_OR_RETURN( + Shape shape, + ShapeInference::InferDotOpShape( + *lhs_shape, *rhs_shape, dimension_numbers, preferred_element_type)); return DotGeneralInternal(shape, lhs, rhs, dimension_numbers, precision_config); }); @@ -1353,28 +1358,33 @@ Status XlaBuilder::VerifyConvolution( XlaOp XlaBuilder::Conv(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, Padding padding, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config) { + const PrecisionConfig* precision_config, + absl::optional<PrimitiveType> preferred_element_type) { return ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, CreateDefaultConvDimensionNumbers(window_strides.size()), - feature_group_count, batch_group_count, precision_config); + feature_group_count, batch_group_count, precision_config, + preferred_element_type); } XlaOp XlaBuilder::ConvWithGeneralPadding( XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, absl::Span<const std::pair<int64, int64>> padding, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config) { + const PrecisionConfig* precision_config, + absl::optional<PrimitiveType> preferred_element_type) { return ConvGeneral(lhs, rhs, window_strides, padding, CreateDefaultConvDimensionNumbers(window_strides.size()), - feature_group_count, batch_group_count, precision_config); + feature_group_count, batch_group_count, precision_config, + preferred_element_type); } XlaOp XlaBuilder::ConvWithGeneralDimensions( XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config) { + const PrecisionConfig* precision_config, + absl::optional<PrimitiveType> preferred_element_type) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs)); TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs)); @@ -1402,7 +1412,8 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions( MakePadding(base_area_dimensions, window_dimensions, window_strides, padding), dimension_numbers, feature_group_count, - batch_group_count, precision_config); + batch_group_count, precision_config, + preferred_element_type); }); } @@ -1411,10 +1422,12 @@ XlaOp XlaBuilder::ConvGeneral( absl::Span<const std::pair<int64, int64>> padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config) { + const PrecisionConfig* precision_config, + absl::optional<PrimitiveType> preferred_element_type) { return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {}, dimension_numbers, feature_group_count, - batch_group_count, precision_config); + batch_group_count, precision_config, + preferred_element_type); } XlaOp XlaBuilder::ConvGeneralDilated( @@ -1423,7 +1436,8 @@ XlaOp XlaBuilder::ConvGeneralDilated( absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config) { + const PrecisionConfig* precision_config, + absl::optional<PrimitiveType> preferred_element_type) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs)); TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs)); @@ -1442,10 +1456,11 @@ XlaOp XlaBuilder::ConvGeneralDilated( ShapeInference::InferWindowFromDimensions( window_dimensions, window_strides, padding, lhs_dilation, rhs_dilation)); - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferConvolveShape( - *lhs_shape, *rhs_shape, feature_group_count, - batch_group_count, window, dimension_numbers)); + TF_ASSIGN_OR_RETURN( + Shape shape, + ShapeInference::InferConvolveShape( + *lhs_shape, *rhs_shape, feature_group_count, batch_group_count, + window, dimension_numbers, preferred_element_type)); return ConvGeneralDilatedInternal(shape, lhs, rhs, window, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, @@ -1459,7 +1474,8 @@ StatusOr<HloInstructionProto> XlaBuilder::DynamicConvInstruction( absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type) { + const PrecisionConfig* precision_config, PaddingType padding_type, + absl::optional<PrimitiveType> preferred_element_type) { TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs)); TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs)); std::vector<int64> window_dimensions( @@ -1472,10 +1488,11 @@ StatusOr<HloInstructionProto> XlaBuilder::DynamicConvInstruction( TF_ASSIGN_OR_RETURN(Window window, ShapeInference::InferWindowFromDimensions( window_dimensions, window_strides, padding, lhs_dilation, rhs_dilation)); - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferConvolveShape( - *lhs_shape, *rhs_shape, feature_group_count, - batch_group_count, window, dimension_numbers)); + TF_ASSIGN_OR_RETURN( + Shape shape, + ShapeInference::InferConvolveShape( + *lhs_shape, *rhs_shape, feature_group_count, batch_group_count, + window, dimension_numbers, preferred_element_type)); HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); @@ -1499,14 +1516,15 @@ XlaOp XlaBuilder::DynamicConvInputGrad( absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type) { + const PrecisionConfig* precision_config, PaddingType padding_type, + absl::optional<PrimitiveType> preferred_element_type) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { TF_ASSIGN_OR_RETURN( HloInstructionProto instr, - DynamicConvInstruction(lhs, rhs, window_strides, padding, lhs_dilation, - rhs_dilation, dimension_numbers, - feature_group_count, batch_group_count, - precision_config, padding_type)); + DynamicConvInstruction( + lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, + dimension_numbers, feature_group_count, batch_group_count, + precision_config, padding_type, preferred_element_type)); instr.set_custom_call_target("DynamicConvolutionInputGrad"); @@ -1521,14 +1539,16 @@ XlaOp XlaBuilder::DynamicConvKernelGrad( absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type) { + const PrecisionConfig* precision_config, PaddingType padding_type, + absl::optional<PrimitiveType> preferred_element_type) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { TF_ASSIGN_OR_RETURN( HloInstructionProto instr, DynamicConvInstruction(activations, gradients, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, - precision_config, padding_type)); + precision_config, padding_type, + preferred_element_type)); instr.set_custom_call_target("DynamicConvolutionKernelGrad"); // The gradient of kernel has kernel shape and shouldn't have any dynamic @@ -1545,14 +1565,15 @@ XlaOp XlaBuilder::DynamicConvForward( absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type) { + const PrecisionConfig* precision_config, PaddingType padding_type, + absl::optional<PrimitiveType> preferred_element_type) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { TF_ASSIGN_OR_RETURN( HloInstructionProto instr, - DynamicConvInstruction(lhs, rhs, window_strides, padding, lhs_dilation, - rhs_dilation, dimension_numbers, - feature_group_count, batch_group_count, - precision_config, padding_type)); + DynamicConvInstruction( + lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, + dimension_numbers, feature_group_count, batch_group_count, + precision_config, padding_type, preferred_element_type)); instr.set_custom_call_target("DynamicConvolutionForward"); return AddInstruction(std::move(instr), HloOpcode::kCustomCall, {lhs, rhs}); @@ -3331,6 +3352,11 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { } if (!need_rewrite) { + if (opcode == HloOpcode::kCompare) { + CHECK(!instr_proto->comparison_type().empty()); + new_instr->set_comparison_type( + ComparisonTypeToString(Comparison::DefaultComparisonType(PRED))); + } *new_instr->mutable_name() = GetFullName(instr_proto->opcode(), kNameSeparator, id); return Status::OK(); @@ -3990,11 +4016,26 @@ XlaOp Eq(const XlaOp lhs, const XlaOp rhs, return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq); } +static XlaOp CompareTotalOrder(const XlaOp lhs, const XlaOp rhs, + absl::Span<const int64> broadcast_dimensions, + ComparisonDirection comparison_direction) { + auto b = lhs.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + TF_ASSIGN_OR_RETURN(auto operand_shape, b->GetShape(lhs)); + auto operand_element_type = operand_shape.element_type(); + auto compare_type = + primitive_util::IsFloatingPointType(operand_element_type) + ? Comparison::Type::kFloatTotalOrder + : Comparison::DefaultComparisonType(operand_element_type); + return Compare(lhs, rhs, broadcast_dimensions, comparison_direction, + compare_type); + }); +} + XlaOp EqTotalOrder(const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> broadcast_dimensions) { - auto compare_type = Comparison::Type::kFloatTotalOrder; - return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq, - compare_type); + return CompareTotalOrder(lhs, rhs, broadcast_dimensions, + ComparisonDirection::kEq); } XlaOp Ne(const XlaOp lhs, const XlaOp rhs, @@ -4004,9 +4045,8 @@ XlaOp Ne(const XlaOp lhs, const XlaOp rhs, XlaOp NeTotalOrder(const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> broadcast_dimensions) { - auto compare_type = Comparison::Type::kFloatTotalOrder; - return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe, - compare_type); + return CompareTotalOrder(lhs, rhs, broadcast_dimensions, + ComparisonDirection::kNe); } XlaOp Ge(const XlaOp lhs, const XlaOp rhs, @@ -4016,9 +4056,8 @@ XlaOp Ge(const XlaOp lhs, const XlaOp rhs, XlaOp GeTotalOrder(const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> broadcast_dimensions) { - auto compare_type = Comparison::Type::kFloatTotalOrder; - return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe, - compare_type); + return CompareTotalOrder(lhs, rhs, broadcast_dimensions, + ComparisonDirection::kGe); } XlaOp Gt(const XlaOp lhs, const XlaOp rhs, @@ -4028,9 +4067,8 @@ XlaOp Gt(const XlaOp lhs, const XlaOp rhs, XlaOp GtTotalOrder(const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> broadcast_dimensions) { - auto compare_type = Comparison::Type::kFloatTotalOrder; - return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt, - compare_type); + return CompareTotalOrder(lhs, rhs, broadcast_dimensions, + ComparisonDirection::kGt); } XlaOp Le(const XlaOp lhs, const XlaOp rhs, @@ -4040,10 +4078,10 @@ XlaOp Le(const XlaOp lhs, const XlaOp rhs, XlaOp LeTotalOrder(const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> broadcast_dimensions) { - auto compare_type = Comparison::Type::kFloatTotalOrder; - return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe, - compare_type); + return CompareTotalOrder(lhs, rhs, broadcast_dimensions, + ComparisonDirection::kLe); } + XlaOp Lt(const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> broadcast_dimensions) { return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt); @@ -4051,8 +4089,8 @@ XlaOp Lt(const XlaOp lhs, const XlaOp rhs, XlaOp LtTotalOrder(const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> broadcast_dimensions) { - return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt, - Comparison::Type::kFloatTotalOrder); + return CompareTotalOrder(lhs, rhs, broadcast_dimensions, + ComparisonDirection::kLt); } XlaOp Compare(const XlaOp lhs, const XlaOp rhs, @@ -4074,44 +4112,49 @@ XlaOp Compare(const XlaOp lhs, const XlaOp rhs, ComparisonDirection direction) { } XlaOp Dot(const XlaOp lhs, const XlaOp rhs, - const PrecisionConfig* precision_config) { - return lhs.builder()->Dot(lhs, rhs, precision_config); + const PrecisionConfig* precision_config, + absl::optional<PrimitiveType> preferred_element_type) { + return lhs.builder()->Dot(lhs, rhs, precision_config, preferred_element_type); } XlaOp DotGeneral(const XlaOp lhs, const XlaOp rhs, const DotDimensionNumbers& dimension_numbers, - const PrecisionConfig* precision_config) { + const PrecisionConfig* precision_config, + absl::optional<PrimitiveType> preferred_element_type) { return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers, - precision_config); + precision_config, preferred_element_type); } XlaOp Conv(const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> window_strides, Padding padding, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config) { + const PrecisionConfig* precision_config, + absl::optional<PrimitiveType> preferred_element_type) { return lhs.builder()->Conv(lhs, rhs, window_strides, padding, feature_group_count, batch_group_count, - precision_config); + precision_config, preferred_element_type); } -XlaOp ConvWithGeneralPadding(const XlaOp lhs, const XlaOp rhs, - absl::Span<const int64> window_strides, - absl::Span<const std::pair<int64, int64>> padding, - int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config) { +XlaOp ConvWithGeneralPadding( + const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> window_strides, + absl::Span<const std::pair<int64, int64>> padding, + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config, + absl::optional<PrimitiveType> preferred_element_type) { return lhs.builder()->ConvWithGeneralPadding( lhs, rhs, window_strides, padding, feature_group_count, batch_group_count, - precision_config); + precision_config, preferred_element_type); } XlaOp ConvWithGeneralDimensions( const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config) { + const PrecisionConfig* precision_config, + absl::optional<PrimitiveType> preferred_element_type) { return lhs.builder()->ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count, - batch_group_count, precision_config); + batch_group_count, precision_config, preferred_element_type); } XlaOp ConvGeneral(const XlaOp lhs, const XlaOp rhs, @@ -4119,10 +4162,11 @@ XlaOp ConvGeneral(const XlaOp lhs, const XlaOp rhs, absl::Span<const std::pair<int64, int64>> padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config) { - return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding, - dimension_numbers, feature_group_count, - batch_group_count, precision_config); + const PrecisionConfig* precision_config, + absl::optional<PrimitiveType> preferred_element_type) { + return lhs.builder()->ConvGeneral( + lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count, + batch_group_count, precision_config, preferred_element_type); } XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs, @@ -4132,26 +4176,27 @@ XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config) { + const PrecisionConfig* precision_config, + absl::optional<PrimitiveType> preferred_element_type) { return lhs.builder()->ConvGeneralDilated( lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, - precision_config); + precision_config, preferred_element_type); } -XlaOp DynamicConvInputGrad(XlaOp input_sizes, const XlaOp lhs, const XlaOp rhs, - absl::Span<const int64> window_strides, - absl::Span<const std::pair<int64, int64>> padding, - absl::Span<const int64> lhs_dilation, - absl::Span<const int64> rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config, - PaddingType padding_type) { +XlaOp DynamicConvInputGrad( + XlaOp input_sizes, const XlaOp lhs, const XlaOp rhs, + absl::Span<const int64> window_strides, + absl::Span<const std::pair<int64, int64>> padding, + absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config, PaddingType padding_type, + absl::optional<PrimitiveType> preferred_element_type) { return lhs.builder()->DynamicConvInputGrad( input_sizes, lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, - precision_config, padding_type); + precision_config, padding_type, preferred_element_type); } XlaOp DynamicConvKernelGrad( @@ -4160,11 +4205,12 @@ XlaOp DynamicConvKernelGrad( absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type) { + const PrecisionConfig* precision_config, PaddingType padding_type, + absl::optional<PrimitiveType> preferred_element_type) { return activations.builder()->DynamicConvKernelGrad( activations, gradients, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, - precision_config, padding_type); + precision_config, padding_type, preferred_element_type); } XlaOp DynamicConvForward(const XlaOp lhs, const XlaOp rhs, @@ -4175,11 +4221,12 @@ XlaOp DynamicConvForward(const XlaOp lhs, const XlaOp rhs, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config, - PaddingType padding_type) { + PaddingType padding_type, + absl::optional<PrimitiveType> preferred_element_type) { return lhs.builder()->DynamicConvForward( lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, - precision_config, padding_type); + precision_config, padding_type, preferred_element_type); } XlaOp Fft(const XlaOp operand, FftType fft_type, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index e54ed12c655..9e6a46f77b7 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -521,56 +521,63 @@ class XlaBuilder { XlaOp tuple_data, int64 index); - XlaOp Dot(XlaOp lhs, XlaOp rhs, - const PrecisionConfig* precision_config = nullptr); + XlaOp Dot( + XlaOp lhs, XlaOp rhs, const PrecisionConfig* precision_config = nullptr, + absl::optional<PrimitiveType> preferred_element_type = absl::nullopt); - XlaOp DotGeneral(XlaOp lhs, XlaOp rhs, - const DotDimensionNumbers& dimension_numbers, - const PrecisionConfig* precision_config = nullptr); + XlaOp DotGeneral( + XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config = nullptr, + absl::optional<PrimitiveType> preferred_element_type = absl::nullopt); - XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, - Padding padding, int64 feature_group_count = 1, - int64 batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr); + XlaOp Conv( + XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, + Padding padding, int64 feature_group_count = 1, + int64 batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + absl::optional<PrimitiveType> preferred_element_type = absl::nullopt); XlaOp ConvWithGeneralPadding( XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, absl::Span<const std::pair<int64, int64>> padding, int64 feature_group_count = 1, int64 batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr); + const PrecisionConfig* precision_config = nullptr, + absl::optional<PrimitiveType> preferred_element_type = absl::nullopt); XlaOp ConvWithGeneralDimensions( XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, int64 batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr); + const PrecisionConfig* precision_config = nullptr, + absl::optional<PrimitiveType> preferred_element_type = absl::nullopt); - XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs, - absl::Span<const int64> window_strides, - absl::Span<const std::pair<int64, int64>> padding, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, int64 batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr); + XlaOp ConvGeneral( + XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, + absl::Span<const std::pair<int64, int64>> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, int64 batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + absl::optional<PrimitiveType> preferred_element_type = absl::nullopt); - XlaOp ConvGeneralDilated(XlaOp lhs, XlaOp rhs, - absl::Span<const int64> window_strides, - absl::Span<const std::pair<int64, int64>> padding, - absl::Span<const int64> lhs_dilation, - absl::Span<const int64> rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, - int64 batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr); + XlaOp ConvGeneralDilated( + XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, + absl::Span<const std::pair<int64, int64>> padding, + absl::Span<const int64> lhs_dilation, + absl::Span<const int64> rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, int64 batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + absl::optional<PrimitiveType> preferred_element_type = absl::nullopt); - XlaOp DynamicConvForward(XlaOp lhs, XlaOp rhs, - absl::Span<const int64> window_strides, - absl::Span<const std::pair<int64, int64>> padding, - absl::Span<const int64> lhs_dilation, - absl::Span<const int64> rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config, - PaddingType padding_type); + XlaOp DynamicConvForward( + XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, + absl::Span<const std::pair<int64, int64>> padding, + absl::Span<const int64> lhs_dilation, + absl::Span<const int64> rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config, PaddingType padding_type, + absl::optional<PrimitiveType> preferred_element_type = absl::nullopt); XlaOp DynamicConvInputGrad( XlaOp input_sizes, XlaOp lhs, XlaOp rhs, @@ -580,7 +587,8 @@ class XlaBuilder { absl::Span<const int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type); + const PrecisionConfig* precision_config, PaddingType padding_type, + absl::optional<PrimitiveType> preferred_element_type = absl::nullopt); XlaOp DynamicConvKernelGrad( XlaOp activations, XlaOp gradients, @@ -590,7 +598,8 @@ class XlaBuilder { absl::Span<const int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type); + const PrecisionConfig* precision_config, PaddingType padding_type, + absl::optional<PrimitiveType> preferred_element_type = absl::nullopt); StatusOr<HloInstructionProto> DynamicConvInstruction( XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, @@ -599,7 +608,8 @@ class XlaBuilder { absl::Span<const int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type); + const PrecisionConfig* precision_config, PaddingType padding_type, + absl::optional<PrimitiveType> preferred_element_type = absl::nullopt); virtual StatusOr<XlaOp> ConvGeneralDilatedInternal( const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, @@ -1098,10 +1108,12 @@ class XlaBuilder { ComparisonDirection direction, Comparison::Type compare_type); friend XlaOp Dot(XlaOp lhs, XlaOp rhs, - const PrecisionConfig* precision_config); + const PrecisionConfig* precision_config, + absl::optional<PrimitiveType> preferred_element_type); friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_number, - const PrecisionConfig* precision_config); + const PrecisionConfig* precision_config, + absl::optional<PrimitiveType> preferred_element_type); virtual StatusOr<XlaOp> DotGeneralInternal( const Shape& shape, XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_number, @@ -1109,23 +1121,27 @@ class XlaBuilder { friend XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, Padding padding, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config); + const PrecisionConfig* precision_config, + absl::optional<PrimitiveType> preferred_element_type); friend XlaOp ConvWithGeneralPadding( XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, absl::Span<const std::pair<int64, int64>> padding, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config); + const PrecisionConfig* precision_config, + absl::optional<PrimitiveType> preferred_element_type); friend XlaOp ConvWithGeneralDimensions( XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_confige); - friend XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs, - absl::Span<const int64> window_strides, - absl::Span<const std::pair<int64, int64>> padding, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config); + const PrecisionConfig* precision_config, + absl::optional<PrimitiveType> preferred_element_type); + friend XlaOp ConvGeneral( + XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, + absl::Span<const std::pair<int64, int64>> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config, + absl::optional<PrimitiveType> preferred_element_type); friend XlaOp DynamicConvForward( XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, absl::Span<const std::pair<int64, int64>> padding, @@ -1133,7 +1149,8 @@ class XlaBuilder { absl::Span<const int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type); + const PrecisionConfig* precision_config, PaddingType padding_type, + absl::optional<PrimitiveType> preferred_element_type); friend XlaOp DynamicConvKernelGrad( XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides, @@ -1142,7 +1159,8 @@ class XlaBuilder { absl::Span<const int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type); + const PrecisionConfig* precision_config, PaddingType padding_type, + absl::optional<PrimitiveType> preferred_element_type); friend XlaOp DynamicConvInputGrad( XlaOp input_sizes, XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, @@ -1151,7 +1169,8 @@ class XlaBuilder { absl::Span<const int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type); + const PrecisionConfig* precision_config, PaddingType padding_type, + absl::optional<PrimitiveType> preferred_element_type); friend XlaOp ConvKernelGrad( XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, @@ -1160,7 +1179,8 @@ class XlaBuilder { absl::Span<const int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config); + const PrecisionConfig* precision_config, + absl::optional<PrimitiveType> preferred_element_type); friend XlaOp ConvGeneralDilated( XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, @@ -1169,7 +1189,8 @@ class XlaBuilder { absl::Span<const int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config); + const PrecisionConfig* precision_config, + absl::optional<PrimitiveType> preferred_element_type); friend XlaOp Fft(XlaOp operand, FftType fft_type, absl::Span<const int64> fft_length); friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, @@ -1813,28 +1834,31 @@ XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction); // Enqueues a dot instruction onto the computation. XlaOp Dot(XlaOp lhs, XlaOp rhs, - const PrecisionConfig* precision_config = nullptr); + const PrecisionConfig* precision_config = nullptr, + absl::optional<PrimitiveType> preferred_element_type = absl::nullopt); // Enqueues a general dot instruction onto the computation. -XlaOp DotGeneral(XlaOp lhs, XlaOp rhs, - const DotDimensionNumbers& dimension_numbers, - const PrecisionConfig* precision_config = nullptr); +XlaOp DotGeneral( + XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config = nullptr, + absl::optional<PrimitiveType> preferred_element_type = absl::nullopt); // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. -XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, - Padding padding, int64 feature_group_count = 1, - int64 batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr); +XlaOp Conv( + XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, + Padding padding, int64 feature_group_count = 1, int64 batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + absl::optional<PrimitiveType> preferred_element_type = absl::nullopt); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration in the format returned by MakePadding(). -XlaOp ConvWithGeneralPadding(XlaOp lhs, XlaOp rhs, - absl::Span<const int64> window_strides, - absl::Span<const std::pair<int64, int64>> padding, - int64 feature_group_count = 1, - int64 batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr); +XlaOp ConvWithGeneralPadding( + XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, + absl::Span<const std::pair<int64, int64>> padding, + int64 feature_group_count = 1, int64 batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + absl::optional<PrimitiveType> preferred_element_type = absl::nullopt); // Enqueues a convolution instruction onto the computation, with the caller // provided dimension numbers configuration. @@ -1842,47 +1866,48 @@ XlaOp ConvWithGeneralDimensions( XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, int64 batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr); + const PrecisionConfig* precision_config = nullptr, + absl::optional<PrimitiveType> preferred_element_type = absl::nullopt); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration as well as the dimension numbers. -XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, - absl::Span<const std::pair<int64, int64>> padding, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, int64 batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr); +XlaOp ConvGeneral( + XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, + absl::Span<const std::pair<int64, int64>> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, int64 batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + absl::optional<PrimitiveType> preferred_element_type = absl::nullopt); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration, dilation factors and dimension numbers. -XlaOp ConvGeneralDilated(XlaOp lhs, XlaOp rhs, - absl::Span<const int64> window_strides, - absl::Span<const std::pair<int64, int64>> padding, - absl::Span<const int64> lhs_dilation, - absl::Span<const int64> rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, - int64 batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr); +XlaOp ConvGeneralDilated( + XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, + absl::Span<const std::pair<int64, int64>> padding, + absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, int64 batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + absl::optional<PrimitiveType> preferred_element_type = absl::nullopt); -XlaOp DynamicConvForward(XlaOp lhs, XlaOp rhs, - absl::Span<const int64> window_strides, - absl::Span<const std::pair<int64, int64>> padding, - absl::Span<const int64> lhs_dilation, - absl::Span<const int64> rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config, - PaddingType padding_type); +XlaOp DynamicConvForward( + XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, + absl::Span<const std::pair<int64, int64>> padding, + absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config, PaddingType padding_type, + absl::optional<PrimitiveType> preferred_element_type = absl::nullopt); -XlaOp DynamicConvInputGrad(XlaOp input_sizes, XlaOp lhs, XlaOp rhs, - absl::Span<const int64> window_strides, - absl::Span<const std::pair<int64, int64>> padding, - absl::Span<const int64> lhs_dilation, - absl::Span<const int64> rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config, - PaddingType padding_type); +XlaOp DynamicConvInputGrad( + XlaOp input_sizes, XlaOp lhs, XlaOp rhs, + absl::Span<const int64> window_strides, + absl::Span<const std::pair<int64, int64>> padding, + absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config, PaddingType padding_type, + absl::optional<PrimitiveType> preferred_element_type = absl::nullopt); XlaOp DynamicConvKernelGrad( XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides, @@ -1890,7 +1915,8 @@ XlaOp DynamicConvKernelGrad( absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type); + const PrecisionConfig* precision_config, PaddingType padding_type, + absl::optional<PrimitiveType> preferred_element_type = absl::nullopt); // Enqueues an FFT instruction onto the computation, of the given type and // with the given FFT length. diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 4fc6c848a38..cd21c6dc414 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -338,8 +338,7 @@ TEST_F(XlaBuilderTest, BroadcastInDimWithNegativeSize) { /*broadcast_dimensions=*/{0, 1, 2}); auto statusor = BuildHloModule(&b); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT(statusor.status().error_message(), - HasSubstr("shape's dimensions must not be < 0")); + EXPECT_THAT(statusor.status().error_message(), HasSubstr("invalid shape")); } TEST_F(XlaBuilderTest, OperandFromWrongBuilder) { @@ -1066,6 +1065,56 @@ TEST_F(XlaBuilderTest, DynamicTranspose) { << result_shape; } +TEST_F(XlaBuilderTest, DotWithPreferredElementType) { + XlaBuilder b(TestName()); + Shape p0_shape = ShapeUtil::MakeShape(U8, {2, 3}); + Shape p1_shape = ShapeUtil::MakeShape(U16, {3, 2}); + auto p0 = Parameter(&b, 0, p0_shape, "p0"); + auto p1 = Parameter(&b, 1, p1_shape, "p1"); + + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(1); + dnums.add_rhs_contracting_dimensions(0); + DotGeneral(p0, p1, dnums, /*precision_config=*/nullptr, + /*preferred_element_type=*/U32); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + ASSERT_TRUE( + ShapeUtil::Equal(ShapeUtil::MakeShape(U32, {2, 2}), result_shape)); +} + +TEST_F(XlaBuilderTest, ConvolutionWithPreferredElementType) { + XlaBuilder b(TestName()); + Shape p0_shape = ShapeUtil::MakeShape(S16, {1, 2, 2, 128}); + Shape p1_shape = ShapeUtil::MakeShape(S8, {2, 2, 128, 8}); + auto p0 = Parameter(&b, 0, p0_shape, "p0"); + auto p1 = Parameter(&b, 1, p1_shape, "p1"); + + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + ConvWithGeneralDimensions(p0, p1, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/1, /*batch_group_count=*/1, + /*precision_config=*/nullptr, + /*preferred_element_type=*/S32); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + ASSERT_TRUE( + ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {1, 1, 1, 8}), result_shape)); +} + TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) { XlaBuilder b(TestName()); AfterAll(&b, {CreateToken(&b), ConstantR0<float>(&b, 1.0)}); diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD index 5a14e229fa3..8b6eee1f75e 100644 --- a/tensorflow/compiler/xla/pjrt/BUILD +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -243,6 +243,7 @@ cc_library( hdrs = ["nvidia_gpu_device.h"], deps = [ ":pjrt_client", + "@com_google_absl//absl/container:flat_hash_map", "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", diff --git a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc index 057f8d498a7..570c78c2d70 100644 --- a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc @@ -15,12 +15,15 @@ limitations under the License. #include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h" +#include "absl/container/flat_hash_map.h" + #ifdef NCCL_ENABLED #include "third_party/nccl/nccl.h" #endif // NCCL_ENABLED #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/device/device_host_allocator.h" #include "tensorflow/core/common_runtime/device/device_id.h" @@ -166,56 +169,57 @@ std::unique_ptr<tensorflow::BFCAllocator> GetGpuHostAllocator( // A table mapping NcclCliqueKeys to ncclUniqueId values encoded as strings. // In a distributed setup the table of NCCL IDs is kept on the master node -// (node 0). Currently node 0 is the only node that generates ncclUniqueIds; -// see the TODO below. +// (node 0). The node of the first participating device will create the unique +// id. class NcclIdStore { public: - NcclIdStore(int node_id, std::shared_ptr<DistributedRuntimeClient> client) - : node_id_(node_id), client_(std::move(client)) {} + NcclIdStore(int node_id, std::shared_ptr<DistributedRuntimeClient> client, + absl::flat_hash_map<GlobalDeviceId, int> device_to_node) + : node_id_(node_id), + client_(std::move(client)), + device_to_node_(std::move(device_to_node)) {} StatusOr<std::string> GetNcclUniqueId(const NcclCliqueKey& key); private: const int node_id_; const std::shared_ptr<DistributedRuntimeClient> client_; + const absl::flat_hash_map<GlobalDeviceId, int> device_to_node_; absl::Mutex mu_; - absl::flat_hash_map<std::string, std::string> cache_ ABSL_GUARDED_BY(mu_); + absl::flat_hash_map<NcclCliqueKey, std::string> cache_ ABSL_GUARDED_BY(mu_); }; StatusOr<std::string> NcclIdStore::GetNcclUniqueId(const NcclCliqueKey& key) { - std::string key_string = GlobalDeviceIdsToString(key.devices()); + // The caller must ensure that threads calling this method concurrently have + // unique keys, otherwise the global key-value store may hold the wrong value. { absl::MutexLock lock(&mu_); - auto it = cache_.find(key_string); + auto it = cache_.find(key); if (it != cache_.end()) { return it->second; } } - auto result = [&]() -> StatusOr<std::string> { - // TODO(phawkins): this will deadlock if node 0 is not involved in the - // computation. Add support for computations that only use a subset of - // replicas. - if (node_id_ == 0) { + std::string id_string; + int primary_node_id = device_to_node_.at(key.devices()[0]); + if (node_id_ == primary_node_id) { #ifdef NCCL_ENABLED - ncclUniqueId id; - ncclResult_t r = ncclGetUniqueId(&id); - TF_RET_CHECK(r == ncclSuccess); - std::string value(id.internal, NCCL_UNIQUE_ID_BYTES); - TF_RETURN_IF_ERROR(client_->KeyValueSet(key_string, value)); - return value; + ncclUniqueId id; + ncclResult_t r = ncclGetUniqueId(&id); + TF_RET_CHECK(r == ncclSuccess); + id_string = std::string(id.internal, NCCL_UNIQUE_ID_BYTES); + TF_RETURN_IF_ERROR(client_->KeyValueSet(key.ToString(), id_string)); #else - return FailedPrecondition("NCCL support was not built into XLA binary."); + return FailedPrecondition("NCCL support was not built into XLA binary."); #endif - } else { - return client_->BlockingKeyValueGet(key_string, absl::Minutes(5)); - } - }(); - if (!result.ok()) { - return result.status(); + } else { + TF_ASSIGN_OR_RETURN(id_string, client_->BlockingKeyValueGet( + key.ToString(), absl::Minutes(5))); } absl::MutexLock lock(&mu_); - return cache_.emplace(key_string, result.ValueOrDie()).first->second; + auto result = cache_.emplace(key, std::move(id_string)); + TF_RET_CHECK(result.second) << "Unique ID already in cache."; + return result.first->second; } std::vector<std::unique_ptr<PjRtDevice>> BuildLocalDevices( @@ -258,8 +262,11 @@ Status BuildDistributedDevices( distributed_client->EnumerateDevices(local_topology, &global_topology)); std::vector<GlobalDeviceId> gpu_device_ids(local_device_states.size()); + absl::flat_hash_map<GlobalDeviceId, int> device_to_node; for (const LocalTopologyProto& node : global_topology.nodes()) { for (const DeviceProto& device_proto : node.devices()) { + GlobalDeviceId global_device_id(device_proto.global_device_id()); + device_to_node[global_device_id] = node.node_id(); std::unique_ptr<LocalDeviceState> local_device; if (node.node_id() == node_id) { TF_RET_CHECK(device_proto.local_device_ordinal() >= 0 && @@ -269,8 +276,7 @@ Status BuildDistributedDevices( nullptr); local_device = std::move(local_device_states[device_proto.local_device_ordinal()]); - gpu_device_ids[device_proto.local_device_ordinal()] = - GlobalDeviceId(device_proto.global_device_id()); + gpu_device_ids[device_proto.local_device_ordinal()] = global_device_id; } auto device = absl::make_unique<GpuDevice>( device_proto.global_device_id(), std::move(local_device), @@ -283,8 +289,8 @@ Status BuildDistributedDevices( } gpu_executable_run_options->set_gpu_global_device_ids( std::move(gpu_device_ids)); - auto nccl_id_store = - std::make_shared<NcclIdStore>(node_id, distributed_client); + auto nccl_id_store = std::make_shared<NcclIdStore>( + node_id, distributed_client, device_to_node); gpu_executable_run_options->set_nccl_unique_id_callback( [nccl_id_store](const NcclCliqueKey& key) { return nccl_id_store->GetNcclUniqueId(key); diff --git a/tensorflow/compiler/xla/python/bfloat16.cc b/tensorflow/compiler/xla/python/bfloat16.cc index 5dcfc3b0dcc..5f96c494c25 100644 --- a/tensorflow/compiler/xla/python/bfloat16.cc +++ b/tensorflow/compiler/xla/python/bfloat16.cc @@ -597,19 +597,25 @@ struct TypeDescriptor<uint16> { static int Dtype() { return NPY_UINT16; } }; +// We register "int", "long", and "long long" types for portability across +// Linux, where "int" and "long" are the same type, and Windows, where "long" +// and "longlong" are the same type. template <> -struct TypeDescriptor<uint32> { - typedef uint32 T; - static int Dtype() { return NPY_UINT32; } +struct TypeDescriptor<unsigned int> { + typedef unsigned int T; + static int Dtype() { return NPY_UINT; } }; -template <typename Uint64Type> -struct TypeDescriptor< - Uint64Type, typename std::enable_if<std::is_integral<Uint64Type>::value && - !std::is_signed<Uint64Type>::value && - sizeof(Uint64Type) == 8>::type> { - typedef Uint64Type T; - static int Dtype() { return NPY_UINT64; } +template <> +struct TypeDescriptor<unsigned long> { // NOLINT + typedef unsigned long T; // NOLINT + static int Dtype() { return NPY_ULONG; } +}; + +template <> +struct TypeDescriptor<unsigned long long> { // NOLINT + typedef unsigned long long T; // NOLINT + static int Dtype() { return NPY_ULONGLONG; } }; template <> @@ -625,18 +631,21 @@ struct TypeDescriptor<int16> { }; template <> -struct TypeDescriptor<int32> { - typedef int32 T; - static int Dtype() { return NPY_INT32; } +struct TypeDescriptor<int> { + typedef int T; + static int Dtype() { return NPY_INT; } }; -template <typename Int64Type> -struct TypeDescriptor< - Int64Type, typename std::enable_if<std::is_integral<Int64Type>::value && - std::is_signed<Int64Type>::value && - sizeof(Int64Type) == 8>::type> { - typedef Int64Type T; - static int Dtype() { return NPY_INT64; } +template <> +struct TypeDescriptor<long> { // NOLINT + typedef long T; // NOLINT + static int Dtype() { return NPY_LONG; } +}; + +template <> +struct TypeDescriptor<long long> { // NOLINT + typedef long long T; // NOLINT + static int Dtype() { return NPY_LONGLONG; } }; template <> @@ -1354,7 +1363,15 @@ bool Initialize() { if (!RegisterBfloat16Cast<uint16>(NPY_UINT16, /*cast_is_safe=*/false)) { return false; } - if (!RegisterBfloat16Cast<uint32>(NPY_UINT32, /*cast_is_safe=*/false)) { + if (!RegisterBfloat16Cast<unsigned int>(NPY_UINT, /*cast_is_safe=*/false)) { + return false; + } + if (!RegisterBfloat16Cast<unsigned long>(NPY_ULONG, // NOLINT + /*cast_is_safe=*/false)) { + return false; + } + if (!RegisterBfloat16Cast<unsigned long long>( // NOLINT + NPY_ULONGLONG, /*cast_is_safe=*/false)) { return false; } if (!RegisterBfloat16Cast<uint64>(NPY_UINT64, /*cast_is_safe=*/false)) { @@ -1366,14 +1383,15 @@ bool Initialize() { if (!RegisterBfloat16Cast<int16>(NPY_INT16, /*cast_is_safe=*/false)) { return false; } - if (!RegisterBfloat16Cast<int32>(NPY_INT32, /*cast_is_safe=*/false)) { + if (!RegisterBfloat16Cast<int>(NPY_INT, /*cast_is_safe=*/false)) { return false; } - if (!RegisterBfloat16Cast<int64>(NPY_INT64, /*cast_is_safe=*/false)) { + if (!RegisterBfloat16Cast<long>(NPY_LONG, // NOLINT + /*cast_is_safe=*/false)) { return false; } - if (!RegisterBfloat16Cast<npy_longlong>(NPY_LONGLONG, - /*cast_is_safe=*/false)) { + if (!RegisterBfloat16Cast<long long>( // NOLINT + NPY_LONGLONG, /*cast_is_safe=*/false)) { return false; } // Following the numpy convention. imag part is dropped when converting to diff --git a/tensorflow/compiler/xla/python/bfloat16_test.py b/tensorflow/compiler/xla/python/bfloat16_test.py index 9aaa955d546..4c7321a5b7f 100644 --- a/tensorflow/compiler/xla/python/bfloat16_test.py +++ b/tensorflow/compiler/xla/python/bfloat16_test.py @@ -293,7 +293,7 @@ class Bfloat16NumPyTest(parameterized.TestCase): for dtype in [ np.float16, np.float32, np.float64, np.int8, np.int16, np.int32, np.int64, np.complex64, np.complex128, np.uint8, np.uint16, np.uint32, - np.uint64 + np.uint64, np.intc, np.int_, np.longlong, np.uintc, np.ulonglong ]: x = np.array([[1, 2, 3]], dtype=dtype) y = x.astype(bfloat16) diff --git a/tensorflow/compiler/xla/python/jax_jit.cc b/tensorflow/compiler/xla/python/jax_jit.cc index af02c1ef0d4..27bbde09f09 100644 --- a/tensorflow/compiler/xla/python/jax_jit.cc +++ b/tensorflow/compiler/xla/python/jax_jit.cc @@ -595,7 +595,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, static const auto* xla_module = new py::module(py::module::import("jax.interpreters.xla")); - const auto& device_array = xla_module->attr("DeviceArray"); + const auto& device_array = xla_module->attr("_DeviceArray"); static const auto* numpy_module = new py::module(py::module::import("numpy")); const auto& np_array = numpy_module->attr("array"); @@ -613,7 +613,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, for (py::handle arg : arguments.flat_dynamic_args) { // We specically only deal with DeviceArray (not ShardedDeviceArray). // (Can happen in jit(pmap), e.g. "test_jit_nested_donate_ignored"). - if (arg.get_type().is(device_array)) { + if (py::isinstance<PyBuffer>(arg) || arg.get_type().is(device_array)) { xla::PyBuffer* buffer; if (arg.attr("_device").is_none()) { // Skip non-sticky devices. continue; @@ -653,7 +653,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, xla::PjRtClient* pjrt_client = data_device->client(); for (py::handle arg : arguments.flat_dynamic_args) { - if (arg.get_type().is(device_array)) { + if (py::isinstance<PyBuffer>(arg) || arg.get_type().is(device_array)) { if (!HasTrivialLazyExpr(arg)) { return InvalidArgument( "Non-trivial lazy expression not supported in C++. " diff --git a/tensorflow/compiler/xla/python/ops.cc b/tensorflow/compiler/xla/python/ops.cc index 04e68f9a563..60d47081e21 100644 --- a/tensorflow/compiler/xla/python/ops.cc +++ b/tensorflow/compiler/xla/python/ops.cc @@ -108,7 +108,8 @@ void BuildOpsSubmodule(py::module* m) { py::arg("lhs_dilation"), py::arg("rhs_dilation"), py::arg("dimension_numbers"), py::arg("feature_group_count") = 1, py::arg("batch_group_count") = 1, - py::arg("precision_config") = nullptr); + py::arg("precision_config") = nullptr, + py::arg("preferred_element_type") = absl::nullopt); ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"), py::arg("new_element_type")); ops.def( @@ -136,9 +137,11 @@ void BuildOpsSubmodule(py::module* m) { py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"), py::arg("opaque") = py::bytes(""), py::arg("has_side_effect") = false); ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"), - py::arg("precision_config") = nullptr); + py::arg("precision_config") = nullptr, + py::arg("preferred_element_type") = absl::nullopt); ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"), - py::arg("dimension_numbers"), py::arg("precision_config") = nullptr); + py::arg("dimension_numbers"), py::arg("precision_config") = nullptr, + py::arg("preferred_element_type") = absl::nullopt); ops.def("DynamicSlice", static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>, absl::Span<const int64>)>(&DynamicSlice), diff --git a/tensorflow/compiler/xla/python/py_buffer.cc b/tensorflow/compiler/xla/python/py_buffer.cc index cac14142b75..1f39266a989 100644 --- a/tensorflow/compiler/xla/python/py_buffer.cc +++ b/tensorflow/compiler/xla/python/py_buffer.cc @@ -238,4 +238,22 @@ PyBufferProcs PjRtBufferProcs = []() { return &PjRtBufferProcs; } +void PyBuffer::SetStickyDevice(pybind11::object sticky_device) { + if (sticky_device_ && !sticky_device_->equal(sticky_device)) { + throw std::invalid_argument( + "One cannot set again the stickyness of a buffer and needs to create " + "a new one or a `_DeviceArray`"); + } + sticky_device_ = sticky_device; +} + +void PyBuffer::SetAval(pybind11::object aval) { + if (aval_ && !aval_->equal(aval)) { + throw std::invalid_argument( + "One cannot set again the aval_ of a buffer and needs to create a " + "new one or a `_DeviceArray`"); + } + aval_ = aval; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/python/py_buffer.h b/tensorflow/compiler/xla/python/py_buffer.h index 0562beaa14a..c7c62d83225 100644 --- a/tensorflow/compiler/xla/python/py_buffer.h +++ b/tensorflow/compiler/xla/python/py_buffer.h @@ -17,8 +17,11 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_PYTHON_PY_BUFFER_H_ #include <memory> +#include <stdexcept> #include <vector> +#include "absl/types/optional.h" +#include "pybind11/pybind11.h" #include "tensorflow/compiler/xla/python/py_client.h" #include "tensorflow/compiler/xla/python/traceback.h" #include "tensorflow/compiler/xla/statusor.h" @@ -41,6 +44,9 @@ class DeviceArrayBase { // Python wrapper around PjRtBuffer. We use a wrapper class: // a) to keep the PjRtClient alive via a std::shared_ptr<> // b) to add Python-specific functionality. +// +// A `PyBuffer` can be used from Python without being wrapped in a Python +// `DeviceArray` object, at the condition there is no associated LazyExpr. class PyBuffer : public DeviceArrayBase { public: PyBuffer(std::shared_ptr<PyClient> client, std::unique_ptr<PjRtBuffer> buffer, @@ -57,8 +63,12 @@ class PyBuffer : public DeviceArrayBase { StatusOr<std::unique_ptr<PyBuffer>> CopyToDevice( const ClientAndPtr<PjRtDevice>& dst_device) const; - void Delete() { return buffer_->Delete(); } + void Delete() { + buffer_->Delete(); + npy_value_ = pybind11::none(); + } + // Returns xla::InvalidArgument if the buffer has been deleted. Status BlockHostUntilReady(); Status CopyToHostAsync() { return buffer_->CopyToHostAsync(); } @@ -75,13 +85,33 @@ class PyBuffer : public DeviceArrayBase { Traceback* traceback() { return traceback_.get(); } + // Returns the size (i.e. number of elements) of the (host) numpy array. + int64 size() { return ShapeUtil::ElementsIn(buffer()->on_host_shape()); } + + // Returns the number of dimensions of the (host) numpy array. + int ndim() const { return buffer()->on_host_shape().dimensions_size(); } + + void SetStickyDevice(pybind11::object sticky_device); + pybind11::object GetStickyDevice() const { return sticky_device_.value(); } + + void SetNpyValue(pybind11::object npy_value) { npy_value_ = npy_value; } + pybind11::object GetNpyValue() const { return npy_value_; } + + void SetAval(pybind11::object aval); + pybind11::object GetAval() const { return aval_.value(); } + private: friend class PyClient; std::shared_ptr<PyClient> client_; std::unique_ptr<PjRtBuffer> buffer_; std::shared_ptr<Traceback> traceback_; - + // The host numpy array caching the value when it has been copied to the host. + pybind11::object npy_value_ = pybind11::none(); + absl::optional<pybind11::object> sticky_device_ = absl::nullopt; + // TODO(jblespiau): It's currently there for convenience but maybe we can do + // without it (adding `weak_type` instead). + absl::optional<pybind11::object> aval_ = absl::nullopt; // Doubly-linked list of all buffers known to the client. Protected by the // GIL. PyBuffer* next_; diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index b7c5148a4cf..00af81b7064 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -73,6 +73,25 @@ bool IsOptimizedBuild() { #endif // NDEBUG } +StatusOr<py::object> BufferToPython(PyBuffer* buffer, py::handle& buffer_obj) { + GlobalPyRefManager()->CollectGarbage(); + if (buffer->buffer()->IsOnCpu() && + buffer->buffer()->on_device_shape().IsArray() && + buffer->buffer()->on_device_shape().element_type() != BF16) { + py::object out = + py::reinterpret_steal<py::object>(PyArray_FROM_O(buffer_obj.ptr())); + CHECK(out.ptr() != nullptr) << buffer->buffer()->on_host_shape().ToString( + /*print_layout=*/true); + return out; + } + std::shared_ptr<Literal> literal; + { + py::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(literal, buffer->buffer()->ToLiteral()); + } + return LiteralToPython(std::move(literal)); +} + } // namespace PYBIND11_MODULE(xla_extension, m) { @@ -261,33 +280,65 @@ PYBIND11_MODULE(xla_extension, m) { // TODO(phawkins): alias for backward compatibility. Remove after JAX no // longer uses this name. m.add_object("PyLocalBuffer", buffer); - buffer.def("copy_to_device", &PyBuffer::CopyToDevice) + buffer + .def_property_readonly("__array_priority__", + [](py::object) { return 100; }) + .def_property("_device", &PyBuffer::GetStickyDevice, + &PyBuffer::SetStickyDevice) + .def_property("aval", &PyBuffer::GetAval, &PyBuffer::SetAval) + .def_property_readonly("_lazy_expr", + [](py::object buffer) { return py::none(); }) + .def_property_readonly("device_buffer", + [](py::object buffer) { return buffer; }) + .def_property_readonly( + "shape", + [](const PyBuffer& pybuffer) -> pybind11::tuple { + return IntSpanToTuple( + pybuffer.buffer()->on_host_shape().dimensions()); + }) + .def_property_readonly( + "dtype", + [](const PyBuffer& buffer) { + PrimitiveType primitive = + buffer.buffer()->on_host_shape().element_type(); + return PrimitiveTypeToDtype(primitive).ValueOrDie(); + }) + .def_property_readonly("size", &PyBuffer::size) + .def_property_readonly("ndim", &PyBuffer::ndim) + .def_property_readonly( + "_value", + [](py::handle buffer_obj) -> pybind11::object { + PyBuffer* buffer = buffer_obj.cast<PyBuffer*>(); + if (buffer->is_deleted()) { + throw std::runtime_error("DeviceArray has been deleted."); + } + py::object npy_value_ = buffer->GetNpyValue(); + if (npy_value_.is_none()) { + npy_value_ = BufferToPython(buffer, buffer_obj).ValueOrDie(); + // TODO(jblspiau): Change `LiteralToPython` to return a + // `py::array`, so we can set more easily the attribute. + npy_value_.attr("flags").attr("writeable") = Py_False; + buffer->SetNpyValue(npy_value_); + } + return npy_value_; + }) + .def("copy_to_device", &PyBuffer::CopyToDevice) .def("delete", &PyBuffer::Delete) + // The GIL is released within BlockHostUntilReady. + .def("block_until_ready", + [](py::object buffer_obj) -> xla::StatusOr<py::object> { + PyBuffer* buffer = buffer_obj.cast<PyBuffer*>(); + TF_RETURN_IF_ERROR(buffer->BlockHostUntilReady()); + return buffer_obj; + }) .def("block_host_until_ready", &PyBuffer::BlockHostUntilReady) .def("copy_to_host_async", &PyBuffer::CopyToHostAsync, py::call_guard<py::gil_scoped_release>()) .def("to_py", - [](py::object buffer_obj) -> StatusOr<py::object> { - GlobalPyRefManager()->CollectGarbage(); + [](py::handle buffer_obj) { PyBuffer* buffer = buffer_obj.cast<PyBuffer*>(); - if (buffer->buffer()->IsOnCpu() && - buffer->buffer()->on_device_shape().IsArray() && - buffer->buffer()->on_device_shape().element_type() != BF16) { - py::object out = py::reinterpret_steal<py::object>( - PyArray_FROM_O(buffer_obj.ptr())); - CHECK(out.ptr() != nullptr) - << buffer->buffer()->on_host_shape().ToString( - /*print_layout=*/true); - return out; - } - std::shared_ptr<Literal> literal; - { - py::gil_scoped_release gil_release; - TF_ASSIGN_OR_RETURN(literal, buffer->buffer()->ToLiteral()); - } - return LiteralToPython(std::move(literal)); + return BufferToPython(buffer, buffer_obj); }) - .def("shape", &PyBuffer::shape) .def("xla_shape", &PyBuffer::shape) .def_property_readonly("client", &PyBuffer::client) .def("device", &PyBuffer::device) diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 63891c537ff..0e538fa8ca8 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -28,6 +28,8 @@ import inspect import os from typing import List, Sequence, Tuple, Union +from . import xla_extension as _xla + from absl import logging import numpy as np @@ -36,8 +38,6 @@ import numpy as np # of TensorFlow. If we use protocol buffers here, then importing both jaxlib # and TensorFlow may fail with duplicate protocol buffer message definitions. -from tensorflow.compiler.xla.python import xla_extension as _xla - # Most functions are snake_case for consistency with other modules, some # method names are CamelCase for consistency with XLA. # pylint: disable=invalid-name diff --git a/tensorflow/compiler/xla/python/xla_client_backend_independent_test.py b/tensorflow/compiler/xla/python/xla_client_backend_independent_test.py index 180bb040cc4..eb9c90941b6 100644 --- a/tensorflow/compiler/xla/python/xla_client_backend_independent_test.py +++ b/tensorflow/compiler/xla/python/xla_client_backend_independent_test.py @@ -38,8 +38,7 @@ ops = xla_client.ops class ShapeTest(absltest.TestCase): def testInvalidShapes(self): - with self.assertRaisesRegex(RuntimeError, - "shape's dimensions must not be < 0.*"): + with self.assertRaisesRegex(RuntimeError, "invalid shape"): xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [-2, 4]) with self.assertRaisesRegex( diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index fd2a065c34f..fc30883880e 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -14,10 +14,6 @@ # ============================================================================== """Backend-dependent tests for the Python XLA client.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import functools import itertools import re @@ -445,17 +441,10 @@ def TestFactory(xla_backend, cloud_tpu=False): with self.assertRaises(RuntimeError): compiled_c.execute([arg_buffer]) - def testShape(self): - pyval = np.array([[1., 2.]], np.float32) - local_buffer = self.backend.buffer_from_pyval(pyval) - xla_shape = local_buffer.shape() - self.assertEqual(xla_shape.dimensions(), (1, 2)) - self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32)) - def testXlaShape(self): pyval = np.array([[1., 2.]], np.float32) - buffer = self.backend.buffer_from_pyval(pyval) - xla_shape = buffer.xla_shape() + local_buffer = self.backend.buffer_from_pyval(pyval) + xla_shape = local_buffer.xla_shape() self.assertEqual(xla_shape.dimensions(), (1, 2)) self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32)) @@ -476,6 +465,27 @@ def TestFactory(xla_backend, cloud_tpu=False): "BlockHostUntilReady() called on deleted or donated buffer")): buffer.block_host_until_ready() + def testDeviceArrayBaseSignatures(self): + # When extending `DeviceArrayBase`, the object behaves as a `DeviceArray` + # and thus needs to correctly implement the following methods. + arg = np.array([[1., 2., 3.]], np.float32) + buffer = self.backend.buffer_from_pyval(arg) + if not isinstance(buffer, xla_client.DeviceArrayBase): + raise unittest.SkipTest( + "The objectof type {} do not extend DeviceArrayBase".format( + type(buffer))) + + self.assertEqual(buffer.__array_priority__, 100) + self.assertEqual(buffer.shape, (1, 3)) + self.assertEqual(buffer.dtype, np.float32) + self.assertEqual(buffer.size, 3) + self.assertEqual(buffer.ndim, 2) + + self.assertIs(buffer, buffer.block_until_ready()) + buffer.delete() + with self.assertRaises(RuntimeError): + buffer.block_until_ready() + def testCopyToHost(self): arg0 = np.array([[1., 2.]], np.float32) arg1 = np.array([[3., 4.]], np.float32) diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 007ba568527..c1e6b1b1a1d 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -465,7 +465,8 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( const Shape& shape = ShapeInference::InferConvolveShape( lhs_literal.shape(), rhs_literal.shape(), - /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums) + /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, + /*preferred_element_type=*/absl::nullopt) .ConsumeValueOrDie(); HloInstruction* lhs_instruction = diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 85de316e569..5fe89e84995 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -3721,6 +3721,7 @@ cc_library( ":hlo_casting_utils", ":hlo_pass", ":shape_inference", + "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", @@ -3774,6 +3775,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 39d1eddc569..9a725cd541d 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1798,10 +1798,10 @@ StatusOr<bool> AlgebraicSimplifierVisitor::RemoveDegenerateDimensionFromDot( ShapeUtil::DropDegenerateDimensions(rhs_shape), dot->mutable_operand(1))) : dot->mutable_operand(1); - TF_ASSIGN_OR_RETURN(auto new_dot, MakeDotHlo(new_lhs, new_rhs, new_dnums, - dot->precision_config())); - // TODO(b/165824019): Add an optional preferred element type to MakeDotHlo. - new_dot->mutable_shape()->set_element_type(dot->shape().element_type()); + TF_ASSIGN_OR_RETURN( + auto new_dot, + MakeDotHlo(new_lhs, new_rhs, new_dnums, dot->precision_config(), + /*preferred_element_type=*/dot->shape().element_type())); if (ShapeUtil::Compatible(dot->shape(), new_dot->shape())) { TF_RETURN_IF_ERROR(ReplaceInstruction(dot, new_dot)); } else { @@ -4678,10 +4678,10 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { } } TF_ASSIGN_OR_RETURN( - auto new_dot, MakeDotHlo(lhs, rhs, new_dnums, dot->precision_config())); + auto new_dot, + MakeDotHlo(lhs, rhs, new_dnums, dot->precision_config(), + /*preferred_element_type=*/dot->shape().element_type())); dot->SetupDerivedInstruction(new_dot); - // TODO(b/165824019): Add an optional preferred element type to MakeDotHlo. - new_dot->mutable_shape()->set_element_type(dot->shape().element_type()); if (reduce_dims.empty()) { return ReplaceInstruction(hlo, new_dot); } @@ -5312,10 +5312,13 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands( if (!reverse_dimensions.empty()) { TF_ASSIGN_OR_RETURN(kernel, MakeReverseHlo(kernel, reverse_dimensions)); } - TF_ASSIGN_OR_RETURN(HloInstruction * new_convolution, - MakeConvolveHlo(kernel, input, /*feature_group_count=*/1, - /*batch_group_count=*/1, swapped_window, - swapped_dnums, precision_config)); + TF_ASSIGN_OR_RETURN( + HloInstruction * new_convolution, + MakeConvolveHlo( + kernel, input, /*feature_group_count=*/1, + /*batch_group_count=*/1, swapped_window, swapped_dnums, + precision_config, + /*preferred_element_type=*/convolution->shape().element_type())); convolution->SetupDerivedInstruction(new_convolution); TF_RETURN_IF_ERROR(ReplaceInstruction(convolution, new_convolution)); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 91c6c29ee80..1050ba9ad64 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -3785,9 +3785,11 @@ TEST_P(ConvInputPaddingTest, DoTest) { ParseWindow(absl::StrCat("size=3x3 ", testcase.orig_conv_window)) .ValueOrDie(); builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape(lhs_pad->shape(), filter->shape(), - /*feature_group_count=*/1, - /*batch_group_count=*/1, window, dnums) + ShapeInference::InferConvolveShape( + lhs_pad->shape(), filter->shape(), + /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums, + /*preferred_element_type=*/absl::nullopt) .ValueOrDie(), lhs_pad, filter, /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); @@ -3902,9 +3904,11 @@ TEST_P(ConvFilterPaddingTest, DoIt) { precision_config.add_operand_precision(PrecisionConfig::HIGHEST); builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(), - /*feature_group_count=*/1, - /*batch_group_count=*/1, window, dnums) + ShapeInference::InferConvolveShape( + input->shape(), rhs_pad->shape(), + /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums, + /*preferred_element_type=*/absl::nullopt) .ValueOrDie(), input, rhs_pad, /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, precision_config)); @@ -4050,7 +4054,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter")); Shape out_shape = ShapeInference::InferConvolveShape( in_shape, f_shape, /*feature_group_count=*/1, - /*batch_group_count=*/1, window, dnums) + /*batch_group_count=*/1, window, dnums, + /*preferred_element_type=*/absl::nullopt) .ValueOrDie(); if (options.output_minor_to_major_layout) { out_shape = ShapeUtil::MakeShapeWithLayout(F32, out_shape.dimensions(), diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc index 72112585cb3..942ee59a3aa 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc @@ -87,9 +87,11 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( 0, new_dim_numbers.rhs_contracting_dimensions(0) - degenerate_dims.size()); - TF_ASSIGN_OR_RETURN(HloInstruction * new_dot, - MakeDotHlo(new_lhs, new_rhs, new_dim_numbers, - batch_dot->precision_config())); + TF_ASSIGN_OR_RETURN( + HloInstruction * new_dot, + MakeDotHlo(new_lhs, new_rhs, new_dim_numbers, + batch_dot->precision_config(), + /*preferred_element_type=*/batch_dot->shape().element_type())); TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped, MakeReshapeHlo(batch_dot->shape(), new_dot)); diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc index 199bc787b83..bff68574e1a 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc @@ -144,6 +144,15 @@ StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) { << conditional->ToShortString(); return false; } + + bool branch_empty = + ComputationIsEmptyWithArrayRoot(conditional->branch_computation(0)) || + ComputationIsEmptyWithArrayRoot(conditional->branch_computation(1)); + // Empty branch is faster to execute than select. + if (branch_empty) { + return false; + } + HloInstruction* true_call_op = create_call(0); HloInstruction* false_call_op = create_call(1); auto condition_broadcast = [&](const Shape& shape) { @@ -160,13 +169,6 @@ StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) { hlo->shape().tuple_shapes(i), hlo, i)); }; - bool branch_empty = - ComputationIsEmptyWithArrayRoot(conditional->branch_computation(0)) || - ComputationIsEmptyWithArrayRoot(conditional->branch_computation(1)); - // Empty branch is faster to execute than select. - if (branch_empty) { - return false; - } std::function<HloInstruction*(HloInstruction*, HloInstruction*)> select = [&](HloInstruction* t, HloInstruction* f) { if (f->shape().IsToken()) { diff --git a/tensorflow/compiler/xla/service/convolution_group_converter.cc b/tensorflow/compiler/xla/service/convolution_group_converter.cc index f5506b894fd..b202c3adc41 100644 --- a/tensorflow/compiler/xla/service/convolution_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter.cc @@ -299,9 +299,11 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) { window_dim->set_window_reversal(false); window_dim->set_window_dilation(1); HloInstruction* new_convolution = - MakeConvolveHlo(activation, filter, convolution->feature_group_count(), - /*batch_group_count=*/1, window, dim_numbers, - convolution->precision_config()) + MakeConvolveHlo( + activation, filter, convolution->feature_group_count(), + /*batch_group_count=*/1, window, dim_numbers, + convolution->precision_config(), + /*preferred_element_type=*/convolution->shape().element_type()) .ValueOrDie(); convolution->SetupDerivedInstruction(new_convolution); TF_CHECK_OK(computation_->ReplaceInstruction( @@ -650,9 +652,11 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { window_dim->set_window_reversal(false); window_dim->set_window_dilation(1); HloInstruction* new_convolution = - MakeConvolveHlo(activation, filter, /*feature_group_count=*/1, - /*batch_group_count=*/1, window, dim_numbers, - convolution->precision_config()) + MakeConvolveHlo( + activation, filter, /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dim_numbers, + convolution->precision_config(), + /*preferred_element_type=*/convolution->shape().element_type()) .ValueOrDie(); convolution->SetupDerivedInstruction(new_convolution); changed_ = true; diff --git a/tensorflow/compiler/xla/service/cpu/test_target_triple_helper.h b/tensorflow/compiler/xla/service/cpu/test_target_triple_helper.h index 857de4a8143..bfd04ab60a7 100644 --- a/tensorflow/compiler/xla/service/cpu/test_target_triple_helper.h +++ b/tensorflow/compiler/xla/service/cpu/test_target_triple_helper.h @@ -21,8 +21,13 @@ limitations under the License. static const char kTargetCpuForHost[] = "ppc"; static const char kTargetTripleForHost[] = "ppc64le-ibm-linux-gnu"; #else +#if defined(__s390x__) +static const char kTargetCpuForHost[] = "s390x"; +static const char kTargetTripleForHost[] = "systemz-none-linux-gnu"; +#else static const char kTargetCpuForHost[] = ""; static const char kTargetTripleForHost[] = "x86_64-pc-linux"; #endif +#endif #endif diff --git a/tensorflow/compiler/xla/service/dot_as_convolution_util.cc b/tensorflow/compiler/xla/service/dot_as_convolution_util.cc index 3adde5f7d48..2abbe887a51 100644 --- a/tensorflow/compiler/xla/service/dot_as_convolution_util.cc +++ b/tensorflow/compiler/xla/service/dot_as_convolution_util.cc @@ -142,7 +142,8 @@ CreateShardedConvForDotGeneralConvolution( ShapeInference::InferConvolveShape( sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(), /*feature_group_count=*/conv.feature_group_count(), - /*batch_group_count=*/conv.batch_group_count(), window, conv_dnums)); + /*batch_group_count=*/conv.batch_group_count(), window, conv_dnums, + /*preferred_element_type=*/conv.shape().element_type())); *sharded_conv_shape.mutable_layout() = conv.shape().layout(); return HloInstruction::CreateConvolve( sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo, diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 5edbfebddda..1ebb82d7e81 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -268,6 +268,7 @@ cc_library( "//tensorflow/compiler/mlir:name_utils", "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:lhlo", + "//tensorflow/compiler/mlir/hlo:lhlo_gpu", "//tensorflow/compiler/mlir/xla:hlo_module_importer", "//tensorflow/compiler/mlir/xla:hlo_utils", "//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter_test.cc index b10c33da2fd..e9e65c9ebfa 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter_test.cc @@ -110,7 +110,8 @@ TEST_F(GpuConvRewriterTest, BackwardFilterConvolve) { ShapeInference::InferConvolveShape( activations->shape(), gradients->shape(), /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, - tf_default_dnums_for_backward_filter_) + tf_default_dnums_for_backward_filter_, + /*preferred_element_type=*/absl::nullopt) .ConsumeValueOrDie(), activations, gradients, /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, @@ -150,7 +151,8 @@ TEST_F(GpuConvRewriterTest, ShapeInference::InferConvolveShape( activations->shape(), gradients->shape(), /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, - tf_default_dnums_for_backward_filter_) + tf_default_dnums_for_backward_filter_, + /*preferred_element_type=*/absl::nullopt) .ConsumeValueOrDie(), activations, gradients, /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, @@ -292,11 +294,12 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveEvenPadding) { DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), - /*feature_group_count=*/1, /*batch_group_count=*/1, - conv_window, conv_dnums) - .ValueOrDie())); + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), + /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, + conv_dnums, /*preferred_element_type=*/absl::nullopt) + .ValueOrDie())); auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = @@ -337,10 +340,12 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolve1x1Filter) { conv_window.mutable_dimensions(1)->set_base_dilation(2); builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape(output->shape(), kernel->shape(), - /*feature_group_count=*/1, - /*batch_group_count=*/1, conv_window, - tf_default_dnums_for_backward_input_) + ShapeInference::InferConvolveShape( + output->shape(), kernel->shape(), + /*feature_group_count=*/1, + /*batch_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, + /*preferred_element_type=*/absl::nullopt) .ConsumeValueOrDie(), /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, @@ -374,7 +379,8 @@ TEST_F(GpuConvRewriterTest, ShapeInference::InferConvolveShape( output->shape(), kernel->shape(), /*feature_group_count=*/1, /*batch_group_count=*/1, default_conv_window_, - tf_default_dnums_for_backward_input_) + tf_default_dnums_for_backward_input_, + /*preferred_element_type=*/absl::nullopt) .ConsumeValueOrDie(), /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, /*batch_group_count=*/1, default_conv_window_, @@ -431,7 +437,8 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) { conv->shape(), ShapeInference::InferConvolveShape( output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, /*batch_group_count=*/1, - conv_window, tf_default_dnums_for_backward_input_) + conv_window, tf_default_dnums_for_backward_input_, + /*preferred_element_type=*/absl::nullopt) .ValueOrDie())); auto module = CreateNewVerifiedModule(); @@ -481,7 +488,8 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { conv->shape(), ShapeInference::InferConvolveShape( output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, /*batch_group_count=*/1, - conv_window, tf_default_dnums_for_backward_input_) + conv_window, tf_default_dnums_for_backward_input_, + /*preferred_element_type=*/absl::nullopt) .ValueOrDie())); auto module = CreateNewVerifiedModule(); @@ -535,7 +543,8 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) { conv->shape(), ShapeInference::InferConvolveShape( output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, /*batch_group_count=*/1, - conv_window, tf_default_dnums_for_backward_input_) + conv_window, tf_default_dnums_for_backward_input_, + /*preferred_element_type=*/absl::nullopt) .ValueOrDie())); auto module = CreateNewVerifiedModule(); @@ -590,7 +599,8 @@ TEST_F(GpuConvRewriterTest, conv->shape(), ShapeInference::InferConvolveShape( output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, /*batch_group_count=*/1, - conv_window, tf_default_dnums_for_backward_input_) + conv_window, tf_default_dnums_for_backward_input_, + /*preferred_element_type=*/absl::nullopt) .ValueOrDie())); auto module = CreateNewVerifiedModule(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.cc index b152962eb99..3566742a03b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.cc @@ -37,6 +37,10 @@ NcclCliqueKey::NcclCliqueKey(std::vector<GlobalDeviceId> devices) << GlobalDeviceIdsToString(devices_); } +std::string NcclCliqueKey::ToString() const { + return GlobalDeviceIdsToString(devices_); +} + GpuExecutableRunOptions& GpuExecutableRunOptions::set_gpu_global_device_ids( absl::optional<std::vector<GlobalDeviceId>> gpu_global_device_ids) { gpu_global_device_ids_ = std::move(gpu_global_device_ids); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h b/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h index 7a43c80121b..58284b2292b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h @@ -53,6 +53,8 @@ class NcclCliqueKey { const std::vector<GlobalDeviceId>& devices() const { return devices_; } + std::string ToString() const; + private: std::vector<GlobalDeviceId> devices_; }; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h index 34b93ca5b3f..81cfaabd63a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" @@ -50,9 +51,9 @@ class IrEmitterContext { profile_index_map_(profile_index_map), mlir_context_(mlir_context), llvm_module_(llvm_module) { - mlir_context_ - ->loadDialect<mlir::lmhlo::LmhloDialect, mlir::mhlo::MhloDialect, - mlir::StandardOpsDialect>(); + mlir_context_->loadDialect< + mlir::lmhlo::LmhloDialect, mlir::mhlo::MhloDialect, + mlir::StandardOpsDialect, mlir::lmhlo_gpu::LmhloGpuDialect>(); } // Disallow copy and assign. IrEmitterContext(const IrEmitterContext&) = delete; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 68e29bf68c6..6b7a0f03e23 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -167,18 +167,28 @@ int64_t GetAllocationIndex(mlir::BlockArgument func_arg) { .getSExtValue(); } +static int64_t GetMemRefSizeInBytes(mlir::MemRefType type) { + // For i1 memrefs, the underlying allocation is 8 bits. + if (type.getElementType().isInteger(/*width=*/1)) { + return type.getNumElements(); + } else { + return type.getSizeInBits() / CHAR_BIT; + } +} + StatusOr<BufferAllocation::Slice> GetAllocationSliceForMlir( mlir::Value v, absl::Span<const BufferAllocation> allocations) { - int64 size = v.getType().cast<mlir::MemRefType>().getSizeInBits() / 8; + int64 size = GetMemRefSizeInBytes(v.getType().cast<mlir::MemRefType>()); if (auto arg = v.dyn_cast<mlir::BlockArgument>()) { return BufferAllocation::Slice(&allocations[GetAllocationIndex(arg)], 0, size); } - // We match two patterns here: - // * v = ViewOp(arg); - // * v = MemRefReinterpretCastOp(ViewOp(arg)); + // We match the following patterns here: + // base := ViewOp(arg) | get_global_memref (global_memref) + // root := base | MemRefReinterpretCastOp(base) + if (mlir::Operation* op = v.getDefiningOp()) { if (auto cast = mlir::dyn_cast<mlir::MemRefReinterpretCastOp>(op)) { mlir::Value source = cast.getViewSource(); @@ -197,6 +207,14 @@ StatusOr<BufferAllocation::Slice> GetAllocationSliceForMlir( .getValue() .getSExtValue(), size); + } else if (mlir::isa<mlir::GetGlobalMemrefOp>(op)) { + int64_t index = + op->getAttrOfType<mlir::IntegerAttr>("lmhlo.alloc").getInt(); + int64_t offset = + op->getAttrOfType<mlir::IntegerAttr>("lmhlo.slice_offset").getInt(); + int64_t size = + op->getAttrOfType<mlir::IntegerAttr>("lmhlo.slice_size").getInt(); + return BufferAllocation::Slice(&allocations[index], offset, size); } return Unimplemented("MemRefReinterpretCastOp has to wrap a ViewOp"); } @@ -554,13 +572,24 @@ Status IrEmitterUnnested::DefaultActionForMlir(MlirEmitterInput input) { // because we don't have a fully functioning LMHLO graph yet. mlir::Location loc = input.op->getLoc(); - mlir::lmhlo::FusionOp fusion = nullptr; + mlir::lmhlo::FusionOp fusion = + mlir::OpBuilder(input.op).create<mlir::lmhlo::FusionOp>( + loc, llvm::ArrayRef<mlir::NamedAttribute>()); Shape output_shape; + mlir::OpBuilder b(&fusion.region()); + + const auto load_memrefs = [loc, &b](mlir::ValueRange range) { + std::vector<mlir::Value> operands; + for (mlir::Value memref : range) { + auto load = b.create<mlir::TensorLoadOp>(loc, memref); + HloFunctionImporter::SetLayoutForMlir(load, + TypeToShape(memref.getType())); + operands.push_back(load); + } + return operands; + }; + if (auto copy = mlir::dyn_cast<mlir::lmhlo::CopyOp>(input.op)) { - fusion = mlir::OpBuilder(copy).create<mlir::lmhlo::FusionOp>( - loc, llvm::ArrayRef<mlir::NamedAttribute>()); - copy.getOperation()->moveBefore(&fusion.region().front().back()); - mlir::OpBuilder b(copy); auto operand = b.create<mlir::TensorLoadOp>(loc, copy.operand()); HloFunctionImporter::SetLayoutForMlir( operand, TypeToShape(copy.operand().getType())); @@ -568,15 +597,41 @@ Status IrEmitterUnnested::DefaultActionForMlir(MlirEmitterInput input) { output_shape = TypeToShape(copy.output().getType()); HloFunctionImporter::SetLayoutForMlir(fused_copy, output_shape); b.create<mlir::TensorStoreOp>(loc, fused_copy, copy.output()); - copy.getOperation()->erase(); + } else if (auto reduce = mlir::dyn_cast<mlir::lmhlo::ReduceOp>(input.op)) { + std::vector<mlir::Value> operands = load_memrefs(reduce.operands()); + std::vector<mlir::Value> init_values = load_memrefs(reduce.init_values()); + auto fused_reduce = b.create<mlir::mhlo::ReduceOp>( + loc, operands, init_values, reduce.dimensions()); + fused_reduce.body().takeBody(reduce.body()); + CHECK_EQ(fused_reduce.getNumResults(), reduce.out().size()); + std::vector<Shape> output_shapes; + for (int i = 0; i < reduce.out().size(); i++) { + b.create<mlir::TensorStoreOp>(loc, fused_reduce.getResult(i), + reduce.out()[i]); + auto shape = TypeToShape(reduce.out()[i].getType()); + if (i == 0) { + HloFunctionImporter::SetLayoutForMlir(fused_reduce, shape); + } + output_shapes.push_back(shape); + } + if (output_shapes.size() == 1) { + output_shape = output_shapes[0]; + } else { + output_shape = ShapeUtil::MakeTupleShape(output_shapes); + } } else { input.op->dump(); LOG(FATAL) << "Unimplemented default action for mlir op"; } + input.op->erase(); input.op = fusion; - auto ret = EmitLoopFusionFromMlir( - input, output_shape, - ComputeMaxUnrollFactor(output_shape, hlo_module_config_)); + int unroll_factor = 1; + // TODO(timshen): Port MayPreventVectorization as we add more ops into this + // function. + if (output_shape.IsArray()) { + unroll_factor = ComputeMaxUnrollFactor(output_shape, hlo_module_config_); + } + auto ret = EmitLoopFusionFromMlir(input, output_shape, unroll_factor); return ret; } @@ -893,7 +948,8 @@ Status IrEmitterUnnested::HandleTriangularSolve(HloInstruction* hlo) { // This function won't be needed once ElementalIrEmitter migrates to take MHLO // instead. static Status ProcessFusionForConversion(mlir::Region* region, - std::vector<Shape>* operand_shapes) { + std::vector<Shape>* operand_shapes, + std::vector<Shape>* output_shapes) { std::vector<mlir::TensorLoadOp> loads; std::vector<mlir::TensorStoreOp> stores; @@ -913,8 +969,7 @@ static Status ProcessFusionForConversion(mlir::Region* region, auto arg = region->addArgument(load.getType()); load.replaceAllUsesWith(arg); Shape shape = TypeToShape(load.getType()); - auto attr = mlir::GetLayoutFromMlirHlo(load); - if (attr) { + if (auto attr = mlir::GetLayoutFromMlirHlo(load)) { std::vector<int64> minor_to_major; absl::c_transform( attr, std::back_inserter(minor_to_major), @@ -930,6 +985,16 @@ static Status ProcessFusionForConversion(mlir::Region* region, std::vector<mlir::Value> returned_values; for (auto store : stores) { + Shape shape = TypeToShape(store.memref().getType()); + if (auto attr = mlir::GetLayoutFromMlirHlo(store)) { + std::vector<int64> minor_to_major; + absl::c_transform( + attr, std::back_inserter(minor_to_major), + std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue)); + *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); + } + output_shapes->push_back(shape); + returned_values.push_back(store.tensor()); store.erase(); } @@ -1254,12 +1319,14 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce( } Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { + TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(reduce)); + if (IsReductionFromOrToContiguousDimensions(*reduce) && reduce->shape().IsArray()) { return EmitReductionFromOrToContiguousDimensions(reduce, {reduce}); } - return IrEmitter::HandleReduce(reduce); + return DefaultActionForMlir(input); } Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { @@ -1325,23 +1392,23 @@ Status IrEmitterUnnested::HandleSelectAndScatter( "Dilation for SelectAndScatter not implemented on GPU."); } - TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk, - BuildInitializerThunk(select_and_scatter)); - TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(select_and_scatter)); - return EmitSelectAndScatterFromMlir(input, std::move(initializer_thunk)); + return EmitSelectAndScatterFromMlir(input); } Status IrEmitterUnnested::EmitSelectAndScatterFromMlir( - MlirEmitterInput mlir_input, std::unique_ptr<Thunk>&& initializer_thunk) { + MlirEmitterInput mlir_input) { auto select_and_scatter_op = ::mlir::cast<::mlir::lmhlo::SelectAndScatterOp>(mlir_input.op); std::string name = mlir::GetNameFromLoc(select_and_scatter_op.getLoc()); std::vector<std::unique_ptr<Thunk>> thunks; - thunks.push_back(std::move(initializer_thunk)); - + thunks.emplace_back(); + TF_ASSIGN_OR_RETURN(thunks.back(), + BuildInitializerThunkForMlir( + mlir_input.op, select_and_scatter_op.init_value(), + select_and_scatter_op.out())); absl::Span<const BufferAllocation> allocations( ir_emitter_context_->buffer_assignment().Allocations()); @@ -1862,9 +1929,10 @@ IrEmitterUnnested::GetOrCreateSubComputationFromRegion(mlir::Region* region, bool is_fusion) { std::unique_ptr<HloModule>& module = scratch_nested_computations_[region]; if (module == nullptr) { - std::vector<Shape> operand_shapes; + std::vector<Shape> operand_shapes, output_shapes; if (is_fusion) { - TF_RETURN_IF_ERROR(ProcessFusionForConversion(region, &operand_shapes)); + TF_RETURN_IF_ERROR( + ProcessFusionForConversion(region, &operand_shapes, &output_shapes)); } xla::XlaComputation xla_computation; @@ -1878,6 +1946,62 @@ IrEmitterUnnested::GetOrCreateSubComputationFromRegion(mlir::Region* region, module, HloModule::CreateFromProto(xla_computation.proto(), HloModuleConfig(program_shape))); + if (is_fusion) { + HloComputation* fused_computation = module->entry_computation(); + CHECK_EQ(operand_shapes.size(), fused_computation->num_parameters()); + for (int i = 0; i < fused_computation->num_parameters(); i++) { + *fused_computation->parameter_instruction(i) + ->mutable_shape() + ->mutable_layout() = operand_shapes[i].layout(); + } + HloInstruction* root = fused_computation->root_instruction(); + // Manually fold Tuple(GTE(a, 0), GTE(a, 1), GTE(a, 2), ...) to a. + // FusedIrEmitter doesn't take GTE ops because we aim to elimiate tuples + // as much as possible. + if (root->opcode() == HloOpcode::kTuple) { + [&] { + HloInstruction* real_root = nullptr; + int expected_tuple_index = 0; + for (HloInstruction* operand : root->operands()) { + if (operand->opcode() != HloOpcode::kGetTupleElement) { + return; + } + if (real_root == nullptr) { + real_root = operand->mutable_operand(0); + } else if (real_root != operand->operand(0)) { + return; + } + if (expected_tuple_index != operand->tuple_index()) { + return; + } + expected_tuple_index++; + } + fused_computation->set_root_instruction(real_root); + std::vector<HloInstruction*> to_be_removed; + to_be_removed.push_back(root); + for (HloInstruction* operand : root->operands()) { + to_be_removed.push_back(operand); + } + for (auto instr : to_be_removed) { + TF_CHECK_OK(fused_computation->RemoveInstruction(instr)); + } + + root = real_root; + }(); + } + + if (output_shapes.size() > 1) { + CHECK(root->shape().IsTuple()); + CHECK_EQ(root->shape().tuple_shapes_size(), output_shapes.size()); + + for (int i = 0; i < output_shapes.size(); i++) { + *root->mutable_shape()->mutable_tuple_shapes(i) = output_shapes.at(i); + } + } else { + CHECK_EQ(1, output_shapes.size()); + *root->mutable_shape() = output_shapes[0]; + } + } // Post-process the generated computation: // * Sanitize constant names, so that they can be used as LLVM global // symbols. @@ -1887,22 +2011,13 @@ IrEmitterUnnested::GetOrCreateSubComputationFromRegion(mlir::Region* region, if (instr->opcode() == HloOpcode::kConstant) { instr->SetAndSanitizeName(llvm_ir::SanitizeConstantName(*instr)); } - if (instr->shape().IsTuple()) { - TF_ASSIGN_OR_RETURN(*instr->mutable_shape(), - ShapeInference::InferVariadicOpShape( - instr->opcode(), instr->operands())); + if (instr->shape().IsTuple() && + computation == module->entry_computation() && + instr != computation->root_instruction()) { + return InternalError("Non-root tuple types are not handled."); } } } - if (is_fusion) { - HloComputation* fused_computation = module->entry_computation(); - CHECK_EQ(operand_shapes.size(), fused_computation->num_parameters()); - for (int i = 0; i < fused_computation->num_parameters(); i++) { - *fused_computation->parameter_instruction(i) - ->mutable_shape() - ->mutable_layout() = operand_shapes[i].layout(); - } - } } return module->entry_computation(); } @@ -2620,6 +2735,45 @@ IrEmitterUnnested::BuildKernelThunkForMlir( ir_arrays); } +std::unique_ptr<Thunk> IrEmitterUnnested::BuildConstantInitializerThunk( + absl::Span<const uint8> init_value, const BufferAllocation::Slice& dest, + const Shape& output_shape) { + int64 num_bytes = init_value.size(); + if (absl::c_all_of(init_value, [](uint8 byte) { return byte == 0; })) { + return absl::make_unique<MemzeroThunk>(Thunk::ThunkInfo(), dest); + } + + // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by + // repeating the literal 4 or 2 times, so long as the destination buffer is + // an even multiple of 32 bits long. + if ((num_bytes == 1 || num_bytes == 2) && + ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) { + uint16 pattern16; + if (num_bytes == 1) { + uint8 b = init_value.front(); + pattern16 = uint16{b} | (uint16{b} << 8); + } else { + memcpy(&pattern16, init_value.data(), sizeof(pattern16)); + } + uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16); + return absl::make_unique<Memset32BitValueThunk>(Thunk::ThunkInfo(), + pattern32, dest); + } + + // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit + // memset so long as all 32-bit words of the scalar are equal to each other. + if (num_bytes >= 4 && num_bytes % 4 == 0 && + memcmp(init_value.data(), init_value.data() + 4, init_value.size() - 4) == + 0) { + uint32 word; + memcpy(&word, init_value.data(), sizeof(word)); + return absl::make_unique<Memset32BitValueThunk>(Thunk::ThunkInfo(), word, + dest); + } + + return nullptr; +} + StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( HloInstruction* hlo, const ShapeIndex& index) { bool fused = HloOpcode::kFusion == hlo->opcode(); @@ -2661,43 +2815,14 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( CHECK(ShapeUtil::IsScalar(init_value->shape())); int64 num_bytes = ShapeUtil::ByteSizeOfElements(init_value->shape()); const auto& literal = init_value->literal(); - - // Are all the bytes of this scalar equal to 0? If so, we can create a - // MemzeroThunk. absl::Span<const uint8> literal_bytes( reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes); - if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { - return {absl::make_unique<MemzeroThunk>(Thunk::ThunkInfo(), - GetAllocationSlice(*hlo, index))}; - } - - // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by - // repeating the literal 4 or 2 times, so long as the destination buffer is - // an even multiple of 32 bits long. const Shape& output_shape = ShapeUtil::GetSubshape(hlo->shape(), index); - if ((num_bytes == 1 || num_bytes == 2) && - ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) { - uint16 pattern16; - if (num_bytes == 1) { - uint8 b = literal_bytes.front(); - pattern16 = uint16{b} | (uint16{b} << 8); - } else { - memcpy(&pattern16, literal_bytes.data(), sizeof(pattern16)); - } - uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16); - return {absl::make_unique<Memset32BitValueThunk>( - Thunk::ThunkInfo(), pattern32, GetAllocationSlice(*hlo, index))}; - } - // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit - // memset so long as all 32-bit words of the scalar are equal to each other. - if (num_bytes >= 4 && num_bytes % 4 == 0 && - memcmp(literal_bytes.data(), literal_bytes.data() + 4, - literal_bytes.size() - 4) == 0) { - uint32 word; - memcpy(&word, literal_bytes.data(), sizeof(word)); - return {absl::make_unique<Memset32BitValueThunk>( - Thunk::ThunkInfo(), word, GetAllocationSlice(*hlo, index))}; + auto thunk = BuildConstantInitializerThunk( + literal_bytes, GetAllocationSlice(*hlo, index), output_shape); + if (thunk) { + return thunk; } } @@ -2744,6 +2869,83 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( return {std::move(kernel_thunk)}; } +StatusOr<std::unique_ptr<Thunk>> +IrEmitterUnnested::BuildInitializerThunkForMlir(mlir::Operation* op, + mlir::Value init_value, + mlir::Value dest) { + // initial value must be a scalar memref. + auto init_type = init_value.getType().dyn_cast<mlir::MemRefType>(); + TF_RET_CHECK(init_type.getRank() == 0); + + const Shape dest_shape = TypeToShape(dest.getType()); + if (auto get_global_memref = mlir::dyn_cast_or_null<mlir::GetGlobalMemrefOp>( + init_value.getDefiningOp())) { + auto global_memref = + mlir::SymbolTable::lookupNearestSymbolFrom<mlir::GlobalMemrefOp>( + get_global_memref, get_global_memref.name()); + if (global_memref.constant() && global_memref.initial_value()) { + // If the initial value happens to be a constant, generate a specialized + // thunk. + auto const_init = global_memref.initial_value() + .getValue() + .cast<mlir::DenseElementsAttr>(); + + Shape init_shape = TypeToShape(init_value.getType()); + CHECK(ShapeUtil::IsScalar(init_shape)); + int64 num_bytes = ShapeUtil::ByteSizeOfElements(init_shape); + bool bool_init; + absl::Span<const uint8> literal_bytes( + reinterpret_cast<const uint8*>(const_init.getRawData().data()), + num_bytes); + if (init_type.getElementTypeBitWidth() == 1) { + TF_RET_CHECK(num_bytes == 1); + bool_init = *const_init.getBoolValues().begin(); + literal_bytes = + absl::MakeSpan(reinterpret_cast<const uint8_t*>(&bool_init), 1); + } + + absl::Span<const BufferAllocation> allocations( + ir_emitter_context_->buffer_assignment().Allocations()); + TF_ASSIGN_OR_RETURN(auto dest_slice, + GetAllocationSliceForMlir(dest, allocations)); + + auto thunk = + BuildConstantInitializerThunk(literal_bytes, dest_slice, dest_shape); + if (thunk) { + return {std::move(thunk)}; + } + } + } + + // Otherwise fall back to our slow initializer code. The thunk in this case + // will just need the IR arrays for the initial value and the destination. + std::vector<llvm_ir::IrArray> ir_arrays; + TF_ASSIGN_OR_RETURN( + std::unique_ptr<KernelThunk> kernel_thunk, + BuildKernelThunkForMlir(op, {init_value, dest}, Thunk::ThunkInfo(), {}, + &ir_arrays)); + const llvm_ir::IrArray init_array = ir_arrays[0]; + const llvm_ir::IrArray dest_array = ir_arrays[1]; + + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + dest_shape, ir_emitter_context_->gpu_device_info()); + UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), + ir_emitter_context_->llvm_module()); + + // TODO(jurahul): Handled fused cases. + // In the unfused case the element is already there, just read from it. + std::string name = mlir::GetNameFromLoc(op->getLoc()); + TF_RETURN_IF_ERROR(ParallelLoopEmitter( + [=](const IrArray::Index& index) { + return init_array.EmitReadArrayElement(index, &b_); + }, + dest_array, launch_dimensions, &b_) + .EmitLoop(mlir::GetNameFromLoc(op->getLoc()))); + + // Convert unique_ptr<KernelThunk> to StatusOr<unique_ptr<Thunk>>. + return {std::move(kernel_thunk)}; +} + namespace { // Checks that the buffers corresponding to the given two HLOs share the same diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 197a25e2cc1..68661ce3034 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -175,8 +175,7 @@ class IrEmitterUnnested : public IrEmitter, Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleReduce(HloInstruction* reduce) override; Status HandleSelectAndScatter(HloInstruction* instruction) override; - Status EmitSelectAndScatterFromMlir( - MlirEmitterInput mlir_input, std::unique_ptr<Thunk>&& initializer_thunk); + Status EmitSelectAndScatterFromMlir(MlirEmitterInput mlir_input); Status HandleTuple(HloInstruction* tuple) override; Status HandleWhile(HloInstruction* xla_while) override; Status HandleInfeed(HloInstruction* xla_infeed) override; @@ -633,6 +632,13 @@ class IrEmitterUnnested : public IrEmitter, StatusOr<std::unique_ptr<Thunk>> BuildInitializerThunk( HloInstruction* hlo, const ShapeIndex& index = {}); + std::unique_ptr<Thunk> BuildConstantInitializerThunk( + absl::Span<const uint8> init_value, const BufferAllocation::Slice& dest, + const Shape& output_shape); + + StatusOr<std::unique_ptr<Thunk>> BuildInitializerThunkForMlir( + mlir::Operation* op, mlir::Value init_value, mlir::Value dest); + // Returns a WhileThunk that invokes thunk sequences for 'condition' and // 'body' sub-computations of while instruction 'hlo'. StatusOr<std::unique_ptr<Thunk>> BuildWhileThunk(const HloInstruction* hlo); diff --git a/tensorflow/compiler/xla/service/gpu/tests/reduce_nested.hlo b/tensorflow/compiler/xla/service/gpu/tests/reduce_nested.hlo index baf3ccf03d5..8e9ee58fbf2 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/reduce_nested.hlo +++ b/tensorflow/compiler/xla/service/gpu/tests/reduce_nested.hlo @@ -13,15 +13,15 @@ // CHECK: %[[VAL_9:.*]] = alloca float, align 4 // CHECK: %[[VAL_10:.*]] = alloca float, align 4 // CHECK: %[[VAL_11:.*]] = getelementptr inbounds i8, i8* %[[VAL_12:.*]], i64 0 -// CHECK: %[[VAL_13:.*]] = bitcast i8* %[[VAL_11]] to [2 x i8*]* +// CHECK: %[[VAL_13:.*]] = bitcast i8* %[[VAL_11]] to [100 x [200 x [300 x float]]]* // CHECK: %[[VAL_14:.*]] = getelementptr inbounds i8, i8* %[[VAL_15:.*]], i64 0 -// CHECK: %[[VAL_16:.*]] = bitcast i8* %[[VAL_14]] to [200 x float]* +// CHECK: %[[VAL_16:.*]] = bitcast i8* %[[VAL_14]] to [100 x [200 x [300 x float]]]* // CHECK: %[[VAL_17:.*]] = getelementptr inbounds i8, i8* %[[VAL_18:.*]], i64 0 // CHECK: %[[VAL_19:.*]] = bitcast i8* %[[VAL_17]] to [200 x float]* // CHECK: %[[VAL_20:.*]] = getelementptr inbounds i8, i8* %[[VAL_21:.*]], i64 0 -// CHECK: %[[VAL_22:.*]] = bitcast i8* %[[VAL_20]] to [100 x [200 x [300 x float]]]* +// CHECK: %[[VAL_22:.*]] = bitcast i8* %[[VAL_20]] to [200 x float]* // CHECK: %[[VAL_23:.*]] = getelementptr inbounds i8, i8* %[[VAL_24:.*]], i64 0 -// CHECK: %[[VAL_25:.*]] = bitcast i8* %[[VAL_23]] to [100 x [200 x [300 x float]]]* +// CHECK: %[[VAL_25:.*]] = bitcast i8* %[[VAL_23]] to [2 x i8*]* // CHECK: %[[VAL_26:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() // CHECK: %[[VAL_27:.*]] = icmp eq i32 0, %[[VAL_26]] // CHECK: %[[VAL_28:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() @@ -41,11 +41,11 @@ // CHECK: d.in_bounds-after: ; preds = %[[VAL_43:.*]], %[[VAL_32]] // CHECK: ret void // CHECK: emit_mof_tuple-true: ; preds = %[[VAL_33]] -// CHECK: %[[VAL_44:.*]] = bitcast [200 x float]* %[[VAL_16]] to i8* -// CHECK: %[[VAL_45:.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* %[[VAL_13]], i64 0, i64 0 +// CHECK: %[[VAL_44:.*]] = bitcast [200 x float]* %[[VAL_19]] to i8* +// CHECK: %[[VAL_45:.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* %[[VAL_25]], i64 0, i64 0 // CHECK: store i8* %[[VAL_44]], i8** %[[VAL_45]], align 8 -// CHECK: %[[VAL_46:.*]] = bitcast [200 x float]* %[[VAL_19]] to i8* -// CHECK: %[[VAL_47:.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* %[[VAL_13]], i64 0, i64 1 +// CHECK: %[[VAL_46:.*]] = bitcast [200 x float]* %[[VAL_22]] to i8* +// CHECK: %[[VAL_47:.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* %[[VAL_25]], i64 0, i64 1 // CHECK: store i8* %[[VAL_46]], i8** %[[VAL_47]], align 8 // CHECK: br label %[[VAL_32]] // CHECK: d.in_bounds-true: ; preds = %[[VAL_32]] @@ -55,23 +55,23 @@ // CHECK: store float %[[VAL_49]], float* %[[VAL_9]], align 4 // CHECK: store i32 0, i32* %[[VAL_8]], align 4 // CHECK: br label %[[VAL_50:.*]] -// CHECK: d.inner.loop_header.reduction_dim.0: ; preds = %[[VAL_51:.*]], %[[VAL_41]] +// CHECK: reduce.13.inner.loop_header.reduction_dim.0: ; preds = %[[VAL_51:.*]], %[[VAL_41]] // CHECK: %[[VAL_52:.*]] = load i32, i32* %[[VAL_8]], align 4 // CHECK: %[[VAL_53:.*]] = icmp uge i32 %[[VAL_52]], 100 // CHECK: br i1 %[[VAL_53]], label %[[VAL_43]], label %[[VAL_54:.*]] -// CHECK: d.inner.loop_body.reduction_dim.0: ; preds = %[[VAL_50]] +// CHECK: reduce.13.inner.loop_body.reduction_dim.0: ; preds = %[[VAL_50]] // CHECK: store i32 0, i32* %[[VAL_7]], align 4 // CHECK: br label %[[VAL_55:.*]] -// CHECK: d.inner.loop_header.reduction_dim.2: ; preds = %[[VAL_56:.*]], %[[VAL_54]] +// CHECK: reduce.13.inner.loop_header.reduction_dim.2: ; preds = %[[VAL_56:.*]], %[[VAL_54]] // CHECK: %[[VAL_57:.*]] = load i32, i32* %[[VAL_7]], align 4 // CHECK: %[[VAL_58:.*]] = icmp uge i32 %[[VAL_57]], 300 // CHECK: br i1 %[[VAL_58]], label %[[VAL_51]], label %[[VAL_56]] -// CHECK: d.inner.loop_body.reduction_dim.2: ; preds = %[[VAL_55]] +// CHECK: reduce.13.inner.loop_body.reduction_dim.2: ; preds = %[[VAL_55]] // CHECK: %[[VAL_59:.*]] = load float, float* %[[VAL_10]], align 4 // CHECK: %[[VAL_60:.*]] = load float, float* %[[VAL_9]], align 4 -// CHECK: %[[VAL_61:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_22]], i32 0, i32 %[[VAL_52]], i32 %[[VAL_39]], i32 %[[VAL_57]] +// CHECK: %[[VAL_61:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_13]], i32 0, i32 %[[VAL_52]], i32 %[[VAL_39]], i32 %[[VAL_57]] // CHECK: %[[VAL_62:.*]] = load float, float* %[[VAL_61]], align 4, !invariant.load !4 -// CHECK: %[[VAL_63:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_25]], i32 0, i32 %[[VAL_52]], i32 %[[VAL_39]], i32 %[[VAL_57]] +// CHECK: %[[VAL_63:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_16]], i32 0, i32 %[[VAL_52]], i32 %[[VAL_39]], i32 %[[VAL_57]] // CHECK: %[[VAL_64:.*]] = load float, float* %[[VAL_63]], align 4, !invariant.load !4 // CHECK: store float %[[VAL_59]], float* %[[VAL_5]], align 4 // CHECK: store float %[[VAL_60]], float* %[[VAL_4]], align 4 @@ -83,7 +83,7 @@ // CHECK: %[[VAL_67:.*]] = bitcast float* %[[VAL_1]] to i8* // CHECK: %[[VAL_68:.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* %[[VAL_6]], i64 0, i64 1 // CHECK: store i8* %[[VAL_67]], i8** %[[VAL_68]], align 8 -// CHECK: call void @Add(float* %[[VAL_5]], float* %[[VAL_4]], float* %[[VAL_3]], float* %[[VAL_2]], [2 x i8*]* %[[VAL_6]]) +// CHECK: call void @region_1_5(float* %[[VAL_5]], float* %[[VAL_4]], float* %[[VAL_3]], float* %[[VAL_2]], [2 x i8*]* %[[VAL_6]]) // CHECK: %[[VAL_69:.*]] = load float, float* %[[VAL_0]], align 4 // CHECK: %[[VAL_70:.*]] = load float, float* %[[VAL_1]], align 4 // CHECK: store float %[[VAL_69]], float* %[[VAL_10]], align 4 @@ -91,21 +91,21 @@ // CHECK: %[[VAL_71:.*]] = add nuw nsw i32 %[[VAL_57]], 1 // CHECK: store i32 %[[VAL_71]], i32* %[[VAL_7]], align 4 // CHECK: br label %[[VAL_55]] -// CHECK: d.inner.loop_exit.reduction_dim.2: ; preds = %[[VAL_55]] +// CHECK: reduce.13.inner.loop_exit.reduction_dim.2: ; preds = %[[VAL_55]] // CHECK: %[[VAL_72:.*]] = add nuw nsw i32 %[[VAL_52]], 1 // CHECK: store i32 %[[VAL_72]], i32* %[[VAL_8]], align 4 // CHECK: br label %[[VAL_50]] -// CHECK: d.inner.loop_exit.reduction_dim.0: ; preds = %[[VAL_50]] +// CHECK: reduce.13.inner.loop_exit.reduction_dim.0: ; preds = %[[VAL_50]] // CHECK: %[[VAL_73:.*]] = load float, float* %[[VAL_10]], align 4 // CHECK: %[[VAL_74:.*]] = insertvalue { float, float } undef, float %[[VAL_73]], 0 // CHECK: %[[VAL_75:.*]] = load float, float* %[[VAL_9]], align 4 // CHECK: %[[VAL_76:.*]] = insertvalue { float, float } %[[VAL_74]], float %[[VAL_75]], 1 // CHECK: %[[VAL_77:.*]] = extractvalue { float, float } %[[VAL_76]], 0 -// CHECK: %[[VAL_78:.*]] = bitcast [200 x float]* %[[VAL_16]] to float* +// CHECK: %[[VAL_78:.*]] = bitcast [200 x float]* %[[VAL_19]] to float* // CHECK: %[[VAL_79:.*]] = getelementptr inbounds float, float* %[[VAL_78]], i32 %[[VAL_37]] // CHECK: store float %[[VAL_77]], float* %[[VAL_79]], align 4 // CHECK: %[[VAL_80:.*]] = extractvalue { float, float } %[[VAL_76]], 1 -// CHECK: %[[VAL_81:.*]] = bitcast [200 x float]* %[[VAL_19]] to float* +// CHECK: %[[VAL_81:.*]] = bitcast [200 x float]* %[[VAL_22]] to float* // CHECK: %[[VAL_82:.*]] = getelementptr inbounds float, float* %[[VAL_81]], i32 %[[VAL_37]] // CHECK: store float %[[VAL_80]], float* %[[VAL_82]], align 4 // CHECK: br label %[[VAL_42]] diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index a90d2828000..30ee9853903 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -94,13 +94,15 @@ StatusOr<HloInstruction*> MakeConvolveHlo( HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, int64 batch_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - const PrecisionConfig& precision_config) { + const PrecisionConfig& precision_config, + absl::optional<PrimitiveType> preferred_element_type) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); - TF_ASSIGN_OR_RETURN(Shape convolve_shape, - ShapeInference::InferConvolveShape( - lhs->shape(), rhs->shape(), feature_group_count, - batch_group_count, window, dimension_numbers)); + TF_ASSIGN_OR_RETURN( + Shape convolve_shape, + ShapeInference::InferConvolveShape( + lhs->shape(), rhs->shape(), feature_group_count, batch_group_count, + window, dimension_numbers, preferred_element_type)); return computation->AddInstruction(HloInstruction::CreateConvolve( convolve_shape, lhs, rhs, feature_group_count, batch_group_count, window, dimension_numbers, precision_config)); @@ -281,14 +283,17 @@ HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape, HloInstruction::CreateIota(shape, iota_dimension)); } -StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dim_numbers, - const PrecisionConfig& precision_config) { +StatusOr<HloInstruction*> MakeDotHlo( + HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, + absl::optional<PrimitiveType> preferred_element_type) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); TF_ASSIGN_OR_RETURN( Shape dot_shape, - ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers)); + ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers, + preferred_element_type)); return computation->AddInstruction(HloInstruction::CreateDot( dot_shape, lhs, rhs, dim_numbers, precision_config)); } diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 53eeeffb858..1fa2a3faeea 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -59,11 +59,14 @@ StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand, // Creates a convolution HLO instruction and adds it to the computation // containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). +// If the result shape has integral element type, an optional +// preferred_element_type can be specified to override the element type. StatusOr<HloInstruction*> MakeConvolveHlo( HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, int64 batch_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - const PrecisionConfig& precision_config); + const PrecisionConfig& precision_config, + absl::optional<PrimitiveType> preferred_element_type); // Creates a transpose HLO instruction and adds it to the computation containing // `operand`. @@ -128,10 +131,14 @@ HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape, int64 iota_dimension); // Creates a Dot HLO instruction and adds it to the computation containing `lhs` -// and `rhs` (both must be in the same computation). -StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dim_numbers, - const PrecisionConfig& precision_config); +// and `rhs` (both must be in the same computation). If the result shape has +// integral element type, an optional preferred_element_type can be specified to +// override the element type. +StatusOr<HloInstruction*> MakeDotHlo( + HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, + absl::optional<PrimitiveType> preferred_element_type); // Creates a Map HLO instruction and adds it to the computation containing the // operands. All operands must be in the same computation. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 1480fdd53ac..57f96c8eea2 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -388,9 +388,10 @@ StatusOr<Literal> HloEvaluator::EvaluateDotOp( std::unique_ptr<HloInstruction> rhs_instr = HloInstruction::CreateConstant(rhs.Clone()); - TF_ASSIGN_OR_RETURN( - Shape dot_shape, - ShapeInference::InferDotOpShape(lhs.shape(), rhs.shape(), dim_numbers)); + TF_ASSIGN_OR_RETURN(Shape dot_shape, + ShapeInference::InferDotOpShape( + lhs.shape(), rhs.shape(), dim_numbers, + /*preferred_element_type=*/absl::nullopt)); std::unique_ptr<HloInstruction> cloned_instruction = HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(), diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 01ad536e033..7f2157a85bd 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -4567,6 +4567,50 @@ TEST_F(HloEvaluatorTest, MapBF16) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } +TEST_F(HloEvaluatorTest, MapS16) { + const absl::string_view hlo_text = R"( + HloModule test + + map_computation { + p = s16[] parameter(0) + add = s16[] add(p, p) + ROOT conv = f32[] convert(add) + } + + ENTRY CopyStartCopyDone { + c = s16[3] constant({1, 2, 3}) + ROOT map = f32[3] map(c), to_apply=map_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + Literal expected = LiteralUtil::CreateR1<float>({2.f, 4.f, 6.f}); + TF_ASSERT_OK_AND_ASSIGN( + Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {})); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + +TEST_F(HloEvaluatorTest, MapU16) { + const absl::string_view hlo_text = R"( + HloModule test + + map_computation { + p = u16[] parameter(0) + add = u16[] add(p, p) + ROOT conv = f32[] convert(add) + } + + ENTRY CopyStartCopyDone { + c = u16[3] constant({1, 2, 3}) + ROOT map = f32[3] map(c), to_apply=map_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + Literal expected = LiteralUtil::CreateR1<float>({2.f, 4.f, 6.f}); + TF_ASSERT_OK_AND_ASSIGN( + Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {})); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + TEST_F(HloEvaluatorTest, DotUpcast) { const absl::string_view hlo_text = R"( HloModule test diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 808c2a4b769..953145df474 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -1290,7 +1290,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN(auto inferred_return_shape, ShapeInference::InferConvolveShape( lhs_shape, rhs_shape, conv->feature_group_count(), - conv->batch_group_count(), window, dnums)); + conv->batch_group_count(), window, dnums, + /*preferred_element_type=*/absl::nullopt)); CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) << "return shape set to: " << ShapeUtil::HumanString(result_shape) << " but is inferred to be: " @@ -1769,6 +1770,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint8>(map)); break; } + case U16: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint16>(map)); + break; + } case U32: { TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint32>(map)); break; @@ -1781,6 +1786,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int8>(map)); break; } + case S16: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int16>(map)); + break; + } case S32: { TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int32>(map)); break; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index f2ea03f063a..e43f68fd257 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 59b1ac31e9b..864f293420b 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -16,10 +16,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_rematerialization.h" #include <algorithm> +#include <iterator> #include <memory> #include <set> #include <string> +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" @@ -153,7 +155,22 @@ struct Item { int64 position; }; +// Data structure meant to record the user of the buffer defined from an Item. +// It records also the operand_number from where such use derives, so that +// indirect uses can be better identified (like for example a buffer used +// through a bitcast). +struct ItemUse { + Item* user; + int operand_number; + + ItemUse(Item* user, int op_num) : user(user), operand_number(op_num) {} + bool operator==(const ItemUse& other) const { + return user == other.user && operand_number == other.operand_number; + } +}; + using ItemList = absl::InlinedVector<Item*, 3>; +using UsesList = absl::InlinedVector<ItemUse, 3>; // Class which maintains an ordered list of instructions with fast insertion // before arbitrary elements. @@ -412,11 +429,11 @@ class InstructionList { // has_indirect_users to whether any of the uses is indirect. A use is indirect // if the instruction defining logical_buffer is not an operand of the use. This // can happen via buffer aliasing (eg, tuples). -ItemList GetUsers(const InstructionList& instruction_list, +UsesList GetUsers(const InstructionList& instruction_list, const LogicalBuffer* logical_buffer, const TuplePointsToAnalysis& points_to_analysis, bool* has_indirect_users) { - ItemList users; + UsesList users; // To identify uses iterate through all HloInstruction users of the // BufferAliases of the logical buffer. *has_indirect_users = false; @@ -431,14 +448,18 @@ ItemList GetUsers(const InstructionList& instruction_list, // instruction (the GTE instruction only uses the pointer vector). continue; } - if (buffer_alias.instruction() != logical_buffer->instruction()) { + if (buffer_alias.instruction() != logical_buffer->instruction() && + buffer_alias.instruction()->opcode() != HloOpcode::kBitcast) { *has_indirect_users = true; } // A buffer may be used by the instruction via more than one alias. For // example, a buffer which appears in more than one element of a tuple. Item* user_item = instruction_list.GetItem(user); - if (!absl::c_linear_search(users, user_item)) { - users.push_back(user_item); + for (int64 op_idx : user->OperandIndices(buffer_alias.instruction())) { + if (!absl::c_linear_search( + users, ItemUse{user_item, static_cast<int>(op_idx)})) { + users.push_back(ItemUse{user_item, static_cast<int>(op_idx)}); + } } } } @@ -516,7 +537,8 @@ class MemoryUsageTracker { // is remat_item. This method should be called after the HLO graph has // been transformed (rematerialization instruction created and connected // to uses). - Status AddRematerializedInstruction(Item* original_item, Item* remat_item); + Status AddRematerializedInstruction(Item* original_item, Item* remat_item, + absl::Span<Item*> bitcasts); // Selects and returns the best candidate instructions for rematerialization. // A sequence of candidate instructions of length between min_block_size and @@ -538,6 +560,9 @@ class MemoryUsageTracker { // Returns whether 'item' has any unplaced users. bool HasUnplacedUsers(Item* item) const; + // Returns the list of uses for a specific 'item'. + const UsesList GetItemUses(Item* item) const; + // Returns whether 'item' is currently in progress. bool IsInProgressItem(Item* item) const { return item == in_progress_item_; } @@ -588,7 +613,7 @@ class MemoryUsageTracker { bool has_indirect_uses; // The instructions which use this buffer. - ItemList users; + UsesList users; // The number of users (HloInstructions) of this buffer which have not yet // been placed in the sequence. @@ -611,7 +636,7 @@ class MemoryUsageTracker { const LogicalBuffer* logical_buffer, const TuplePointsToAnalysis& points_to_analysis, bool live_out) { bool has_indirect_uses = false; - ItemList users = GetUsers(instruction_list_, logical_buffer, + UsesList users = GetUsers(instruction_list_, logical_buffer, points_to_analysis, &has_indirect_uses); return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()), logical_buffer->shape(), std::move(users), live_out, @@ -621,13 +646,13 @@ class MemoryUsageTracker { // Create a new buffer representing a rematerialization of given buffer for // the given uses. Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item, - ItemList&& rematerialized_uses) { + UsesList&& rematerialized_uses) { CHECK(original_buffer.defining_instruction->placed) << original_buffer.defining_instruction->instruction->name(); CHECK(!original_buffer.has_indirect_uses) << original_buffer.ToString(); CHECK(!original_buffer.live_out) << original_buffer.ToString(); - for (Item* use : rematerialized_uses) { - CHECK(!use->placed) << use->instruction->name(); + for (ItemUse& use : rematerialized_uses) { + CHECK(!use.user->placed) << use.user->instruction->name(); } return NewBuffer(remat_item, original_buffer.shape, std::move(rematerialized_uses), /*live_out=*/false, @@ -692,11 +717,18 @@ class MemoryUsageTracker { // Create a new buffer, add it to buffers_, and return a reference. Buffer& NewBuffer(Item* defining_instruction, const Shape& shape, - ItemList&& users, bool live_out, bool has_indirect_uses) { + UsesList&& uses, bool live_out, bool has_indirect_uses) { int buffer_id = buffers_.size(); + auto get_num_of_unique_users = [](const UsesList& uses) -> int64 { + absl::flat_hash_set<Item*> users_set; + for (const ItemUse& use : uses) { + users_set.insert(use.user); + } + return users_set.size(); + }; buffers_.push_back(Buffer{ buffer_id, defining_instruction, size_function_(shape), shape, live_out, - has_indirect_uses, users, static_cast<int64>(users.size())}); + has_indirect_uses, uses, get_num_of_unique_users(uses)}); return buffers_.back(); } @@ -771,12 +803,15 @@ MemoryUsageTracker::MemoryUsageTracker( // Add users of while to Buffer users. bool unused; - for (Item* user_item : GetUsers(instruction_list_, logical_buffer, - points_to_analysis, &unused)) { - if (!absl::c_linear_search(buffer->users, user_item)) { - buffer->users.push_back(user_item); + for (ItemUse& user_item : GetUsers(instruction_list_, logical_buffer, + points_to_analysis, &unused)) { + auto existing_user_it = absl::c_find_if( + buffer->users, + [&](const ItemUse& use) { return user_item.user == use.user; }); + if (existing_user_it == buffer->users.end()) { buffer->unfinished_user_count++; - user_item->buffers_used.push_back(buffer->id); + user_item.user->buffers_used.push_back(buffer->id); + buffer->users.push_back(user_item); } } } else { @@ -784,8 +819,10 @@ MemoryUsageTracker::MemoryUsageTracker( logical_buffer, points_to_analysis, ContainsKey(live_out_set, logical_buffer)); item->buffers_defined.push_back(buffer->id); - for (Item* user : buffer->users) { - user->buffers_used.push_back(buffer->id); + for (ItemUse& user : buffer->users) { + if (!absl::c_linear_search(user.user->buffers_used, buffer->id)) { + user.user->buffers_used.push_back(buffer->id); + } } } @@ -1003,14 +1040,14 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item, // Compressed buffer is now alive. memory_usage_ += size_function_(compressed_item->instruction->shape()); - ItemList placed_users; - ItemList unplaced_users; + UsesList placed_users; + UsesList unplaced_users; CHECK_EQ(original_item->buffers_output.size(), 1); BufferId original_buffer_id = original_item->buffers_output[0]; Buffer& original_buffer = buffers_.at(original_buffer_id); - for (Item* user : original_buffer.users) { - if (user->placed) { - CHECK(IsFinished(user)) << user->instruction->name(); + for (ItemUse& user : original_buffer.users) { + if (user.user->placed) { + CHECK(IsFinished(user.user)) << user.user->instruction->name(); placed_users.push_back(user); } else { unplaced_users.push_back(user); @@ -1018,10 +1055,10 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item, } original_buffer.users = std::move(placed_users); original_buffer.unfinished_user_count = 0; - original_buffer.users.push_back(compressed_item); + original_buffer.users.push_back(ItemUse{compressed_item, 0}); Buffer& compressed_buffer = NewBuffer(compressed_item, compressed_item->instruction->shape(), - {uncompressed_item}, /*live_out=*/false, + {ItemUse{uncompressed_item, 0}}, /*live_out=*/false, /*has_indirect_uses=*/false); compressed_item->buffers_used = original_item->buffers_output; compressed_item->buffers_output = {compressed_buffer.id}; @@ -1036,8 +1073,8 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item, uncompressed_item->buffers_output = {uncompressed_buffer.id}; uncompressed_item->buffers_defined = {uncompressed_buffer.id}; - for (Item* user : uncompressed_buffer.users) { - BufferIdList& buffers_used = user->buffers_used; + for (ItemUse& user : uncompressed_buffer.users) { + BufferIdList& buffers_used = user.user->buffers_used; std::replace(buffers_used.begin(), buffers_used.end(), original_buffer_id, uncompressed_buffer.id); } @@ -1045,8 +1082,8 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item, return Status::OK(); } -Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item, - Item* remat_item) { +Status MemoryUsageTracker::AddRematerializedInstruction( + Item* original_item, Item* remat_item, absl::Span<Item*> bitcasts) { VLOG(3) << "AddRematerializedInstruction: original_instruction = " << original_item->instruction->name() << ", remat_instruction = " << remat_item->instruction->name(); @@ -1067,9 +1104,23 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item, // Buffer used by this instruction was dead, now is alive. memory_usage_ += AllocatedSize(buffer.id); } - buffer.unfinished_user_count++; - buffer.users.push_back(remat_item); + absl::InlinedVector<ItemUse, 2> filtered_users; + std::copy_if(buffer.users.begin(), buffer.users.end(), + std::back_inserter(filtered_users), + [&](const ItemUse& iu) { return iu.user == original_item; }); + for (ItemUse& u : filtered_users) { + buffer.users.push_back(ItemUse{remat_item, u.operand_number}); + } + } + + for (Item* bitcast : bitcasts) { + CHECK_EQ(bitcast->instruction->opcode(), HloOpcode::kBitcast); + for (BufferId buffer_id : bitcast->buffers_used) { + Buffer& buffer = buffers_.at(buffer_id); + buffer.unfinished_user_count++; + buffer.users.push_back(ItemUse{bitcast, 0}); + } } // Create a new set of Buffers defined by the new rematerialization @@ -1078,10 +1129,10 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item, for (BufferId old_buffer_id : original_item->buffers_defined) { Buffer& old_buffer = buffers_.at(old_buffer_id); - ItemList placed_users; - ItemList unplaced_users; - for (Item* user : old_buffer.users) { - if (user->placed) { + UsesList placed_users; + UsesList unplaced_users; + for (ItemUse& user : old_buffer.users) { + if (user.user->placed) { placed_users.push_back(user); } else { unplaced_users.push_back(user); @@ -1097,8 +1148,8 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item, RematerializeBuffer(old_buffer, remat_item, std::move(unplaced_users)); remat_item->buffers_defined.push_back(new_buffer.id); - for (Item* user : new_buffer.users) { - BufferIdList& buffers_used = user->buffers_used; + for (ItemUse& user : new_buffer.users) { + BufferIdList& buffers_used = user.user->buffers_used; std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id, new_buffer.id); } @@ -1131,6 +1182,10 @@ string MemoryUsageTracker::ToString() const { absl::StrAppend(&output, " ", buffer.ToString(), live, ", ", buffer.unfinished_user_count, " unfinished uses\n"); } + absl::StrAppend(&output, " Outputs:\n"); + for (BufferId buffer_id : item->buffers_output) { + absl::StrAppend(&output, " ", buffers_[buffer_id].ToString(), "\n"); + } absl::StrAppend(&output, " Uses:\n"); for (BufferId buffer_id : item->buffers_used) { absl::StrAppend(&output, " ", buffers_[buffer_id].ToString(), "\n"); @@ -1190,12 +1245,14 @@ bool MemoryUsageTracker::Check() const { } for (const Buffer& buffer : buffers_) { int64 unfinished_uses = 0; - for (Item* user : buffer.users) { - const BufferIdList& used_buffers = user->buffers_used; + absl::flat_hash_set<Item*> already_counted_user; + for (const ItemUse& user : buffer.users) { + const BufferIdList& used_buffers = user.user->buffers_used; CHECK(absl::c_linear_search(used_buffers, buffer.id)) - << "Instruction " << user->instruction->name() + << "Instruction " << user.user->instruction->name() << " used buffers is missing " << buffer.ToString(); - if (!IsFinished(user)) { + if (!IsFinished(user.user) && + already_counted_user.insert(user.user).second) { unfinished_uses++; } } @@ -1397,8 +1454,8 @@ MemoryUsageTracker::PickRematerializationCandidates( bool MemoryUsageTracker::HasUnplacedUsers(Item* item) const { for (BufferId buffer_id : item->buffers_defined) { const Buffer& buffer = buffers_.at(buffer_id); - for (Item* user : buffer.users) { - if (!user->placed) { + for (const ItemUse& user : buffer.users) { + if (!user.user->placed) { return true; } } @@ -1406,6 +1463,17 @@ bool MemoryUsageTracker::HasUnplacedUsers(Item* item) const { return false; } +const UsesList MemoryUsageTracker::GetItemUses(Item* item) const { + UsesList combined_users; + for (BufferId buffer_id : item->buffers_defined) { + const Buffer& buffer = buffers_.at(buffer_id); + for (const ItemUse& user : buffer.users) { + combined_users.push_back(user); + } + } + return combined_users; +} + StatusOr<int64> RematerializeInstructions( MemoryUsageTracker* memory_tracker, std::vector<Item*>* best_items, absl::flat_hash_set<const HloInstruction*>* remat_move_instructions, @@ -1443,18 +1511,30 @@ StatusOr<int64> RematerializeInstructions( Item* remat_item = instruction_list->CreateItem(remat); // Replace each remaining use of 'best' with the rematerialization. - std::vector<HloInstruction*> best_users_copy = best->users(); - for (HloInstruction* user : best_users_copy) { - if (!memory_tracker->IsPlaced(user)) { + absl::InlinedVector<Item*, 4> bitcasts; + for (auto& user : memory_tracker->GetItemUses(best_item)) { + if (!memory_tracker->IsPlaced(user.user->instruction)) { VLOG(2) << " Replacing use of " << best->name() << " in " - << user->name() << " with " << remat->name(); - TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat)); + << user.user->instruction->name() << " with " << remat->name(); + const int op_idx = user.operand_number; + auto* remat_use = remat; + if (user.user->instruction->operand(op_idx)->shape() != + remat->shape()) { + remat_use = computation->AddInstruction(HloInstruction::CreateUnary( + user.user->instruction->operand(op_idx)->shape(), + HloOpcode::kBitcast, remat)); + bitcasts.push_back(instruction_list->CreateItem(remat_use)); + bitcasts.back()->buffers_output = remat_item->buffers_defined; + bitcasts.back()->buffers_used = remat_item->buffers_defined; + } + TF_RETURN_IF_ERROR( + user.user->instruction->ReplaceOperandWith(op_idx, remat_use)); } } // Account for the rematerialization in the memory tracker. - TF_RETURN_IF_ERROR( - memory_tracker->AddRematerializedInstruction(best_item, remat_item)); + TF_RETURN_IF_ERROR(memory_tracker->AddRematerializedInstruction( + best_item, remat_item, absl::MakeSpan(bitcasts))); // Insert rematerialized instruction right before the earliest unplaced // use of the instruction *and* the earliest unplaced last use of any @@ -1463,7 +1543,14 @@ StatusOr<int64> RematerializeInstructions( // this could increase memory usage. ItemList place_before; for (auto user : remat->users()) { - place_before.push_back(instruction_list->GetItem(user)); + if (!absl::c_linear_search(bitcasts, instruction_list->GetItem(user))) { + place_before.push_back(instruction_list->GetItem(user)); + } + } + for (auto* bitcast : bitcasts) { + for (auto user : bitcast->instruction->users()) { + place_before.push_back(instruction_list->GetItem(user)); + } } for (auto* operand : remat->operands()) { for (auto* operand_user : operand->users()) { @@ -1486,6 +1573,9 @@ StatusOr<int64> RematerializeInstructions( } instruction_list->InsertBeforeInstructions(remat_item, place_before); + for (auto* bitcast : bitcasts) { + instruction_list->InsertBeforeInstructions(bitcast, place_before); + } // If the rematerialized instruction is dead then rematerialization is // essentially a move. Don't delete the instruction now because we don't // want duplicate HloInstruction* values during the course of the @@ -1501,8 +1591,12 @@ StatusOr<int64> RematerializeInstructions( instruction_list->Denylist(remat); } remat_move_instructions->insert(remat); + net_instructions_added += bitcasts.size(); } else { - net_instructions_added++; + net_instructions_added += bitcasts.size() + 1; + } + for (auto* bitcast : bitcasts) { + instruction_list->Denylist(bitcast->instruction); } } VLOG(1) << "Rematerializing instructions [" diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 35f39e9a342..5e747f9076a 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -748,6 +748,73 @@ ENTRY %entry { op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant())); } +// Test rematerialization of values through bitcasts +// Its expected that the broadcast gets rematerialized +TEST_F(HloRematerializationTest, ThroughBitcastRemat) { + const string& hlo_string = R"( +HloModule fusion, is_scheduled=true + +ENTRY %mycomp (param: f32[1]) -> f32[1] { + %param = f32[1]{0} parameter(0) + %reshape = f32[] reshape(f32[1]{0} %param) + %broadcast = f32[1024,1]{1,0} broadcast(f32[] %reshape), dimensions={} + %bitcast = f32[1024]{0} bitcast(f32[1024,1]{1,0} %broadcast) + %negate = f32[1024,1]{1,0} negate(f32[1024,1]{1,0} %broadcast) + %concatenate = f32[2048,1]{1,0} concatenate(f32[1024,1]{1,0} %negate, f32[1024,1]{1,0} %negate), dimensions={0} + %slice = f32[1,1]{1,0} slice(f32[2048,1]{1,0} %concatenate), slice={[0:1], [0:1]} + %bitcast.1 = f32[1]{0} bitcast(f32[1,1]{1,0} %slice) + %concatenate.1 = f32[1025]{0} concatenate(f32[1024]{0} %bitcast, f32[1]{0} %bitcast.1), dimensions={0} + ROOT %slice.1 = f32[1]{0} slice(f32[1025]{0} %concatenate.1), slice={[0:1]} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + auto* computation = module->entry_computation(); + // Find and save the original broadcast instruction which should be + // rematerialized. + const HloInstruction* slice = computation->root_instruction(); + ASSERT_THAT(slice, + op::Slice(op::Concatenate(op::Bitcast(op::Broadcast(_)), _))); + const HloInstruction* concat = slice->operand(0); + const HloInstruction* bcast = concat->operand(0)->operand(0); + + LOG(INFO) << module->ToString(); + // Computation requires 16KB without rematerialization, but uses only 12KB + // with rematerialization so pick a memory limit between these values (14KB). + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/14 * 1024, module.get())); + LOG(INFO) << module->ToString(); + EXPECT_TRUE(changed); + + // Root should not have changed. + EXPECT_EQ(computation->root_instruction(), slice); + + // The bitcast for the rematerialized broadcast + const HloInstruction* remat_bitcast = concat->operand(0); + // The broadcast should have been rematerialized. + const HloInstruction* remat_broadcast = remat_bitcast->operand(0); + + EXPECT_THAT(remat_broadcast, op::Broadcast(::testing::Ne(bcast))); + + // The rematerialized broadcast should be immediately before its bitcast + // and the bitcast before the concatenate in the sequence. + EXPECT_EQ(module->schedule() + .sequence(computation) + .instructions()[computation->instruction_count() - 2], + concat); + EXPECT_EQ(module->schedule() + .sequence(computation) + .instructions()[computation->instruction_count() - 3], + remat_bitcast); + EXPECT_EQ(module->schedule() + .sequence(computation) + .instructions()[computation->instruction_count() - 4], + remat_broadcast); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 13c754438eb..84e4fe6e3fd 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" @@ -136,13 +137,12 @@ Status ShapeVerifier::HandleCopy(HloInstruction* copy) { } Status ShapeVerifier::HandleDot(HloInstruction* dot) { - TF_ASSIGN_OR_RETURN(Shape expected, - ShapeInference::InferDotOpShape( - dot->operand(0)->shape(), dot->operand(1)->shape(), - dot->dot_dimension_numbers())); - if (ShapeUtil::CanUpcastIntegral(expected, dot->shape())) { - expected.set_element_type(dot->shape().element_type()); - } + TF_ASSIGN_OR_RETURN( + const Shape expected, + ShapeInference::InferDotOpShape( + dot->operand(0)->shape(), dot->operand(1)->shape(), + dot->dot_dimension_numbers(), + /*preferred_element_type=*/dot->shape().element_type())); return CheckShape(dot, expected); } @@ -152,10 +152,8 @@ Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { ShapeInference::InferConvolveShape( convolution->operand(0)->shape(), convolution->operand(1)->shape(), convolution->feature_group_count(), convolution->batch_group_count(), - convolution->window(), convolution->convolution_dimension_numbers())); - if (ShapeUtil::CanUpcastIntegral(expected, convolution->shape())) { - expected.set_element_type(convolution->shape().element_type()); - } + convolution->window(), convolution->convolution_dimension_numbers(), + /*preferred_element_type=*/convolution->shape().element_type())); return CheckShape(convolution, expected); } @@ -1749,6 +1747,31 @@ Status CheckElementwiseInstruction(HloInstruction* instruction) { ShapeUtil::HumanString(operand_shape)); } } + if (auto* comparison = DynCast<HloCompareInstruction>(instruction)) { + const Shape& operand_shape = comparison->operand(1)->shape(); + PrimitiveType operand_element_type = operand_shape.element_type(); + Comparison::Type default_comparison_type = + Comparison::DefaultComparisonType(operand_element_type); + if (primitive_util::IsFloatingPointType(operand_element_type)) { + if (comparison->type() != Comparison::Type::kFloat && + comparison->type() != Comparison::Type::kFloatTotalOrder) { + return FailedPrecondition( + "Expected comparison type %s or %s.\n" + "actual: %s\noperand: %s\n", + ComparisonTypeToString(Comparison::Type::kFloat), + ComparisonTypeToString(Comparison::Type::kFloatTotalOrder), + ComparisonTypeToString(comparison->type()), + ShapeUtil::HumanString(operand_shape)); + } + } else if (comparison->type() != default_comparison_type) { + return FailedPrecondition( + "Expected comparison type %s.\n" + "actual: %s\noperand: %s\n", + ComparisonTypeToString(default_comparison_type), + ComparisonTypeToString(comparison->type()), + ShapeUtil::HumanString(operand_shape)); + } + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 0df30166a1c..c6c09e3dee1 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -1220,5 +1220,77 @@ TEST_F(HloVerifierTest, CollectivePermuteDoneNoCollectivePermuteStart) { "needs to be collective-permute-start, found tuple")); } +TEST_F(HloVerifierTest, ComparisonTypeFloat) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngOperandElementTypesNotMatch { + p0 = f32[] parameter(0) + ROOT cmp = pred[] compare(f32[] p0, f32[] p0), direction=LT, type=UNSIGNED + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected comparison type FLOAT or TOTALORDER")); +} + +TEST_F(HloVerifierTest, ComparisonTypeSigned) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngOperandElementTypesNotMatch { + p0 = s32[] parameter(0) + ROOT cmp = pred[] compare(s32[] p0, s32[] p0), direction=LT, type=UNSIGNED + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected comparison type SIGNED")); +} + +TEST_F(HloVerifierTest, ComparisonTypeUnsigned) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngOperandElementTypesNotMatch { + p0 = u32[] parameter(0) + ROOT cmp = pred[] compare(u32[] p0, u32[] p0), direction=LT, type=SIGNED + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected comparison type UNSIGNED")); +} + +TEST_F(HloVerifierTest, ComparisonTypePred) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngOperandElementTypesNotMatch { + p0 = pred[] parameter(0) + ROOT cmp = pred[] compare(pred[] p0, pred[] p0), direction=LT, type=SIGNED + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected comparison type UNSIGNED")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/integral_upcaster.cc b/tensorflow/compiler/xla/service/integral_upcaster.cc index d8383b25c84..9bb8e468ad4 100644 --- a/tensorflow/compiler/xla/service/integral_upcaster.cc +++ b/tensorflow/compiler/xla/service/integral_upcaster.cc @@ -26,12 +26,14 @@ StatusOr<absl::optional<Shape>> MaybeInferShape( case HloOpcode::kDot: return ShapeInference::InferDotOpShape( instruction->operand(0)->shape(), instruction->operand(1)->shape(), - instruction->dot_dimension_numbers()); + instruction->dot_dimension_numbers(), + /*preferred_element_type=*/absl::nullopt); case HloOpcode::kConvolution: return ShapeInference::InferConvolveShape( instruction->operand(0)->shape(), instruction->operand(1)->shape(), instruction->feature_group_count(), instruction->batch_group_count(), - instruction->window(), instruction->convolution_dimension_numbers()); + instruction->window(), instruction->convolution_dimension_numbers(), + /*preferred_element_type=*/absl::nullopt); default: return absl::make_optional<Shape>(); } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc index 1457fa5df1d..e92cd54e8fd 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc @@ -345,7 +345,9 @@ Status LhloDialectEmitter::HandleReduce(HloInstruction* instr) { CreateDenseIntElementsAttrFromVector(instr->dimensions(), builder_); auto reduce_op = builder.create<lhlo::ReduceOp>(loc, inputs, init_values, results, dimensions_attr); - reduce_op.ensureTerminator(reduce_op.body(), builder, getLocation(instr)); + builder.createBlock(&reduce_op.body()); + OpBuilder::atBlockEnd(&reduce_op.body().front()) + .create<lhlo::TerminatorOp>(getLocation(instr)); return SpliceHloComputation(OpBuilder{&reduce_op.body()}, loc, *instr->to_apply(), emission_context_); } diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 02895d9627b..badfe81625e 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -214,6 +214,37 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, output_is_dynamic); } +StatusOr<PrimitiveType> MaybeUpcast( + PrimitiveType from_type, + absl::optional<PrimitiveType> preferred_element_type) { + if (!preferred_element_type.has_value() || + *preferred_element_type == from_type) { + return from_type; + } + if (primitive_util::IsIntegralType(from_type) != + primitive_util::IsIntegralType(*preferred_element_type)) { + return InvalidArgument( + "`preferred_element_type` and the original type must both be integral " + "or both be floating point."); + } + if (!primitive_util::IsSignedIntegralType(from_type) != + !primitive_util::IsSignedIntegralType(*preferred_element_type)) { + return InvalidArgument( + "`preferred_element_type` must have the same signedness as the " + "original type."); + } + if (primitive_util::BitWidth(*preferred_element_type) < + primitive_util::BitWidth(from_type)) { + if (primitive_util::IsFloatingPointType(from_type)) { + return from_type; + } + return InvalidArgument( + "`preferred_element_type` must not be narrower than the original " + "type."); + } + return *preferred_element_type; +} + } // namespace /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape( @@ -622,7 +653,8 @@ Status ValidateDotDimensionNumbers( /* static */ StatusOr<Shape> ShapeInference::InferDotOpShape( const Shape& lhs, const Shape& rhs, - const DotDimensionNumbers& dimension_numbers) { + const DotDimensionNumbers& dimension_numbers, + absl::optional<PrimitiveType> preferred_element_type) { TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of dot")); TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot")); @@ -700,8 +732,11 @@ Status ValidateDotDimensionNumbers( is_dynamic.push_back(rhs.is_dynamic_dimension(i)); } } - Shape result = ShapeUtil::MakeShape( - ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions, is_dynamic); + TF_ASSIGN_OR_RETURN( + PrimitiveType type, + MaybeUpcast(ShapeUtil::HigherPrecisionElementType(lhs, rhs), + preferred_element_type)); + Shape result = ShapeUtil::MakeShape(type, dimensions, is_dynamic); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result)); VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result); @@ -1586,7 +1621,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr<Shape> ShapeInference::InferConvolveShape( const Shape& lhs, const Shape& rhs, int64 feature_group_count, int64 batch_group_count, const Window& window, - const ConvolutionDimensionNumbers& dnums) { + const ConvolutionDimensionNumbers& dnums, + absl::optional<PrimitiveType> preferred_element_type) { TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution")); TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution")); @@ -1833,8 +1869,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } } - return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), - dimensions, is_dynamic); + TF_ASSIGN_OR_RETURN( + PrimitiveType type, + MaybeUpcast(ShapeUtil::HigherPrecisionElementType(lhs, rhs), + preferred_element_type)); + return ShapeUtil::MakeShape(type, dimensions, is_dynamic); } /* static */ StatusOr<Shape> ShapeInference::InferFftShape( diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 8ef72f7ac3f..f6d55c45334 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -105,12 +105,14 @@ class ShapeInference { const Shape& output_grad_shape, int64 feature_index); - // Infers the shape produced by applying the given convolutional - // filter (rhs) to lhs in the way specified by the fields on window. + // Infers the shape produced by applying the given convolutional filter (rhs) + // to lhs in the way specified by the fields on window. An optional + // preferred_element_type can be specified to upcast the element type. static StatusOr<Shape> InferConvolveShape( const Shape& lhs, const Shape& rhs, int64 feature_group_count, int64 batch_group_count, const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + absl::optional<PrimitiveType> preferred_element_type); // Infers the shape produced by the given FFT type on the given operand. static StatusOr<Shape> InferFftShape(const Shape& in, FftType fft_type, @@ -298,10 +300,12 @@ class ShapeInference { absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply); // Helper that infers the shape produced by performing a dot operation with - // the given LHS and RHS shapes. + // the given LHS and RHS shapes. An optional preferred_element_type can be + // specified to upcast the element type. static StatusOr<Shape> InferDotOpShape( const Shape& lhs, const Shape& rhs, - const DotDimensionNumbers& dimension_numbers); + const DotDimensionNumbers& dimension_numbers, + absl::optional<PrimitiveType> preferred_element_type); // Helper that infers the shape of the tensor produced by a gather operation // with the given input shape, gather indices shape and gather dimension diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 73bbe5cb3bd..77f84a69205 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -437,7 +437,7 @@ TEST_F(ShapeInferenceTest, Convolve) { dim1->set_base_dilation(1); auto inferred_status = ShapeInference::InferConvolveShape( lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, - window, dnums); + window, dnums, /*preferred_element_type=*/absl::nullopt); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}), @@ -483,7 +483,7 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) { dim1->set_base_dilation(1); auto inferred_status = ShapeInference::InferConvolveShape( lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, - window, dnums); + window, dnums, /*preferred_element_type=*/absl::nullopt); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}), @@ -529,7 +529,7 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) { dim1->set_base_dilation(2); auto inferred_status = ShapeInference::InferConvolveShape( lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, - window, dnums); + window, dnums, /*preferred_element_type=*/absl::nullopt); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}), @@ -568,7 +568,7 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { dim1->set_padding_high(1); auto inferred_status = ShapeInference::InferConvolveShape( lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, - window, dnums); + window, dnums, /*preferred_element_type=*/absl::nullopt); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("each dimension exactly once")); @@ -605,12 +605,150 @@ TEST_F(ShapeInferenceTest, ConvolveBatchGroupCountUnequalOutputFeature) { dim1->set_window_dilation(2); auto inferred_status = ShapeInference::InferConvolveShape( lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/6, - window, dnums); + window, dnums, /*preferred_element_type=*/absl::nullopt); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("to be a multiple of batch group count")); } +struct ConvolveArgs { + Shape lhs_shape; + Shape rhs_shape; + ConvolutionDimensionNumbers dnums; + Window window; +}; + +ConvolveArgs MakeConvolveArgs(PrimitiveType lhs_type, PrimitiveType rhs_type) { + ConvolveArgs args; + ConvolutionDimensionNumbers& dnums = args.dnums; + + // Dimension order: batch, feature, x0, x1 + args.lhs_shape = ShapeUtil::MakeShape(lhs_type, {10, 11, 3, 4}); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.set_input_feature_dimension(1); + dnums.set_output_feature_dimension(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(3); + dnums.add_output_spatial_dimensions(3); + + // Dimension order: x1, batch, feature, x0 + args.rhs_shape = ShapeUtil::MakeShape(rhs_type, {2, 12, 11, 3}); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(1); + dnums.add_kernel_spatial_dimensions(3); + dnums.add_kernel_spatial_dimensions(0); + + auto dim0 = args.window.add_dimensions(); + auto dim1 = args.window.add_dimensions(); + dim0->set_size(3); + dim0->set_stride(2); + dim0->set_padding_low(1); + dim0->set_padding_high(1); + dim0->set_window_dilation(1); + dim0->set_base_dilation(1); + dim1->set_size(2); + dim1->set_stride(1); + dim1->set_padding_low(0); + dim1->set_padding_high(0); + dim1->set_window_dilation(1); + dim1->set_base_dilation(1); + return args; +} + +TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementType) { + ConvolveArgs args = MakeConvolveArgs(S8, S16); + TF_ASSERT_OK_AND_ASSIGN( + Shape inferred_shape, + ShapeInference::InferConvolveShape( + args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, + /*batch_group_count=*/1, args.window, args.dnums, + /*preferred_element_type=*/S16)) + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S16, {10, 12, 2, 3}), + inferred_shape)); +} + +TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementTypeSameAsInferredType) { + ConvolveArgs args = MakeConvolveArgs(S8, S16); + TF_ASSERT_OK_AND_ASSIGN( + Shape inferred_shape, + ShapeInference::InferConvolveShape( + args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, + /*batch_group_count=*/1, args.window, args.dnums, + /*preferred_element_type=*/S32)) + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}), + inferred_shape)); +} + +TEST_F(ShapeInferenceTest, + FloatingPointConvolveWithNarrowerPreferredElementType) { + ConvolveArgs args = MakeConvolveArgs(F32, F32); + TF_ASSERT_OK_AND_ASSIGN( + Shape inferred_shape, + ShapeInference::InferConvolveShape( + args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, + /*batch_group_count=*/1, args.window, args.dnums, + /*preferred_element_type=*/BF16)) + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}), + inferred_shape)); +} + +TEST_F(ShapeInferenceTest, + FloatingPointConvolveWithInvalidPreferredElementType) { + ConvolveArgs args = MakeConvolveArgs(BF16, BF16); + auto inferred_status = + ShapeInference::InferConvolveShape( + args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, + /*batch_group_count=*/1, args.window, args.dnums, + /*preferred_element_type=*/S32) + .status(); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.error_message(), + HasSubstr("must both be integral or both be floating point")); +} + +TEST_F(ShapeInferenceTest, + IntegralConvolveWithFloatingPointPreferredElementType) { + ConvolveArgs args = MakeConvolveArgs(S8, S16); + auto inferred_status = + ShapeInference::InferConvolveShape( + args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, + /*batch_group_count=*/1, args.window, args.dnums, + /*preferred_element_type=*/F32) + .status(); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.error_message(), + HasSubstr("must both be integral or both be floating point")); +} + +TEST_F(ShapeInferenceTest, + ConvolveWithPreferredElementTypeWithDifferentSignedness) { + ConvolveArgs args = MakeConvolveArgs(S8, S16); + auto inferred_status = + ShapeInference::InferConvolveShape( + args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, + /*batch_group_count=*/1, args.window, args.dnums, + /*preferred_element_type=*/U32) + .status(); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.error_message(), + HasSubstr("must have the same signedness as the original type")); +} + +TEST_F(ShapeInferenceTest, ConvolveWithNarrowerPreferredElementType) { + ConvolveArgs args = MakeConvolveArgs(S8, S16); + auto inferred_status = + ShapeInference::InferConvolveShape( + args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, + /*batch_group_count=*/1, args.window, args.dnums, + /*preferred_element_type=*/S8) + .status(); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.error_message(), + HasSubstr("must not be narrower than the original type")); +} + namespace fft { static const char* unsupported_rank = "only supports ranks 1-3"; @@ -1282,8 +1420,8 @@ TEST_F(ShapeInferenceTest, BroadcastScalar) { // scalar <dot> vector: ok TEST_F(ShapeInferenceTest, ScalarDotVector) { DotDimensionNumbers dot_dnums; - auto inferred_status = - ShapeInference::InferDotOpShape(f32_, vector_32_, dot_dnums); + auto inferred_status = ShapeInference::InferDotOpShape( + f32_, vector_32_, dot_dnums, /*preferred_element_type=*/absl::nullopt); EXPECT_TRUE(inferred_status.ok()); EXPECT_EQ(inferred_status.ValueOrDie(), vector_32_); } @@ -1294,7 +1432,8 @@ TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) { dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = ShapeInference::InferDotOpShape( - ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums); + ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums, + /*preferred_element_type=*/absl::nullopt); EXPECT_TRUE(inferred_status.ok()); EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), ShapeUtil::MakeShape(F32, {32, 32, 64}))); @@ -1306,11 +1445,13 @@ TEST_F(ShapeInferenceTest, VectorDotVector) { dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = - ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums); + ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums, + /*preferred_element_type=*/absl::nullopt); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie())); auto inferred_status_mismatch = - ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums); + ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums, + /*preferred_element_type=*/absl::nullopt); ASSERT_FALSE(inferred_status_mismatch.ok()); } @@ -1320,11 +1461,13 @@ TEST_F(ShapeInferenceTest, MatrixDotVector) { dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = - ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums); + ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums, + /*preferred_element_type=*/absl::nullopt); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_)); auto inferred_status_mismatch = - ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums); + ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums, + /*preferred_element_type=*/absl::nullopt); ASSERT_FALSE(inferred_status_mismatch.ok()); } @@ -1334,11 +1477,13 @@ TEST_F(ShapeInferenceTest, VectorDotMatrix) { dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = - ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums); + ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums, + /*preferred_element_type=*/absl::nullopt); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_)); auto inferred_status_mismatch = - ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums); + ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums, + /*preferred_element_type=*/absl::nullopt); ASSERT_FALSE(inferred_status_mismatch.ok()); } @@ -1348,7 +1493,8 @@ TEST_F(ShapeInferenceTest, MatrixDotMatrix) { dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status_match = - ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums); + ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums, + /*preferred_element_type=*/absl::nullopt); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE( ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_)) @@ -1356,7 +1502,8 @@ TEST_F(ShapeInferenceTest, MatrixDotMatrix) { << ShapeUtil::HumanString(inferred_status_match.ValueOrDie()) << " expected: " << ShapeUtil::HumanString(matrix_64_48_); auto inferred_status_mismatch = - ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums); + ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums, + /*preferred_element_type=*/absl::nullopt); ASSERT_FALSE(inferred_status_mismatch.ok()); } @@ -1376,7 +1523,8 @@ TEST_F(ShapeInferenceTest, DotGeneral) { dot_dnums.add_rhs_batch_dimensions(1); auto inferred_status_match = - ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, + /*preferred_element_type=*/absl::nullopt); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE( ShapeUtil::Equal(inferred_status_match.ValueOrDie(), output_shape)) @@ -1399,7 +1547,8 @@ TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) { dot_dnums.add_rhs_batch_dimensions(0); auto inferred_status = - ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, + /*preferred_element_type=*/absl::nullopt); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("Must specify the same number of contracting " @@ -1421,7 +1570,8 @@ TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) { dot_dnums.add_rhs_batch_dimensions(0); auto inferred_status = - ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, + /*preferred_element_type=*/absl::nullopt); EXPECT_TRUE(inferred_status.ok()); EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), output_shape)); } @@ -1461,7 +1611,8 @@ TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimSizesFails) { dot_dnums.add_rhs_batch_dimensions(0); auto inferred_status = - ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, + /*preferred_element_type=*/absl::nullopt); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("Batch dimension sizes must match")); @@ -1480,7 +1631,8 @@ TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimNumbersPasses) { dot_dnums.add_rhs_batch_dimensions(1); auto inferred_status = - ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, + /*preferred_element_type=*/absl::nullopt); ASSERT_TRUE(inferred_status.ok()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), ShapeUtil::MakeShape(F32, {2, 11, 14}))); @@ -1499,7 +1651,8 @@ TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) { dot_dnums.add_rhs_batch_dimensions(1); auto inferred_status = - ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, + /*preferred_element_type=*/absl::nullopt); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("A dimension number is out of range")); @@ -1518,12 +1671,108 @@ TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) { dot_dnums.add_rhs_batch_dimensions(1); auto inferred_status = - ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, + /*preferred_element_type=*/absl::nullopt); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("A dimension number is not unique")); } +TEST_F(ShapeInferenceTest, DotWithIntegralPreferredElementType) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape, + ShapeInference::InferDotOpShape( + ShapeUtil::MakeShape(S8, {32, 32}), + ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums, + /*preferred_element_type=*/S32)); + EXPECT_TRUE( + ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(S32, {32, 32}))); +} + +TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeSameAsInferredType) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape, + ShapeInference::InferDotOpShape( + ShapeUtil::MakeShape(BF16, {32, 32}), + ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums, + /*preferred_element_type=*/F32)); + EXPECT_TRUE( + ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {32, 32}))); +} + +TEST_F(ShapeInferenceTest, FloatingPointDotWithNarrowerPreferredElementType) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape, + ShapeInference::InferDotOpShape( + ShapeUtil::MakeShape(BF16, {32, 32}), + ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums, + /*preferred_element_type=*/BF16)); + EXPECT_TRUE( + ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {32, 32}))); +} + +TEST_F(ShapeInferenceTest, FloatingPointDotWithInvalidPreferredElementType) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status = ShapeInference::InferDotOpShape( + ShapeUtil::MakeShape(BF16, {32, 32}), + ShapeUtil::MakeShape(BF16, {32, 32}), dot_dnums, + /*preferred_element_type=*/S32) + .status(); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.error_message(), + HasSubstr("must both be integral or both be floating point")); +} + +TEST_F(ShapeInferenceTest, IntegralDotWithFloatingPointPreferredElementType) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status = ShapeInference::InferDotOpShape( + ShapeUtil::MakeShape(S8, {32, 32}), + ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums, + /*preferred_element_type=*/F32) + .status(); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.error_message(), + HasSubstr("must both be integral or both be floating point")); +} + +TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeWithDifferentSignedness) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status = ShapeInference::InferDotOpShape( + ShapeUtil::MakeShape(S8, {32, 32}), + ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums, + /*preferred_element_type=*/U32) + .status(); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.error_message(), + HasSubstr("must have the same signedness as the original type")); +} + +TEST_F(ShapeInferenceTest, DotWithNarrowerPreferredElementType) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status = ShapeInference::InferDotOpShape( + ShapeUtil::MakeShape(S8, {32, 32}), + ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums, + /*preferred_element_type=*/S8) + .status(); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.error_message(), + HasSubstr("must not be narrower than the original type")); +} + TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) { // Test variations of broadcasting a vector for a binary add with a // matrix. diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter.cc b/tensorflow/compiler/xla/service/space_to_batch_converter.cc index 1b67f3563bc..05cdc6e24b7 100644 --- a/tensorflow/compiler/xla/service/space_to_batch_converter.cc +++ b/tensorflow/compiler/xla/service/space_to_batch_converter.cc @@ -20,6 +20,7 @@ limitations under the License. #include <map> #include <memory> #include <queue> +#include <tuple> #include <unordered_set> #include <utility> @@ -64,6 +65,19 @@ class ConvolutionVisitor { // Top-level function to begin space-to-batch conversion. Status PerformSpaceToBatchOnConvolution(HloInstruction* convolution); + // Struct containing details about a convolution. + struct ConvDetails { + int64 spatial_dimension_to_split, inherent_low_padding, + inherent_high_padding, stride, spatial_size, base_dilation_factor, + halo_size, high_padding_for_conv, low_padding_for_conv, + kernel_spatial_dim_size, input_dim_size; + }; + + // Return a struct containing various necessary information pieces for + // performing space-to-batch on a convolution. + ConvDetails GetConvolutionDetails(HloInstruction* convolution, + ConvolutionDimensionNumbers& dim_numbers); + // Function that determines if space-to-batch can be propagated into the // consumer. Such propagation is only possible when all required operands are // space-to-batch'ed. @@ -225,11 +239,29 @@ bool ConvolutionVisitor::IsConvSuitableForSpaceToBatch( return false; } - // TODO(b/168316428): Support base dilations. - if (convolution->window() - .dimensions(get_chosen_spatial_dim(convolution)) - .base_dilation() != 1) { - return false; + const ConvDetails c = GetConvolutionDetails(convolution, dim_numbers); + + const int64 low_pad = convolution->window() + .dimensions(get_chosen_spatial_dim(convolution)) + .padding_low(); + + // TODO(b/168316428): Support base dilations more generically. + if (c.base_dilation_factor != 1) { + if (c.stride != 1) { + return false; + } + // For low pad of 0, only support a pointwise kernel. + if (low_pad == 0) { + if (c.kernel_spatial_dim_size != 1) { + return false; + } + } else if (c.kernel_spatial_dim_size != c.base_dilation_factor + 1 || + low_pad != c.base_dilation_factor - 1) { + // Only support dilations such that base dilation factor and low pad are + // compatible with kernel_spatial_dim_size to be compatible with + // HaloDuplicateWithSlice. + return false; + } } int64 activations_batch_dim = dim_numbers.input_batch_dimension(); @@ -240,42 +272,17 @@ bool ConvolutionVisitor::IsConvSuitableForSpaceToBatch( if (old_batch_size > limit_on_batch_size_) { return false; } - - auto kernel = convolution->mutable_operand(1); - const auto& kernel_shape = kernel->shape(); - const int64 kernel_spatial_dim_size = - kernel_shape.dimensions(dim_numbers.kernel_spatial_dimensions( - get_chosen_spatial_dim(convolution))); - - auto activations = convolution->mutable_operand(0); - - const int64 input_dim_size = - activations->shape().dimensions(dim_numbers.input_spatial_dimensions( - get_chosen_spatial_dim(convolution))); - - const int64 inherent_low_padding = - convolution->window() - .dimensions(get_chosen_spatial_dim(convolution)) - .padding_low(); - const int64 inherent_high_padding = - convolution->window() - .dimensions(get_chosen_spatial_dim(convolution)) - .padding_high(); - - const int64 spatial_size = - input_dim_size + inherent_low_padding + inherent_high_padding; - VLOG(1) << "spatial size " << spatial_size; - - const int64 num_splits = kNewBatchSize / old_batch_size; - // We currently only cater to evenly divisible cases. if (kNewBatchSize % old_batch_size != 0) { return false; } - // Splitting will be incorrect in these cases. - if (spatial_size < num_splits || - input_dim_size / num_splits < kernel_spatial_dim_size) { + VLOG(1) << "spatial size " << c.spatial_size; + + const int64 num_splits = kNewBatchSize / old_batch_size; + // If the ratio is not within the 2X range, we can't Halo Pad from the next + // split. + if (c.halo_size > CeilOfRatio(c.spatial_size, num_splits)) { return false; } VLOG(1) << "Legal space-to-batch convolution " << convolution->ToString(); @@ -292,8 +299,8 @@ StatusOr<HloInstruction*> ConvolutionVisitor::HaloDuplicateWithSlice( activations->shape().dimensions(spatial_dimension_to_split); const int64 batch_size = activations->shape().dimensions(activations_batch_dim); - CHECK_LT(low_padding, spatial_split_size); + CHECK_LE(std::abs(halo_size - low_padding), spatial_split_size); VLOG(1) << "In HaloDuplicateWithSlice with activations " << activations->ToString() << " batch_size " << batch_size << " spatial_split_size " << spatial_split_size << " low_padding " @@ -439,6 +446,7 @@ StatusOr<bool> ConvolutionVisitor::Run() { // Iterate through all instructions that we could not propagate through, and // turn their operands from batch-to-space as needed. for (auto instr : non_propagatable_instrs_) { + VLOG(1) << "Could not eventually propagate through " << instr->ToString(); absl::flat_hash_map<int64, HloInstruction*> operand_map; for (int64 i = 0; i < instr->operand_count(); ++i) { if (old_to_new_instrs_.count(instr->mutable_operand(i))) { @@ -480,8 +488,9 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, if (!old_to_new_instrs_.contains(old_producer) && !broadcast_or_constant) { - VLOG(1) << "Cannot propagate on elementwise op " - << consumer->ToString(); + VLOG(1) << "Cannot propagate on elementwise op " << consumer->ToString() + << " because operand " << old_producer->ToString() + << " isn't ready "; return false; } else { if (broadcast_or_constant) { @@ -496,10 +505,11 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, pivot_operand = old_producer; VLOG(2) << "Elementwise op: pivot " << old_producer->ToString(); } else { - VLOG(2) << "Elementwise op: checking for shape equivalence " - << consumer->ToString(); if (instr_to_dim_map_[pivot_operand] != instr_to_dim_map_[old_producer]) { + VLOG(2) << "Elementwise op: checking for shape equivalence " + << consumer->ToString() + << " failed due to changed batch space ordering "; return false; } auto pivot_new_instr = old_to_new_instrs_[pivot_operand]; @@ -509,13 +519,22 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, for (int j = 0; j < pivot_permute_dims.size(); ++j) { // Ensure the dimension mapping is the same. if (pivot_permute_dims[j] != permute_dims[j]) { + VLOG(2) << "Elementwise op: checking for shape equivalence " + << consumer->ToString() + << " failed due to permuted dimensions "; return false; } // Make sure all other dimensions are of the same size. if (pivot_new_instr->shape().dimensions(j) != new_instr->shape().dimensions(j)) { - return false; + if (!(consumer->IsElementwiseBinary() && + j == instr_to_dim_map_[pivot_operand].second)) { + VLOG(2) << "Elementwise op: checking for shape equivalence " + << consumer->ToString() + << " failed due to changed shape sizes "; + return false; + } } } } @@ -769,6 +788,28 @@ StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer, if (IsTrivialElementwise(consumer)) { auto dim_map_val = instr_to_dim_map_[producer]; auto new_consumer = computation->AddInstruction(consumer->Clone()); + if (consumer->IsElementwiseBinary()) { + for (int64 i = 0; i < 2; ++i) { + if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) { + break; + } + CHECK(old_to_new_instrs_.contains(consumer->mutable_operand(i))); + if (i == 1) { + // Choose the larger shape to be used as the producer. + if (old_to_new_instrs_[consumer->mutable_operand(0)] + ->shape() + .dimensions() > + old_to_new_instrs_[consumer->mutable_operand(1)] + ->shape() + .dimensions()) { + producer = consumer->mutable_operand(0); + } else { + producer = consumer->mutable_operand(1); + } + } + } + } + for (int64 i = 0; i < consumer->operand_count(); ++i) { if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) { CHECK(old_to_new_instrs_.contains(producer)); @@ -786,8 +827,66 @@ StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer, new_consumer->ReplaceOperandWithDifferentShape(i, new_broadcast)); } else { CHECK(old_to_new_instrs_.contains(consumer->mutable_operand(i))); - TF_CHECK_OK(new_consumer->ReplaceOperandWithDifferentShape( - i, old_to_new_instrs_[consumer->mutable_operand(i)])); + HloInstruction* operand_to_use = nullptr; + + auto result = instr_to_dim_map_[producer]; + const int64 old_batch_dim = result.first; + const int64 old_space_dim = result.second; + const int64 old_batch_size = + producer->shape().dimensions(old_batch_dim); + HloInstruction* new_instr = + old_to_new_instrs_[consumer->mutable_operand(i)]; + HloInstruction* pivot_new_instr = old_to_new_instrs_[producer]; + + auto permute_dims = instr_to_dim_permute_map_[new_instr]; + const int64 batch_dim = DimLookUp(permute_dims, old_batch_dim); + const int64 space_dim = DimLookUp(permute_dims, old_space_dim); + const int64 batch_size = new_instr->shape().dimensions(batch_dim); + + if (new_instr->shape().dimensions(space_dim) != + pivot_new_instr->shape().dimensions(space_dim)) { + // Because we do not propagate through transposes, the batch should + // always be followed by the split space dimension. + CHECK_EQ(batch_dim + 1, space_dim); + + // Reshape to 1D, pad to the producer's size, reshape back to 2D. + std::vector<int64> new_dimensions( + new_instr->shape().dimensions().begin(), + new_instr->shape().dimensions().end()); + new_dimensions[space_dim] *= (batch_size / old_batch_size); + new_dimensions[batch_dim] = old_batch_size; + + TF_ASSIGN_OR_RETURN(HloInstruction * reshape, + MakeReshapeHlo(new_dimensions, new_instr)); + + const int64 pivot_space_size = + pivot_new_instr->shape().dimensions(space_dim) * batch_size / + old_batch_size; + + CHECK_GT(pivot_space_size, new_dimensions[space_dim]); + + PaddingConfig padding_config = + MakeNoPaddingConfig(reshape->shape().dimensions_size()); + padding_config.mutable_dimensions(space_dim)->set_edge_padding_high( + pivot_space_size - new_dimensions[space_dim]); + padding_config.mutable_dimensions(space_dim)->set_edge_padding_low(0); + HloInstruction* padding = + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(reshape->shape().element_type()))); + + TF_ASSIGN_OR_RETURN(HloInstruction * padded_operand, + MakePadHlo(reshape, padding, padding_config)); + + TF_ASSIGN_OR_RETURN( + operand_to_use, + MakeReshapeHlo(pivot_new_instr->shape().dimensions(), + padded_operand)); + + } else { + operand_to_use = old_to_new_instrs_[consumer->mutable_operand(i)]; + } + TF_CHECK_OK( + new_consumer->ReplaceOperandWithDifferentShape(i, operand_to_use)); } } auto old_type = new_consumer->mutable_shape()->element_type(); @@ -1329,25 +1428,21 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { original_conv_dims.input_spatial_dimensions(i))); } - int64 spatial_dimension_to_split = - permuted_conv_dims_numbers.input_spatial_dimensions( - get_chosen_spatial_dim(convolution)); - const int64 old_batch_dim = original_conv_dims.input_batch_dimension(); const int64 old_batch_size = activations_old->shape().dimensions(old_batch_dim); - const int64 input_dim_size = activations_old->shape().dimensions( - permuted_conv_dims_numbers.input_spatial_dimensions( - get_chosen_spatial_dim(convolution))); + ConvDetails c = + GetConvolutionDetails(convolution, permuted_conv_dims_numbers); VLOG(1) << "Propagating on conv activations_batch_dim " << activations_batch_dim << " spatial_dimension_to_split " - << spatial_dimension_to_split << " old_batch_size " << old_batch_size; - TF_ASSIGN_OR_RETURN( - activations_new, - BringSpaceNextToBatch(activations_new, permuted_conv_dims_numbers, - spatial_dimension_to_split, activations_batch_dim)); + << c.spatial_dimension_to_split << " old_batch_size " + << old_batch_size; + TF_ASSIGN_OR_RETURN(activations_new, + BringSpaceNextToBatch( + activations_new, permuted_conv_dims_numbers, + c.spatial_dimension_to_split, activations_batch_dim)); auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::Zero(activations_new->shape().element_type()))); @@ -1355,32 +1450,12 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { TF_ASSIGN_OR_RETURN( activations_new, SelectValidPortion(activations_new, activations_old, select_val, - activations_batch_dim, spatial_dimension_to_split, + activations_batch_dim, c.spatial_dimension_to_split, old_batch_dim, old_space_dim)); // Create the new convolution dim numbers. auto new_dim_numbers = permuted_conv_dims_numbers; - auto kernel = convolution->mutable_operand(1); - const auto& kernel_shape = kernel->shape(); - const int64 kernel_spatial_dim_size = kernel_shape.dimensions( - permuted_conv_dims_numbers.kernel_spatial_dimensions( - get_chosen_spatial_dim(convolution))); - - const int64 inherent_low_padding = - convolution->window() - .dimensions(get_chosen_spatial_dim(convolution)) - .padding_low(); - const int64 inherent_high_padding = - convolution->window() - .dimensions(get_chosen_spatial_dim(convolution)) - .padding_high(); - const int64 stride = convolution->window() - .dimensions(get_chosen_spatial_dim(convolution)) - .stride(); - - const int64 spatial_size = - input_dim_size + inherent_low_padding + inherent_high_padding; - VLOG(1) << "spatial size " << spatial_size; + VLOG(1) << "spatial size " << c.spatial_size; const int64 num_splits = kNewBatchSize / old_batch_size; @@ -1390,18 +1465,18 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { const int64 output_offsets_per_split = CeilOfRatio(output_offsets, num_splits); - int64 spatial_split_size = output_offsets_per_split * stride; - const int64 halo_size = - std::max(kernel_spatial_dim_size - stride, static_cast<int64>(0)); + int64 spatial_split_size = + CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride; + // Keep increasing the split size so that overall size isn't smaller than the // original spatial dimension. Unlike for the first space-to-batch'ed // convolution, while propagating, we can use the last halo_size as available // spatial size. - while (spatial_split_size * num_splits + halo_size - spatial_size < 0) { - spatial_split_size += stride; + while (spatial_split_size * num_splits + c.halo_size - c.spatial_size < 0) { + spatial_split_size += c.stride; } - int64 slice_size = spatial_split_size + halo_size; + int64 slice_size = spatial_split_size + c.halo_size; VLOG(1) << "spatial_split_size " << spatial_split_size << " slice_size " << slice_size; @@ -1409,7 +1484,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { const int64 new_batch_size = activations_new->shape().dimensions(activations_batch_dim); const int64 new_space_size = - activations_new->shape().dimensions(spatial_dimension_to_split); + activations_new->shape().dimensions(c.spatial_dimension_to_split); // In the below case, we cannot use the activations directly for Halo // Duplication. We must reshape them. if (spatial_split_size > new_space_size) { @@ -1418,7 +1493,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { activations_new->shape().dimensions().end()); const int64 reshaped_space_size = new_space_size * new_batch_size / old_batch_size; - new_dimensions[spatial_dimension_to_split] = reshaped_space_size; + new_dimensions[c.spatial_dimension_to_split] = reshaped_space_size; new_dimensions[activations_batch_dim] = old_batch_size; // Reshape the output of the new conv into the old convolutions shape. @@ -1427,10 +1502,10 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { PaddingConfig padding_config = MakeNoPaddingConfig(reshaped_activations->shape().dimensions_size()); - padding_config.mutable_dimensions(spatial_dimension_to_split) + padding_config.mutable_dimensions(c.spatial_dimension_to_split) ->set_edge_padding_high(spatial_split_size * new_batch_size - reshaped_space_size); - padding_config.mutable_dimensions(spatial_dimension_to_split) + padding_config.mutable_dimensions(c.spatial_dimension_to_split) ->set_edge_padding_low(0); HloInstruction* padding = computation_->AddInstruction(HloInstruction::CreateConstant( @@ -1444,7 +1519,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { reshaped_activations->shape().dimensions().begin(), reshaped_activations->shape().dimensions().end()); - reshape_back_dims[spatial_dimension_to_split] = spatial_split_size; + reshape_back_dims[c.spatial_dimension_to_split] = spatial_split_size; reshape_back_dims[activations_batch_dim] = new_batch_size; TF_ASSIGN_OR_RETURN( @@ -1453,34 +1528,38 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { TF_ASSIGN_OR_RETURN( activations_new, - HaloDuplicateWithSlice(reshaped_activations, spatial_dimension_to_split, - activations_batch_dim, old_batch_size, - /*low_padding=*/inherent_low_padding, - /*high_padding=*/inherent_high_padding, - slice_size - spatial_split_size, - old_split_dim_size)); + HaloDuplicateWithSlice( + reshaped_activations, c.spatial_dimension_to_split, + activations_batch_dim, old_batch_size, + /*low_padding=*/c.base_dilation_factor != 1 && + c.inherent_low_padding != 0 + ? c.base_dilation_factor - 1 + : c.inherent_low_padding, + c.inherent_high_padding, slice_size - spatial_split_size, + old_split_dim_size)); } else { // If the ideal spatial_split_size was smaller than the incoming spatial // dimension size, we don't need reshaping. Instead, we determine the // additional space available, and adjust the required slice size (and - // thereby the halo size).'t need reshaping. Instead, we determine the - // additional space available, and adjust the required slice size (and // thereby the halo size). if (spatial_split_size < new_space_size) { - const int64 additional_space_present = spatial_split_size % stride; + const int64 additional_space_present = spatial_split_size % c.stride; spatial_split_size = new_space_size; slice_size = - spatial_split_size + - std::max(kernel_spatial_dim_size - stride - additional_space_present, - static_cast<int64>(0)); + spatial_split_size + std::max(c.kernel_spatial_dim_size - c.stride - + additional_space_present, + static_cast<int64>(0)); } TF_ASSIGN_OR_RETURN( activations_new, - HaloDuplicateWithSlice(activations_new, spatial_dimension_to_split, + HaloDuplicateWithSlice(activations_new, c.spatial_dimension_to_split, activations_batch_dim, old_batch_size, - /*low_padding=*/inherent_low_padding, - /*high_padding=*/inherent_high_padding, + /*low_padding=*/c.base_dilation_factor != 1 && + c.inherent_low_padding != 0 + ? c.base_dilation_factor - 1 + : c.inherent_low_padding, + c.inherent_high_padding, slice_size - spatial_split_size, old_split_dim_size)); } @@ -1515,15 +1594,16 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { auto new_window = convolution->window(); new_window.mutable_dimensions(get_chosen_spatial_dim(convolution)) - ->set_padding_high(0); + ->set_padding_high(c.high_padding_for_conv); new_window.mutable_dimensions(get_chosen_spatial_dim(convolution)) - ->set_padding_low(0); + ->set_padding_low(c.low_padding_for_conv); TF_ASSIGN_OR_RETURN( HloInstruction * new_conv, - MakeConvolveHlo(activations_new, /*rhs=*/convolution->mutable_operand(1), - convolution->feature_group_count(), - convolution->batch_group_count(), new_window, - new_dim_numbers, convolution->precision_config())); + MakeConvolveHlo( + activations_new, /*rhs=*/convolution->mutable_operand(1), + convolution->feature_group_count(), convolution->batch_group_count(), + new_window, new_dim_numbers, convolution->precision_config(), + /*preferred_element_type=*/convolution->shape().element_type())); convolution->SetupDerivedInstruction(new_conv); old_to_new_instrs_[convolution] = new_conv; @@ -1800,10 +1880,11 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv( TF_ASSIGN_OR_RETURN( HloInstruction * new_conv, - MakeConvolveHlo(activations_new, kernel_new, - convolution->feature_group_count(), - convolution->batch_group_count(), new_window, - new_dim_numbers, convolution->precision_config())); + MakeConvolveHlo( + activations_new, kernel_new, convolution->feature_group_count(), + convolution->batch_group_count(), new_window, new_dim_numbers, + convolution->precision_config(), + /*preferred_element_type=*/convolution->shape().element_type())); convolution->SetupDerivedInstruction(new_conv); std::vector<int64> output_sizes(new_conv->shape().dimensions().begin(), @@ -1853,19 +1934,9 @@ HloInstruction* ConvolutionVisitor::DoesConvolutionFeedReduceWindow( return nullptr; } -Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( - HloInstruction* convolution) { - VLOG(1) << "Handling conv " << convolution->ToString(); - - changed_ = false; - - ConvolutionDimensionNumbers dim_numbers = - convolution->convolution_dimension_numbers(); - - int64 activations_batch_dim = dim_numbers.input_batch_dimension(); - - const int64 old_batch_size = - convolution->operand(0)->shape().dimensions(activations_batch_dim); +ConvolutionVisitor::ConvDetails ConvolutionVisitor::GetConvolutionDetails( + HloInstruction* convolution, ConvolutionDimensionNumbers& dim_numbers) { + auto activations = convolution->mutable_operand(0); auto kernel = convolution->mutable_operand(1); const auto& kernel_shape = kernel->shape(); @@ -1873,14 +1944,11 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( kernel_shape.dimensions(dim_numbers.kernel_spatial_dimensions( get_chosen_spatial_dim(convolution))); - auto activations = convolution->mutable_operand(0); - - int64 spatial_dimension_to_split = + const int64 spatial_dimension_to_split = dim_numbers.input_spatial_dimensions(get_chosen_spatial_dim(convolution)); const int64 input_dim_size = - activations->shape().dimensions(dim_numbers.input_spatial_dimensions( - get_chosen_spatial_dim(convolution))); + activations->shape().dimensions(spatial_dimension_to_split); const int64 inherent_low_padding = convolution->window() @@ -1890,26 +1958,75 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( convolution->window() .dimensions(get_chosen_spatial_dim(convolution)) .padding_high(); - const bool inherent_padding_needed = - inherent_low_padding != 0 || inherent_high_padding != 0; const int64 stride = convolution->window() .dimensions(get_chosen_spatial_dim(convolution)) .stride(); + const int64 base_dilation_factor = + convolution->window() + .dimensions(get_chosen_spatial_dim(convolution)) + .base_dilation(); + const int64 spatial_size = - input_dim_size + inherent_low_padding + inherent_high_padding; - VLOG(1) << "spatial size " << spatial_size; + input_dim_size + (base_dilation_factor > 1 ? 0 : inherent_low_padding) + + inherent_high_padding; + + const int64 halo_size = + std::max(kernel_spatial_dim_size - stride - (base_dilation_factor - 1), + static_cast<int64>(0)); + const int64 high_padding_for_conv = base_dilation_factor == 1 ? 0 + : inherent_low_padding == 0 + ? base_dilation_factor - 1 + : 0; + const int64 low_padding_for_conv = + base_dilation_factor == 1 ? 0 : inherent_low_padding; + + return ConvDetails{spatial_dimension_to_split, + inherent_low_padding, + inherent_high_padding, + stride, + spatial_size, + base_dilation_factor, + halo_size, + high_padding_for_conv, + low_padding_for_conv, + kernel_spatial_dim_size, + input_dim_size}; +} + +Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( + HloInstruction* convolution) { + VLOG(1) << "Handling conv " << convolution->ToString(); + + changed_ = false; + + ConvolutionDimensionNumbers dim_numbers = + convolution->convolution_dimension_numbers(); + + ConvDetails c = GetConvolutionDetails(convolution, dim_numbers); + + int64 activations_batch_dim = dim_numbers.input_batch_dimension(); + + const int64 old_batch_size = + convolution->operand(0)->shape().dimensions(activations_batch_dim); + + auto activations = convolution->mutable_operand(0); + + const bool inherent_padding_needed = + c.inherent_low_padding != 0 || c.inherent_high_padding != 0; + + VLOG(1) << "spatial size " << c.spatial_size; const int64 num_splits = kNewBatchSize / old_batch_size; auto original_conv = convolution; // We'd need transposition of activations here such that batch and space dim // that is being split are adjacent (in that order). - TF_ASSIGN_OR_RETURN( - activations, - BringSpaceNextToBatch(activations, dim_numbers, - spatial_dimension_to_split, activations_batch_dim)); + TF_ASSIGN_OR_RETURN(activations, + BringSpaceNextToBatch(activations, dim_numbers, + c.spatial_dimension_to_split, + activations_batch_dim)); // Create the new convolution dim numbers. auto new_dim_numbers = dim_numbers; @@ -1920,11 +2037,12 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( const int64 output_offsets_per_split = CeilOfRatio(output_offsets, num_splits); - int64 spatial_split_size = output_offsets_per_split * stride; + int64 spatial_split_size = + CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride; // Keep increasing the split size so that overall size isn't smaller than the // original spatial dimension. - while (spatial_split_size * num_splits - spatial_size < 0) { - spatial_split_size += stride; + while (spatial_split_size * num_splits - c.spatial_size < 0) { + spatial_split_size += c.stride; } auto reduce_window = DoesConvolutionFeedReduceWindow(convolution); @@ -1936,33 +2054,32 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( // windows. const int64 red_win_stride = reduce_window->window().dimensions(output_spatial_dim).stride(); - while ((spatial_split_size / stride) % red_win_stride != 0) { - spatial_split_size += stride; + while ((spatial_split_size / c.stride) % red_win_stride != 0) { + spatial_split_size += c.stride; } } - const int64 slice_size = - spatial_split_size + - std::max(kernel_spatial_dim_size - stride, static_cast<int64>(0)); + const int64 slice_size = spatial_split_size + c.halo_size; // Pad spatial dim. - const int64 pad_size = spatial_split_size * num_splits - spatial_size; + const int64 pad_size = spatial_split_size * num_splits - c.spatial_size; VLOG(1) << "spatial_split_size " << spatial_split_size << " stride " - << stride; - VLOG(1) << "spatial_dimension_to_split " << spatial_dimension_to_split + << c.stride << " slice_size " << slice_size; + VLOG(1) << "spatial_dimension_to_split " << c.spatial_dimension_to_split << " num_splits " << num_splits << " kernel_spatial_dim_size " - << kernel_spatial_dim_size; + << c.kernel_spatial_dim_size; // Because we are splitting the spatial dimension, if convolution needed // padding in the spatial dimension, we materialize it. if (pad_size != 0 || inherent_padding_needed) { PaddingConfig padding_config = MakeNoPaddingConfig(activations->shape().dimensions_size()); - padding_config.mutable_dimensions(spatial_dimension_to_split) - ->set_edge_padding_high(inherent_high_padding + pad_size); - padding_config.mutable_dimensions(spatial_dimension_to_split) - ->set_edge_padding_low(inherent_low_padding); + padding_config.mutable_dimensions(c.spatial_dimension_to_split) + ->set_edge_padding_high(c.inherent_high_padding + pad_size); + padding_config.mutable_dimensions(c.spatial_dimension_to_split) + ->set_edge_padding_low( + c.base_dilation_factor == 1 ? c.inherent_low_padding : 0); HloInstruction* padding = computation_->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::Zero(activations->shape().element_type()))); @@ -1989,7 +2106,7 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( activations->shape().dimensions().begin(), activations->shape().dimensions().end()); - reshape_dimensions[spatial_dimension_to_split] = spatial_split_size; + reshape_dimensions[c.spatial_dimension_to_split] = spatial_split_size; reshape_dimensions[activations_batch_dim] = num_splits * old_batch_size; TF_ASSIGN_OR_RETURN(HloInstruction * batch_increased_reshape, @@ -1998,12 +2115,12 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( VLOG(1) << "First reshape done " << batch_increased_reshape->ToString(); - TF_ASSIGN_OR_RETURN(activations, - HaloDuplicateWithSlice( - batch_increased_reshape, spatial_dimension_to_split, - activations_batch_dim, old_batch_size, - /*low_padding=*/0, /*high_padding=*/0, - slice_size - spatial_split_size, input_dim_size)); + TF_ASSIGN_OR_RETURN( + activations, HaloDuplicateWithSlice(batch_increased_reshape, + c.spatial_dimension_to_split, + activations_batch_dim, old_batch_size, + /*low_padding=*/0, /*high_padding=*/0, + c.halo_size, c.input_dim_size)); VLOG(1) << "Batch merge done " << activations->ToString(); @@ -2038,15 +2155,16 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( << " batch dim " << new_dim_numbers.input_batch_dimension(); auto new_window = convolution->window(); new_window.mutable_dimensions(get_chosen_spatial_dim(convolution)) - ->set_padding_high(0); + ->set_padding_high(c.high_padding_for_conv); new_window.mutable_dimensions(get_chosen_spatial_dim(convolution)) - ->set_padding_low(0); + ->set_padding_low(c.low_padding_for_conv); TF_ASSIGN_OR_RETURN( HloInstruction * new_conv, - MakeConvolveHlo(activations, /*rhs=*/convolution->mutable_operand(1), - convolution->feature_group_count(), - convolution->batch_group_count(), new_window, - new_dim_numbers, convolution->precision_config())); + MakeConvolveHlo( + activations, /*rhs=*/convolution->mutable_operand(1), + convolution->feature_group_count(), convolution->batch_group_count(), + new_window, new_dim_numbers, convolution->precision_config(), + /*preferred_element_type=*/convolution->shape().element_type())); convolution->SetupDerivedInstruction(new_conv); VLOG(1) << "Space-to-batched convolution " << new_conv->ToString(); diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc b/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc index d53bb7d75f3..8921d98cad0 100644 --- a/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc +++ b/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc @@ -113,7 +113,7 @@ TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithStrideAndPad) { EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 4); } -TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithKernelDilation) { +TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithBaseDilation) { string hlo_string = R"( HloModule module @@ -129,8 +129,22 @@ ENTRY computation { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, ParseAndReturnVerifiedModule(hlo_string)); + auto computation = module->entry_computation(); ConvolutionSpaceToBatchConverter converter; - ASSERT_FALSE(converter.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); + + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Transpose()); + EXPECT_THAT(root->operand(0), op::Slice()); + auto reshape = root->operand(0)->operand(0); + EXPECT_THAT(reshape, op::Reshape()); + EXPECT_THAT(reshape->operand(0)->operand(1), op::Convolution()); + const int64 batch_dim = reshape->operand(0) + ->operand(1) + ->convolution_dimension_numbers() + .output_batch_dimension(); + + EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 4); } } // namespace diff --git a/tensorflow/compiler/xla/service/spmd/convolution_handler.cc b/tensorflow/compiler/xla/service/spmd/convolution_handler.cc index 0d34c5b62e9..7130ba3dd5b 100644 --- a/tensorflow/compiler/xla/service/spmd/convolution_handler.cc +++ b/tensorflow/compiler/xla/service/spmd/convolution_handler.cc @@ -950,7 +950,8 @@ StatusOr<std::unique_ptr<HloInstruction>> CreateShardedConvConvolution( Shape sharded_conv_shape, ShapeInference::InferConvolveShape( sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(), - feature_group_count, batch_group_count, window, conv_dnums)); + feature_group_count, batch_group_count, window, conv_dnums, + /*preferred_element_type=*/conv.shape().element_type())); *sharded_conv_shape.mutable_layout() = conv.shape().layout(); return HloInstruction::CreateConvolve( sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo, feature_group_count, diff --git a/tensorflow/compiler/xla/service/spmd/dot_handler.cc b/tensorflow/compiler/xla/service/spmd/dot_handler.cc index a346d8778d6..bb666221a56 100644 --- a/tensorflow/compiler/xla/service/spmd/dot_handler.cc +++ b/tensorflow/compiler/xla/service/spmd/dot_handler.cc @@ -80,8 +80,9 @@ Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) { const Window& conv_window) -> StatusOr<HloInstruction*> { TF_ASSIGN_OR_RETURN( auto sharded_dot_shape, - ShapeInference::InferDotOpShape(l->shape(), r->shape(), - hlo->dot_dimension_numbers())); + ShapeInference::InferDotOpShape( + l->shape(), r->shape(), hlo->dot_dimension_numbers(), + /*preferred_element_type=*/hlo->shape().element_type())); return b->AddInstruction(HloInstruction::CreateDot( sharded_dot_shape, l, r, hlo->dot_dimension_numbers(), hlo->precision_config())); @@ -1289,6 +1290,14 @@ StatusOr<HloInstruction*> PartitionDot( const SpmdPartitionerOptions& options, SpmdBuilder* b, std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>* windowed_dot_general_loops) { + // If lhs‘ hlo and rhs' hlo are identical, make a copy for rhs. + if (lhs.hlo() == rhs.hlo()) { + auto copy_hlo = b->AddInstruction(HloInstruction::CreateUnary( + rhs.hlo()->shape(), HloOpcode::kCopy, rhs.hlo())); + copy_hlo->set_sharding(rhs.sharding()); + rhs = PartitionedHlo(copy_hlo, rhs.base_shape(), rhs.state()); + } + // lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output. auto get_partitions_for_dims = [&](const HloSharding& sharding, diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index a70fdd9b828..4c1fb336439 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -6538,6 +6538,35 @@ ENTRY entry { op::Shape("c64[1,1,3]"))); } +TEST_F(SpmdPartitioningTest, DotInputsAreIdentical) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %parameter.1 = f32[4000,4000]{1,0} parameter(0), + sharding={devices=[2,4]0,1,2,3,4,5,6,7} + ROOT %convolution = f32[4000,4000]{1,0} convolution( + f32[4000,4000]{1,0} %parameter.1, f32[4000,4000]{1,0} %parameter.1), + dim_labels=bf_io->bf, sharding={devices=[2,4]0,1,2,3,4,5,6,7} +} + +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto param = AllOf(op::Parameter(), op::Shape("f32[2000, 1000]")); + auto resharded_lhs = + AllOf(op::AllReduce(op::DynamicUpdateSlice(_, param, _, _)), + op::Shape("f32[2000, 4000]")); + auto resharded_rhs = + AllOf(op::AllReduce(op::DynamicUpdateSlice(_, op::Copy(param), _, _)), + op::Shape("f32[4000, 1000]")); + EXPECT_THAT(root, AllOf(op::Convolution(resharded_lhs, resharded_rhs), + op::Shape("f32[2000, 1000]"))); +} + } // namespace } // namespace spmd } // namespace xla diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 3fe69d22e9c..18223406da8 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -242,7 +242,8 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { } StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape( x->shape(), transpose_y->shape(), /*feature_group_count=*/1, - /*batch_group_count=*/1, window, dnums); + /*batch_group_count=*/1, window, dnums, + /*preferred_element_type=*/absl::nullopt); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), x, transpose_y, @@ -298,7 +299,8 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { } StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape( x->shape(), transpose_y->shape(), /*feature_group_count=*/1, - /*batch_group_count=*/1, window, dnums); + /*batch_group_count=*/1, window, dnums, + /*preferred_element_type=*/absl::nullopt); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), x, transpose_y, @@ -359,7 +361,8 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { } StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape( transpose_x->shape(), y->shape(), /*feature_group_count=*/1, - /*batch_group_count=*/1, window, dnums); + /*batch_group_count=*/1, window, dnums, + /*preferred_element_type=*/absl::nullopt); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), transpose_x, y, @@ -426,7 +429,8 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { } StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape( transpose_x->shape(), y->shape(), /*feature_group_count=*/1, - /*batch_group_count=*/1, window, dnums); + /*batch_group_count=*/1, window, dnums, + /*preferred_element_type=*/absl::nullopt); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), transpose_x, y, diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index cb0edfb6be6..e84a2591707 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -52,6 +52,33 @@ namespace xla { using absl::StrAppend; using absl::StrCat; +namespace { +// An array that is indexed by PrimitiveType, and returns +// the size of each element of that primitive type, or 0 +// if the PrimitiveType is not a primitive type +constexpr uint8 primitive_byte_size[PrimitiveType_ARRAYSIZE] = { + 0, // PRIMITIVE_TYPE_INVALID = 0, + sizeof(int8), // PRED = 1 + sizeof(int8), // S8 = 2 + sizeof(int16), // S16 = 3 + sizeof(int32), // S32 = 4 + sizeof(int64), // S64 = 5 + sizeof(uint8), // U8 = 6 + sizeof(uint16), // U16 = 7 + sizeof(uint32), // U32 = 8 + sizeof(uint64), // U64 = 9 + sizeof(float) / 2, // F16 = 10 + sizeof(float), // F32 = 11 + sizeof(double), // F64 = 12 + 0, // TUPLE = 13 + 0, // OPAQUE_TYPE = 14 + sizeof(complex64), // C64 = 15 + sizeof(float) / 2, // BF16 = 16 + 0, // TOKEN = 17 + sizeof(complex128) // C128 = 18 +}; +} // namespace + string ShapeIndex::ToString() const { return ShapeIndexView(*this).ToString(); } string ShapeIndexView::ToString() const { @@ -175,6 +202,42 @@ StatusOr<Shape> MakeShapeWithLayoutInternal( return accum; } +/* static */ bool ShapeUtil::FillNewShape(PrimitiveType element_type, + absl::Span<const int64> dimensions, + Shape* shape) { + const int eint = static_cast<int>(element_type); + int64 dense_shape_size = ((eint >= 0 && eint < PrimitiveType_ARRAYSIZE) + ? primitive_byte_size[eint] + : 0); // Out of range: force a failure + if (dense_shape_size <= 0) { + return false; + } + + // Verify that array-based lookup is consistent with public API. + DCHECK_EQ(dense_shape_size, ByteSizeOfPrimitiveType(element_type)) + << element_type; + + shape->set_element_type(element_type); + const int ndims = dimensions.size(); + auto layout = shape->mutable_layout(); + layout->set_format(DENSE); + auto* minor_to_major = layout->mutable_minor_to_major(); + for (int i = 0; i < ndims; i++) { + const int64 d = dimensions[i]; + if (d < 0) { + return false; + } + dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, d); + if (dense_shape_size < 0) { + return false; + } + + shape->add_dimensions(d); + minor_to_major->push_back(ndims - 1 - i); + } + return true; +} + /* static */ ProgramShape ShapeUtil::MakeProgramShape( std::initializer_list<Shape> parameters, Shape result) { ProgramShape program_shape; @@ -187,7 +250,9 @@ StatusOr<Shape> MakeShapeWithLayoutInternal( /* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type, absl::Span<const int64> dimensions) { - return MakeValidatedShape(element_type, dimensions).ValueOrDie(); + Shape shape; + CHECK(FillNewShape(element_type, dimensions, &shape)); + return shape; } /* static */ Shape ShapeUtil::MakeScalarShape(PrimitiveType element_type) { @@ -210,18 +275,31 @@ StatusOr<Shape> MakeShapeWithLayoutInternal( /* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape( PrimitiveType element_type, absl::Span<const int64> dimensions) { - CHECK(IsArrayPrimitiveType(element_type)) << element_type; - Shape result; - TF_RETURN_IF_ERROR(PopulateShape(element_type, dimensions, &result)); - return result; + Shape shape; + if (!FillNewShape(element_type, dimensions, &shape)) { + return InvalidArgument("invalid shape type=%d, dims=[%s]", + static_cast<int>(element_type), + absl::StrJoin(dimensions, ",")); + } + return shape; } /* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape( PrimitiveType element_type, absl::Span<const int64> dimensions, const std::vector<bool>& dynamic_dimensions) { - TF_ASSIGN_OR_RETURN(Shape shape, - MakeValidatedShape(element_type, dimensions)); - for (int i = 0; i < dynamic_dimensions.size(); ++i) { + if (dynamic_dimensions.size() != dimensions.size()) { + return InvalidArgument( + "dynamic dimensions size %d did not match number of dimensions %d", + dynamic_dimensions.size(), dimensions.size()); + } + + Shape shape; + if (!FillNewShape(element_type, dimensions, &shape)) { + return InvalidArgument("invalid shape type=%d, dims=[%s]", + static_cast<int>(element_type), + absl::StrJoin(dimensions, ",")); + } + for (int i = 0, n = dimensions.size(); i < n; i++) { shape.set_dynamic_dimension(i, dynamic_dimensions[i]); } return shape; diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index c1a6a2c8b1d..ff47ab6ea80 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -792,6 +792,11 @@ class ShapeUtil { static bool CanUpcastIntegral(const Shape& from, const Shape& to); private: + // Fills *shape. Returns true on success. + // REQUIRES: *shape is empty. + static bool FillNewShape(PrimitiveType element_type, + absl::Span<const int64> dimensions, Shape* shape); + // Validates the shape size is sane. This makes sure it's safe to do // calculations in int64 without overflowing. static Status ValidateShapeSize(const Shape& shape); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 4e2030667ee..1a944d01941 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test_benchmark.h" namespace xla { namespace { @@ -827,5 +828,19 @@ TEST(AlignmentTest, EXPECT_FALSE(aligned_shape); } +void BM_MakeShape(::testing::benchmark::State& state) { + for (auto s : state) { + ShapeUtil::MakeShape(F32, {2}); + } +} +BENCHMARK(BM_MakeShape); + +void BM_MakeValidatedShape(::testing::benchmark::State& state) { + for (auto s : state) { + ShapeUtil::MakeValidatedShape(F32, {2}).ValueOrDie(); + } +} +BENCHMARK(BM_MakeValidatedShape); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index fe27a8c6963..69916f6abc5 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -554,7 +554,7 @@ ENTRY jit_broken.874 { abs.129 = f32[4]{0} abs(subtract.126) constant.130 = f32[] constant(inf) broadcast.131 = f32[4]{0} broadcast(constant.130), dimensions={} - compare.132 = pred[4]{0} compare(abs.129, broadcast.131), direction=EQ, type=UNSIGNED + compare.132 = pred[4]{0} compare(abs.129, broadcast.131), direction=EQ not.133 = pred[4]{0} not(compare.132) and.134 = pred[4]{0} and(not.128, not.133) add.135 = f32[4]{0} add(add.124, add.89) @@ -577,7 +577,7 @@ ENTRY jit_broken.874 { abs.219 = f32[4]{0} abs(subtract.216) constant.220 = f32[] constant(inf) broadcast.221 = f32[4]{0} broadcast(constant.220), dimensions={} - compare.222 = pred[4]{0} compare(abs.219, broadcast.221), direction=EQ, type=UNSIGNED + compare.222 = pred[4]{0} compare(abs.219, broadcast.221), direction=EQ not.223 = pred[4]{0} not(compare.222) and.224 = pred[4]{0} and(not.218, not.223) add.225 = f32[4]{0} add(add.214, add.179) @@ -600,7 +600,7 @@ ENTRY jit_broken.874 { abs.309 = f32[4]{0} abs(subtract.306) constant.310 = f32[] constant(inf) broadcast.311 = f32[4]{0} broadcast(constant.310), dimensions={} - compare.312 = pred[4]{0} compare(abs.309, broadcast.311), direction=EQ, type=UNSIGNED + compare.312 = pred[4]{0} compare(abs.309, broadcast.311), direction=EQ not.313 = pred[4]{0} not(compare.312) and.314 = pred[4]{0} and(not.308, not.313) add.315 = f32[4]{0} add(add.304, add.269) @@ -623,7 +623,7 @@ ENTRY jit_broken.874 { abs.399 = f32[4]{0} abs(subtract.396) constant.400 = f32[] constant(inf) broadcast.401 = f32[4]{0} broadcast(constant.400), dimensions={} - compare.402 = pred[4]{0} compare(abs.399, broadcast.401), direction=EQ, type=UNSIGNED + compare.402 = pred[4]{0} compare(abs.399, broadcast.401), direction=EQ not.403 = pred[4]{0} not(compare.402) and.404 = pred[4]{0} and(not.398, not.403) add.405 = f32[4]{0} add(add.394, add.359) @@ -646,7 +646,7 @@ ENTRY jit_broken.874 { abs.489 = f32[4]{0} abs(subtract.486) constant.490 = f32[] constant(inf) broadcast.491 = f32[4]{0} broadcast(constant.490), dimensions={} - compare.492 = pred[4]{0} compare(abs.489, broadcast.491), direction=EQ, type=UNSIGNED + compare.492 = pred[4]{0} compare(abs.489, broadcast.491), direction=EQ not.493 = pred[4]{0} not(compare.492) and.494 = pred[4]{0} and(not.488, not.493) add.495 = f32[4]{0} add(add.484, add.449) @@ -669,7 +669,7 @@ ENTRY jit_broken.874 { abs.579 = f32[4]{0} abs(subtract.576) constant.580 = f32[] constant(inf) broadcast.581 = f32[4]{0} broadcast(constant.580), dimensions={} - compare.582 = pred[4]{0} compare(abs.579, broadcast.581), direction=EQ, type=UNSIGNED + compare.582 = pred[4]{0} compare(abs.579, broadcast.581), direction=EQ not.583 = pred[4]{0} not(compare.582) and.584 = pred[4]{0} and(not.578, not.583) add.585 = f32[4]{0} add(add.574, add.539) @@ -692,7 +692,7 @@ ENTRY jit_broken.874 { abs.669 = f32[4]{0} abs(subtract.666) constant.670 = f32[] constant(inf) broadcast.671 = f32[4]{0} broadcast(constant.670), dimensions={} - compare.672 = pred[4]{0} compare(abs.669, broadcast.671), direction=EQ, type=UNSIGNED + compare.672 = pred[4]{0} compare(abs.669, broadcast.671), direction=EQ not.673 = pred[4]{0} not(compare.672) and.674 = pred[4]{0} and(not.668, not.673) add.675 = f32[4]{0} add(add.664, add.629) @@ -715,7 +715,7 @@ ENTRY jit_broken.874 { abs.759 = f32[4]{0} abs(subtract.756) constant.760 = f32[] constant(inf) broadcast.761 = f32[4]{0} broadcast(constant.760), dimensions={} - compare.762 = pred[4]{0} compare(abs.759, broadcast.761), direction=EQ, type=UNSIGNED + compare.762 = pred[4]{0} compare(abs.759, broadcast.761), direction=EQ not.763 = pred[4]{0} not(compare.762) and.764 = pred[4]{0} and(not.758, not.763) add.765 = f32[4]{0} add(add.754, add.719) @@ -738,7 +738,7 @@ ENTRY jit_broken.874 { abs.849 = f32[4]{0} abs(subtract.846) constant.850 = f32[] constant(inf) broadcast.851 = f32[4]{0} broadcast(constant.850), dimensions={} - compare.852 = pred[4]{0} compare(abs.849, broadcast.851), direction=EQ, type=UNSIGNED + compare.852 = pred[4]{0} compare(abs.849, broadcast.851), direction=EQ not.853 = pred[4]{0} not(compare.852) and.854 = pred[4]{0} and(not.848, not.853) add.855 = f32[4]{0} add(add.844, add.809) diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index 3e9a3ec2314..aa377b57fe6 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -73,6 +73,8 @@ int main(int argc, char** argv) { triple_string = "x86_64-pc-windows-msvc19"; } else if (target_cpu == "ppc") { triple_string = "ppc64le-ibm-linux-gnu"; + } else if (target_cpu == "s390x") { + triple_string = "systemz-none-linux-gnu"; } else if (target_cpu == "local") { triple_string = llvm::sys::getDefaultTargetTriple(); } else { diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index 6e7deda13f0..f39bd269ef4 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -367,15 +367,15 @@ string SanitizeFileName(string file_name) { // precision, Numerische Mathematik, vol. 18, pp. 224–242, 1971. std::pair<float, float> SplitF64ToF32(double x) { const float x_f32 = static_cast<float>(x); - // Early return if x is an infinity or NaN. - if (!std::isfinite(x)) { - return std::make_pair(x_f32, 0.0f); - } - // Only values within the range of F32 are supported, unless it is infinity. - // Small values with large negative exponents would be rounded to zero. + // Early return if x is an infinity or NaN. if (!std::isfinite(x_f32)) { - LOG(WARNING) << "Out of range F64 constant detected: " << x; + // Only values within the range of F32 are supported, unless it is infinity. + // Small values with large negative exponents would be rounded to zero. + if (std::isfinite(x)) { + LOG(WARNING) << "Out of range F64 constant detected: " << x; + } + return std::make_pair(x_f32, 0.0f); } // The high float is simply the double rounded to the nearest float. Because diff --git a/tensorflow/compiler/xla/util_test.cc b/tensorflow/compiler/xla/util_test.cc index 5477dfba18d..6f60e241d92 100644 --- a/tensorflow/compiler/xla/util_test.cc +++ b/tensorflow/compiler/xla/util_test.cc @@ -126,5 +126,13 @@ TEST(UtilTest, RoundTripFpToString) { "-nan"); } +TEST(UtilTest, SplitF64ToF32) { + // Overflowing the F32 exponent in SplitF64ToF32 should result in a pair of + // [∞,0]. + EXPECT_EQ(SplitF64ToF32(std::numeric_limits<double>::max()).first, + std::numeric_limits<float>::infinity()); + EXPECT_EQ(SplitF64ToF32(std::numeric_limits<double>::max()).second, 0.0f); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index eade7c2426d..01de56bf85d 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -85,6 +85,7 @@ enum PrimitiveType { // Next = 19 } // LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, // https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc // ) diff --git a/tensorflow/core/api_def/BUILD b/tensorflow/core/api_def/BUILD index b1ced361927..4cab7b79034 100644 --- a/tensorflow/core/api_def/BUILD +++ b/tensorflow/core/api_def/BUILD @@ -6,7 +6,6 @@ # :python_api_def # :java_api_def -load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", @@ -27,9 +26,9 @@ package( licenses = ["notice"], # Apache 2.0 ) -filegroup( +alias( name = "base_api_def", - srcs = glob(["base_api/*"]), + actual = "//tensorflow/core/api_def/base_api:base_api_def", visibility = ["//tensorflow:internal"], ) diff --git a/tensorflow/core/api_def/base_api/BUILD b/tensorflow/core/api_def/base_api/BUILD new file mode 100644 index 00000000000..22cc342b1a1 --- /dev/null +++ b/tensorflow/core/api_def/base_api/BUILD @@ -0,0 +1,21 @@ +# Description: +# Expose TensorFlow base api. + +load("//tensorflow:tensorflow.bzl", "filegroup") + +package( + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "base_api_def", + srcs = glob( + [ + "*", + ], + exclude = [ + "BUILD", + ], + ), + visibility = ["//tensorflow:internal"], +) diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc index 11f28655f05..06af68bdb65 100644 --- a/tensorflow/core/common_runtime/bfc_allocator.cc +++ b/tensorflow/core/common_runtime/bfc_allocator.cc @@ -125,7 +125,8 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) { // Try allocating. size_t bytes = std::min(curr_region_allocation_bytes_, available_bytes); - void* mem_addr = sub_allocator_->Alloc(alignment, bytes); + size_t bytes_received; + void* mem_addr = sub_allocator_->Alloc(alignment, bytes, &bytes_received); if (mem_addr == nullptr && !started_backpedal_) { // Only backpedal once. started_backpedal_ = true; @@ -136,7 +137,7 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) { while (mem_addr == nullptr) { bytes = RoundedBytes(bytes * kBackpedalFactor); if (bytes < rounded_bytes) break; - mem_addr = sub_allocator_->Alloc(alignment, bytes); + mem_addr = sub_allocator_->Alloc(alignment, bytes, &bytes_received); } } @@ -158,7 +159,7 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) { VLOG(1) << "Allocated memory at " << mem_addr << " to " << static_cast<void*>(static_cast<char*>(mem_addr) + bytes); - region_manager_.AddAllocationRegion(mem_addr, bytes); + region_manager_.AddAllocationRegion(mem_addr, bytes_received); // Create one large chunk for the whole memory space that will // be chunked later. diff --git a/tensorflow/core/common_runtime/device/device_host_allocator.h b/tensorflow/core/common_runtime/device/device_host_allocator.h index 539dab18266..9d11705eadf 100644 --- a/tensorflow/core/common_runtime/device/device_host_allocator.h +++ b/tensorflow/core/common_runtime/device/device_host_allocator.h @@ -36,8 +36,10 @@ class DeviceHostAllocator : public SubAllocator { } ~DeviceHostAllocator() override {} - void* Alloc(size_t alignment, size_t num_bytes) override { + void* Alloc(size_t alignment, size_t num_bytes, + size_t* bytes_received) override { void* ptr = nullptr; + *bytes_received = num_bytes; if (num_bytes > 0) { ptr = stream_exec_->HostMemoryAllocate(num_bytes); if (ptr == nullptr) { diff --git a/tensorflow/core/common_runtime/device/device_mem_allocator.h b/tensorflow/core/common_runtime/device/device_mem_allocator.h index 52803102562..16a14c7114b 100644 --- a/tensorflow/core/common_runtime/device/device_mem_allocator.h +++ b/tensorflow/core/common_runtime/device/device_mem_allocator.h @@ -41,8 +41,10 @@ class DeviceMemAllocator : public SubAllocator { } ~DeviceMemAllocator() override {} - void* Alloc(size_t alignment, size_t num_bytes) override { + void* Alloc(size_t alignment, size_t num_bytes, + size_t* bytes_received) override { void* ptr = nullptr; + *bytes_received = num_bytes; if (num_bytes > 0) { if (use_unified_memory_) { ptr = stream_exec_->UnifiedMemoryAllocate(num_bytes); diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index a02ef9188af..1922cdf0937 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -90,6 +90,8 @@ EagerContext::EagerContext( log_device_placement_(opts.config.log_device_placement()), allow_soft_placement_(opts.config.allow_soft_placement()), num_active_steps_(0), + step_container_(std::make_unique<ScopedStepContainer>( + 0, [this](const string& name) { ClearResourceContainer(name); })), default_executor_(async), log_memory_(LogMemory::IsEnabled()), env_(opts.env), @@ -385,6 +387,11 @@ void EagerContext::ClearCachesAndDefaultExecutor() { for (auto& entry : registered_functions_) { entry.second->cached_kernel_keys->clear(); } + { + mutex_lock ml(metadata_mu_); + step_container_.reset(new ScopedStepContainer( + 0, [this](const string& name) { ClearResourceContainer(name); })); + } } void EagerContext::SetThreadLocalDevicePlacementPolicy( @@ -600,29 +607,20 @@ void EagerContext::ListDevices( void EagerContext::StartStep() { mutex_lock ml(metadata_mu_); num_active_steps_++; - if (step_container_ == nullptr) { - step_container_.reset( - new ScopedStepContainer(0, [this](const string& name) { - auto local_devices = local_device_mgr()->ListDevices(); - for (Device* device : local_devices) { - device->resource_manager()->Cleanup(name).IgnoreError(); - } - })); - } } void EagerContext::EndStep() { mutex_lock ml(metadata_mu_); num_active_steps_--; if (num_active_steps_ == 0) { - step_container_.reset(); + // TODO(b/139809335): This does not properly clean up remote resources + // Clean up the previous step container and create a new one. + step_container_.reset(new ScopedStepContainer( + 0, [this](const string& name) { ClearResourceContainer(name); })); } } ScopedStepContainer* EagerContext::StepContainer() { - if (num_active_steps_.load() == 0) { - return nullptr; - } mutex_lock ml(metadata_mu_); return step_container_.get(); } @@ -988,6 +986,15 @@ Status EagerContext::CPUDeviceOnTask(const Device* device, return FindDeviceFromName(cpu_device_name.c_str(), cpu_device); } +void EagerContext::ClearResourceContainer(const string& name) { + // TODO(b/139809335): This does not properly clean up remote resources + auto local_devices = local_device_mgr()->ListDevices(); + for (Device* device : local_devices) { + // Only ignore container not found errors. + device->resource_manager()->Cleanup(name).IgnoreError(); + } +} + namespace { Status GetTaskName(Device* d, string* task_name) { string ignored; diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 1e9516d5a69..dc833a71650 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -506,6 +506,8 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { void ResetClusterFLR(DistributedFunctionLibraryRuntime* cluster_flr); + void ClearResourceContainer(const string& name); + template <typename T> struct OwnedOrUnownedHelper { public: diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index d2327ff9592..3b73f018c0d 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -285,13 +285,7 @@ Status KernelAndDeviceOp::Run( params.runner = get_runner(); - params.step_container = - step_container == nullptr ? &step_container_ : step_container; - auto step_container_cleanup = gtl::MakeCleanup([step_container, this] { - if (step_container == nullptr) { - this->step_container_.CleanUp(); - } - }); + params.step_container = step_container; params.collective_executor = collective_executor_ ? collective_executor_->get() : nullptr; @@ -392,8 +386,7 @@ void KernelAndDeviceFunc::RunAsync( opts->cancellation_manager = local_cm.get(); } opts->allow_dead_tensors = true; - opts->step_container = - step_container == nullptr ? &step_container_ : step_container; + opts->step_container = step_container; opts->collective_executor = collective_executor_ ? collective_executor_->get() : nullptr; @@ -406,9 +399,6 @@ void KernelAndDeviceFunc::RunAsync( [opts, rendezvous, local_cm, step_container, this, done = std::move(done)](const Status& s) { rendezvous->Unref(); - if (step_container == nullptr) { - this->step_container_.CleanUp(); - } done(s); }); } diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h index f48452aa46b..f0d018ee093 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.h +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h @@ -201,10 +201,7 @@ class KernelAndDeviceOp final : public KernelAndDevice { : KernelAndDevice(flr, runner, std::move(collective_executor), host_cpu_device), rendezvous_(rendezvous), - log_memory_(log_memory), - step_container_(0, [this](const string& name) { - device_->resource_manager()->Cleanup(name).IgnoreError(); - }) {} + log_memory_(log_memory) {} ~KernelAndDeviceOp() override {} @@ -252,7 +249,6 @@ class KernelAndDeviceOp final : public KernelAndDevice { Rendezvous* const rendezvous_; checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_; const bool log_memory_; - ScopedStepContainer step_container_; }; // Represents a multi-device function. Functions can also be run using @@ -286,15 +282,7 @@ class KernelAndDeviceFunc : public KernelAndDevice { std::move(input_resource_dtypes_and_shapes)), name_(name), rendezvous_creator_(std::move(rendezvous_creator)), - get_op_id_(std::move(get_op_id)), - step_container_(0, [this](const string& name) { - // TODO(b/139809335): This does not properly clean up remote resources - const std::vector<Device*> devices = - pflr_->device_mgr()->ListDevices(); - for (Device* device : devices) { - device->resource_manager()->Cleanup(name).IgnoreError(); - } - }) {} + get_op_id_(std::move(get_op_id)) {} ~KernelAndDeviceFunc() override; @@ -362,8 +350,6 @@ class KernelAndDeviceFunc : public KernelAndDevice { std::function<Rendezvous*(const int64)> rendezvous_creator_; std::function<int64()> get_op_id_; - - ScopedStepContainer step_container_; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/pool_allocator.cc b/tensorflow/core/common_runtime/pool_allocator.cc index 6b40fcc4c70..6c66069b275 100644 --- a/tensorflow/core/common_runtime/pool_allocator.cc +++ b/tensorflow/core/common_runtime/pool_allocator.cc @@ -127,8 +127,9 @@ void* PoolAllocator::AllocateRaw(size_t alignment, size_t num_bytes) { delete pr; return PrepareChunk(r, alignment, num_bytes); } else { - void* ptr = allocator_->Alloc(kPoolAlignment, num_bytes); - return PrepareChunk(ptr, alignment, num_bytes); + size_t bytes_received; + void* ptr = allocator_->Alloc(kPoolAlignment, num_bytes, &bytes_received); + return PrepareChunk(ptr, alignment, bytes_received); } } @@ -256,8 +257,10 @@ void PoolAllocator::EvictOne() { } } -void* BasicCPUAllocator::Alloc(size_t alignment, size_t num_bytes) { +void* BasicCPUAllocator::Alloc(size_t alignment, size_t num_bytes, + size_t* bytes_received) { void* ptr = nullptr; + *bytes_received = num_bytes; if (num_bytes > 0) { if (numa_node_ == port::kNUMANoAffinity) { ptr = port::AlignedMalloc(num_bytes, static_cast<int>(alignment)); diff --git a/tensorflow/core/common_runtime/pool_allocator.h b/tensorflow/core/common_runtime/pool_allocator.h index 7c896cd4261..da1a3796830 100644 --- a/tensorflow/core/common_runtime/pool_allocator.h +++ b/tensorflow/core/common_runtime/pool_allocator.h @@ -22,6 +22,7 @@ limitations under the License. #include <map> #include <memory> #include <vector> + #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/platform/logging.h" @@ -154,7 +155,8 @@ class BasicCPUAllocator : public SubAllocator { ~BasicCPUAllocator() override {} - void* Alloc(size_t alignment, size_t num_bytes) override; + void* Alloc(size_t alignment, size_t num_bytes, + size_t* bytes_received) override; void Free(void* ptr, size_t num_bytes) override; diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 50f3b52e4c6..4a8f37eca1f 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -1426,6 +1426,10 @@ void ProcessFunctionLibraryRuntime::Run( InternalArgs* comp_args) -> Status { // "Index"s of _Arg nodes are unique when all arguments are local Tensors. for (const auto& it : comp_data.arg_indices) { + if (it.index >= args.size()) { + return errors::InvalidArgument( + "index ", it.index, " is out of range [0, ", args.size(), ")"); + } if (it.sub_index >= 0) { const Tensor& t = args[it.index]; if (t.dtype() != DT_RESOURCE) { diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index d386dc6ec6b..eb93402b1a4 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -393,7 +393,7 @@ cc_library( hdrs = [ "test_util.h", ], - data = glob(["testdata/*.pbtxt"]), + data = ["//tensorflow/core/data/service/testdata"], deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", diff --git a/tensorflow/core/data/service/data_service.cc b/tensorflow/core/data/service/data_service.cc index a119331c56a..680edcce33a 100644 --- a/tensorflow/core/data/service/data_service.cc +++ b/tensorflow/core/data/service/data_service.cc @@ -148,34 +148,16 @@ Status DataServiceDispatcherClient::RegisterDataset(GraphDef dataset, return Status::OK(); } -Status DataServiceDispatcherClient::CreateJob(int64 dataset_id, - ProcessingMode processing_mode, - int64& job_client_id) { - TF_RETURN_IF_ERROR(EnsureInitialized()); - CreateJobRequest req; - req.set_dataset_id(dataset_id); - req.set_processing_mode(ProcessingModeDef(processing_mode)); - CreateJobResponse resp; - grpc::ClientContext client_ctx; - grpc::Status status = stub_->CreateJob(&client_ctx, req, &resp); - if (!status.ok()) { - return grpc_util::WrapError( - absl::StrCat("Failed to create job for dataset with id ", dataset_id), - status); - } - job_client_id = resp.job_client_id(); - return Status::OK(); -} - Status DataServiceDispatcherClient::GetOrCreateJob( int64 dataset_id, ProcessingMode processing_mode, - const std::string& job_name, int job_name_index, int64& job_client_id) { + const absl::optional<JobKey>& job_key, int64& job_client_id) { TF_RETURN_IF_ERROR(EnsureInitialized()); GetOrCreateJobRequest req; req.set_dataset_id(dataset_id); req.set_processing_mode(ProcessingModeDef(processing_mode)); - req.set_job_name(job_name); - req.set_job_name_index(job_name_index); + if (job_key.has_value()) { + *req.mutable_job_key() = job_key.value(); + } GetOrCreateJobResponse resp; grpc::ClientContext client_ctx; grpc::Status status = stub_->GetOrCreateJob(&client_ctx, req, &resp); diff --git a/tensorflow/core/data/service/data_service.h b/tensorflow/core/data/service/data_service.h index f10abae8acf..935e549b303 100644 --- a/tensorflow/core/data/service/data_service.h +++ b/tensorflow/core/data/service/data_service.h @@ -100,16 +100,11 @@ class DataServiceDispatcherClient : public DataServiceClientBase { // dataset id in `dataset_id`. Status RegisterDataset(GraphDef dataset, int64& dataset_id); - // Creates a new tf.data service job for the specified dataset. The id for the - // created job will be stored in `job_client_id`. - Status CreateJob(int64 dataset_id, ProcessingMode processing_mode, - int64& job_client_id); - // Gets the job id for the job represented by the tuple // (job_name, job_name_index), and stores the id in `job_client_id`. If the // job doesn't exist yet, it will be created. Status GetOrCreateJob(int64 dataset_id, ProcessingMode processing_mode, - const std::string& job_name, int job_name_index, + const absl::optional<JobKey>& job_key, int64& job_client_id); // Releases a job client id, indicating that the id will no longer be used to diff --git a/tensorflow/core/data/service/dispatcher.proto b/tensorflow/core/data/service/dispatcher.proto index be423b23eb5..3592efc48a8 100644 --- a/tensorflow/core/data/service/dispatcher.proto +++ b/tensorflow/core/data/service/dispatcher.proto @@ -57,29 +57,23 @@ message GetOrRegisterDatasetResponse { int64 dataset_id = 1; } -message CreateJobRequest { - // The id of the dataset to create a job for. - int64 dataset_id = 1; - // A mode controlling how the tf.data service produces data for the job. - ProcessingModeDef processing_mode = 2; -} - -message CreateJobResponse { - // An id for the client that will read from the job. When the client is done - // with the job, they should call ReleaseJobClient with this id. - int64 job_client_id = 1; +message JobKey { + // A name for the job. + string job_name = 1; + // An index for the job. Multiple jobs can be created for the same name, if + // they have different indices. + int64 job_name_index = 2; } message GetOrCreateJobRequest { + reserved 3, 4; // The id of the dataset to create a job for. int64 dataset_id = 1; // A mode controlling how the tf.data service produces data for the job. ProcessingModeDef processing_mode = 2; - // A name for the job. - string job_name = 3; - // An index for the job. Multiple jobs can be created for the same name, if - // they have different indices. - int64 job_name_index = 4; + // Optional job key identifying a shared job. If not set, the RPC will always + // create a new job. + JobKey job_key = 5; } message GetOrCreateJobResponse { @@ -144,9 +138,6 @@ service DispatcherService { // Gets a job if it already exists, otherwise creates it. rpc GetOrCreateJob(GetOrCreateJobRequest) returns (GetOrCreateJobResponse); - // Creates a job for reading from the tf.data service. - rpc CreateJob(CreateJobRequest) returns (CreateJobResponse); - // Releases a job client so that a job may eventually be cleaned up. rpc ReleaseJobClient(ReleaseJobClientRequest) returns (ReleaseJobClientResponse); diff --git a/tensorflow/core/data/service/dispatcher_impl.cc b/tensorflow/core/data/service/dispatcher_impl.cc index 5e865242fd1..1e0ee62d27c 100644 --- a/tensorflow/core/data/service/dispatcher_impl.cc +++ b/tensorflow/core/data/service/dispatcher_impl.cc @@ -404,55 +404,36 @@ Status DataServiceDispatcherImpl::RegisterDataset(uint64 fingerprint, return Apply(update); } -Status DataServiceDispatcherImpl::CreateJob(const CreateJobRequest* request, - CreateJobResponse* response) { - TF_RETURN_IF_ERROR(CheckStarted()); - VLOG(3) << "Received create job request for dataset id " - << request->dataset_id(); - ProcessingMode processing_mode = ProcessingMode(request->processing_mode()); - std::shared_ptr<const Job> job; - std::vector<std::shared_ptr<const Task>> tasks; - { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(CreateJob(request->dataset_id(), processing_mode, - absl::optional<NamedJobKey>(), job)); - int64 job_client_id; - TF_RETURN_IF_ERROR(AcquireJobClientId(job, job_client_id)); - response->set_job_client_id(job_client_id); - TF_RETURN_IF_ERROR(CreateTasksForJob(job, tasks)); - } - TF_RETURN_IF_ERROR(AssignTasks(tasks)); - - VLOG(3) << "Creating job " << job->job_id << " for dataset " - << request->dataset_id(); - return Status::OK(); -} - Status DataServiceDispatcherImpl::GetOrCreateJob( const GetOrCreateJobRequest* request, GetOrCreateJobResponse* response) { TF_RETURN_IF_ERROR(CheckStarted()); - VLOG(3) << "Received get or create job request for dataset id " - << request->dataset_id() << " with name " << request->job_name() - << " and index " << request->job_name_index(); - NamedJobKey key(request->job_name(), request->job_name_index()); + VLOG(3) << "GetOrCreateJob(" << request->DebugString() << ")"; + absl::optional<NamedJobKey> key; + if (request->has_job_key()) { + key.emplace(request->job_key().job_name(), + request->job_key().job_name_index()); + } ProcessingMode requested_processing_mode = ProcessingMode(request->processing_mode()); std::shared_ptr<const Job> job; std::vector<std::shared_ptr<const Task>> tasks; { mutex_lock l(mu_); - Status s = state_.NamedJobByKey(key, job); - if (s.ok()) { - TF_RETURN_IF_ERROR(ValidateMatchingJob(job, requested_processing_mode, - request->dataset_id())); - int64 job_client_id; - TF_RETURN_IF_ERROR(AcquireJobClientId(job, job_client_id)); - response->set_job_client_id(job_client_id); - VLOG(3) << "Found existing job for name=" << key.name - << ", index=" << key.index << ". job_id: " << job->job_id; - return Status::OK(); - } else if (!errors::IsNotFound(s)) { - return s; + if (key.has_value()) { + Status s = state_.NamedJobByKey(key.value(), job); + if (s.ok()) { + TF_RETURN_IF_ERROR(ValidateMatchingJob(job, requested_processing_mode, + request->dataset_id())); + int64 job_client_id; + TF_RETURN_IF_ERROR(AcquireJobClientId(job, job_client_id)); + response->set_job_client_id(job_client_id); + VLOG(3) << "Found existing job for name=" << key.value().name + << ", index=" << key.value().index + << ". job_id: " << job->job_id; + return Status::OK(); + } else if (!errors::IsNotFound(s)) { + return s; + } } TF_RETURN_IF_ERROR( CreateJob(request->dataset_id(), requested_processing_mode, key, job)); @@ -462,8 +443,8 @@ Status DataServiceDispatcherImpl::GetOrCreateJob( TF_RETURN_IF_ERROR(CreateTasksForJob(job, tasks)); } TF_RETURN_IF_ERROR(AssignTasks(tasks)); - VLOG(3) << "Created job " << job->job_id << " for dataset " - << request->dataset_id() << " and name " << request->job_name(); + VLOG(3) << "Created job " << job->job_id << " for CreateJob(" + << request->DebugString() << ")"; return Status::OK(); } diff --git a/tensorflow/core/data/service/dispatcher_impl.h b/tensorflow/core/data/service/dispatcher_impl.h index e8cd3954d59..31c8b874ef9 100644 --- a/tensorflow/core/data/service/dispatcher_impl.h +++ b/tensorflow/core/data/service/dispatcher_impl.h @@ -68,8 +68,6 @@ class DataServiceDispatcherImpl { /// Client-facing API. Status GetOrRegisterDataset(const GetOrRegisterDatasetRequest* request, GetOrRegisterDatasetResponse* response); - Status CreateJob(const CreateJobRequest* request, - CreateJobResponse* response); Status GetOrCreateJob(const GetOrCreateJobRequest* request, GetOrCreateJobResponse* response); Status ReleaseJobClient(const ReleaseJobClientRequest* request, diff --git a/tensorflow/core/data/service/grpc_dispatcher_impl.cc b/tensorflow/core/data/service/grpc_dispatcher_impl.cc index 456fdfcc583..5d8b2c79a3e 100644 --- a/tensorflow/core/data/service/grpc_dispatcher_impl.cc +++ b/tensorflow/core/data/service/grpc_dispatcher_impl.cc @@ -45,7 +45,6 @@ HANDLER(WorkerUpdate); HANDLER(GetDatasetDef); HANDLER(GetSplit); HANDLER(GetOrRegisterDataset); -HANDLER(CreateJob); HANDLER(ReleaseJobClient); HANDLER(GetOrCreateJob); HANDLER(GetTasks); diff --git a/tensorflow/core/data/service/grpc_dispatcher_impl.h b/tensorflow/core/data/service/grpc_dispatcher_impl.h index ec6ffbb2d3f..3faa7bc36de 100644 --- a/tensorflow/core/data/service/grpc_dispatcher_impl.h +++ b/tensorflow/core/data/service/grpc_dispatcher_impl.h @@ -44,7 +44,6 @@ class GrpcDispatcherImpl : public DispatcherService::Service { HANDLER(GetDatasetDef); HANDLER(GetSplit); HANDLER(GetOrRegisterDataset); - HANDLER(CreateJob); HANDLER(ReleaseJobClient); HANDLER(GetOrCreateJob); HANDLER(GetTasks); diff --git a/tensorflow/core/data/service/testdata/BUILD b/tensorflow/core/data/service/testdata/BUILD new file mode 100644 index 00000000000..b32f47b91eb --- /dev/null +++ b/tensorflow/core/data/service/testdata/BUILD @@ -0,0 +1,14 @@ +# Description: +# Data service test data. + +load("//tensorflow:tensorflow.bzl", "filegroup") + +package( + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "testdata", + srcs = glob(["*.pbtxt"]), + visibility = ["//tensorflow/core/data/service:__pkg__"], +) diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index d615b36dc42..dd675d7e8d3 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -1251,6 +1251,31 @@ cc_library( ], ) +cc_library( + name = "extension_type_variant", + srcs = ["extension_type_variant.cc"], + hdrs = ["extension_type_variant.h"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "extension_type_variant_test", + size = "small", + srcs = ["extension_type_variant_test.cc"], + deps = [ + ":extension_type_variant", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + # All framewrok protos are self-contained, i.e. they only import other # protos from the same package, so we can build the protos here and then # link them from core:protos_all without circular dependencies. diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h index f7402f7b293..2899a59c548 100644 --- a/tensorflow/core/framework/allocator.h +++ b/tensorflow/core/framework/allocator.h @@ -439,7 +439,11 @@ class SubAllocator { const std::vector<Visitor>& free_visitors); virtual ~SubAllocator() {} - virtual void* Alloc(size_t alignment, size_t num_bytes) = 0; + // Allocates at least num_bytes. Returns actual number of bytes allocated in + // bytes_received. The caller can safely use the full bytes_received sized + // buffer following the returend pointer. + virtual void* Alloc(size_t alignment, size_t num_bytes, + size_t* bytes_received) = 0; virtual void Free(void* ptr, size_t num_bytes) = 0; protected: diff --git a/tensorflow/core/framework/cpu_allocator_impl.cc b/tensorflow/core/framework/cpu_allocator_impl.cc index 511cfce8ab5..567454a3295 100644 --- a/tensorflow/core/framework/cpu_allocator_impl.cc +++ b/tensorflow/core/framework/cpu_allocator_impl.cc @@ -156,7 +156,9 @@ class CPUAllocatorFactory : public AllocatorFactory { explicit CPUSubAllocator(CPUAllocator* cpu_allocator) : SubAllocator({}, {}), cpu_allocator_(cpu_allocator) {} - void* Alloc(size_t alignment, size_t num_bytes) override { + void* Alloc(size_t alignment, size_t num_bytes, + size_t* bytes_received) override { + *bytes_received = num_bytes; return cpu_allocator_->AllocateRaw(alignment, num_bytes); } diff --git a/tensorflow/core/framework/extension_type_variant.cc b/tensorflow/core/framework/extension_type_variant.cc new file mode 100644 index 00000000000..bd6a6f9d42f --- /dev/null +++ b/tensorflow/core/framework/extension_type_variant.cc @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/extension_type_variant.h" + +#include "tensorflow/core/framework/variant_op_registry.h" +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { + +constexpr const char ExtensionTypeVariant::kTypeName[]; + +void ExtensionTypeVariant::Encode(VariantTensorData* data) const { + data->set_type_name(TypeName()); + metadata_.type_spec_proto().SerializeToString(&data->metadata_string()); + for (const Tensor& tensor : flat_components_) { + data->add_tensor(tensor); + } +} + +bool ExtensionTypeVariant::Decode(const VariantTensorData& data) { + if (!metadata_.mutable_type_spec_proto()->ParseFromString( + data.metadata_string())) { + return false; + } + flat_components_ = data.tensors(); + return true; +} + +string ExtensionTypeVariant::DebugString() const { + string type_spec; + ::tensorflow::protobuf::TextFormat::Printer printer; + printer.SetSingleLineMode(true); + printer.PrintToString(metadata_.type_spec_proto(), &type_spec); + string result("<ExtensionTypeVariant type_spec={"); + result.append(type_spec.empty() ? "none" : type_spec); + result.append("}, components=["); + for (const auto& tensor : flat_components_) { + if (&tensor != &flat_components_[0]) { + result.append(", "); + } + result.append(tensor.DebugString()); + } + result.append("]>"); + return result; +} + +REGISTER_UNARY_VARIANT_DECODE_FUNCTION(ExtensionTypeVariant, + ExtensionTypeVariant::kTypeName); + +} // namespace tensorflow diff --git a/tensorflow/core/framework/extension_type_variant.h b/tensorflow/core/framework/extension_type_variant.h new file mode 100644 index 00000000000..d50abb5ffad --- /dev/null +++ b/tensorflow/core/framework/extension_type_variant.h @@ -0,0 +1,96 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_EXTENSION_TYPE_VARIANT_H_ +#define TENSORFLOW_CORE_FRAMEWORK_EXTENSION_TYPE_VARIANT_H_ + +#include <vector> + +#include "absl/types/span.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/protobuf/extension_type_variant.pb.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace tensorflow { + +// Encoding for a `tf.ExtensionType` value, that can be saved as a Variant. +// +// `tf.ExtensionType` (also known as `CompositeTensor`) is a Python base class +// used to Python types that are supported by TensorFlow APIs. Example +// ExtensionTypes include `tf.RaggedTensor` and `tf.SparseTensor`. +// +// `ExtensionTypeVariant` decomposes the `ExtensionType` value into two +// parts: +// +// * `components`: A list of Tensors, which encodes the value's dynamic +// data -- i.e., data that may change for different executions of a graph. +// * `type_spec_proto`: A serialized TypeSpec, which encodes the value's +// static data -- i.e., data that is the same for all executions of a graph. +// +// ExtensionTypeVariant can be stored in a Tensor with dtype=DT_VARIANT. +// Typically, extension type values are encoded with a scalar tensor containing +// a single ExtensionTypeVariant value. +class ExtensionTypeVariant { + public: + ExtensionTypeVariant(const TypeSpecProto& type_spec_proto, + absl::Span<Tensor> flat_components) + : flat_components_(flat_components.begin(), flat_components.end()) { + *metadata_.mutable_type_spec_proto() = type_spec_proto; + } + + // This type is default-constructible, copyable, assignable, and movable. + ExtensionTypeVariant() = default; + ExtensionTypeVariant(const ExtensionTypeVariant& other) = default; + ExtensionTypeVariant& operator=(ExtensionTypeVariant&& other) = default; + ExtensionTypeVariant& operator=(const ExtensionTypeVariant& other) = default; + + // Returns the list of Tensor components that encode this value's dynamic + // data. + absl::Span<const Tensor> flat_components() const { + return absl::MakeConstSpan(flat_components_); + } + + // Returns the serialized TypeSpec that encodes the value's static data. + TypeSpecProto type_spec_proto() const { return metadata_.type_spec_proto(); } + + // Variant methods. + string TypeName() const { return kTypeName; } + + // Updates `VariantTensorData` with an encoding for this value. + void Encode(VariantTensorData* data) const; + + // Updates this value to match the encoding in a given `VariantTensorData`. + bool Decode(const VariantTensorData& data); + + // Returns a string summary for this value. + string DebugString() const; + + // Name of this type (used for variant serialization). + static constexpr const char kTypeName[] = "ExtensionTypeVariant"; + + private: + // Tensor components for this value. + std::vector<Tensor> flat_components_; + + // TypeSpec for this value. ExtensionTypeVariantMetadata is a thin wrapper + // around a TypeSpecProto, which is used to retain flexibility to change the + // variant encoding. + ExtensionTypeVariantMetadata metadata_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_EXTENSION_TYPE_VARIANT_H_ diff --git a/tensorflow/core/framework/extension_type_variant_test.cc b/tensorflow/core/framework/extension_type_variant_test.cc new file mode 100644 index 00000000000..cd30b320bad --- /dev/null +++ b/tensorflow/core/framework/extension_type_variant_test.cc @@ -0,0 +1,99 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/extension_type_variant.h" + +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +// TypeSpecProto for a 2D Ragged Tensor. +constexpr const char* k2DRaggedTensorSpec = R"( +type_spec_class: RAGGED_TENSOR_SPEC +type_state: { + tuple_value: { + values: [ + {tensor_shape_value: {dim: [{size: -1}, {size: -1}]}}, # shape + {tensor_dtype_value: DT_INT32}, # dtype + {int64_value: 1}, # ragged_rank + {tensor_dtype_value: DT_INT64} # row_splits_dtype + ] + } +} +)"; + +// Returns an ExtensionTypeVariant encoding for a 2D ragged tensor with +// the specified values and row_splits. +ExtensionTypeVariant Make2DRaggedTensor(const std::vector<int32>& values, + const std::vector<int64>& splits) { + TypeSpecProto type_spec; + EXPECT_TRUE( + protobuf::TextFormat::ParseFromString(k2DRaggedTensorSpec, &type_spec)); + std::vector<Tensor> components; + components.push_back(test::AsTensor<int32>(values)); + components.push_back(test::AsTensor<int64>(splits)); + ExtensionTypeVariant v(type_spec, absl::MakeSpan(components)); + return v; +} + +TEST(ExtensionTypeVariantTest, EncodeAndDecodeRagged) { + ExtensionTypeVariant v = Make2DRaggedTensor( + /* values = */ {5, 5, 3, 4, 1, 8}, + /* splits = */ {0, 2, 3, 6}); + Tensor t(DT_VARIANT, {}); + + t.flat<Variant>()(0) = v; // Encode to variant. + auto* decoded = t.flat<Variant>()(0).get<ExtensionTypeVariant>(); + + EXPECT_EQ(v.type_spec_proto().SerializeAsString(), + decoded->type_spec_proto().SerializeAsString()); + EXPECT_EQ(v.flat_components().size(), 2); + test::ExpectTensorEqual<int32>(v.flat_components()[0], + decoded->flat_components()[0]); + test::ExpectTensorEqual<int64>(v.flat_components()[1], + decoded->flat_components()[1]); +} + +TEST(ExtensionTypeVariantTest, DebugStringForDefaultConstructed) { + ExtensionTypeVariant v; + EXPECT_EQ(v.DebugString(), + "<ExtensionTypeVariant type_spec={none}, components=[]>"); +} + +TEST(ExtensionTypeVariantTest, DebugStringForRagged) { + ExtensionTypeVariant v = Make2DRaggedTensor( + /* values = */ {5, 5, 3, 4, 1}, + /* splits = */ {0, 2, 3, 5}); + EXPECT_EQ(v.DebugString(), + "<ExtensionTypeVariant type_spec={type_spec_class: " + "RAGGED_TENSOR_SPEC type_state { tuple_value { values { " + "tensor_shape_value { dim { size: -1 } dim { size: -1 } } } " + "values { tensor_dtype_value: DT_INT32 } values " + "{ int64_value: 1 } values { tensor_dtype_value: DT_INT64 } } } }, " + "components=[Tensor<type: int32 shape: [5] values: 5 5 3...>, " + "Tensor<type: int64 shape: [4] values: 0 2 3...>]>"); +} + +TEST(ExtensionTypeVariantTest, TypeName) { + ExtensionTypeVariant v; + EXPECT_EQ(v.TypeName(), "ExtensionTypeVariant"); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h index 961e6df005b..6061d9737ec 100644 --- a/tensorflow/core/framework/resource_mgr.h +++ b/tensorflow/core/framework/resource_mgr.h @@ -94,14 +94,16 @@ class ScopedStepContainer { // prefix: optional string prefix to disambiguate step containers. ScopedStepContainer(const int64 step_id, std::function<void(const string&)> cleanup) - : container_(strings::StrCat("__per_step_", step_id)), + : step_id_(step_id), + container_(strings::StrCat("__per_step_", step_id)), cleanup_(cleanup), dirty_(false) {} ScopedStepContainer(const int64 step_id, std::function<void(const string&)> cleanup, const std::string& prefix) - : container_(strings::StrCat("__", prefix, "_per_step_", step_id)), + : step_id_(step_id), + container_(strings::StrCat("__", prefix, "_per_step_", step_id)), cleanup_(cleanup), dirty_(false) {} @@ -141,8 +143,10 @@ class ScopedStepContainer { template <typename T> Status LookupOrCreate(ResourceMgr* rm, const std::string& name, T** resource, std::function<Status(T**)> creator) TF_MUST_USE_RESULT; + int64 StepId() const { return step_id_; } private: + const int64 step_id_; const std::string container_; const std::function<void(const string&)> cleanup_; mutex mu_; diff --git a/tensorflow/core/framework/types.proto b/tensorflow/core/framework/types.proto index 01b598591e4..e5f33036dcd 100644 --- a/tensorflow/core/framework/types.proto +++ b/tensorflow/core/framework/types.proto @@ -84,4 +84,6 @@ enum SpecializedType { ST_INVALID = 0; // "tensorflow::TensorList" in the variant type registry. ST_TENSOR_LIST = 1; + // "tensorflow::data::Optional" in the variant type registry. + ST_OPTIONAL = 2; } diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD index 8ac60a0916e..c49e3937cb9 100644 --- a/tensorflow/core/grappler/clusters/BUILD +++ b/tensorflow/core/grappler/clusters/BUILD @@ -149,6 +149,7 @@ tf_cc_test( tags = [ "no_cuda_on_cpu_tap", "no_gpu", + "no_mac", # b/174055645 "nomsan", # TODO(b/160921160): broken by NOAUTOROLLBACK CL ], deps = [ diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index 79b3405712f..8bda3f60913 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -16,12 +16,9 @@ package( licenses = ["notice"], # Apache 2.0 ) -filegroup( +alias( name = "graph_properties_testdata", - srcs = glob([ - "graph_properties_testdata/*.pbtxt", - "graph_properties_testdata/*.pbtxt.html", - ]), + actual = "//tensorflow/core/grappler/costs/graph_properties_testdata:graph_properties_testdata", visibility = ["//visibility:public"], ) diff --git a/tensorflow/core/grappler/costs/graph_properties_testdata/BUILD b/tensorflow/core/grappler/costs/graph_properties_testdata/BUILD new file mode 100644 index 00000000000..3ba0979ca6c --- /dev/null +++ b/tensorflow/core/grappler/costs/graph_properties_testdata/BUILD @@ -0,0 +1,17 @@ +# Description: +# Graph properties test data. + +load("//tensorflow:tensorflow.bzl", "filegroup") + +package( + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "graph_properties_testdata", + srcs = glob([ + "*.pbtxt", + "*.pbtxt.html", + ]), + visibility = ["//visibility:public"], +) diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 14cfe09944f..b7d1d7eb597 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -41,6 +41,7 @@ const char kAttrSrcDevice[] = "send_device"; const char kAttrDstDevice[] = "recv_device"; const char kAttrTensorName[] = "tensor_name"; const char kChannelDevice[] = "Channel"; +const char kStreaming[] = "_streaming"; namespace { @@ -109,14 +110,10 @@ void UpdateDeviceAnnotationState(const NodeDef* node, (execution_count > 1 && node->attr().count(kOutputSame) == 0) ? 1 : 0; } -bool IsStreamingNode(const NodeDef& node) { - return node.attr().contains("_streaming"); -} - bool IsStreamingPort(const NodeDef& node, const int port) { - if (!IsStreamingNode(node)) return false; + if (!node.attr().contains(kStreaming)) return false; - auto& attr_list = node.attr().at("_streaming").list(); + auto& attr_list = node.attr().at(kStreaming).list(); bool is_streaming_port = false; if (port >= 0 && port < attr_list.b().size()) { is_streaming_port = attr_list.b(port); @@ -662,10 +659,12 @@ std::pair<const NodeDef*, const NodeDef*> SchedulerState::CreateSendRecv( auto input_node_port_num = NodePosition(input_name); string src_name; + bool control_input = false; if (input_node_port_num >= 0) { src_name = absl::StrCat(from->name(), "_", input_node_port_num); } else { src_name = absl::StrCat(from->name(), "_minus1"); + control_input = true; } // _Send op. @@ -695,16 +694,24 @@ std::pair<const NodeDef*, const NodeDef*> SchedulerState::CreateSendRecv( recv->add_input(send->name()); recv->set_device(DeviceName(to)); auto& recv_attr = *(recv->mutable_attr()); - if (from->attr().contains("_streaming")) { - *(recv_attr["_streaming"].mutable_list()) = - from->attr().at("_streaming").list(); - } recv_attr[kAttrInputSrc].set_s(input_name); if (input_node->attr().count(kAttrTensorName)) { recv_attr[kAttrTensorName].set_s( input_node->attr().at(kAttrTensorName).s()); } + // Propagate the streaming attribute to the send/recv nodes. + if (from->attr().contains(kStreaming) && !control_input) { + if (input_node_port_num >= from->attr().at(kStreaming).list().b_size()) { + LOG(ERROR) + << from->name() + << " port index larger than length of _streaming attribute list."; + } else if (from->attr().at(kStreaming).list().b(input_node_port_num)) { + send_attr[kStreaming].mutable_list()->add_b(true); + recv_attr[kStreaming].mutable_list()->add_b(true); + } + } + // NodeState for _Send op. auto& send_node_state = GetNodeStateOrCreateIt(send); send_node_state.device_name = send->device(); // Set Channel device. diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index 04f1e571ae5..308659aeac6 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -37,6 +37,7 @@ ABSL_CONST_INIT extern const char kAttrSrcDevice[]; ABSL_CONST_INIT extern const char kAttrDstDevice[]; ABSL_CONST_INIT extern const char kAttrTensorName[]; ABSL_CONST_INIT extern const char kChannelDevice[]; +ABSL_CONST_INIT extern const char kStreaming[]; struct NodeState { // A node (i.e., an op) takes a set of input:port pairs and produces @@ -438,7 +439,7 @@ class SchedulerState { // Auxiliary data structures for constructing NodeState and DeviceState. std::unique_ptr<GraphProperties> graph_properties_; // Initialized in Init(). Cluster* cluster_; // Not owned. - const GrapplerItem* grappler_item_; // Not owned. + const GrapplerItem* grappler_item_; // Not owned. bool use_static_shapes_; bool initialized_; bool track_mem_usage_snapshot_; diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc index 6e502490796..2da15512181 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc @@ -2547,7 +2547,7 @@ TEST_F(VirtualSchedulerTest, MemoryUsageForStreamingOps) { node.set_device(kCPU1); } if (node.name() == "z" || node.name() == "w") - (*node.mutable_attr())["_streaming"].mutable_list()->add_b(true); + (*node.mutable_attr())[kStreaming].mutable_list()->add_b(true); } InitScheduler(); diff --git a/tensorflow/core/grappler/optimizers/data/auto_shard.cc b/tensorflow/core/grappler/optimizers/data/auto_shard.cc index 1288f9695b9..ae315e97ccb 100644 --- a/tensorflow/core/grappler/optimizers/data/auto_shard.cc +++ b/tensorflow/core/grappler/optimizers/data/auto_shard.cc @@ -71,7 +71,7 @@ constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = { "ZipDataset" }; -constexpr std::array<const char*, 26> kPassThroughOps = { +constexpr std::array<const char*, 28> kPassThroughOps = { "_Retval", "AssertNextDataset", "BatchDataset", @@ -83,12 +83,14 @@ constexpr std::array<const char*, 26> kPassThroughOps = { "Identity", "MapAndBatchDataset", "MapDataset", + "MaxIntraOpParallelismDataset", "ModelDataset", "OptimizeDataset", "PaddedBatchDataset", "ParallelMapDataset", "ParseExampleDataset", "PrefetchDataset", + "PrivateThreadPoolDataset", "ReduceDataset", "RebatchDataset", "RepeatDataset", diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 36fee01c034..62488211e7d 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -83,6 +83,7 @@ package_group( packages = [ "//tensorflow/...", "//tensorflow_text/...", + "//third_party/google_research/google_research/tf3d/...", ], ) @@ -635,9 +636,12 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/common_runtime:device_mgr", + "//tensorflow/core/kernels/batching_util:adaptive_shared_batch_scheduler", "//tensorflow/core/kernels/batching_util:batch_resource_base", "//tensorflow/core/kernels/batching_util:concat_split_util", "//tensorflow/core/kernels/batching_util:periodic_function_dynamic", + "//tensorflow/core/platform:numbers", "@com_google_absl//absl/strings", ], alwayslink = 1, diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc index 5f742c37f35..4cbd214e04f 100644 --- a/tensorflow/core/kernels/batch_kernels.cc +++ b/tensorflow/core/kernels/batch_kernels.cc @@ -14,12 +14,15 @@ limitations under the License. ==============================================================================*/ #include "absl/strings/str_cat.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/device.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h" #include "tensorflow/core/kernels/batching_util/batch_resource_base.h" #include "tensorflow/core/kernels/batching_util/concat_split_util.h" #include "tensorflow/core/kernels/batching_util/periodic_function.h" @@ -29,8 +32,15 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/numbers.h" namespace tensorflow { +namespace { +constexpr int64 kMinInflightBatchesLimit = 16; +constexpr double kInitialInflightBatchesLimit = 64; +constexpr int64 kBatchesToAverageOver = 10; +constexpr int64 kMaxInflightBatchesLimit = 128; +} // namespace auto* batch_op_split_usage = monitoring::Gauge<string, 1>::New( "/tensorflow/serving/batching/enable_large_batch_splitting", @@ -52,6 +62,14 @@ void RecordBatchSplitUsage( } } +void RecordBatchParamNumBatchThreads(int64 num_batch_threads, + const string& model_name) { + static auto* cell = monitoring::Gauge<int64, 1>::New( + "/tensorflow/serving/batching/num_batch_threads", + "Tracks the number of batch threads of a model.", "model_name"); + cell->GetCell(model_name)->Set(num_batch_threads); +} + const string& GetModelName(OpKernelContext* ctx) { static string* kModelNameUnset = new string("model_name_unset"); if (!ctx->session_metadata()) return *kModelNameUnset; @@ -86,6 +104,24 @@ class BatchResource : public serving::BatchResourceBase { return Status::OK(); } + static Status Create( + AdaptiveBatcherT::Options adaptive_shared_batch_scheduler_options, + int32 max_batch_size, int32 batch_timeout_micros, + int32 max_enqueued_batches, const std::vector<int32>& allowed_batch_sizes, + FunctionLibraryRuntime::Handle fhandle, + std::unique_ptr<BatchResource>* resource) { + std::shared_ptr<AdaptiveBatcherT> batcher; + TF_RETURN_IF_ERROR(AdaptiveBatcherT::Create( + adaptive_shared_batch_scheduler_options, &batcher)); + + resource->reset(new BatchResource( + fhandle, std::move(batcher), + GetAdaptiveBatcherQueueOptions(max_batch_size, batch_timeout_micros, + max_enqueued_batches, true), + allowed_batch_sizes)); + return Status::OK(); + } + string DebugString() const final { return "BatchResource"; } private: @@ -99,6 +135,16 @@ class BatchResource : public serving::BatchResourceBase { std::move(allowed_batch_sizes)), fhandle_(fhandle) {} + BatchResource(FunctionLibraryRuntime::Handle fhandle, + std::shared_ptr<AdaptiveBatcherT> batcher, + const AdaptiveBatcherT::QueueOptions& batcher_queue_options, + std::vector<int32> allowed_batch_sizes) + : BatchResourceBase( + /*has_process_batch_function=*/fhandle != kInvalidHandle, + std::move(batcher), batcher_queue_options, + std::move(allowed_batch_sizes)), + fhandle_(fhandle) {} + void ProcessFuncBatchImpl( const BatchTask& last_task, absl::Span<const Tensor> inputs, std::vector<Tensor>* combined_outputs, @@ -134,11 +180,6 @@ class BatchFunctionKernel : public AsyncOpKernel { explicit BatchFunctionKernel(OpKernelConstruction* c) : AsyncOpKernel(c) { OP_REQUIRES_OK(c, c->GetAttr("container", &container_)); OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_)); - // If shared_name is not supplied, use name instead (prevent collisions by - // default). - if (shared_name_.empty()) { - shared_name_ = name(); - } OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_)); OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_)); OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_)); @@ -148,12 +189,30 @@ class BatchFunctionKernel : public AsyncOpKernel { c->GetAttr("max_enqueued_batches", &max_enqueued_batches_)); OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_)); - auto lib = c->function_library(); - OP_REQUIRES(c, lib != nullptr, errors::Internal("No function library")); - NameAttrList func; - OP_REQUIRES_OK(c, c->GetAttr("f", &func)); - OP_REQUIRES_OK( - c, lib->Instantiate(func.name(), AttrSlice(&func.attr()), &fhandle_)); + OP_REQUIRES_OK(c, c->GetAttr("f", &func_)); + if (num_batch_threads_ <= 0) { + adaptive_batch_scheduler_options_ = + absl::make_optional(AdaptiveBatchSchedulerOptions{ + kMinInflightBatchesLimit, kInitialInflightBatchesLimit, + kBatchesToAverageOver}); + + // Use a shared shared pool across all models if adaptive shared batch + // scheduler is used. + // `shared_name_` and `container_` is used to look up an instantiated + // scheduler instance in `ComputeAsync`. + container_ = "__adapative_container"; + shared_name_ = "__adaptive_global_shared_thread_pool"; + // Use name to prevent collisions by default. + if (batcher_queue_.empty()) { + batcher_queue_ = name(); + } + } + + if (shared_name_.empty()) { + // If shared_name is not supplied, use name instead (prevent collisions by + // default). + shared_name_ = name(); + } if (c->HasAttr("enable_large_batch_splitting")) { OP_REQUIRES_OK(c, c->GetAttr("enable_large_batch_splitting", @@ -175,16 +234,63 @@ class BatchFunctionKernel : public AsyncOpKernel { ? absl::make_optional(enable_large_batch_splitting_) : absl::nullopt, GetModelName(c)); + // TODO(b/173255290): Add num_batch_threads_ parameter to TFRT batch kernel. + RecordBatchParamNumBatchThreads(num_batch_threads_, GetModelName(c)); + + std::function<Status(BatchResource**)> creator; + + FunctionLibraryRuntime::Handle handle; + OP_REQUIRES_OK_ASYNC(c, GetOrCreateFunctionHandle(c, &handle), done); + + if (adaptive_batch_scheduler_options_ != absl::nullopt) { + creator = [this, handle](BatchResource** r) { + serving::AdaptiveSharedBatchScheduler< + serving::BatchResourceBase::BatchTask>::Options + adaptive_shared_batch_scheduler_options; + adaptive_shared_batch_scheduler_options.thread_pool_name = + "adaptive_batch_threads"; + adaptive_shared_batch_scheduler_options.num_batch_threads = + kMaxInflightBatchesLimit; + // adaptive_shared_batch_scheduler_options.full_batch_scheduling_boost_micros + // is 0 (default value) intentionally, so tasks are scheduled in a FIFO + // way. + // Two rationales to use default value (zero) for + // `full_batch_scheduling_boost_micros` + // 1) In this way, tasks scheduling policy is FIFO. Compared with round + // robin (what shared batch scheduler does), FIFO ensures that model + // with low QPS (i.e., models enqueue fewer tasks in the shared queue) + // will be processed timely. + // 2) If set, `full_batch_scheduling_boost_micros` should be of order + // the batch processing latency (which varies on a model basis). + // If a non-zero value is not set properly, it harms tail latency. + adaptive_shared_batch_scheduler_options.min_in_flight_batches_limit = + adaptive_batch_scheduler_options_->min_in_flight_batches_limit; + adaptive_shared_batch_scheduler_options + .initial_in_flight_batches_limit = + adaptive_batch_scheduler_options_->initial_in_flight_batches_limit; + adaptive_shared_batch_scheduler_options.batches_to_average_over = + adaptive_batch_scheduler_options_->batches_to_average_over; + std::unique_ptr<BatchResource> new_resource; + TF_RETURN_IF_ERROR(BatchResource::Create( + adaptive_shared_batch_scheduler_options, max_batch_size_, + batch_timeout_micros_, max_enqueued_batches_, allowed_batch_sizes_, + handle, &new_resource)); + *r = new_resource.release(); + return Status::OK(); + }; + } else { + creator = [this, handle](BatchResource** r) { + std::unique_ptr<BatchResource> new_resource; + TF_RETURN_IF_ERROR(BatchResource::Create( + num_batch_threads_, max_batch_size_, batch_timeout_micros_, + max_enqueued_batches_, allowed_batch_sizes_, handle, + enable_large_batch_splitting_, &new_resource)); + *r = new_resource.release(); + return Status::OK(); + }; + } + BatchResource* br; - std::function<Status(BatchResource**)> creator = [this](BatchResource** r) { - std::unique_ptr<BatchResource> new_resource; - TF_RETURN_IF_ERROR(BatchResource::Create( - num_batch_threads_, max_batch_size_, batch_timeout_micros_, - max_enqueued_batches_, allowed_batch_sizes_, fhandle_, - enable_large_batch_splitting_, &new_resource)); - *r = new_resource.release(); - return Status::OK(); - }; OP_REQUIRES_OK_ASYNC(c, c->resource_manager()->LookupOrCreate( container_, shared_name_, &br, creator), @@ -196,6 +302,75 @@ class BatchFunctionKernel : public AsyncOpKernel { // Assume br calls done, so nothing to do here. } + Status InstantiateFunction(OpKernelContext* c, + FunctionLibraryRuntime::Handle* handle) const { + // TODO(b/173748062): Merge this instantiation logic with PartitionedCall. + FunctionLibraryRuntime* lib = c->function_library(); + if (!lib) { + return errors::Internal("No function library"); + } + + FunctionLibraryRuntime::InstantiateOptions opts; + opts.target = lib->device() == nullptr ? "" : lib->device()->name(); + opts.is_multi_device_function = true; + + Device* cpu_device; + TF_RETURN_IF_ERROR(lib->device_mgr()->LookupDevice("CPU:0", &cpu_device)); + + const FunctionDef* fdef = + lib->GetFunctionLibraryDefinition()->Find(func_.name()); + if (!fdef) { + return errors::NotFound("Failed to find definition for function \"", + func_.name(), "\""); + } + OpInputList in_tensors; + TF_RETURN_IF_ERROR(c->input_list("in_tensors", &in_tensors)); + for (int i = 0; i < in_tensors.size(); i++) { + if (in_tensors[i].dtype() == DT_RESOURCE) { + return errors::InvalidArgument( + "BatchFunction cannot take resource inputs but input ", i, + " is a resource."); + } else { + // Currently, inputs are on CPU since they are concatenated on CPU + opts.input_devices.push_back(cpu_device->name()); + } + } + OpInputList captured_tensors; + TF_RETURN_IF_ERROR(c->input_list("captured_tensors", &captured_tensors)); + for (const Tensor& t : captured_tensors) { + if (t.dtype() == DT_RESOURCE) { + const ResourceHandle& rhandle = t.flat<ResourceHandle>()(0); + opts.input_devices.push_back(rhandle.device()); + } else { + opts.input_devices.push_back(cpu_device->name()); + } + } + const OpDef& signature = fdef->signature(); + for (int i = 0; i < signature.output_arg_size(); i++) { + // Currently, outputs must be on CPU since they are split on CPU. + opts.output_devices.push_back(cpu_device->name()); + } + if (opts.input_devices.size() != signature.input_arg_size()) { + return errors::InvalidArgument( + "Function takes ", signature.input_arg_size(), " argument(s) but ", + opts.input_devices.size(), " argument(s) were passed"); + } + return lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), opts, + handle); + } + + Status GetOrCreateFunctionHandle(OpKernelContext* c, + FunctionLibraryRuntime::Handle* handle) { + mutex_lock ml(mu_); + if (!fhandle_) { + TF_RETURN_IF_ERROR(InstantiateFunction(c, handle)); + fhandle_ = *handle; + } else { + *handle = fhandle_.value(); + } + return Status::OK(); + } + // Validates 'allowed_batch_sizes_'. The entries must increase monotonically, // and the last one must equal 'max_batch_size_'. Status ValidateAllowedBatchSizes() const { @@ -231,13 +406,34 @@ class BatchFunctionKernel : public AsyncOpKernel { int32 batch_timeout_micros_; int32 max_enqueued_batches_; std::vector<int32> allowed_batch_sizes_; - FunctionLibraryRuntime::Handle fhandle_; + NameAttrList func_; + absl::optional<FunctionLibraryRuntime::Handle> fhandle_ TF_GUARDED_BY(mu_); bool enable_large_batch_splitting_; bool has_attribute_enable_large_batch_splitting_; + mutex mu_; + + // Parameters for adaptive batch scheduler only. + // Note 'num_batch_threads_' above is shared by two implementations of batch + // scheduler. + struct AdaptiveBatchSchedulerOptions { + int64 min_in_flight_batches_limit; + double initial_in_flight_batches_limit; + int64 batches_to_average_over; + }; + absl::optional<AdaptiveBatchSchedulerOptions> + adaptive_batch_scheduler_options_ = absl::nullopt; }; REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU), BatchFunctionKernel); +// Currently all inputs and outputs are on the host. +// TODO(b/173748277): Accept inputs/outputs on the device. +REGISTER_KERNEL_BUILDER(Name("BatchFunction") + .Device(DEVICE_GPU) + .HostMemory("in_tensors") + .HostMemory("captured_tensors") + .HostMemory("out_tensors"), + BatchFunctionKernel); class BatchKernel : public AsyncOpKernel { public: @@ -282,8 +478,8 @@ class BatchKernel : public AsyncOpKernel { // Assume br calls done, so nothing to do here. } - // Validates 'allowed_batch_sizes_'. The entries must increase monotonically, - // and the last one must equal 'max_batch_size_'. + // Validates 'allowed_batch_sizes_'. The entries must increase + // monotonically, and the last one must equal 'max_batch_size_'. Status ValidateAllowedBatchSizes() const { if (allowed_batch_sizes_.empty()) { return Status::OK(); diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD index 5bbfbddf0d4..75c270ff405 100644 --- a/tensorflow/core/kernels/batching_util/BUILD +++ b/tensorflow/core/kernels/batching_util/BUILD @@ -241,17 +241,19 @@ cc_library( srcs = ["batch_resource_base.cc"], hdrs = ["batch_resource_base.h"], deps = [ + ":adaptive_shared_batch_scheduler", ":batch_scheduler", ":concat_split_util", ":shared_batch_scheduler", + ":threadsafe_status", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core/kernels/batching_util:threadsafe_status", "//tensorflow/core/platform:status", "//tensorflow/core/platform:thread_annotations", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/lib:traceme_encode", "//tensorflow/core/util:incremental_barrier", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.cc b/tensorflow/core/kernels/batching_util/batch_resource_base.cc index 81a16522c55..e4af643adc4 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.cc +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/kernels/batching_util/concat_split_util.h" #include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/monitoring/gauge.h" #include "tensorflow/core/lib/monitoring/percentile_sampler.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/profiler/lib/traceme_encode.h" @@ -74,6 +75,39 @@ void RecordBatchDelayMs(int64 batch_delay_ms, const string& model_name) { cell->GetCell(model_name)->Add(static_cast<double>(batch_delay_ms)); } +void RecordBatchParamBatchTimeoutMicros(int64 batch_timeout_micros, + const string& model_name) { + static auto* cell = monitoring::Gauge<int64, 1>::New( + "/tensorflow/serving/batching/batch_timeout_micros", + "Tracks how long a request can wait before being processed by a batch.", + "model_name"); + cell->GetCell(model_name)->Set(batch_timeout_micros); +} + +void RecordBatchParamMaxBatchSize(int64 max_batch_size, + const string& model_name) { + static auto* cell = monitoring::Gauge<int64, 1>::New( + "/tensorflow/serving/batching/max_batch_size", + "Tracks the maximum size of a batch.", "model_name"); + cell->GetCell(model_name)->Set(max_batch_size); +} + +void RecordBatchParamMaxEnqueuedBatches(int64 max_enqueued_batches, + const string& model_name) { + static auto* cell = monitoring::Gauge<int64, 1>::New( + "/tensorflow/serving/batching/max_enqueued_batches", + "Tracks the maximum number of enqueued batches.", "model_name"); + cell->GetCell(model_name)->Set(max_enqueued_batches); +} + +void RecordBatchParamAllowedBatchSizes(const string& allowed_batch_sizes, + const string& model_name) { + static auto* cell = monitoring::Gauge<string, 1>::New( + "/tensorflow/serving/batching/allowed_batch_sizes", + "Tracks the sizes that are allowed to form a batch.", "model_name"); + cell->GetCell(model_name)->Set(allowed_batch_sizes); +} + const string& GetModelName(OpKernelContext* ctx) { static string* kModelNameUnset = new string("model_name_unset"); if (!ctx->session_metadata()) return *kModelNameUnset; @@ -132,6 +166,14 @@ Status BatchResourceBase::RegisterInput( batch_components->inputs.push_back(tensor); } RecordInputBatchSize(tensors[0].shape().dim_size(0), GetModelName(context)); + RecordBatchParamBatchTimeoutMicros( + batcher_queue_options_.batch_timeout_micros, GetModelName(context)); + RecordBatchParamMaxBatchSize(batcher_queue_options_.max_execution_batch_size, + GetModelName(context)); + RecordBatchParamMaxEnqueuedBatches( + batcher_queue_options_.max_enqueued_batches, GetModelName(context)); + RecordBatchParamAllowedBatchSizes(allowed_batch_sizes_str_, + GetModelName(context)); OpInputList captured_tensors; const auto captured_status = context->input_list("captured_tensors", &captured_tensors); @@ -162,7 +204,6 @@ BatchResourceBase::GetBatcherQueueOptions( batcher_queue_options.input_batch_size_limit = max_batch_size; batcher_queue_options.max_enqueued_batches = max_enqueued_batches; batcher_queue_options.batch_timeout_micros = batch_timeout_micros; - // Support for splitting large batch is still in progress. batcher_queue_options.enable_large_batch_splitting = enable_large_batch_splitting; if (enable_large_batch_splitting) { @@ -185,6 +226,28 @@ BatchResourceBase::GetBatcherQueueOptions( return batcher_queue_options; } +/*static*/ BatchResourceBase::AdaptiveBatcherT::QueueOptions +BatchResourceBase::GetAdaptiveBatcherQueueOptions( + int32 max_batch_size, int32 batch_timeout_micros, + int32 max_enqueued_batches, bool enable_large_batch_splitting) { + AdaptiveBatcherT::QueueOptions batcher_queue_options; + batcher_queue_options.max_batch_size = max_batch_size; + batcher_queue_options.max_enqueued_batches = max_enqueued_batches; + batcher_queue_options.batch_timeout_micros = batch_timeout_micros; + + if (enable_large_batch_splitting) { + batcher_queue_options.split_input_task_func = + [](std::unique_ptr<BatchTask>* input_task, + int open_batch_remaining_slot, int max_batch_size, + std::vector<std::unique_ptr<BatchTask>>* output_tasks) -> Status { + return SplitInputTask(input_task, open_batch_remaining_slot, + max_batch_size, output_tasks); + }; + } + + return batcher_queue_options; +} + /*static*/ Status BatchResourceBase::ValidateBatch(const BatchT& batch) { for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) { const BatchResourceBase::BatchTask& task = batch.task(task_idx); diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.h b/tensorflow/core/kernels/batching_util/batch_resource_base.h index 89391f2defe..ea8c7729c89 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.h +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.h @@ -18,10 +18,12 @@ limitations under the License. #include <map> +#include "absl/strings/str_join.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h" #include "tensorflow/core/kernels/batching_util/batch_scheduler.h" #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h" #include "tensorflow/core/kernels/batching_util/threadsafe_status.h" @@ -48,7 +50,7 @@ class BatchResourceBase : public ResourceBase { const string& batcher_queue_name, AsyncOpKernel::DoneCallback done_callback); - protected: + public: // One task to be batched, corresponds to a `slice` of input from one batch-op // invocation. // @@ -106,6 +108,8 @@ class BatchResourceBase : public ResourceBase { // tensorflow::serving namespace, because some versions of compiler complain // about changing meaning of the symbols. using BatcherT = SharedBatchScheduler<BatchResourceBase::BatchTask>; + using AdaptiveBatcherT = + AdaptiveSharedBatchScheduler<BatchResourceBase::BatchTask>; using BatcherQueueT = BatchScheduler<BatchResourceBase::BatchTask>; using BatchT = Batch<BatchResourceBase::BatchTask>; @@ -116,6 +120,17 @@ class BatchResourceBase : public ResourceBase { : has_process_batch_function_(has_process_batch_function), batcher_(std::move(batcher)), batcher_queue_options_(batcher_queue_options), + allowed_batch_sizes_(std::move(allowed_batch_sizes)) { + allowed_batch_sizes_str_ = absl::StrJoin(allowed_batch_sizes_, ","); + } + + BatchResourceBase(bool has_process_batch_function, + std::shared_ptr<AdaptiveBatcherT> batcher, + const AdaptiveBatcherT::QueueOptions& batcher_queue_options, + std::vector<int32> allowed_batch_sizes) + : has_process_batch_function_(has_process_batch_function), + adaptive_batcher_(std::move(batcher)), + adaptive_batcher_queue_options_(batcher_queue_options), allowed_batch_sizes_(std::move(allowed_batch_sizes)) {} static BatcherT::QueueOptions GetBatcherQueueOptions( @@ -123,6 +138,10 @@ class BatchResourceBase : public ResourceBase { int32 max_enqueued_batches, const std::vector<int32>& allowed_batch_sizes, bool enable_large_batch_splitting); + static AdaptiveBatcherT::QueueOptions GetAdaptiveBatcherQueueOptions( + int32 max_batch_size, int32 batch_timeout_micros, + int32 max_enqueued_batches, bool enable_large_batch_splitting); + private: // Implementation of calling the process batch function. virtual void ProcessFuncBatchImpl( @@ -196,6 +215,10 @@ class BatchResourceBase : public ResourceBase { std::shared_ptr<BatcherT> batcher_; BatcherT::QueueOptions batcher_queue_options_; + // A batch scheduler, and options for creating queues. + std::shared_ptr<AdaptiveBatcherT> adaptive_batcher_; + AdaptiveBatcherT::QueueOptions adaptive_batcher_queue_options_; + // A collection of batcher queues, keyed on queue name. // TODO(olston): Garbage-collect unused queues (perhaps simply remove empty // ones (with a time delay?); it's okay if they get recreated later). @@ -204,6 +227,9 @@ class BatchResourceBase : public ResourceBase { TF_GUARDED_BY(batcher_queues_mu_); std::vector<int32> allowed_batch_sizes_; + // A concatenated string of <allowed_batch_sizes_>, separated by ",". This is + // used to record batching parameter. + string allowed_batch_sizes_str_; }; } // namespace serving diff --git a/tensorflow/core/kernels/batching_util/concat_split_util.h b/tensorflow/core/kernels/batching_util/concat_split_util.h index 77c4463f118..914c793bc89 100644 --- a/tensorflow/core/kernels/batching_util/concat_split_util.h +++ b/tensorflow/core/kernels/batching_util/concat_split_util.h @@ -71,8 +71,10 @@ Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor> inputs, TensorShape output_shape(input_shape); output_shape.set_dim(0, output_dim0); - TF_RETURN_IF_ERROR( - context->allocate_temp(DataTypeToEnum<T>::value, output_shape, output)); + AllocatorAttributes attr; + attr.set_on_host(true); + TF_RETURN_IF_ERROR(context->allocate_temp(DataTypeToEnum<T>::value, + output_shape, output, attr)); if (output->NumElements() > 0) { auto output_flat = output->shaped<T, 2>({1, output->NumElements()}); #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ @@ -167,8 +169,10 @@ Status SplitCPU(OpKernelContext* context, const Tensor& input, TensorShape output_shape = input.shape(); output_shape.set_dim(0, size); Tensor output; + AllocatorAttributes attr; + attr.set_on_host(true); TF_RETURN_IF_ERROR( - context->allocate_temp(input.dtype(), output_shape, &output)); + context->allocate_temp(input.dtype(), output_shape, &output, attr)); auto output_shaped = output.shaped<T, 2>({size, suffix_dim_size}); Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{ diff --git a/tensorflow/core/kernels/cwise_op_imag.cc b/tensorflow/core/kernels/cwise_op_imag.cc index bda9c19e3c2..fb76ec24df0 100644 --- a/tensorflow/core/kernels/cwise_op_imag.cc +++ b/tensorflow/core/kernels/cwise_op_imag.cc @@ -27,9 +27,12 @@ REGISTER_COMPLEX(CPU, float, complex64); REGISTER_COMPLEX(CPU, double, complex128); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \ + !defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED) REGISTER_COMPLEX(GPU, float, complex64); REGISTER_COMPLEX(GPU, double, complex128); #endif +#endif #undef REGISTER_COMPLEX } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_real.cc b/tensorflow/core/kernels/cwise_op_real.cc index 453f2801132..cb8484802f0 100644 --- a/tensorflow/core/kernels/cwise_op_real.cc +++ b/tensorflow/core/kernels/cwise_op_real.cc @@ -28,9 +28,12 @@ REGISTER_COMPLEX(CPU, float, complex64); REGISTER_COMPLEX(CPU, double, complex128); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \ + !defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED) REGISTER_COMPLEX(GPU, float, complex64); REGISTER_COMPLEX(GPU, double, complex128); #endif +#endif #undef REGISTER_COMPLEX } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_sin.cc b/tensorflow/core/kernels/cwise_op_sin.cc index d3e8f3b605c..3689f8b7399 100644 --- a/tensorflow/core/kernels/cwise_op_sin.cc +++ b/tensorflow/core/kernels/cwise_op_sin.cc @@ -20,7 +20,10 @@ REGISTER6(UnaryOp, CPU, "Sin", functor::sin, float, Eigen::half, bfloat16, double, complex64, complex128); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \ + !defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED) REGISTER3(UnaryOp, GPU, "Sin", functor::sin, float, Eigen::half, double); #endif +#endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc index 47f0f588991..93c2f57033e 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/snappy.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/protobuf/error_codes.pb.h" @@ -60,13 +61,8 @@ namespace data { /* static */ constexpr const char* const DataServiceDatasetOp::kOutputShapes; namespace { -// Once we've spent `kRetryTimeoutMicros` in `GetNextInternal`, we will wait for -// the current attempt to complete and perform no more retries. -const int64 kRetryTimeoutMicros = 1000LL * 1000 * 60 * 60; // 60 minutes. - // Default interval between task list refreshes. const int64 kDefaultTaskRefreshIntervalMs = 1000; // 1 second. - } // namespace // Dataset for reading data from the tf.data service non-deterministically. @@ -224,30 +220,23 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { &deregister_fn_)); dispatcher_ = absl::make_unique<DataServiceDispatcherClient>( dataset()->address_, dataset()->protocol_); - int64 deadline_micros = ctx->env()->NowMicros() + kRetryTimeoutMicros; - if (dataset()->job_name_.empty()) { - TF_RETURN_IF_ERROR(grpc_util::Retry( - [&]() { - return dispatcher_->CreateJob(dataset()->dataset_id_, - dataset()->processing_mode_, - job_client_id_); - }, - /*description=*/ - strings::StrCat("create job with dispatcher at ", - dataset()->address_), - deadline_micros)); - } else { - TF_RETURN_IF_ERROR(grpc_util::Retry( - [&]() { - return dispatcher_->GetOrCreateJob( - dataset()->dataset_id_, dataset()->processing_mode_, - dataset()->job_name_, iterator_index_, job_client_id_); - }, - /*description=*/ - strings::StrCat("get or create job with dispatcher at ", - dataset()->address_), - deadline_micros)); + int64 deadline_micros = kint64max; + absl::optional<JobKey> key; + if (!dataset()->job_name_.empty()) { + key.emplace(); + key.value().set_job_name(std::string(dataset()->job_name_)); + key.value().set_job_name_index(iterator_index_); } + TF_RETURN_IF_ERROR(grpc_util::Retry( + [&]() { + return dispatcher_->GetOrCreateJob(dataset()->dataset_id_, + dataset()->processing_mode_, key, + job_client_id_); + }, + /*description=*/ + strings::StrCat("get or create job with dispatcher at ", + dataset()->address_), + deadline_micros)); initialized_ = true; VLOG(1) << "Created data service job with id " << job_client_id_; return Status::OK(); @@ -496,8 +485,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { DCHECK(task_to_process != nullptr); VLOG(3) << "Processing task " << task_to_process->task_id; } - int64 deadline_micros = - Env::Default()->NowMicros() + kRetryTimeoutMicros; + int64 deadline_micros = kint64max; Status s = GetElement(task_to_process.get(), deadline_micros); if (!s.ok()) { mutex_lock l(mu_); diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc index bfa39a71bd9..53ff074438d 100644 --- a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc @@ -94,6 +94,14 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { return kUnknownCardinality; } + Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override { + inputs->push_back(selector_input_); + for (const auto& data_input : data_inputs_) { + inputs->push_back(data_input); + } + return Status::OK(); + } + Status CheckExternalState() const override { for (const auto& input : data_inputs_) { TF_RETURN_IF_ERROR(input->CheckExternalState()); diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc index 0606ca24da0..6c8622274d5 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op.cc +++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc @@ -85,7 +85,7 @@ void OptimizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, // clang-format off absl::flat_hash_map<string, uint64> live_experiments = { {"enable_gradient_descent", 0}, - {"map_parallelization", 50} + {"map_parallelization", 100} }; // clang-format on auto hash_func = [](const string& str) { return Hash64(str); }; diff --git a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc index a71f2902559..5607d41b178 100644 --- a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc @@ -41,6 +41,8 @@ class FuzzParseTensor : public FuzzSession { // remainder of the fuzzer testing. Of course, this duplicates some work // but it's better than repeating the investigation whenever Autofuzz // detects another similar OOM. + // After adding `-fsanitize=null` to ASAN (cl/317376103), the memory + // footprint increased, so we lower the maximum threshold to 2^18. string as_string = string(reinterpret_cast<const char*>(data), size); TensorProto proto; if (!ParseProtoUnlimited(&proto, as_string)) { @@ -53,7 +55,7 @@ class FuzzParseTensor : public FuzzSession { } TensorShape shape(proto.tensor_shape()); const int64 num_elements = shape.num_elements(); - const int64 max_num_elements = 1 << 20; + const int64 max_num_elements = 1 << 18; if (num_elements > max_num_elements) { LOG(WARNING) << "Requiring a tensor with too many elements\n"; return; diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD index c9789ea2421..b26716d19f2 100644 --- a/tensorflow/core/kernels/mlir_generated/BUILD +++ b/tensorflow/core/kernels/mlir_generated/BUILD @@ -52,10 +52,13 @@ filegroup( "unranked_op_gpu_cos.cc", "unranked_op_gpu_exp.cc", "unranked_op_gpu_floor.cc", + "unranked_op_gpu_imag.cc", "unranked_op_gpu_log.cc", "unranked_op_gpu_logical_not.cc", + "unranked_op_gpu_real.cc", "unranked_op_gpu_rsqrt.cc", "unranked_op_gpu_sign.cc", + "unranked_op_gpu_sin.cc", "unranked_op_gpu_sqrt.cc", "unranked_op_gpu_tanh.cc", ], @@ -104,10 +107,13 @@ tf_kernel_library( ":cos_unranked_kernels", ":exp_unranked_kernels", ":floor_unranked_kernels", + ":imag_unranked_kernels", ":log_unranked_kernels", ":logical_not_unranked_kernels", + ":real_unranked_kernels", ":rsqrt_unranked_kernels", ":sign_unranked_kernels", + ":sin_unranked_kernels", ":sqrt_unranked_kernels", ":tanh_unranked_kernels", ":unranked_op_gpu_base", @@ -116,19 +122,18 @@ tf_kernel_library( ), ) -# TODO(herhut): Uncomment once unranked kernels build again. -# tf_kernel_library( -# name = "cwise_binary_op", -# srcs = ["unranked_gpu_add.cc"], -# tags = [ -# "manual", -# ], -# deps = [ -# ":add_v2_unranked_kernels", -# ":unranked_op_gpu_base", -# "//third_party/eigen3", -# ], -# ) +tf_kernel_library( + name = "cwise_binary_op", + srcs = ["unranked_gpu_add.cc"], + tags = [ + "manual", + ], + deps = [ + ":add_v2_unranked_kernels", + ":unranked_op_gpu_base", + "//third_party/eigen3", + ], +) tf_kernel_library( name = "cwise_op", @@ -142,8 +147,7 @@ tf_kernel_library( ":cwise_unary_op", ]) + if_mlir_unranked_kernels_enabled( [ - # TODO(herhut): Uncomment once it builds again. - # ":cwise_binary_op", + ":cwise_binary_op", ], ), ) @@ -169,27 +173,27 @@ tf_cuda_cc_test( ], ) -# TODO(herhut): Uncomment once unranked kernels build again. -# tf_cuda_cc_test( -# name = "gpu_add_test", -# size = "small", -# srcs = if_mlir_generated_gpu_kernels_enabled(["gpu_add_test.cc"]), -# tags = tf_cuda_tests_tags() + [ -# "no_cuda_asan", # b/173033461 -# ], -# deps = [ -# "//tensorflow/core:framework", -# "//tensorflow/core:framework_internal", -# "//tensorflow/core:tensorflow", -# "//tensorflow/core:test", -# "//tensorflow/core:test_main", -# "//tensorflow/core:testlib", -# "//tensorflow/core/common_runtime:device", -# "//tensorflow/core/common_runtime:device_factory", -# "//tensorflow/core/kernels:cwise_op", -# "//tensorflow/core/kernels:ops_testutil", -# ], -# ) +tf_cuda_cc_test( + name = "gpu_add_test", + size = "small", + srcs = if_mlir_generated_gpu_kernels_enabled(["gpu_add_test.cc"]), + tags = tf_cuda_tests_tags() + [ + "no_cuda_asan", # b/173033461 + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/common_runtime:device", + "//tensorflow/core/common_runtime:device_factory", + "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:ops_testutil", + ], +) + # TODO(b/160731748): Re-enable when it works again. # gen_kernel_library( # name = "bias_add", @@ -238,6 +242,7 @@ gen_kernel_library( gen_kernel_library( name = "imag", + generate_unranked = True, tile_size = "256", types = [ "f32", @@ -269,6 +274,7 @@ gen_kernel_library( gen_kernel_library( name = "real", + generate_unranked = True, tile_size = "256", types = [ "f32", @@ -292,20 +298,19 @@ gen_kernel_library( unroll_factors = "4", ) -# TODO(herhut): Uncomment once it builds again. -# gen_kernel_library( -# name = "add_v2", -# generate_ranked = False, -# generate_unranked = True, -# tile_size = "256", -# types = [ -# "f16", -# "f32", -# "f64", -# "i64", -# ], -# unroll_factors = "4", -# ) +gen_kernel_library( + name = "add_v2", + generate_ranked = False, + generate_unranked = True, + tile_size = "256,1,1", + types = [ + "f16", + "f32", + "f64", + "i64", + ], + unroll_factors = "4", +) [ gen_kernel_library( diff --git a/tensorflow/core/kernels/mlir_generated/gpu_add_test.cc b/tensorflow/core/kernels/mlir_generated/gpu_add_test.cc index 42806ee807a..97e55971a02 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_add_test.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_add_test.cc @@ -64,6 +64,36 @@ class GpuAddTest : public OpsTestBase { test::ExpectEqual(expected_tensor, *GetOutput(0)); } + template <typename T, typename BaselineType = T> + void TestBroadcastingExpandAddOp() { + auto input_1 = {static_cast<T>(10)}; + auto input_2 = {static_cast<T>(1), static_cast<T>(2), static_cast<T>(3), + static_cast<T>(4), static_cast<T>(5), static_cast<T>(6)}; + std::vector<T> expected{ + static_cast<T>(11), static_cast<T>(12), static_cast<T>(13), + static_cast<T>(14), static_cast<T>(15), static_cast<T>(16), + }; + auto expected_shape = TensorShape({6}); + RunAndCompareAddOp<T, BaselineType>(input_1, TensorShape({1}), input_2, + TensorShape({6}), expected, + expected_shape); + } + + template <typename T, typename BaselineType = T> + void TestBroadcastingInDimAddOp() { + auto input_1 = {static_cast<T>(10), static_cast<T>(20), static_cast<T>(30)}; + auto input_2 = {static_cast<T>(1), static_cast<T>(2), static_cast<T>(3), + static_cast<T>(4), static_cast<T>(5), static_cast<T>(6)}; + std::vector<T> expected{ + static_cast<T>(11), static_cast<T>(22), static_cast<T>(33), + static_cast<T>(14), static_cast<T>(25), static_cast<T>(36), + }; + auto expected_shape = TensorShape({2, 3}); + RunAndCompareAddOp<T, BaselineType>(input_1, TensorShape({3}), input_2, + TensorShape({2, 3}), expected, + expected_shape); + } + template <typename T, typename BaselineType = T> void TestBroadcastingAddOp() { auto input_1 = {static_cast<T>(10), static_cast<T>(20)}; @@ -104,6 +134,52 @@ class GpuAddTest : public OpsTestBase { TensorShape{2, 3}); } + template <typename T, typename BaselineType = T> + void TestEqualShapesAddOp() { + auto input_1 = { + static_cast<T>(-std::numeric_limits<BaselineType>::infinity()), + static_cast<T>(-0.1), + static_cast<T>(-0.0), + static_cast<T>(0.0), + static_cast<T>(0.1), + static_cast<T>(std::numeric_limits<BaselineType>::infinity())}; + auto input_2 = { + static_cast<T>(-std::numeric_limits<BaselineType>::infinity()), + static_cast<T>(-0.1), + static_cast<T>(-0.0), + static_cast<T>(0.0), + static_cast<T>(0.1), + static_cast<T>(std::numeric_limits<BaselineType>::infinity())}; + std::vector<T> expected; + for (const T& inp : input_2) { + expected.push_back(static_cast<T>(static_cast<BaselineType>(inp) + + static_cast<BaselineType>(inp))); + } + RunAndCompareAddOp<T, BaselineType>(input_1, TensorShape{2, 3}, input_2, + TensorShape{2, 3}, expected, + TensorShape{2, 3}); + } + + template <typename T, typename BaselineType = T> + void TestOneIsScalarAddOp() { + auto input_1 = static_cast<T>(42); + auto input_2 = { + static_cast<T>(-std::numeric_limits<BaselineType>::infinity()), + static_cast<T>(-0.1), + static_cast<T>(-0.0), + static_cast<T>(0.0), + static_cast<T>(0.1), + static_cast<T>(std::numeric_limits<BaselineType>::infinity())}; + std::vector<T> expected; + for (const T& inp : input_2) { + expected.push_back(static_cast<T>(static_cast<BaselineType>(input_1) + + static_cast<BaselineType>(inp))); + } + RunAndCompareAddOp<T, BaselineType>({input_1}, TensorShape{}, input_2, + TensorShape{2, 3}, expected, + TensorShape{2, 3}); + } + template <typename T, typename RT = T> void TestIncompatibleShapes() { auto input_1 = {static_cast<T>(-0.1), static_cast<T>(-0.0), @@ -120,12 +196,51 @@ class GpuAddTest : public OpsTestBase { TEST_F(GpuAddTest, AddFloat) { RunAddOp<float>(); } TEST_F(GpuAddTest, AddDouble) { RunAddOp<double>(); } TEST_F(GpuAddTest, AddHalf) { RunAddOp<Eigen::half, float>(); } +TEST_F(GpuAddTest, AddInt64) { RunAddOp<int64, int64>(); } + +TEST_F(GpuAddTest, AddEqShapesFloat) { TestEqualShapesAddOp<float>(); } +TEST_F(GpuAddTest, AddEqShapesDouble) { TestEqualShapesAddOp<double>(); } +TEST_F(GpuAddTest, AddEqShapesHalf) { + TestEqualShapesAddOp<Eigen::half, float>(); +} +TEST_F(GpuAddTest, AddEqShapesInt64) { TestEqualShapesAddOp<int64>(); } + +TEST_F(GpuAddTest, AddScalarFloat) { TestOneIsScalarAddOp<float>(); } +TEST_F(GpuAddTest, AddScalarDouble) { TestOneIsScalarAddOp<double>(); } +TEST_F(GpuAddTest, AddScalarHalf) { + TestOneIsScalarAddOp<Eigen::half, float>(); +} +TEST_F(GpuAddTest, AddScalarInt64) { TestOneIsScalarAddOp<int64>(); } + +TEST_F(GpuAddTest, BCastExpandAddFloat) { + TestBroadcastingExpandAddOp<float>(); +} +TEST_F(GpuAddTest, BCastExpandAddDouble) { + TestBroadcastingExpandAddOp<double>(); +} +TEST_F(GpuAddTest, BCastExpandAddHalf) { + TestBroadcastingExpandAddOp<Eigen::half, float>(); +} +TEST_F(GpuAddTest, BCastExpandAddInt64) { + TestBroadcastingExpandAddOp<int64>(); +} + +TEST_F(GpuAddTest, BCastInDimAddFloat) { TestBroadcastingInDimAddOp<float>(); } +TEST_F(GpuAddTest, BCastInDimAddDouble) { + TestBroadcastingInDimAddOp<double>(); +} +TEST_F(GpuAddTest, BCastInDimAddHalf) { + TestBroadcastingInDimAddOp<Eigen::half, float>(); +} +TEST_F(GpuAddTest, BCastInDimAddInt64) { TestBroadcastingInDimAddOp<int64>(); } + TEST_F(GpuAddTest, BCastAddFloat) { TestBroadcastingAddOp<float>(); } TEST_F(GpuAddTest, BCastAddDouble) { TestBroadcastingAddOp<double>(); } TEST_F(GpuAddTest, BCastAddHalf) { TestBroadcastingAddOp<Eigen::half, float>(); } TEST_F(GpuAddTest, BCastAddInt64) { TestBroadcastingAddOp<int64>(); } + TEST_F(GpuAddTest, IncompatibleShapes) { TestIncompatibleShapes<float>(); } // TEST_F(GpuAddTest, AddV2Half) { RunAddOp<Eigen::half, float>(); } diff --git a/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc b/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc index a7541032ec3..a878a9d2431 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc @@ -14,10 +14,12 @@ limitations under the License. ==============================================================================*/ #include <cmath> +#include <complex> #include <functional> #include <initializer_list> #include <memory> #include <numeric> +#include <type_traits> #include <vector> #include "tensorflow/core/common_runtime/device.h" @@ -42,32 +44,39 @@ class GpuUnaryOpTest : public OpsTestBase { SetDevice(tensorflow::DEVICE_GPU, std::move(device_gpu)); } - template <typename T, typename RT = T> + // 'T' is the input type, 'RT' is the input type for the callback function, + // 'OutT' is the output type, 'ROutT' is the output type for the callback + // function. In most cases it is enough to just provide the input type, + // because all the types are the same. + template <typename T, typename RT = T, typename OutT = T, typename ROutT = RT> void Run(std::vector<int64> input_shape, std::vector<T> input, - const std::string op_name, RT (*expected_callback)(RT), + const std::string op_name, ROutT (*expected_callback)(RT), bool expect_equal = true) { assert(std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<int64>()) == input.size() && "Expected input length to equal to shape's number of elements."); TensorShape shape(input_shape); - TF_ASSERT_OK(NodeDefBuilder("some_name", op_name) - .Input(FakeInput(DataTypeToEnum<T>::v())) - .Attr("T", DataTypeToEnum<T>::v()) - .Finalize(node_def())); + NodeDefBuilder builder("some_name", op_name); + builder.Input(FakeInput(DataTypeToEnum<T>::v())) + .Attr("T", DataTypeToEnum<T>::v()); + if (!std::is_same_v<OutT, T>) { + builder.Attr("Tout", DataTypeToEnum<OutT>::v()); + } + TF_ASSERT_OK(builder.Finalize(node_def())); TF_ASSERT_OK(InitOp()); AddInputFromArray<T>(shape, input); TF_ASSERT_OK(RunOpKernel()); - Tensor expected_tensor(allocator(), DataTypeToEnum<T>::value, shape); - std::vector<T> expected; + Tensor expected_tensor(allocator(), DataTypeToEnum<OutT>::value, shape); + std::vector<OutT> expected; expected.reserve(input.size()); for (const T& inp : input) { expected.push_back( - static_cast<T>(expected_callback(static_cast<RT>(inp)))); + static_cast<OutT>(expected_callback(static_cast<RT>(inp)))); } - test::FillValues<T>(&expected_tensor, expected); + test::FillValues<OutT>(&expected_tensor, expected); if (expect_equal) { test::ExpectEqual(expected_tensor, *GetOutput(0)); } else { @@ -85,6 +94,16 @@ class GpuUnaryOpTest : public OpsTestBase { 0.5, 0.7, 0.9, 9.0, 18.0}); } + template <typename T> + std::vector<std::complex<T>> DefaultComplexInput() { + auto input = DefaultInput<T>(); + std::vector<std::complex<T>> complex_input; + for (T value : input) { + complex_input.emplace_back(value, -value); + } + return complex_input; + } + template <typename T> std::vector<T> DefaultInputGreaterThanZero() { return InputAsVector<T>({18.0, 9.0, 1e-6, 1.0, 0.1, 1e-6, 0.1, 0.2, 0.3, @@ -109,52 +128,6 @@ class GpuUnaryOpTest : public OpsTestBase { } }; -/// Test `tf.Tanh`. - -TEST_F(GpuUnaryOpTest, TanhFloat) { - Run<float>(DefaultInputShape(), DefaultInput<float>(), - /*op_name=*/"Tanh", - /*expected_callback=*/std::tanh, - /*expect_equal=*/false); -} - -TEST_F(GpuUnaryOpTest, TanhDouble) { - Run<double>(DefaultInputShape(), DefaultInput<double>(), - /*op_name=*/"Tanh", - /*expected_callback=*/std::tanh, - /*expect_equal=*/false); -} - -TEST_F(GpuUnaryOpTest, TanhHalf) { - Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(), - /*op_name=*/"Tanh", - /*expected_callback=*/std::tanh, - /*expect_equal=*/false); -} - -/// Test `tf.Ceil`. - -TEST_F(GpuUnaryOpTest, CeilFloat) { - Run<float>(DefaultInputShape(), DefaultInput<float>(), - /*op_name=*/"Ceil", - /*expected_callback=*/std::ceil, - /*expect_equal=*/true); -} - -TEST_F(GpuUnaryOpTest, CeilDouble) { - Run<double>(DefaultInputShape(), DefaultInput<double>(), - /*op_name=*/"Ceil", - /*expected_callback=*/std::ceil, - /*expect_equal=*/true); -} - -TEST_F(GpuUnaryOpTest, CeilHalf) { - Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(), - /*op_name=*/"Ceil", - /*expected_callback=*/std::ceil, - /*expect_equal=*/true); -} - /// Test `tf.Abs`. TEST_F(GpuUnaryOpTest, AbsFloat) { @@ -214,6 +187,29 @@ TEST_F(GpuUnaryOpTest, AbsInt64) { /*expect_equal=*/true); } +/// Test `tf.Ceil`. + +TEST_F(GpuUnaryOpTest, CeilFloat) { + Run<float>(DefaultInputShape(), DefaultInput<float>(), + /*op_name=*/"Ceil", + /*expected_callback=*/std::ceil, + /*expect_equal=*/true); +} + +TEST_F(GpuUnaryOpTest, CeilDouble) { + Run<double>(DefaultInputShape(), DefaultInput<double>(), + /*op_name=*/"Ceil", + /*expected_callback=*/std::ceil, + /*expect_equal=*/true); +} + +TEST_F(GpuUnaryOpTest, CeilHalf) { + Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(), + /*op_name=*/"Ceil", + /*expected_callback=*/std::ceil, + /*expect_equal=*/true); +} + /// Test `tf.Cos`. TEST_F(GpuUnaryOpTest, CosFloat) { @@ -283,6 +279,24 @@ TEST_F(GpuUnaryOpTest, FloorHalf) { /*expect_equal=*/true); } +/// Test `tf.Imag`. + +TEST_F(GpuUnaryOpTest, ImagFloat) { + Run<std::complex<float>, const std::complex<float>&, float, float>( + DefaultInputShape(), DefaultComplexInput<float>(), + /*op_name=*/"Imag", + /*expected_callback=*/std::imag, + /*expect_equal=*/false); +} + +TEST_F(GpuUnaryOpTest, ImagDouble) { + Run<std::complex<double>, const std::complex<double>&, double, double>( + DefaultInputShape(), DefaultComplexInput<double>(), + /*op_name=*/"Imag", + /*expected_callback=*/std::imag, + /*expect_equal=*/false); +} + /// Test `tf.Log`. TEST_F(GpuUnaryOpTest, LogFloat) { @@ -338,6 +352,24 @@ TEST_F(GpuUnaryOpTest, NegHalf) { /*expect_equal=*/true); } +/// Test `tf.Real`. + +TEST_F(GpuUnaryOpTest, RealFloat) { + Run<std::complex<float>, const std::complex<float>&, float, float>( + DefaultInputShape(), DefaultComplexInput<float>(), + /*op_name=*/"Real", + /*expected_callback=*/std::real, + /*expect_equal=*/false); +} + +TEST_F(GpuUnaryOpTest, RealDouble) { + Run<std::complex<double>, const std::complex<double>&, double, double>( + DefaultInputShape(), DefaultComplexInput<double>(), + /*op_name=*/"Real", + /*expected_callback=*/std::real, + /*expect_equal=*/false); +} + /// Test `tf.Rsqrt`. /// Reference implementation. @@ -369,6 +401,39 @@ TEST_F(GpuUnaryOpTest, RsqrtHalf) { /*expect_equal=*/false); } +/// Test `tf.Sign`. + +// Reference implementation +template <typename T> +T expected_sign(T x) { + if (x == 0) return 0; + if (x < 0) return -1; + return 1; +} + +// TODO(b/162577610): Enable these tests when our generated kernels handle 0.0 +// and -0.0 correctly. +TEST_F(GpuUnaryOpTest, DISABLED_SignFloat) { + Run<float>(DefaultInputShape(), DefaultInput<float>(), + /*op_name=*/"Sign", + /*expected_callback=*/expected_sign, + /*expect_equal=*/true); +} + +TEST_F(GpuUnaryOpTest, DISABLED_SignDouble) { + Run<double>(DefaultInputShape(), DefaultInput<double>(), + /*op_name=*/"Sign", + /*expected_callback=*/expected_sign, + /*expect_equal=*/true); +} + +TEST_F(GpuUnaryOpTest, DISABLED_SignHalf) { + Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(), + /*op_name=*/"Sign", + /*expected_callback=*/expected_sign, + /*expect_equal=*/true); +} + /// Test `tf.Sin`. TEST_F(GpuUnaryOpTest, SinFloat) { @@ -416,5 +481,28 @@ TEST_F(GpuUnaryOpTest, SqrtHalf) { /*expect_equal=*/false); } +/// Test `tf.Tanh`. + +TEST_F(GpuUnaryOpTest, TanhFloat) { + Run<float>(DefaultInputShape(), DefaultInput<float>(), + /*op_name=*/"Tanh", + /*expected_callback=*/std::tanh, + /*expect_equal=*/false); +} + +TEST_F(GpuUnaryOpTest, TanhDouble) { + Run<double>(DefaultInputShape(), DefaultInput<double>(), + /*op_name=*/"Tanh", + /*expected_callback=*/std::tanh, + /*expect_equal=*/false); +} + +TEST_F(GpuUnaryOpTest, TanhHalf) { + Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(), + /*op_name=*/"Tanh", + /*expected_callback=*/std::tanh, + /*expect_equal=*/false); +} + } // namespace } // end namespace tensorflow diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/angle.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/angle.mlir.tmpl index 8605741dfea..36ca1d8cf56 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/angle.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/angle.mlir.tmpl @@ -1,4 +1,5 @@ -func @Angle(%arg0: tensor<?xcomplex<elem_type>>) -> tensor<?xelem_type> { - %0 = "tf.Angle"(%arg0) : (tensor<?xcomplex<elem_type>>) -> tensor<?xelem_type> - return %0 : tensor<?xelem_type> +func @Angle_elem_type(%arg0: tensor<*xcomplex<elem_type>>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Angle"(%arg0) : (tensor<*xcomplex<elem_type>>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/asin.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/asin.mlir.tmpl index 741ca1b145c..0b5f35eb985 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/asin.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/asin.mlir.tmpl @@ -1,4 +1,5 @@ -func @Asin(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> { - %0 = "tf.Asin"(%arg0) : (tensor<?xelem_type>) -> tensor<?xelem_type> - return %0 : tensor<?xelem_type> +func @Asin_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Asin"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/atan.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/atan.mlir.tmpl index 80e22f38dbe..6ee1349549a 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/atan.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/atan.mlir.tmpl @@ -1,4 +1,5 @@ -func @Atan(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> { - %0 = "tf.Atan"(%arg0) : (tensor<?xelem_type>) -> tensor<?xelem_type> - return %0 : tensor<?xelem_type> +func @Atan_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Atan"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/bias_add.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/bias_add.mlir.tmpl index 210df66516c..78e5ffe204d 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/bias_add.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/bias_add.mlir.tmpl @@ -1,6 +1,6 @@ -func @BiasAdd(%arg0: tensor<?x?xelem_type>, %arg1: tensor<?xelem_type>) - -> tensor<?x?xelem_type> { +func @BiasAdd_elem_type(%arg0: tensor<*x*xelem_type>, %arg1: tensor<*xelem_type>) + -> tensor<*x*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { %0 = "tf.BiasAdd"(%arg0, %arg1) - : (tensor<?x?xelem_type>, tensor<?xelem_type>) -> tensor<?x?xelem_type> - return %0 : tensor<?x?xelem_type> + : (tensor<*x*xelem_type>, tensor<*xelem_type>) -> tensor<*x*xelem_type> + return %0 : tensor<*x*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/conj.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/conj.mlir.tmpl index 963a0740c6f..5608e9b37f3 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/conj.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/conj.mlir.tmpl @@ -1,4 +1,5 @@ -func @Conj(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> { - %0 = "tf.Conj"(%arg0) : (tensor<?xelem_type>) -> tensor<?xelem_type> - return %0 : tensor<?xelem_type> +func @Conj_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Conj"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/imag.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/imag.mlir.tmpl index c68c85a798c..d52487a8d98 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/imag.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/imag.mlir.tmpl @@ -1,4 +1,5 @@ -func @Imag(%arg0: tensor<?xcomplex<elem_type>>) -> tensor<?xelem_type> { - %0 = "tf.Imag"(%arg0) : (tensor<?xcomplex<elem_type>>) -> tensor<?xelem_type> - return %0 : tensor<?xelem_type> +func @Imag_elem_type(%arg0: tensor<*xcomplex<elem_type>>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Imag"(%arg0) : (tensor<*xcomplex<elem_type>>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/real.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/real.mlir.tmpl index 600fbe563b8..5ddbfb666f0 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/real.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/real.mlir.tmpl @@ -1,4 +1,5 @@ -func @Real(%arg0: tensor<?xcomplex<elem_type>>) -> tensor<?xelem_type> { - %0 = "tf.Real"(%arg0) : (tensor<?xcomplex<elem_type>>) -> tensor<?xelem_type> - return %0 : tensor<?xelem_type> +func @Real_elem_type(%arg0: tensor<*xcomplex<elem_type>>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Real"(%arg0) : (tensor<*xcomplex<elem_type>>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/sin.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/sin.mlir.tmpl index 0e739740f93..de3f1d8505f 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/sin.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/sin.mlir.tmpl @@ -1,4 +1,5 @@ -func @Sin(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> { - %0 = "tf.Sin"(%arg0) : (tensor<?xelem_type>) -> tensor<?xelem_type> - return %0 : tensor<?xelem_type> +func @Sin_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Sin"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h index badb510ad74..3a939293347 100644 --- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h +++ b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h @@ -57,7 +57,7 @@ template <typename ElemType> template <typename ElemType> Tensor ConvertDescriptorToTensor( - ::UnrankedMemRefType<ElemType> unranked_descriptor, DataType tf_data_type, + ::UnrankedMemRefType<ElemType> unranked_descriptor, DataType TfDataType, Allocator* allocator) { void* base_ptr = static_cast<void**>(unranked_descriptor.descriptor)[0]; TensorShape result_shape; @@ -69,25 +69,26 @@ Tensor ConvertDescriptorToTensor( base_ptr, sizeof(ElemType) * result_shape.num_elements(), allocator); // Tensor takes ownership of the buffer. - Tensor tensor{tf_data_type, result_shape, buffer}; + Tensor tensor{TfDataType, result_shape, buffer}; // When Tensor is constructed, its ref-counter is incremented. We need to // decrement it back. buffer->Unref(); return tensor; } -template <DataType tf_data_type, typename data_type, typename Derived> +template <DataType TfDataType, typename OutputDataType, typename Kernel, + typename InputDataType = OutputDataType> class MlirUnrankedOp : public OpKernel { public: explicit MlirUnrankedOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { - llvm::SmallVector<::UnrankedMemRefType<data_type>, 2> input_descs; + llvm::SmallVector<::UnrankedMemRefType<InputDataType>, 2> input_descs; for (int i = 0, end = ctx->num_inputs(); i < end; ++i) { input_descs.push_back( - std::move(ConvertTensorToDescriptor<data_type>(ctx->input(i)))); + std::move(ConvertTensorToDescriptor<InputDataType>(ctx->input(i)))); } - auto result_desc = Derived::Invoke(ctx, input_descs); + auto result_desc = Kernel::Invoke(ctx, input_descs); for (const auto& input_desc : input_descs) { free(input_desc.descriptor); } @@ -100,25 +101,38 @@ class MlirUnrankedOp : public OpKernel { for (int i = 0, end = ctx->num_inputs(); i < end; ++i) { const Tensor& input = ctx->input(i); if (input.data() == result_data_ptr) { - ctx->set_output(0, input); + // Run a bitcast in case the output type is different. + Tensor output; + OP_REQUIRES_OK(ctx, + output.BitcastFrom(input, TfDataType, input.shape())); + ctx->set_output(0, output); free(result_desc.descriptor); return; } } tensorflow::AllocatorAttributes attrs; auto* allocator = ctx->get_allocator(attrs); - Tensor result_tensor = ConvertDescriptorToTensor<data_type>( - result_desc, tf_data_type, allocator); + Tensor result_tensor = ConvertDescriptorToTensor<OutputDataType>( + result_desc, TfDataType, allocator); free(result_desc.descriptor); ctx->set_output(0, result_tensor); } }; #define MLIR_FUNCTION(tf_op, mlir_type) _mlir_ciface_##tf_op##_##mlir_type + #define REGISTER_KERNEL(tf_op, mlir_type, data_type) \ REGISTER_KERNEL_BUILDER( \ Name(#tf_op).Device(DEVICE_GPU).TypeConstraint<data_type>("T"), \ MlirUnranked##tf_op##mlir_type##Op); + +#define REGISTER_COMPLEX_KERNEL(tf_op, mlir_type, data_type, input_data_type) \ + REGISTER_KERNEL_BUILDER(Name(#tf_op) \ + .Device(DEVICE_GPU) \ + .TypeConstraint<input_data_type>("T") \ + .TypeConstraint<data_type>("Tout"), \ + MlirUnranked##tf_op##mlir_type##Op); + #define REGISTER_KERNEL_NO_TYPE_CONSTRAINT(tf_op, mlir_type) \ REGISTER_KERNEL_BUILDER(Name(#tf_op).Device(DEVICE_GPU), \ MlirUnranked##tf_op##mlir_type##Op); @@ -158,21 +172,26 @@ class MlirUnrankedOp : public OpKernel { GENERATE_UNARY_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \ REGISTER_KERNEL(tf_op, mlir_type, data_type) -#define GENERATE_UNARY_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \ +#define GENERATE_UNARY_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \ + GENERATE_UNARY_KERNEL2(tf_op, mlir_type, tf_data_type, data_type, data_type) + +#define GENERATE_UNARY_KERNEL2(tf_op, mlir_type, tf_data_type, data_type, \ + input_data_type) \ extern "C" ::UnrankedMemRefType<data_type> MLIR_FUNCTION(tf_op, mlir_type)( \ tensorflow::OpKernelContext * ctx, \ - const ::UnrankedMemRefType<data_type>* arg); \ + const ::UnrankedMemRefType<input_data_type>* arg); \ \ namespace { \ class MlirUnranked##tf_op##mlir_type##Op \ : public MlirUnrankedOp<tf_data_type, data_type, \ - MlirUnranked##tf_op##mlir_type##Op> { \ + MlirUnranked##tf_op##mlir_type##Op, \ + input_data_type> { \ public: \ using MlirUnrankedOp::MlirUnrankedOp; \ \ static ::UnrankedMemRefType<data_type> Invoke( \ OpKernelContext* ctx, \ - llvm::ArrayRef<::UnrankedMemRefType<data_type>> args) { \ + llvm::ArrayRef<::UnrankedMemRefType<input_data_type>> args) { \ return MLIR_FUNCTION(tf_op, mlir_type)(ctx, &args[0]); \ } \ }; \ diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_imag.cc b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_imag.cc new file mode 100644 index 00000000000..dec53c9f5a0 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_imag.cc @@ -0,0 +1,28 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <complex> + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h" + +namespace tensorflow { + +GENERATE_UNARY_KERNEL2(Imag, f32, DT_FLOAT, float, std::complex<float>); +REGISTER_COMPLEX_KERNEL(Imag, f32, float, std::complex<float>); +GENERATE_UNARY_KERNEL2(Imag, f64, DT_DOUBLE, double, std::complex<double>); +REGISTER_COMPLEX_KERNEL(Imag, f64, double, std::complex<double>); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_real.cc b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_real.cc new file mode 100644 index 00000000000..8567060fd62 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_real.cc @@ -0,0 +1,28 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <complex> + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h" + +namespace tensorflow { + +GENERATE_UNARY_KERNEL2(Real, f32, DT_FLOAT, float, std::complex<float>); +REGISTER_COMPLEX_KERNEL(Real, f32, float, std::complex<float>); +GENERATE_UNARY_KERNEL2(Real, f64, DT_DOUBLE, double, std::complex<double>); +REGISTER_COMPLEX_KERNEL(Real, f64, double, std::complex<double>); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sign.cc b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sign.cc index a29c53a2978..e49d640f7e2 100644 --- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sign.cc +++ b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sign.cc @@ -23,5 +23,6 @@ GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, f32, DT_FLOAT, float); GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, f64, DT_DOUBLE, double); GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, i32, DT_INT32, int32); GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, i64, DT_INT64, int64); +// TODO(b/162577610): Register the kernel for complex types and bfloat. } // namespace tensorflow diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc new file mode 100644 index 00000000000..2dd4a8dcda6 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_sin.cc @@ -0,0 +1,25 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h" + +namespace tensorflow { + +GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f16, DT_HALF, Eigen::half); +GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f32, DT_FLOAT, float); +GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f64, DT_DOUBLE, double); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h index dc311ca9d3d..449b0cbf8ad 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.h +++ b/tensorflow/core/kernels/segment_reduction_ops.h @@ -36,7 +36,8 @@ namespace functor { #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM typedef Eigen::GpuDevice GPUDevice; -// Functor for SegmentSumGPUOp. +// Functor for SegmentSumGPUOp & SegmentProdGPUOp & SegmentMaxGPUOp +// & SegmentMinGPUOp. // output_rows: the number of output segments (unique segment ids in // 'segment_ids'). // segment_ids_shape: shape of 'segment_ids' tensor. @@ -45,8 +46,9 @@ typedef Eigen::GpuDevice GPUDevice; // data_size: size of input data tensor. // data: input data tensor. // output: output reshaped to {output_rows, output.size/output_rows} -template <typename T, typename Index> -struct SegmentSumFunctor { +template <typename T, typename Index, typename InitialValueF, + typename ReductionF, typename AtomicReductionF> +struct SegmentReductionFunctor { void operator()(OpKernelContext* ctx, const GPUDevice& d, const Index output_rows, const TensorShape& segment_ids_shape, typename TTypes<Index>::ConstFlat segment_ids, @@ -66,9 +68,10 @@ struct UnsortedSegmentFunctor { }; #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -// reduction functors for the gpu + +// Atomic reduction functors for the gpu. template <typename T> -struct SumOpGpu { +struct AtomicSumOpGpu { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest, const T& value) { GpuAtomicAdd(dest, value); @@ -76,7 +79,7 @@ struct SumOpGpu { }; template <typename T> -struct ProdOpGpu { +struct AtomicProdOpGpu { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest, const T& value) { GpuAtomicMul(dest, value); @@ -84,7 +87,7 @@ struct ProdOpGpu { }; template <typename T> -struct MaxOpGpu { +struct AtomicMaxOpGpu { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest, const T& value) { GpuAtomicMax(dest, value); @@ -92,16 +95,49 @@ struct MaxOpGpu { }; template <typename T> -struct MinOpGpu { +struct AtomicMinOpGpu { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest, const T& value) { GpuAtomicMin(dest, value); } }; +// Non-atomic reduction functors for the gpu. +template <typename T> +struct NonAtomicSumOpGpu { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest, + const T& value) { + *dest += value; + } +}; + +template <typename T> +struct NonAtomicProdOpGpu { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest, + const T& value) { + *dest *= value; + } +}; + +template <typename T> +struct NonAtomicMaxOpGpu { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest, + const T& value) { + *dest = max(*dest, value); + } +}; + +template <typename T> +struct NonAtomicMinOpGpu { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest, + const T& value) { + *dest = min(*dest, value); + } +}; + #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM -// initial value functors +// Initial value functors. template <typename T> struct Zero { EIGEN_STRONG_INLINE T operator()() const { return T(0); } diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc index 418af1d6b6d..c6f9d8b9158 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc @@ -31,14 +31,14 @@ namespace tensorflow { using GPUDevice = Eigen::GpuDevice; -// SortedSegmentSumFunctor kernel reduces input data just as -// UnsortedSegmentSumCustomKernel does except that input data +// SortedSegmentReductionFunctor kernel reduces input data just as +// UnsortedSegmentReductionCustomKernel does except that input data // is partitioned along the outer reduction dimension. This is // because consecutive rows (elements in a row share the same // outer dimension index) in the flattened 2D input data likely // belong to the same segment in sorted segment sum operation. // Therefore such partitioning strategy has two advantages over -// the UnsortedSegmentSumFunctor kernel: +// the UnsortedSegmentReductionFunctor kernel: // 1. Each thread reduces across multiple rows before writing // answers to the global memory, we can therefore // write reduction results to global memory less often. @@ -51,18 +51,19 @@ using GPUDevice = Eigen::GpuDevice; // size OuterDimTileSize x 1. This strip runs across multiple // rows of input data and all reduction elements share one inner // dimension index. -template <typename T, typename Index, int OuterDimTileSize> -__global__ void SortedSegmentSumCustomKernel( +template <typename T, typename Index, int OuterDimTileSize, typename ReductionF, + typename AtomicReductionF> +__global__ void SortedSegmentReductionCustomKernel( const Index input_outer_dim_size, const Index inner_dim_size, const Index output_outer_dim_size, const Index* __restrict__ segment_ids, const T* __restrict__ input, T* __restrict__ output, - const Index total_stripe_count) { + const Index total_stripe_count, const T initial_value) { for (int stripe_index : GpuGridRangeX(total_stripe_count)) { const Index segment_offset = stripe_index % inner_dim_size; const Index input_outer_dim_index_base = stripe_index / inner_dim_size * Index(OuterDimTileSize); - T sum = T(0); + T reduce_res = initial_value; Index first_segment_id = segment_ids[input_outer_dim_index_base]; Index last_output_segment_id = output_outer_dim_size; @@ -72,24 +73,25 @@ __global__ void SortedSegmentSumCustomKernel( for (Index j = 0; j < actual_stripe_height; j++) { Index current_output_segment_id = segment_ids[input_outer_dim_index_base + j]; - // Decide whether to write result to global memory. - // Result is only written to global memory if we move - // to another segment. Otherwise we can keep accumulating - // locally. + // Decide whether to write result to global memory. Result is only written + // to global memory if we move to another segment. Otherwise we can keep + // accumulating locally. if (current_output_segment_id > last_output_segment_id) { const Index output_index = last_output_segment_id * inner_dim_size + segment_offset; - // decide whether to write result to global memory using atomic - // operations + // Decide whether to write result to global memory using atomic + // operations. if (last_output_segment_id == first_segment_id) { - GpuAtomicAdd(output + output_index, sum); + AtomicReductionF()(output + output_index, reduce_res); } else { - *(output + output_index) = sum; + ReductionF()(output + output_index, reduce_res); } - sum = T(0); + reduce_res = initial_value; } - sum += ldg(input + (input_outer_dim_index_base + j) * inner_dim_size + - segment_offset); + ReductionF()( + &reduce_res, + ldg(input + (input_outer_dim_index_base + j) * inner_dim_size + + segment_offset)); last_output_segment_id = current_output_segment_id; } // For the last result in a strip, always write using atomic operations @@ -97,7 +99,7 @@ __global__ void SortedSegmentSumCustomKernel( // the following strip. const Index output_index = last_output_segment_id * inner_dim_size + segment_offset; - GpuAtomicAdd(output + output_index, sum); + AtomicReductionF()(output + output_index, reduce_res); } } @@ -126,25 +128,30 @@ __global__ void UnsortedSegmentCustomKernel( namespace functor { -template <typename T, typename Index> -void SegmentSumFunctor<T, Index>::operator()( - OpKernelContext* ctx, const GPUDevice& d, const Index output_rows, - const TensorShape& segment_ids_shape, - typename TTypes<Index>::ConstFlat segment_ids, const Index data_size, - const T* data, typename TTypes<T, 2>::Tensor output) { +template <typename T, typename Index, typename InitialValueF, + typename ReductionF, typename AtomicReductionF> +void SegmentReductionFunctor< + T, Index, InitialValueF, ReductionF, + AtomicReductionF>::operator()(OpKernelContext* ctx, const GPUDevice& d, + const Index output_rows, + const TensorShape& segment_ids_shape, + typename TTypes<Index>::ConstFlat segment_ids, + const Index data_size, const T* data, + typename TTypes<T, 2>::Tensor output) { if (output.size() == 0) { return; } - // Set 'output' to zeros. + // Set 'output' to initial value. GpuLaunchConfig config = GetGpuLaunchConfig(output.size(), d); - TF_CHECK_OK(GpuLaunchKernel(SetZero<T>, config.block_count, + const T InitialValue = InitialValueF()(); + TF_CHECK_OK(GpuLaunchKernel(SetToValue<T>, config.block_count, config.thread_per_block, 0, d.stream(), - output.size(), output.data())); + output.size(), output.data(), InitialValue)); if (data_size == 0 || segment_ids_shape.num_elements() == 0) { return; } - // Launch kernel to compute sorted segment sum. + // Launch kernel to compute sorted segment reduction. // Notes: // *) 'input_total_size' is the total number of elements to process. // *) 'segment_ids.shape' is a prefix of data's shape. @@ -163,10 +170,12 @@ void SegmentSumFunctor<T, Index>::operator()( config = GetGpuLaunchConfig(total_stripe_count, d); TF_CHECK_OK(GpuLaunchKernel( - SortedSegmentSumCustomKernel<T, Index, OuterDimTileSize>, + SortedSegmentReductionCustomKernel<T, Index, OuterDimTileSize, ReductionF, + AtomicReductionF>, config.block_count, config.thread_per_block, 0, d.stream(), input_outer_dim_size, input_inner_dim_size, output_rows, - segment_ids.data(), data, output.data(), total_stripe_count)); + segment_ids.data(), data, output.data(), total_stripe_count, + InitialValue)); } template <typename T, typename Index, typename InitialValueF, @@ -207,8 +216,19 @@ struct UnsortedSegmentFunctor<GPUDevice, T, Index, InitialValueF, ReductionF> { } }; -#define DEFINE_SORTED_GPU_SPECS_INDEX(T, Index) \ - template struct SegmentSumFunctor<T, Index> +#define DEFINE_SORTED_GPU_SPECS_INDEX(T, Index) \ + template struct SegmentReductionFunctor<T, Index, functor::Zero<T>, \ + functor::NonAtomicSumOpGpu<T>, \ + functor::AtomicSumOpGpu<T>>; \ + template struct SegmentReductionFunctor<T, Index, functor::One<T>, \ + functor::NonAtomicProdOpGpu<T>, \ + functor::AtomicProdOpGpu<T>>; \ + template struct SegmentReductionFunctor<T, Index, functor::Highest<T>, \ + functor::NonAtomicMinOpGpu<T>, \ + functor::AtomicMinOpGpu<T>>; \ + template struct SegmentReductionFunctor<T, Index, functor::Lowest<T>, \ + functor::NonAtomicMaxOpGpu<T>, \ + functor::AtomicMaxOpGpu<T>>; #define DEFINE_SORTED_GPU_SPECS(T) \ DEFINE_SORTED_GPU_SPECS_INDEX(T, int32); \ @@ -218,16 +238,16 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_SORTED_GPU_SPECS); #define DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, Index) \ template struct UnsortedSegmentFunctor< \ - GPUDevice, T, Index, functor::Lowest<T>, functor::MaxOpGpu<T>>; \ + GPUDevice, T, Index, functor::Lowest<T>, functor::AtomicMaxOpGpu<T>>; \ template struct UnsortedSegmentFunctor< \ - GPUDevice, T, Index, functor::Highest<T>, functor::MinOpGpu<T>>; \ + GPUDevice, T, Index, functor::Highest<T>, functor::AtomicMinOpGpu<T>>; \ template struct UnsortedSegmentFunctor<GPUDevice, T, Index, functor::One<T>, \ - functor::ProdOpGpu<T>>; + functor::AtomicProdOpGpu<T>>; -// sum is the only op that supports all input types currently +// Sum is the only op that supports all input types currently. #define DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, Index) \ template struct UnsortedSegmentFunctor< \ - GPUDevice, T, Index, functor::Zero<T>, functor::SumOpGpu<T>>; + GPUDevice, T, Index, functor::Zero<T>, functor::AtomicSumOpGpu<T>>; #define DEFINE_REAL_GPU_SPECS(T) \ DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, int32); \ diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl.h b/tensorflow/core/kernels/segment_reduction_ops_impl.h index 7cf15ef5b72..3dd142d75e9 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_impl.h +++ b/tensorflow/core/kernels/segment_reduction_ops_impl.h @@ -206,24 +206,26 @@ class SegmentReductionOp : public OpKernel { }; #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -// SegmentSumGPUOp is a segment sum operator implemented for GPU only. -// TODO: This implementation of SegmentSumGPUOp is sometimes slower than +// SegmentReductionGPUOp is a segment reduction operator implemented for GPU +// only. +// TODO: This implementation of SegmentReductionGPUOp is sometimes slower than // its unsorted counterpart (mostly when problem size is small). // This is due to the following two main reasons and a cost-effective way // to resolve these problems is desirable. -// 1. Sorted segment sum requires a memory transfer from device to host in -// order to know the size of the output dimension whereas unsorted segment -// sum receives the size of the output dimension as an input parameter. -// 2. Sorted segment sum is essentially a tiled version of unsorted segment -// sum and therefore such optimization comes at an inherent cost. However -// such cost may not be justified when the problem size is small. When to -// use the tiled version or the untiled version depends on many factors -// including data alignments, ratio of calculation to memory traffic and -// obviously, the problem sizes. -template <class T, class Index> -class SegmentSumGPUOp : public AsyncOpKernel { +// 1. Sorted segment reduction requires a memory transfer from device to host +// in order to know the size of the output dimension whereas unsorted +// segment reduction receives the size of the output dimension as an input +// parameter. +// 2. Sorted segment reduction is essentially a tiled version of unsorted +// segment reduction and therefore such optimization comes at an inherent +// cost. However such cost may not be justified when the problem size is +// small. When to use the tiled version or the untiled version depends on +// many factors including data alignments, ratio of calculation to memory +// traffic and obviously, the problem sizes. +template <class T, class Index, class SegmentReductionFunctor> +class SegmentReductionGPUOp : public AsyncOpKernel { public: - explicit SegmentSumGPUOp(OpKernelConstruction* context) + explicit SegmentReductionGPUOp(OpKernelConstruction* context) : AsyncOpKernel(context) {} void ComputeAsync(OpKernelContext* context, DoneCallback done) override { @@ -265,11 +267,11 @@ class SegmentSumGPUOp : public AsyncOpKernel { ->ThenMemcpy(output_rows_host.mutable_data(), output_rows_device, sizeof(Index)) .ok(), - errors::Internal( - "SegmentSumGPUOp: failed to copy output_rows from device"), + errors::Internal(type_string() + + ": failed to copy output_rows from device"), done); - functor::SegmentSumFunctor<T, Index> functor_; + SegmentReductionFunctor functor_; auto create_and_check_output = [context, output_rows_host, &input, &segment_ids, &functor_, done]() { // Ensure that within the callback, the proper GPU settings are diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc b/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc index f71a8dac462..97c0762c36f 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc +++ b/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc @@ -113,17 +113,39 @@ REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128); #undef REGISTER_COMPLEX_CPU_KERNELS_ALL #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#define REGISTER_GPU_SORTED_KERNELS(type, index_type) \ - REGISTER_KERNEL_BUILDER(Name("SegmentSum") \ - .Device(DEVICE_GPU) \ - .TypeConstraint<type>("T") \ - .TypeConstraint<index_type>("Tindices"), \ - SegmentSumGPUOp<type, index_type>) +#define REGISTER_GPU_KERNEL_SORTEDSEGMENT( \ + name, type, index_type, initial_value_functor, reduction_kernel_functor, \ + atomic_reduction_kernel_functor) \ + REGISTER_KERNEL_BUILDER( \ + Name(name) \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<index_type>("Tindices"), \ + SegmentReductionGPUOp< \ + type, index_type, \ + functor::SegmentReductionFunctor< \ + type, index_type, initial_value_functor, \ + reduction_kernel_functor, atomic_reduction_kernel_functor> >) + +#define REGISTER_GPU_SORTED_KERNELS(type, index_type) \ + REGISTER_GPU_KERNEL_SORTEDSEGMENT( \ + "SegmentSum", type, index_type, functor::Zero<type>, \ + functor::NonAtomicSumOpGpu<type>, functor::AtomicSumOpGpu<type>); \ + REGISTER_GPU_KERNEL_SORTEDSEGMENT( \ + "SegmentProd", type, index_type, functor::One<type>, \ + functor::NonAtomicProdOpGpu<type>, functor::AtomicProdOpGpu<type>); \ + REGISTER_GPU_KERNEL_SORTEDSEGMENT( \ + "SegmentMin", type, index_type, functor::Highest<type>, \ + functor::NonAtomicMinOpGpu<type>, functor::AtomicMinOpGpu<type>); \ + REGISTER_GPU_KERNEL_SORTEDSEGMENT( \ + "SegmentMax", type, index_type, functor::Lowest<type>, \ + functor::NonAtomicMaxOpGpu<type>, functor::AtomicMaxOpGpu<type>); #define REGISTER_GPU_SORTED_KERNELS_ALL(type) \ REGISTER_GPU_SORTED_KERNELS(type, int32) TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SORTED_KERNELS_ALL); +#undef REGISTER_GPU_KERNEL_SORTEDSEGMENT #undef REGISTER_GPU_SORTED_KERNELS #undef REGISTER_GPU_SORTED_KERNELS_ALL #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl_2.cc b/tensorflow/core/kernels/segment_reduction_ops_impl_2.cc index f2164260b8f..21b4ddf821d 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_impl_2.cc +++ b/tensorflow/core/kernels/segment_reduction_ops_impl_2.cc @@ -63,17 +63,39 @@ REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128); #undef REGISTER_COMPLEX_CPU_KERNELS_ALL #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#define REGISTER_GPU_SORTED_KERNELS(type, index_type) \ - REGISTER_KERNEL_BUILDER(Name("SegmentSum") \ - .Device(DEVICE_GPU) \ - .TypeConstraint<type>("T") \ - .TypeConstraint<index_type>("Tindices"), \ - SegmentSumGPUOp<type, index_type>) +#define REGISTER_GPU_KERNEL_SORTEDSEGMENT( \ + name, type, index_type, initial_value_functor, reduction_kernel_functor, \ + atomic_reduction_kernel_functor) \ + REGISTER_KERNEL_BUILDER( \ + Name(name) \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<index_type>("Tindices"), \ + SegmentReductionGPUOp< \ + type, index_type, \ + functor::SegmentReductionFunctor< \ + type, index_type, initial_value_functor, \ + reduction_kernel_functor, atomic_reduction_kernel_functor> >) + +#define REGISTER_GPU_SORTED_KERNELS(type, index_type) \ + REGISTER_GPU_KERNEL_SORTEDSEGMENT( \ + "SegmentSum", type, index_type, functor::Zero<type>, \ + functor::NonAtomicSumOpGpu<type>, functor::AtomicSumOpGpu<type>); \ + REGISTER_GPU_KERNEL_SORTEDSEGMENT( \ + "SegmentProd", type, index_type, functor::One<type>, \ + functor::NonAtomicProdOpGpu<type>, functor::AtomicProdOpGpu<type>); \ + REGISTER_GPU_KERNEL_SORTEDSEGMENT( \ + "SegmentMin", type, index_type, functor::Highest<type>, \ + functor::NonAtomicMinOpGpu<type>, functor::AtomicMinOpGpu<type>); \ + REGISTER_GPU_KERNEL_SORTEDSEGMENT( \ + "SegmentMax", type, index_type, functor::Lowest<type>, \ + functor::NonAtomicMaxOpGpu<type>, functor::AtomicMaxOpGpu<type>); #define REGISTER_GPU_SORTED_KERNELS_ALL(type) \ REGISTER_GPU_SORTED_KERNELS(type, int64); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SORTED_KERNELS_ALL); +#undef REGISTER_GPU_KERNEL_SORTEDSEGMENT #undef REGISTER_GPU_SORTED_KERNELS #undef REGISTER_GPU_SORTED_KERNELS_ALL #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl_3.cc b/tensorflow/core/kernels/segment_reduction_ops_impl_3.cc index eef5a532b29..c809a5ed1c1 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_impl_3.cc +++ b/tensorflow/core/kernels/segment_reduction_ops_impl_3.cc @@ -88,18 +88,18 @@ REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128); #define REGISTER_REAL_GPU_UNSORTED_KERNELS(type, index_type) \ REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type, \ functor::Lowest<type>, \ - functor::MaxOpGpu<type>); \ + functor::AtomicMaxOpGpu<type>); \ REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type, \ functor::Highest<type>, \ - functor::MinOpGpu<type>); \ + functor::AtomicMinOpGpu<type>); \ REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \ functor::One<type>, \ - functor::ProdOpGpu<type>); + functor::AtomicProdOpGpu<type>); #define REGISTER_SUM_GPU_UNSORTED_KERNELS(type, index_type) \ REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \ functor::Zero<type>, \ - functor::SumOpGpu<type>); + functor::AtomicSumOpGpu<type>); #define REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL(type) \ REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int32) diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl_4.cc b/tensorflow/core/kernels/segment_reduction_ops_impl_4.cc index cad6f8a5e08..c47e8d171e5 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_impl_4.cc +++ b/tensorflow/core/kernels/segment_reduction_ops_impl_4.cc @@ -88,18 +88,18 @@ REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128); #define REGISTER_REAL_GPU_UNSORTED_KERNELS(type, index_type) \ REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type, \ functor::Lowest<type>, \ - functor::MaxOpGpu<type>); \ + functor::AtomicMaxOpGpu<type>); \ REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type, \ functor::Highest<type>, \ - functor::MinOpGpu<type>); \ + functor::AtomicMinOpGpu<type>); \ REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \ functor::One<type>, \ - functor::ProdOpGpu<type>); + functor::AtomicProdOpGpu<type>); #define REGISTER_SUM_GPU_UNSORTED_KERNELS(type, index_type) \ REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \ functor::Zero<type>, \ - functor::SumOpGpu<type>); + functor::AtomicSumOpGpu<type>); #define REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL(type) \ REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int64) diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index f2f6889c900..3113de527e4 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -376,6 +376,61 @@ inline T FtrlCompute(const T& accum, const T& linear, const T& lr, const T& l1, } } +template <typename T, typename GradTy, typename GradeMaybeWithShrinkageTy, + typename AccumTy, typename LinearTy, typename VarTy> +void ComputeFtrl(GradTy grad, + GradeMaybeWithShrinkageTy grad_maybe_with_shrinkage, + AccumTy accum, LinearTy linear, VarTy var, T l1_scalar, + T l2_scalar, bool multiply_linear_by_lr, T lr_power_scalar, + T lr_scalar) { + auto new_accum = accum + grad.square(); + if (multiply_linear_by_lr) { + if (lr_power_scalar == static_cast<T>(-0.5)) { + linear += grad_maybe_with_shrinkage * lr_scalar - + (new_accum.sqrt() - accum.sqrt()) * var; + } else { + linear += + grad_maybe_with_shrinkage * lr_scalar - + (new_accum.pow(-lr_power_scalar) - accum.pow(-lr_power_scalar)) * var; + } + } else { + if (lr_power_scalar == static_cast<T>(-0.5)) { + linear += grad_maybe_with_shrinkage - + (new_accum.sqrt() - accum.sqrt()) / lr_scalar * var; + } else { + linear += grad_maybe_with_shrinkage - (new_accum.pow(-lr_power_scalar) - + accum.pow(-lr_power_scalar)) / + lr_scalar * var; + } + } + auto l1_reg_adjust = + (multiply_linear_by_lr ? linear.cwiseMin(l1_scalar * lr_scalar) + .cwiseMax(-l1_scalar * lr_scalar) + : linear.cwiseMin(l1_scalar).cwiseMax(-l1_scalar)); + auto x = l1_reg_adjust - linear; + if (multiply_linear_by_lr) { + if (lr_power_scalar == static_cast<T>(-0.5)) { + auto y = new_accum.sqrt() + + linear.constant(static_cast<T>(2) * l2_scalar * lr_scalar); + var = x / y; + } else { + auto y = new_accum.pow(-lr_power_scalar) + + linear.constant(static_cast<T>(2) * l2_scalar * lr_scalar); + var = x / y; + } + } else { + if (lr_power_scalar == static_cast<T>(-0.5)) { + auto y = new_accum.sqrt() / new_accum.constant(lr_scalar) + + linear.constant(static_cast<T>(2) * l2_scalar); + var = x / y; + } else { + auto y = new_accum.pow(-lr_power_scalar) / new_accum.constant(lr_scalar) + + linear.constant(static_cast<T>(2) * l2_scalar); + var = x / y; + } + } + accum += grad.square(); +} } // namespace template <typename T, typename Tindex, bool has_l2_shrinkage> @@ -417,70 +472,25 @@ struct SparseApplyFtrl<CPUDevice, T, Tindex, has_l2_shrinkage> { auto grad = grad_flat.template chip<0>(i); auto var = var_flat.template chip<0>(index); -// TODO(sanjoy): Remove this macro. -// Use a macro to implement the computation here due to the templating of the -// eigen tensor library. -#define COMPUTE_FTRL(grad, grad_maybe_with_shrinkage) \ - auto new_accum = accum + grad.square(); \ - if (multiply_linear_by_lr) { \ - if (lr_power_scalar == static_cast<T>(-0.5)) { \ - linear += grad_maybe_with_shrinkage * lr_scalar - \ - (new_accum.sqrt() - accum.sqrt()) * var; \ - } else { \ - linear += \ - grad_maybe_with_shrinkage * lr_scalar - \ - (new_accum.pow(-lr_power_scalar) - accum.pow(-lr_power_scalar)) * \ - var; \ - } \ - } else { \ - if (lr_power_scalar == static_cast<T>(-0.5)) { \ - linear += grad_maybe_with_shrinkage - \ - (new_accum.sqrt() - accum.sqrt()) / lr_scalar * var; \ - } else { \ - linear += grad_maybe_with_shrinkage - (new_accum.pow(-lr_power_scalar) - \ - accum.pow(-lr_power_scalar)) / \ - lr_scalar * var; \ - } \ - } \ - auto l1_reg_adjust = \ - (multiply_linear_by_lr \ - ? linear.cwiseMin(l1_scalar * lr_scalar) \ - .cwiseMax(-l1_scalar * lr_scalar) \ - : linear.cwiseMin(l1_scalar).cwiseMax(-l1_scalar)); \ - auto x = l1_reg_adjust - linear; \ - if (multiply_linear_by_lr) { \ - if (lr_power_scalar == static_cast<T>(-0.5)) { \ - auto y = new_accum.sqrt() + \ - linear.constant(static_cast<T>(2) * l2_scalar * lr_scalar); \ - var = x / y; \ - } else { \ - auto y = new_accum.pow(-lr_power_scalar) + \ - linear.constant(static_cast<T>(2) * l2_scalar * lr_scalar); \ - var = x / y; \ - } \ - } else { \ - if (lr_power_scalar == static_cast<T>(-0.5)) { \ - auto y = new_accum.sqrt() / new_accum.constant(lr_scalar) + \ - linear.constant(static_cast<T>(2) * l2_scalar); \ - var = x / y; \ - } else { \ - auto y = \ - new_accum.pow(-lr_power_scalar) / new_accum.constant(lr_scalar) + \ - linear.constant(static_cast<T>(2) * l2_scalar); \ - var = x / y; \ - } \ - } \ - accum += grad.square(); - if (has_l2_shrinkage) { auto grad_with_shrinkage = grad + static_cast<T>(2) * l2_shrinkage_scalar * var; - COMPUTE_FTRL(grad, grad_with_shrinkage); + ComputeFtrl(/*grad=*/grad, + /*grad_maybe_with_shrinkage=*/grad_with_shrinkage, + /*accum=*/accum, /*linear=*/linear, /*var=*/var, + /*l1_scalar=*/l1_scalar, /*l2_scalar=*/l2_scalar, + /*multiply_linear_by_lr=*/multiply_linear_by_lr, + /*lr_power_scalar=*/lr_power_scalar, + /*lr_scalar=*/lr_scalar); } else { - COMPUTE_FTRL(grad, grad); + ComputeFtrl(/*grad=*/grad, /*grad_maybe_with_shrinkage=*/grad, + /*accum=*/accum, /*linear=*/linear, /*var=*/var, + /*l1_scalar=*/l1_scalar, /*l2_scalar=*/l2_scalar, + /*multiply_linear_by_lr=*/multiply_linear_by_lr, + /*lr_power_scalar=*/lr_power_scalar, + /*lr_scalar=*/lr_scalar); } } -#undef COMPUTE_FTRL } else { const Tindex first_dim_size = accum_flat.size(); diff --git a/tensorflow/core/lib/bmp/BUILD b/tensorflow/core/lib/bmp/BUILD index 186c3a0753f..3e21a027265 100644 --- a/tensorflow/core/lib/bmp/BUILD +++ b/tensorflow/core/lib/bmp/BUILD @@ -1,24 +1,12 @@ # Description: -# bmp test data packages. - -load("//tensorflow:tensorflow.bzl", "filegroup") +# bmp test data package alias. package( licenses = ["notice"], # Apache 2.0 ) -filegroup( +alias( name = "bmp_testdata", - srcs = [ - # BMP data - "testdata/lena.bmp", - "testdata/rgb_small.bmp", - "testdata/rgb_small_255.bmp", - "testdata/rgba_small.bmp", - "testdata/rgba_small_255.bmp", - "testdata/grayscale_small.bmp", - "testdata/grayscale_small_3channels.bmp", - "testdata/grayscale_small_4channels.bmp", - ], + actual = "//tensorflow/core/lib/bmp/testdata:bmp_testdata", visibility = ["//visibility:public"], ) diff --git a/tensorflow/core/lib/bmp/testdata/BUILD b/tensorflow/core/lib/bmp/testdata/BUILD new file mode 100644 index 00000000000..da11c447add --- /dev/null +++ b/tensorflow/core/lib/bmp/testdata/BUILD @@ -0,0 +1,24 @@ +# Description: +# bmp test data packages. + +load("//tensorflow:tensorflow.bzl", "filegroup") + +package( + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "bmp_testdata", + srcs = [ + # BMP data + "lena.bmp", + "rgb_small.bmp", + "rgb_small_255.bmp", + "rgba_small.bmp", + "rgba_small_255.bmp", + "grayscale_small.bmp", + "grayscale_small_3channels.bmp", + "grayscale_small_4channels.bmp", + ], + visibility = ["//visibility:public"], +) diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index c4bd397ee60..e00ebf7325a 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -853,7 +853,18 @@ REGISTER_OP("OptionalFromValue") .Input("components: Toutput_types") .Output("optional: variant") .Attr("Toutput_types: list(type) >= 1") - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + std::vector<DataType> dtypes; + TF_RETURN_IF_ERROR(c->GetAttr("Toutput_types", &dtypes)); + c->set_output(0, c->Scalar()); + std::vector<shape_inference::ShapeAndType> shapes_and_types; + shapes_and_types.reserve(c->num_inputs()); + for (int i = 0; i < c->num_inputs(); ++i) { + shapes_and_types.emplace_back(c->input(i), dtypes[i], ST_OPTIONAL); + } + c->set_output_handle_shapes_and_types(0, shapes_and_types); + return Status::OK(); + }); REGISTER_OP("OptionalNone") .Output("optional: variant") diff --git a/tensorflow/core/platform/cloud/curl_http_request.cc b/tensorflow/core/platform/cloud/curl_http_request.cc index 10df2e12038..e7a060e73cc 100644 --- a/tensorflow/core/platform/cloud/curl_http_request.cc +++ b/tensorflow/core/platform/cloud/curl_http_request.cc @@ -609,7 +609,7 @@ int CurlHttpRequest::ProgressCallback(void* this_object, curl_off_t dltotal, double starttransfer_time = -1; const auto starttransfer_time_status = that->libcurl_->curl_easy_getinfo( - that->curl_, CURLINFO_PRETRANSFER_TIME, &starttransfer_time); + that->curl_, CURLINFO_STARTTRANSFER_TIME, &starttransfer_time); LOG(ERROR) << "The transmission of request " << this_object << " (URI: " << that->uri_ << ") has been stuck at " diff --git a/tensorflow/core/platform/default/BUILD b/tensorflow/core/platform/default/BUILD index 74370988d07..01fd9d8bb3d 100644 --- a/tensorflow/core/platform/default/BUILD +++ b/tensorflow/core/platform/default/BUILD @@ -202,6 +202,7 @@ cc_library( "//tensorflow/core/platform", "//tensorflow/core/platform:env_time", "//tensorflow/core/platform:macros", + "//tensorflow/core/platform:mutex", "//tensorflow/core/platform:types", "@com_google_absl//absl/base", "@com_google_absl//absl/strings", diff --git a/tensorflow/core/platform/default/logging.cc b/tensorflow/core/platform/default/logging.cc index 43d8545e6fd..b19c2630f23 100644 --- a/tensorflow/core/platform/default/logging.cc +++ b/tensorflow/core/platform/default/logging.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/base/internal/sysinfo.h" #include "tensorflow/core/platform/env_time.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" #if defined(PLATFORM_POSIX_ANDROID) #include <android/log.h> @@ -31,30 +32,135 @@ limitations under the License. #include <string.h> #include <time.h> -#include <string> +#include <algorithm> +#include <queue> #include <unordered_map> namespace tensorflow { -void TFAddLogSink(TFLogSink* sink) { - // LogSink is not implemented. - // If necessary, one can add the log sink support as follows. - // 1. Define a global vector<TFLogSink> to keep track of all registered - // TFLogSink objects. Protect the global vector with mutex to make it - // thread-safe. - // 2. Add/remove elements from the global vector<TFLogSink> in TFAddLogSink - // and TFRemoveLogSink function - // 3. Add logic in LogMessage::GenerateLogMessage() below to dispatch log - // messages to all the registered log sinks. -} - -void TFRemoveLogSink(TFLogSink* sink) { - // LogSink is not implemented. -} - namespace internal { namespace { +// This is an internal singleton class that manages the log sinks. It allows +// adding and removing the log sinks, as well as handling sending log messages +// to all the registered log sinks. +class TFLogSinks { + public: + // Gets the TFLogSinks instance. This is the entry point for using this class. + static TFLogSinks& Instance(); + + // Adds a log sink. The sink argument must not be a nullptr. TFLogSinks + // takes ownership of the pointer, the user must not free the pointer. + // The pointer will remain valid until the application terminates or + // until TFLogSinks::Remove is called for the same pointer value. + void Add(TFLogSink* sink); + + // Removes a log sink. This will also erase the sink object. The pointer + // to the sink becomes invalid after this call. + void Remove(TFLogSink* sink); + + // Gets the currently registered log sinks. + std::vector<TFLogSink*> GetSinks() const; + + // Sends a log message to all registered log sinks. + // + // If there are no log sinks are registered: + // + // NO_DEFAULT_LOGGER is defined: + // Up to 128 messages will be queued until a log sink is added. + // The queue will then be logged to the first added log sink. + // + // NO_DEFAULT_LOGGER is not defined: + // The messages will be logged using the default logger. The default logger + // will log to stdout on all platforms except for Android. On Androit the + // default Android logger will be used. + void Send(const TFLogEntry& entry); + + private: + TFLogSinks(); + void SendToSink(TFLogSink& sink, const TFLogEntry& entry); + + std::queue<TFLogEntry> log_entry_queue_; + static const size_t kMaxLogEntryQueueSize = 128; + + mutable tensorflow::mutex mutex_; + std::vector<TFLogSink*> sinks_; +}; + +TFLogSinks::TFLogSinks() { +#ifndef NO_DEFAULT_LOGGER + static TFDefaultLogSink* default_sink = new TFDefaultLogSink(); + sinks_.emplace_back(default_sink); +#endif +} + +TFLogSinks& TFLogSinks::Instance() { + static TFLogSinks* instance = new TFLogSinks(); + return *instance; +} + +void TFLogSinks::Add(TFLogSink* sink) { + assert(sink != nullptr && "The sink must not be a nullptr"); + + tensorflow::mutex_lock lock(mutex_); + sinks_.emplace_back(sink); + + // If this is the only sink log all the queued up messages to this sink + if (sinks_.size() == 1) { + while (!log_entry_queue_.empty()) { + for (const auto& sink : sinks_) { + SendToSink(*sink, log_entry_queue_.front()); + } + log_entry_queue_.pop(); + } + } +} + +void TFLogSinks::Remove(TFLogSink* sink) { + assert(sink != nullptr && "The sink must not be a nullptr"); + + tensorflow::mutex_lock lock(mutex_); + auto it = std::find(sinks_.begin(), sinks_.end(), sink); + if (it != sinks_.end()) sinks_.erase(it); +} + +std::vector<TFLogSink*> TFLogSinks::GetSinks() const { + tensorflow::mutex_lock lock(mutex_); + return sinks_; +} + +void TFLogSinks::Send(const TFLogEntry& entry) { + tensorflow::mutex_lock lock(mutex_); + + // If we don't have any sinks registered, queue them up + if (sinks_.empty()) { + // If we've exceeded the maximum queue size, drop the oldest entries + while (log_entry_queue_.size() >= kMaxLogEntryQueueSize) { + log_entry_queue_.pop(); + } + log_entry_queue_.push(entry); + return; + } + + // If we have items in the queue, push them out first + while (!log_entry_queue_.empty()) { + for (const auto& sink : sinks_) { + SendToSink(*sink, log_entry_queue_.front()); + } + log_entry_queue_.pop(); + } + + // ... and now we can log the current log entry + for (const auto& sink : sinks_) { + SendToSink(*sink, entry); + } +} + +void TFLogSinks::SendToSink(TFLogSink& sink, const TFLogEntry& entry) { + sink.Send(entry); + sink.WaitTillSent(); +} + int ParseInteger(const char* str, size_t size) { // Ideally we would use env_var / safe_strto64, but it is // hard to use here without pulling in a lot of dependencies, @@ -197,70 +303,10 @@ LogMessage::~LogMessage() { } } -#if defined(PLATFORM_POSIX_ANDROID) void LogMessage::GenerateLogMessage() { - int android_log_level; - switch (severity_) { - case INFO: - android_log_level = ANDROID_LOG_INFO; - break; - case WARNING: - android_log_level = ANDROID_LOG_WARN; - break; - case ERROR: - android_log_level = ANDROID_LOG_ERROR; - break; - case FATAL: - android_log_level = ANDROID_LOG_FATAL; - break; - default: - if (severity_ < INFO) { - android_log_level = ANDROID_LOG_VERBOSE; - } else { - android_log_level = ANDROID_LOG_ERROR; - } - break; - } - - std::stringstream ss; - const char* const partial_name = strrchr(fname_, '/'); - ss << (partial_name != nullptr ? partial_name + 1 : fname_) << ":" << line_ - << " " << str(); - __android_log_write(android_log_level, "native", ss.str().c_str()); - - // Also log to stderr (for standalone Android apps). - fprintf(stderr, "native : %s\n", ss.str().c_str()); - - // Android logging at level FATAL does not terminate execution, so abort() - // is still required to stop the program. - if (severity_ == FATAL) { - abort(); - } + TFLogSinks::Instance().Send(TFLogEntry(severity_, fname_, line_, str())); } -#else - -void LogMessage::GenerateLogMessage() { - static bool log_thread_id = EmitThreadIdFromEnv(); - uint64 now_micros = EnvTime::NowMicros(); - time_t now_seconds = static_cast<time_t>(now_micros / 1000000); - int32 micros_remainder = static_cast<int32>(now_micros % 1000000); - const size_t time_buffer_size = 30; - char time_buffer[time_buffer_size]; - strftime(time_buffer, time_buffer_size, "%Y-%m-%d %H:%M:%S", - localtime(&now_seconds)); - const size_t tid_buffer_size = 10; - char tid_buffer[tid_buffer_size] = ""; - if (log_thread_id) { - snprintf(tid_buffer, sizeof(tid_buffer), " %7u", - absl::base_internal::GetTID()); - } - // TODO(jeff,sanjay): Replace this with something that logs through the env. - fprintf(stderr, "%s.%06d: %c%s %s:%d] %s\n", time_buffer, micros_remainder, - "IWEF"[severity_], tid_buffer, fname_, line_, str().c_str()); -} -#endif - int64 LogMessage::MinVLogLevel() { static int64 min_vlog_level = MinVLogLevelFromEnv(); return min_vlog_level; @@ -393,4 +439,104 @@ bool LogEveryNSecState::ShouldLog(double seconds) { } } // namespace internal + +void TFAddLogSink(TFLogSink* sink) { + internal::TFLogSinks::Instance().Add(sink); +} + +void TFRemoveLogSink(TFLogSink* sink) { + internal::TFLogSinks::Instance().Remove(sink); +} + +std::vector<TFLogSink*> TFGetLogSinks() { + return internal::TFLogSinks::Instance().GetSinks(); +} + +void TFDefaultLogSink::Send(const TFLogEntry& entry) { +#ifdef PLATFORM_POSIX_ANDROID + int android_log_level; + switch (entry.log_severity()) { + case absl::LogSeverity::kInfo: + android_log_level = ANDROID_LOG_INFO; + break; + case absl::LogSeverity::kWarning: + android_log_level = ANDROID_LOG_WARN; + break; + case absl::LogSeverity::kError: + android_log_level = ANDROID_LOG_ERROR; + break; + case absl::LogSeverity::kFatal: + android_log_level = ANDROID_LOG_FATAL; + break; + default: + if (entry.log_severity() < absl::LogSeverity::kInfo) { + android_log_level = ANDROID_LOG_VERBOSE; + } else { + android_log_level = ANDROID_LOG_ERROR; + } + break; + } + + std::stringstream ss; + const auto& fname = entry.FName(); + auto pos = fname.find("/"); + ss << (pos != std::string::npos ? fname.substr(pos + 1) : fname) << ":" + << entry.Line() << " " << entry.ToString(); + __android_log_write(android_log_level, "native", ss.str().c_str()); + + // Also log to stderr (for standalone Android apps). + std::cerr << "native : " << ss.str() << std::endl; + + // Android logging at level FATAL does not terminate execution, so abort() + // is still required to stop the program. + if (entry.log_severity() == absl::LogSeverity::kFatal) { + abort(); + } +#else // PLATFORM_POSIX_ANDROID + static bool log_thread_id = internal::EmitThreadIdFromEnv(); + uint64 now_micros = EnvTime::NowMicros(); + time_t now_seconds = static_cast<time_t>(now_micros / 1000000); + int32 micros_remainder = static_cast<int32>(now_micros % 1000000); + const size_t time_buffer_size = 30; + char time_buffer[time_buffer_size]; + strftime(time_buffer, time_buffer_size, "%Y-%m-%d %H:%M:%S", + localtime(&now_seconds)); + const size_t tid_buffer_size = 10; + char tid_buffer[tid_buffer_size] = ""; + if (log_thread_id) { + snprintf(tid_buffer, sizeof(tid_buffer), " %7u", + absl::base_internal::GetTID()); + } + + char sev; + switch (entry.log_severity()) { + case absl::LogSeverity::kInfo: + sev = 'I'; + break; + + case absl::LogSeverity::kWarning: + sev = 'W'; + break; + + case absl::LogSeverity::kError: + sev = 'E'; + break; + + case absl::LogSeverity::kFatal: + sev = 'F'; + break; + + default: + assert(false && "Unknown logging severity"); + sev = '?'; + break; + } + + // TODO(jeff,sanjay): Replace this with something that logs through the env. + fprintf(stderr, "%s.%06d: %c%s %s:%d] %s\n", time_buffer, micros_remainder, + sev, tid_buffer, entry.FName().c_str(), entry.Line(), + entry.ToString().c_str()); +#endif // PLATFORM_POSIX_ANDROID +} + } // namespace tensorflow diff --git a/tensorflow/core/platform/default/logging.h b/tensorflow/core/platform/default/logging.h index f60deb43683..56dd19c8aa0 100644 --- a/tensorflow/core/platform/default/logging.h +++ b/tensorflow/core/platform/default/logging.h @@ -23,6 +23,7 @@ limitations under the License. #include <limits> #include <memory> #include <sstream> +#include <vector> #include "absl/base/log_severity.h" #include "absl/strings/string_view.h" @@ -477,15 +478,27 @@ class TFLogEntry { } public: - explicit TFLogEntry(int severity, absl::string_view log_line) - : severity_(AsAbslLogSeverity(severity)), log_line_(log_line) {} + explicit TFLogEntry(int severity, absl::string_view message) + : severity_(AsAbslLogSeverity(severity)), message_(message) {} + + explicit TFLogEntry(int severity, absl::string_view fname, int line, + absl::string_view message) + : severity_(AsAbslLogSeverity(severity)), + fname_(fname), + line_(line), + message_(message) {} absl::LogSeverity log_severity() const { return severity_; } - std::string ToString() const { return std::string(log_line_); } + std::string FName() const { return fname_; } + int Line() const { return line_; } + std::string ToString() const { return message_; } + absl::string_view text_message() const { return message_; } private: const absl::LogSeverity severity_; - const absl::string_view log_line_; + const std::string fname_; + int line_ = -1; + const std::string message_; }; class TFLogSink { @@ -513,10 +526,23 @@ class TFLogSink { virtual void WaitTillSent() {} }; +// This is the default log sink. This log sink is used if there are no other +// log sinks registered. To disable the default log sink, set the +// "no_default_logger" Bazel config setting to true or define a +// NO_DEFAULT_LOGGER preprocessor symbol. This log sink will always log to +// stderr. +class TFDefaultLogSink : public TFLogSink { + public: + void Send(const TFLogEntry& entry) override; +}; + // Add or remove a `LogSink` as a consumer of logging data. Thread-safe. void TFAddLogSink(TFLogSink* sink); void TFRemoveLogSink(TFLogSink* sink); +// Get all the log sinks. Thread-safe. +std::vector<TFLogSink*> TFGetLogSinks(); + } // namespace tensorflow #endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_LOGGING_H_ diff --git a/tensorflow/core/platform/errors.h b/tensorflow/core/platform/errors.h index 55af45a4c24..fbd3b518699 100644 --- a/tensorflow/core/platform/errors.h +++ b/tensorflow/core/platform/errors.h @@ -44,7 +44,7 @@ namespace internal { // Eventually absl::strings will have native support for this and we will be // able to completely remove PrepareForStrCat(). template <typename T> -typename std::enable_if<!std::is_constructible<strings::AlphaNum, T>::value, +typename std::enable_if<!std::is_convertible<T, strings::AlphaNum>::value, std::string>::type PrepareForStrCat(const T& t) { std::stringstream ss; diff --git a/tensorflow/core/platform/logging_test.cc b/tensorflow/core/platform/logging_test.cc index 75da1a4e484..54890bdd2f0 100644 --- a/tensorflow/core/platform/logging_test.cc +++ b/tensorflow/core/platform/logging_test.cc @@ -14,6 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/platform/logging.h" + +#include <sstream> +#include <vector> + #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -96,4 +100,33 @@ TEST(InternalLogString, Basic) { internal::LogString(__FILE__, __LINE__, INFO, "Hello there"); } +class TestSink : public TFLogSink { + public: + void Send(const TFLogEntry& entry) override { + ss_ << entry.text_message() << std::endl; + } + + std::string Get() const { return ss_.str(); } + + private: + std::stringstream ss_; +}; + +TEST(LogSinkTest, testLogSinks) { + const int sinks_initial_size = TFGetLogSinks().size(); + TestSink sink; + + TFAddLogSink(&sink); + + EXPECT_EQ(TFGetLogSinks().size(), sinks_initial_size + 1); + + LOG(INFO) << "Foo"; + LOG(INFO) << "Bar"; + EXPECT_EQ(sink.Get(), "Foo\nBar\n"); + + TFRemoveLogSink(&sink); + + EXPECT_EQ(TFGetLogSinks().size(), sinks_initial_size); +} + } // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc index e1dbaed1746..b33cd911afa 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc @@ -210,12 +210,9 @@ OpMetricsDb ConvertHostThreadsXPlaneToOpMetricsDb(const XPlane& host_trace) { return result; } -OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb( - const XPlane& device_trace, double peak_tera_flops_per_second, - double peak_hbm_bw_giga_bytes_per_second) { +OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb(const XPlane& device_trace) { OpMetricsDb result; - DeviceOpMetricsDbBuilder device_op_metrics_db_builder( - &result, peak_tera_flops_per_second, peak_hbm_bw_giga_bytes_per_second); + DeviceOpMetricsDbBuilder device_op_metrics_db_builder(&result); int64 first_op_offset_ps = kint64max; int64 last_op_offset_ps = 0; diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h index f2d7fc702fc..93ab7339a37 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h @@ -50,9 +50,7 @@ void ConsumeTfMetricsDbData(TfMetricsDbData src, OpMetricsDbCombiner* dst); OpMetricsDb ConvertHostThreadsXPlaneToOpMetricsDb(const XPlane& host_trace); -OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb( - const XPlane& device_trace, double peak_tera_flops_per_second, - double peak_hbm_bw_giga_bytes_per_second); +OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb(const XPlane& device_trace); } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc index 7d6f23db041..7be6277ba58 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc @@ -132,9 +132,7 @@ TEST(ConvertXPlaneToOpMetricsDb, DeviceOpMetricsDb) { kKernel3DurationNs, /*on_device=*/true, kKernel3, &device_plane, &stream2); - OpMetricsDb op_metrics = ConvertDeviceTraceXPlaneToOpMetricsDb( - *xplane, /*peak_tera_flops_per_second=*/0, - /*peak_hbm_bw_giga_bytes_per_second=*/0); + OpMetricsDb op_metrics = ConvertDeviceTraceXPlaneToOpMetricsDb(*xplane); // kernel1, kernel2, kernel3, Idle. EXPECT_EQ(4, op_metrics.metrics_db_size()); diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc index 6eb67eab216..e6b84b6b873 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc @@ -170,10 +170,8 @@ OpStats ConvertXSpaceToOpStats(const XSpace& space, if (!op_stats.has_perf_env()) { *op_stats.mutable_perf_env() = GetPerfEnvFromXPlane(*device_trace); } - const PerfEnv& perf_env = op_stats.perf_env(); - OpMetricsDb device_op_metrics_db = ConvertDeviceTraceXPlaneToOpMetricsDb( - *device_trace, perf_env.peak_tera_flops_per_second(), - perf_env.peak_hbm_bw_giga_bytes_per_second()); + OpMetricsDb device_op_metrics_db = + ConvertDeviceTraceXPlaneToOpMetricsDb(*device_trace); op_metrics_db_combiner.Combine(device_op_metrics_db); } if (options.generate_step_db) { diff --git a/tensorflow/core/profiler/protobuf/xplane.proto b/tensorflow/core/profiler/protobuf/xplane.proto index f57d7609891..3936c44227b 100644 --- a/tensorflow/core/profiler/protobuf/xplane.proto +++ b/tensorflow/core/profiler/protobuf/xplane.proto @@ -118,7 +118,7 @@ message XStat { // Metadata for an XEvent, corresponds to an event type and is shared by // all XEvents with the same metadata_id. -// Next ID: 6 +// Next ID: 7 message XEventMetadata { // XPlane.event_metadata map key. int64 id = 1; @@ -135,6 +135,9 @@ message XEventMetadata { // XStats that are constant for all XEvents with the same metadata_id. // Each of these XStats should have a different metadata_id. repeated XStat stats = 5; + + // XPlane.event_metadata map key for children events. + repeated int64 child_id = 6; } // Metadata for an XStat, corresponds to a stat type and is shared by all diff --git a/tensorflow/core/profiler/rpc/client/BUILD b/tensorflow/core/profiler/rpc/client/BUILD index ea37a8852a0..9bf0179859d 100644 --- a/tensorflow/core/profiler/rpc/client/BUILD +++ b/tensorflow/core/profiler/rpc/client/BUILD @@ -117,6 +117,7 @@ cc_library( tf_cc_test( name = "profiler_client_test", srcs = ["profiler_client_test.cc"], + tags = ["notap"], # b/173824689 deps = [ ":profiler_client", ":profiler_client_impl", # for oss diff --git a/tensorflow/core/profiler/utils/op_utils.h b/tensorflow/core/profiler/utils/op_utils.h index 0aa606f0144..bb7c3103de5 100644 --- a/tensorflow/core/profiler/utils/op_utils.h +++ b/tensorflow/core/profiler/utils/op_utils.h @@ -49,12 +49,7 @@ class HostOpMetricsDbBuilder : public OpMetricsDbBuilder { class DeviceOpMetricsDbBuilder : public OpMetricsDbBuilder { public: - explicit DeviceOpMetricsDbBuilder(OpMetricsDb* db, - double peak_tera_flops_per_second, - double peak_hbm_bw_giga_bytes_per_second) - : OpMetricsDbBuilder(db), - peak_tera_flops_per_second_(peak_tera_flops_per_second), - peak_hbm_bw_giga_bytes_per_second_(peak_hbm_bw_giga_bytes_per_second) {} + explicit DeviceOpMetricsDbBuilder(OpMetricsDb* db) : OpMetricsDbBuilder(db) {} // A function that will be called when the end of an OP is // observed on a trace, where: @@ -78,12 +73,6 @@ class DeviceOpMetricsDbBuilder : public OpMetricsDbBuilder { uint64 children_time_ps, int64 flops, int64 bytes_accessed, const protobuf::RepeatedPtrField<OpMetrics::MemoryAccessed>& memory_accessed_breakdown = {}); - - protected: - // Peak performance of a TensorCore or a GPU in TFLOP/s. - double peak_tera_flops_per_second_; - // Peak memory bandwidth of a TensorCore or a GPU in GiBs/s. - double peak_hbm_bw_giga_bytes_per_second_; }; } // namespace profiler diff --git a/tensorflow/core/profiler/utils/xplane_visitor.h b/tensorflow/core/profiler/utils/xplane_visitor.h index e7ac97f3098..73dd8db577a 100644 --- a/tensorflow/core/profiler/utils/xplane_visitor.h +++ b/tensorflow/core/profiler/utils/xplane_visitor.h @@ -83,15 +83,15 @@ class XStatVisitor { template <class T> class XStatsOwner { public: - // REQUIRED: metadata and stats_owner cannot be nullptr. - XStatsOwner(const XPlaneVisitor* metadata, const T* stats_owner) - : stats_owner_(stats_owner), metadata_(metadata) {} + // REQUIRED: plane and stats_owner cannot be nullptr. + XStatsOwner(const XPlaneVisitor* plane, const T* stats_owner) + : plane_(plane), stats_owner_(stats_owner) {} - // For each plane level stats, call the specified lambda. + // For each stat, call the specified lambda. template <typename ForEachStatFunc> void ForEachStat(ForEachStatFunc&& for_each_stat) const { for (const XStat& stat : stats_owner_->stats()) { - for_each_stat(XStatVisitor(metadata_, &stat)); + for_each_stat(XStatVisitor(plane_, &stat)); } } @@ -100,12 +100,35 @@ class XStatsOwner { // Prefer ForEachStat above when multiple stat values are necessary. absl::optional<XStatVisitor> GetStat(int64 stat_type) const; + protected: + const XPlaneVisitor* plane() const { return plane_; } + const T* stats_owner() const { return stats_owner_; } + private: + const XPlaneVisitor* plane_; const T* stats_owner_; - const XPlaneVisitor* metadata_; }; -using XEventMetadataVisitor = XStatsOwner<XEventMetadata>; +class XEventMetadataVisitor : public XStatsOwner<XEventMetadata> { + public: + // REQUIRED: plane and metadata cannot be nullptr. + XEventMetadataVisitor(const XPlaneVisitor* plane, + const XEventMetadata* metadata) + : XStatsOwner(plane, metadata) {} + + absl::string_view Name() const { return metadata()->name(); } + + bool HasDisplayName() const { return !metadata()->display_name().empty(); } + + absl::string_view DisplayName() const { return metadata()->display_name(); } + + // For each child event metadata, call the specified lambda. + template <typename ForEachChildFunc> + void ForEachChild(ForEachChildFunc&& for_each_child) const; + + private: + const XEventMetadata* metadata() const { return stats_owner(); } +}; class XEventVisitor : public XStatsOwner<XEvent> { public: @@ -113,10 +136,6 @@ class XEventVisitor : public XStatsOwner<XEvent> { XEventVisitor(const XPlaneVisitor* plane, const XLine* line, const XEvent* event); - XEventMetadataVisitor MetadataStats() const { - return XEventMetadataVisitor(plane_, metadata_); - } - int64 Id() const { return event_->metadata_id(); } absl::string_view Name() const { return metadata_->name(); } @@ -127,8 +146,6 @@ class XEventVisitor : public XStatsOwner<XEvent> { absl::string_view DisplayName() const { return metadata_->display_name(); } - absl::string_view Metadata() const { return metadata_->metadata(); } - double OffsetNs() const { return PicosToNanos(event_->offset_ps()); } int64 OffsetPs() const { return event_->offset_ps(); } @@ -157,6 +174,10 @@ class XEventVisitor : public XStatsOwner<XEvent> { const XEventMetadata* metadata() const { return metadata_; } + XEventMetadataVisitor Metadata() const { + return XEventMetadataVisitor(plane_, metadata_); + } + Timespan GetTimespan() const { return Timespan(TimestampPs(), DurationPs()); } private: @@ -262,17 +283,28 @@ class XPlaneVisitor : public XStatsOwner<XPlane> { template <class T> absl::optional<XStatVisitor> XStatsOwner<T>::GetStat(int64 stat_type) const { - const auto* stat_metadata = metadata_->GetStatMetadataByType(stat_type); + const auto* stat_metadata = plane_->GetStatMetadataByType(stat_type); if (stat_metadata != nullptr) { for (const XStat& stat : stats_owner_->stats()) { if (stat.metadata_id() == stat_metadata->id()) { - return XStatVisitor(metadata_, &stat, stat_metadata, stat_type); + return XStatVisitor(plane_, &stat, stat_metadata, stat_type); } } } return absl::nullopt; // type does not exist in this owner. } +template <typename ForEachChildFunc> +void XEventMetadataVisitor::ForEachChild( + ForEachChildFunc&& for_each_child) const { + for (int64 child_id : metadata()->child_id()) { + const auto* event_metadata = plane()->GetEventMetadata(child_id); + if (event_metadata != nullptr) { + for_each_child(XEventMetadataVisitor(plane(), event_metadata)); + } + } +} + } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/protobuf/BUILD b/tensorflow/core/protobuf/BUILD index d1cbf872087..ef968563ba8 100644 --- a/tensorflow/core/protobuf/BUILD +++ b/tensorflow/core/protobuf/BUILD @@ -141,6 +141,7 @@ exports_files( "snapshot.proto", "service_config.proto", "debug_event.proto", + "extension_type_variant.proto", "meta_graph.proto", "named_tensor.proto", "remote_tensor_handle.proto", @@ -170,6 +171,7 @@ tf_proto_library( "snapshot.proto", "service_config.proto", "debug_event.proto", + "extension_type_variant.proto", "meta_graph.proto", "named_tensor.proto", "remote_tensor_handle.proto", diff --git a/tensorflow/core/protobuf/extension_type_variant.proto b/tensorflow/core/protobuf/extension_type_variant.proto new file mode 100644 index 00000000000..536db3b2435 --- /dev/null +++ b/tensorflow/core/protobuf/extension_type_variant.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +package tensorflow; + +import "tensorflow/core/protobuf/struct.proto"; + +// Metadata for ExtensionTypeVariant, used when serializing as Variant. +// +// We define a new message here (rather than directly using TypeSpecProto for +// the metadata string) to retain flexibility to change the metadata encoding +// to support additional features. +message ExtensionTypeVariantMetadata { + TypeSpecProto type_spec_proto = 1; +} diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 20e650de6d1..41478958118 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 589 // Updated: 2020/11/18 +#define TF_GRAPH_DEF_VERSION 594 // Updated: 2020/11/23 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc index bd20924ff23..ca85e3fb076 100644 --- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc @@ -653,14 +653,75 @@ struct ShardedInputInfo { std::vector<NodeOut> sharded_inputs; }; +// Adds pad node after split node to graph for uneven sharding tiled inputs. +// |graph| owns the returned Node* instance. +xla::StatusOr<Node*> CreatePadNode(const int padding, const int num_dims, + const int split_dim, DataType dtype, + Node* control_predecessor, Node* split_node, + const int split_index, Graph* graph) { + // Add paddings node. + Status s; + NodeDef paddings_def; + paddings_def.set_name( + graph->NewName(absl::StrCat(split_node->name(), "/paddings"))); + paddings_def.set_op("Const"); + AddNodeAttr("dtype", DT_INT32, &paddings_def); + paddings_def.set_device(split_node->assigned_device_name()); + TensorProto sizes_tensor_proto; + sizes_tensor_proto.set_dtype(DT_INT32); + for (int i = 0; i < num_dims; ++i) { + sizes_tensor_proto.add_int_val(0); + if (i == split_dim) { + sizes_tensor_proto.add_int_val(padding); + } else { + sizes_tensor_proto.add_int_val(0); + } + } + TensorShape sizes_shape({num_dims, 2}); + sizes_shape.AsProto(sizes_tensor_proto.mutable_tensor_shape()); + AddNodeAttr("value", sizes_tensor_proto, &paddings_def); + Node* paddings_node = graph->AddNode(paddings_def, &s); + TF_RETURN_IF_ERROR(s); + + // Add Pad node. + NodeDef pad_def; + pad_def.set_name(graph->NewName( + absl::StrCat(split_node->name(), "/pad_shard_", split_index))); + pad_def.set_op("Pad"); + pad_def.set_device(split_node->assigned_device_name()); + AddNodeAttr("T", dtype, &pad_def); + AddNodeAttr("Tpaddings", DT_INT32, &pad_def); + pad_def.add_input(absl::StrCat(split_node->name(), ":", split_index)); + pad_def.add_input(absl::StrCat(paddings_node->name(), ":0")); + Node* pad_node = graph->AddNode(pad_def, &s); + pad_node->set_assigned_device_name(split_node->assigned_device_name()); + TF_RETURN_IF_ERROR(s); + // Add edges for pad node. + graph->AddEdge(split_node, split_index, pad_node, 0); + graph->AddEdge(paddings_node, 0, pad_node, 1); + graph->AddControlEdge(control_predecessor, pad_node); + return pad_node; +} + // Adds split node and split dimension node to graph for sharding tiled inputs. // |graph| owns the returned Node* instance. -xla::StatusOr<Node*> CreateSplitNode(int num_splits, int dim, - int orig_src_output, DataType dtype, +xla::StatusOr<Node*> CreateSplitNode(const int num_splits, const int dim, + const int num_dims, const int64 padding, + const int orig_src_output, DataType dtype, absl::string_view name_prefix, Node* control_predecessor, Node* orig_src, Graph* graph) { const std::string input_assigned_device = orig_src->assigned_device_name(); + Node* to_split_node = orig_src; + int to_split_index = orig_src_output; + if (padding > 0) { + TF_ASSIGN_OR_RETURN( + Node * pad_node, + CreatePadNode(padding, num_dims, dim, dtype, control_predecessor, + orig_src, orig_src_output, graph)); + to_split_node = pad_node; + to_split_index = 0; + } // Add a split dimension node. NodeDef split_dim_def; @@ -686,28 +747,48 @@ xla::StatusOr<Node*> CreateSplitNode(int num_splits, int dim, AddNodeAttr("num_split", num_splits, &split_def); AddNodeAttr("T", dtype, &split_def); split_def.add_input(absl::StrCat(split_dim_node->name(), ":0")); - split_def.add_input(absl::StrCat(orig_src->name(), ":", orig_src_output)); + split_def.add_input(absl::StrCat(to_split_node->name(), ":", to_split_index)); Node* split_node = graph->AddNode(split_def, &s); - split_node->set_assigned_device_name(input_assigned_device); TF_RETURN_IF_ERROR(s); + split_node->set_assigned_device_name(input_assigned_device); + + // If colocate the newly created split op to source node of input to TPU + // computation. + split_node->AddAttr(kColocationAttrName, + std::vector<string>{absl::StrCat(kColocationGroupPrefix, + orig_src->name())}); + graph->AddEdge(split_dim_node, 0, split_node, 0); - graph->AddEdge(orig_src, orig_src_output, split_node, 1); + graph->AddEdge(to_split_node, to_split_index, split_node, 1); // Add a control dependency from `control_predecessor` to newly created // constant node. This ensures that newly added split/split dim // nodes are placed inside correct while loop frames when TPUExecute // node is inside a host training loop. graph->AddControlEdge(control_predecessor, split_dim_node); - return split_node; } +int64 GetPadding(const int split_dim, const int num_splits, + const PartialTensorShape& partial_tensor_shape) { + // If dim dimension is not defined, no uneven sharding support. + if (partial_tensor_shape.dim_size(split_dim) <= 0) { + return 0; + } + int64 per_split_size = tensorflow::MathUtil::CeilOfRatio<int64>( + partial_tensor_shape.dim_size(split_dim), num_splits); + int64 total_padding = + per_split_size * num_splits - partial_tensor_shape.dim_size(split_dim); + return total_padding; +} + // Creates a set of splits nodes that shards tiled input node in graph. xla::StatusOr<ShardedInputInfo> CreateOrGetSplitNodesForInputSharding( const xla::OpSharding& sharding, int orig_arg_num, DataType dtype, - int replica_id, int orig_src_output, Node* orig_src, - Node* control_predecessor, Graph* graph, + const PartialTensorShape& partial_tensor_shape, int replica_id, + int orig_src_output, Node* orig_src, Node* control_predecessor, + Graph* graph, std::map<ShardedInputIndex, ShardedInputInfo>* arg_index_to_sharded_input_map) { ShardedInputIndex input_index{replica_id, orig_arg_num}; @@ -738,6 +819,7 @@ xla::StatusOr<ShardedInputInfo> CreateOrGetSplitNodesForInputSharding( auto sharding_it = split_dimension_map.begin(); std::queue<Node*> split_nodes_for_dimension; + absl::flat_hash_map<Node*, int> node_to_split_dim; int split_dimension = sharding_it->first; int num_split = sharding_it->second; @@ -747,13 +829,17 @@ xla::StatusOr<ShardedInputInfo> CreateOrGetSplitNodesForInputSharding( // that split the input data at ith dimension. TF_ASSIGN_OR_RETURN( Node * root_split_node, - CreateSplitNode(num_split, split_dimension, orig_src_output, dtype, - absl::StrCat("sharded_input/replica_", replica_id, - "_dim_", split_dimension), - control_predecessor, orig_src, graph)); + CreateSplitNode( + num_split, split_dimension, partial_tensor_shape.dims(), + GetPadding(split_dimension, num_split, partial_tensor_shape), + orig_src_output, dtype, + absl::StrCat("sharded_input/replica_", replica_id, "_dim_", + split_dimension), + control_predecessor, orig_src, graph)); sharding_it++; split_nodes_for_dimension.emplace(root_split_node); + node_to_split_dim[root_split_node] = split_dimension; while (sharding_it != split_dimension_map.end()) { split_dimension = sharding_it->first; @@ -767,11 +853,15 @@ xla::StatusOr<ShardedInputInfo> CreateOrGetSplitNodesForInputSharding( ++src_output_index) { TF_ASSIGN_OR_RETURN( Node * split_node, - CreateSplitNode(num_split, split_dimension, src_output_index, dtype, - absl::StrCat("sharded_input/replica_", replica_id, - "_dim_", split_dimension), - control_predecessor, input_split_node, graph)); + CreateSplitNode( + num_split, split_dimension, partial_tensor_shape.dims(), + GetPadding(split_dimension, num_split, partial_tensor_shape), + src_output_index, dtype, + absl::StrCat("sharded_input/replica_", replica_id, "_dim_", + split_dimension), + control_predecessor, input_split_node, graph)); split_nodes_for_dimension.emplace(split_node); + node_to_split_dim[split_node] = split_dimension; } } sharding_it++; @@ -856,18 +946,82 @@ xla::StatusOr<Node*> CreateConcatNode(int dim, int num_splits, DataType dtype, return concat_node; } +// Adds slice node after concat node to graph for uneven sharding tiled inputs. +xla::StatusOr<Node*> CreateSliceNode(DataType dtype, + const PartialTensorShape& shape, + Node* concat_node, + const int concat_out_index, Graph* graph, + absl::string_view device) { + Status s; + // Add begin node for concat. + NodeDef begin_def; + begin_def.set_name( + graph->NewName(absl::StrCat(concat_node->name(), "/slice_begin"))); + begin_def.set_op("Const"); + AddNodeAttr("dtype", DT_INT32, &begin_def); + begin_def.set_device(std::string(device)); + TensorProto begin_tensor_proto; + begin_tensor_proto.set_dtype(DT_INT32); + for (int i = 0; i < shape.dims(); ++i) { + begin_tensor_proto.add_int_val(0); + } + TensorShape begin_shape({shape.dims()}); + begin_shape.AsProto(begin_tensor_proto.mutable_tensor_shape()); + AddNodeAttr("value", begin_tensor_proto, &begin_def); + Node* begin_node = graph->AddNode(begin_def, &s); + TF_RETURN_IF_ERROR(s); + + // Add size node. + NodeDef size_def; + size_def.set_name( + graph->NewName(absl::StrCat(concat_node->name(), "/slice_size"))); + size_def.set_op("Const"); + AddNodeAttr("dtype", DT_INT32, &size_def); + size_def.set_device(std::string(device)); + TensorProto sizes_tensor_proto; + sizes_tensor_proto.set_dtype(DT_INT32); + for (int i = 0; i < shape.dims(); ++i) { + sizes_tensor_proto.add_int_val(shape.dim_size(i)); + } + TensorShape sizes_shape({shape.dims()}); + sizes_shape.AsProto(sizes_tensor_proto.mutable_tensor_shape()); + AddNodeAttr("value", sizes_tensor_proto, &size_def); + Node* size_node = graph->AddNode(size_def, &s); + TF_RETURN_IF_ERROR(s); + + // Add Slice node. + NodeDef slice_def; + slice_def.set_name( + graph->NewName(absl::StrCat(concat_node->name(), "/slice"))); + slice_def.set_op("Slice"); + slice_def.set_device(std::string(device)); + AddNodeAttr("T", dtype, &slice_def); + AddNodeAttr("Index", DT_INT32, &slice_def); + slice_def.add_input(absl::StrCat(concat_node->name(), ":", concat_out_index)); + slice_def.add_input(absl::StrCat(begin_node->name(), ":0")); + slice_def.add_input(absl::StrCat(size_node->name(), ":0")); + Node* slice_node = graph->AddNode(slice_def, &s); + TF_RETURN_IF_ERROR(s); + // Add edges for slice node. + graph->AddEdge(concat_node, concat_out_index, slice_node, 0); + graph->AddEdge(begin_node, 0, slice_node, 1); + graph->AddEdge(size_node, 0, slice_node, 2); + return slice_node; +} + // Creates a set of Concat nodes that aggregates sharded outputs from TPUExecute // nodes into a single output. Sharded outputs are concatenated along row major // order. That is, tiled output along 0th dimension will be concatenated last. xla::StatusOr<Node*> CreateConcatNodesForRetval( - const xla::OpSharding& sharding, DataType dtype, int replica_id, + const xla::OpSharding& sharding, DataType dtype, + const PartialTensorShape& inferred_shape, int replica_id, const std::vector<NodeOut>& orig_inputs, Graph* graph, absl::string_view device) { std::map<int, int> split_dimension_map; TF_RETURN_IF_ERROR(GetDimensionIndicesAndNumSplitsFromSharding( sharding, &split_dimension_map)); - std::vector<NodeOut> inputs_to_sharded_retval = orig_inputs; + bool has_paddings = false; for (auto it = split_dimension_map.rbegin(); it != split_dimension_map.rend(); it++) { @@ -891,12 +1045,21 @@ xla::StatusOr<Node*> CreateConcatNodesForRetval( dim, num_splits, dtype, absl::StrCat("sharded_output/replica_", replica_id, "_dim_", dim), inputs, graph, device)); + int64 paddings = GetPadding(dim, num_splits, inferred_shape); + has_paddings |= paddings > 0; new_concat_nodes.emplace_back(NodeOut{concat_node, 0}); } inputs_to_sharded_retval = new_concat_nodes; } TF_RET_CHECK(inputs_to_sharded_retval.size() == 1); + if (has_paddings) { + TF_ASSIGN_OR_RETURN(Node * slice_node, + CreateSliceNode(dtype, inferred_shape, + inputs_to_sharded_retval.at(0).node, + /*concat_out_index*/ 0, graph, device)); + return slice_node; + } return inputs_to_sharded_retval.at(0).node; } @@ -2759,7 +2922,8 @@ Status DistributedTPURewritePass::BuildExecuteNodes( TF_ASSIGN_OR_RETURN( ShardedInputInfo sharded_input_info, CreateOrGetSplitNodesForInputSharding( - sharding, orig_arg_num, dtype, replica, + sharding, orig_arg_num, dtype, + arg_shapes[orig_arg_num].handle_shape, replica, edge->src_output(), edge->src(), control_predecessor, graph, &input_index_to_sharded_inputs)); NodeOut split_node_and_index = @@ -2840,7 +3004,8 @@ Status DistributedTPURewritePass::BuildExecuteNodes( ShardedInputInfo sharded_input_info, CreateOrGetSplitNodesForInputSharding( sharding, orig_arg_num, - arg_shapes[orig_arg_num].handle_type, replica, + arg_shapes[orig_arg_num].handle_type, + arg_shapes[orig_arg_num].handle_shape, replica, var_data.index, var_data.node, control_predecessor, graph, &input_index_to_sharded_inputs)); NodeOut split_node_and_index = @@ -2927,8 +3092,9 @@ Status DistributedTPURewritePass::BuildExecuteNodes( DataType dtype = e->src()->output_type(e->src_output()); TF_ASSIGN_OR_RETURN( Node * concat_node, - CreateConcatNodesForRetval(sharding, dtype, replica, orig_inputs, - graph, /*device=*/"")); + CreateConcatNodesForRetval( + sharding, dtype, /*inferred_shape*/ PartialTensorShape(), + replica, orig_inputs, graph, /*device=*/"")); const Edge* edge = replicate_output_edges[output_num]; Node* dst = edge->dst(); @@ -3009,8 +3175,9 @@ Status DistributedTPURewritePass::BuildExecuteNodes( TF_ASSIGN_OR_RETURN( Node * concat_node, CreateConcatNodesForRetval( - sharding, arg_shapes[orig_arg_num].handle_type, replica, - orig_inputs, graph, device)); + sharding, arg_shapes[orig_arg_num].handle_type, + arg_shapes[orig_arg_num].handle_shape, replica, orig_inputs, + graph, device)); // Populate VariableWrite. VariableWrite& write = variable_writes->at(core_variable_writes[i]); write.value = concat_node; diff --git a/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc b/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc index 5ea92d007d9..f1f12a716da 100644 --- a/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc +++ b/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc @@ -329,10 +329,20 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) { server_address_output); }); size_t server_address_output_size; + + TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params params; + params.struct_size = + TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params_SIZE; + params.priv = nullptr; + params.tpu_host_config_size = tpu_host_config.size(); + params.tpu_host_config = tpu_host_config.data(); + params.server_address_output_size = &server_address_output_size; + params.server_address_output = &server_address_output; + params.status = status; + tpu::OpsApiFn() ->TpuConfigurationApi_CompilationCacheServerAddressFromConfigFn( - tpu_host_config.size(), tpu_host_config.data(), - &server_address_output_size, &server_address_output, status); + ¶ms); OP_REQUIRES_OK(ctx, StatusFromTF_Status(status)); std::string server_address(server_address_output, diff --git a/tensorflow/core/tpu/kernels/tpu_pod_state.cc b/tensorflow/core/tpu/kernels/tpu_pod_state.cc index 72993baf57b..bccfa085d5e 100644 --- a/tensorflow/core/tpu/kernels/tpu_pod_state.cc +++ b/tensorflow/core/tpu/kernels/tpu_pod_state.cc @@ -78,9 +78,16 @@ Status GetServerAddressAndPort(std::string* server_address, int* serving_port) { }); size_t server_address_output_size; *serving_port = -1; - tpu::OpsApiFn()->TpuConfigurationApi_GetServerAddressAndPortFn( - &server_address_output_size, &server_address_output, serving_port, - status); + + TpuConfigurationApi_GetServerAddressAndPort_Params params; + params.struct_size = TpuConfigurationApi_GetServerAddressAndPort_Params_SIZE; + params.priv = nullptr; + params.server_address_output_size = &server_address_output_size; + params.server_address_output = &server_address_output; + params.port_output = serving_port; + params.status = status; + + tpu::OpsApiFn()->TpuConfigurationApi_GetServerAddressAndPortFn(¶ms); TF_RETURN_IF_ERROR(StatusFromTF_Status(status)); *server_address = std::string(server_address_output, server_address_output_size); diff --git a/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc b/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc index fc15d71dfd8..52fa3bfa8f1 100644 --- a/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc +++ b/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc @@ -22,58 +22,6 @@ limitations under the License. namespace tensorflow { namespace { -// TODO(b/32945756): Add a scatter op in XLA and move this to a HLO optimization -// pass. Optimization for UnsortedSegmentSum on TPU: use k-hot matmul. This -// optimization requires: -// 1. data has dtype supported by TPU matmul and has rank of 1 or 2. -// 2. indices has rank of 1. -// 3. matmul op count is less than 800 billion. -// -// Example of calculating UnsortedSegmentSum by k-hot matmul: -// data shape [A, B] -// indices shape [A] -// num_segment N -// output shape [N, B] -// matmul op count N * A * B -// Step 1: create k-hot matrix -// k-hot matrix has shape of [A, N], where row i is responsible for -// collecting the sum of the i-th segment, concretely -// k-hot[i][j] = 1 if indices[i] = j -// Step 2: perform matmul -// the final result is obtained by multiplying k-hot matrix with data -// matrix, namely -// k-hot * data => result -// shape: [N, A] * [A, B] => [N, B] -xla::XlaOp KHotMatmul(XlaOpKernelContext* ctx, xla::XlaBuilder* builder, - const xla::XlaOp data, const xla::XlaOp indices, - int64 num_segments) { - DataType data_dtype = ctx->input_type(0); - xla::PrimitiveType indices_type = ctx->input_xla_type(1); - TensorShape data_shape = ctx->InputShape(0); - TensorShape indices_shape = ctx->InputShape(1); - xla::XlaOp linspace = xla::Iota(builder, indices_type, num_segments); - xla::XlaOp linspace_col = xla::Reshape(linspace, {num_segments, 1}); - TensorShape indices_row_shape = indices_shape; - indices_row_shape.InsertDim(0, 1); - xla::XlaOp indices_row = xla::Reshape(indices, indices_row_shape.dim_sizes()); - xla::XlaOp k_hot = xla::Eq(indices_row, linspace_col); - xla::XlaOp k_hot_with_data_dtype = - XlaHelpers::ConvertElementType(k_hot, data_dtype); - // F32 version of the KHotMatmul. It splits the F32 data into three - // BF16 partial data and run KHotMatmul for each of them. The final result - // is the summation of three BF16 results. - // Note that this still doesn't fully retain f32 precision. - // In particular, values smaller than 2^-111 may see loss of precision. - xla::PrecisionConfig precision_config; - if (data_dtype == DT_FLOAT) { - precision_config.add_operand_precision(xla::PrecisionConfig::HIGHEST); - } else { - CHECK_EQ(data_dtype, DT_BFLOAT16); - precision_config.add_operand_precision(xla::PrecisionConfig::DEFAULT); - } - precision_config.add_operand_precision(xla::PrecisionConfig::DEFAULT); - return xla::Dot(k_hot_with_data_dtype, data, &precision_config); -} class UnsortedSegmentSum : public XlaOpKernel { public: diff --git a/tensorflow/core/tpu/tpu_ops_c_api.h b/tensorflow/core/tpu/tpu_ops_c_api.h index 45dd2d1e88d..f49438bde85 100644 --- a/tensorflow/core/tpu/tpu_ops_c_api.h +++ b/tensorflow/core/tpu/tpu_ops_c_api.h @@ -240,14 +240,42 @@ TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpuMemoryLimit(int64_t* memory_limit, TF_Status* status); TFTPU_CAPI_EXPORT void TpuConfigurationApi_RemoteCompilationCacheSizeInBytes( int64_t* cache_size_in_bytes); + +typedef struct TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params { + int32_t struct_size; + void* priv; + + size_t tpu_host_config_size; + const char* tpu_host_config; + + size_t* server_address_output_size; // out + char** server_address_output; // out + TF_Status* status; // out +} TpuConfigurationApi_CompilationCacheServerAddressFromConfig_Params; + +#define TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params_SIZE \ + (sizeof( \ + struct TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params)) + TFTPU_CAPI_EXPORT void TpuConfigurationApi_CompilationCacheServerAddressFromConfig( - size_t tpu_host_config_size, const char* tpu_host_config, - size_t* server_address_output_size, char** server_address_output, - TF_Status* status); + TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params* params); + +typedef struct TpuConfigurationApi_GetServerAddressAndPort_Params { + int32_t struct_size; + void* priv; + + size_t* server_address_output_size; // out + char** server_address_output; // out + int* port_output; // out + TF_Status* status; // out +} TpuConfigurationApi_GetServerAddressAndPort_Params; + +#define TpuConfigurationApi_GetServerAddressAndPort_Params_SIZE \ + (sizeof(struct TpuConfigurationApi_GetServerAddressAndPort_Params)) + TFTPU_CAPI_EXPORT void TpuConfigurationApi_GetServerAddressAndPort( - size_t* server_address_output_size, char** server_address_output, - int* port_output, TF_Status* status); + TpuConfigurationApi_GetServerAddressAndPort_Params* params); // Creates a new TPU program. TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_New(); diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index b725be34844..5ac978ff0b8 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -69,6 +69,11 @@ FRAMEWORK_LIB_HDRS = [ "stderr_reporter.h", ] +exports_files( + FRAMEWORK_LIB_HDRS, + visibility = ["//tensorflow/lite/core/shims:__subpackages__"], +) + cc_library( name = "version", hdrs = ["version.h"], @@ -164,9 +169,7 @@ cc_library( cc_library( name = "builtin_op_data", - hdrs = [ - "builtin_op_data.h", - ], + hdrs = ["builtin_op_data.h"], deps = ["//tensorflow/lite/c:common"], ) @@ -225,16 +228,10 @@ cc_library( ], ) +# The library that implements the full C++ API. +# See also 'framework' below, which is the corresponding public target. cc_library( name = "framework_lib", - srcs = [ - "core/subgraph.cc", - "graph_info.cc", - "interpreter.cc", - "interpreter_builder.cc", - "model_builder.cc", - "optional_debug_tools.cc", - ], hdrs = FRAMEWORK_LIB_HDRS, compatible_with = get_compatible_with_portable(), copts = tflite_copts() + TFLITE_DEFAULT_COPTS, @@ -244,9 +241,99 @@ cc_library( deps = [ ":allocation", ":arena_planner", + ":cc_api", ":external_cpu_backend_context", ":graph_info", ":kernel_api", + ":macros", + ":memory_planner", + ":minimal_logging", + ":mutable_op_resolver", + ":optional_debug_tools", + ":shared_library", + ":simple_memory_arena", + ":stderr_reporter", + ":string", + ":type_to_tflitetype", + ":util", + ":version", + "//tensorflow/lite/c:common", + "//tensorflow/lite/core/api", + "//tensorflow/lite/core/api:verifier", + "//tensorflow/lite/delegates:status", + "//tensorflow/lite/experimental/resource", + "//tensorflow/lite/kernels/internal:compatibility", + "//tensorflow/lite/profiling:platform_profiler", + "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", + ], + alwayslink = 1, # Why?? TODO(b/161243354): eliminate this. +) + +# The public target for the full C++ API. +# The deps listed here, other than ":framework_lib", are the interface dependencies +# (dependencies required by the header files). +# TODO(ahentz): investigate dependency on gemm_support requiring usage of tf_copts. +cc_library( + name = "framework", + srcs = [], + hdrs = FRAMEWORK_LIB_HDRS, + compatible_with = get_compatible_with_portable(), + copts = tflite_copts() + TFLITE_DEFAULT_COPTS, + deps = [ + ":allocation", + ":arena_planner", + ":cc_api", + ":external_cpu_backend_context", + ":framework_lib", + ":graph_info", + ":memory_planner", + ":minimal_logging", + ":simple_memory_arena", + ":string", + ":type_to_tflitetype", + ":util", + ":version", + "//tensorflow/lite/c:common", + "//tensorflow/lite/core/api", + "//tensorflow/lite/core/api:verifier", + "//tensorflow/lite/experimental/resource", + "//tensorflow/lite/schema:schema_fbs", + ], +) + +# The key parts of the C++ API. This target defines the TF Lite classes for +# loading models and interpreting them. +cc_library( + name = "cc_api", + srcs = [ + "core/subgraph.cc", + "graph_info.cc", + "interpreter.cc", + "interpreter_builder.cc", + "model_builder.cc", + ], + hdrs = [ + "core/subgraph.h", + "graph_info.h", + "interpreter.h", + "interpreter_builder.h", + "model.h", + "model_builder.h", + ], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts() + TFLITE_DEFAULT_COPTS, + visibility = [ + "//tensorflow/lite/core/shims:__subpackages__", + "//tensorflow/lite/kernels:__subpackages__", + ], + deps = [ + ":allocation", + ":arena_planner", + ":external_cpu_backend_context", + ":graph_info", + ":kernel_api", + ":macros", ":memory_planner", ":minimal_logging", ":mutable_op_resolver", @@ -267,34 +354,24 @@ cc_library( "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_utils", ], - alwayslink = 1, + alwayslink = 1, # Why?? TODO(b/161243354): eliminate this. ) -# TODO(ahentz): investigate dependency on gemm_support requiring usage of tf_copts. cc_library( - name = "framework", + name = "optional_debug_tools", srcs = [ + "optional_debug_tools.cc", ], - hdrs = FRAMEWORK_LIB_HDRS, + hdrs = ["optional_debug_tools.h"], compatible_with = get_compatible_with_portable(), copts = tflite_copts() + TFLITE_DEFAULT_COPTS, + visibility = [ + "//visibility:public", + ], deps = [ - ":allocation", - ":arena_planner", - ":external_cpu_backend_context", - ":framework_lib", - ":graph_info", - ":memory_planner", - ":minimal_logging", - ":simple_memory_arena", - ":string", - ":type_to_tflitetype", - ":util", - ":version", + ":macros", + "//tensorflow/lite:cc_api", "//tensorflow/lite/c:common", - "//tensorflow/lite/core/api", - "//tensorflow/lite/core/api:verifier", - "//tensorflow/lite/experimental/resource", "//tensorflow/lite/schema:schema_fbs", ], ) @@ -448,9 +525,6 @@ cc_test( size = "small", srcs = ["string_util_test.cc"], features = ["-dynamic_link_test_srcs"], # see go/dynamic_link_test_srcs - tags = [ - "tflite_not_portable_ios", # TODO(b/117786830) - ], deps = [ ":framework", ":string_util", @@ -469,7 +543,7 @@ cc_test( ], features = ["-dynamic_link_test_srcs"], # see go/dynamic_link_test_srcs tags = [ - "tflite_not_portable_ios", # TODO(b/117786830) + "tflite_not_portable_ios", # TODO(b/173711739) "tflite_smoke_test", ], deps = [ @@ -497,9 +571,6 @@ cc_test( size = "small", srcs = ["graph_info_test.cc"], features = ["-dynamic_link_test_srcs"], # see go/dynamic_link_test_srcs - tags = [ - "tflite_not_portable_ios", # TODO(b/117786830) - ], deps = [ ":framework", "//tensorflow/lite/testing:util", @@ -513,9 +584,6 @@ cc_test( size = "small", srcs = ["simple_memory_arena_test.cc"], features = ["-dynamic_link_test_srcs"], # see go/dynamic_link_test_srcs - tags = [ - "tflite_not_portable_ios", # TODO(b/117786830) - ], deps = [ ":simple_memory_arena", "//tensorflow/core:tflite_portable_logging", @@ -610,9 +678,6 @@ cc_test( size = "small", srcs = ["mutable_op_resolver_test.cc"], features = ["-dynamic_link_test_srcs"], # see go/dynamic_link_test_srcs - tags = [ - "tflite_not_portable_ios", # TODO(b/117786830) - ], deps = [ ":framework", "//tensorflow/lite/testing:util", @@ -647,9 +712,6 @@ cc_test( size = "small", srcs = ["util_test.cc"], features = ["-dynamic_link_test_srcs"], # see go/dynamic_link_test_srcs - tags = [ - "tflite_not_portable_ios", # TODO(b/117786830) - ], deps = [ ":util", "//tensorflow/lite/c:common", @@ -717,7 +779,7 @@ cc_test( size = "small", srcs = ["minimal_logging_test.cc"], tags = [ - "tflite_not_portable_ios", # TODO(b/117786830) + "tflite_not_portable_ios", # TODO(b/173711739) ], deps = [ ":minimal_logging", @@ -735,6 +797,7 @@ cc_library( cc_library( name = "macros", hdrs = ["core/macros.h"], + compatible_with = get_compatible_with_portable(), ) cc_library( diff --git a/tensorflow/lite/CMakeLists.txt b/tensorflow/lite/CMakeLists.txt index ebf823fa57d..3f2bb836320 100644 --- a/tensorflow/lite/CMakeLists.txt +++ b/tensorflow/lite/CMakeLists.txt @@ -27,6 +27,12 @@ # - Host Tools (i.e conversion / analysis tools etc.) cmake_minimum_required(VERSION 3.16) +if(NOT CMAKE_BUILD_TYPE) + message(STATUS "Setting build type to Release, for debug builds use" + "'-DCMAKE_BUILD_TYPE=Debug'.") + set(CMAKE_BUILD_TYPE "Release") +endif() + # Double colon in target name means ALIAS or IMPORTED target. cmake_policy(SET CMP0028 NEW) # Enable MACOSX_RPATH (@rpath) for built dynamic libraries. @@ -197,9 +203,14 @@ endif() if(TFLITE_ENABLE_GPU) find_package(opencl_headers REQUIRED) find_package(vulkan_headers REQUIRED) + # Android NDK already has OpenGL, EGL headers. + if(NOT "${CMAKE_SYSTEM_NAME}" STREQUAL "Android") + find_package(opengl_headers REQUIRED) + find_package(egl_headers REQUIRED) + endif() populate_tflite_source_vars( "delegates/gpu/cl" TFLITE_DELEGATES_GPU_CL_SRCS - FILTER "(_test|gl_interop|egl_sync)\\.(cc|h)$" + FILTER "(_test|gl_interop|gpu_api_delegate|egl_sync)\\.(cc|h)$" ) populate_tflite_source_vars( "delegates/gpu/cl/kernels" TFLITE_DELEGATES_GPU_CL_KERNELS_SRCS diff --git a/tensorflow/lite/c/BUILD b/tensorflow/lite/c/BUILD index e8db0dcf440..504655f7b28 100644 --- a/tensorflow/lite/c/BUILD +++ b/tensorflow/lite/c/BUILD @@ -68,7 +68,7 @@ cc_library( "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/kernels/internal:compatibility", ], - alwayslink = 1, + alwayslink = 1, # Why?? TODO(b/161243354): eliminate this. ) cc_library( @@ -83,7 +83,7 @@ cc_library( "//tensorflow/lite:framework", "//tensorflow/lite:kernel_api", ], - alwayslink = 1, + alwayslink = 1, # Why?? TODO(b/161243354): eliminate this. ) cc_test( @@ -127,8 +127,8 @@ cc_library( "common.h", ], compatible_with = get_compatible_with_portable(), - deps = ["//tensorflow/lite:builtin_ops"], - alwayslink = 1, + copts = tflite_copts(), + alwayslink = 1, # Why?? TODO(b/161243354): eliminate this. ) # For use with library targets that can't use relative paths. diff --git a/tensorflow/lite/core/shims/BUILD b/tensorflow/lite/core/shims/BUILD new file mode 100644 index 00000000000..05b2566eef3 --- /dev/null +++ b/tensorflow/lite/core/shims/BUILD @@ -0,0 +1,178 @@ +# Description: this package contains shim library targets that forward +# to the TF Lite C and C++ API targets. See README.md. + +load("//tensorflow:tensorflow.bzl", "if_not_windows") +load("//tensorflow/lite:build_def.bzl", "tflite_copts") +load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") +load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable") +load(":build_defs.bzl", "build_test") + +package( + default_visibility = ["//visibility:private"], + licenses = ["notice"], # Apache 2.0 +) + +TFLITE_DEFAULT_COPTS = if_not_windows([ + "-Wall", + "-Wno-comment", + "-Wno-extern-c-compat", +]) + +#------------------------------------------------------------------------------ +# C++ API + +FRAMEWORK_LIB_HDRS = [ + "//tensorflow/lite:allocation.h", + "//tensorflow/lite:context.h", + "//tensorflow/lite:context_util.h", + "//tensorflow/lite:core/macros.h", + "//tensorflow/lite:core/subgraph.h", + "//tensorflow/lite:error_reporter.h", + "//tensorflow/lite:graph_info.h", + "//tensorflow/lite:mutable_op_resolver.h", + "//tensorflow/lite:op_resolver.h", + "//tensorflow/lite:optional_debug_tools.h", + "//tensorflow/lite:stderr_reporter.h", +] + +CC_API_HDRS = [ + "cc/interpreter.h", + "cc/interpreter_builder.h", + "cc/model.h", + "cc/model_builder.h", +] + +cc_library( + name = "framework", + srcs = [], + hdrs = FRAMEWORK_LIB_HDRS + CC_API_HDRS, + compatible_with = get_compatible_with_portable(), + copts = tflite_copts() + TFLITE_DEFAULT_COPTS, + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/lite:allocation", + "//tensorflow/lite:arena_planner", + "//tensorflow/lite:external_cpu_backend_context", + "//tensorflow/lite:framework_lib", + "//tensorflow/lite:graph_info", + "//tensorflow/lite:kernel_api", + "//tensorflow/lite:memory_planner", + "//tensorflow/lite:minimal_logging", + "//tensorflow/lite:simple_memory_arena", + "//tensorflow/lite:string", + "//tensorflow/lite:type_to_tflitetype", + "//tensorflow/lite:util", + "//tensorflow/lite:version", + "//tensorflow/lite/c:common", + "//tensorflow/lite/core/api", + "//tensorflow/lite/core/api:verifier", + "//tensorflow/lite/experimental/resource", + "//tensorflow/lite/schema:schema_fbs", + ], +) + +cc_library( + name = "cc_api", + srcs = [], + hdrs = CC_API_HDRS, + compatible_with = get_compatible_with_portable(), + copts = tflite_copts() + TFLITE_DEFAULT_COPTS, + visibility = ["//tensorflow/lite:__pkg__"], + deps = [ + "//tensorflow/lite:allocation", + "//tensorflow/lite:arena_planner", + "//tensorflow/lite:cc_api", + "//tensorflow/lite:external_cpu_backend_context", + "//tensorflow/lite:graph_info", + "//tensorflow/lite:kernel_api", + "//tensorflow/lite:macros", + "//tensorflow/lite:memory_planner", + "//tensorflow/lite:minimal_logging", + "//tensorflow/lite:simple_memory_arena", + "//tensorflow/lite:string", + "//tensorflow/lite:type_to_tflitetype", + "//tensorflow/lite:util", + "//tensorflow/lite:version", + "//tensorflow/lite/c:common", + "//tensorflow/lite/core/api", + "//tensorflow/lite/core/api:verifier", + "//tensorflow/lite/experimental/resource", + "//tensorflow/lite/schema:schema_fbs", + ], +) + +cc_library( + name = "builtin_ops", + hdrs = [ + "builtin_op_kernels.h", + "fully_connected.h", + "register.h", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + ":builtin_ops_impl", + "//tensorflow/lite:framework_lib", + "//tensorflow/lite/c:common", + ], +) + +build_test( + name = "cc_api_build_test", + targets = [ + ":cc_api", + ":framework", + ], +) + +#------------------------------------------------------------------------------ +# C API + +cc_library( + name = "c_api", + hdrs = ["c/c_api.h"], + compatible_with = get_compatible_with_portable(), + copts = TFLITE_DEFAULT_COPTS, + visibility = ["//visibility:public"], + deps = ["//tensorflow/lite/c:c_api"], +) + +cc_library( + name = "c_api_experimental", + hdrs = ["c/c_api_experimental.h"], + compatible_with = get_compatible_with_portable(), + copts = TFLITE_DEFAULT_COPTS, + visibility = ["//visibility:public"], + deps = ["//tensorflow/lite/c:c_api_experimental"], +) + +cc_library( + name = "common", + hdrs = ["c/common.h"], + compatible_with = get_compatible_with_portable(), + copts = TFLITE_DEFAULT_COPTS, + visibility = ["//visibility:public"], + deps = ["//tensorflow/lite/c:common"], +) + +cc_library( + name = "builtin_op_data", + hdrs = ["c/builtin_op_data.h"], + compatible_with = get_compatible_with_portable(), + copts = TFLITE_DEFAULT_COPTS, + visibility = ["//visibility:public"], + deps = ["//tensorflow/lite/c:common"], +) + +build_test( + name = "c_api_build_test", + targets = [ + ":builtin_op_data", + ":c_api", + ":c_api_experimental", + ":common", + ], +) + +#------------------------------------------------------------------------------ + +tflite_portable_test_suite() diff --git a/tensorflow/lite/core/shims/README.md b/tensorflow/lite/core/shims/README.md new file mode 100644 index 00000000000..2a95d5cb9c6 --- /dev/null +++ b/tensorflow/lite/core/shims/README.md @@ -0,0 +1,11 @@ +This directory contains shim header files that forward to the TF Lite +C API and to the key headers of the TF Lite C++ API. + +The intent is that the shims in this directory could be modified to optionally +redirect to a different implementation of those APIs (for example, +one built into the underlying operating system platform). + +These should be used as follows: #includes from .cc files that are +_implementing_ the shimmed TF Lite APIs should include the regular TF +Lite API headers. #includes from files that are _using_ the shimmed +APIs should include the shimmed headers. diff --git a/tensorflow/lite/core/shims/build_defs.bzl b/tensorflow/lite/core/shims/build_defs.bzl new file mode 100644 index 00000000000..330e2e9cc95 --- /dev/null +++ b/tensorflow/lite/core/shims/build_defs.bzl @@ -0,0 +1,22 @@ +"""A simple portable implementation of build_test.""" + +def build_test(name, targets, visibility = None): + """Generates a test that just verifies that the specified targets can be built.""" + + # Generate an sh_test rule that lists the specified targets as data, + # (thus forcing those targets to be built before the test can be run) + # and that runs a script which always succeeds. + native.sh_test( + name = name, + srcs = [name + ".sh"], + data = targets, + visibility = visibility, + ) + + # Generate the script which always succeeds. We just generate an empty script. + native.genrule( + name = name + "_gen_sh", + outs = [name + ".sh"], + cmd = "> $@", + visibility = ["//visibility:private"], + ) diff --git a/tensorflow/lite/core/shims/c/builtin_op_data.h b/tensorflow/lite/core/shims/c/builtin_op_data.h new file mode 100644 index 00000000000..f75b0144be7 --- /dev/null +++ b/tensorflow/lite/core/shims/c/builtin_op_data.h @@ -0,0 +1,20 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_BUILTIN_OP_DATA_H_ +#define PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_BUILTIN_OP_DATA_H_ + +#include "tensorflow/lite/c/builtin_op_data.h" + +#endif // PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_BUILTIN_OP_DATA_H_ diff --git a/tensorflow/lite/core/shims/c/c_api.h b/tensorflow/lite/core/shims/c/c_api.h new file mode 100644 index 00000000000..90e0147a479 --- /dev/null +++ b/tensorflow/lite/core/shims/c/c_api.h @@ -0,0 +1,20 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_C_API_H_ +#define PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_C_API_H_ + +#include "tensorflow/lite/c/c_api.h" + +#endif // PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_C_API_H_ diff --git a/tensorflow/lite/micro/xtensa_hifimini_staging/micro_time.cc b/tensorflow/lite/core/shims/c/c_api_experimental.h similarity index 68% rename from tensorflow/lite/micro/xtensa_hifimini_staging/micro_time.cc rename to tensorflow/lite/core/shims/c/c_api_experimental.h index 6f3844c1fe3..ceb1cb7e5bb 100644 --- a/tensorflow/lite/micro/xtensa_hifimini_staging/micro_time.cc +++ b/tensorflow/lite/core/shims/c/c_api_experimental.h @@ -12,17 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#ifndef PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_C_API_EXPERIMENTAL_H_ +#define PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_C_API_EXPERIMENTAL_H_ -// Xtensa implementation of micro_timer. -// To include this with make, add TAGS=xtensa-xpg. -#include "tensorflow/lite/micro/micro_time.h" +#include "tensorflow/lite/c/c_api_experimental.h" -#include <time.h> - -namespace tflite { - -int32_t ticks_per_second() { return CLOCKS_PER_SEC; } - -int32_t GetCurrentTimeTicks() { return clock(); } - -} // namespace tflite +#endif // PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_C_API_EXPERIMENTAL_H_ diff --git a/tensorflow/lite/core/shims/c/common.h b/tensorflow/lite/core/shims/c/common.h new file mode 100644 index 00000000000..a531546ec9e --- /dev/null +++ b/tensorflow/lite/core/shims/c/common.h @@ -0,0 +1,20 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_COMMON_H_ +#define PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_COMMON_H_ + +#include "tensorflow/lite/c/common.h" + +#endif // PARTY_TENSORFLOW_LITE_CORE_SHIMS_C_COMMON_H_ diff --git a/tensorflow/lite/core/shims/cc/interpreter.h b/tensorflow/lite/core/shims/cc/interpreter.h new file mode 100644 index 00000000000..20e8f10ce94 --- /dev/null +++ b/tensorflow/lite/core/shims/cc/interpreter.h @@ -0,0 +1,26 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_CORE_SHIMS_CC_INTERPRETER_H_ +#define TENSORFLOW_LITE_CORE_SHIMS_CC_INTERPRETER_H_ + +/// For documentation, see third_party/tensorflow/lite/interpreter.h. + +#include "tensorflow/lite/interpreter.h" + +namespace tflite_shims { +using Interpreter = ::tflite::Interpreter; +} // namespace tflite_shims + +#endif // TENSORFLOW_LITE_CORE_SHIMS_CC_INTERPRETER_H_ diff --git a/tensorflow/lite/core/shims/cc/interpreter_builder.h b/tensorflow/lite/core/shims/cc/interpreter_builder.h new file mode 100644 index 00000000000..891a5747f7e --- /dev/null +++ b/tensorflow/lite/core/shims/cc/interpreter_builder.h @@ -0,0 +1,26 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_CORE_SHIMS_CC_INTERPRETER_BUILDER_H_ +#define TENSORFLOW_LITE_CORE_SHIMS_CC_INTERPRETER_BUILDER_H_ + +/// For documentation, see third_party/tensorflow/lite/interpreter_builder.h. + +#include "tensorflow/lite/interpreter_builder.h" + +namespace tflite_shims { +using InterpreterBuilder = ::tflite::InterpreterBuilder; +} // namespace tflite_shims + +#endif // TENSORFLOW_LITE_CORE_SHIMS_CC_INTERPRETER_BUILDER_H_ diff --git a/tensorflow/lite/core/shims/cc/kernels/register.h b/tensorflow/lite/core/shims/cc/kernels/register.h new file mode 100644 index 00000000000..d5a52034fb7 --- /dev/null +++ b/tensorflow/lite/core/shims/cc/kernels/register.h @@ -0,0 +1,20 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_CORE_SHIMS_CC_KERNELS_REGISTER_H_ +#define TENSORFLOW_LITE_CORE_SHIMS_CC_KERNELS_REGISTER_H_ + +#include "tensorflow/lite/kernels/register.h" + +#endif // TENSORFLOW_LITE_CORE_SHIMS_CC_KERNELS_REGISTER_H_ diff --git a/tensorflow/lite/core/shims/cc/model.h b/tensorflow/lite/core/shims/cc/model.h new file mode 100644 index 00000000000..8a016319510 --- /dev/null +++ b/tensorflow/lite/core/shims/cc/model.h @@ -0,0 +1,23 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_CORE_SHIMS_CC_MODEL_H_ +#define TENSORFLOW_LITE_CORE_SHIMS_CC_MODEL_H_ + +/// For documentation, see third_party/tensorflow/lite/model.h. + +#include "tensorflow/lite/core/shims/cc/interpreter_builder.h" +#include "tensorflow/lite/core/shims/cc/model_builder.h" + +#endif // TENSORFLOW_LITE_CORE_SHIMS_CC_MODEL_H_ diff --git a/tensorflow/lite/core/shims/cc/model_builder.h b/tensorflow/lite/core/shims/cc/model_builder.h new file mode 100644 index 00000000000..3ee9c5a3016 --- /dev/null +++ b/tensorflow/lite/core/shims/cc/model_builder.h @@ -0,0 +1,26 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_CORE_SHIMS_CC_MODEL_BUILDER_H_ +#define TENSORFLOW_LITE_CORE_SHIMS_CC_MODEL_BUILDER_H_ + +/// For documentation, see third_party/tensorflow/lite/model_builder.h. + +#include "tensorflow/lite/model_builder.h" + +namespace tflite_shims { +using FlatBufferModel = ::tflite::FlatBufferModel; +} // namespace tflite_shims + +#endif // TENSORFLOW_LITE_CORE_SHIMS_CC_MODEL_BUILDER_H_ diff --git a/tensorflow/lite/delegates/BUILD b/tensorflow/lite/delegates/BUILD index 240de7fef94..7c68b25cf14 100644 --- a/tensorflow/lite/delegates/BUILD +++ b/tensorflow/lite/delegates/BUILD @@ -71,9 +71,6 @@ cc_test( size = "small", srcs = ["delegate_test.cc"], features = ["-dynamic_link_test_srcs"], # see go/dynamic_link_test_srcs - tags = [ - "tflite_not_portable_ios", # TODO(b/117786830) - ], deps = [ ":interpreter_utils", ":utils", diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD index 58a36ca7a64..cece0d59768 100644 --- a/tensorflow/lite/delegates/gpu/cl/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/BUILD @@ -99,13 +99,13 @@ cc_library( deps = [ ":buffer", ":cl_context", - ":device_info", ":gpu_object", ":linear_storage", ":tensor", ":texture2d", "//tensorflow/lite/delegates/gpu/common:access_type", "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:gpu_info", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/common:util", @@ -126,8 +126,8 @@ cc_test( deps = [ ":buffer", ":cl_arguments", - ":device_info", ":gpu_object", + "//tensorflow/lite/delegates/gpu/common:gpu_info", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", ], @@ -171,9 +171,9 @@ cc_library( srcs = ["cl_device.cc"], hdrs = ["cl_device.h"], deps = [ - ":device_info", ":opencl_wrapper", ":util", + "//tensorflow/lite/delegates/gpu/common:gpu_info", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:types", "@com_google_absl//absl/strings", @@ -247,8 +247,8 @@ cc_library( ":cl_kernel", ":program_cache", ":tensor", - "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", ], ) @@ -276,16 +276,6 @@ flatbuffer_cc_library( ], ) -cc_library( - name = "device_info", - srcs = ["device_info.cc"], - hdrs = ["device_info.h"], - deps = [ - "//tensorflow/lite/delegates/gpu/common:data_type", - "@com_google_absl//absl/strings", - ], -) - cc_library( name = "egl_sync", srcs = ["egl_sync.cc"], @@ -307,11 +297,11 @@ cc_library( ":cl_command_queue", ":cl_context", ":cl_device", - ":device_info", ":program_cache", ":tensor", ":util", "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:gpu_info", "//tensorflow/lite/delegates/gpu/common:precision", "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", @@ -406,7 +396,6 @@ cc_library( ":serialization_cc_fbs", ":storage_type_util", ":tensor", - "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/cl/selectors:operation_selector", "//tensorflow/lite/delegates/gpu/cl/selectors:special_selector", "//tensorflow/lite/delegates/gpu/common:data_type", @@ -424,6 +413,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common/task:arguments", "//tensorflow/lite/delegates/gpu/common/task:buffer_desc", "//tensorflow/lite/delegates/gpu/common/task:gpu_object_desc", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "//tensorflow/lite/delegates/gpu/common/task:tensor_desc", "//tensorflow/lite/delegates/gpu/common/task:tensor_linear_desc", "//tensorflow/lite/delegates/gpu/common/task:texture2d_desc", @@ -509,8 +499,8 @@ cc_library( srcs = ["storage_type_util.cc"], hdrs = ["storage_type_util.h"], deps = [ - ":device_info", "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:gpu_info", "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:util", "//tensorflow/lite/delegates/gpu/common/task:tensor_desc", diff --git a/tensorflow/lite/delegates/gpu/cl/cl_arguments.h b/tensorflow/lite/delegates/gpu/cl/cl_arguments.h index 673b24f63e2..170c5538e02 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_arguments.h +++ b/tensorflow/lite/delegates/gpu/cl/cl_arguments.h @@ -21,8 +21,8 @@ limitations under the License. #include <vector> #include "tensorflow/lite/delegates/gpu/cl/cl_context.h" -#include "tensorflow/lite/delegates/gpu/cl/device_info.h" #include "tensorflow/lite/delegates/gpu/cl/gpu_object.h" +#include "tensorflow/lite/delegates/gpu/common/gpu_info.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/task/arguments.h" diff --git a/tensorflow/lite/delegates/gpu/cl/cl_arguments_test.cc b/tensorflow/lite/delegates/gpu/cl/cl_arguments_test.cc index ddca3d4dc3a..69a490b9493 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_arguments_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/cl_arguments_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include <gtest/gtest.h> #include "absl/strings/match.h" #include "tensorflow/lite/delegates/gpu/cl/buffer.h" -#include "tensorflow/lite/delegates/gpu/cl/device_info.h" #include "tensorflow/lite/delegates/gpu/cl/gpu_object.h" +#include "tensorflow/lite/delegates/gpu/common/gpu_info.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/cl/cl_context.cc b/tensorflow/lite/delegates/gpu/cl/cl_context.cc index 32a5e43d799..e1cc54ba574 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_context.cc +++ b/tensorflow/lite/delegates/gpu/cl/cl_context.cc @@ -54,29 +54,29 @@ void AddSupportedImageFormats(cl_context context, GpuInfo* info) { auto supported_formats = GetSupportedImage2DFormats(context, CL_MEM_READ_WRITE); for (auto format : supported_formats) { - info->supports_r_f16_tex2d = - info->supports_r_f16_tex2d || + info->opencl_info.supports_r_f16_tex2d = + info->opencl_info.supports_r_f16_tex2d || IsEqualToImageFormat(format, DataType::FLOAT16, 1); - info->supports_rg_f16_tex2d = - info->supports_rg_f16_tex2d || + info->opencl_info.supports_rg_f16_tex2d = + info->opencl_info.supports_rg_f16_tex2d || IsEqualToImageFormat(format, DataType::FLOAT16, 2); - info->supports_rgb_f16_tex2d = - info->supports_rgb_f16_tex2d || + info->opencl_info.supports_rgb_f16_tex2d = + info->opencl_info.supports_rgb_f16_tex2d || IsEqualToImageFormat(format, DataType::FLOAT16, 3); - info->supports_rgba_f16_tex2d = - info->supports_rgba_f16_tex2d || + info->opencl_info.supports_rgba_f16_tex2d = + info->opencl_info.supports_rgba_f16_tex2d || IsEqualToImageFormat(format, DataType::FLOAT16, 4); - info->supports_r_f32_tex2d = - info->supports_r_f32_tex2d || + info->opencl_info.supports_r_f32_tex2d = + info->opencl_info.supports_r_f32_tex2d || IsEqualToImageFormat(format, DataType::FLOAT32, 1); - info->supports_rg_f32_tex2d = - info->supports_rg_f32_tex2d || + info->opencl_info.supports_rg_f32_tex2d = + info->opencl_info.supports_rg_f32_tex2d || IsEqualToImageFormat(format, DataType::FLOAT32, 2); - info->supports_rgb_f32_tex2d = - info->supports_rgb_f32_tex2d || + info->opencl_info.supports_rgb_f32_tex2d = + info->opencl_info.supports_rgb_f32_tex2d || IsEqualToImageFormat(format, DataType::FLOAT32, 3); - info->supports_rgba_f32_tex2d = - info->supports_rgba_f32_tex2d || + info->opencl_info.supports_rgba_f32_tex2d = + info->opencl_info.supports_rgba_f32_tex2d || IsEqualToImageFormat(format, DataType::FLOAT32, 4); } } @@ -148,7 +148,7 @@ absl::Status CreateCLGLContext(const CLDevice& device, cl_context_properties egl_context, cl_context_properties egl_display, CLContext* result) { - if (!device.SupportsExtension("cl_khr_gl_sharing")) { + if (!device.GetInfo().SupportsExtension("cl_khr_gl_sharing")) { return absl::UnavailableError("Device doesn't support CL-GL sharing."); } cl_context_properties platform = diff --git a/tensorflow/lite/delegates/gpu/cl/cl_device.cc b/tensorflow/lite/delegates/gpu/cl/cl_device.cc index 26206529235..1bd5db7b646 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_device.cc +++ b/tensorflow/lite/delegates/gpu/cl/cl_device.cc @@ -89,34 +89,34 @@ void GetDeviceWorkDimsSizes(cl_device_id id, int3* result) { result->z = limits[2]; } -OpenCLVersion ParseCLVersion(const std::string& version) { +OpenClVersion ParseCLVersion(const std::string& version) { const auto first_dot_pos = version.find_first_of('.'); if (first_dot_pos == std::string::npos) { - return OpenCLVersion::CL_1_0; + return OpenClVersion::kCl1_0; } const int major = version[first_dot_pos - 1] - '0'; const int minor = version[first_dot_pos + 1] - '0'; if (major == 1) { if (minor == 2) { - return OpenCLVersion::CL_1_2; + return OpenClVersion::kCl1_2; } else if (minor == 1) { - return OpenCLVersion::CL_1_1; + return OpenClVersion::kCl1_1; } else { - return OpenCLVersion::CL_1_0; + return OpenClVersion::kCl1_0; } } else if (major == 2) { if (minor == 2) { - return OpenCLVersion::CL_2_2; + return OpenClVersion::kCl2_2; } else if (minor == 1) { - return OpenCLVersion::CL_2_1; + return OpenClVersion::kCl2_1; } else { - return OpenCLVersion::CL_2_0; + return OpenClVersion::kCl2_0; } } else if (major == 3) { - return OpenCLVersion::CL_3_0; + return OpenClVersion::kCl3_0; } else { - return OpenCLVersion::CL_1_0; + return OpenClVersion::kCl1_0; } } @@ -162,87 +162,90 @@ GpuInfo GpuInfoFromDeviceID(cl_device_id id) { const auto vendor_name = GetDeviceInfo<std::string>(id, CL_DEVICE_VENDOR); const auto opencl_c_version = GetDeviceInfo<std::string>(id, CL_DEVICE_OPENCL_C_VERSION); - info.gpu_vendor = ParseVendor(device_name, vendor_name); + info.gpu_api = GpuApi::kOpenCl; + info.vendor = ParseVendor(device_name, vendor_name); if (info.IsAdreno()) { info.adreno_info = AdrenoInfo(opencl_c_version); } else if (info.IsMali()) { info.mali_info = MaliInfo(device_name); } - info.cl_version = ParseCLVersion(opencl_c_version); - info.extensions = + info.opencl_info.cl_version = ParseCLVersion(opencl_c_version); + info.opencl_info.extensions = absl::StrSplit(GetDeviceInfo<std::string>(id, CL_DEVICE_EXTENSIONS), ' '); - info.supports_fp16 = false; - info.supports_image3d_writes = false; - for (const auto& ext : info.extensions) { + info.opencl_info.supports_fp16 = false; + info.opencl_info.supports_image3d_writes = false; + for (const auto& ext : info.opencl_info.extensions) { if (ext == "cl_khr_fp16") { - info.supports_fp16 = true; + info.opencl_info.supports_fp16 = true; } if (ext == "cl_khr_3d_image_writes") { - info.supports_image3d_writes = true; + info.opencl_info.supports_image3d_writes = true; } } cl_device_fp_config f32_config = GetDeviceInfo<cl_device_fp_config>(id, CL_DEVICE_SINGLE_FP_CONFIG); - info.supports_fp32_rtn = f32_config & CL_FP_ROUND_TO_NEAREST; + info.opencl_info.supports_fp32_rtn = f32_config & CL_FP_ROUND_TO_NEAREST; - if (info.supports_fp16) { + if (info.opencl_info.supports_fp16) { cl_device_fp_config f16_config; auto status = GetDeviceInfo<cl_device_fp_config>( id, CL_DEVICE_HALF_FP_CONFIG, &f16_config); // AMD supports cl_khr_fp16 but CL_DEVICE_HALF_FP_CONFIG is empty. if (status.ok() && !info.IsAMD()) { - info.supports_fp16_rtn = f16_config & CL_FP_ROUND_TO_NEAREST; + info.opencl_info.supports_fp16_rtn = f16_config & CL_FP_ROUND_TO_NEAREST; } else { // happens on PowerVR f16_config = f32_config; - info.supports_fp16_rtn = info.supports_fp32_rtn; + info.opencl_info.supports_fp16_rtn = info.opencl_info.supports_fp32_rtn; } } else { - info.supports_fp16_rtn = false; + info.opencl_info.supports_fp16_rtn = false; } - if (info.IsPowerVR() && !info.supports_fp16) { + if (info.IsPowerVR() && !info.opencl_info.supports_fp16) { // PowerVR doesn't have full support of fp16 and so doesn't list this // extension. But it can support fp16 in MADs and as buffers/textures types, // so we will use it. - info.supports_fp16 = true; - info.supports_fp16_rtn = info.supports_fp32_rtn; + info.opencl_info.supports_fp16 = true; + info.opencl_info.supports_fp16_rtn = info.opencl_info.supports_fp32_rtn; } - if (!info.supports_image3d_writes && + if (!info.opencl_info.supports_image3d_writes && ((info.IsAdreno() && info.adreno_info.IsAdreno4xx()) || info.IsNvidia())) { // in local tests Adreno 430 can write in image 3d, at least on small sizes, // but it doesn't have cl_khr_3d_image_writes in list of available // extensions // The same for NVidia - info.supports_image3d_writes = true; + info.opencl_info.supports_image3d_writes = true; } - info.compute_units_count = + info.opencl_info.compute_units_count = GetDeviceInfo<cl_uint>(id, CL_DEVICE_MAX_COMPUTE_UNITS); - info.image2d_max_width = + info.opencl_info.image2d_max_width = GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE2D_MAX_WIDTH); - info.image2d_max_height = + info.opencl_info.image2d_max_height = GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE2D_MAX_HEIGHT); - info.buffer_max_size = + info.opencl_info.buffer_max_size = GetDeviceInfo<cl_ulong>(id, CL_DEVICE_MAX_MEM_ALLOC_SIZE); - if (info.cl_version >= OpenCLVersion::CL_1_2) { - info.image_buffer_max_size = + if (info.opencl_info.cl_version >= OpenClVersion::kCl1_2) { + info.opencl_info.image_buffer_max_size = GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE_MAX_BUFFER_SIZE); - info.image_array_max_layers = + info.opencl_info.image_array_max_layers = GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE_MAX_ARRAY_SIZE); } - info.image3d_max_width = + info.opencl_info.image3d_max_width = GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE3D_MAX_WIDTH); - info.image3d_max_height = + info.opencl_info.image3d_max_height = GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE2D_MAX_HEIGHT); - info.image3d_max_depth = + info.opencl_info.image3d_max_depth = GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE3D_MAX_DEPTH); int3 max_work_group_sizes; GetDeviceWorkDimsSizes(id, &max_work_group_sizes); - info.max_work_group_size_x = max_work_group_sizes.x; - info.max_work_group_size_y = max_work_group_sizes.y; - info.max_work_group_size_z = max_work_group_sizes.z; + info.opencl_info.max_work_group_size_x = max_work_group_sizes.x; + info.opencl_info.max_work_group_size_y = max_work_group_sizes.y; + info.opencl_info.max_work_group_size_z = max_work_group_sizes.z; + info.opencl_info.max_work_group_total_size = + GetDeviceInfo<size_t>(id, CL_DEVICE_MAX_WORK_GROUP_SIZE); if (info.IsIntel()) { if (info.SupportsExtension("cl_intel_required_subgroup_size")) { @@ -300,48 +303,10 @@ CLDevice& CLDevice::operator=(CLDevice&& device) { return *this; } -bool CLDevice::SupportsFP16() const { return info_.supports_fp16; } - -bool CLDevice::SupportsExtension(const std::string& extension) const { - return info_.SupportsExtension(extension); -} - -bool CLDevice::SupportsTextureArray() const { - return info_.SupportsTextureArray(); -} - -bool CLDevice::SupportsImageBuffer() const { - return info_.SupportsImageBuffer(); -} - -bool CLDevice::SupportsImage3D() const { return info_.SupportsImage3D(); } - -bool CLDevice::SupportsFP32RTN() const { return info_.supports_fp32_rtn; } - -bool CLDevice::SupportsFP16RTN() const { return info_.supports_fp16_rtn; } - std::string CLDevice::GetPlatformVersion() const { return GetPlatformInfo(platform_id_, CL_PLATFORM_VERSION); } -bool CLDevice::IsCL20OrHigher() const { return info_.IsCL20OrHigher(); } - -bool CLDevice::SupportsSubGroupWithSize(int sub_group_size) const { - return info_.SupportsSubGroupWithSize(sub_group_size); -} - -bool CLDevice::IsAdreno() const { return info_.IsAdreno(); } - -bool CLDevice::IsPowerVR() const { return info_.IsPowerVR(); } - -bool CLDevice::IsNvidia() const { return info_.IsNvidia(); } - -bool CLDevice::IsMali() const { return info_.IsMali(); } - -bool CLDevice::IsAMD() const { return info_.IsAMD(); } - -bool CLDevice::IsIntel() const { return info_.IsIntel(); } - void CLDevice::DisableOneLayerTextureArray() { info_.adreno_info.support_one_layer_texture_array = false; } diff --git a/tensorflow/lite/delegates/gpu/cl/cl_device.h b/tensorflow/lite/delegates/gpu/cl/cl_device.h index 3614e2211f1..b94704b40d6 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_device.h +++ b/tensorflow/lite/delegates/gpu/cl/cl_device.h @@ -19,9 +19,9 @@ limitations under the License. #include <string> #include <vector> -#include "tensorflow/lite/delegates/gpu/cl/device_info.h" #include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h" #include "tensorflow/lite/delegates/gpu/cl/util.h" +#include "tensorflow/lite/delegates/gpu/common/gpu_info.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/types.h" @@ -46,24 +46,6 @@ class CLDevice { cl_platform_id platform() const { return platform_id_; } std::string GetPlatformVersion() const; - GpuVendor vendor() const { return info_.gpu_vendor; } - OpenCLVersion cl_version() const { return info_.cl_version; } - bool SupportsFP16() const; - bool SupportsTextureArray() const; - bool SupportsImageBuffer() const; - bool SupportsImage3D() const; - bool SupportsExtension(const std::string& extension) const; - bool SupportsFP32RTN() const; - bool SupportsFP16RTN() const; - bool IsCL20OrHigher() const; - bool SupportsSubGroupWithSize(int sub_group_size) const; - bool IsAdreno() const; - bool IsPowerVR() const; - bool IsNvidia() const; - bool IsMali() const; - bool IsAMD() const; - bool IsIntel() const; - // To track bug on some Adreno. b/131099086 void DisableOneLayerTextureArray(); diff --git a/tensorflow/lite/delegates/gpu/cl/cl_operation.h b/tensorflow/lite/delegates/gpu/cl/cl_operation.h index 3e4d40a5fa0..b357e2f087b 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_operation.h +++ b/tensorflow/lite/delegates/gpu/cl/cl_operation.h @@ -24,9 +24,9 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/cl_context.h" #include "tensorflow/lite/delegates/gpu/cl/cl_device.h" #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/program_cache.h" #include "tensorflow/lite/delegates/gpu/cl/tensor.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/cl/device_info.cc b/tensorflow/lite/delegates/gpu/cl/device_info.cc deleted file mode 100644 index 94c3cce4a74..00000000000 --- a/tensorflow/lite/delegates/gpu/cl/device_info.cc +++ /dev/null @@ -1,376 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/delegates/gpu/cl/device_info.h" - -#include <algorithm> -#include <string> -#include <vector> - -#include "absl/strings/numbers.h" -#include "absl/strings/str_split.h" - -namespace tflite { -namespace gpu { -namespace cl { -namespace { -AdrenoGpu GetAdrenoGpuVersion(const std::string& device_name) { - const std::map<std::string, AdrenoGpu> kMapping = { - // Adreno 6xx series - {"685", AdrenoGpu::kAdreno685}, - {"680", AdrenoGpu::kAdreno680}, - {"675", AdrenoGpu::kAdreno675}, - {"650", AdrenoGpu::kAdreno650}, - {"640", AdrenoGpu::kAdreno640}, - {"630", AdrenoGpu::kAdreno630}, - {"620", AdrenoGpu::kAdreno620}, - {"616", AdrenoGpu::kAdreno618}, - {"616", AdrenoGpu::kAdreno616}, - {"615", AdrenoGpu::kAdreno615}, - {"612", AdrenoGpu::kAdreno612}, - {"610", AdrenoGpu::kAdreno610}, - {"605", AdrenoGpu::kAdreno605}, - // Adreno 5xx series - {"540", AdrenoGpu::kAdreno540}, - {"530", AdrenoGpu::kAdreno530}, - {"512", AdrenoGpu::kAdreno512}, - {"510", AdrenoGpu::kAdreno510}, - {"509", AdrenoGpu::kAdreno509}, - {"508", AdrenoGpu::kAdreno508}, - {"506", AdrenoGpu::kAdreno506}, - {"505", AdrenoGpu::kAdreno505}, - {"504", AdrenoGpu::kAdreno504}, - // Adreno 4xx series - {"430", AdrenoGpu::kAdreno430}, - {"420", AdrenoGpu::kAdreno420}, - {"418", AdrenoGpu::kAdreno418}, - {"405", AdrenoGpu::kAdreno405}, - // Adreno 3xx series - {"330", AdrenoGpu::kAdreno330}, - {"320", AdrenoGpu::kAdreno320}, - {"308", AdrenoGpu::kAdreno308}, - {"306", AdrenoGpu::kAdreno306}, - {"305", AdrenoGpu::kAdreno305}, - {"304", AdrenoGpu::kAdreno304}, - // Adreno 2xx series - {"225", AdrenoGpu::kAdreno225}, - {"220", AdrenoGpu::kAdreno220}, - {"205", AdrenoGpu::kAdreno205}, - {"203", AdrenoGpu::kAdreno203}, - {"200", AdrenoGpu::kAdreno200}, - // Adreno 1xx series - {"130", AdrenoGpu::kAdreno130}, - {"120", AdrenoGpu::kAdreno120}, - }; - - for (const auto& v : kMapping) { - if (device_name.find(v.first) != std::string::npos) { - return v.second; - } - } - return AdrenoGpu::kUnknown; -} - -MaliGpu GetMaliGpuVersion(const std::string& gpu_description) { - const std::map<std::string, MaliGpu> kMapping = { - {"t604", MaliGpu::kT604}, {"t622", MaliGpu::kT622}, - {"t624", MaliGpu::kT624}, {"t628", MaliGpu::kT628}, - {"t658", MaliGpu::kT658}, {"t678", MaliGpu::kT678}, - {"t720", MaliGpu::kT720}, {"t760", MaliGpu::kT760}, - {"t820", MaliGpu::kT820}, {"t830", MaliGpu::kT830}, - {"t860", MaliGpu::kT860}, {"t880", MaliGpu::kT880}, - {"g31", MaliGpu::kG31}, {"g51", MaliGpu::kG51}, - {"g71", MaliGpu::kG71}, {"g52", MaliGpu::kG52}, - {"g72", MaliGpu::kG72}, {"g76", MaliGpu::kG76}, - {"g57", MaliGpu::kG57}, {"g77", MaliGpu::kG77}, - {"g68", MaliGpu::kG68}, {"g78", MaliGpu::kG78}, - }; - for (const auto& v : kMapping) { - if (gpu_description.find(v.first) != std::string::npos) { - return v.second; - } - } - return MaliGpu::kUnknown; -} - -} // namespace - -std::string GpuVendorToString(GpuVendor v) { - switch (v) { - case GpuVendor::kApple: - return "Apple"; - case GpuVendor::kQualcomm: - return "Qualcomm"; - case GpuVendor::kMali: - return "Mali"; - case GpuVendor::kPowerVR: - return "PowerVR"; - case GpuVendor::kNvidia: - return "NVIDIA"; - case GpuVendor::kAMD: - return "AMD"; - case GpuVendor::kIntel: - return "Intel"; - case GpuVendor::kUnknown: - return "unknown vendor"; - } -} - -std::string OpenCLVersionToString(OpenCLVersion version) { - switch (version) { - case OpenCLVersion::CL_1_0: - return "1.0"; - case OpenCLVersion::CL_1_1: - return "1.1"; - case OpenCLVersion::CL_1_2: - return "1.2"; - case OpenCLVersion::CL_2_0: - return "2.0"; - case OpenCLVersion::CL_2_1: - return "2.1"; - case OpenCLVersion::CL_2_2: - return "2.2"; - case OpenCLVersion::CL_3_0: - return "3.0"; - } -} - -AdrenoInfo::AdrenoInfo(const std::string& device_version) - : adreno_gpu(GetAdrenoGpuVersion(device_version)) {} - -bool AdrenoInfo::IsAdreno1xx() const { - return adreno_gpu == AdrenoGpu::kAdreno120 || - adreno_gpu == AdrenoGpu::kAdreno130; -} - -bool AdrenoInfo::IsAdreno2xx() const { - return adreno_gpu == AdrenoGpu::kAdreno200 || - adreno_gpu == AdrenoGpu::kAdreno203 || - adreno_gpu == AdrenoGpu::kAdreno205 || - adreno_gpu == AdrenoGpu::kAdreno220 || - adreno_gpu == AdrenoGpu::kAdreno225; -} - -bool AdrenoInfo::IsAdreno3xx() const { - return adreno_gpu == AdrenoGpu::kAdreno304 || - adreno_gpu == AdrenoGpu::kAdreno305 || - adreno_gpu == AdrenoGpu::kAdreno306 || - adreno_gpu == AdrenoGpu::kAdreno308 || - adreno_gpu == AdrenoGpu::kAdreno320 || - adreno_gpu == AdrenoGpu::kAdreno330; -} - -bool AdrenoInfo::IsAdreno4xx() const { - return adreno_gpu == AdrenoGpu::kAdreno405 || - adreno_gpu == AdrenoGpu::kAdreno418 || - adreno_gpu == AdrenoGpu::kAdreno420 || - adreno_gpu == AdrenoGpu::kAdreno430; -} - -bool AdrenoInfo::IsAdreno5xx() const { - return adreno_gpu == AdrenoGpu::kAdreno504 || - adreno_gpu == AdrenoGpu::kAdreno505 || - adreno_gpu == AdrenoGpu::kAdreno506 || - adreno_gpu == AdrenoGpu::kAdreno508 || - adreno_gpu == AdrenoGpu::kAdreno509 || - adreno_gpu == AdrenoGpu::kAdreno510 || - adreno_gpu == AdrenoGpu::kAdreno512 || - adreno_gpu == AdrenoGpu::kAdreno530 || - adreno_gpu == AdrenoGpu::kAdreno540; -} - -bool AdrenoInfo::IsAdreno6xx() const { - return adreno_gpu == AdrenoGpu::kAdreno605 || - adreno_gpu == AdrenoGpu::kAdreno610 || - adreno_gpu == AdrenoGpu::kAdreno612 || - adreno_gpu == AdrenoGpu::kAdreno615 || - adreno_gpu == AdrenoGpu::kAdreno616 || - adreno_gpu == AdrenoGpu::kAdreno618 || - adreno_gpu == AdrenoGpu::kAdreno620 || - adreno_gpu == AdrenoGpu::kAdreno630 || - adreno_gpu == AdrenoGpu::kAdreno640 || - adreno_gpu == AdrenoGpu::kAdreno650 || - adreno_gpu == AdrenoGpu::kAdreno675 || - adreno_gpu == AdrenoGpu::kAdreno680 || - adreno_gpu == AdrenoGpu::kAdreno685; -} - -bool AdrenoInfo::IsAdreno6xxOrHigher() const { return IsAdreno6xx(); } - -int AdrenoInfo::GetMaximumWavesCount() const { - if (IsAdreno6xx()) { - if (adreno_gpu == AdrenoGpu::kAdreno640) { - return 30; - } else { - return 16; - } - } else { - // all other versions not supported - return 1; - } -} - -int AdrenoInfo::GetRegisterMemorySizePerComputeUnit() const { - if (IsAdreno6xx()) { - if (adreno_gpu == AdrenoGpu::kAdreno640) { - return 128 * 144 * 16; - } else if (adreno_gpu == AdrenoGpu::kAdreno650) { - return 128 * 64 * 16; - } else { - return 128 * 96 * 16; - } - } else { - // all other versions not supported - return 1; - } -} - -int AdrenoInfo::GetMaximumWavesCount(int register_footprint_per_tread, - bool full_wave) const { - const int register_usage_per_wave = - GetWaveSize(full_wave) * register_footprint_per_tread; - const int possible_waves_count = - GetRegisterMemorySizePerComputeUnit() / register_usage_per_wave; - return std::min(possible_waves_count, GetMaximumWavesCount()); -} - -int AdrenoInfo::GetWaveSize(bool full_wave) const { - if (IsAdreno6xx()) { - return full_wave ? 128 : 64; - } else if (IsAdreno5xx() || IsAdreno4xx()) { - return full_wave ? 64 : 32; - } else { - // all other versions not supported - return 1; - } -} - -MaliInfo::MaliInfo(const std::string& gpu_description) - : gpu_version(GetMaliGpuVersion(gpu_description)) {} - -bool MaliInfo::IsMaliT6xx() const { - return gpu_version == MaliGpu::kT604 || gpu_version == MaliGpu::kT622 || - gpu_version == MaliGpu::kT624 || gpu_version == MaliGpu::kT628 || - gpu_version == MaliGpu::kT658 || gpu_version == MaliGpu::kT678; -} - -bool MaliInfo::IsMaliT7xx() const { - return gpu_version == MaliGpu::kT720 || gpu_version == MaliGpu::kT760; -} - -bool MaliInfo::IsMaliT8xx() const { - return gpu_version == MaliGpu::kT820 || gpu_version == MaliGpu::kT830 || - gpu_version == MaliGpu::kT860 || gpu_version == MaliGpu::kT880; -} - -bool MaliInfo::IsMidgard() const { - return IsMaliT6xx() || IsMaliT7xx() || IsMaliT8xx(); -} - -bool MaliInfo::IsBifrostGen1() const { - return gpu_version == MaliGpu::kG31 || gpu_version == MaliGpu::kG51 || - gpu_version == MaliGpu::kG71; -} - -bool MaliInfo::IsBifrostGen2() const { - return gpu_version == MaliGpu::kG52 || gpu_version == MaliGpu::kG72; -} - -bool MaliInfo::IsBifrostGen3() const { return gpu_version == MaliGpu::kG76; } - -bool MaliInfo::IsBifrost() const { - return IsBifrostGen1() || IsBifrostGen2() || IsBifrostGen3(); -} - -bool MaliInfo::IsValhall() const { - return gpu_version == MaliGpu::kG57 || gpu_version == MaliGpu::kG77 || - gpu_version == MaliGpu::kG68 || gpu_version == MaliGpu::kG78; -} - -bool GpuInfo::SupportsTextureArray() const { - return cl_version >= OpenCLVersion::CL_1_2; -} - -bool GpuInfo::SupportsImageBuffer() const { - return cl_version >= OpenCLVersion::CL_1_2; -} - -bool GpuInfo::SupportsImage3D() const { - if (IsMali() && mali_info.IsMidgard()) { - // On Mali T880 read_imageh doesn't compile with image3d_t - return false; - } - return supports_image3d_writes; -} - -bool GpuInfo::SupportsFloatImage2D(DataType data_type, int channels) const { - if (channels == 1) { - return data_type == DataType::FLOAT32 ? supports_r_f32_tex2d - : supports_r_f16_tex2d; - } else if (channels == 2) { - return data_type == DataType::FLOAT32 ? supports_rg_f32_tex2d - : supports_rg_f16_tex2d; - } else if (channels == 3) { - return data_type == DataType::FLOAT32 ? supports_rgb_f32_tex2d - : supports_rgb_f16_tex2d; - } else if (channels == 4) { - return data_type == DataType::FLOAT32 ? supports_rgba_f32_tex2d - : supports_rgba_f16_tex2d; - } else { - return false; - } -} - -bool GpuInfo::SupportsExtension(const std::string& extension) const { - for (const auto& ext : extensions) { - if (ext == extension) { - return true; - } - } - return false; -} - -bool GpuInfo::IsCL20OrHigher() const { - return cl_version != OpenCLVersion::CL_1_0 && - cl_version != OpenCLVersion::CL_1_1 && - cl_version != OpenCLVersion::CL_1_2; -} - -bool GpuInfo::SupportsSubGroupWithSize(int sub_group_size) const { - for (auto subgroup_size : supported_subgroup_sizes) { - if (sub_group_size == subgroup_size) { - return true; - } - } - return false; -} - -bool GpuInfo::IsAdreno() const { return gpu_vendor == GpuVendor::kQualcomm; } - -bool GpuInfo::IsApple() const { return gpu_vendor == GpuVendor::kApple; } - -bool GpuInfo::IsMali() const { return gpu_vendor == GpuVendor::kMali; } - -bool GpuInfo::IsPowerVR() const { return gpu_vendor == GpuVendor::kPowerVR; } - -bool GpuInfo::IsNvidia() const { return gpu_vendor == GpuVendor::kNvidia; } - -bool GpuInfo::IsAMD() const { return gpu_vendor == GpuVendor::kAMD; } - -bool GpuInfo::IsIntel() const { return gpu_vendor == GpuVendor::kIntel; } - -} // namespace cl -} // namespace gpu -} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/device_info.h b/tensorflow/lite/delegates/gpu/cl/device_info.h deleted file mode 100644 index a2ba4ca2597..00000000000 --- a/tensorflow/lite/delegates/gpu/cl/device_info.h +++ /dev/null @@ -1,245 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_DEVICE_INFO_H_ -#define TENSORFLOW_LITE_DELEGATES_GPU_CL_DEVICE_INFO_H_ - -#include <string> -#include <vector> - -#include "tensorflow/lite/delegates/gpu/common/data_type.h" - -// for use only in device_info.cc, but keep here to make tests -int GetAdrenoGPUVersion(const std::string& gpu_version); - -namespace tflite { -namespace gpu { -namespace cl { - -enum class GpuVendor { - kApple, - kQualcomm, - kMali, - kPowerVR, - kNvidia, - kAMD, - kIntel, - kUnknown -}; - -std::string GpuVendorToString(GpuVendor v); - -enum class OpenCLVersion { - CL_1_0, - CL_1_1, - CL_1_2, - CL_2_0, - CL_2_1, - CL_2_2, - CL_3_0 -}; -std::string OpenCLVersionToString(OpenCLVersion version); - -enum class AdrenoGpu { - // Adreno 6xx series - kAdreno685, - kAdreno680, - kAdreno675, - kAdreno650, - kAdreno640, - kAdreno630, - kAdreno620, - kAdreno618, - kAdreno616, - kAdreno615, - kAdreno612, - kAdreno610, - kAdreno605, - // Adreno 5xx series - kAdreno540, - kAdreno530, - kAdreno512, - kAdreno510, - kAdreno509, - kAdreno508, - kAdreno506, - kAdreno505, - kAdreno504, - // Adreno 4xx series - kAdreno430, - kAdreno420, - kAdreno418, - kAdreno405, - // Adreno 3xx series - kAdreno330, - kAdreno320, - kAdreno308, - kAdreno306, - kAdreno305, - kAdreno304, - // Adreno 2xx series - kAdreno225, - kAdreno220, - kAdreno205, - kAdreno203, - kAdreno200, - // Adreno 1xx series - kAdreno130, - kAdreno120, - kUnknown -}; - -struct AdrenoInfo { - AdrenoInfo() = default; - explicit AdrenoInfo(const std::string& device_version); - - AdrenoGpu adreno_gpu; - - bool IsAdreno1xx() const; - bool IsAdreno2xx() const; - bool IsAdreno3xx() const; - bool IsAdreno4xx() const; - bool IsAdreno5xx() const; - bool IsAdreno6xx() const; - bool IsAdreno6xxOrHigher() const; - - // This function returns some not very documented physical parameter of - // Adreno6xx GPU. - // We obtained it using Snapdragon Profiler. - int GetMaximumWavesCount() const; - - // returns amount of register memory per CU(Compute Unit) in bytes. - int GetRegisterMemorySizePerComputeUnit() const; - - // returns maximum possible amount of waves based on register usage. - int GetMaximumWavesCount(int register_footprint_per_tread, - bool full_wave = true) const; - - int GetWaveSize(bool full_wave) const; - - // Not supported on some Adreno devices with specific driver version. - // b/131099086 - bool support_one_layer_texture_array = true; -}; - -enum class MaliGpu { - kUnknown, - kT604, - kT622, - kT624, - kT628, - kT658, - kT678, - kT720, - kT760, - kT820, - kT830, - kT860, - kT880, - kG31, - kG51, - kG71, - kG52, - kG72, - kG76, - kG57, - kG77, - kG68, - kG78, -}; - -struct MaliInfo { - MaliInfo() = default; - explicit MaliInfo(const std::string& gpu_description); - MaliGpu gpu_version; - - bool IsMaliT6xx() const; - bool IsMaliT7xx() const; - bool IsMaliT8xx() const; - bool IsMidgard() const; - bool IsBifrostGen1() const; - bool IsBifrostGen2() const; - bool IsBifrostGen3() const; - bool IsBifrost() const; - bool IsValhall() const; -}; - -struct GpuInfo { - GpuInfo() = default; - - bool IsAdreno() const; - bool IsApple() const; - bool IsMali() const; - bool IsPowerVR() const; - bool IsNvidia() const; - bool IsAMD() const; - bool IsIntel() const; - - bool SupportsTextureArray() const; - bool SupportsImageBuffer() const; - bool SupportsImage3D() const; - - bool SupportsFloatImage2D(DataType data_type, int channels) const; - - bool SupportsExtension(const std::string& extension) const; - bool IsCL20OrHigher() const; - bool SupportsSubGroupWithSize(int sub_group_size) const; - - std::vector<std::string> extensions; - bool supports_fp16; - bool supports_image3d_writes; - GpuVendor gpu_vendor; - OpenCLVersion cl_version; - int compute_units_count; - uint64_t buffer_max_size; - uint64_t image2d_max_width; - uint64_t image2d_max_height; - uint64_t image_buffer_max_size; - uint64_t image_array_max_layers; - uint64_t image3d_max_width; - uint64_t image3d_max_height; - uint64_t image3d_max_depth; - int max_work_group_size_x; - int max_work_group_size_y; - int max_work_group_size_z; - std::vector<int> supported_subgroup_sizes; - - // rtn is ROUND_TO_NEAREST - // with rtn precision is much better then with rtz (ROUND_TO_ZERO) - // Adreno 3xx supports only rtz, Adreno 4xx and more support rtn - // Mali from T6xx supports rtn - // PowerVR supports only rtz - bool supports_fp32_rtn; - bool supports_fp16_rtn; - - bool supports_r_f16_tex2d = false; - bool supports_rg_f16_tex2d = false; - bool supports_rgb_f16_tex2d = false; - bool supports_rgba_f16_tex2d = false; - - bool supports_r_f32_tex2d = false; - bool supports_rg_f32_tex2d = false; - bool supports_rgb_f32_tex2d = false; - bool supports_rgba_f32_tex2d = false; - - AdrenoInfo adreno_info; - MaliInfo mali_info; -}; - -} // namespace cl -} // namespace gpu -} // namespace tflite - -#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_DEVICE_INFO_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/environment.cc b/tensorflow/lite/delegates/gpu/cl/environment.cc index 9b2fef288fe..275ea696e09 100644 --- a/tensorflow/lite/delegates/gpu/cl/environment.cc +++ b/tensorflow/lite/delegates/gpu/cl/environment.cc @@ -48,6 +48,39 @@ absl::Status CreateEnvironment(Environment* result, bool shared, return result->Init(); } +bool IsGpuSupportsStorageType(const GpuInfo& gpu_info, + TensorStorageType storage_type) { + switch (storage_type) { + case TensorStorageType::TEXTURE_2D: + return !gpu_info.IsAMD(); + case TensorStorageType::BUFFER: + return true; + case TensorStorageType::TEXTURE_ARRAY: + return !gpu_info.IsAMD() && gpu_info.SupportsTextureArray(); + case TensorStorageType::IMAGE_BUFFER: + return (gpu_info.IsAdreno() || gpu_info.IsAMD() || gpu_info.IsNvidia()) && + gpu_info.SupportsImageBuffer(); + case TensorStorageType::TEXTURE_3D: + return !gpu_info.IsAMD() && gpu_info.SupportsImage3D(); + case TensorStorageType::SINGLE_TEXTURE_2D: + return false; + case TensorStorageType::UNKNOWN: + return false; + } + return false; +} + +bool IsGpuSupportsPrecision(const GpuInfo& gpu_info, + CalculationsPrecision precision) { + switch (precision) { + case CalculationsPrecision::F32_F16: + case CalculationsPrecision::F16: + return gpu_info.SupportsFP16(); + case CalculationsPrecision::F32: + return true; + } +} + } // namespace Environment::Environment(CLDevice&& device, CLContext&& context, @@ -77,7 +110,8 @@ Environment& Environment::operator=(Environment&& environment) { } absl::Status Environment::Init() { - if (device().IsAdreno() && device().SupportsTextureArray()) { + if (device().GetInfo().IsAdreno() && + device().GetInfo().SupportsTextureArray()) { const auto& adreno_info = device().info_.adreno_info; // Some Adreno < 600 have bug with one layer texture array. b/131099086 // If we have one layer texture array and will write smt from kernel to this @@ -117,13 +151,7 @@ std::vector<CalculationsPrecision> Environment::GetSupportedPrecisions() const { } bool Environment::IsSupported(CalculationsPrecision precision) const { - switch (precision) { - case CalculationsPrecision::F32_F16: - case CalculationsPrecision::F16: - return device_.SupportsFP16(); - case CalculationsPrecision::F32: - return true; - } + return IsGpuSupportsPrecision(device_.GetInfo(), precision); } std::vector<TensorStorageType> Environment::GetSupportedStorages() const { @@ -153,24 +181,7 @@ Environment::GetSupportedStoragesWithHWZeroClampSupport() const { } bool Environment::IsSupported(TensorStorageType storage_type) const { - switch (storage_type) { - case TensorStorageType::TEXTURE_2D: - return !device_.IsAMD(); - case TensorStorageType::BUFFER: - return true; - case TensorStorageType::TEXTURE_ARRAY: - return !device_.IsAMD() && device_.SupportsTextureArray(); - case TensorStorageType::IMAGE_BUFFER: - return (device_.IsAdreno() || device_.IsAMD() || device_.IsNvidia()) && - device_.SupportsImageBuffer(); - case TensorStorageType::TEXTURE_3D: - return !device_.IsAMD() && device_.SupportsImage3D(); - case TensorStorageType::SINGLE_TEXTURE_2D: - return false; - case TensorStorageType::UNKNOWN: - return false; - } - return false; + return IsGpuSupportsStorageType(device_.GetInfo(), storage_type); } TensorStorageType GetFastestStorageType(const GpuInfo& gpu_info) { diff --git a/tensorflow/lite/delegates/gpu/cl/environment.h b/tensorflow/lite/delegates/gpu/cl/environment.h index 8917351841f..5138e59ee7e 100644 --- a/tensorflow/lite/delegates/gpu/cl/environment.h +++ b/tensorflow/lite/delegates/gpu/cl/environment.h @@ -19,9 +19,9 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h" #include "tensorflow/lite/delegates/gpu/cl/cl_context.h" #include "tensorflow/lite/delegates/gpu/cl/cl_device.h" -#include "tensorflow/lite/delegates/gpu/cl/device_info.h" #include "tensorflow/lite/delegates/gpu/cl/program_cache.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/gpu_info.h" #include "tensorflow/lite/delegates/gpu/common/precision.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h" diff --git a/tensorflow/lite/delegates/gpu/cl/gl_interop.cc b/tensorflow/lite/delegates/gpu/cl/gl_interop.cc index 599e6766301..2d4e6c54b39 100644 --- a/tensorflow/lite/delegates/gpu/cl/gl_interop.cc +++ b/tensorflow/lite/delegates/gpu/cl/gl_interop.cc @@ -89,7 +89,7 @@ absl::Status CreateClEventFromEglSync(cl_context context, } bool IsClEventFromEglSyncSupported(const CLDevice& device) { - return device.SupportsExtension("cl_khr_egl_event"); + return device.GetInfo().SupportsExtension("cl_khr_egl_event"); } absl::Status CreateClMemoryFromGlBuffer(GLuint gl_ssbo_id, @@ -126,7 +126,7 @@ absl::Status CreateClMemoryFromGlTexture(GLenum texture_target, bool IsGlSharingSupported(const CLDevice& device) { return clCreateFromGLBuffer && clCreateFromGLTexture && - device.SupportsExtension("cl_khr_gl_sharing"); + device.GetInfo().SupportsExtension("cl_khr_gl_sharing"); } AcquiredGlObjects::~AcquiredGlObjects() { Release({}, nullptr).IgnoreError(); } diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.cc b/tensorflow/lite/delegates/gpu/cl/inference_context.cc index 332de066bca..3bfc455865d 100644 --- a/tensorflow/lite/delegates/gpu/cl/inference_context.cc +++ b/tensorflow/lite/delegates/gpu/cl/inference_context.cc @@ -27,7 +27,6 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "tensorflow/lite/delegates/gpu/cl/buffer.h" #include "tensorflow/lite/delegates/gpu/cl/cl_device.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h" #include "tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h" #include "tensorflow/lite/delegates/gpu/cl/storage_type_util.h" @@ -38,6 +37,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/precision.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h" #include "tensorflow/lite/delegates/gpu/common/transformations/add_bias.h" #include "tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h" @@ -163,14 +163,14 @@ absl::Status InferenceContext::InitFromGraph( ReserveGraphTensors(create_info, creation_context.GetGpuInfo(), graph); precision_ = create_info.precision; storage_type_ = create_info.storage_type; - if (env->device().IsMali()) { + if (env->device().GetInfo().IsMali()) { need_flush_ = true; need_manual_release_ = true; flush_periodically_ = true; flush_period_ = 24; } - if (env->device().IsPowerVR()) { + if (env->device().GetInfo().IsPowerVR()) { need_flush_ = true; } CopyInAndOutIds(graph); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD index 402c30c247e..afdadd86133 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD @@ -13,11 +13,10 @@ cc_library( srcs = ["add.cc"], hdrs = ["add.h"], deps = [ - ":gpu_operation", - ":util", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "@com_google_absl//absl/strings", ], ) @@ -79,13 +78,12 @@ cc_library( srcs = ["concat_xy.cc"], hdrs = ["concat_xy.h"], deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/cl:cl_kernel", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) @@ -94,14 +92,13 @@ cc_library( srcs = ["concat_z.cc"], hdrs = ["concat_z.h"], deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/cl:cl_device", "//tensorflow/lite/delegates/gpu/cl:cl_kernel", "//tensorflow/lite/delegates/gpu/cl:tensor", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) @@ -110,10 +107,6 @@ cc_library( srcs = ["conv_buffer_1x1.cc"], hdrs = ["conv_buffer_1x1.h"], deps = [ - ":conv_common", - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/cl:buffer", "//tensorflow/lite/delegates/gpu/cl:cl_device", "//tensorflow/lite/delegates/gpu/cl:cl_kernel", @@ -127,14 +120,14 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:tensor", "//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/common:winograd_util", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:util", + "//tensorflow/lite/delegates/gpu/common/task:weights_conversion", + "//tensorflow/lite/delegates/gpu/common/task:weights_layout", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) -cc_library( - name = "conv_common", - hdrs = ["conv_common.h"], -) - cc_test( name = "conv_buffer_1x1_test", srcs = ["conv_buffer_1x1_test.cc"], @@ -157,9 +150,6 @@ cc_library( srcs = ["conv_constants.cc"], hdrs = ["conv_constants.h"], deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/cl:buffer", "//tensorflow/lite/delegates/gpu/cl:linear_storage", "//tensorflow/lite/delegates/gpu/cl:tensor", @@ -170,6 +160,9 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:tensor", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:util", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) @@ -195,10 +188,6 @@ cc_library( srcs = ["conv_powervr.cc"], hdrs = ["conv_powervr.h"], deps = [ - ":conv_common", - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/cl:buffer", "//tensorflow/lite/delegates/gpu/cl:cl_device", "//tensorflow/lite/delegates/gpu/cl:linear_storage", @@ -212,6 +201,11 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:tensor", "//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/common:winograd_util", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:util", + "//tensorflow/lite/delegates/gpu/common/task:weights_conversion", + "//tensorflow/lite/delegates/gpu/common/task:weights_layout", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", "@com_google_absl//absl/strings", ], ) @@ -238,14 +232,14 @@ cc_library( srcs = ["conv_weights_converter.cc"], hdrs = ["conv_weights_converter.h"], deps = [ - ":conv_common", - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/cl:cl_command_queue", "//tensorflow/lite/delegates/gpu/cl:cl_kernel", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:util", + "//tensorflow/lite/delegates/gpu/common/task:weights_layout", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) @@ -254,8 +248,6 @@ cc_library( srcs = ["converter.cc"], hdrs = ["converter.h"], deps = [ - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu:spi", "//tensorflow/lite/delegates/gpu/cl:cl_arguments", "//tensorflow/lite/delegates/gpu/cl:cl_command_queue", @@ -267,6 +259,8 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:util", "//tensorflow/lite/delegates/gpu/common/task:arguments", "//tensorflow/lite/delegates/gpu/common/task:tensor_desc", + "//tensorflow/lite/delegates/gpu/common/task:util", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) @@ -275,10 +269,6 @@ cc_library( srcs = ["convolution_transposed.cc"], hdrs = ["convolution_transposed.h"], deps = [ - ":conv_common", - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/cl:buffer", "//tensorflow/lite/delegates/gpu/cl:linear_storage", "//tensorflow/lite/delegates/gpu/cl:tensor", @@ -290,6 +280,10 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:tensor", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:weights_conversion", + "//tensorflow/lite/delegates/gpu/common/task:weights_layout", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", "@com_google_absl//absl/strings", ], ) @@ -316,9 +310,6 @@ cc_library( srcs = ["convolution_transposed_3x3.cc"], hdrs = ["convolution_transposed_3x3.h"], deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/cl:buffer", "//tensorflow/lite/delegates/gpu/cl:linear_storage", "//tensorflow/lite/delegates/gpu/cl:tensor", @@ -329,6 +320,10 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:tensor", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:weights_conversion", + "//tensorflow/lite/delegates/gpu/common/task:weights_layout", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) @@ -355,9 +350,6 @@ cc_library( srcs = ["convolution_transposed_3x3_thin.cc"], hdrs = ["convolution_transposed_3x3_thin.h"], deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/cl:buffer", "//tensorflow/lite/delegates/gpu/cl:linear_storage", "//tensorflow/lite/delegates/gpu/cl:tensor", @@ -368,6 +360,10 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:tensor", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:weights_conversion", + "//tensorflow/lite/delegates/gpu/common/task:weights_layout", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) @@ -393,9 +389,6 @@ cc_library( srcs = ["convolution_transposed_4x4.cc"], hdrs = ["convolution_transposed_4x4.h"], deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/cl:buffer", "//tensorflow/lite/delegates/gpu/cl:linear_storage", "//tensorflow/lite/delegates/gpu/cl:tensor", @@ -406,6 +399,10 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:tensor", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:weights_conversion", + "//tensorflow/lite/delegates/gpu/common/task:weights_layout", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) @@ -432,9 +429,6 @@ cc_library( srcs = ["convolution_transposed_thin.cc"], hdrs = ["convolution_transposed_thin.h"], deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/cl:buffer", "//tensorflow/lite/delegates/gpu/cl:tensor", "//tensorflow/lite/delegates/gpu/cl:texture2d", @@ -445,6 +439,8 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:tensor", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) @@ -470,9 +466,6 @@ cc_library( srcs = ["depthwise_conv.cc"], hdrs = ["depthwise_conv.h"], deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/cl:buffer", "//tensorflow/lite/delegates/gpu/cl:cl_device", "//tensorflow/lite/delegates/gpu/cl:linear_storage", @@ -485,6 +478,9 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:tensor", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:util", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) @@ -510,9 +506,6 @@ cc_library( srcs = ["depthwise_conv_3x3.cc"], hdrs = ["depthwise_conv_3x3.h"], deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/cl:buffer", "//tensorflow/lite/delegates/gpu/cl:tensor", "//tensorflow/lite/delegates/gpu/cl:texture2d", @@ -523,6 +516,8 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:tensor", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) @@ -548,11 +543,10 @@ cc_library( srcs = ["elementwise.cc"], hdrs = ["elementwise.h"], deps = [ - ":gpu_operation", - ":util", "//tensorflow/lite/delegates/gpu/cl:storage_type_util", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "@com_google_absl//absl/strings", ], ) @@ -579,11 +573,8 @@ cc_library( srcs = ["fully_connected.cc"], hdrs = ["fully_connected.h"], deps = [ - ":gpu_operation", - ":util", "//tensorflow/lite/delegates/gpu/cl:buffer", "//tensorflow/lite/delegates/gpu/cl:cl_kernel", - "//tensorflow/lite/delegates/gpu/cl:device_info", "//tensorflow/lite/delegates/gpu/cl:linear_storage", "//tensorflow/lite/delegates/gpu/cl:tensor", "//tensorflow/lite/delegates/gpu/cl:texture2d", @@ -593,6 +584,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:tensor", "//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "@com_google_absl//absl/memory", ], ) @@ -608,52 +600,26 @@ cc_test( deps = [ ":cl_test", ":fully_connected", - ":gpu_operation", "//tensorflow/lite/delegates/gpu/cl:environment", "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:tensor", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "@com_google_googletest//:gtest_main", ], ) -cc_library( - name = "gpu_operation", - srcs = ["gpu_operation.cc"], - hdrs = ["gpu_operation.h"], - deps = [ - ":util", - ":work_group_picking", - "//tensorflow/lite/delegates/gpu/cl:device_info", - "//tensorflow/lite/delegates/gpu/cl:serialization_cc_fbs", - "//tensorflow/lite/delegates/gpu/common:access_type", - "//tensorflow/lite/delegates/gpu/common:data_type", - "//tensorflow/lite/delegates/gpu/common:kernel_info", - "//tensorflow/lite/delegates/gpu/common:precision", - "//tensorflow/lite/delegates/gpu/common:status", - "//tensorflow/lite/delegates/gpu/common:types", - "//tensorflow/lite/delegates/gpu/common/task:arguments", - "//tensorflow/lite/delegates/gpu/common/task:buffer_desc", - "//tensorflow/lite/delegates/gpu/common/task:compiler_options", - "//tensorflow/lite/delegates/gpu/common/task:gpu_tensor", - "//tensorflow/lite/delegates/gpu/common/task:tensor_desc", - "//tensorflow/lite/delegates/gpu/common/task:tuning_type", - "@com_google_absl//absl/strings", - ], -) - cc_library( name = "lstm", srcs = ["lstm.cc"], hdrs = ["lstm.h"], deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) @@ -698,12 +664,11 @@ cc_library( srcs = ["max_unpooling.cc"], hdrs = ["max_unpooling.h"], deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) @@ -729,14 +694,12 @@ cc_library( srcs = ["mean_stddev_normalization.cc"], hdrs = ["mean_stddev_normalization.h"], deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/cl:cl_program", - "//tensorflow/lite/delegates/gpu/cl:device_info", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) @@ -763,12 +726,11 @@ cc_library( srcs = ["padding.cc"], hdrs = ["padding.h"], deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) @@ -794,14 +756,14 @@ cc_library( srcs = ["pooling.cc"], hdrs = ["pooling.h"], deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/cl:cl_kernel", "//tensorflow/lite/delegates/gpu/cl:tensor", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:util", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) @@ -827,8 +789,6 @@ cc_library( srcs = ["prelu.cc"], hdrs = ["prelu.h"], deps = [ - ":gpu_operation", - ":util", "//tensorflow/lite/delegates/gpu/cl:cl_context", "//tensorflow/lite/delegates/gpu/cl:cl_kernel", "//tensorflow/lite/delegates/gpu/cl:linear_storage", @@ -837,6 +797,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:tensor", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:variant", ], @@ -864,8 +825,6 @@ cc_library( srcs = ["quantize_and_dequantize.cc"], hdrs = ["quantize_and_dequantize.h"], deps = [ - ":gpu_operation", - ":util", "//tensorflow/lite/delegates/gpu/cl:cl_context", "//tensorflow/lite/delegates/gpu/cl:cl_kernel", "//tensorflow/lite/delegates/gpu/cl:linear_storage", @@ -873,6 +832,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:tensor", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:variant", ], @@ -901,14 +861,14 @@ cc_library( srcs = ["reduce.cc"], hdrs = ["reduce.h"], deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/common:kernel_info", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:util", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) @@ -934,8 +894,8 @@ cc_library( srcs = ["relu.cc"], hdrs = ["relu.h"], deps = [ - ":gpu_operation", "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "@com_google_absl//absl/strings", ], ) @@ -962,12 +922,11 @@ cc_library( srcs = ["reshape.cc"], hdrs = ["reshape.h"], deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/cl:cl_kernel", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) @@ -993,13 +952,12 @@ cc_library( srcs = ["reshapex4.cc"], hdrs = ["reshapex4.h"], deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/cl:cl_command_queue", "//tensorflow/lite/delegates/gpu/cl:cl_kernel", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) @@ -1025,13 +983,12 @@ cc_library( srcs = ["softmax.cc"], hdrs = ["softmax.h"], deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/cl:cl_kernel", "//tensorflow/lite/delegates/gpu/cl:tensor", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) @@ -1057,11 +1014,11 @@ cc_library( srcs = ["softmax1x1.cc"], hdrs = ["softmax1x1.h"], deps = [ - ":gpu_operation", - ":util", "//tensorflow/lite/delegates/gpu/cl:cl_kernel", "//tensorflow/lite/delegates/gpu/cl:tensor", "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:util", ], ) @@ -1087,13 +1044,12 @@ cc_library( srcs = ["space_to_depth.cc"], hdrs = ["space_to_depth.h"], deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/cl:cl_kernel", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) @@ -1119,11 +1075,10 @@ cc_library( srcs = ["strided_slice.cc"], hdrs = ["strided_slice.h"], deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) @@ -1149,11 +1104,10 @@ cc_library( srcs = ["transpose.cc"], hdrs = ["transpose.h"], deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", "@com_google_absl//absl/strings", ], ) @@ -1180,12 +1134,11 @@ cc_library( srcs = ["resize.cc"], hdrs = ["resize.h"], deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) @@ -1206,31 +1159,11 @@ cc_test( ], ) -cc_library( - name = "util", - srcs = ["util.cc"], - hdrs = ["util.h"], - deps = [ - "//tensorflow/lite/delegates/gpu/cl:device_info", - "//tensorflow/lite/delegates/gpu/common:data_type", - "//tensorflow/lite/delegates/gpu/common:precision", - "//tensorflow/lite/delegates/gpu/common:shape", - "//tensorflow/lite/delegates/gpu/common:tensor", - "//tensorflow/lite/delegates/gpu/common:types", - "//tensorflow/lite/delegates/gpu/common:util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) - cc_library( name = "winograd", srcs = ["winograd.cc"], hdrs = ["winograd.h"], deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", "//tensorflow/lite/delegates/gpu/cl:cl_device", "//tensorflow/lite/delegates/gpu/cl:cl_kernel", "//tensorflow/lite/delegates/gpu/cl:linear_storage", @@ -1240,6 +1173,8 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:winograd_util", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", "@com_google_absl//absl/strings:str_format", ], ) @@ -1254,7 +1189,6 @@ cc_test( ], deps = [ ":cl_test", - ":util", ":winograd", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", @@ -1263,21 +1197,6 @@ cc_test( ], ) -cc_library( - name = "work_group_picking", - srcs = ["work_group_picking.cc"], - hdrs = ["work_group_picking.h"], - deps = [ - "//tensorflow/lite/delegates/gpu/cl:cl_kernel", - "//tensorflow/lite/delegates/gpu/cl:device_info", - "//tensorflow/lite/delegates/gpu/common:kernel_info", - "//tensorflow/lite/delegates/gpu/common:types", - "//tensorflow/lite/delegates/gpu/common:util", - "//tensorflow/lite/delegates/gpu/common:workgroup_selection", - "//tensorflow/lite/delegates/gpu/common/task:tuning_type", - ], -) - test_suite( name = "all_tests", tests = [ diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/add.cc b/tensorflow/lite/delegates/gpu/cl/kernels/add.cc index 1cb41e79d88..15714f07b23 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/add.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/add.cc @@ -18,7 +18,6 @@ limitations under the License. #include <string> #include "absl/strings/str_cat.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" #include "tensorflow/lite/delegates/gpu/common/util.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/add.h b/tensorflow/lite/delegates/gpu/cl/kernels/add.h index 0e9d7e0d333..3a10f71f6b4 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/add.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/add.h @@ -19,9 +19,9 @@ limitations under the License. #include <string> #include <vector> -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc b/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc index fa5b933db8a..faf893fa1f0 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc @@ -19,9 +19,8 @@ limitations under the License. #include <string> #include <vector> -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/common/types.h" namespace tflite { @@ -69,7 +68,7 @@ std::string GetConcatKernelCode(const OperationDef& op_def, dst_coord += ", " + dst_coords[i]; } - std::string c = GetCommonDefines(op_def.precision); + std::string c; c += "__kernel void main_function(\n"; c += "$0) {\n"; if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.h b/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.h index 9dd3fcee52a..2ccdc2a26ad 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.h @@ -17,9 +17,9 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONCAT_XY_H_ #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/types.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc b/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc index cefbf557ac1..6a264c6185c 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc @@ -17,9 +17,8 @@ limitations under the License. #include <string> -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/common/types.h" namespace tflite { @@ -43,7 +42,7 @@ std::string GetConcatKernelCode(const OperationDef& op_def, tensor_names[i] = "src_tensor_" + std::to_string(i); } - std::string c = GetCommonDefines(op_def.precision); + std::string c; c += "__kernel void main_function(\n"; c += "$0) {\n"; c += " int X = get_global_id(0);\n"; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.h b/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.h index 16341af4187..227ce15d04d 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.h @@ -20,9 +20,9 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/cl_device.h" #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/tensor.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/types.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.cc index ed15df0bea7..70d79813829 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.cc @@ -20,9 +20,9 @@ limitations under the License. #include <utility> #include "tensorflow/lite/delegates/gpu/cl/cl_device.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/util.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" namespace tflite { namespace gpu { @@ -139,7 +139,7 @@ ConvBuffer1x1::ConvParams GetBestParams(const GpuInfo& gpu_info, conv_params.element_size = 4; conv_params.block_size = int3(1, 1, 1); if (gpu_info.IsMali() && definition.precision == CalculationsPrecision::F16 && - gpu_info.compute_units_count <= 4) { + gpu_info.GetComputeUnitsCount() <= 4) { conv_params.block_size.x *= 2; } return conv_params; @@ -194,7 +194,7 @@ std::string ConvBuffer1x1::GenerateConvBuffer1x1( } AddDstTensor("dst_tensor", dst_desc); - std::string c = GetCommonDefines(op_def.precision); + std::string c; switch (op_def.precision) { case CalculationsPrecision::F32: c += "#define FLT8 float8\n"; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h index d93e9a0460c..3687b963078 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h @@ -18,9 +18,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/buffer.h" #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_common.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" #include "tensorflow/lite/delegates/gpu/cl/tensor.h" #include "tensorflow/lite/delegates/gpu/cl/util.h" @@ -28,6 +25,9 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" +#include "tensorflow/lite/delegates/gpu/common/task/weights_conversion.h" +#include "tensorflow/lite/delegates/gpu/common/task/weights_layout.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/common/winograd_util.h" @@ -52,9 +52,9 @@ class ConvBuffer1x1 : public GPUOperation { std::vector<int3>* work_groups) const override; int3 GetGridSize() const override; - ConvWeightsDescription GetConvWeightsDescription() const { - ConvWeightsDescription desc; - desc.layout = ConvWeightsLayout::kOHWIOGroupI4O4; + WeightsDescription GetWeightsDescription() const { + WeightsDescription desc; + desc.layout = WeightsLayout::kOHWIOGroupI4O4; desc.output_group_size = conv_params_.block_size.z; return desc; } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.cc index 99b1e041d1c..e272ad64f3a 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.cc @@ -18,8 +18,8 @@ limitations under the License. #include <string> #include <utility> -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" +#include "tensorflow/lite/delegates/gpu/common/task/util.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" namespace tflite { namespace gpu { @@ -111,14 +111,13 @@ std::string GenerateConvolutionConstantCode(const OperationDef& op_def, } op->AddDstTensor("dst_tensor", dst_desc); - std::string c = GetCommonDefines(op_def.precision); - const int out_z = DivideRoundUp(weights_shape.o, 4); const std::string kOutZ = std::to_string(out_z); const int src_depth = DivideRoundUp(weights_shape.i, 4); const std::string postfixes[] = {".x", ".xy", ".xyz", ""}; + std::string c; c += "__kernel void main_function(\n"; c += "$0) {\n"; c += " int X = get_global_id(0);\n"; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h index 5a7cd248999..be75f06befb 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h @@ -17,7 +17,6 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONV_CONSTANTS_H_ #include "tensorflow/lite/delegates/gpu/cl/buffer.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" #include "tensorflow/lite/delegates/gpu/cl/tensor.h" #include "tensorflow/lite/delegates/gpu/cl/util.h" @@ -25,6 +24,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/gpu/common/types.h" diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc index 448a243a830..6d4395b2b18 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc @@ -20,11 +20,11 @@ limitations under the License. #include <utility> #include "absl/strings/substitute.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/util.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" namespace tflite { namespace gpu { @@ -443,9 +443,9 @@ std::string ConvPowerVR::GenerateConv(const GpuInfo& gpu_info, const std::string weights_global_ptr = weights_space + " " + weights_data_type + "*"; - std::string c = GetCommonDefines(op_def.precision); + std::string c; if (use_simd_broadcast) { - if (gpu_info.cl_version == OpenCLVersion::CL_2_0) { + if (gpu_info.opencl_info.cl_version == OpenClVersion::kCl2_0) { c += "#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n"; } else if (gpu_info.SupportsExtension("cl_intel_subgroups")) { c += "#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n"; @@ -1045,7 +1045,7 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams( if (dst_shape) { int task_size = dst_shape->w * dst_shape->b * dst_shape->h * dst_depth; float task_size_per_cu = - static_cast<float>(task_size) / gpu_info.compute_units_count; + static_cast<float>(task_size) / gpu_info.GetComputeUnitsCount(); int block_size = conv_params.block_size.x * conv_params.block_size.y * conv_params.block_size.w; float threads_per_cu = task_size_per_cu / block_size; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h index 85289f63339..b1862e48a71 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h @@ -21,9 +21,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/buffer.h" #include "tensorflow/lite/delegates/gpu/cl/cl_device.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_common.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" #include "tensorflow/lite/delegates/gpu/cl/tensor.h" #include "tensorflow/lite/delegates/gpu/cl/texture2d.h" @@ -32,6 +29,9 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" +#include "tensorflow/lite/delegates/gpu/common/task/weights_conversion.h" +#include "tensorflow/lite/delegates/gpu/common/task/weights_layout.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/common/winograd_util.h" @@ -50,9 +50,9 @@ class ConvPowerVR : public GPUOperation { absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; - ConvWeightsDescription GetConvWeightsDescription() const { - ConvWeightsDescription desc; - desc.layout = ConvWeightsLayout::kOHWIOGroupI4O4; + WeightsDescription GetWeightsDescription() const { + WeightsDescription desc; + desc.layout = WeightsLayout::kOHWIOGroupI4O4; desc.output_group_size = conv_params_.block_size.w; return desc; } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.cc index 521cbefd885..5d3f4e26e95 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.cc @@ -17,37 +17,35 @@ limitations under the License. #include <string> -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" +#include "tensorflow/lite/delegates/gpu/common/task/util.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" namespace tflite { namespace gpu { namespace cl { ConverterToConvWeights::ConverterToConvWeights( - const OperationDef& definition, - const ConvWeightsDescription& conv_weights_desc) - : GPUOperation(definition), conv_weights_desc_(conv_weights_desc) { - code_ = GetConverterToConvWeightsCode(definition_, conv_weights_desc_); + const OperationDef& definition, const WeightsDescription& weights_desc) + : GPUOperation(definition), weights_desc_(weights_desc) { + code_ = GetConverterToConvWeightsCode(definition_, weights_desc_); } ConverterToConvWeights::ConverterToConvWeights( ConverterToConvWeights&& operation) : GPUOperation(std::move(operation)), - conv_weights_desc_(operation.conv_weights_desc_) {} + weights_desc_(std::move(operation.weights_desc_)) {} ConverterToConvWeights& ConverterToConvWeights::operator=( ConverterToConvWeights&& operation) { if (this != &operation) { - conv_weights_desc_ = operation.conv_weights_desc_; + weights_desc_ = std::move(operation.weights_desc_); GPUOperation::operator=(std::move(operation)); } return *this; } std::string ConverterToConvWeights::GetConverterToConvWeightsCode( - const OperationDef& op_def, - const ConvWeightsDescription& conv_weights_desc) { + const OperationDef& op_def, const WeightsDescription& conv_weights_desc) { AddSrcTensor("src_tensor", op_def.src_tensors[0]); AddDstTensor("dst_tensor", op_def.dst_tensors[0]); args_.AddFloat("mask_x"); @@ -55,7 +53,7 @@ std::string ConverterToConvWeights::GetConverterToConvWeightsCode( args_.AddFloat("mask_z"); args_.AddFloat("mask_w"); - std::string c = GetCommonDefines(op_def.precision); + std::string c; c += "__kernel void main_function(\n"; c += "$0) {\n"; c += " int GROUP_SIZE = " + @@ -120,16 +118,15 @@ absl::Status ConverterToConvWeights::BindArguments(ArgumentsBinder* args) { int3 ConverterToConvWeights::GetGridSize() const { const int grid_x = DivideRoundUp( - AlignByN(src_[0]->Batch(), 4 * conv_weights_desc_.output_group_size), 4); + AlignByN(src_[0]->Batch(), 4 * weights_desc_.output_group_size), 4); const int grid_y = src_[0]->Slices(); const int grid_z = src_[0]->Width() * src_[0]->Height(); return int3(grid_x, grid_y, grid_z); } ConverterToConvWeights CreateConverterToConvWeights( - const OperationDef& definition, - const ConvWeightsDescription& conv_weights_desc) { - return ConverterToConvWeights(definition, conv_weights_desc); + const OperationDef& definition, const WeightsDescription& weights_desc) { + return ConverterToConvWeights(definition, weights_desc); } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.h index 3c7314ea6c9..3d5306886b3 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.h @@ -18,9 +18,9 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h" #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_common.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" +#include "tensorflow/lite/delegates/gpu/common/task/weights_layout.h" #include "tensorflow/lite/delegates/gpu/common/types.h" namespace tflite { @@ -30,7 +30,7 @@ namespace cl { class ConverterToConvWeights : public GPUOperation { public: ConverterToConvWeights(const OperationDef& definition, - const ConvWeightsDescription& conv_weights_desc); + const WeightsDescription& weights_desc); absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; @@ -42,10 +42,9 @@ class ConverterToConvWeights : public GPUOperation { private: std::string GetConverterToConvWeightsCode( - const OperationDef& op_def, - const ConvWeightsDescription& conv_weights_desc); + const OperationDef& op_def, const WeightsDescription& weights_desc); - ConvWeightsDescription conv_weights_desc_; + WeightsDescription weights_desc_; }; // We expect src BHWC tensor and we assume that B is O, H = H, W = W, C is I @@ -53,8 +52,7 @@ class ConverterToConvWeights : public GPUOperation { // dst.b * dst.h * dst.w * dst.c = AlignByN(src.b, 4) * src.h * src.w // AlignByN(src.c, 4) ConverterToConvWeights CreateConverterToConvWeights( - const OperationDef& definition, - const ConvWeightsDescription& conv_weights_desc); + const OperationDef& definition, const WeightsDescription& weights_desc); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc b/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc index ab6f986ea97..fdab1890b69 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc @@ -22,13 +22,13 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/cl_arguments.h" #include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h" #include "tensorflow/lite/delegates/gpu/cl/cl_errors.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/cl/tensor.h" #include "tensorflow/lite/delegates/gpu/cl/tensor_type_util.h" #include "tensorflow/lite/delegates/gpu/common/precision.h" #include "tensorflow/lite/delegates/gpu/common/task/arguments.h" #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h" +#include "tensorflow/lite/delegates/gpu/common/task/util.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/common/util.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc index 99df2689b95..f6c05270e8c 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc @@ -20,10 +20,9 @@ limitations under the License. #include <vector> #include "absl/strings/substitute.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" namespace tflite { namespace gpu { @@ -132,7 +131,7 @@ std::string ConvolutionTransposed::GenerateConvolutionTransposedCode( const auto& src_def = op_def.src_tensors[0]; - std::string c = GetCommonDefines(op_def.precision); + std::string c; for (int s = 0; s < block_size.w; ++s) { const std::string f0 = diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h index c25c916cc56..0373c4cc964 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h @@ -20,9 +20,6 @@ limitations under the License. #include <vector> #include "tensorflow/lite/delegates/gpu/cl/buffer.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_common.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" #include "tensorflow/lite/delegates/gpu/cl/tensor.h" #include "tensorflow/lite/delegates/gpu/cl/texture2d.h" @@ -31,6 +28,9 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" +#include "tensorflow/lite/delegates/gpu/common/task/weights_conversion.h" +#include "tensorflow/lite/delegates/gpu/common/task/weights_layout.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/gpu/common/types.h" @@ -54,9 +54,9 @@ class ConvolutionTransposed : public GPUOperation { ConvolutionTransposed(const ConvolutionTransposed&) = delete; ConvolutionTransposed& operator=(const ConvolutionTransposed&) = delete; - ConvWeightsDescription GetConvWeightsDescription() const { - ConvWeightsDescription desc; - desc.layout = ConvWeightsLayout::kOHWIOGroupI4O4; + WeightsDescription GetWeightsDescription() const { + WeightsDescription desc; + desc.layout = WeightsLayout::kOHWIOGroupI4O4; desc.output_group_size = block_size_.w; return desc; } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.cc index ca1be8ac5e0..1030891ab38 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.cc @@ -19,8 +19,7 @@ limitations under the License. #include <utility> #include <vector> -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" namespace tflite { namespace gpu { @@ -80,6 +79,19 @@ std::string ConvolutionTransposed3x3::GenerateConvolutionTransposedCode( } AddDstTensor("dst_tensor", dst_desc); + if (op_def.src_tensors.size() == 2) { + // dynamic weights + BufferDescriptor desc; + desc.element_type = op_def.src_tensors[1].data_type; + desc.element_size = 4; + desc.memory_type = + weights_upload_type == + ConvolutionTransposed3x3::WeightsUploadType::CONSTANT_MEM + ? MemoryType::CONSTANT + : MemoryType::GLOBAL; + AddSrcBuffer("weights", desc); + } + args_.AddInt("filter_offset"); args_.AddInt("padding_x"); args_.AddInt("padding_y"); @@ -90,7 +102,7 @@ std::string ConvolutionTransposed3x3::GenerateConvolutionTransposedCode( weights_upload_type == ConvolutionTransposed3x3::WeightsUploadType::LOCAL_MEM_ASYNC; - std::string c = GetCommonDefines(op_def.precision); + std::string c; switch (op_def.precision) { case CalculationsPrecision::F32: case CalculationsPrecision::F16: @@ -350,6 +362,22 @@ int3 ConvolutionTransposed3x3::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } +std::vector<int> ConvolutionTransposed3x3::GetSpatialWeightsRemap() const { + const int padding_x_rem = abs(padding_.x) % 2; + const int padding_y_rem = abs(padding_.y) % 2; + + std::vector<int> remap; + if (padding_x_rem == 1 && padding_y_rem == 1) { + return std::vector<int>{4, 5, 3, 7, 1, 8, 6, 2, 0}; + } else if (padding_x_rem == 0 && padding_y_rem == 1) { + return std::vector<int>{5, 3, 4, 8, 6, 2, 0, 7, 1}; + } else if (padding_x_rem == 1 && padding_y_rem == 0) { + return std::vector<int>{7, 1, 8, 6, 2, 0, 4, 5, 3}; + } else { // padding_x_rem == 0 && padding_y_rem == 0 + return std::vector<int>{8, 6, 2, 0, 7, 1, 5, 3, 4}; + } +} + bool IsConvolutionTransposed3x3Supported( const OperationDef& definition, const ConvolutionTransposedAttributes& attr) { @@ -373,6 +401,21 @@ ConvolutionTransposed3x3 CreateConvolutionTransposed3x3( return result; } +ConvolutionTransposed3x3 CreateConvolutionTransposed3x3DynamicWeights( + const GpuInfo& gpu_info, const OperationDef& definition, + const ConvolutionTransposedAttributes& attr) { + const int2 padding = int2(attr.padding.prepended.w, attr.padding.prepended.h); + ConvolutionTransposed3x3 result(definition, gpu_info, padding); + + TensorLinearDescriptor desc; + desc.storage_type = LinearStorageType::TEXTURE_2D; + desc.element_type = definition.GetDataType(); + desc.UploadLinearData(attr.bias); + result.args_.AddObject( + "biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc))); + return result; +} + } // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.h index 89abf70498a..49550a6cd55 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.h @@ -19,7 +19,6 @@ limitations under the License. #include <vector> #include "tensorflow/lite/delegates/gpu/cl/buffer.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" #include "tensorflow/lite/delegates/gpu/cl/tensor.h" #include "tensorflow/lite/delegates/gpu/cl/util.h" @@ -27,6 +26,9 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" +#include "tensorflow/lite/delegates/gpu/common/task/weights_conversion.h" +#include "tensorflow/lite/delegates/gpu/common/task/weights_layout.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/gpu/common/types.h" @@ -50,6 +52,13 @@ class ConvolutionTransposed3x3 : public GPUOperation { ConvolutionTransposed3x3(const ConvolutionTransposed3x3&) = delete; ConvolutionTransposed3x3& operator=(const ConvolutionTransposed3x3&) = delete; + WeightsDescription GetWeightsDescription() const { + WeightsDescription desc; + desc.layout = WeightsLayout::kOICustomSSpatialI4O4; + desc.spatial_remap = GetSpatialWeightsRemap(); + return desc; + } + enum class WeightsUploadType { LOCAL_MEM_ASYNC, LOCAL_MEM_BY_THREADS, @@ -63,12 +72,14 @@ class ConvolutionTransposed3x3 : public GPUOperation { friend ConvolutionTransposed3x3 CreateConvolutionTransposed3x3( const GpuInfo& gpu_info, const OperationDef& definition, const ConvolutionTransposedAttributes& attr); + friend ConvolutionTransposed3x3 CreateConvolutionTransposed3x3DynamicWeights( + const GpuInfo& gpu_info, const OperationDef& definition, + const ConvolutionTransposedAttributes& attr); + template <DataType T> void UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights); - template <DataType S, typename T> - void RearrangeWeightsData(const tflite::gpu::Tensor<OHWI, S>& weights, - absl::Span<T> dst); + std::vector<int> GetSpatialWeightsRemap() const; std::string GenerateConvolutionTransposedCode( const OperationDef& op_def, @@ -104,71 +115,18 @@ void ConvolutionTransposed3x3::UploadWeights( if (f32_weights) { float4* ptr = reinterpret_cast<float4*>(desc.data.data()); - RearrangeWeightsData(weights, absl::MakeSpan(ptr, flt4_count)); + RearrangeWeightsToOICustomSpatialI4O4(weights, GetSpatialWeightsRemap(), + absl::MakeSpan(ptr, flt4_count)); } else { half4* ptr = reinterpret_cast<half4*>(desc.data.data()); - RearrangeWeightsData(weights, absl::MakeSpan(ptr, flt4_count)); + RearrangeWeightsToOICustomSpatialI4O4(weights, GetSpatialWeightsRemap(), + absl::MakeSpan(ptr, flt4_count)); } args_.AddObject("weights", absl::make_unique<BufferDescriptor>(std::move(desc))); } -template <DataType S, typename T> -void ConvolutionTransposed3x3::RearrangeWeightsData( - const tflite::gpu::Tensor<OHWI, S>& weights, absl::Span<T> dst) { - const int src_depth = DivideRoundUp(weights.shape.i, 4); - const int dst_depth = DivideRoundUp(weights.shape.o, 4); - const int kernel_x = 3; - const int kernel_y = 3; - - const int padding_x_rem = abs(padding_.x) % 2; - const int padding_y_rem = abs(padding_.y) % 2; - - // we are reorganizing weights to read them sequentially in kernel - std::vector<int> remap; - if (padding_x_rem == 1 && padding_y_rem == 1) { - remap = {4, 5, 3, 7, 1, 8, 6, 2, 0}; - } else if (padding_x_rem == 0 && padding_y_rem == 1) { - remap = {5, 3, 4, 8, 6, 2, 0, 7, 1}; - } else if (padding_x_rem == 1 && padding_y_rem == 0) { - remap = {7, 1, 8, 6, 2, 0, 4, 5, 3}; - } else { // padding_x_rem == 0 && padding_y_rem == 0 - remap = {8, 6, 2, 0, 7, 1, 5, 3, 4}; - } - - int counter = 0; - for (int d = 0; d < dst_depth; ++d) { - for (int s = 0; s < src_depth; ++s) { - for (int y = 0; y < kernel_y; ++y) { - for (int x = 0; x < kernel_x; ++x) { - const int kernel_index = remap[y * kernel_x + x]; - const int kernel_index_x = kernel_index % kernel_x; - const int kernel_index_y = kernel_index / kernel_x; - T filters[4]; - for (int j = 0; j < 4; ++j) { - for (int i = 0; i < 4; ++i) { - const int s_ch = s * 4 + i; - const int d_ch = d * 4 + j; - if (s_ch < weights.shape.i && d_ch < weights.shape.o) { - const int f_index = weights.shape.LinearIndex( - {d_ch, kernel_index_y, kernel_index_x, s_ch}); - filters[i][j] = weights.data[f_index]; - } else { - filters[i][j] = 0.0f; - } - } - } - dst[counter++] = filters[0]; - dst[counter++] = filters[1]; - dst[counter++] = filters[2]; - dst[counter++] = filters[3]; - } - } - } - } -} - bool IsConvolutionTransposed3x3Supported( const OperationDef& definition, const ConvolutionTransposedAttributes& attr); @@ -177,6 +135,10 @@ ConvolutionTransposed3x3 CreateConvolutionTransposed3x3( const GpuInfo& gpu_info, const OperationDef& definition, const ConvolutionTransposedAttributes& attr); +ConvolutionTransposed3x3 CreateConvolutionTransposed3x3DynamicWeights( + const GpuInfo& gpu_info, const OperationDef& definition, + const ConvolutionTransposedAttributes& attr); + } // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.cc index 4cabbbee376..7a11e617627 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.cc @@ -19,8 +19,7 @@ limitations under the License. #include <utility> #include <vector> -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" namespace tflite { namespace gpu { @@ -53,9 +52,18 @@ std::string ConvolutionTransposed3x3Thin::GenerateConvolutionTransposedCode( AddSrcTensor("src_tensor", src_desc); AddDstTensor("dst_tensor", op_def.dst_tensors[0]); + if (op_def.src_tensors.size() == 2) { + // dynamic weights + BufferDescriptor desc; + desc.element_type = op_def.src_tensors[1].data_type; + desc.element_size = 4; + desc.memory_type = MemoryType::CONSTANT; + AddSrcBuffer("weights", desc); + } + const auto src_tensor_type = op_def.src_tensors[0].storage_type; - std::string c = GetCommonDefines(op_def.precision); + std::string c; switch (op_def.precision) { case CalculationsPrecision::F32: @@ -160,8 +168,7 @@ std::string ConvolutionTransposed3x3Thin::GenerateConvolutionTransposedCode( for (int d = 0; d < dst_depth; ++d) { const std::string layer = std::to_string(d); c += " {\n"; - c += " FLT4 bias_val = args.weights.Read(" + - std::to_string(36 * filters_index + d) + ");\n"; + c += " FLT4 bias_val = args.biases.Read(" + layer + ");\n"; for (int y = 0; y < 2; ++y) { for (int x = 0; x < 2; ++x) { const std::string x_coord = "X + " + std::to_string(x); @@ -188,6 +195,10 @@ int3 ConvolutionTransposed3x3Thin::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } +std::vector<int> ConvolutionTransposed3x3Thin::GetSpatialWeightsRemap() const { + return std::vector<int>{4, 5, 3, 7, 1, 8, 6, 2, 0}; +} + bool IsConvolutionTransposed3x3ThinSupported( const ConvolutionTransposedAttributes& attr) { return attr.weights.shape.o <= 8 && attr.weights.shape.w == 3 && @@ -201,7 +212,28 @@ ConvolutionTransposed3x3Thin CreateConvolutionTransposed3x3Thin( const GpuInfo& gpu_info, const OperationDef& definition, const ConvolutionTransposedAttributes& attr) { ConvolutionTransposed3x3Thin result(definition, attr); - result.UploadData(attr.weights, attr.bias); + result.UploadWeights(attr.weights); + + TensorLinearDescriptor desc; + desc.storage_type = LinearStorageType::TEXTURE_2D; + desc.element_type = definition.GetDataType(); + desc.UploadLinearData(attr.bias); + result.args_.AddObject( + "biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc))); + return result; +} + +ConvolutionTransposed3x3Thin CreateConvolutionTransposed3x3ThinDynamicWeights( + const GpuInfo& gpu_info, const OperationDef& definition, + const ConvolutionTransposedAttributes& attr) { + ConvolutionTransposed3x3Thin result(definition, attr); + + TensorLinearDescriptor desc; + desc.storage_type = LinearStorageType::TEXTURE_2D; + desc.element_type = definition.GetDataType(); + desc.UploadLinearData(attr.bias); + result.args_.AddObject( + "biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc))); return result; } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.h index 8ff50f95f31..8080d757e9c 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3_thin.h @@ -19,7 +19,6 @@ limitations under the License. #include <vector> #include "tensorflow/lite/delegates/gpu/cl/buffer.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" #include "tensorflow/lite/delegates/gpu/cl/tensor.h" #include "tensorflow/lite/delegates/gpu/cl/util.h" @@ -27,6 +26,9 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" +#include "tensorflow/lite/delegates/gpu/common/task/weights_conversion.h" +#include "tensorflow/lite/delegates/gpu/common/task/weights_layout.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/gpu/common/types.h" @@ -47,29 +49,38 @@ class ConvolutionTransposed3x3Thin : public GPUOperation { ConvolutionTransposed3x3Thin& operator=(const ConvolutionTransposed3x3Thin&) = delete; + WeightsDescription GetWeightsDescription() const { + WeightsDescription desc; + desc.layout = WeightsLayout::kOICustomSSpatialI4O4; + desc.spatial_remap = GetSpatialWeightsRemap(); + return desc; + } + private: - friend ConvolutionTransposed3x3Thin CreateConvolutionTransposed3x3Thin( - const GpuInfo& gpu_info, const OperationDef& definition, - const ConvolutionTransposedAttributes& attr); explicit ConvolutionTransposed3x3Thin( const OperationDef& definition, const ConvolutionTransposedAttributes& attr); - template <DataType T> - void UploadData(const tflite::gpu::Tensor<OHWI, T>& weights, - const tflite::gpu::Tensor<Linear, T>& biases); - template <DataType S, typename T> - void RearrangeWeightsData(const tflite::gpu::Tensor<OHWI, S>& weights, - absl::Span<T> dst); + friend ConvolutionTransposed3x3Thin CreateConvolutionTransposed3x3Thin( + const GpuInfo& gpu_info, const OperationDef& definition, + const ConvolutionTransposedAttributes& attr); + friend ConvolutionTransposed3x3Thin + CreateConvolutionTransposed3x3ThinDynamicWeights( + const GpuInfo& gpu_info, const OperationDef& definition, + const ConvolutionTransposedAttributes& attr); + + template <DataType T> + void UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights); + + std::vector<int> GetSpatialWeightsRemap() const; std::string GenerateConvolutionTransposedCode(const OperationDef& op_def, int src_depth, int dst_depth); }; template <DataType T> -void ConvolutionTransposed3x3Thin::UploadData( - const tflite::gpu::Tensor<OHWI, T>& weights, - const tflite::gpu::Tensor<Linear, T>& biases) { +void ConvolutionTransposed3x3Thin::UploadWeights( + const tflite::gpu::Tensor<OHWI, T>& weights) { const int src_depth = DivideRoundUp(weights.shape.i, 4); const int dst_depth = DivideRoundUp(weights.shape.o, 4); const int kernel_x = 3; // This operation support only 3x3 kernel @@ -83,79 +94,23 @@ void ConvolutionTransposed3x3Thin::UploadData( desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; desc.element_size = 4; desc.memory_type = MemoryType::CONSTANT; - desc.size = flt4_size * (flt4_count + dst_depth); + desc.size = flt4_size * flt4_count; desc.data.resize(desc.size); if (f32_weights) { float4* gpu_data = reinterpret_cast<float4*>(desc.data.data()); - RearrangeWeightsData(weights, absl::MakeSpan(gpu_data, flt4_count)); - for (int i = 0; i < dst_depth; ++i) { - float4 bias_value(0.0f); - for (int c = 0; c < 4; ++c) { - int ch = i * 4 + c; - bias_value[c] = ch < weights.shape.o ? biases.data[ch] : 0.0f; - } - gpu_data[flt4_count + i] = bias_value; - } + RearrangeWeightsToOICustomSpatialI4O4(weights, GetSpatialWeightsRemap(), + absl::MakeSpan(gpu_data, flt4_count)); } else { half4* gpu_data = reinterpret_cast<half4*>(desc.data.data()); - RearrangeWeightsData(weights, absl::MakeSpan(gpu_data, flt4_count)); - for (int i = 0; i < dst_depth; ++i) { - half4 bias_value(0.0f); - for (int c = 0; c < 4; ++c) { - int ch = i * 4 + c; - bias_value[c] = ch < weights.shape.o ? biases.data[ch] : 0.0f; - } - gpu_data[flt4_count + i] = bias_value; - } + RearrangeWeightsToOICustomSpatialI4O4(weights, GetSpatialWeightsRemap(), + absl::MakeSpan(gpu_data, flt4_count)); } args_.AddObject("weights", absl::make_unique<BufferDescriptor>(std::move(desc))); } -template <DataType S, typename T> -void ConvolutionTransposed3x3Thin::RearrangeWeightsData( - const tflite::gpu::Tensor<OHWI, S>& weights, absl::Span<T> dst) { - const int src_depth = DivideRoundUp(weights.shape.i, 4); - const int dst_depth = DivideRoundUp(weights.shape.o, 4); - const int kernel_x = 3; - const int kernel_y = 3; - - const int remap[9] = {4, 5, 3, 7, 1, 8, 6, 2, 0}; - - int counter = 0; - for (int s = 0; s < src_depth; ++s) { - for (int d = 0; d < dst_depth; ++d) { - for (int y = 0; y < kernel_y; ++y) { - for (int x = 0; x < kernel_x; ++x) { - const int kernel_index = remap[y * kernel_x + x]; - const int kernel_index_x = kernel_index % kernel_x; - const int kernel_index_y = kernel_index / kernel_x; - T filters[4]; - for (int j = 0; j < 4; ++j) { - for (int i = 0; i < 4; ++i) { - const int s_ch = s * 4 + i; - const int d_ch = d * 4 + j; - if (s_ch < weights.shape.i && d_ch < weights.shape.o) { - const int f_index = weights.shape.LinearIndex( - {d_ch, kernel_index_y, kernel_index_x, s_ch}); - filters[i][j] = weights.data[f_index]; - } else { - filters[i][j] = 0.0f; - } - } - } - dst[counter++] = filters[0]; - dst[counter++] = filters[1]; - dst[counter++] = filters[2]; - dst[counter++] = filters[3]; - } - } - } - } -} - bool IsConvolutionTransposed3x3ThinSupported( const ConvolutionTransposedAttributes& attr); @@ -163,6 +118,10 @@ ConvolutionTransposed3x3Thin CreateConvolutionTransposed3x3Thin( const GpuInfo& gpu_info, const OperationDef& definition, const ConvolutionTransposedAttributes& attr); +ConvolutionTransposed3x3Thin CreateConvolutionTransposed3x3ThinDynamicWeights( + const GpuInfo& gpu_info, const OperationDef& definition, + const ConvolutionTransposedAttributes& attr); + } // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc index 5db4d03b0f4..61c5f1c32c4 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc @@ -19,30 +19,40 @@ limitations under the License. #include <utility> #include <vector> -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" namespace tflite { namespace gpu { namespace cl { +namespace { +ConvolutionTransposed4x4::WeightsUploadType GetBestWeightsUploadType( + const GpuInfo& gpu_info) { + ConvolutionTransposed4x4::WeightsUploadType weights_upload_type = + ConvolutionTransposed4x4::WeightsUploadType::GLOBAL_MEM; + if (gpu_info.IsPowerVR()) { + weights_upload_type = + ConvolutionTransposed4x4::WeightsUploadType::LOCAL_MEM_ASYNC; + } else if (gpu_info.IsNvidia() || gpu_info.IsIntel()) { + weights_upload_type = + ConvolutionTransposed4x4::WeightsUploadType::LOCAL_MEM_BY_THREADS; + } else if (gpu_info.IsAMD()) { + weights_upload_type = + ConvolutionTransposed4x4::WeightsUploadType::CONSTANT_MEM; + } else { + weights_upload_type = + ConvolutionTransposed4x4::WeightsUploadType::GLOBAL_MEM; + } + return weights_upload_type; +} +} // namespace + ConvolutionTransposed4x4::ConvolutionTransposed4x4( - const OperationDef& definition, const GpuInfo& gpu_info, - const ConvolutionTransposedAttributes& attr) + const OperationDef& definition, const GpuInfo& gpu_info) : GPUOperation(definition) { work_group_size_ = int3(8, 4, 1); - WeightsUploadType weights_upload_type = WeightsUploadType::GLOBAL_MEM; - if (gpu_info.IsPowerVR()) { - weights_upload_type = WeightsUploadType::LOCAL_MEM_ASYNC; - } else if (gpu_info.IsNvidia() || gpu_info.IsIntel()) { - weights_upload_type = WeightsUploadType::LOCAL_MEM_BY_THREADS; - } else if (gpu_info.IsAMD()) { - weights_upload_type = WeightsUploadType::CONSTANT_MEM; - } else { - weights_upload_type = WeightsUploadType::GLOBAL_MEM; - } - code_ = GenerateConvolutionTransposedCode(definition_, weights_upload_type); - UploadWeights(attr.weights, weights_upload_type); + code_ = GenerateConvolutionTransposedCode(definition_, + GetBestWeightsUploadType(gpu_info)); if (definition_.precision == CalculationsPrecision::F16 && gpu_info.IsPowerVR()) { compiler_options_.push_back(CompilerOptions::kClPowervrFp16); @@ -76,6 +86,19 @@ std::string ConvolutionTransposed4x4::GenerateConvolutionTransposedCode( } AddDstTensor("dst_tensor", dst_desc); + if (op_def.src_tensors.size() == 2) { + // dynamic weights + BufferDescriptor desc; + desc.element_type = op_def.src_tensors[1].data_type; + desc.element_size = 4; + desc.memory_type = + weights_upload_type == + ConvolutionTransposed4x4::WeightsUploadType::CONSTANT_MEM + ? MemoryType::CONSTANT + : MemoryType::GLOBAL; + AddSrcBuffer("weights", desc); + } + args_.AddInt("filter_offset"); const bool need_local_mem = @@ -84,7 +107,7 @@ std::string ConvolutionTransposed4x4::GenerateConvolutionTransposedCode( weights_upload_type == ConvolutionTransposed4x4::WeightsUploadType::LOCAL_MEM_ASYNC; - std::string c = GetCommonDefines(op_def.precision); + std::string c; switch (op_def.precision) { case CalculationsPrecision::F32: case CalculationsPrecision::F16: @@ -323,6 +346,10 @@ int3 ConvolutionTransposed4x4::GetGridSize() const { return int3(grid_x, grid_y, grid_z); } +std::vector<int> ConvolutionTransposed4x4::GetSpatialWeightsRemap() const { + return std::vector<int>{10, 11, 14, 15, 8, 9, 12, 13, 2, 3, 6, 7, 0, 1, 4, 5}; +} + bool IsConvolutionTransposed4x4Supported( const OperationDef& definition, const ConvolutionTransposedAttributes& attr) { @@ -334,7 +361,22 @@ bool IsConvolutionTransposed4x4Supported( ConvolutionTransposed4x4 CreateConvolutionTransposed4x4( const GpuInfo& gpu_info, const OperationDef& definition, const ConvolutionTransposedAttributes& attr) { - ConvolutionTransposed4x4 result(definition, gpu_info, attr); + ConvolutionTransposed4x4 result(definition, gpu_info); + result.UploadWeights(attr.weights, GetBestWeightsUploadType(gpu_info)); + + TensorLinearDescriptor desc; + desc.storage_type = LinearStorageType::TEXTURE_2D; + desc.element_type = definition.GetDataType(); + desc.UploadLinearData(attr.bias); + result.args_.AddObject( + "biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc))); + return result; +} + +ConvolutionTransposed4x4 CreateConvolutionTransposed4x4DynamicWeights( + const GpuInfo& gpu_info, const OperationDef& definition, + const ConvolutionTransposedAttributes& attr) { + ConvolutionTransposed4x4 result(definition, gpu_info); TensorLinearDescriptor desc; desc.storage_type = LinearStorageType::TEXTURE_2D; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h index febbc575c33..57304adb01f 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h @@ -19,7 +19,6 @@ limitations under the License. #include <vector> #include "tensorflow/lite/delegates/gpu/cl/buffer.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" #include "tensorflow/lite/delegates/gpu/cl/tensor.h" #include "tensorflow/lite/delegates/gpu/cl/util.h" @@ -27,6 +26,9 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" +#include "tensorflow/lite/delegates/gpu/common/task/weights_conversion.h" +#include "tensorflow/lite/delegates/gpu/common/task/weights_layout.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/gpu/common/types.h" @@ -52,6 +54,13 @@ class ConvolutionTransposed4x4 : public GPUOperation { ConvolutionTransposed4x4(const ConvolutionTransposed4x4&) = delete; ConvolutionTransposed4x4& operator=(const ConvolutionTransposed4x4&) = delete; + WeightsDescription GetWeightsDescription() const { + WeightsDescription desc; + desc.layout = WeightsLayout::kOICustomSSpatialI4O4; + desc.spatial_remap = GetSpatialWeightsRemap(); + return desc; + } + enum class WeightsUploadType { LOCAL_MEM_ASYNC, LOCAL_MEM_BY_THREADS, @@ -61,18 +70,20 @@ class ConvolutionTransposed4x4 : public GPUOperation { private: ConvolutionTransposed4x4(const OperationDef& definition, - const GpuInfo& gpu_info, - const ConvolutionTransposedAttributes& attr); + const GpuInfo& gpu_info); + friend ConvolutionTransposed4x4 CreateConvolutionTransposed4x4( const GpuInfo& gpu_info, const OperationDef& definition, const ConvolutionTransposedAttributes& attr); + friend ConvolutionTransposed4x4 CreateConvolutionTransposed4x4DynamicWeights( + const GpuInfo& gpu_info, const OperationDef& definition, + const ConvolutionTransposedAttributes& attr); + template <DataType T> void UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights, WeightsUploadType weights_upload_type); - template <DataType S, typename T> - void RearrangeWeightsData(const tflite::gpu::Tensor<OHWI, S>& weights, - absl::Span<T> dst); + std::vector<int> GetSpatialWeightsRemap() const; std::string GenerateConvolutionTransposedCode( const OperationDef& op_def, WeightsUploadType weights_upload_type); @@ -104,58 +115,18 @@ void ConvolutionTransposed4x4::UploadWeights( if (f32_weights) { float4* ptr = reinterpret_cast<float4*>(desc.data.data()); - RearrangeWeightsData(weights, absl::MakeSpan(ptr, flt4_count)); + RearrangeWeightsToOICustomSpatialI4O4(weights, GetSpatialWeightsRemap(), + absl::MakeSpan(ptr, flt4_count)); } else { half4* ptr = reinterpret_cast<half4*>(desc.data.data()); - RearrangeWeightsData(weights, absl::MakeSpan(ptr, flt4_count)); + RearrangeWeightsToOICustomSpatialI4O4(weights, GetSpatialWeightsRemap(), + absl::MakeSpan(ptr, flt4_count)); } args_.AddObject("weights", absl::make_unique<BufferDescriptor>(std::move(desc))); } -template <DataType S, typename T> -void ConvolutionTransposed4x4::RearrangeWeightsData( - const tflite::gpu::Tensor<OHWI, S>& weights, absl::Span<T> dst) { - const int src_depth = DivideRoundUp(weights.shape.i, 4); - const int dst_depth = DivideRoundUp(weights.shape.o, 4); - const int kernel_x = 4; - const int kernel_y = 4; - - const int remap[16] = {10, 11, 14, 15, 8, 9, 12, 13, 2, 3, 6, 7, 0, 1, 4, 5}; - - int counter = 0; - for (int d = 0; d < dst_depth; ++d) { - for (int s = 0; s < src_depth; ++s) { - for (int y = 0; y < kernel_y; ++y) { - for (int x = 0; x < kernel_x; ++x) { - const int kernel_index = remap[y * kernel_x + x]; - const int kernel_index_x = kernel_index % kernel_x; - const int kernel_index_y = kernel_index / kernel_x; - T filters[4]; - for (int j = 0; j < 4; ++j) { - for (int i = 0; i < 4; ++i) { - const int s_ch = s * 4 + i; - const int d_ch = d * 4 + j; - if (s_ch < weights.shape.i && d_ch < weights.shape.o) { - const int f_index = weights.shape.LinearIndex( - {d_ch, kernel_index_y, kernel_index_x, s_ch}); - filters[i][j] = weights.data[f_index]; - } else { - filters[i][j] = 0.0f; - } - } - } - dst[counter++] = filters[0]; - dst[counter++] = filters[1]; - dst[counter++] = filters[2]; - dst[counter++] = filters[3]; - } - } - } - } -} - bool IsConvolutionTransposed4x4Supported( const OperationDef& definition, const ConvolutionTransposedAttributes& attr); @@ -164,6 +135,10 @@ ConvolutionTransposed4x4 CreateConvolutionTransposed4x4( const GpuInfo& gpu_info, const OperationDef& definition, const ConvolutionTransposedAttributes& attr); +ConvolutionTransposed4x4 CreateConvolutionTransposed4x4DynamicWeights( + const GpuInfo& gpu_info, const OperationDef& definition, + const ConvolutionTransposedAttributes& attr); + } // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.cc index c21b0cc0fc1..9e75b0bca23 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.cc @@ -19,8 +19,7 @@ limitations under the License. #include <utility> #include <vector> -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" namespace tflite { namespace gpu { @@ -76,7 +75,7 @@ std::string ConvolutionTransposedThin::GenerateConvolutionTransposedCode( break; } - std::string c = GetCommonDefines(op_def.precision); + std::string c; c += "__kernel void main_function(\n"; c += "$0) {\n"; if (op_def.IsBatchSupported()) { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.h index 8b57ac03e5f..b18ca6814b8 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.h @@ -19,7 +19,6 @@ limitations under the License. #include <vector> #include "tensorflow/lite/delegates/gpu/cl/buffer.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/tensor.h" #include "tensorflow/lite/delegates/gpu/cl/texture2d.h" #include "tensorflow/lite/delegates/gpu/cl/util.h" @@ -27,6 +26,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/gpu/common/types.h" diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc index 538183231b0..b1787b5a7ed 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc @@ -20,9 +20,9 @@ limitations under the License. #include <vector> #include "tensorflow/lite/delegates/gpu/cl/cl_device.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" +#include "tensorflow/lite/delegates/gpu/common/task/util.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" namespace tflite { namespace gpu { @@ -86,7 +86,7 @@ std::string GenerateDepthwiseConvolutionCode( } op->AddDstTensor("dst_tensor", dst_desc); - std::string c = GetCommonDefines(op_def.precision); + std::string c; c += "__kernel void main_function(\n"; c += "$0) {\n"; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h index 5c708622ff9..1ca2b3ff8b8 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h @@ -19,7 +19,6 @@ limitations under the License. #include <vector> #include "tensorflow/lite/delegates/gpu/cl/buffer.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" #include "tensorflow/lite/delegates/gpu/cl/tensor.h" #include "tensorflow/lite/delegates/gpu/cl/texture2d.h" @@ -28,6 +27,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/gpu/common/types.h" diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.cc b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.cc index 6a5bc2c4700..ceaf59fff4d 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.cc @@ -18,9 +18,8 @@ limitations under the License. #include <string> #include <utility> -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" namespace tflite { namespace gpu { @@ -66,7 +65,7 @@ std::string DepthwiseConv3x3::GenerateDepthwiseConvCode( const bool manual_clamp = src_tensor_type == TensorStorageType::BUFFER || src_tensor_type == TensorStorageType::IMAGE_BUFFER; - std::string c = GetCommonDefines(op_def.precision); + std::string c; if (local_mem_uploads) { c += "__attribute__((reqd_work_group_size(8, 4, 1)))\n"; } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.h b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.h index b4f706a779c..4462ea87411 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.h @@ -20,7 +20,6 @@ limitations under the License. #include <vector> #include "tensorflow/lite/delegates/gpu/cl/buffer.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/tensor.h" #include "tensorflow/lite/delegates/gpu/cl/texture2d.h" #include "tensorflow/lite/delegates/gpu/cl/util.h" @@ -28,6 +27,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/gpu/common/types.h" diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc index c3b41c6e786..5167057408d 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc @@ -19,7 +19,6 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/substitute.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" #include "tensorflow/lite/delegates/gpu/cl/storage_type_util.h" namespace tflite { @@ -83,12 +82,7 @@ std::string GetOneInputCode(const OperationType& op_type, result = "$0 *= $0;\n"; break; case OperationType::TANH: - if (precision != CalculationsPrecision::F32) { - result = "float4 t = native_exp(convert_float4($0 * 2.0h));\n"; - result += "$0 = convert_half4(native_divide(t - 1.0f, t + 1.0f));\n"; - } else { - result = "$0 = tanh($0);\n"; - } + result = "$0 = tanh($0);\n"; break; default: return "Unknown operation type;\n"; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h index 572b731d908..22a1a8d4b92 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h @@ -18,9 +18,9 @@ limitations under the License. #include <string> -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc index 56d8b7d11c0..00f9a5260f9 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc @@ -388,7 +388,7 @@ TEST_F(OpenCLOperationTest, Square) { TEST_F(OpenCLOperationTest, Tanh) { TensorFloat32 src_tensor; src_tensor.shape = BHWC(1, 2, 1, 2); - src_tensor.data = {1.0f, 2.0f, 3.0f, 4.0f}; + src_tensor.data = {-50.0f, -0.1f, 0.1f, 50.0f}; for (auto storage : env_.GetSupportedStorages()) { for (auto precision : env_.GetSupportedPrecisions()) { @@ -407,8 +407,8 @@ TEST_F(OpenCLOperationTest, Tanh) { BHWC(1, 2, 1, 2), &dst_tensor)); EXPECT_THAT( dst_tensor.data, - Pointwise(FloatNear(eps), {std::tanh(1.0f), std::tanh(2.0f), - std::tanh(3.0f), std::tanh(4.0f)})); + Pointwise(FloatNear(eps), {std::tanh(-50.0f), std::tanh(-0.1f), + std::tanh(0.1f), std::tanh(50.0f)})); } } } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.cc b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.cc index b5caef81b43..50b92381c53 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.cc @@ -20,12 +20,10 @@ limitations under the License. #include <vector> #include "absl/memory/memory.h" -#include "tensorflow/lite/delegates/gpu/cl/device_info.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" #include "tensorflow/lite/delegates/gpu/cl/tensor.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/types.h" namespace tflite { @@ -83,7 +81,7 @@ std::string FullyConnected::GetFullyConnectedKernelCode( const bool weights_are_buffer = UseBufferForWeights(gpu_info); - std::string c = GetCommonDefines(op_def.precision); + std::string c; switch (op_def.precision) { case CalculationsPrecision::F32: c += "#define FLT16 float16\n"; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h index a508e741ef9..fda3dc6fa32 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h @@ -25,12 +25,11 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/lite/delegates/gpu/cl/buffer.h" #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h" -#include "tensorflow/lite/delegates/gpu/cl/device_info.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/texture2d.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/common/util.h" diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected_test.cc index 3db7fb70406..87020324f8a 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected_test.cc @@ -21,10 +21,10 @@ limitations under the License. #include <gtest/gtest.h> #include "tensorflow/lite/delegates/gpu/cl/environment.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" using ::testing::ElementsAreArray; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/lstm.cc b/tensorflow/lite/delegates/gpu/cl/kernels/lstm.cc index 11d2e209428..d659d8bff02 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/lstm.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/lstm.cc @@ -17,15 +17,14 @@ limitations under the License. #include <string> -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" namespace tflite { namespace gpu { namespace cl { namespace { std::string GetLSTMCode(const OperationDef& op_def, const GpuInfo& gpu_info) { - std::string c = GetCommonDefines(op_def.precision); + std::string c; c += "__kernel void main_function(\n"; c += "$0) {\n"; c += " int B = get_global_id(0);\n"; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/lstm.h b/tensorflow/lite/delegates/gpu/cl/kernels/lstm.h index fa0aa270158..bfb3cd25055 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/lstm.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/lstm.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_LSTM_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_LSTM_H_ -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/types.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc index 4905c5f9284..4692ac0834e 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc @@ -17,8 +17,7 @@ limitations under the License. #include <string> -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" namespace tflite { namespace gpu { @@ -44,7 +43,7 @@ std::string GetMaxUnpoolingKernelCode(const OperationDef& op_def, } op->AddDstTensor("dst_tensor", dst_desc); - std::string c = GetCommonDefines(op_def.precision); + std::string c; c += "__kernel void main_function(\n"; c += "$0) {\n"; c += " int X = get_global_id(0);\n"; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h index c1b6cbf334b..fb05e2dc7a3 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_MAX_UNPOOLING_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_MAX_UNPOOLING_H_ -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/types.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.cc b/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.cc index 5cdb81e5324..20edf489c2b 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.cc @@ -18,9 +18,7 @@ limitations under the License. #include <string> #include "tensorflow/lite/delegates/gpu/cl/cl_program.h" -#include "tensorflow/lite/delegates/gpu/cl/device_info.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" namespace tflite { namespace gpu { @@ -95,7 +93,7 @@ MeanStdDevNormalization::MeanStdDevNormalization(const OperationDef& definition, // For now, fix workgroup size to the biggest supported by the device, but not // larger than the number of tensor slices. int desired_work_group_size = - std::min(tensor_slices, gpu_info.max_work_group_size_x); + std::min(tensor_slices, gpu_info.GetMaxWorkGroupSizeForX()); if (gpu_info.IsMali()) { // Don't use more than 64 work items per work group on ARM Mali. They // implement local memory using the global memory, larger workgroups have @@ -136,9 +134,9 @@ MeanStdDevNormalization::MeanStdDevNormalization(const OperationDef& definition, work_group_size_.y = 1; // Required work_group_size_.z = 1; // Required code_ = GetNormalizationCode(); - if (gpu_info.cl_version >= OpenCLVersion::CL_3_0) { + if (gpu_info.IsCL30OrHigher()) { compiler_options_.push_back(CompilerOptions::kCl30); - } else if (gpu_info.cl_version >= OpenCLVersion::CL_2_0) { + } else if (gpu_info.IsCL20OrHigher()) { compiler_options_.push_back(CompilerOptions::kCl20); } } @@ -147,7 +145,7 @@ std::string MeanStdDevNormalization::GetNormalizationCode() { AddSrcTensor("src_tensor", definition_.src_tensors[0]); AddDstTensor("dst_tensor", definition_.dst_tensors[0]); - std::string c = GetCommonDefines(definition_.precision); + std::string c; c += GetVectorReduceCode(); c += GetReduceCode(); c += GetFilterCode(); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.h b/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.h index 6a4a1848394..defdf9d0c3a 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.h @@ -16,10 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_LSTM_NORMALIZATION_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_LSTM_NORMALIZATION_H_ -#include "tensorflow/lite/delegates/gpu/cl/device_info.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/types.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/padding.cc b/tensorflow/lite/delegates/gpu/cl/kernels/padding.cc index 8012e601c0b..77edcbfef2d 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/padding.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/padding.cc @@ -17,9 +17,8 @@ limitations under the License. #include <string> -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" namespace tflite { namespace gpu { @@ -36,7 +35,7 @@ std::string GetPaddingCode(const OperationDef& op_def, const std::string dst_batch = op_def.dst_tensors[0].HasAxis(Axis::BATCH) ? "B" : "0"; - std::string c = GetCommonDefines(op_def.precision); + std::string c; const std::string channels[] = {".x", ".y", ".z", ".w"}; if (attr.type == PaddingContentType::REFLECT) { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/padding.h b/tensorflow/lite/delegates/gpu/cl/kernels/padding.h index 81047162d20..be534017a28 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/padding.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/padding.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_PADDING_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_PADDING_H_ -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/types.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/pooling.cc b/tensorflow/lite/delegates/gpu/cl/kernels/pooling.cc index aa2069618c7..d687376e871 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/pooling.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/pooling.cc @@ -17,8 +17,8 @@ limitations under the License. #include <string> -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" +#include "tensorflow/lite/delegates/gpu/common/task/util.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" namespace tflite { namespace gpu { @@ -72,7 +72,7 @@ std::string GetAveragePoolingKernelCode(const OperationDef& op_def, op_def.src_tensors[0].storage_type == TensorStorageType::BUFFER || op_def.src_tensors[0].storage_type == TensorStorageType::IMAGE_BUFFER; - std::string c = GetCommonDefines(op_def.precision); + std::string c; c += "__kernel void main_function(\n"; c += "$0) {\n"; c += " int X = get_global_id(0);\n"; @@ -193,7 +193,7 @@ std::string GetMaxPoolingKernelCode(const OperationDef& op_def, dst_coord += ", " + dst_coords[i]; } - std::string c = GetCommonDefines(op_def.precision); + std::string c; c += "__kernel void main_function(\n"; c += "$0) {\n"; c += " int X = get_global_id(0);\n"; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/pooling.h b/tensorflow/lite/delegates/gpu/cl/kernels/pooling.h index 81a0dfff4de..ade7584b924 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/pooling.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/pooling.h @@ -17,10 +17,10 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_POOLING_H_ #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/tensor.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/types.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/prelu.cc b/tensorflow/lite/delegates/gpu/cl/kernels/prelu.cc index 088dc5b027f..97bad10cac9 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/prelu.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/prelu.cc @@ -17,7 +17,6 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/types/variant.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" #include "tensorflow/lite/delegates/gpu/cl/storage_type_util.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/prelu.h b/tensorflow/lite/delegates/gpu/cl/kernels/prelu.h index 1e98d043eed..3ec5a7791f4 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/prelu.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/prelu.h @@ -20,11 +20,11 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/cl_context.h" #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.cc b/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.cc index 1e08eb0ff52..267a56477c8 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.cc @@ -19,7 +19,6 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/types/variant.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.h b/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.h index 1e37e427af8..c84409d3f7a 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.h @@ -20,11 +20,11 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/cl_context.h" #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reduce.cc b/tensorflow/lite/delegates/gpu/cl/kernels/reduce.cc index 93345ce1d78..1ff3ec582e9 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/reduce.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/reduce.cc @@ -18,9 +18,9 @@ limitations under the License. #include <set> #include <string> -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/util.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/common/util.h" namespace tflite { @@ -196,7 +196,7 @@ std::string Reduce::GetReduceKernelCode(const OperationDef& op_def, } }; - std::string c = GetCommonDefines(op_def.precision); + std::string c; const std::string wg_x = std::to_string(work_group_size.x); const std::string wg_y = std::to_string(work_group_size.y); const std::string wg_z = std::to_string(work_group_size.z); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reduce.h b/tensorflow/lite/delegates/gpu/cl/kernels/reduce.h index 8747b5dc832..a46d094b5ba 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/reduce.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/reduce.h @@ -19,9 +19,9 @@ limitations under the License. #include <map> #include <set> -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/kernel_info.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/types.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/relu.h b/tensorflow/lite/delegates/gpu/cl/kernels/relu.h index 1b4e3a81605..ba6676b27a2 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/relu.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/relu.h @@ -18,8 +18,8 @@ limitations under the License. #include <string> -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reshape.cc b/tensorflow/lite/delegates/gpu/cl/kernels/reshape.cc index d965b6f0611..631e140cd36 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/reshape.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/reshape.cc @@ -17,15 +17,14 @@ limitations under the License. #include <string> -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" namespace tflite { namespace gpu { namespace cl { namespace { std::string GetReshapeCode(const OperationDef& op_def) { - std::string c = GetCommonDefines(op_def.precision); + std::string c; c += "__kernel void main_function(\n"; c += "$0) {\n"; if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reshape.h b/tensorflow/lite/delegates/gpu/cl/kernels/reshape.h index 59cc5c1560d..b02e7e1d10f 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/reshape.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/reshape.h @@ -17,8 +17,8 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_RESHAPE_H_ #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/types.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.cc b/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.cc index 78440e3c843..a040f300cba 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.cc @@ -17,8 +17,7 @@ limitations under the License. #include <string> -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" namespace tflite { namespace gpu { @@ -26,7 +25,7 @@ namespace cl { namespace { std::string GetReshapeCode(const OperationDef& op_def) { - std::string c = GetCommonDefines(op_def.precision); + std::string c; c += "__kernel void main_function(\n"; c += "$0) {\n"; if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.h b/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.h index 2052d45b3e1..b725cf06267 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.h @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h" #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/types.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc b/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc index cbf706f9bd2..1c69f84e9f1 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc @@ -15,9 +15,8 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/kernels/resize.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" namespace tflite { namespace gpu { @@ -56,7 +55,7 @@ std::string Resize::GetResizeCode(const OperationDef& op_def, args_.AddFloat("scale_factor_x"); args_.AddFloat("scale_factor_y"); - std::string c = GetCommonDefines(op_def.precision); + std::string c; c += "__kernel void main_function(\n"; c += "$0) {\n"; c += " int Y = get_global_id(1);\n"; @@ -191,7 +190,7 @@ std::string Resize3D::GetResize3DCode(const OperationDef& op_def, args_.AddFloat("scale_factor_y"); args_.AddFloat("scale_factor_z"); - std::string c = GetCommonDefines(op_def.precision); + std::string c; c += "__kernel void main_function(\n"; c += "$0) {\n"; c += " int Y = get_global_id(1);\n"; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/resize.h b/tensorflow/lite/delegates/gpu/cl/kernels/resize.h index 859d750b7e0..660243975ef 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/resize.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/resize.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_RESIZE_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_RESIZE_H_ -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/types.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax.cc b/tensorflow/lite/delegates/gpu/cl/kernels/softmax.cc index 03a53d5716b..ac55e6831e0 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax.cc @@ -17,16 +17,15 @@ limitations under the License. #include <string> -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" namespace tflite { namespace gpu { namespace cl { namespace { std::string GetSoftmaxKernelCode(const OperationDef& op_def) { - std::string c = GetCommonDefines(op_def.precision); + std::string c; c += "__kernel void main_function(\n"; c += "$0) {\n"; c += " int X = get_global_id(0);\n"; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax.h b/tensorflow/lite/delegates/gpu/cl/kernels/softmax.h index 4a9b5fd7c18..bfc9ecc3f16 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax.h @@ -17,8 +17,8 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_SOFTMAX_H_ #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/tensor.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/types.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc index d4d0442e61d..da04402b99f 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc @@ -17,8 +17,8 @@ limitations under the License. #include <string> -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/util.h" namespace tflite { namespace gpu { @@ -48,7 +48,7 @@ std::string Softmax1x1::GetSoftmaxKernelCode(const OperationDef& op_def) { args_.AddFloat("mask_w"); args_.AddInt("slices_x32"); - std::string c = GetCommonDefines(op_def.precision); + std::string c; c += "__kernel void main_function(\n"; c += "$0) {\n"; if (op_def.IsBatchSupported()) { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h index 2a50dee2d63..c670bdf3a6c 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h @@ -17,8 +17,8 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_SOFTMAX1X1_H_ #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/tensor.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.cc b/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.cc index f5323b48bae..3bf90092b60 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.cc @@ -19,15 +19,14 @@ limitations under the License. #include <utility> #include <vector> -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" namespace tflite { namespace gpu { namespace cl { namespace { std::string GetSpaceToDepthCode(const OperationDef& op_def) { - std::string c = GetCommonDefines(op_def.precision); + std::string c; c += "__kernel void main_function(\n"; c += "$0) {\n"; if (op_def.IsBatchSupported()) { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.h b/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.h index 08aca3054d6..0ef263acd38 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.h @@ -17,9 +17,9 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_SPACE_TO_DEPTH_H_ #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/types.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/special/BUILD b/tensorflow/lite/delegates/gpu/cl/kernels/special/BUILD index 92231338730..c359f0887a6 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/special/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/kernels/special/BUILD @@ -12,15 +12,14 @@ cc_library( "//tensorflow/lite/delegates/gpu/cl:cl_device", "//tensorflow/lite/delegates/gpu/cl:gpu_object", "//tensorflow/lite/delegates/gpu/cl:util", - "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", - "//tensorflow/lite/delegates/gpu/cl/kernels:util", - "//tensorflow/lite/delegates/gpu/cl/kernels:work_group_picking", "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:tensor", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) @@ -31,18 +30,16 @@ cc_library( deps = [ "//tensorflow/lite/delegates/gpu/cl:buffer", "//tensorflow/lite/delegates/gpu/cl:cl_kernel", - "//tensorflow/lite/delegates/gpu/cl:device_info", "//tensorflow/lite/delegates/gpu/cl:linear_storage", "//tensorflow/lite/delegates/gpu/cl:tensor", "//tensorflow/lite/delegates/gpu/cl:texture2d", - "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", - "//tensorflow/lite/delegates/gpu/cl/kernels:util", "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:tensor", "//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/special/depthwise_conv_plus_1x1_conv.cc b/tensorflow/lite/delegates/gpu/cl/kernels/special/depthwise_conv_plus_1x1_conv.cc index 7875f95b23e..2fbaf3de080 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/special/depthwise_conv_plus_1x1_conv.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/special/depthwise_conv_plus_1x1_conv.cc @@ -20,8 +20,7 @@ limitations under the License. #include <vector> #include "tensorflow/lite/delegates/gpu/cl/cl_device.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" namespace tflite { namespace gpu { @@ -128,7 +127,7 @@ std::string GenerateCode(const OperationDef& op_def, result->args_.AddInt("padding_y", -dw_attr.padding.prepended.h); result->args_.AddInt("dilation_y", dw_attr.dilations.h); - std::string c = GetCommonDefines(op_def.precision); + std::string c; c += "__kernel void main_function(\n"; c += "$0) {\n"; if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/special/depthwise_conv_plus_1x1_conv.h b/tensorflow/lite/delegates/gpu/cl/kernels/special/depthwise_conv_plus_1x1_conv.h index b87051104b7..4d1cc19f8b2 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/special/depthwise_conv_plus_1x1_conv.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/special/depthwise_conv_plus_1x1_conv.h @@ -20,12 +20,12 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/buffer.h" #include "tensorflow/lite/delegates/gpu/cl/gpu_object.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/util.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/gpu/common/types.h" diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.cc b/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.cc index f1c3ddb7045..7accb9276b2 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.cc @@ -20,12 +20,10 @@ limitations under the License. #include <vector> #include "absl/memory/memory.h" -#include "tensorflow/lite/delegates/gpu/cl/device_info.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" #include "tensorflow/lite/delegates/gpu/cl/tensor.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/types.h" namespace tflite { @@ -82,7 +80,7 @@ std::string FCFCAdd::GetFCFCAddKernelCode(const OperationDef& op_def, const bool weights_are_buffer = UseBufferForWeights(gpu_info); - std::string c = GetCommonDefines(op_def.precision); + std::string c; switch (op_def.precision) { case CalculationsPrecision::F32: c += "#define FLT16 float16\n"; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.h b/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.h index 09e0548c663..37248e0480c 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.h @@ -25,12 +25,11 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/lite/delegates/gpu/cl/buffer.h" #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h" -#include "tensorflow/lite/delegates/gpu/cl/device_info.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/texture2d.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/common/util.h" diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc index 1f8f985f3ee..3fa2121e986 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc @@ -17,8 +17,7 @@ limitations under the License. #include <string> -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" namespace tflite { namespace gpu { @@ -108,7 +107,7 @@ std::string StridedSlice::GetStridedSliceCode(const OperationDef& op_def, const std::string batch_id = op_def.dst_tensors[0].HasAxis(Axis::BATCH) ? "B" : "0"; - std::string c = GetCommonDefines(op_def.precision); + std::string c; c += "__kernel void main_function(\n"; c += "$0) {\n"; if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h index dddff2faf35..c2af72242ec 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_STRIDED_SLICE_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_STRIDED_SLICE_H_ -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/types.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc b/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc index 0182ec7d90c..a177ade3b12 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc @@ -18,8 +18,7 @@ limitations under the License. #include <string> #include "absl/strings/str_cat.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" namespace tflite { namespace gpu { @@ -29,7 +28,7 @@ std::string GetTransposeCode(const OperationDef& op_def, const TransposeAttributes& attr) { const std::string batch_id = op_def.dst_tensors[0].HasAxis(Axis::BATCH) ? "B" : "0"; - std::string c = GetCommonDefines(op_def.precision); + std::string c; c += "__kernel void main_function(\n"; c += "$0) {\n"; if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/transpose.h b/tensorflow/lite/delegates/gpu/cl/kernels/transpose.h index 631d5dc08b3..d318e9600c5 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/transpose.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/transpose.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_TRANSPOSE_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_TRANSPOSE_H_ -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/types.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/util.cc b/tensorflow/lite/delegates/gpu/cl/kernels/util.cc deleted file mode 100644 index 2530c73571b..00000000000 --- a/tensorflow/lite/delegates/gpu/cl/kernels/util.cc +++ /dev/null @@ -1,201 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" - -#include <cfloat> -#include <cmath> -#include <string> -#include <vector> - -#include "absl/strings/str_cat.h" -#include "absl/strings/substitute.h" -#include "tensorflow/lite/delegates/gpu/common/data_type.h" -#include "tensorflow/lite/delegates/gpu/common/precision.h" - -namespace tflite { -namespace gpu { -namespace cl { - -std::string GetCommonDefines(CalculationsPrecision precision) { - std::string result; - - switch (precision) { - case CalculationsPrecision::F32: - result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n"; - result += "#define ACCUM_FLT4 float4\n"; - result += "#define FLT float\n"; - result += "#define FLT2 float2\n"; - result += "#define FLT3 float3\n"; - result += "#define FLT4 float4\n"; - result += "#define TO_FLT4 convert_float4\n"; - result += "#define TO_ACCUM_TYPE convert_float4\n"; - result += "#define TO_ACCUM_FLT convert_float\n"; - break; - case CalculationsPrecision::F16: - result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n"; - result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"; - result += "#define ACCUM_FLT4 half4\n"; - result += "#define FLT half\n"; - result += "#define FLT2 half2\n"; - result += "#define FLT3 half3\n"; - result += "#define FLT4 half4\n"; - result += "#define TO_FLT4 convert_half4\n"; - result += "#define TO_ACCUM_TYPE convert_half4\n"; - result += "#define TO_ACCUM_FLT convert_half\n"; - break; - case CalculationsPrecision::F32_F16: - result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n"; - result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"; - result += "#define ACCUM_FLT4 float4\n"; - result += "#define FLT half\n"; - result += "#define FLT2 half2\n"; - result += "#define FLT3 half3\n"; - result += "#define FLT4 half4\n"; - result += "#define TO_FLT4 convert_half4\n"; - result += "#define TO_ACCUM_TYPE convert_float4\n"; - result += "#define TO_ACCUM_FLT convert_float\n"; - break; - } - return result; -} - -std::string GetXStrideCorrected(const std::string& src_x, - const std::string& batch_size, - const std::string& stride_x, - const std::string& padding_x) { - // TODO(sorokin) check perf and optimize with floor() if needed - // int p0 = src_x / batch_size;\n"; - // int b0 = src_x % batch_size;\n"; - // return p0 * stride_x * batch_size + b0 + padding_x;\n"; - return absl::Substitute("((($0) / $1) * $2 * $1 + (($0) % $1) + $3)", src_x, - batch_size, stride_x, padding_x); -} - -std::string GetXStrideCorrectedV2(const std::string& src_x, - const std::string& batch_size, - const std::string& stride_x, - const std::string& padding_x) { - // int p0 = src_x / batch_size;\n"; - // int b0 = src_x % batch_size;\n"; - // return (p0 * stride_x + padding_x) * batch_size + b0;\n"; - return absl::Substitute("(((($0) / $1) * $2 + $3) * $1 + ($0) % $1)", src_x, - batch_size, stride_x, padding_x); -} - -float4 GetMaskForLastPlane(int channels) { - float4 mask = float4(0.0f); - const int reminder = channels % 4 == 0 ? 4 : channels % 4; - for (int i = 0; i < reminder; ++i) { - mask[i] = 1.0f; - } - return mask; -} - -int3 GetFirstSuitableWorkGroup(const std::vector<int3>& wgs, int max_wg_size) { - for (const auto& wg : wgs) { - const int wg_size = wg.x * wg.y * wg.z; - if (wg_size <= max_wg_size) { - return wg; - } - } - return {1, 1, 1}; -} - -int GetRecommendedBlockSizeForConv(const GpuInfo& gpu_info, - CalculationsPrecision precision, - int task_size) { - const float task_size_per_cu = - task_size / static_cast<float>(gpu_info.compute_units_count); - int block_size = 1; - float threshold_1 = FLT_MAX; - float threshold_2 = FLT_MAX; - float threshold_4 = FLT_MAX; - if (!gpu_info.IsMali()) { - return 1; - } - MaliInfo mali_info = gpu_info.mali_info; - switch (precision) { - case CalculationsPrecision::F16: - if (mali_info.IsBifrostGen1()) { - threshold_1 = 256.0f; - threshold_2 = 256.0f * 4.0f; - threshold_4 = 256.0f * 8.0f; - } else if (mali_info.IsBifrostGen2()) { - threshold_1 = 256.0f * 2.0f; - threshold_2 = 256.0f * 8.0f; - threshold_4 = 256.0f * 16.0f; - } else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) { - threshold_1 = 256.0f; - threshold_2 = 256.0f * 6.0f; - threshold_4 = 256.0f * 16.0f; - } else if (mali_info.IsMidgard()) { - threshold_1 = 256.0f * 4.0f; - threshold_2 = 256.0f * 16.0f; - } - break; - case CalculationsPrecision::F32_F16: - if (mali_info.IsBifrostGen1()) { - threshold_1 = 256.0f; - threshold_2 = 256.0f * 3.0f; - threshold_4 = 256.0f * 32.0f; - } else if (mali_info.IsBifrostGen2()) { - threshold_1 = 256.0f * 2.0f; - threshold_2 = 256.0f * 8.0f; - } else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) { - threshold_1 = 256.0f; - threshold_2 = 256.0f * 8.0f; - } else if (mali_info.IsMidgard()) { - threshold_1 = 256.0f * 4.0f; - } - break; - case CalculationsPrecision::F32: - if (mali_info.IsBifrostGen1()) { - threshold_1 = 256.0f; - threshold_2 = 256.0f * 4.0f; - } else if (mali_info.IsBifrostGen2()) { - threshold_1 = 128.0f; - threshold_2 = 256.0f * 4.0f; - } else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) { - threshold_1 = 256.0f; - threshold_2 = 256.0f * 12.0f; - } else if (mali_info.IsMidgard()) { - threshold_1 = 256.0f * 16.0f; - } - break; - } - if (task_size_per_cu <= threshold_1) { - block_size = 1; - } else if (task_size_per_cu <= threshold_2) { - block_size = 2; - } else if (task_size_per_cu <= threshold_4) { - block_size = 4; - } else { - block_size = 8; - } - return block_size; -} - -int3 GetWorkGroupsCount(const int3& grid_size, const int3& work_group_size) { - int3 work_groups_count; - work_groups_count.x = DivideRoundUp(grid_size.x, work_group_size.x); - work_groups_count.y = DivideRoundUp(grid_size.y, work_group_size.y); - work_groups_count.z = DivideRoundUp(grid_size.z, work_group_size.z); - return work_groups_count; -} - -} // namespace cl -} // namespace gpu -} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc index 47b66ededdd..dd380b49e32 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc @@ -20,11 +20,10 @@ limitations under the License. #include "absl/strings/str_format.h" #include "tensorflow/lite/delegates/gpu/cl/cl_device.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/common/winograd_util.h" namespace tflite { @@ -59,7 +58,7 @@ Winograd4x4To36& Winograd4x4To36::operator=(Winograd4x4To36&& operation) { std::string Winograd4x4To36::GetWinograd4x4To36Code( const OperationDef& op_def) { - std::string c = GetCommonDefines(op_def.precision); + std::string c; const auto src_tensor_type = op_def.src_tensors[0].storage_type; const bool is_image_buffer = @@ -327,7 +326,7 @@ Winograd36To4x4& Winograd36To4x4::operator=(Winograd36To4x4&& operation) { std::string Winograd36To4x4::GetWinograd36To4x4Code( const OperationDef& op_def) { - std::string c = GetCommonDefines(op_def.precision); + std::string c; switch (op_def.precision) { case CalculationsPrecision::F32: diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h index 11430c99c0b..bdc3caf422d 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h @@ -17,11 +17,11 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_WINOGRAD_H_ #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" #include "tensorflow/lite/delegates/gpu/cl/tensor.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/winograd_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/winograd_test.cc index 9da73ba9783..c858356d405 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/winograd_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/winograd_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include <gmock/gmock.h> #include <gtest/gtest.h> #include "tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/winograd_util.h" @@ -78,9 +77,13 @@ TEST_F(OpenCLOperationTest, Winograd4x4To36) { for (auto precision : env_.GetSupportedPrecisions()) { float eps; if (precision == CalculationsPrecision::F32) { - eps = 1e-5f * (env_.device().SupportsFP32RTN() ? 1.0f : 4.0f); + eps = 1e-5f * (env_.device().GetInfo().opencl_info.supports_fp32_rtn + ? 1.0f + : 4.0f); } else { - eps = 1e-2f * (env_.device().SupportsFP16RTN() ? 1.0f : 4.0f); + eps = 1e-2f * (env_.device().GetInfo().opencl_info.supports_fp16_rtn + ? 1.0f + : 4.0f); } OperationDef op_def; op_def.precision = precision; @@ -151,9 +154,13 @@ TEST_F(OpenCLOperationTest, Winograd36To4x4) { for (auto precision : env_.GetSupportedPrecisions()) { float eps; if (precision == CalculationsPrecision::F32) { - eps = 1e-5f * (env_.device().SupportsFP32RTN() ? 1.0f : 4.0f); + eps = 1e-5f * (env_.device().GetInfo().opencl_info.supports_fp32_rtn + ? 1.0f + : 4.0f); } else { - eps = 1e-2f * (env_.device().SupportsFP16RTN() ? 1.0f : 4.0f); + eps = 1e-2f * (env_.device().GetInfo().opencl_info.supports_fp16_rtn + ? 1.0f + : 4.0f); } OperationDef op_def; op_def.precision = precision; diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/BUILD b/tensorflow/lite/delegates/gpu/cl/selectors/BUILD index f4a2f8f654a..976fa82b851 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/selectors/BUILD @@ -9,18 +9,18 @@ cc_library( hdrs = ["convolution_selector.h"], deps = [ "//tensorflow/lite/delegates/gpu/cl/kernels:conv_buffer_1x1", - "//tensorflow/lite/delegates/gpu/cl/kernels:conv_common", "//tensorflow/lite/delegates/gpu/cl/kernels:conv_constants", "//tensorflow/lite/delegates/gpu/cl/kernels:conv_powervr", "//tensorflow/lite/delegates/gpu/cl/kernels:conv_weights_converter", - "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", - "//tensorflow/lite/delegates/gpu/cl/kernels:work_group_picking", "//tensorflow/lite/delegates/gpu/common:model_hints", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "//tensorflow/lite/delegates/gpu/common/task:tensor_desc", + "//tensorflow/lite/delegates/gpu/common/task:weights_layout", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", "@com_google_absl//absl/memory", ], ) @@ -30,16 +30,16 @@ cc_library( srcs = ["convolution_transposed_selector.cc"], hdrs = ["convolution_transposed_selector.h"], deps = [ - "//tensorflow/lite/delegates/gpu/cl/kernels:conv_common", "//tensorflow/lite/delegates/gpu/cl/kernels:convolution_transposed", "//tensorflow/lite/delegates/gpu/cl/kernels:convolution_transposed_3x3", "//tensorflow/lite/delegates/gpu/cl/kernels:convolution_transposed_3x3_thin", "//tensorflow/lite/delegates/gpu/cl/kernels:convolution_transposed_4x4", "//tensorflow/lite/delegates/gpu/cl/kernels:convolution_transposed_thin", - "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "//tensorflow/lite/delegates/gpu/common/task:tensor_desc", + "//tensorflow/lite/delegates/gpu/common/task:weights_layout", "@com_google_absl//absl/memory", ], ) @@ -49,12 +49,11 @@ cc_library( hdrs = ["default_selector.h"], deps = [ ":subgraph", - "//tensorflow/lite/delegates/gpu/cl:device_info", - "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/cl/selectors/default:default_selector", "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:model_hints", "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "//tensorflow/lite/delegates/gpu/common/task:tensor_desc", ], ) @@ -67,9 +66,9 @@ cc_library( "//tensorflow/lite/delegates/gpu/cl:cl_device", "//tensorflow/lite/delegates/gpu/cl/kernels:depthwise_conv", "//tensorflow/lite/delegates/gpu/cl/kernels:depthwise_conv_3x3", - "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "@com_google_absl//absl/memory", ], ) @@ -82,9 +81,9 @@ cc_library( "//tensorflow/lite/delegates/gpu/cl/kernels:conv_buffer_1x1", "//tensorflow/lite/delegates/gpu/cl/kernels:conv_powervr", "//tensorflow/lite/delegates/gpu/cl/kernels:fully_connected", - "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "@com_google_absl//absl/memory", ], ) @@ -102,9 +101,7 @@ cc_library( ":subgraph", "//tensorflow/lite/delegates/gpu/cl:cl_device", "//tensorflow/lite/delegates/gpu/cl:storage_type_util", - "//tensorflow/lite/delegates/gpu/cl/kernels:conv_common", "//tensorflow/lite/delegates/gpu/cl/kernels:elementwise", - "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/cl/kernels:mean_stddev_normalization", "//tensorflow/lite/delegates/gpu/cl/kernels:transpose", "//tensorflow/lite/delegates/gpu/cl/selectors:default_selector", @@ -114,6 +111,8 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:tensor", + "//tensorflow/lite/delegates/gpu/common:winograd_util", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "//tensorflow/lite/delegates/gpu/common/task:tensor_desc", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:any", @@ -130,7 +129,6 @@ cc_library( "//tensorflow/lite/delegates/gpu/cl/kernels:concat_xy", "//tensorflow/lite/delegates/gpu/cl/kernels:concat_z", "//tensorflow/lite/delegates/gpu/cl/kernels:depthwise_conv", - "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/cl/kernels:lstm", "//tensorflow/lite/delegates/gpu/cl/kernels:max_unpooling", "//tensorflow/lite/delegates/gpu/cl/kernels:padding", @@ -151,6 +149,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "@com_google_absl//absl/memory", ], ) @@ -162,7 +161,6 @@ cc_library( deps = [ ":subgraph", "//tensorflow/lite/delegates/gpu/cl:cl_device", - "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/cl/kernels/special:depthwise_conv_plus_1x1_conv", "//tensorflow/lite/delegates/gpu/cl/kernels/special:fc_fc_add", "//tensorflow/lite/delegates/gpu/common:data_type", @@ -171,6 +169,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:tensor", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "//tensorflow/lite/delegates/gpu/common/task:tensor_desc", "@com_google_absl//absl/types:any", ], @@ -181,8 +180,8 @@ cc_library( srcs = ["subgraph.cc"], hdrs = ["subgraph.h"], deps = [ - "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "//tensorflow/lite/delegates/gpu/common/task:tensor_desc", ], ) diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc index b6b0131aeb9..5362004a903 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc @@ -20,8 +20,8 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/common/util.h" namespace tflite { @@ -55,10 +55,10 @@ std::unique_ptr<GPUOperation> SelectConvolutionDynamicWeightsAdreno( const Convolution2DAttributes& attr, const BHWC& weights_shape, const BHWC& dst_shape, const GpuInfo& gpu_info, const OperationDef& op_def, ModelHints hints, - ConvWeightsDescription* weights_desc) { + WeightsDescription* weights_desc) { ConvPowerVR conv = CreateConvPowerVRDynamicWeights( gpu_info, op_def, attr, weights_shape, &dst_shape); - *weights_desc = conv.GetConvWeightsDescription(); + *weights_desc = conv.GetWeightsDescription(); return absl::make_unique<ConvPowerVR>(std::move(conv)); } @@ -113,17 +113,17 @@ std::unique_ptr<GPUOperation> SelectConvolutionDynamicWeightsMali( const Convolution2DAttributes& attr, const BHWC& weights_shape, const BHWC& dst_shape, const GpuInfo& gpu_info, const OperationDef& op_def, ModelHints hints, - ConvWeightsDescription* weights_desc) { + WeightsDescription* weights_desc) { if (op_def.src_tensors[0].storage_type == TensorStorageType::BUFFER && IsConvBuffer1x1Supported(op_def, weights_shape, attr)) { ConvBuffer1x1 conv = CreateConvBuffer1x1DynamicWeights( gpu_info, op_def, attr, weights_shape, &dst_shape); - *weights_desc = conv.GetConvWeightsDescription(); + *weights_desc = conv.GetWeightsDescription(); return absl::make_unique<ConvBuffer1x1>(std::move(conv)); } else { ConvPowerVR conv = CreateConvPowerVRDynamicWeights( gpu_info, op_def, attr, weights_shape, &dst_shape); - *weights_desc = conv.GetConvWeightsDescription(); + *weights_desc = conv.GetWeightsDescription(); return absl::make_unique<ConvPowerVR>(std::move(conv)); } } @@ -172,7 +172,7 @@ std::unique_ptr<GPUOperation> SelectConvolutionWithDynamicWeights( const Convolution2DAttributes& attr, const BHWC& weights_shape, const BHWC& dst_shape, const GpuInfo& gpu_info, const OperationDef& op_def, ModelHints hints, - ConvWeightsDescription* weights_desc) { + WeightsDescription* weights_desc) { if (gpu_info.IsAdreno()) { return SelectConvolutionDynamicWeightsAdreno(attr, weights_shape, dst_shape, gpu_info, op_def, hints, @@ -184,13 +184,13 @@ std::unique_ptr<GPUOperation> SelectConvolutionWithDynamicWeights( } else { ConvPowerVR conv = CreateConvPowerVRDynamicWeights( gpu_info, op_def, attr, weights_shape, &dst_shape); - *weights_desc = conv.GetConvWeightsDescription(); + *weights_desc = conv.GetWeightsDescription(); return absl::make_unique<ConvPowerVR>(std::move(conv)); } } std::unique_ptr<GPUOperation> SelectConverterToConvWeights( - const ConvWeightsDescription& weights_desc, const OperationDef& op_def, + const WeightsDescription& weights_desc, const OperationDef& op_def, ModelHints hints) { ConverterToConvWeights converter = ConverterToConvWeights(op_def, weights_desc); diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h index b243b5c54ae..cef1e014217 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h @@ -18,12 +18,12 @@ limitations under the License. #include <memory> -#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_common.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/model_hints.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" +#include "tensorflow/lite/delegates/gpu/common/task/weights_layout.h" namespace tflite { namespace gpu { @@ -40,10 +40,10 @@ std::unique_ptr<GPUOperation> SelectConvolutionForWinograd( std::unique_ptr<GPUOperation> SelectConvolutionWithDynamicWeights( const Convolution2DAttributes& attr, const BHWC& weights_shape, const BHWC& dst_shape, const GpuInfo& gpu_info, const OperationDef& op_def, - ModelHints hints, ConvWeightsDescription* weights_desc); + ModelHints hints, WeightsDescription* weights_desc); std::unique_ptr<GPUOperation> SelectConverterToConvWeights( - const ConvWeightsDescription& weights_desc, const OperationDef& op_def, + const WeightsDescription& weights_desc, const OperationDef& op_def, ModelHints hints); } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.cc index 6b3c0530c41..2aa56e2bb19 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.cc @@ -98,10 +98,10 @@ std::unique_ptr<GPUOperation> SelectConvolutionTransposed( std::unique_ptr<GPUOperation> SelectConvolutionTransposedWithDynamicWeights( const ConvolutionTransposedAttributes& attr, const GpuInfo& gpu_info, - const OperationDef& op_def, ConvWeightsDescription* weights_desc) { + const OperationDef& op_def, WeightsDescription* weights_desc) { ConvolutionTransposed conv = CreateConvolutionTransposedDynamicWeights(gpu_info, op_def, attr); - *weights_desc = conv.GetConvWeightsDescription(); + *weights_desc = conv.GetWeightsDescription(); return absl::make_unique<ConvolutionTransposed>(std::move(conv)); } diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h index e00fe020015..4a2a6d9645f 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h @@ -18,10 +18,10 @@ limitations under the License. #include <memory> -#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_common.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" +#include "tensorflow/lite/delegates/gpu/common/task/weights_layout.h" namespace tflite { namespace gpu { @@ -33,7 +33,7 @@ std::unique_ptr<GPUOperation> SelectConvolutionTransposed( std::unique_ptr<GPUOperation> SelectConvolutionTransposedWithDynamicWeights( const ConvolutionTransposedAttributes& attr, const GpuInfo& gpu_info, - const OperationDef& op_def, ConvWeightsDescription* weights_desc); + const OperationDef& op_def, WeightsDescription* weights_desc); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default/BUILD b/tensorflow/lite/delegates/gpu/cl/selectors/default/BUILD index 9ab0993b078..71b3f3b68e5 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/default/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/selectors/default/BUILD @@ -7,12 +7,12 @@ cc_library( name = "default_selector", srcs = ["default_selector.cc"], deps = [ - "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/cl/selectors:subgraph", "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:model_hints", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default/default_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/default/default_selector.cc index 72eec316da6..a7d94fabf43 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/default/default_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/default/default_selector.cc @@ -16,12 +16,12 @@ limitations under the License. #include <memory> #include "absl/strings/str_cat.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/model_hints.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h index f6cb9a33ada..1efa215e602 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h @@ -18,12 +18,11 @@ limitations under the License. #include <memory> -#include "tensorflow/lite/delegates/gpu/cl/device_info.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/model_hints.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h index 647bee97f3d..0c920984bc1 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h @@ -18,9 +18,9 @@ limitations under the License. #include <memory> -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h index 5a2639f26f3..5b1563a9351 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h @@ -18,9 +18,9 @@ limitations under the License. #include <memory> -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc index fd12390cfd4..6dc28b4d454 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc @@ -34,28 +34,34 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" +#include "tensorflow/lite/delegates/gpu/common/winograd_util.h" namespace tflite { namespace gpu { namespace cl { namespace { -bool IsSuitableForWinograd4x4To6x6(const Convolution2DAttributes& attr, - const GpuInfo& gpu_info, - const BHWC& dst_shape) { +bool IsRecommendedForWinograd4x4To6x6(const Convolution2DAttributes& attr, + const GpuInfo& gpu_info, + const BHWC& dst_shape) { const int tiles_x = DivideRoundUp(dst_shape.w, 4); const int tiles_y = DivideRoundUp(dst_shape.h, 4); + const int total_tiles = tiles_x * tiles_y; const int src_depth = DivideRoundUp(attr.weights.shape.i, 4); const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4); - const bool suitable_attributes = - attr.weights.shape.w == 3 && attr.weights.shape.h == 3 && - attr.dilations == HW(1, 1) && attr.strides == HW(1, 1); // Mali among other devices has smaller SIMD line size - const int min_depth = gpu_info.IsMali() ? 16 : 32; - const int min_hw = gpu_info.IsMali() ? 32 : 128; + int min_depth = gpu_info.IsMali() ? 16 : 32; + const int min_tiles = gpu_info.IsMali() ? 32 : 128; + if (total_tiles >= min_tiles * 8) { + min_depth /= 4; + min_depth = std::max(min_depth, 8); + } else if (total_tiles >= min_tiles * 4) { + min_depth /= 2; + min_depth = std::max(min_depth, 8); + } const bool recommended_channels = dst_depth % 4 == 0 && src_depth >= min_depth && dst_depth >= min_depth; - const bool recommended_hw = tiles_x * tiles_y >= min_hw; - return suitable_attributes && recommended_channels && recommended_hw; + const bool recommended_hw = total_tiles >= min_tiles; + return recommended_channels && recommended_hw; } absl::Status WinogradFromNode(const GpuInfo& gpu_info, @@ -65,9 +71,12 @@ absl::Status WinogradFromNode(const GpuInfo& gpu_info, const BHWC& input_shape, const BHWC& output_shape, const Convolution2DAttributes& attr, GPUOperationsSubgraph* gpu_subgraph) { - if (!IsSuitableForWinograd4x4To6x6(attr, gpu_info, output_shape)) { + if (!IsSuitableForWinograd4x4To6x6(attr)) { return absl::UnimplementedError("No implementation for this case."); } + if (!IsRecommendedForWinograd4x4To6x6(attr, gpu_info, output_shape)) { + return absl::UnimplementedError("Not recommended for this case."); + } const int tiles_x = DivideRoundUp(output_shape.w, 4); const int tiles_y = DivideRoundUp(output_shape.h, 4); @@ -203,7 +212,7 @@ absl::Status GPUOperationFromNode(const GpuInfo& gpu_info, conv_op.output_ids = {static_cast<int>(outputs[0]->id)}; OperationDef conv_def = op_def; conv_def.src_tensors[1] = weights_desc; - ConvWeightsDescription conv_weights_desc; + WeightsDescription conv_weights_desc; conv_op.operation = SelectConvolutionWithDynamicWeights( attr, weights_shape, dst_shape, gpu_info, conv_def, hints, &conv_weights_desc); @@ -280,7 +289,7 @@ absl::Status GPUOperationFromNode(const GpuInfo& gpu_info, conv_op.output_ids = {static_cast<int>(outputs[0]->id)}; OperationDef conv_def = op_def; conv_def.src_tensors[1] = weights_desc; - ConvWeightsDescription conv_weights_desc; + WeightsDescription conv_weights_desc; conv_op.operation = SelectConvolutionWithDynamicWeights( attr, weights_shape, output_shape, gpu_info, conv_def, hints, &conv_weights_desc); @@ -329,7 +338,7 @@ absl::Status GPUOperationFromNode(const GpuInfo& gpu_info, conv_op.output_ids = {static_cast<int>(outputs[0]->id)}; OperationDef conv_def = op_def; conv_def.src_tensors[1] = weights_desc; - ConvWeightsDescription conv_weights_desc; + WeightsDescription conv_weights_desc; conv_op.operation = SelectConvolutionTransposedWithDynamicWeights( attr, gpu_info, conv_def, &conv_weights_desc); diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h index fa0c5cb1af1..b81bdaa0506 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h @@ -18,11 +18,11 @@ limitations under the License. #include <memory> -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/model_hints.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h index 971209f8f22..52c23102dd9 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h @@ -19,10 +19,10 @@ limitations under the License. #include <memory> #include "tensorflow/lite/delegates/gpu/cl/cl_device.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h index 09ba31f87b0..aecd0a0a519 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h @@ -20,10 +20,10 @@ limitations under the License. #include <set> #include <vector> -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.cc b/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.cc index 3329bc5a83e..cd3c987ccaf 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.cc @@ -17,8 +17,8 @@ limitations under the License. #include <memory> -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h b/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h index f402358b2c5..f94e0c430a3 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h @@ -19,8 +19,8 @@ limitations under the License. #include <memory> #include <vector> -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/cl/serialization.cc b/tensorflow/lite/delegates/gpu/cl/serialization.cc index cc608177e1d..812782c60bc 100644 --- a/tensorflow/lite/delegates/gpu/cl/serialization.cc +++ b/tensorflow/lite/delegates/gpu/cl/serialization.cc @@ -19,13 +19,13 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/gpu_object.h" #include "tensorflow/lite/delegates/gpu/cl/inference_context.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/serialization_generated.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/precision.h" #include "tensorflow/lite/delegates/gpu/common/task/arguments.h" #include "tensorflow/lite/delegates/gpu/common/task/buffer_desc.h" #include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h" #include "tensorflow/lite/delegates/gpu/common/task/tensor_linear_desc.h" #include "tensorflow/lite/delegates/gpu/common/task/texture2d_desc.h" @@ -187,10 +187,6 @@ Layout ToEnum(data::Layout type) { return Layout::UNKNOWN; } } -} // namespace - -namespace cl { -namespace { data::CalculationsPrecision ToFB(CalculationsPrecision type) { switch (type) { @@ -279,7 +275,6 @@ CompilerOptions ToEnum(data::CompilerOptions type) { } } // namespace -} // namespace cl flatbuffers::Offset<data::Int2> Encode( const int2& v, flatbuffers::FlatBufferBuilder* builder) { @@ -732,7 +727,6 @@ flatbuffers::Offset<data::Arguments> Encode( return arguments_builder.Finish(); } -namespace cl { flatbuffers::Offset<data::OperationDef> Encode( const OperationDef& def, flatbuffers::FlatBufferBuilder* builder) { std::vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>> @@ -773,22 +767,6 @@ void Decode(const data::OperationDef* fb_def, OperationDef* def) { def->precision = ToEnum(fb_def->precision()); } -flatbuffers::Offset<data::TensorDescWithId> Encode( - const TensorDescriptor& desc, const ValueId& id, - flatbuffers::FlatBufferBuilder* builder) { - auto desc_fb = Encode(desc, builder); - data::TensorDescWithIdBuilder desc_builder(*builder); - desc_builder.add_desc(desc_fb); - desc_builder.add_id(id); - return desc_builder.Finish(); -} - -void Decode(const data::TensorDescWithId* fb_desc, TensorDescriptor* desc, - ValueId* id) { - Decode(fb_desc->desc(), desc); - *id = fb_desc->id(); -} - absl::Status Decode(const data::GPUOperation* fb_op, GPUOperation* op) { RETURN_IF_ERROR(Decode(fb_op->arguments(), &op->args_)); op->code_ = std::string(fb_op->code()->c_str(), fb_op->code()->size()); @@ -881,6 +859,24 @@ flatbuffers::Offset<data::GPUOperation> Encode( return op_builder.Finish(); } +namespace cl { + +flatbuffers::Offset<data::TensorDescWithId> Encode( + const TensorDescriptor& desc, const ValueId& id, + flatbuffers::FlatBufferBuilder* builder) { + auto desc_fb = Encode(desc, builder); + data::TensorDescWithIdBuilder desc_builder(*builder); + desc_builder.add_desc(desc_fb); + desc_builder.add_id(id); + return desc_builder.Finish(); +} + +void Decode(const data::TensorDescWithId* fb_desc, TensorDescriptor* desc, + ValueId* id) { + Decode(fb_desc->desc(), desc); + *id = fb_desc->id(); +} + flatbuffers::Offset<data::CLNode> Encode( const CLNode& node, flatbuffers::FlatBufferBuilder* builder) { auto op_fb = Encode(node.cl_operation.GetGpuOperation(), builder); diff --git a/tensorflow/lite/delegates/gpu/cl/serialization.fbs b/tensorflow/lite/delegates/gpu/cl/serialization.fbs index 9d5cf5ed783..67bd587162e 100644 --- a/tensorflow/lite/delegates/gpu/cl/serialization.fbs +++ b/tensorflow/lite/delegates/gpu/cl/serialization.fbs @@ -16,66 +16,13 @@ include "tensorflow/lite/delegates/gpu/common/task/serialization_base.fbs"; namespace tflite.gpu.cl.data; -enum CalculationsPrecision : byte { - F32 = 0, - F32_F16 = 1, - F16 = 2, -} - -enum TensorToGrid : byte { - CUSTOM = 0, - WB_TO_X_HD_TO_Y_S_TO_Z = 1, - WB_TO_X_HD_TO_Y_Z_IS_1 = 2, - WB_TO_X_H_TO_Y_D_TO_Z = 3, - B_TO_X_Y_IS_1_Z_IS_1 = 4, -} - -enum CompilerOptions : byte { - ADRENO_FULL_SIMD_LINE = 0, - ADRENO_MORE_WAVES = 1, - POWERVR_FP16 = 2, - CL_OPT_DISABLE = 3, - CL_2_0 = 4, - CL_3_0 = 5, -} - -table OperationDef { - precision:CalculationsPrecision; - src_tensors:[tflite.gpu.data.TensorDescriptor]; - dst_tensors:[tflite.gpu.data.TensorDescriptor]; -} - -table CompilerOption { - option:CompilerOptions; -} - -table GPUOperation { - arguments:tflite.gpu.data.Arguments; - code:string; - work_group_size:tflite.gpu.data.Int3; - compiler_options:[CompilerOption]; - tensor_to_grid:TensorToGrid; - elementwise:bool; - linkable:bool; - check_src_channels_size:bool; - definition:OperationDef; - grid_dimension:int32; - work_group_launch_order:tflite.gpu.data.Int3; - grid_size:tflite.gpu.data.Int3; - src_tensors_names:[string]; - dst_tensors_names:[string]; - work_groups_count:tflite.gpu.data.Int3; - linkable_count:int32; - elementwise_code:string; -} - table TensorDescWithId { desc:tflite.gpu.data.TensorDescriptor; id:int32; } table CLNode { - gpu_op:GPUOperation; + gpu_op:tflite.gpu.data.GPUOperation; input_ids:[int32]; output_ids:[int32]; name:string; @@ -91,7 +38,7 @@ table InferenceContext { flush_periodically:bool; flush_period:int32; need_manual_release:bool; - precision:CalculationsPrecision; + precision:tflite.gpu.data.CalculationsPrecision; storage_type:tflite.gpu.data.TensorStorageType; nodes:[CLNode]; tensors:[TensorDescWithId]; diff --git a/tensorflow/lite/delegates/gpu/cl/serialization_generated.h b/tensorflow/lite/delegates/gpu/cl/serialization_generated.h index a3bc04e12ca..d423b19e60a 100644 --- a/tensorflow/lite/delegates/gpu/cl/serialization_generated.h +++ b/tensorflow/lite/delegates/gpu/cl/serialization_generated.h @@ -27,15 +27,6 @@ namespace gpu { namespace cl { namespace data { -struct OperationDef; -struct OperationDefBuilder; - -struct CompilerOption; -struct CompilerOptionBuilder; - -struct GPUOperation; -struct GPUOperationBuilder; - struct TensorDescWithId; struct TensorDescWithIdBuilder; @@ -48,500 +39,6 @@ struct PairOfValueIdsBuilder; struct InferenceContext; struct InferenceContextBuilder; -enum class CalculationsPrecision : int8_t { - F32 = 0, - F32_F16 = 1, - F16 = 2, - MIN = F32, - MAX = F16 -}; - -inline const CalculationsPrecision (&EnumValuesCalculationsPrecision())[3] { - static const CalculationsPrecision values[] = { - CalculationsPrecision::F32, - CalculationsPrecision::F32_F16, - CalculationsPrecision::F16 - }; - return values; -} - -inline const char * const *EnumNamesCalculationsPrecision() { - static const char * const names[4] = { - "F32", - "F32_F16", - "F16", - nullptr - }; - return names; -} - -inline const char *EnumNameCalculationsPrecision(CalculationsPrecision e) { - if (flatbuffers::IsOutRange(e, CalculationsPrecision::F32, CalculationsPrecision::F16)) return ""; - const size_t index = static_cast<size_t>(e); - return EnumNamesCalculationsPrecision()[index]; -} - -enum class TensorToGrid : int8_t { - CUSTOM = 0, - WB_TO_X_HD_TO_Y_S_TO_Z = 1, - WB_TO_X_HD_TO_Y_Z_IS_1 = 2, - WB_TO_X_H_TO_Y_D_TO_Z = 3, - B_TO_X_Y_IS_1_Z_IS_1 = 4, - MIN = CUSTOM, - MAX = B_TO_X_Y_IS_1_Z_IS_1 -}; - -inline const TensorToGrid (&EnumValuesTensorToGrid())[5] { - static const TensorToGrid values[] = { - TensorToGrid::CUSTOM, - TensorToGrid::WB_TO_X_HD_TO_Y_S_TO_Z, - TensorToGrid::WB_TO_X_HD_TO_Y_Z_IS_1, - TensorToGrid::WB_TO_X_H_TO_Y_D_TO_Z, - TensorToGrid::B_TO_X_Y_IS_1_Z_IS_1 - }; - return values; -} - -inline const char * const *EnumNamesTensorToGrid() { - static const char * const names[6] = { - "CUSTOM", - "WB_TO_X_HD_TO_Y_S_TO_Z", - "WB_TO_X_HD_TO_Y_Z_IS_1", - "WB_TO_X_H_TO_Y_D_TO_Z", - "B_TO_X_Y_IS_1_Z_IS_1", - nullptr - }; - return names; -} - -inline const char *EnumNameTensorToGrid(TensorToGrid e) { - if (flatbuffers::IsOutRange(e, TensorToGrid::CUSTOM, TensorToGrid::B_TO_X_Y_IS_1_Z_IS_1)) return ""; - const size_t index = static_cast<size_t>(e); - return EnumNamesTensorToGrid()[index]; -} - -enum class CompilerOptions : int8_t { - ADRENO_FULL_SIMD_LINE = 0, - ADRENO_MORE_WAVES = 1, - POWERVR_FP16 = 2, - CL_OPT_DISABLE = 3, - CL_2_0 = 4, - CL_3_0 = 5, - MIN = ADRENO_FULL_SIMD_LINE, - MAX = CL_3_0 -}; - -inline const CompilerOptions (&EnumValuesCompilerOptions())[6] { - static const CompilerOptions values[] = { - CompilerOptions::ADRENO_FULL_SIMD_LINE, - CompilerOptions::ADRENO_MORE_WAVES, - CompilerOptions::POWERVR_FP16, - CompilerOptions::CL_OPT_DISABLE, - CompilerOptions::CL_2_0, - CompilerOptions::CL_3_0 - }; - return values; -} - -inline const char * const *EnumNamesCompilerOptions() { - static const char * const names[7] = { - "ADRENO_FULL_SIMD_LINE", - "ADRENO_MORE_WAVES", - "POWERVR_FP16", - "CL_OPT_DISABLE", - "CL_2_0", - "CL_3_0", - nullptr - }; - return names; -} - -inline const char *EnumNameCompilerOptions(CompilerOptions e) { - if (flatbuffers::IsOutRange(e, CompilerOptions::ADRENO_FULL_SIMD_LINE, CompilerOptions::CL_3_0)) return ""; - const size_t index = static_cast<size_t>(e); - return EnumNamesCompilerOptions()[index]; -} - -struct OperationDef FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef OperationDefBuilder Builder; - enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { - VT_PRECISION = 4, - VT_SRC_TENSORS = 6, - VT_DST_TENSORS = 8 - }; - tflite::gpu::cl::data::CalculationsPrecision precision() const { - return static_cast<tflite::gpu::cl::data::CalculationsPrecision>(GetField<int8_t>(VT_PRECISION, 0)); - } - const flatbuffers::Vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>> *src_tensors() const { - return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>> *>(VT_SRC_TENSORS); - } - const flatbuffers::Vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>> *dst_tensors() const { - return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>> *>(VT_DST_TENSORS); - } - bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyField<int8_t>(verifier, VT_PRECISION) && - VerifyOffset(verifier, VT_SRC_TENSORS) && - verifier.VerifyVector(src_tensors()) && - verifier.VerifyVectorOfTables(src_tensors()) && - VerifyOffset(verifier, VT_DST_TENSORS) && - verifier.VerifyVector(dst_tensors()) && - verifier.VerifyVectorOfTables(dst_tensors()) && - verifier.EndTable(); - } -}; - -struct OperationDefBuilder { - typedef OperationDef Table; - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_precision(tflite::gpu::cl::data::CalculationsPrecision precision) { - fbb_.AddElement<int8_t>(OperationDef::VT_PRECISION, static_cast<int8_t>(precision), 0); - } - void add_src_tensors(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>> src_tensors) { - fbb_.AddOffset(OperationDef::VT_SRC_TENSORS, src_tensors); - } - void add_dst_tensors(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>> dst_tensors) { - fbb_.AddOffset(OperationDef::VT_DST_TENSORS, dst_tensors); - } - explicit OperationDefBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { - start_ = fbb_.StartTable(); - } - flatbuffers::Offset<OperationDef> Finish() { - const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<OperationDef>(end); - return o; - } -}; - -inline flatbuffers::Offset<OperationDef> CreateOperationDef( - flatbuffers::FlatBufferBuilder &_fbb, - tflite::gpu::cl::data::CalculationsPrecision precision = tflite::gpu::cl::data::CalculationsPrecision::F32, - flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>> src_tensors = 0, - flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>> dst_tensors = 0) { - OperationDefBuilder builder_(_fbb); - builder_.add_dst_tensors(dst_tensors); - builder_.add_src_tensors(src_tensors); - builder_.add_precision(precision); - return builder_.Finish(); -} - -inline flatbuffers::Offset<OperationDef> CreateOperationDefDirect( - flatbuffers::FlatBufferBuilder &_fbb, - tflite::gpu::cl::data::CalculationsPrecision precision = tflite::gpu::cl::data::CalculationsPrecision::F32, - const std::vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>> *src_tensors = nullptr, - const std::vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>> *dst_tensors = nullptr) { - auto src_tensors__ = src_tensors ? _fbb.CreateVector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>(*src_tensors) : 0; - auto dst_tensors__ = dst_tensors ? _fbb.CreateVector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>(*dst_tensors) : 0; - return tflite::gpu::cl::data::CreateOperationDef( - _fbb, - precision, - src_tensors__, - dst_tensors__); -} - -struct CompilerOption FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef CompilerOptionBuilder Builder; - enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { - VT_OPTION = 4 - }; - tflite::gpu::cl::data::CompilerOptions option() const { - return static_cast<tflite::gpu::cl::data::CompilerOptions>(GetField<int8_t>(VT_OPTION, 0)); - } - bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyField<int8_t>(verifier, VT_OPTION) && - verifier.EndTable(); - } -}; - -struct CompilerOptionBuilder { - typedef CompilerOption Table; - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_option(tflite::gpu::cl::data::CompilerOptions option) { - fbb_.AddElement<int8_t>(CompilerOption::VT_OPTION, static_cast<int8_t>(option), 0); - } - explicit CompilerOptionBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { - start_ = fbb_.StartTable(); - } - flatbuffers::Offset<CompilerOption> Finish() { - const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<CompilerOption>(end); - return o; - } -}; - -inline flatbuffers::Offset<CompilerOption> CreateCompilerOption( - flatbuffers::FlatBufferBuilder &_fbb, - tflite::gpu::cl::data::CompilerOptions option = tflite::gpu::cl::data::CompilerOptions::ADRENO_FULL_SIMD_LINE) { - CompilerOptionBuilder builder_(_fbb); - builder_.add_option(option); - return builder_.Finish(); -} - -struct GPUOperation FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef GPUOperationBuilder Builder; - enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { - VT_ARGUMENTS = 4, - VT_CODE = 6, - VT_WORK_GROUP_SIZE = 8, - VT_COMPILER_OPTIONS = 10, - VT_TENSOR_TO_GRID = 12, - VT_ELEMENTWISE = 14, - VT_LINKABLE = 16, - VT_CHECK_SRC_CHANNELS_SIZE = 18, - VT_DEFINITION = 20, - VT_GRID_DIMENSION = 22, - VT_WORK_GROUP_LAUNCH_ORDER = 24, - VT_GRID_SIZE = 26, - VT_SRC_TENSORS_NAMES = 28, - VT_DST_TENSORS_NAMES = 30, - VT_WORK_GROUPS_COUNT = 32, - VT_LINKABLE_COUNT = 34, - VT_ELEMENTWISE_CODE = 36 - }; - const tflite::gpu::data::Arguments *arguments() const { - return GetPointer<const tflite::gpu::data::Arguments *>(VT_ARGUMENTS); - } - const flatbuffers::String *code() const { - return GetPointer<const flatbuffers::String *>(VT_CODE); - } - const tflite::gpu::data::Int3 *work_group_size() const { - return GetPointer<const tflite::gpu::data::Int3 *>(VT_WORK_GROUP_SIZE); - } - const flatbuffers::Vector<flatbuffers::Offset<tflite::gpu::cl::data::CompilerOption>> *compiler_options() const { - return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<tflite::gpu::cl::data::CompilerOption>> *>(VT_COMPILER_OPTIONS); - } - tflite::gpu::cl::data::TensorToGrid tensor_to_grid() const { - return static_cast<tflite::gpu::cl::data::TensorToGrid>(GetField<int8_t>(VT_TENSOR_TO_GRID, 0)); - } - bool elementwise() const { - return GetField<uint8_t>(VT_ELEMENTWISE, 0) != 0; - } - bool linkable() const { - return GetField<uint8_t>(VT_LINKABLE, 0) != 0; - } - bool check_src_channels_size() const { - return GetField<uint8_t>(VT_CHECK_SRC_CHANNELS_SIZE, 0) != 0; - } - const tflite::gpu::cl::data::OperationDef *definition() const { - return GetPointer<const tflite::gpu::cl::data::OperationDef *>(VT_DEFINITION); - } - int32_t grid_dimension() const { - return GetField<int32_t>(VT_GRID_DIMENSION, 0); - } - const tflite::gpu::data::Int3 *work_group_launch_order() const { - return GetPointer<const tflite::gpu::data::Int3 *>(VT_WORK_GROUP_LAUNCH_ORDER); - } - const tflite::gpu::data::Int3 *grid_size() const { - return GetPointer<const tflite::gpu::data::Int3 *>(VT_GRID_SIZE); - } - const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *src_tensors_names() const { - return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_SRC_TENSORS_NAMES); - } - const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *dst_tensors_names() const { - return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_DST_TENSORS_NAMES); - } - const tflite::gpu::data::Int3 *work_groups_count() const { - return GetPointer<const tflite::gpu::data::Int3 *>(VT_WORK_GROUPS_COUNT); - } - int32_t linkable_count() const { - return GetField<int32_t>(VT_LINKABLE_COUNT, 0); - } - const flatbuffers::String *elementwise_code() const { - return GetPointer<const flatbuffers::String *>(VT_ELEMENTWISE_CODE); - } - bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_ARGUMENTS) && - verifier.VerifyTable(arguments()) && - VerifyOffset(verifier, VT_CODE) && - verifier.VerifyString(code()) && - VerifyOffset(verifier, VT_WORK_GROUP_SIZE) && - verifier.VerifyTable(work_group_size()) && - VerifyOffset(verifier, VT_COMPILER_OPTIONS) && - verifier.VerifyVector(compiler_options()) && - verifier.VerifyVectorOfTables(compiler_options()) && - VerifyField<int8_t>(verifier, VT_TENSOR_TO_GRID) && - VerifyField<uint8_t>(verifier, VT_ELEMENTWISE) && - VerifyField<uint8_t>(verifier, VT_LINKABLE) && - VerifyField<uint8_t>(verifier, VT_CHECK_SRC_CHANNELS_SIZE) && - VerifyOffset(verifier, VT_DEFINITION) && - verifier.VerifyTable(definition()) && - VerifyField<int32_t>(verifier, VT_GRID_DIMENSION) && - VerifyOffset(verifier, VT_WORK_GROUP_LAUNCH_ORDER) && - verifier.VerifyTable(work_group_launch_order()) && - VerifyOffset(verifier, VT_GRID_SIZE) && - verifier.VerifyTable(grid_size()) && - VerifyOffset(verifier, VT_SRC_TENSORS_NAMES) && - verifier.VerifyVector(src_tensors_names()) && - verifier.VerifyVectorOfStrings(src_tensors_names()) && - VerifyOffset(verifier, VT_DST_TENSORS_NAMES) && - verifier.VerifyVector(dst_tensors_names()) && - verifier.VerifyVectorOfStrings(dst_tensors_names()) && - VerifyOffset(verifier, VT_WORK_GROUPS_COUNT) && - verifier.VerifyTable(work_groups_count()) && - VerifyField<int32_t>(verifier, VT_LINKABLE_COUNT) && - VerifyOffset(verifier, VT_ELEMENTWISE_CODE) && - verifier.VerifyString(elementwise_code()) && - verifier.EndTable(); - } -}; - -struct GPUOperationBuilder { - typedef GPUOperation Table; - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_arguments(flatbuffers::Offset<tflite::gpu::data::Arguments> arguments) { - fbb_.AddOffset(GPUOperation::VT_ARGUMENTS, arguments); - } - void add_code(flatbuffers::Offset<flatbuffers::String> code) { - fbb_.AddOffset(GPUOperation::VT_CODE, code); - } - void add_work_group_size(flatbuffers::Offset<tflite::gpu::data::Int3> work_group_size) { - fbb_.AddOffset(GPUOperation::VT_WORK_GROUP_SIZE, work_group_size); - } - void add_compiler_options(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::gpu::cl::data::CompilerOption>>> compiler_options) { - fbb_.AddOffset(GPUOperation::VT_COMPILER_OPTIONS, compiler_options); - } - void add_tensor_to_grid(tflite::gpu::cl::data::TensorToGrid tensor_to_grid) { - fbb_.AddElement<int8_t>(GPUOperation::VT_TENSOR_TO_GRID, static_cast<int8_t>(tensor_to_grid), 0); - } - void add_elementwise(bool elementwise) { - fbb_.AddElement<uint8_t>(GPUOperation::VT_ELEMENTWISE, static_cast<uint8_t>(elementwise), 0); - } - void add_linkable(bool linkable) { - fbb_.AddElement<uint8_t>(GPUOperation::VT_LINKABLE, static_cast<uint8_t>(linkable), 0); - } - void add_check_src_channels_size(bool check_src_channels_size) { - fbb_.AddElement<uint8_t>(GPUOperation::VT_CHECK_SRC_CHANNELS_SIZE, static_cast<uint8_t>(check_src_channels_size), 0); - } - void add_definition(flatbuffers::Offset<tflite::gpu::cl::data::OperationDef> definition) { - fbb_.AddOffset(GPUOperation::VT_DEFINITION, definition); - } - void add_grid_dimension(int32_t grid_dimension) { - fbb_.AddElement<int32_t>(GPUOperation::VT_GRID_DIMENSION, grid_dimension, 0); - } - void add_work_group_launch_order(flatbuffers::Offset<tflite::gpu::data::Int3> work_group_launch_order) { - fbb_.AddOffset(GPUOperation::VT_WORK_GROUP_LAUNCH_ORDER, work_group_launch_order); - } - void add_grid_size(flatbuffers::Offset<tflite::gpu::data::Int3> grid_size) { - fbb_.AddOffset(GPUOperation::VT_GRID_SIZE, grid_size); - } - void add_src_tensors_names(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> src_tensors_names) { - fbb_.AddOffset(GPUOperation::VT_SRC_TENSORS_NAMES, src_tensors_names); - } - void add_dst_tensors_names(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> dst_tensors_names) { - fbb_.AddOffset(GPUOperation::VT_DST_TENSORS_NAMES, dst_tensors_names); - } - void add_work_groups_count(flatbuffers::Offset<tflite::gpu::data::Int3> work_groups_count) { - fbb_.AddOffset(GPUOperation::VT_WORK_GROUPS_COUNT, work_groups_count); - } - void add_linkable_count(int32_t linkable_count) { - fbb_.AddElement<int32_t>(GPUOperation::VT_LINKABLE_COUNT, linkable_count, 0); - } - void add_elementwise_code(flatbuffers::Offset<flatbuffers::String> elementwise_code) { - fbb_.AddOffset(GPUOperation::VT_ELEMENTWISE_CODE, elementwise_code); - } - explicit GPUOperationBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { - start_ = fbb_.StartTable(); - } - flatbuffers::Offset<GPUOperation> Finish() { - const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<GPUOperation>(end); - return o; - } -}; - -inline flatbuffers::Offset<GPUOperation> CreateGPUOperation( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset<tflite::gpu::data::Arguments> arguments = 0, - flatbuffers::Offset<flatbuffers::String> code = 0, - flatbuffers::Offset<tflite::gpu::data::Int3> work_group_size = 0, - flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::gpu::cl::data::CompilerOption>>> compiler_options = 0, - tflite::gpu::cl::data::TensorToGrid tensor_to_grid = tflite::gpu::cl::data::TensorToGrid::CUSTOM, - bool elementwise = false, - bool linkable = false, - bool check_src_channels_size = false, - flatbuffers::Offset<tflite::gpu::cl::data::OperationDef> definition = 0, - int32_t grid_dimension = 0, - flatbuffers::Offset<tflite::gpu::data::Int3> work_group_launch_order = 0, - flatbuffers::Offset<tflite::gpu::data::Int3> grid_size = 0, - flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> src_tensors_names = 0, - flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> dst_tensors_names = 0, - flatbuffers::Offset<tflite::gpu::data::Int3> work_groups_count = 0, - int32_t linkable_count = 0, - flatbuffers::Offset<flatbuffers::String> elementwise_code = 0) { - GPUOperationBuilder builder_(_fbb); - builder_.add_elementwise_code(elementwise_code); - builder_.add_linkable_count(linkable_count); - builder_.add_work_groups_count(work_groups_count); - builder_.add_dst_tensors_names(dst_tensors_names); - builder_.add_src_tensors_names(src_tensors_names); - builder_.add_grid_size(grid_size); - builder_.add_work_group_launch_order(work_group_launch_order); - builder_.add_grid_dimension(grid_dimension); - builder_.add_definition(definition); - builder_.add_compiler_options(compiler_options); - builder_.add_work_group_size(work_group_size); - builder_.add_code(code); - builder_.add_arguments(arguments); - builder_.add_check_src_channels_size(check_src_channels_size); - builder_.add_linkable(linkable); - builder_.add_elementwise(elementwise); - builder_.add_tensor_to_grid(tensor_to_grid); - return builder_.Finish(); -} - -inline flatbuffers::Offset<GPUOperation> CreateGPUOperationDirect( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset<tflite::gpu::data::Arguments> arguments = 0, - const char *code = nullptr, - flatbuffers::Offset<tflite::gpu::data::Int3> work_group_size = 0, - const std::vector<flatbuffers::Offset<tflite::gpu::cl::data::CompilerOption>> *compiler_options = nullptr, - tflite::gpu::cl::data::TensorToGrid tensor_to_grid = tflite::gpu::cl::data::TensorToGrid::CUSTOM, - bool elementwise = false, - bool linkable = false, - bool check_src_channels_size = false, - flatbuffers::Offset<tflite::gpu::cl::data::OperationDef> definition = 0, - int32_t grid_dimension = 0, - flatbuffers::Offset<tflite::gpu::data::Int3> work_group_launch_order = 0, - flatbuffers::Offset<tflite::gpu::data::Int3> grid_size = 0, - const std::vector<flatbuffers::Offset<flatbuffers::String>> *src_tensors_names = nullptr, - const std::vector<flatbuffers::Offset<flatbuffers::String>> *dst_tensors_names = nullptr, - flatbuffers::Offset<tflite::gpu::data::Int3> work_groups_count = 0, - int32_t linkable_count = 0, - const char *elementwise_code = nullptr) { - auto code__ = code ? _fbb.CreateString(code) : 0; - auto compiler_options__ = compiler_options ? _fbb.CreateVector<flatbuffers::Offset<tflite::gpu::cl::data::CompilerOption>>(*compiler_options) : 0; - auto src_tensors_names__ = src_tensors_names ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*src_tensors_names) : 0; - auto dst_tensors_names__ = dst_tensors_names ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*dst_tensors_names) : 0; - auto elementwise_code__ = elementwise_code ? _fbb.CreateString(elementwise_code) : 0; - return tflite::gpu::cl::data::CreateGPUOperation( - _fbb, - arguments, - code__, - work_group_size, - compiler_options__, - tensor_to_grid, - elementwise, - linkable, - check_src_channels_size, - definition, - grid_dimension, - work_group_launch_order, - grid_size, - src_tensors_names__, - dst_tensors_names__, - work_groups_count, - linkable_count, - elementwise_code__); -} - struct TensorDescWithId FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef TensorDescWithIdBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { @@ -602,8 +99,8 @@ struct CLNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_OUTPUT_IDS = 8, VT_NAME = 10 }; - const tflite::gpu::cl::data::GPUOperation *gpu_op() const { - return GetPointer<const tflite::gpu::cl::data::GPUOperation *>(VT_GPU_OP); + const tflite::gpu::data::GPUOperation *gpu_op() const { + return GetPointer<const tflite::gpu::data::GPUOperation *>(VT_GPU_OP); } const flatbuffers::Vector<int32_t> *input_ids() const { return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_INPUT_IDS); @@ -632,7 +129,7 @@ struct CLNodeBuilder { typedef CLNode Table; flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_gpu_op(flatbuffers::Offset<tflite::gpu::cl::data::GPUOperation> gpu_op) { + void add_gpu_op(flatbuffers::Offset<tflite::gpu::data::GPUOperation> gpu_op) { fbb_.AddOffset(CLNode::VT_GPU_OP, gpu_op); } void add_input_ids(flatbuffers::Offset<flatbuffers::Vector<int32_t>> input_ids) { @@ -657,7 +154,7 @@ struct CLNodeBuilder { inline flatbuffers::Offset<CLNode> CreateCLNode( flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset<tflite::gpu::cl::data::GPUOperation> gpu_op = 0, + flatbuffers::Offset<tflite::gpu::data::GPUOperation> gpu_op = 0, flatbuffers::Offset<flatbuffers::Vector<int32_t>> input_ids = 0, flatbuffers::Offset<flatbuffers::Vector<int32_t>> output_ids = 0, flatbuffers::Offset<flatbuffers::String> name = 0) { @@ -671,7 +168,7 @@ inline flatbuffers::Offset<CLNode> CreateCLNode( inline flatbuffers::Offset<CLNode> CreateCLNodeDirect( flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset<tflite::gpu::cl::data::GPUOperation> gpu_op = 0, + flatbuffers::Offset<tflite::gpu::data::GPUOperation> gpu_op = 0, const std::vector<int32_t> *input_ids = nullptr, const std::vector<int32_t> *output_ids = nullptr, const char *name = nullptr) { @@ -767,8 +264,9 @@ struct InferenceContext FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { bool need_manual_release() const { return GetField<uint8_t>(VT_NEED_MANUAL_RELEASE, 0) != 0; } - tflite::gpu::cl::data::CalculationsPrecision precision() const { - return static_cast<tflite::gpu::cl::data::CalculationsPrecision>(GetField<int8_t>(VT_PRECISION, 0)); + tflite::gpu::data::CalculationsPrecision precision() const { + return static_cast<tflite::gpu::data::CalculationsPrecision>( + GetField<int8_t>(VT_PRECISION, 0)); } tflite::gpu::data::TensorStorageType storage_type() const { return static_cast<tflite::gpu::data::TensorStorageType>(GetField<int8_t>(VT_STORAGE_TYPE, 0)); @@ -847,7 +345,7 @@ struct InferenceContextBuilder { void add_need_manual_release(bool need_manual_release) { fbb_.AddElement<uint8_t>(InferenceContext::VT_NEED_MANUAL_RELEASE, static_cast<uint8_t>(need_manual_release), 0); } - void add_precision(tflite::gpu::cl::data::CalculationsPrecision precision) { + void add_precision(tflite::gpu::data::CalculationsPrecision precision) { fbb_.AddElement<int8_t>(InferenceContext::VT_PRECISION, static_cast<int8_t>(precision), 0); } void add_storage_type(tflite::gpu::data::TensorStorageType storage_type) { @@ -895,8 +393,8 @@ inline flatbuffers::Offset<InferenceContext> CreateInferenceContext( flatbuffers::FlatBufferBuilder &_fbb, bool need_flush = false, bool flush_periodically = false, int32_t flush_period = 0, bool need_manual_release = false, - tflite::gpu::cl::data::CalculationsPrecision precision = - tflite::gpu::cl::data::CalculationsPrecision::F32, + tflite::gpu::data::CalculationsPrecision precision = + tflite::gpu::data::CalculationsPrecision::F32, tflite::gpu::data::TensorStorageType storage_type = tflite::gpu::data::TensorStorageType::UNKNOWN, flatbuffers::Offset< @@ -937,8 +435,8 @@ inline flatbuffers::Offset<InferenceContext> CreateInferenceContextDirect( flatbuffers::FlatBufferBuilder &_fbb, bool need_flush = false, bool flush_periodically = false, int32_t flush_period = 0, bool need_manual_release = false, - tflite::gpu::cl::data::CalculationsPrecision precision = - tflite::gpu::cl::data::CalculationsPrecision::F32, + tflite::gpu::data::CalculationsPrecision precision = + tflite::gpu::data::CalculationsPrecision::F32, tflite::gpu::data::TensorStorageType storage_type = tflite::gpu::data::TensorStorageType::UNKNOWN, const std::vector<flatbuffers::Offset<tflite::gpu::cl::data::CLNode>> diff --git a/tensorflow/lite/delegates/gpu/cl/storage_type_util.cc b/tensorflow/lite/delegates/gpu/cl/storage_type_util.cc index d579368926e..7ba81d138e2 100644 --- a/tensorflow/lite/delegates/gpu/cl/storage_type_util.cc +++ b/tensorflow/lite/delegates/gpu/cl/storage_type_util.cc @@ -33,37 +33,38 @@ bool CanCreateTensorWithShape(const GpuInfo& gpu_info, const BHWDC& shape, 4 * (descriptor.data_type == DataType::FLOAT32 ? 4 : 2); const int buffer_size = shape.b * shape.w * shape.h * shape.d * slices * flt4_size; - return buffer_size <= gpu_info.buffer_max_size; + return buffer_size <= gpu_info.GetMaxBufferSize(); } case TensorStorageType::IMAGE_BUFFER: return shape.b * shape.w * shape.h * shape.d * slices <= - gpu_info.image_buffer_max_size; + gpu_info.GetMaxImageBufferWidth(); case TensorStorageType::TEXTURE_3D: - if (gpu_info.cl_version < OpenCLVersion::CL_1_2 && slices == 1) { + if (gpu_info.opencl_info.cl_version < OpenClVersion::kCl1_2 && + slices == 1) { // clCreateImage3D (that used in CL 1.0/1.1) can not create image with // depth = 1 by specification; return false; } - return shape.w * shape.b <= gpu_info.image3d_max_width && - shape.h <= gpu_info.image3d_max_height && - slices * shape.d <= gpu_info.image3d_max_depth; + return shape.w * shape.b <= gpu_info.GetMaxImage3DWidth() && + shape.h <= gpu_info.GetMaxImage3DHeight() && + slices * shape.d <= gpu_info.GetMaxImage3DDepth(); case TensorStorageType::TEXTURE_ARRAY: // Bug on some Adreno. b/131099086 if (slices == 1 && gpu_info.IsAdreno() && !gpu_info.adreno_info.support_one_layer_texture_array) { return false; } - return shape.w * shape.b <= gpu_info.image2d_max_width && - shape.h <= gpu_info.image2d_max_height && - slices * shape.d <= gpu_info.image_array_max_layers; + return shape.w * shape.b <= gpu_info.GetMaxImage2DWidth() && + shape.h <= gpu_info.GetMaxImage2DHeight() && + slices * shape.d <= gpu_info.GetMaxImage2DArrayLayers(); case TensorStorageType::TEXTURE_2D: - return shape.w * shape.b * shape.d <= gpu_info.image2d_max_width && - shape.h * slices <= gpu_info.image2d_max_height; + return shape.w * shape.b * shape.d <= gpu_info.GetMaxImage2DWidth() && + shape.h * slices <= gpu_info.GetMaxImage2DHeight(); case TensorStorageType::SINGLE_TEXTURE_2D: return shape.c <= 4 && gpu_info.SupportsFloatImage2D(descriptor.data_type, shape.c) && - shape.w * shape.b * shape.d <= gpu_info.image2d_max_width && - shape.h <= gpu_info.image2d_max_height; + shape.w * shape.b * shape.d <= gpu_info.GetMaxImage2DWidth() && + shape.h <= gpu_info.GetMaxImage2DHeight(); default: return false; } diff --git a/tensorflow/lite/delegates/gpu/cl/storage_type_util.h b/tensorflow/lite/delegates/gpu/cl/storage_type_util.h index f30219156b4..b849d7a1087 100644 --- a/tensorflow/lite/delegates/gpu/cl/storage_type_util.h +++ b/tensorflow/lite/delegates/gpu/cl/storage_type_util.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_STORAGE_TYPE_UTIL_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_CL_STORAGE_TYPE_UTIL_H_ -#include "tensorflow/lite/delegates/gpu/cl/device_info.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/gpu_info.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h" diff --git a/tensorflow/lite/delegates/gpu/common/BUILD b/tensorflow/lite/delegates/gpu/common/BUILD index fe67d396539..ef43c76310b 100644 --- a/tensorflow/lite/delegates/gpu/common/BUILD +++ b/tensorflow/lite/delegates/gpu/common/BUILD @@ -40,6 +40,7 @@ cc_library( srcs = ["gpu_info.cc"], hdrs = ["gpu_info.h"], deps = [ + ":data_type", "@com_google_absl//absl/strings", ], ) @@ -386,11 +387,21 @@ cc_library( hdrs = ["winograd_util.h"], deps = [ ":data_type", + ":operations", ":shape", ":tensor", ], ) +cc_test( + name = "winograd_util_test", + srcs = ["winograd_util_test.cc"], + deps = [ + ":winograd_util", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "workgroup_selection", srcs = ["workgroup_selection.cc"], diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.cc b/tensorflow/lite/delegates/gpu/common/gpu_info.cc index 14f79df7801..2204f4a6448 100644 --- a/tensorflow/lite/delegates/gpu/common/gpu_info.cc +++ b/tensorflow/lite/delegates/gpu/common/gpu_info.cc @@ -358,6 +358,27 @@ void GetGpuInfoFromDeviceDescription(const std::string& gpu_description, } } +std::string OpenClVersionToString(OpenClVersion version) { + switch (version) { + case OpenClVersion::kCl1_0: + return "1.0"; + case OpenClVersion::kCl1_1: + return "1.1"; + case OpenClVersion::kCl1_2: + return "1.2"; + case OpenClVersion::kCl2_0: + return "2.0"; + case OpenClVersion::kCl2_1: + return "2.1"; + case OpenClVersion::kCl2_2: + return "2.2"; + case OpenClVersion::kCl3_0: + return "3.0"; + default: + return "Unknown OpenCL version"; + } +} + bool GpuInfo::IsAdreno() const { return vendor == GpuVendor::kQualcomm; } bool GpuInfo::IsApple() const { return vendor == GpuVendor::kApple; } @@ -373,11 +394,45 @@ bool GpuInfo::IsAMD() const { return vendor == GpuVendor::kAMD; } bool GpuInfo::IsIntel() const { return vendor == GpuVendor::kIntel; } bool GpuInfo::IsRoundToNearestSupported() const { + if (IsApiOpenCl()) { + return opencl_info.supports_fp16_rtn || opencl_info.supports_fp32_rtn; + } if (IsApple()) { return apple_info.IsRoundToNearestSupported(); - } else { - return true; } + return true; +} + +bool GpuInfo::SupportsFP16() const { + if (IsApiOpenCl()) { + return opencl_info.supports_fp16; + } + return true; +} + +bool GpuInfo::SupportsTextureArray() const { + if (IsApiOpenCl()) { + return opencl_info.cl_version >= OpenClVersion::kCl1_2; + } + return true; +} + +bool GpuInfo::SupportsImageBuffer() const { + if (IsApiOpenCl()) { + return opencl_info.cl_version >= OpenClVersion::kCl1_2; + } + return true; +} + +bool GpuInfo::SupportsImage3D() const { + if (IsApiOpenCl()) { + if (IsMali() && mali_info.IsMidgard()) { + // On Mali T880 read_imageh doesn't compile with image3d_t + return false; + } + return opencl_info.supports_image3d_writes; + } + return true; } bool GpuInfo::IsWaveSizeEqualTo32() const { @@ -385,36 +440,212 @@ bool GpuInfo::IsWaveSizeEqualTo32() const { supported_subgroup_sizes[0] == 32; } +bool GpuInfo::SupportsExtension(const std::string& extension) const { + const std::vector<std::string>* extensions = nullptr; + if (IsApiOpenGl()) { + extensions = &opengl_info.extensions; + } else if (IsApiVulkan()) { + extensions = &vulkan_info.extensions; + } else if (IsApiOpenCl()) { + extensions = &opencl_info.extensions; + } + if (!extensions) { + return false; + } + for (const auto& ext : *extensions) { + if (ext == extension) { + return true; + } + } + return false; +} + +bool GpuInfo::SupportsSubGroupWithSize(int sub_group_size) const { + for (auto subgroup_size : supported_subgroup_sizes) { + if (sub_group_size == subgroup_size) { + return true; + } + } + return false; +} + +bool GpuInfo::SupportsFloatImage2D(DataType data_type, int channels) const { + if (IsApiOpenCl()) { + if (channels == 1) { + return data_type == DataType::FLOAT32 ? opencl_info.supports_r_f32_tex2d + : opencl_info.supports_r_f16_tex2d; + } else if (channels == 2) { + return data_type == DataType::FLOAT32 ? opencl_info.supports_rg_f32_tex2d + : opencl_info.supports_rg_f16_tex2d; + } else if (channels == 3) { + return data_type == DataType::FLOAT32 + ? opencl_info.supports_rgb_f32_tex2d + : opencl_info.supports_rgb_f16_tex2d; + } else if (channels == 4) { + return data_type == DataType::FLOAT32 + ? opencl_info.supports_rgba_f32_tex2d + : opencl_info.supports_rgba_f16_tex2d; + } else { + return false; + } + } + return false; +} + int GpuInfo::GetComputeUnitsCount() const { + if (IsApiOpenCl()) { + return opencl_info.compute_units_count; + } if (IsApple()) { return apple_info.GetComputeUnitsCount(); - } else { - return 1; } + return 1; +} + +int GpuInfo::GetMaxWorkGroupSizeForX() const { + if (IsApiOpenGl()) { + return opengl_info.max_compute_work_group_size_x; + } + if (IsApiVulkan()) { + return vulkan_info.max_compute_work_group_size_x; + } + if (IsApiOpenCl()) { + return opencl_info.max_work_group_size_x; + } + return 256; +} + +int GpuInfo::GetMaxWorkGroupSizeForY() const { + if (IsApiOpenGl()) { + return opengl_info.max_compute_work_group_size_y; + } + if (IsApiVulkan()) { + return vulkan_info.max_compute_work_group_size_y; + } + if (IsApiOpenCl()) { + return opencl_info.max_work_group_size_y; + } + return 256; +} + +int GpuInfo::GetMaxWorkGroupSizeForZ() const { + if (IsApiOpenGl()) { + return opengl_info.max_compute_work_group_size_z; + } + if (IsApiVulkan()) { + return vulkan_info.max_compute_work_group_size_z; + } + if (IsApiOpenCl()) { + return opencl_info.max_work_group_size_z; + } + return 64; +} + +int GpuInfo::GetMaxWorkGroupTotalSize() const { + if (IsApiOpenGl()) { + return opengl_info.max_work_group_invocations; + } + if (IsApiVulkan()) { + return vulkan_info.max_compute_work_group_invocations; + } + if (IsApiOpenCl()) { + return opencl_info.max_work_group_total_size; + } + return 256; +} + +uint64_t GpuInfo::GetMaxImage2DWidth() const { + if (IsApiOpenGl()) { + return opengl_info.max_texture_size; + } + if (IsApiVulkan()) { + return vulkan_info.max_image_dimension_2d; + } + if (IsApiOpenCl()) { + return opencl_info.image2d_max_width; + } + return 2048; +} + +uint64_t GpuInfo::GetMaxImage2DHeight() const { + if (IsApiOpenGl()) { + return opengl_info.max_texture_size; + } + if (IsApiVulkan()) { + return vulkan_info.max_image_dimension_2d; + } + if (IsApiOpenCl()) { + return opencl_info.image2d_max_height; + } + return 2048; +} + +uint64_t GpuInfo::GetMaxImage2DArrayLayers() const { + if (IsApiOpenGl()) { + return opengl_info.max_array_texture_layers; + } + if (IsApiVulkan()) { + return vulkan_info.max_image_array_layers; + } + if (IsApiOpenCl()) { + return opencl_info.image_array_max_layers; + } + return 256; +} + +uint64_t GpuInfo::GetMaxImage3DWidth() const { + if (IsApiOpenCl()) { + return opencl_info.image3d_max_width; + } + return 256; +} + +uint64_t GpuInfo::GetMaxImage3DHeight() const { + if (IsApiOpenCl()) { + return opencl_info.image3d_max_height; + } + return 256; +} + +uint64_t GpuInfo::GetMaxImage3DDepth() const { + if (IsApiOpenCl()) { + return opencl_info.image3d_max_depth; + } + return 256; +} + +uint64_t GpuInfo::GetMaxBufferSize() const { + if (IsApiOpenCl()) { + return opencl_info.buffer_max_size; + } + return 128 * 1024 * 1024; +} + +uint64_t GpuInfo::GetMaxImageBufferWidth() const { + if (IsApiOpenCl()) { + return opencl_info.image_buffer_max_size; + } + return 64 * 1024; } int GpuInfo::GetMaxImageArguments() const { if (IsApiOpenGl()) { return opengl_info.max_image_units; - } else if (IsApiVulkan()) { - return vulkan_info.max_per_stage_descriptor_sampled_images; - } else if (IsApiMetal()) { - return 32; - } else if (IsApiOpenCl()) { - return 128; - } else { - return 1; } + if (IsApiVulkan()) { + return vulkan_info.max_per_stage_descriptor_sampled_images; + } + if (IsApiMetal()) { + return 32; + } + if (IsApiOpenCl()) { + return 128; + } + return 1; } bool GpuInfo::IsApiOpenGl() const { return gpu_api == GpuApi::kOpenGl; } -bool GpuInfo::IsApiVulkan() const { return gpu_api == GpuApi::kVulkan; } - -bool GpuInfo::IsApiMetal() const { return gpu_api == GpuApi::kMetal; } - -bool GpuInfo::IsApiOpenCl() const { return gpu_api == GpuApi::kOpenCl; } - bool GpuInfo::IsApiOpenGl31OrAbove() const { if (!IsApiOpenGl()) { return false; @@ -423,5 +654,29 @@ bool GpuInfo::IsApiOpenGl31OrAbove() const { opengl_info.major_version > 3; } +bool GpuInfo::IsApiVulkan() const { return gpu_api == GpuApi::kVulkan; } + +bool GpuInfo::IsApiMetal() const { return gpu_api == GpuApi::kMetal; } + +bool GpuInfo::IsApiOpenCl() const { return gpu_api == GpuApi::kOpenCl; } + +bool GpuInfo::IsCL20OrHigher() const { + if (!IsApiOpenCl()) { + return false; + } + return opencl_info.cl_version != OpenClVersion::kCl1_0 && + opencl_info.cl_version != OpenClVersion::kCl1_1 && + opencl_info.cl_version != OpenClVersion::kCl1_2; +} + +bool GpuInfo::IsCL30OrHigher() const { + if (!IsApiOpenCl()) { + return false; + } + return IsCL20OrHigher() && opencl_info.cl_version != OpenClVersion::kCl2_0 && + opencl_info.cl_version != OpenClVersion::kCl2_1 && + opencl_info.cl_version != OpenClVersion::kCl2_2; +} + } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.h b/tensorflow/lite/delegates/gpu/common/gpu_info.h index ae5d3a75c1b..cd61887fa83 100644 --- a/tensorflow/lite/delegates/gpu/common/gpu_info.h +++ b/tensorflow/lite/delegates/gpu/common/gpu_info.h @@ -19,6 +19,8 @@ limitations under the License. #include <string> #include <vector> +#include "tensorflow/lite/delegates/gpu/common/data_type.h" + namespace tflite { namespace gpu { @@ -208,6 +210,14 @@ struct OpenGlInfo { int max_image_units = 0; int max_ssbo_bindings = 0; int max_image_bindings = 0; + int max_work_group_invocations = 0; + int max_texture_size = 0; + int max_array_texture_layers = 0; + + std::vector<std::string> extensions; + int max_compute_work_group_size_x; + int max_compute_work_group_size_y; + int max_compute_work_group_size_z; }; struct VulkanInfo { @@ -218,6 +228,65 @@ struct VulkanInfo { uint32_t api_version_patch = -1; uint32_t max_per_stage_descriptor_sampled_images = 0; + uint32_t max_compute_work_group_invocations; + uint32_t max_image_dimension_2d; + uint32_t max_image_array_layers; + + std::vector<std::string> extensions; + int max_compute_work_group_size_x; + int max_compute_work_group_size_y; + int max_compute_work_group_size_z; +}; + +enum class OpenClVersion { + kCl1_0, + kCl1_1, + kCl1_2, + kCl2_0, + kCl2_1, + kCl2_2, + kCl3_0, + kUnknown, +}; +std::string OpenClVersionToString(OpenClVersion version); + +struct OpenClInfo { + OpenClVersion cl_version; + + std::vector<std::string> extensions; + bool supports_fp16; + bool supports_image3d_writes; + int compute_units_count; + uint64_t buffer_max_size; + uint64_t image2d_max_width; + uint64_t image2d_max_height; + uint64_t image_buffer_max_size; + uint64_t image_array_max_layers; + uint64_t image3d_max_width; + uint64_t image3d_max_height; + uint64_t image3d_max_depth; + int max_work_group_size_x; + int max_work_group_size_y; + int max_work_group_size_z; + int max_work_group_total_size; + + // rtn is ROUND_TO_NEAREST + // with rtn precision is much better then with rtz (ROUND_TO_ZERO) + // Adreno 3xx supports only rtz, Adreno 4xx and more support rtn + // Mali from T6xx supports rtn + // PowerVR supports only rtz + bool supports_fp32_rtn; + bool supports_fp16_rtn; + + bool supports_r_f16_tex2d = false; + bool supports_rg_f16_tex2d = false; + bool supports_rgb_f16_tex2d = false; + bool supports_rgba_f16_tex2d = false; + + bool supports_r_f32_tex2d = false; + bool supports_rg_f32_tex2d = false; + bool supports_rgb_f32_tex2d = false; + bool supports_rgba_f32_tex2d = false; }; struct GpuInfo { @@ -232,21 +301,42 @@ struct GpuInfo { // floating point rounding mode bool IsRoundToNearestSupported() const; + bool SupportsFP16() const; + + bool SupportsTextureArray() const; + bool SupportsImageBuffer() const; + bool SupportsImage3D() const; + // returns true if device have fixed wave size equal to 32 bool IsWaveSizeEqualTo32() const; + bool SupportsSubGroupWithSize(int sub_group_size) const; + + bool SupportsFloatImage2D(DataType data_type, int channels) const; + bool SupportsExtension(const std::string& extension) const; int GetComputeUnitsCount() const; int GetMaxImageArguments() const; + int GetMaxWorkGroupSizeForX() const; + int GetMaxWorkGroupSizeForY() const; + int GetMaxWorkGroupSizeForZ() const; + int GetMaxWorkGroupTotalSize() const; + + uint64_t GetMaxImage2DWidth() const; + uint64_t GetMaxImage2DHeight() const; + uint64_t GetMaxImage2DArrayLayers() const; + uint64_t GetMaxImage3DWidth() const; + uint64_t GetMaxImage3DHeight() const; + uint64_t GetMaxImage3DDepth() const; + uint64_t GetMaxBufferSize() const; + uint64_t GetMaxImageBufferWidth() const; + GpuVendor vendor = GpuVendor::kUnknown; GpuApi gpu_api = GpuApi::kUnknown; - std::vector<std::string> extensions; + // Temporary std::vector<int> max_work_group_size; - int max_work_group_invocations; - int max_texture_size = 0; - int max_array_texture_layers = 0; std::vector<int> supported_subgroup_sizes; @@ -265,7 +355,10 @@ struct GpuInfo { bool IsApiMetal() const; + OpenClInfo opencl_info; bool IsApiOpenCl() const; + bool IsCL20OrHigher() const; + bool IsCL30OrHigher() const; }; inline bool IsOpenGl31OrAbove(const GpuInfo& gpu_info) { diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index ec39dbd5c5c..97eb0751328 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -858,7 +858,7 @@ class FullyConnectedOperationParser : public TFLiteOperationParser { absl::Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { - RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 4)); + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 9)); const TfLiteFullyConnectedParams* tf_options; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); if (tf_options->weights_format != @@ -870,6 +870,15 @@ class FullyConnectedOperationParser : public TFLiteOperationParser { return absl::UnimplementedError( "FullyConnected doesn't support more than 2 runtime inputs."); } + if (tf_options->keep_num_dims == true) { + const auto* input = context->tensors + tflite_node->inputs->data[0]; + const auto* output = context->tensors + tflite_node->outputs->data[0]; + if (input->dims->size != output->dims->size) { + return absl::UnimplementedError( + "Input and output dimensions different and FullyConnected doesn't " + "support keep_num_dims."); + } + } // TODO(eignasheva): check input shape return absl::OkStatus(); } @@ -1508,10 +1517,6 @@ class ReduceOperationParser : public TFLiteOperationParser { return absl::UnimplementedError( "Reduce has unsupported tensor for axes."); } - if (tflite::NumElements(axes) != 1) { - return absl::UnimplementedError( - "Supported reduce in single dimensions only."); - } return absl::OkStatus(); } @@ -1526,13 +1531,15 @@ class ReduceOperationParser : public TFLiteOperationParser { const TfLiteReducerParams* tf_options; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); - Tensor<Scalar, DataType::INT32> axes; - RETURN_IF_ERROR(reader->ReadTensor(1, &axes)); - const TfLiteTensor* input = reader->GetInputTensor(0); ReduceAttributes attr; - Axis axis; - RETURN_IF_ERROR(ExtractAxisFromIndex(*input, axes.data[0], &axis)); - attr.dims = {axis}; + Tensor<Linear, DataType::INT32> axes; + RETURN_IF_ERROR(reader->ReadTensor(1, &axes)); + const TfLiteTensor* output = reader->GetOutputTensor(0); + for (int i = 0; i < axes.data.size(); i++) { + Axis axis; + RETURN_IF_ERROR(ExtractAxisFromIndex(*output, axes.data[i], &axis)); + attr.dims.insert(axis); + } node->operation.attributes = attr; return absl::OkStatus(); } @@ -1652,7 +1659,6 @@ class Resize2DOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1)); - RETURN_IF_ERROR(CheckOnlyUpsamplingIsSupported(context, tflite_node)); bool align_corners; RETURN_IF_ERROR(GetAlignCornersValue(tflite_node, &align_corners)); bool half_pixel_centers; @@ -1724,27 +1730,6 @@ class Resize2DOperationParser : public TFLiteOperationParser { return absl::OkStatus(); } - absl::Status CheckOnlyUpsamplingIsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node) { - const auto* input = context->tensors + tflite_node->inputs->data[0]; - const auto* output = context->tensors + tflite_node->outputs->data[0]; - - if (!input->dims || input->dims->size != 4) { - return absl::InvalidArgumentError("input.dims.size != 4"); - } - if (!output->dims || output->dims->size != 4) { - return absl::InvalidArgumentError("output.dims.size != 4"); - } - if (output->dims->data[1] < input->dims->data[1] || - output->dims->data[2] < input->dims->data[2]) { - return absl::InvalidArgumentError(absl::StrCat( - "Only upsampling is supported, received output h,w = ", - output->dims->data[1], ",", output->dims->data[2], - " input h,w = ", input->dims->data[1], ",", input->dims->data[2])); - } - return absl::OkStatus(); - } - SamplingType sampling_type_ = SamplingType::UNKNOWN; }; @@ -2612,18 +2597,10 @@ class MeanOperationParser : public TFLiteOperationParser { /*runtime_inputs=*/1, /*outputs=*/1)); - // Simple mechanism to check if MEAN is to be performed only on HW plane. auto* axes = &context->tensors[tflite_node->inputs->data[1]]; if (axes->allocation_type != kTfLiteMmapRo || axes->type != kTfLiteInt32) { return absl::UnimplementedError("Mean has unsupported tensor for axes"); } - auto* axes_data = axes->data.i32; - const bool is_hw_mean = tflite::NumElements(axes) == 2 && - ((axes_data[0] == 1 && axes_data[1] == 2) || - (axes_data[0] == 2 && axes_data[1] == 1)); - if (!is_hw_mean) { - return absl::UnimplementedError("Mean operation supports only HW plane"); - } return absl::OkStatus(); } @@ -2636,27 +2613,13 @@ class MeanOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(reader->AddOutputs(node)); MeanAttributes attr; - Tensor<Linear, DataType::INT32> channel; - RETURN_IF_ERROR(reader->ReadTensor(1, &channel)); - for (int i = 0; i < channel.data.size(); i++) { - std::string unsupported; - switch (channel.data[i]) { - case 1: - attr.dims.insert(Axis::HEIGHT); - break; - case 2: - attr.dims.insert(Axis::WIDTH); - break; - case 0: - unsupported = unsupported.empty() ? "batch" : unsupported; - ABSL_FALLTHROUGH_INTENDED; - case 3: - unsupported = unsupported.empty() ? "channels" : unsupported; - ABSL_FALLTHROUGH_INTENDED; - default: - return absl::UnimplementedError( - absl::StrCat("Unsupported mean dimension: ", unsupported)); - } + Tensor<Linear, DataType::INT32> axes; + RETURN_IF_ERROR(reader->ReadTensor(1, &axes)); + const TfLiteTensor* output = reader->GetOutputTensor(0); + for (int i = 0; i < axes.data.size(); i++) { + Axis axis; + RETURN_IF_ERROR(ExtractAxisFromIndex(*output, axes.data[i], &axis)); + attr.dims.insert(axis); } node->operation.attributes = attr; return absl::OkStatus(); diff --git a/tensorflow/lite/delegates/gpu/common/object_reader.h b/tensorflow/lite/delegates/gpu/common/object_reader.h index 3c7d7f6a859..71f4a529b33 100644 --- a/tensorflow/lite/delegates/gpu/common/object_reader.h +++ b/tensorflow/lite/delegates/gpu/common/object_reader.h @@ -71,6 +71,9 @@ class ObjectReader { } const TfLiteTensor* tflite_tensor = context_->tensors + tensor_idx; + if (tflite_tensor->sparsity != nullptr) { + return absl::InvalidArgumentError("Sparsity is not supported on GPU."); + } t->data.resize(NumElements(tflite_tensor)); RETURN_IF_ERROR(CreateVectorCopyData(*tflite_tensor, &t->data[0])); diff --git a/tensorflow/lite/delegates/gpu/common/task/BUILD b/tensorflow/lite/delegates/gpu/common/task/BUILD index b6800701989..04e5082e1f5 100644 --- a/tensorflow/lite/delegates/gpu/common/task/BUILD +++ b/tensorflow/lite/delegates/gpu/common/task/BUILD @@ -50,6 +50,30 @@ cc_library( ], ) +cc_library( + name = "gpu_operation", + srcs = ["gpu_operation.cc"], + hdrs = ["gpu_operation.h"], + deps = [ + ":serialization_base_cc_fbs", + "//tensorflow/lite/delegates/gpu/common:access_type", + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:gpu_info", + "//tensorflow/lite/delegates/gpu/common:kernel_info", + "//tensorflow/lite/delegates/gpu/common:precision", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/task:arguments", + "//tensorflow/lite/delegates/gpu/common/task:buffer_desc", + "//tensorflow/lite/delegates/gpu/common/task:compiler_options", + "//tensorflow/lite/delegates/gpu/common/task:gpu_tensor", + "//tensorflow/lite/delegates/gpu/common/task:tensor_desc", + "//tensorflow/lite/delegates/gpu/common/task:tuning_type", + "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "gpu_tensor", hdrs = ["gpu_tensor.h"], @@ -116,5 +140,43 @@ cc_library( hdrs = ["util.h"], deps = [ ":gpu_object_desc", + "//tensorflow/lite/delegates/gpu/common:gpu_info", + "//tensorflow/lite/delegates/gpu/common:precision", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "weights_conversion", + hdrs = ["weights_conversion.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:tensor", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "weights_layout", + hdrs = ["weights_layout.h"], +) + +cc_library( + name = "work_group_picking", + srcs = ["work_group_picking.cc"], + hdrs = ["work_group_picking.h"], + deps = [ + ":tuning_type", + "//tensorflow/lite/delegates/gpu/cl:cl_kernel", + "//tensorflow/lite/delegates/gpu/common:gpu_info", + "//tensorflow/lite/delegates/gpu/common:kernel_info", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/common:workgroup_selection", ], ) diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc similarity index 82% rename from tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc rename to tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc index e9d5b626fd2..800576d0ff1 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc +++ b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc @@ -13,22 +13,61 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" +#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" #include "absl/strings/substitute.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/common/access_type.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" namespace tflite { namespace gpu { -namespace cl { namespace { +std::string GetCommonDefines(CalculationsPrecision precision) { + std::string result; + + switch (precision) { + case CalculationsPrecision::F32: + result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n"; + result += "#define ACCUM_FLT4 float4\n"; + result += "#define FLT float\n"; + result += "#define FLT2 float2\n"; + result += "#define FLT3 float3\n"; + result += "#define FLT4 float4\n"; + result += "#define TO_FLT4 convert_float4\n"; + result += "#define TO_ACCUM_TYPE convert_float4\n"; + result += "#define TO_ACCUM_FLT convert_float\n"; + break; + case CalculationsPrecision::F16: + result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n"; + result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"; + result += "#define ACCUM_FLT4 half4\n"; + result += "#define FLT half\n"; + result += "#define FLT2 half2\n"; + result += "#define FLT3 half3\n"; + result += "#define FLT4 half4\n"; + result += "#define TO_FLT4 convert_half4\n"; + result += "#define TO_ACCUM_TYPE convert_half4\n"; + result += "#define TO_ACCUM_FLT convert_half\n"; + break; + case CalculationsPrecision::F32_F16: + result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n"; + result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"; + result += "#define ACCUM_FLT4 float4\n"; + result += "#define FLT half\n"; + result += "#define FLT2 half2\n"; + result += "#define FLT3 half3\n"; + result += "#define FLT4 half4\n"; + result += "#define TO_FLT4 convert_half4\n"; + result += "#define TO_ACCUM_TYPE convert_float4\n"; + result += "#define TO_ACCUM_FLT convert_float\n"; + break; + } + return result; +} std::string GetElementWiseCode(const OperationDef& op_def, bool check_src_slices) { - std::string c = GetCommonDefines(op_def.precision); - + std::string c; c += "__kernel void main_function(\n"; c += "$0) {\n"; c += " int X = get_global_id(0);\n"; @@ -201,6 +240,7 @@ void GPUOperation::AssembleCode(const GpuInfo& gpu_info) { elementwise_code_ = "{\n" + code_ + "\n}\n" + elementwise_code_; code_ = GetElementWiseCode(definition_, check_src_channels_size_); } + code_ = GetCommonDefines(definition_.precision) + code_; } void GPUOperation::GetPossibleKernelWorkGroups( @@ -247,6 +287,5 @@ void GPUOperation::AddUniquePostfix(const std::string& unique_postfix) { } } -} // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.h similarity index 91% rename from tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h rename to tensorflow/lite/delegates/gpu/common/task/gpu_operation.h index c4d73225a27..f35682dad3a 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h +++ b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.h @@ -13,15 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_GPU_OPERATION_H_ -#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_GPU_OPERATION_H_ +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_GPU_OPERATION_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_GPU_OPERATION_H_ #include <string> #include <vector> -#include "tensorflow/lite/delegates/gpu/cl/device_info.h" -#include "tensorflow/lite/delegates/gpu/cl/serialization_generated.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/gpu_info.h" #include "tensorflow/lite/delegates/gpu/common/kernel_info.h" #include "tensorflow/lite/delegates/gpu/common/precision.h" #include "tensorflow/lite/delegates/gpu/common/status.h" @@ -29,6 +28,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/task/buffer_desc.h" #include "tensorflow/lite/delegates/gpu/common/task/compiler_options.h" #include "tensorflow/lite/delegates/gpu/common/task/gpu_tensor.h" +#include "tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h" #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h" #include "tensorflow/lite/delegates/gpu/common/task/tuning_type.h" #include "tensorflow/lite/delegates/gpu/common/types.h" @@ -36,6 +36,8 @@ limitations under the License. namespace tflite { namespace gpu { namespace cl { +class ClOperation; +} // kCustom: default value // GPUOperation::GetGridSize must be overloaded @@ -142,10 +144,11 @@ class GPUOperation { bool check_src_channels_size_ = false; protected: - friend class ClOperation; - friend flatbuffers::Offset<data::GPUOperation> Encode( + friend class cl::ClOperation; + friend flatbuffers::Offset<tflite::gpu::data::GPUOperation> Encode( const GPUOperation& op, flatbuffers::FlatBufferBuilder* builder); - friend absl::Status Decode(const data::GPUOperation* fb_op, GPUOperation* op); + friend absl::Status Decode(const tflite::gpu::data::GPUOperation* fb_op, + GPUOperation* op); virtual absl::Status BindArguments(ArgumentsBinder* args) { return absl::OkStatus(); @@ -168,8 +171,7 @@ class GPUOperation { std::string elementwise_code_; // temporary, used during op construction }; -} // namespace cl } // namespace gpu } // namespace tflite -#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_GPU_OPERATION_H_ +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_GPU_OPERATION_H_ diff --git a/tensorflow/lite/delegates/gpu/common/task/serialization_base.fbs b/tensorflow/lite/delegates/gpu/common/task/serialization_base.fbs index 469a7dd9443..8d9434786a2 100644 --- a/tensorflow/lite/delegates/gpu/common/task/serialization_base.fbs +++ b/tensorflow/lite/delegates/gpu/common/task/serialization_base.fbs @@ -181,3 +181,56 @@ table Arguments { tensor_linear_objects:[TensorLinearDescriptorMapValue]; tensor_objects:[TensorDescriptorMapValue]; } + +enum CalculationsPrecision : byte { + F32 = 0, + F32_F16 = 1, + F16 = 2, +} + +enum TensorToGrid : byte { + CUSTOM = 0, + WB_TO_X_HD_TO_Y_S_TO_Z = 1, + WB_TO_X_HD_TO_Y_Z_IS_1 = 2, + WB_TO_X_H_TO_Y_D_TO_Z = 3, + B_TO_X_Y_IS_1_Z_IS_1 = 4, +} + +enum CompilerOptions : byte { + ADRENO_FULL_SIMD_LINE = 0, + ADRENO_MORE_WAVES = 1, + POWERVR_FP16 = 2, + CL_OPT_DISABLE = 3, + CL_2_0 = 4, + CL_3_0 = 5, +} + +table OperationDef { + precision:CalculationsPrecision; + src_tensors:[TensorDescriptor]; + dst_tensors:[TensorDescriptor]; +} + +table CompilerOption { + option:CompilerOptions; +} + +table GPUOperation { + arguments:Arguments; + code:string; + work_group_size:Int3; + compiler_options:[CompilerOption]; + tensor_to_grid:TensorToGrid; + elementwise:bool; + linkable:bool; + check_src_channels_size:bool; + definition:OperationDef; + grid_dimension:int32; + work_group_launch_order:Int3; + grid_size:Int3; + src_tensors_names:[string]; + dst_tensors_names:[string]; + work_groups_count:Int3; + linkable_count:int32; + elementwise_code:string; +} diff --git a/tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h b/tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h index 7f089f95082..e3c3a5c33df 100644 --- a/tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h +++ b/tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h @@ -78,6 +78,15 @@ struct TensorDescriptorMapValueBuilder; struct Arguments; struct ArgumentsBuilder; +struct OperationDef; +struct OperationDefBuilder; + +struct CompilerOption; +struct CompilerOptionBuilder; + +struct GPUOperation; +struct GPUOperationBuilder; + enum class AccessType : int8_t { READ = 0, WRITE = 1, @@ -291,6 +300,111 @@ inline const char *EnumNameLayout(Layout e) { return EnumNamesLayout()[index]; } +enum class CalculationsPrecision : int8_t { + F32 = 0, + F32_F16 = 1, + F16 = 2, + MIN = F32, + MAX = F16 +}; + +inline const CalculationsPrecision (&EnumValuesCalculationsPrecision())[3] { + static const CalculationsPrecision values[] = {CalculationsPrecision::F32, + CalculationsPrecision::F32_F16, + CalculationsPrecision::F16}; + return values; +} + +inline const char *const *EnumNamesCalculationsPrecision() { + static const char *const names[4] = {"F32", "F32_F16", "F16", nullptr}; + return names; +} + +inline const char *EnumNameCalculationsPrecision(CalculationsPrecision e) { + if (flatbuffers::IsOutRange(e, CalculationsPrecision::F32, + CalculationsPrecision::F16)) + return ""; + const size_t index = static_cast<size_t>(e); + return EnumNamesCalculationsPrecision()[index]; +} + +enum class TensorToGrid : int8_t { + CUSTOM = 0, + WB_TO_X_HD_TO_Y_S_TO_Z = 1, + WB_TO_X_HD_TO_Y_Z_IS_1 = 2, + WB_TO_X_H_TO_Y_D_TO_Z = 3, + B_TO_X_Y_IS_1_Z_IS_1 = 4, + MIN = CUSTOM, + MAX = B_TO_X_Y_IS_1_Z_IS_1 +}; + +inline const TensorToGrid (&EnumValuesTensorToGrid())[5] { + static const TensorToGrid values[] = { + TensorToGrid::CUSTOM, TensorToGrid::WB_TO_X_HD_TO_Y_S_TO_Z, + TensorToGrid::WB_TO_X_HD_TO_Y_Z_IS_1, TensorToGrid::WB_TO_X_H_TO_Y_D_TO_Z, + TensorToGrid::B_TO_X_Y_IS_1_Z_IS_1}; + return values; +} + +inline const char *const *EnumNamesTensorToGrid() { + static const char *const names[6] = {"CUSTOM", + "WB_TO_X_HD_TO_Y_S_TO_Z", + "WB_TO_X_HD_TO_Y_Z_IS_1", + "WB_TO_X_H_TO_Y_D_TO_Z", + "B_TO_X_Y_IS_1_Z_IS_1", + nullptr}; + return names; +} + +inline const char *EnumNameTensorToGrid(TensorToGrid e) { + if (flatbuffers::IsOutRange(e, TensorToGrid::CUSTOM, + TensorToGrid::B_TO_X_Y_IS_1_Z_IS_1)) + return ""; + const size_t index = static_cast<size_t>(e); + return EnumNamesTensorToGrid()[index]; +} + +enum class CompilerOptions : int8_t { + ADRENO_FULL_SIMD_LINE = 0, + ADRENO_MORE_WAVES = 1, + POWERVR_FP16 = 2, + CL_OPT_DISABLE = 3, + CL_2_0 = 4, + CL_3_0 = 5, + MIN = ADRENO_FULL_SIMD_LINE, + MAX = CL_3_0 +}; + +inline const CompilerOptions (&EnumValuesCompilerOptions())[6] { + static const CompilerOptions values[] = { + CompilerOptions::ADRENO_FULL_SIMD_LINE, + CompilerOptions::ADRENO_MORE_WAVES, + CompilerOptions::POWERVR_FP16, + CompilerOptions::CL_OPT_DISABLE, + CompilerOptions::CL_2_0, + CompilerOptions::CL_3_0}; + return values; +} + +inline const char *const *EnumNamesCompilerOptions() { + static const char *const names[7] = {"ADRENO_FULL_SIMD_LINE", + "ADRENO_MORE_WAVES", + "POWERVR_FP16", + "CL_OPT_DISABLE", + "CL_2_0", + "CL_3_0", + nullptr}; + return names; +} + +inline const char *EnumNameCompilerOptions(CompilerOptions e) { + if (flatbuffers::IsOutRange(e, CompilerOptions::ADRENO_FULL_SIMD_LINE, + CompilerOptions::CL_3_0)) + return ""; + const size_t index = static_cast<size_t>(e); + return EnumNamesCompilerOptions()[index]; +} + struct Int4 FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef Int4Builder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { @@ -1832,6 +1946,454 @@ inline flatbuffers::Offset<Arguments> CreateArgumentsDirect( tensor_objects__); } +struct OperationDef FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef OperationDefBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_PRECISION = 4, + VT_SRC_TENSORS = 6, + VT_DST_TENSORS = 8 + }; + tflite::gpu::data::CalculationsPrecision precision() const { + return static_cast<tflite::gpu::data::CalculationsPrecision>( + GetField<int8_t>(VT_PRECISION, 0)); + } + const flatbuffers::Vector< + flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>> + *src_tensors() const { + return GetPointer<const flatbuffers::Vector< + flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>> *>( + VT_SRC_TENSORS); + } + const flatbuffers::Vector< + flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>> + *dst_tensors() const { + return GetPointer<const flatbuffers::Vector< + flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>> *>( + VT_DST_TENSORS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<int8_t>(verifier, VT_PRECISION) && + VerifyOffset(verifier, VT_SRC_TENSORS) && + verifier.VerifyVector(src_tensors()) && + verifier.VerifyVectorOfTables(src_tensors()) && + VerifyOffset(verifier, VT_DST_TENSORS) && + verifier.VerifyVector(dst_tensors()) && + verifier.VerifyVectorOfTables(dst_tensors()) && verifier.EndTable(); + } +}; + +struct OperationDefBuilder { + typedef OperationDef Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_precision(tflite::gpu::data::CalculationsPrecision precision) { + fbb_.AddElement<int8_t>(OperationDef::VT_PRECISION, + static_cast<int8_t>(precision), 0); + } + void add_src_tensors( + flatbuffers::Offset<flatbuffers::Vector< + flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>> + src_tensors) { + fbb_.AddOffset(OperationDef::VT_SRC_TENSORS, src_tensors); + } + void add_dst_tensors( + flatbuffers::Offset<flatbuffers::Vector< + flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>> + dst_tensors) { + fbb_.AddOffset(OperationDef::VT_DST_TENSORS, dst_tensors); + } + explicit OperationDefBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset<OperationDef> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<OperationDef>(end); + return o; + } +}; + +inline flatbuffers::Offset<OperationDef> CreateOperationDef( + flatbuffers::FlatBufferBuilder &_fbb, + tflite::gpu::data::CalculationsPrecision precision = + tflite::gpu::data::CalculationsPrecision::F32, + flatbuffers::Offset<flatbuffers::Vector< + flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>> + src_tensors = 0, + flatbuffers::Offset<flatbuffers::Vector< + flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>> + dst_tensors = 0) { + OperationDefBuilder builder_(_fbb); + builder_.add_dst_tensors(dst_tensors); + builder_.add_src_tensors(src_tensors); + builder_.add_precision(precision); + return builder_.Finish(); +} + +inline flatbuffers::Offset<OperationDef> CreateOperationDefDirect( + flatbuffers::FlatBufferBuilder &_fbb, + tflite::gpu::data::CalculationsPrecision precision = + tflite::gpu::data::CalculationsPrecision::F32, + const std::vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>> + *src_tensors = nullptr, + const std::vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>> + *dst_tensors = nullptr) { + auto src_tensors__ = + src_tensors + ? _fbb.CreateVector< + flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>( + *src_tensors) + : 0; + auto dst_tensors__ = + dst_tensors + ? _fbb.CreateVector< + flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>( + *dst_tensors) + : 0; + return tflite::gpu::data::CreateOperationDef(_fbb, precision, src_tensors__, + dst_tensors__); +} + +struct CompilerOption FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef CompilerOptionBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_OPTION = 4 + }; + tflite::gpu::data::CompilerOptions option() const { + return static_cast<tflite::gpu::data::CompilerOptions>( + GetField<int8_t>(VT_OPTION, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<int8_t>(verifier, VT_OPTION) && verifier.EndTable(); + } +}; + +struct CompilerOptionBuilder { + typedef CompilerOption Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_option(tflite::gpu::data::CompilerOptions option) { + fbb_.AddElement<int8_t>(CompilerOption::VT_OPTION, + static_cast<int8_t>(option), 0); + } + explicit CompilerOptionBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset<CompilerOption> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<CompilerOption>(end); + return o; + } +}; + +inline flatbuffers::Offset<CompilerOption> CreateCompilerOption( + flatbuffers::FlatBufferBuilder &_fbb, + tflite::gpu::data::CompilerOptions option = + tflite::gpu::data::CompilerOptions::ADRENO_FULL_SIMD_LINE) { + CompilerOptionBuilder builder_(_fbb); + builder_.add_option(option); + return builder_.Finish(); +} + +struct GPUOperation FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef GPUOperationBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_ARGUMENTS = 4, + VT_CODE = 6, + VT_WORK_GROUP_SIZE = 8, + VT_COMPILER_OPTIONS = 10, + VT_TENSOR_TO_GRID = 12, + VT_ELEMENTWISE = 14, + VT_LINKABLE = 16, + VT_CHECK_SRC_CHANNELS_SIZE = 18, + VT_DEFINITION = 20, + VT_GRID_DIMENSION = 22, + VT_WORK_GROUP_LAUNCH_ORDER = 24, + VT_GRID_SIZE = 26, + VT_SRC_TENSORS_NAMES = 28, + VT_DST_TENSORS_NAMES = 30, + VT_WORK_GROUPS_COUNT = 32, + VT_LINKABLE_COUNT = 34, + VT_ELEMENTWISE_CODE = 36 + }; + const tflite::gpu::data::Arguments *arguments() const { + return GetPointer<const tflite::gpu::data::Arguments *>(VT_ARGUMENTS); + } + const flatbuffers::String *code() const { + return GetPointer<const flatbuffers::String *>(VT_CODE); + } + const tflite::gpu::data::Int3 *work_group_size() const { + return GetPointer<const tflite::gpu::data::Int3 *>(VT_WORK_GROUP_SIZE); + } + const flatbuffers::Vector< + flatbuffers::Offset<tflite::gpu::data::CompilerOption>> + *compiler_options() const { + return GetPointer<const flatbuffers::Vector< + flatbuffers::Offset<tflite::gpu::data::CompilerOption>> *>( + VT_COMPILER_OPTIONS); + } + tflite::gpu::data::TensorToGrid tensor_to_grid() const { + return static_cast<tflite::gpu::data::TensorToGrid>( + GetField<int8_t>(VT_TENSOR_TO_GRID, 0)); + } + bool elementwise() const { return GetField<uint8_t>(VT_ELEMENTWISE, 0) != 0; } + bool linkable() const { return GetField<uint8_t>(VT_LINKABLE, 0) != 0; } + bool check_src_channels_size() const { + return GetField<uint8_t>(VT_CHECK_SRC_CHANNELS_SIZE, 0) != 0; + } + const tflite::gpu::data::OperationDef *definition() const { + return GetPointer<const tflite::gpu::data::OperationDef *>(VT_DEFINITION); + } + int32_t grid_dimension() const { + return GetField<int32_t>(VT_GRID_DIMENSION, 0); + } + const tflite::gpu::data::Int3 *work_group_launch_order() const { + return GetPointer<const tflite::gpu::data::Int3 *>( + VT_WORK_GROUP_LAUNCH_ORDER); + } + const tflite::gpu::data::Int3 *grid_size() const { + return GetPointer<const tflite::gpu::data::Int3 *>(VT_GRID_SIZE); + } + const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> + *src_tensors_names() const { + return GetPointer< + const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>( + VT_SRC_TENSORS_NAMES); + } + const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> + *dst_tensors_names() const { + return GetPointer< + const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>( + VT_DST_TENSORS_NAMES); + } + const tflite::gpu::data::Int3 *work_groups_count() const { + return GetPointer<const tflite::gpu::data::Int3 *>(VT_WORK_GROUPS_COUNT); + } + int32_t linkable_count() const { + return GetField<int32_t>(VT_LINKABLE_COUNT, 0); + } + const flatbuffers::String *elementwise_code() const { + return GetPointer<const flatbuffers::String *>(VT_ELEMENTWISE_CODE); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_ARGUMENTS) && + verifier.VerifyTable(arguments()) && + VerifyOffset(verifier, VT_CODE) && verifier.VerifyString(code()) && + VerifyOffset(verifier, VT_WORK_GROUP_SIZE) && + verifier.VerifyTable(work_group_size()) && + VerifyOffset(verifier, VT_COMPILER_OPTIONS) && + verifier.VerifyVector(compiler_options()) && + verifier.VerifyVectorOfTables(compiler_options()) && + VerifyField<int8_t>(verifier, VT_TENSOR_TO_GRID) && + VerifyField<uint8_t>(verifier, VT_ELEMENTWISE) && + VerifyField<uint8_t>(verifier, VT_LINKABLE) && + VerifyField<uint8_t>(verifier, VT_CHECK_SRC_CHANNELS_SIZE) && + VerifyOffset(verifier, VT_DEFINITION) && + verifier.VerifyTable(definition()) && + VerifyField<int32_t>(verifier, VT_GRID_DIMENSION) && + VerifyOffset(verifier, VT_WORK_GROUP_LAUNCH_ORDER) && + verifier.VerifyTable(work_group_launch_order()) && + VerifyOffset(verifier, VT_GRID_SIZE) && + verifier.VerifyTable(grid_size()) && + VerifyOffset(verifier, VT_SRC_TENSORS_NAMES) && + verifier.VerifyVector(src_tensors_names()) && + verifier.VerifyVectorOfStrings(src_tensors_names()) && + VerifyOffset(verifier, VT_DST_TENSORS_NAMES) && + verifier.VerifyVector(dst_tensors_names()) && + verifier.VerifyVectorOfStrings(dst_tensors_names()) && + VerifyOffset(verifier, VT_WORK_GROUPS_COUNT) && + verifier.VerifyTable(work_groups_count()) && + VerifyField<int32_t>(verifier, VT_LINKABLE_COUNT) && + VerifyOffset(verifier, VT_ELEMENTWISE_CODE) && + verifier.VerifyString(elementwise_code()) && verifier.EndTable(); + } +}; + +struct GPUOperationBuilder { + typedef GPUOperation Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_arguments( + flatbuffers::Offset<tflite::gpu::data::Arguments> arguments) { + fbb_.AddOffset(GPUOperation::VT_ARGUMENTS, arguments); + } + void add_code(flatbuffers::Offset<flatbuffers::String> code) { + fbb_.AddOffset(GPUOperation::VT_CODE, code); + } + void add_work_group_size( + flatbuffers::Offset<tflite::gpu::data::Int3> work_group_size) { + fbb_.AddOffset(GPUOperation::VT_WORK_GROUP_SIZE, work_group_size); + } + void add_compiler_options( + flatbuffers::Offset<flatbuffers::Vector< + flatbuffers::Offset<tflite::gpu::data::CompilerOption>>> + compiler_options) { + fbb_.AddOffset(GPUOperation::VT_COMPILER_OPTIONS, compiler_options); + } + void add_tensor_to_grid(tflite::gpu::data::TensorToGrid tensor_to_grid) { + fbb_.AddElement<int8_t>(GPUOperation::VT_TENSOR_TO_GRID, + static_cast<int8_t>(tensor_to_grid), 0); + } + void add_elementwise(bool elementwise) { + fbb_.AddElement<uint8_t>(GPUOperation::VT_ELEMENTWISE, + static_cast<uint8_t>(elementwise), 0); + } + void add_linkable(bool linkable) { + fbb_.AddElement<uint8_t>(GPUOperation::VT_LINKABLE, + static_cast<uint8_t>(linkable), 0); + } + void add_check_src_channels_size(bool check_src_channels_size) { + fbb_.AddElement<uint8_t>(GPUOperation::VT_CHECK_SRC_CHANNELS_SIZE, + static_cast<uint8_t>(check_src_channels_size), 0); + } + void add_definition( + flatbuffers::Offset<tflite::gpu::data::OperationDef> definition) { + fbb_.AddOffset(GPUOperation::VT_DEFINITION, definition); + } + void add_grid_dimension(int32_t grid_dimension) { + fbb_.AddElement<int32_t>(GPUOperation::VT_GRID_DIMENSION, grid_dimension, + 0); + } + void add_work_group_launch_order( + flatbuffers::Offset<tflite::gpu::data::Int3> work_group_launch_order) { + fbb_.AddOffset(GPUOperation::VT_WORK_GROUP_LAUNCH_ORDER, + work_group_launch_order); + } + void add_grid_size(flatbuffers::Offset<tflite::gpu::data::Int3> grid_size) { + fbb_.AddOffset(GPUOperation::VT_GRID_SIZE, grid_size); + } + void add_src_tensors_names( + flatbuffers::Offset< + flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> + src_tensors_names) { + fbb_.AddOffset(GPUOperation::VT_SRC_TENSORS_NAMES, src_tensors_names); + } + void add_dst_tensors_names( + flatbuffers::Offset< + flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> + dst_tensors_names) { + fbb_.AddOffset(GPUOperation::VT_DST_TENSORS_NAMES, dst_tensors_names); + } + void add_work_groups_count( + flatbuffers::Offset<tflite::gpu::data::Int3> work_groups_count) { + fbb_.AddOffset(GPUOperation::VT_WORK_GROUPS_COUNT, work_groups_count); + } + void add_linkable_count(int32_t linkable_count) { + fbb_.AddElement<int32_t>(GPUOperation::VT_LINKABLE_COUNT, linkable_count, + 0); + } + void add_elementwise_code( + flatbuffers::Offset<flatbuffers::String> elementwise_code) { + fbb_.AddOffset(GPUOperation::VT_ELEMENTWISE_CODE, elementwise_code); + } + explicit GPUOperationBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset<GPUOperation> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<GPUOperation>(end); + return o; + } +}; + +inline flatbuffers::Offset<GPUOperation> CreateGPUOperation( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<tflite::gpu::data::Arguments> arguments = 0, + flatbuffers::Offset<flatbuffers::String> code = 0, + flatbuffers::Offset<tflite::gpu::data::Int3> work_group_size = 0, + flatbuffers::Offset<flatbuffers::Vector< + flatbuffers::Offset<tflite::gpu::data::CompilerOption>>> + compiler_options = 0, + tflite::gpu::data::TensorToGrid tensor_to_grid = + tflite::gpu::data::TensorToGrid::CUSTOM, + bool elementwise = false, bool linkable = false, + bool check_src_channels_size = false, + flatbuffers::Offset<tflite::gpu::data::OperationDef> definition = 0, + int32_t grid_dimension = 0, + flatbuffers::Offset<tflite::gpu::data::Int3> work_group_launch_order = 0, + flatbuffers::Offset<tflite::gpu::data::Int3> grid_size = 0, + flatbuffers::Offset< + flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> + src_tensors_names = 0, + flatbuffers::Offset< + flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> + dst_tensors_names = 0, + flatbuffers::Offset<tflite::gpu::data::Int3> work_groups_count = 0, + int32_t linkable_count = 0, + flatbuffers::Offset<flatbuffers::String> elementwise_code = 0) { + GPUOperationBuilder builder_(_fbb); + builder_.add_elementwise_code(elementwise_code); + builder_.add_linkable_count(linkable_count); + builder_.add_work_groups_count(work_groups_count); + builder_.add_dst_tensors_names(dst_tensors_names); + builder_.add_src_tensors_names(src_tensors_names); + builder_.add_grid_size(grid_size); + builder_.add_work_group_launch_order(work_group_launch_order); + builder_.add_grid_dimension(grid_dimension); + builder_.add_definition(definition); + builder_.add_compiler_options(compiler_options); + builder_.add_work_group_size(work_group_size); + builder_.add_code(code); + builder_.add_arguments(arguments); + builder_.add_check_src_channels_size(check_src_channels_size); + builder_.add_linkable(linkable); + builder_.add_elementwise(elementwise); + builder_.add_tensor_to_grid(tensor_to_grid); + return builder_.Finish(); +} + +inline flatbuffers::Offset<GPUOperation> CreateGPUOperationDirect( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<tflite::gpu::data::Arguments> arguments = 0, + const char *code = nullptr, + flatbuffers::Offset<tflite::gpu::data::Int3> work_group_size = 0, + const std::vector<flatbuffers::Offset<tflite::gpu::data::CompilerOption>> + *compiler_options = nullptr, + tflite::gpu::data::TensorToGrid tensor_to_grid = + tflite::gpu::data::TensorToGrid::CUSTOM, + bool elementwise = false, bool linkable = false, + bool check_src_channels_size = false, + flatbuffers::Offset<tflite::gpu::data::OperationDef> definition = 0, + int32_t grid_dimension = 0, + flatbuffers::Offset<tflite::gpu::data::Int3> work_group_launch_order = 0, + flatbuffers::Offset<tflite::gpu::data::Int3> grid_size = 0, + const std::vector<flatbuffers::Offset<flatbuffers::String>> + *src_tensors_names = nullptr, + const std::vector<flatbuffers::Offset<flatbuffers::String>> + *dst_tensors_names = nullptr, + flatbuffers::Offset<tflite::gpu::data::Int3> work_groups_count = 0, + int32_t linkable_count = 0, const char *elementwise_code = nullptr) { + auto code__ = code ? _fbb.CreateString(code) : 0; + auto compiler_options__ = + compiler_options + ? _fbb.CreateVector< + flatbuffers::Offset<tflite::gpu::data::CompilerOption>>( + *compiler_options) + : 0; + auto src_tensors_names__ = + src_tensors_names + ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>( + *src_tensors_names) + : 0; + auto dst_tensors_names__ = + dst_tensors_names + ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>( + *dst_tensors_names) + : 0; + auto elementwise_code__ = + elementwise_code ? _fbb.CreateString(elementwise_code) : 0; + return tflite::gpu::data::CreateGPUOperation( + _fbb, arguments, code__, work_group_size, compiler_options__, + tensor_to_grid, elementwise, linkable, check_src_channels_size, + definition, grid_dimension, work_group_launch_order, grid_size, + src_tensors_names__, dst_tensors_names__, work_groups_count, + linkable_count, elementwise_code__); +} + } // namespace data } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/task/util.cc b/tensorflow/lite/delegates/gpu/common/task/util.cc index 4d2892f7e54..137815b9a90 100644 --- a/tensorflow/lite/delegates/gpu/common/task/util.cc +++ b/tensorflow/lite/delegates/gpu/common/task/util.cc @@ -15,6 +15,11 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/task/util.h" +#include <cfloat> + +#include "absl/strings/substitute.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" + namespace tflite { namespace gpu { @@ -43,5 +48,118 @@ std::string MemoryTypeToMetalType(MemoryType type) { return ""; } +std::string GetXStrideCorrected(const std::string& src_x, + const std::string& batch_size, + const std::string& stride_x, + const std::string& padding_x) { + // int p0 = src_x / batch_size;\n"; + // int b0 = src_x % batch_size;\n"; + // return p0 * stride_x * batch_size + b0 + padding_x;\n"; + return absl::Substitute("((($0) / $1) * $2 * $1 + (($0) % $1) + $3)", src_x, + batch_size, stride_x, padding_x); +} + +std::string GetXStrideCorrectedV2(const std::string& src_x, + const std::string& batch_size, + const std::string& stride_x, + const std::string& padding_x) { + // int p0 = src_x / batch_size;\n"; + // int b0 = src_x % batch_size;\n"; + // return (p0 * stride_x + padding_x) * batch_size + b0;\n"; + return absl::Substitute("(((($0) / $1) * $2 + $3) * $1 + ($0) % $1)", src_x, + batch_size, stride_x, padding_x); +} + +float4 GetMaskForLastPlane(int channels) { + float4 mask = float4(0.0f); + const int reminder = channels % 4 == 0 ? 4 : channels % 4; + for (int i = 0; i < reminder; ++i) { + mask[i] = 1.0f; + } + return mask; +} + +int GetRecommendedBlockSizeForConv(const GpuInfo& gpu_info, + CalculationsPrecision precision, + int task_size) { + const float task_size_per_cu = + task_size / static_cast<float>(gpu_info.GetComputeUnitsCount()); + int block_size = 1; + float threshold_1 = FLT_MAX; + float threshold_2 = FLT_MAX; + float threshold_4 = FLT_MAX; + if (!gpu_info.IsMali()) { + return 1; + } + MaliInfo mali_info = gpu_info.mali_info; + switch (precision) { + case CalculationsPrecision::F16: + if (mali_info.IsBifrostGen1()) { + threshold_1 = 256.0f; + threshold_2 = 256.0f * 4.0f; + threshold_4 = 256.0f * 8.0f; + } else if (mali_info.IsBifrostGen2()) { + threshold_1 = 256.0f * 2.0f; + threshold_2 = 256.0f * 8.0f; + threshold_4 = 256.0f * 16.0f; + } else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) { + threshold_1 = 256.0f; + threshold_2 = 256.0f * 6.0f; + threshold_4 = 256.0f * 16.0f; + } else if (mali_info.IsMidgard()) { + threshold_1 = 256.0f * 4.0f; + threshold_2 = 256.0f * 16.0f; + } + break; + case CalculationsPrecision::F32_F16: + if (mali_info.IsBifrostGen1()) { + threshold_1 = 256.0f; + threshold_2 = 256.0f * 3.0f; + threshold_4 = 256.0f * 32.0f; + } else if (mali_info.IsBifrostGen2()) { + threshold_1 = 256.0f * 2.0f; + threshold_2 = 256.0f * 8.0f; + } else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) { + threshold_1 = 256.0f; + threshold_2 = 256.0f * 8.0f; + } else if (mali_info.IsMidgard()) { + threshold_1 = 256.0f * 4.0f; + } + break; + case CalculationsPrecision::F32: + if (mali_info.IsBifrostGen1()) { + threshold_1 = 256.0f; + threshold_2 = 256.0f * 4.0f; + } else if (mali_info.IsBifrostGen2()) { + threshold_1 = 128.0f; + threshold_2 = 256.0f * 4.0f; + } else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) { + threshold_1 = 256.0f; + threshold_2 = 256.0f * 12.0f; + } else if (mali_info.IsMidgard()) { + threshold_1 = 256.0f * 16.0f; + } + break; + } + if (task_size_per_cu <= threshold_1) { + block_size = 1; + } else if (task_size_per_cu <= threshold_2) { + block_size = 2; + } else if (task_size_per_cu <= threshold_4) { + block_size = 4; + } else { + block_size = 8; + } + return block_size; +} + +int3 GetWorkGroupsCount(const int3& grid_size, const int3& work_group_size) { + int3 work_groups_count; + work_groups_count.x = DivideRoundUp(grid_size.x, work_group_size.x); + work_groups_count.y = DivideRoundUp(grid_size.y, work_group_size.y); + work_groups_count.z = DivideRoundUp(grid_size.z, work_group_size.z); + return work_groups_count; +} + } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/task/util.h b/tensorflow/lite/delegates/gpu/common/task/util.h index 3fb2c5abe84..135126b7d01 100644 --- a/tensorflow/lite/delegates/gpu/common/task/util.h +++ b/tensorflow/lite/delegates/gpu/common/task/util.h @@ -17,8 +17,12 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_UTIL_H_ #include <string> +#include <vector> +#include "tensorflow/lite/delegates/gpu/common/gpu_info.h" +#include "tensorflow/lite/delegates/gpu/common/precision.h" #include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" namespace tflite { namespace gpu { @@ -27,6 +31,36 @@ std::string MemoryTypeToCLType(MemoryType type); std::string MemoryTypeToMetalType(MemoryType type); +// Calculates correct X coordinate when stride != 1 and batch != 1 for layouts +// with B after W (for example HWBC4) and WB stored in one axis of GPU +// resources. +std::string GetXStrideCorrected(const std::string& src_x, + const std::string& batch_size, + const std::string& stride_x, + const std::string& padding_x); + +// Calculates correct X coordinate when stride != 1 and batch != 1 for layouts +// with B after W (for example HWBC4) and WB stored in one axis of GPU +// resources. +std::string GetXStrideCorrectedV2(const std::string& src_x, + const std::string& batch_size, + const std::string& stride_x, + const std::string& padding_x); + +// Returns float4 mask for last plane(batch of 4 channels) +// assumes that plane size is 4; +// for example we have 7 channels, in our data structures we align it to 8 +// but 8s-channel will be empty, then last plane (batch of 4 channels) will +// have this mask (1, 1, 1, 0). +float4 GetMaskForLastPlane(int channels); + +// task_size as amount of FLT4 processed elements. +int GetRecommendedBlockSizeForConv(const GpuInfo& gpu_info, + CalculationsPrecision precision, + int task_size); + +int3 GetWorkGroupsCount(const int3& grid_size, const int3& work_group_size); + } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/util.h b/tensorflow/lite/delegates/gpu/common/task/weights_conversion.h similarity index 75% rename from tensorflow/lite/delegates/gpu/cl/kernels/util.h rename to tensorflow/lite/delegates/gpu/common/task/weights_conversion.h index 519f1f117b2..adbd8a107cd 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/util.h +++ b/tensorflow/lite/delegates/gpu/common/task/weights_conversion.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,16 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_UTIL_H_ -#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_UTIL_H_ +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_WEIGHTS_CONVERSION_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_WEIGHTS_CONVERSION_H_ #include <string> #include <vector> #include "absl/types/span.h" -#include "tensorflow/lite/delegates/gpu/cl/device_info.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" -#include "tensorflow/lite/delegates/gpu/common/precision.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/gpu/common/types.h" @@ -30,25 +28,6 @@ limitations under the License. namespace tflite { namespace gpu { -namespace cl { - -std::string GetCommonDefines(CalculationsPrecision precision); - -// Calculates correct X coordinate when stride != 1 and batch != 1 for layouts -// with B after W (for example HWBC4) and WB stored in one axis of GPU -// resources. -std::string GetXStrideCorrected(const std::string& src_x, - const std::string& batch_size, - const std::string& stride_x, - const std::string& padding_x); - -// Calculates correct X coordinate when stride != 1 and batch != 1 for layouts -// with B after W (for example HWBC4) and WB stored in one axis of GPU -// resources. -std::string GetXStrideCorrectedV2(const std::string& src_x, - const std::string& batch_size, - const std::string& stride_x, - const std::string& padding_x); template <DataType S, typename T> void RearrangeWeightsToOHWIOGroupI4O4( @@ -198,25 +177,43 @@ void RearrangeWeightsToI4DHWIOOGroupO4( } } -// Returns float4 mask for last plane(batch of 4 channels) -// assumes that plane size is 4; -// for example we have 7 channels, in our data structures we align it to 8 -// but 8s-channel will be empty, then last plane (batch of 4 channels) will -// have this mask (1, 1, 1, 0). -float4 GetMaskForLastPlane(int channels); +template <DataType S, typename T> +void RearrangeWeightsToOICustomSpatialI4O4( + const tflite::gpu::Tensor<OHWI, S>& weights, + const std::vector<int>& spatial_remap, absl::Span<T> dst) { + const int dst_slices = DivideRoundUp(weights.shape.o, 4); + const int src_slices = DivideRoundUp(weights.shape.i, 4); -// returns first work group from wgs that has size not bigger than max_wg_size -// if no suitable groups among wgs, returns {1, 1, 1} -int3 GetFirstSuitableWorkGroup(const std::vector<int3>& wgs, int max_wg_size); + int counter = 0; + for (int d = 0; d < dst_slices; ++d) { + for (int s = 0; s < src_slices; ++s) { + for (int y = 0; y < weights.shape.h; ++y) { + for (int x = 0; x < weights.shape.w; ++x) { + const int kernel_index = spatial_remap[y * weights.shape.w + x]; + const int kernel_index_x = kernel_index % weights.shape.w; + const int kernel_index_y = kernel_index / weights.shape.w; + for (int i = 0; i < 4; ++i) { + T filter; + for (int j = 0; j < 4; ++j) { + const int s_ch = s * 4 + i; + const int d_ch = d * 4 + j; + if (s_ch < weights.shape.i && d_ch < weights.shape.o) { + const int f_index = weights.shape.LinearIndex( + {d_ch, kernel_index_y, kernel_index_x, s_ch}); + filter[j] = weights.data[f_index]; + } else { + filter[j] = 0.0f; + } + } + dst[counter++] = filter; + } + } + } + } + } +} -// task_size as amount of FLT4 processed elements. -int GetRecommendedBlockSizeForConv(const GpuInfo& gpu_info, - CalculationsPrecision precision, - int task_size); - -int3 GetWorkGroupsCount(const int3& grid_size, const int3& work_group_size); -} // namespace cl } // namespace gpu } // namespace tflite -#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_UTIL_H_ +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_WEIGHTS_CONVERSION_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_common.h b/tensorflow/lite/delegates/gpu/common/task/weights_layout.h similarity index 64% rename from tensorflow/lite/delegates/gpu/cl/kernels/conv_common.h rename to tensorflow/lite/delegates/gpu/common/task/weights_layout.h index f630c9d1f1c..cf22630ee27 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_common.h +++ b/tensorflow/lite/delegates/gpu/common/task/weights_layout.h @@ -13,25 +13,29 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONV_COMMON_H_ -#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONV_COMMON_H_ +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_WEIGHTS_LAYOUT_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_WEIGHTS_LAYOUT_H_ + +#include <vector> namespace tflite { namespace gpu { -namespace cl { -enum class ConvWeightsLayout { +enum class WeightsLayout { kUnknown, kOHWIOGroupI4O4, + kOICustomSSpatialI4O4, }; -struct ConvWeightsDescription { - ConvWeightsLayout layout; +struct WeightsDescription { + WeightsLayout layout; + // applicable with kOHWIOGroupI4O4 int output_group_size; + // applicable with kOICustomSSpatialI4O4 + std::vector<int> spatial_remap; }; -} // namespace cl } // namespace gpu } // namespace tflite -#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONV_COMMON_H_ +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_WEIGHTS_LAYOUT_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.cc b/tensorflow/lite/delegates/gpu/common/task/work_group_picking.cc similarity index 90% rename from tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.cc rename to tensorflow/lite/delegates/gpu/common/task/work_group_picking.cc index 0b7ec8ed683..765bf1d1a0d 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.cc +++ b/tensorflow/lite/delegates/gpu/common/task/work_group_picking.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" +#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" #include <algorithm> #include <limits> @@ -24,7 +24,6 @@ limitations under the License. namespace tflite { namespace gpu { -namespace cl { namespace { @@ -52,9 +51,9 @@ std::vector<int3> GenerateWorkGroupSizesXYMultipleOf( if (work_group_size_xy * z > kernel_info.max_work_group_size) { continue; } - if (x <= gpu_info.max_work_group_size_x && - y <= gpu_info.max_work_group_size_y && - z <= gpu_info.max_work_group_size_z) { + if (x <= gpu_info.GetMaxWorkGroupSizeForX() && + y <= gpu_info.GetMaxWorkGroupSizeForY() && + z <= gpu_info.GetMaxWorkGroupSizeForZ()) { work_groups.push_back({x, y, z}); } } @@ -78,9 +77,9 @@ std::vector<int3> GenerateWorkGroupSizesXMultipleOf( x += multiplier) { for (auto y : possible_y_sizes) { for (auto z : possible_z_sizes) { - if (x <= gpu_info.max_work_group_size_x && - y <= gpu_info.max_work_group_size_y && - z <= gpu_info.max_work_group_size_z && + if (x <= gpu_info.GetMaxWorkGroupSizeForX() && + y <= gpu_info.GetMaxWorkGroupSizeForY() && + z <= gpu_info.GetMaxWorkGroupSizeForZ() && x * y * z <= kernel_info.max_work_group_size) { work_groups.push_back({x, y, z}); } @@ -94,9 +93,9 @@ void GetWorkGroupsAlignedToGrid(const GpuInfo& gpu_info, const KernelInfo& kernel_info, const int3& grid, std::vector<int3>* work_groups) { int3 max_wg_size; - max_wg_size.x = gpu_info.max_work_group_size_x; - max_wg_size.y = gpu_info.max_work_group_size_y; - max_wg_size.z = gpu_info.max_work_group_size_z; + max_wg_size.x = gpu_info.GetMaxWorkGroupSizeForX(); + max_wg_size.y = gpu_info.GetMaxWorkGroupSizeForY(); + max_wg_size.z = gpu_info.GetMaxWorkGroupSizeForZ(); GenerateWorkGroupSizesAlignedToGrid( grid, max_wg_size, kernel_info.max_work_group_size, work_groups); } @@ -275,7 +274,7 @@ void GetPossibleWorkGroupsConv(TuningType tuning_type, const GpuInfo& gpu_info, if (gpu_info.IsAdreno()) { max_z_size = gpu_info.adreno_info.IsAdreno3xx() ? 16 : 64; } - max_z_size = std::min(max_z_size, gpu_info.max_work_group_size_z); + max_z_size = std::min(max_z_size, gpu_info.GetMaxWorkGroupSizeForZ()); work_groups->push_back( GetWorkGroupConv(grid, kernel_info.max_work_group_size, max_z_size)); return; @@ -290,6 +289,15 @@ void GetPossibleWorkGroupsConv(TuningType tuning_type, const GpuInfo& gpu_info, } } -} // namespace cl +int3 GetFirstSuitableWorkGroup(const std::vector<int3>& wgs, int max_wg_size) { + for (const auto& wg : wgs) { + const int wg_size = wg.x * wg.y * wg.z; + if (wg_size <= max_wg_size) { + return wg; + } + } + return {1, 1, 1}; +} + } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h b/tensorflow/lite/delegates/gpu/common/task/work_group_picking.h similarity index 81% rename from tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h rename to tensorflow/lite/delegates/gpu/common/task/work_group_picking.h index 8135aec8855..508bffc762d 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h +++ b/tensorflow/lite/delegates/gpu/common/task/work_group_picking.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_WORK_GROUP_PICKING_H_ -#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_WORK_GROUP_PICKING_H_ +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_WORK_GROUP_PICKING_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_WORK_GROUP_PICKING_H_ #include <vector> -#include "tensorflow/lite/delegates/gpu/cl/device_info.h" +#include "tensorflow/lite/delegates/gpu/common/gpu_info.h" #include "tensorflow/lite/delegates/gpu/common/kernel_info.h" #include "tensorflow/lite/delegates/gpu/common/task/tuning_type.h" #include "tensorflow/lite/delegates/gpu/common/types.h" @@ -26,7 +26,6 @@ limitations under the License. namespace tflite { namespace gpu { -namespace cl { // multiplier can be power of two only void GetPossibleWorkGroupsXYMultipleOf(int multiplier, const GpuInfo& gpu_info, @@ -56,8 +55,11 @@ void GetPossibleWorkGroupsConv(TuningType tuning_type, const GpuInfo& gpu_info, const KernelInfo& kernel_info, const int3& grid, std::vector<int3>* work_groups); -} // namespace cl +// returns first work group from wgs that has size not bigger than max_wg_size +// if no suitable groups among wgs, returns {1, 1, 1} +int3 GetFirstSuitableWorkGroup(const std::vector<int3>& wgs, int max_wg_size); + } // namespace gpu } // namespace tflite -#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_WORK_GROUP_PICKING_H_ +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_WORK_GROUP_PICKING_H_ diff --git a/tensorflow/lite/delegates/gpu/common/winograd_util.cc b/tensorflow/lite/delegates/gpu/common/winograd_util.cc index 4b9581d0f39..7d7f96bbb1b 100644 --- a/tensorflow/lite/delegates/gpu/common/winograd_util.cc +++ b/tensorflow/lite/delegates/gpu/common/winograd_util.cc @@ -149,5 +149,10 @@ void RearrangeWeightsToWinograd4x4To6x6Weights( } } +bool IsSuitableForWinograd4x4To6x6(const Convolution2DAttributes& attr) { + return attr.weights.shape.w == 3 && attr.weights.shape.h == 3 && + attr.dilations == HW(1, 1) && attr.strides == HW(1, 1); +} + } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/winograd_util.h b/tensorflow/lite/delegates/gpu/common/winograd_util.h index e88ceacb490..55629696b58 100644 --- a/tensorflow/lite/delegates/gpu/common/winograd_util.h +++ b/tensorflow/lite/delegates/gpu/common/winograd_util.h @@ -19,6 +19,7 @@ limitations under the License. #include <vector> #include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" @@ -38,6 +39,8 @@ void RearrangeWeightsToWinograd4x4To6x6Weights( const Tensor<OHWI, DataType::FLOAT32>& src_weights, Tensor<OHWI, DataType::FLOAT32>* dst_weights); +bool IsSuitableForWinograd4x4To6x6(const Convolution2DAttributes& attr); + } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/winograd_util_test.cc b/tensorflow/lite/delegates/gpu/common/winograd_util_test.cc new file mode 100644 index 00000000000..81fb643d399 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/winograd_util_test.cc @@ -0,0 +1,45 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/winograd_util.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +namespace tflite { +namespace gpu { + +TEST(Winograd, CorrectAttributesFor4x4To6x6) { + Convolution2DAttributes attr; + attr.padding.prepended = HW(1, 2); + attr.padding.appended = HW(0, 1); + attr.strides = HW(1, 1); + attr.dilations = HW(1, 1); + attr.weights.shape = OHWI(1, 3, 3, 1); + EXPECT_TRUE(IsSuitableForWinograd4x4To6x6(attr)); +} + +TEST(Winograd, IncorrectAttributesFor4x4To6x6) { + Convolution2DAttributes attr; + attr.padding.prepended = HW(1, 2); + attr.padding.appended = HW(0, 1); + attr.strides = HW(1, 1); + attr.dilations = HW(1, 1); + attr.weights.shape = OHWI(1, 2, 3, 1); + EXPECT_FALSE(IsSuitableForWinograd4x4To6x6(attr)); +} + +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc b/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc index 439eb0ade90..157a8992f71 100644 --- a/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc +++ b/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc @@ -189,15 +189,15 @@ template std::vector<uint3> GenerateWorkGroupSizes( template <typename T> void GenerateWorkGroupSizesAlignedToGrid(const T& grid, const T& max_work_group_size, - const int max_work_group_invocations, + const int max_work_group_total_size, std::vector<T>* work_groups) { auto alignment = WorkGroupSizeAlignment::PRECISE; *work_groups = GenerateWorkGroupSizes<T>( - grid, /*min_work_group_total_size = */ 32, max_work_group_invocations, + grid, /*min_work_group_total_size = */ 32, max_work_group_total_size, max_work_group_size, alignment, alignment, alignment); // If the grid parameter too small, method below cannot generate workgroups. if (work_groups->empty()) { - AddCornerCases(grid, max_work_group_invocations, max_work_group_size, + AddCornerCases(grid, max_work_group_total_size, max_work_group_size, alignment, alignment, alignment, work_groups); } } @@ -206,11 +206,11 @@ void GenerateWorkGroupSizesAlignedToGrid(const T& grid, template void GenerateWorkGroupSizesAlignedToGrid( const int3& grid, const int3& max_work_group_size, - const int max_work_group_invocations, std::vector<int3>* work_groups); + const int max_work_group_total_size, std::vector<int3>* work_groups); template void GenerateWorkGroupSizesAlignedToGrid( const uint3& grid, const uint3& max_work_group_size, - const int max_work_group_invocations, std::vector<uint3>* work_groups); + const int max_work_group_total_size, std::vector<uint3>* work_groups); } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/workgroup_selection.h b/tensorflow/lite/delegates/gpu/common/workgroup_selection.h index 67c51b45177..30e4d8dc33c 100644 --- a/tensorflow/lite/delegates/gpu/common/workgroup_selection.h +++ b/tensorflow/lite/delegates/gpu/common/workgroup_selection.h @@ -41,7 +41,7 @@ std::vector<T> GenerateWorkGroupSizes( template <typename T> void GenerateWorkGroupSizesAlignedToGrid(const T& grid, const T& max_work_group_size, - const int max_work_group_invocations, + const int max_work_group_total_size, std::vector<T>* work_groups); } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/gl/compiler.cc b/tensorflow/lite/delegates/gpu/gl/compiler.cc index 4b19082e4aa..20c93d6216a 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler.cc @@ -43,25 +43,27 @@ namespace gl { namespace { struct ExceedSizeChecker { - bool operator()(uint32_t v) const { return v > max_size; } + bool operator()(uint32_t v) const { return v > max_size.x; } bool operator()(const uint2& v) const { - return v.x > max_size || v.y > max_size; + return v.x > max_size.x || v.y > max_size.y; } bool operator()(const uint3& v) const { - return v.x > max_size || v.y > max_size || v.z > max_z_size; + return v.x > max_size.x || v.y > max_size.y || v.z > max_z_size; } - int max_size; + int2 max_size; int max_z_size; }; // Returns true if any size variable exceeds the given limit bool ExceedsMaxSize(const Object& object, const GpuInfo& gpu_info) { - return absl::visit(ExceedSizeChecker{gpu_info.max_texture_size, - gpu_info.max_array_texture_layers}, - object.size); + ExceedSizeChecker size_checker; + size_checker.max_size = + int2(gpu_info.GetMaxImage2DWidth(), gpu_info.GetMaxImage2DHeight()); + size_checker.max_z_size = gpu_info.GetMaxImage2DArrayLayers(); + return absl::visit(size_checker, object.size); } ObjectType ChooseFastestObjectType(const GpuInfo& gpu_info) { diff --git a/tensorflow/lite/delegates/gpu/gl/request_gpu_info.cc b/tensorflow/lite/delegates/gpu/gl/request_gpu_info.cc index 4f46a3586f5..598a27c0ea7 100644 --- a/tensorflow/lite/delegates/gpu/gl/request_gpu_info.cc +++ b/tensorflow/lite/delegates/gpu/gl/request_gpu_info.cc @@ -59,27 +59,33 @@ absl::Status RequestGpuInfo(GpuInfo* gpu_info) { GLint extensions_count; glGetIntegerv(GL_NUM_EXTENSIONS, &extensions_count); - info.extensions.resize(extensions_count); + info.opengl_info.extensions.resize(extensions_count); for (int i = 0; i < extensions_count; ++i) { - info.extensions[i] = std::string( + info.opengl_info.extensions[i] = std::string( reinterpret_cast<const char*>(glGetStringi(GL_EXTENSIONS, i))); } glGetIntegerv(GL_MAX_COMPUTE_SHADER_STORAGE_BLOCKS, &info.opengl_info.max_ssbo_bindings); glGetIntegerv(GL_MAX_COMPUTE_IMAGE_UNIFORMS, &info.opengl_info.max_image_bindings); - info.max_work_group_size.resize(3); glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 0, - &info.max_work_group_size[0]); + &info.opengl_info.max_compute_work_group_size_x); glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1, - &info.max_work_group_size[1]); + &info.opengl_info.max_compute_work_group_size_y); glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 2, - &info.max_work_group_size[2]); + &info.opengl_info.max_compute_work_group_size_z); + info.max_work_group_size.push_back( + info.opengl_info.max_compute_work_group_size_x); + info.max_work_group_size.push_back( + info.opengl_info.max_compute_work_group_size_y); + info.max_work_group_size.push_back( + info.opengl_info.max_compute_work_group_size_z); glGetIntegerv(GL_MAX_COMPUTE_WORK_GROUP_INVOCATIONS, - &info.max_work_group_invocations); - glGetIntegerv(GL_MAX_TEXTURE_SIZE, &info.max_texture_size); + &info.opengl_info.max_work_group_invocations); + glGetIntegerv(GL_MAX_TEXTURE_SIZE, &info.opengl_info.max_texture_size); glGetIntegerv(GL_MAX_IMAGE_UNITS, &info.opengl_info.max_image_units); - glGetIntegerv(GL_MAX_ARRAY_TEXTURE_LAYERS, &info.max_array_texture_layers); + glGetIntegerv(GL_MAX_ARRAY_TEXTURE_LAYERS, + &info.opengl_info.max_array_texture_layers); RETURN_IF_ERROR(GetOpenGlErrors()); *gpu_info = info; return absl::OkStatus(); diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.cc b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.cc index e21538b22a5..54252dc4fc8 100644 --- a/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.cc +++ b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.cc @@ -29,26 +29,26 @@ uint64_t CalculateProduct(const uint3& value) { } void MaybeShrinkWorkgroup(const GpuInfo& gpu_info, uint3* wg) { - while (wg->x > gpu_info.max_work_group_size[0]) { + while (wg->x > gpu_info.GetMaxWorkGroupSizeForX()) { wg->x /= 2; } - while (wg->y > gpu_info.max_work_group_size[1]) { + while (wg->y > gpu_info.GetMaxWorkGroupSizeForY()) { wg->y /= 2; } - while (wg->z > gpu_info.max_work_group_size[2]) { + while (wg->z > gpu_info.GetMaxWorkGroupSizeForZ()) { wg->z /= 2; } // Code below decreases amount of invocations per workgroup in a balanced way. // As example, workgroup size is x=16, y=8, z=8 (16x8x8 = 1024), but - // max_work_group_invocations = 512. We need to fit this limit and we can + // max_work_group_total_size = 512. We need to fit this limit and we can // reduce workgroup size in different ways, but we want to use the most // balanced way. So code below will find the maximal of three dimensions and // reduce it, so the whole workgroup is kept balanced by all dimensions. And // the final reduced workgroup will be x=8, y=8, z=8 for the given example. - while (CalculateProduct(*wg) > gpu_info.max_work_group_invocations) { + while (CalculateProduct(*wg) > gpu_info.GetMaxWorkGroupTotalSize()) { unsigned int* max = &wg->x; if (wg->y > *max) max = &wg->y; if (wg->z > *max) max = &wg->z; diff --git a/tensorflow/lite/delegates/gpu/metal/api.cc b/tensorflow/lite/delegates/gpu/metal/api.cc index 8062fb16d7c..bd6e9caa342 100644 --- a/tensorflow/lite/delegates/gpu/metal/api.cc +++ b/tensorflow/lite/delegates/gpu/metal/api.cc @@ -294,10 +294,14 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, node_id, inputs[0], inputs[1], outputs[0], absl::any_cast<MaxUnpooling2DAttributes>(node->operation.attributes)); break; - case OperationType::MEAN: - *tasks = Mean(node_id, inputs[0], outputs[0], - absl::any_cast<MeanAttributes>(node->operation.attributes)); + case OperationType::MEAN: { + auto attr = absl::any_cast<MeanAttributes>(node->operation.attributes); + if (attr.dims != std::set<Axis>({Axis::HEIGHT, Axis::WIDTH})) { + return absl::UnimplementedError("Mean supports HW axis only in Metal"); + } + *tasks = Mean(node_id, inputs[0], outputs[0], attr); break; + } case OperationType::MUL: if (inputs.size() == 1) { if (node->operation.attributes.has_value()) { diff --git a/tensorflow/lite/experimental/examples/lstm/rnn.py b/tensorflow/lite/experimental/examples/lstm/rnn.py index 1f55538bda5..aed003982dd 100644 --- a/tensorflow/lite/experimental/examples/lstm/rnn.py +++ b/tensorflow/lite/experimental/examples/lstm/rnn.py @@ -34,11 +34,14 @@ from tensorflow.python.ops.rnn import _best_effort_input_batch_size from tensorflow.python.ops.rnn import _dynamic_rnn_loop from tensorflow.python.ops.rnn import _should_cache from tensorflow.python.ops.rnn import _transpose_batch_time +from tensorflow.python.util import deprecation from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @tf_export(v1=["lite.experimental.nn.dynamic_rnn"]) +@deprecation.deprecated( + None, "Use `keras.layers.LSTM` instead.") def dynamic_rnn(cell, inputs, sequence_length=None, diff --git a/tensorflow/lite/experimental/examples/lstm/rnn_cell.py b/tensorflow/lite/experimental/examples/lstm/rnn_cell.py index 9736719c997..3c609af9b04 100644 --- a/tensorflow/lite/experimental/examples/lstm/rnn_cell.py +++ b/tensorflow/lite/experimental/examples/lstm/rnn_cell.py @@ -32,10 +32,13 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export @tf_export(v1=["lite.experimental.nn.TfLiteRNNCell"]) +@deprecation.deprecated( + None, "Use `keras.layers.RNN` instead for TF2.x.") class TfLiteRNNCell(rnn_cell_impl.LayerRNNCell): """The most basic RNN cell. @@ -159,6 +162,8 @@ class TfLiteRNNCell(rnn_cell_impl.LayerRNNCell): @tf_export(v1=["lite.experimental.nn.TFLiteLSTMCell"]) +@deprecation.deprecated( + None, "Use `keras.layers.LSTM` instead.") class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell): """Long short-term memory unit (LSTM) recurrent network cell. diff --git a/tensorflow/lite/experimental/tflite_api_dispatcher/BUILD b/tensorflow/lite/experimental/tflite_api_dispatcher/BUILD deleted file mode 100644 index bb64be61599..00000000000 --- a/tensorflow/lite/experimental/tflite_api_dispatcher/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable") - -package( - default_visibility = ["//tensorflow:internal"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "tflite_api_dispatcher", - hdrs = ["tflite_api_dispatcher.h"], - compatible_with = get_compatible_with_portable(), - deps = [ - "//tensorflow/lite:framework_lib", - ], -) - -cc_library( - name = "tflite_api_dispatcher_with_kernels", - hdrs = ["tflite_api_dispatcher.h"], - deps = [ - ":tflite_api_dispatcher", - "//tensorflow/lite:framework_lib", - ], -) diff --git a/tensorflow/lite/experimental/tflite_api_dispatcher/tflite_api_dispatcher.h b/tensorflow/lite/experimental/tflite_api_dispatcher/tflite_api_dispatcher.h deleted file mode 100644 index 55771ed9673..00000000000 --- a/tensorflow/lite/experimental/tflite_api_dispatcher/tflite_api_dispatcher.h +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// The purpose of this file is to indirect how implementations of the TensorFlow -// Lite API are selected by providing a single namespace tflite_api_dispatcher. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_TFLITE_API_DISPATCHER_TFLITE_API_DISPATCHER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_TFLITE_API_DISPATCHER_TFLITE_API_DISPATCHER_H_ - -// Import the relevant interpreter and model files. -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/model.h" - -namespace tflite_api_dispatcher { - -// Use the correct interpreter. -using tflite::Interpreter; -using tflite::InterpreterBuilder; -using TfLiteModel = tflite::FlatBufferModel; -using TfLiteVerifier = tflite::TfLiteVerifier; - -} // namespace tflite_api_dispatcher - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_TFLITE_API_DISPATCHER_TFLITE_API_DISPATCHER_H_ diff --git a/tensorflow/lite/g3doc/_book.yaml b/tensorflow/lite/g3doc/_book.yaml index bf52cb3c23d..84a09824613 100644 --- a/tensorflow/lite/g3doc/_book.yaml +++ b/tensorflow/lite/g3doc/_book.yaml @@ -42,7 +42,7 @@ upper_tabs: - heading: "Text" - title: "Text classification with Model Maker" path: /lite/tutorials/model_maker_text_classification - - title: "Question Answer with Model Maker" + - title: "BERT Question Answer with Model Maker" path: /lite/tutorials/model_maker_question_answer - heading: "Microcontrollers" @@ -110,7 +110,7 @@ upper_tabs: - heading: "Run Inference with metadata" - title: "Overview" path: /lite/inference_with_metadata/overview - - title: "Generate model interfaces with codegen" + - title: "Generate model interfaces using metadata" path: /lite/inference_with_metadata/codegen - title: "Integrate models with Task Library" path: /lite/inference_with_metadata/task_library/overview @@ -215,7 +215,7 @@ upper_tabs: - title: "Style transfer" path: /lite/models/style_transfer/overview - heading: "Text" - - title: "Question and answer" + - title: "BERT Question Answer" path: /lite/models/bert_qa/overview - title: "Smart reply" path: /lite/models/smart_reply/overview diff --git a/tensorflow/lite/g3doc/convert/metadata.md b/tensorflow/lite/g3doc/convert/metadata.md index a6670f10aba..41a92b4b1ca 100644 --- a/tensorflow/lite/g3doc/convert/metadata.md +++ b/tensorflow/lite/g3doc/convert/metadata.md @@ -472,15 +472,39 @@ public QuantizationParams getoutputTensorQuantizationParams(int inputIndex); public int[] getoutputTensorShape(int inputIndex); ``` -You can also read associated files through their names with the -`getAssociatedFile` method: - -```java -public InputStream getAssociatedFile(String fileName); -``` - Though the [TensorFlow Lite model schema](https://github.com/tensorflow/tensorflow/blob/aa7ff6aa28977826e7acae379e82da22482b2bf2/tensorflow/lite/schema/schema.fbs#L1075) supports multiple subgraphs, the TFLite Interpreter currently only supports a single subgraph. Therefore, `MetadataExtractor` omits subgraph index as an input argument in its methods. + +## Read the associated files from models + +The TensorFlow Lite model with metadata and associated files is essentially a +zip file that can be unpacked with common zip tools to get the associated files. +For example, you can unzip +[mobilenet_v1_0.75_160_quantized](https://tfhub.dev/tensorflow/lite-model/mobilenet_v1_0.75_160_quantized/1/metadata/1) +and extract the label file in the model as follows: + +```sh +$ unzip mobilenet_v1_0.75_160_quantized_1_metadata_1.tflite +Archive: mobilenet_v1_0.75_160_quantized_1_metadata_1.tflite + extracting: labels.txt +``` + +You can also read associated files through the Metadata Extractor library. + +In Java, pass the file name into the `MetadataExtractor.getAssociatedFile` +method: + +```java +public InputStream getAssociatedFile(String fileName); +``` + +Similarily, in C++, this can be done with the method, +`ModelMetadataExtractor::GetAssociatedFile`: + +```c++ +tflite::support::StatusOr<absl::string_view> GetAssociatedFile( + const std::string& filename) const; +``` diff --git a/tensorflow/lite/g3doc/guide/android.md b/tensorflow/lite/g3doc/guide/android.md index 420269de941..e36e8b1bcc5 100644 --- a/tensorflow/lite/g3doc/guide/android.md +++ b/tensorflow/lite/g3doc/guide/android.md @@ -41,6 +41,34 @@ as a starting point. The following sections contain some useful information for working with TensorFlow Lite on Android. +### Use Android Studio ML Model Binding + +Note: Required [Android Studio 4.1](https://developer.android.com/studio) or +above + +To import a TensorFlow Lite (TFLite) model: + +1. Right-click on the module you would like to use the TFLite model or click on + `File`, then `New` > `Other` > `TensorFlow Lite Model` +  + +1. Select the location of your TFLite file. Note that the tooling will + configure the module's dependency on your behalf with ML Model binding and + all dependencies automatically inserted into your Android module's + `build.gradle` file. + + Optional: Select the second checkbox for importing TensorFlow GPU if you + want to use [GPU acceleration](../performance/gpu). +  + +1. Click `Finish`. + +1. The following screen will appear after the import is successful. To start + using the model, select Kotlin or Java, copy and paste the code under the + `Sample Code` section. You can get back to this screen by double clicking + the TFLite model under the `ml` directory in Android Studio. +  + ### Use the TensorFlow Lite Task Library TensorFlow Lite Task Library contains a set of powerful and easy-to-use diff --git a/tensorflow/lite/g3doc/guide/model_maker.md b/tensorflow/lite/g3doc/guide/model_maker.md index 956bd127bcf..7db650d97a0 100644 --- a/tensorflow/lite/g3doc/guide/model_maker.md +++ b/tensorflow/lite/g3doc/guide/model_maker.md @@ -15,7 +15,7 @@ Supported Tasks -------------------------------------------------------------------------------------------------------- | ------------ Image Classification [guide](https://www.tensorflow.org/lite/tutorials/model_maker_image_classification) | Classify images into predefined categories. Text Classification [guide](https://www.tensorflow.org/lite/tutorials/model_maker_text_classification) | Classify text into predefined categories. -Question Answer [guide](https://www.tensorflow.org/lite/tutorials/model_maker_question_answer) | Find the answer in a certain context for a given question. +BERT Question Answer [guide](https://www.tensorflow.org/lite/tutorials/model_maker_question_answer) | Find the answer in a certain context for a given question with BERT. ## End-to-End Example diff --git a/tensorflow/lite/g3doc/images/android/import_dialog.png b/tensorflow/lite/g3doc/images/android/import_dialog.png new file mode 100644 index 00000000000..adaaa596a52 Binary files /dev/null and b/tensorflow/lite/g3doc/images/android/import_dialog.png differ diff --git a/tensorflow/lite/g3doc/images/android/model_details.png b/tensorflow/lite/g3doc/images/android/model_details.png new file mode 100644 index 00000000000..952d8489e94 Binary files /dev/null and b/tensorflow/lite/g3doc/images/android/model_details.png differ diff --git a/tensorflow/lite/g3doc/images/android/right_click_menu.png b/tensorflow/lite/g3doc/images/android/right_click_menu.png new file mode 100644 index 00000000000..64ff7c46d19 Binary files /dev/null and b/tensorflow/lite/g3doc/images/android/right_click_menu.png differ diff --git a/tensorflow/lite/g3doc/inference_with_metadata/codegen.md b/tensorflow/lite/g3doc/inference_with_metadata/codegen.md index b447573da41..2123737c76a 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/codegen.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/codegen.md @@ -1,4 +1,121 @@ -# Generate model interfaces with TensorFlow Lite code generator +# Generate model interfaces using metadata + +Using [TensorFlow Lite Metadata](../convert/metadata), developers can generate +wrapper code to enable integration on Android. For most developers, the +graphical interface of [Android Studio ML Model Binding](#mlbinding) is the +easiest to use. If you require more customisation or are using command line +tooling, the [TensorFlow Lite Codegen](#codegen) is also available. + +## Use Android Studio ML Model Binding {:#mlbinding} + +For TensorFlow Lite models enhanced with [metadata](../convert/metadata.md), +developers can use Android Studio ML Model Binding to automatically configure +settings for the project and generate wrapper classes based on the model +metadata. The wrapper code removes the need to interact directly with +`ByteBuffer`. Instead, developers can interact with the TensorFlow Lite model +with typed objects such as `Bitmap` and `Rect`. + +Note: Required [Android Studio 4.1](https://developer.android.com/studio) or +above + +### Import a TensorFlow Lite model in Android Studio + +1. Right-click on the module you would like to use the TFLite model or click on + `File`, then `New` > `Other` > `TensorFlow Lite Model` +  + +1. Select the location of your TFLite file. Note that the tooling will + configure the module's dependency on your behalf with ML Model binding and + all dependencies automatically inserted into your Android module's + `build.gradle` file. + + Optional: Select the second checkbox for importing TensorFlow GPU if you + want to use GPU acceleration. +  + +1. Click `Finish`. + +1. The following screen will appear after the import is successful. To start + using the model, select Kotlin or Java, copy and paste the code under the + `Sample Code` section. You can get back to this screen by double clicking + the TFLite model under the `ml` directory in Android Studio. +  + +### Accelerating model inference {:#acceleration} + +ML Model Binding provides a way for developers to accelerate their code through +the use of delegates and the number of threads. + +Note: The TensorFlow Lite Interpreter must be created on the same thread as when +is is run. Otherwise, TfLiteGpuDelegate Invoke: GpuDelegate must run on the same +thread where it was initialized. may occur. + +Step 1. Check the module `build.gradle` file that it contains the following +dependency: + +```java + dependencies { + ... + // TFLite GPU delegate 2.3.0 or above is required. + implementation 'org.tensorflow:tensorflow-lite-gpu:2.3.0' + } +``` + +Step 2. Detect if GPU running on the device is compatible with TensorFlow GPU +delegate, if not run the model using multiple CPU threads: + +<div> + <devsite-selector> + <section> + <h3>Kotlin</h3> + <p><pre class="prettyprint lang-kotlin"> + import org.tensorflow.lite.gpu.CompatibilityList + import org.tensorflow.lite.gpu.GpuDelegate + + val compatList = CompatibilityList() + + val options = if(compatList.isDelegateSupportedOnThisDevice) { + // if the device has a supported GPU, add the GPU delegate + Model.Options.Builder().setDevice(Model.Device.GPU).build() + } else { + // if the GPU is not supported, run on 4 threads + Model.Options.Builder().setNumThreads(4).build() + } + + // Initialize the model as usual feeding in the options object + val myModel = MyModel.newInstance(context, options) + + // Run inference per sample code + </pre></p> + </section> + <section> + <h3>Java</h3> + <p><pre class="prettyprint lang-java"> + import org.tensorflow.lite.support.model.Model + import org.tensorflow.lite.gpu.CompatibilityList; + import org.tensorflow.lite.gpu.GpuDelegate; + + // Initialize interpreter with GPU delegate + Model.Options options; + CompatibilityList compatList = CompatibilityList(); + + if(compatList.isDelegateSupportedOnThisDevice()){ + // if the device has a supported GPU, add the GPU delegate + options = Model.Options.Builder().setDevice(Model.Device.GPU).build(); + } else { + // if the GPU is not supported, run on 4 threads + options = Model.Options.Builder().setNumThreads(4).build(); + } + + MyModel myModel = new MyModel.newInstance(context, options); + + // Run inference per sample code + </pre></p> + </section> + </devsite-selector> +</div> + +## Generate model interfaces with TensorFlow Lite code generator {:#codegen} Note: TensorFlow Lite wrapper code generator currently only supports Android. @@ -14,7 +131,7 @@ under relevant fields in [metadata_schema.fbs](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/metadata_schema.fbs), to see how the codegen tool parses each field. -## Generate wrapper Code +### Generate wrapper Code You will need to install the following tooling in your terminal: @@ -45,9 +162,9 @@ from google.colab import files files.download('classify_wrapper.zip') ``` -## Using the generated code +### Using the generated code -### Step 1: Import the generated code +#### Step 1: Import the generated code Unzip the generated code if necessary into a directory structure. The root of the generated code is assumed to be `SRC_ROOT`. @@ -59,7 +176,7 @@ select `SRC_ROOT` Using the above example, the directory and the module imported would be called `classify_wrapper`. -### Step 2: Update the app's `build.gradle` file +#### Step 2: Update the app's `build.gradle` file In the app module that will be consuming the generated library module: @@ -77,7 +194,7 @@ Under the dependencies section, add the following: implementation project(":classify_wrapper") ``` -### Step 3: Using the model +#### Step 3: Using the model ```java // 1. Initialize the model @@ -103,7 +220,7 @@ if(null != myImageClassifier) { } ``` -## Accelerating model inference +### Accelerating model inference The generated code provides a way for developers to accelerate their code through the use of [delegates](../performance/delegates.md) and the number of @@ -127,7 +244,7 @@ try { } ``` -## Troubleshooting +### Troubleshooting If you get a 'java.io.FileNotFoundException: This file can not be opened as a file descriptor; it is probably compressed' error, insert the following lines @@ -138,16 +255,3 @@ aaptOptions { noCompress "tflite" } ``` - -## Generate code with Android Studio ML Model Binding - -[Android Studio ML Model Binding](https://developer.android.com/studio/preview/features#tensor-flow-lite-models) -allows you to directly import TensorFlow Lite models and use them in your -Android Studio projects. It generates easy-to-use classes so you can run your -model with less code and better type safety. See the -[introduction](https://developer.android.com/studio/preview/features#tensor-flow-lite-models) -for more details. - -Note: Code generated by the TensorFlow Lite Android code generator may include -some latest API or experimental features, which can be a super set of the one -generated by the Android Studio ML Model Binding. diff --git a/tensorflow/lite/g3doc/inference_with_metadata/overview.md b/tensorflow/lite/g3doc/inference_with_metadata/overview.md index 2bb628bddb9..736f1540c85 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/overview.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/overview.md @@ -4,25 +4,32 @@ Inferencing [models with metadata](../convert/metadata.md) can be as easy as just a few lines of code. TensorFlow Lite metadata contains a rich description of what the model does and how to use the model. It can empower code generators to automatically generate the inference code for you, such as using the -[TensorFlow Lite Android code generator](codegen.md#generate-code-with-tensorflow-lite-android-code-generator) -and the -[Android Studio ML Binding feature](codegen.md#generate-code-with-android-studio-ml-model-binding). -It can also be used to configure your custom inference pipeline. +[Android Studio ML Binding feature](codegen.md#mlbinding) or +[TensorFlow Lite Android code generator](codegen.md#codegen). It can also be +used to configure your custom inference pipeline. ## Tools and libraries TensorFlow Lite provides varieties of tools and libraries to serve different tiers of deployment requirements as follows: -### Generate model interface with the TensorFlow Lite Code Generator +### Generate model interface with Android code generators -[TensorFlow Lite Code Generator](codegen.md) is an executable that generates -model interface automatically based on the metadata. It currently supports -Android with Java. The wrapper code removes the need to interact directly with -`ByteBuffer`. Instead, developers can interact with the TensorFlow Lite model -with typed objects such as `Bitmap` and `Rect`. Android Studio users can also -get access to the codegen feature through -[Android Studio ML Binding](codegen.md#generate-code-with-android-studio-ml-model-binding). +There are two ways to automatically generate the necessary Android wrapper code +for TensorFlow Lite model with metadata: + +1. [Android Studio ML Model Binding](codegen.md#mlbinding) is tooling available + within Android Studio to import TensorFlow Lite model through a graphical + interface. Android Studio will automatically configure settings for the + project and generate wrapper classes based on the model metadata. + +2. [TensorFlow Lite Code Generator](codegen.md#codegen) is an executable that + generates model interface automatically based on the metadata. It currently + supports Android with Java. The wrapper code removes the need to interact + directly with `ByteBuffer`. Instead, developers can interact with the + TensorFlow Lite model with typed objects such as `Bitmap` and `Rect`. + Android Studio users can also get access to the codegen feature through + [Android Studio ML Binding](codegen.md#mlbinding). ### Leverage out-of-box APIs with the TensorFlow Lite Task Library diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_question_answerer.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_question_answerer.md index 995c1cf7478..168574b45bd 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_question_answerer.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_question_answerer.md @@ -18,7 +18,7 @@ documentation for the Question-Answer model The following models are compatible with the `BertNLClassifier` API. * Models created by - [TensorFlow Lite Model Maker for Question Answer](https://www.tensorflow.org/lite/tutorials/model_maker_question_answer). + [TensorFlow Lite Model Maker for BERT Question Answer](https://www.tensorflow.org/lite/tutorials/model_maker_question_answer). * The [pretrained BERT models on TensorFlow Hub](https://tfhub.dev/tensorflow/collections/lite/task-library/bert-question-answerer/1). diff --git a/tensorflow/lite/g3doc/models/bert_qa/overview.md b/tensorflow/lite/g3doc/models/bert_qa/overview.md index fc75cb7f33d..af042e44d07 100644 --- a/tensorflow/lite/g3doc/models/bert_qa/overview.md +++ b/tensorflow/lite/g3doc/models/bert_qa/overview.md @@ -1,4 +1,4 @@ -# Question and answer +# BERT Question and Answer Use a pre-trained model to answer questions based on the content of a given passage. diff --git a/tensorflow/lite/g3doc/performance/gpu.md b/tensorflow/lite/g3doc/performance/gpu.md index e992518baf1..581f2f2d843 100644 --- a/tensorflow/lite/g3doc/performance/gpu.md +++ b/tensorflow/lite/g3doc/performance/gpu.md @@ -158,6 +158,12 @@ Note: The TensorFlow Lite Interpreter must be created on the same thread as where it is run. Otherwise, `TfLiteGpuDelegate Invoke: GpuDelegate must run on the same thread where it was initialized.` may occur. +There are two ways to invoke model acceleration depending on if you are using +[Android Studio ML Model Binding](../inference_with_metadata/codegen#acceleration) +or TensorFlow Lite Interpreter. + +#### TensorFlow Lite Interpreter + Look at the demo to see how to add the delegate. In your application, add the AAR as above, import `org.tensorflow.lite.gpu.GpuDelegate` module, and use the`addDelegate` function to register the GPU delegate to the interpreter: diff --git a/tensorflow/lite/g3doc/performance/gpu_advanced.md b/tensorflow/lite/g3doc/performance/gpu_advanced.md index d23c87c8288..45fbaf4c9c8 100644 --- a/tensorflow/lite/g3doc/performance/gpu_advanced.md +++ b/tensorflow/lite/g3doc/performance/gpu_advanced.md @@ -65,7 +65,12 @@ allows the appropriate versions; for example, ADD v2. ## Basic usage -### Android (Kotlin / Java) +There are two ways to invoke model acceleration in Android depending on if you +are using +[Android Studio ML Model Binding](../inference_with_metadata/codegen#acceleration) +or TensorFlow Lite Interpreter. + +### Android via TensorFlow Lite Interpreter Add the `tensorflow-lite-gpu` package alongside the existing `tensorflow-lite` package in the existing `dependencies` block. diff --git a/tensorflow/lite/g3doc/tutorials/model_maker_question_answer.ipynb b/tensorflow/lite/g3doc/tutorials/model_maker_question_answer.ipynb index 328f9d0cb70..04e351785b0 100644 --- a/tensorflow/lite/g3doc/tutorials/model_maker_question_answer.ipynb +++ b/tensorflow/lite/g3doc/tutorials/model_maker_question_answer.ipynb @@ -37,7 +37,7 @@ "id": "Gb7qyhNL1yWt" }, "source": [ - "# Question Answer with TensorFlow Lite Model Maker" + "# BERT Question Answer with TensorFlow Lite Model Maker" ] }, { @@ -79,7 +79,7 @@ "id": "UxEHFTk755qw" }, "source": [ - "# Introduction to Question Answer Task" + "# Introduction to BERT Question Answer Task" ] }, { diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java index cea4a4ae2c2..e14c38d1b08 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -178,6 +178,11 @@ public final class Interpreter implements AutoCloseable { Boolean allowFp16PrecisionForFp32; Boolean allowBufferHandleOutput; Boolean allowCancellation; + + // TODO(b/171856982): update the comment when applying XNNPACK delegate by default is + // enabled for C++ TfLite library on Android platform. + // Note: the initial "null" value indicates default behavior which may mean XNNPACK + // delegate will be applied by default. Boolean useXNNPACK; final List<Delegate> delegates = new ArrayList<>(); } diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index 3d439e4470c..1eaaafdff88 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -83,9 +83,17 @@ final class NativeInterpreterWrapper implements AutoCloseable { allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput.booleanValue()); } applyDelegates(options); + + // Simply use "-1" to represent the default mode. + int applyXNNPACKMode = -1; if (options.useXNNPACK != null) { - useXNNPACK( - interpreterHandle, errorHandle, options.useXNNPACK.booleanValue(), options.numThreads); + applyXNNPACKMode = options.useXNNPACK.booleanValue() ? 1 : 0; + } + + // TODO(b/171856982): uncomment the following when applying XNNPACK delegate by default is + // enabled for C++ TfLite library on Android platform. + if (applyXNNPACKMode == 1 /*|| applyXNNPACKMode == -1*/) { + useXNNPACK(interpreterHandle, errorHandle, applyXNNPACKMode, options.numThreads); } allocateTensors(interpreterHandle, errorHandle); this.isMemoryAllocated = true; @@ -459,7 +467,7 @@ final class NativeInterpreterWrapper implements AutoCloseable { private static native void allowBufferHandleOutput(long interpreterHandle, boolean allow); private static native void useXNNPACK( - long interpreterHandle, long errorHandle, boolean state, int numThreads); + long interpreterHandle, long errorHandle, int state, int numThreads); private static native long createErrorReporter(int size); diff --git a/tensorflow/lite/java/src/main/native/BUILD b/tensorflow/lite/java/src/main/native/BUILD index 75e32dfc2a9..8b6bad53206 100644 --- a/tensorflow/lite/java/src/main/native/BUILD +++ b/tensorflow/lite/java/src/main/native/BUILD @@ -32,7 +32,6 @@ cc_library( "//tensorflow/lite:util", "//tensorflow/lite/c:common", "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate_hdrs_only", - "//tensorflow/lite/experimental/tflite_api_dispatcher:tflite_api_dispatcher_with_kernels", "//tensorflow/lite/java/jni", ], alwayslink = 1, diff --git a/tensorflow/lite/java/src/main/native/builtin_ops_jni.cc b/tensorflow/lite/java/src/main/native/builtin_ops_jni.cc index aa654046ef4..eb17fdcf2a5 100644 --- a/tensorflow/lite/java/src/main/native/builtin_ops_jni.cc +++ b/tensorflow/lite/java/src/main/native/builtin_ops_jni.cc @@ -21,12 +21,14 @@ limitations under the License. namespace tflite { // The JNI code in interpreter_jni.cc expects a CreateOpResolver() function in -// the tflite namespace. This one instantiates a BuiltinOpResolver, with all the -// builtin ops. For smaller binary sizes users should avoid linking this in, and -// should provide a custom make CreateOpResolver() instead. +// the tflite namespace. This one instantiates a +// BuiltinOpResolverWithoutDefaultDelegates, with all the builtin ops but +// without applying any TfLite delegates by default (like the XNNPACK delegate). +// For smaller binary sizes users should avoid linking this in, and should +// provide a custom make CreateOpResolver() instead. std::unique_ptr<OpResolver> CreateOpResolver() { // NOLINT return std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver>( - new tflite::ops::builtin::BuiltinOpResolver()); + new tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates()); } } // namespace tflite diff --git a/tensorflow/lite/java/src/main/native/jni_utils.cc b/tensorflow/lite/java/src/main/native/jni_utils.cc index 0187d489ee8..1dc7d076fad 100644 --- a/tensorflow/lite/java/src/main/native/jni_utils.cc +++ b/tensorflow/lite/java/src/main/native/jni_utils.cc @@ -19,16 +19,15 @@ limitations under the License. #include <stdio.h> #include <stdlib.h> +namespace tflite { +namespace jni { + const char kIllegalArgumentException[] = "java/lang/IllegalArgumentException"; const char kIllegalStateException[] = "java/lang/IllegalStateException"; const char kNullPointerException[] = "java/lang/NullPointerException"; -const char kIndexOutOfBoundsException[] = "java/lang/IndexOutOfBoundsException"; const char kUnsupportedOperationException[] = "java/lang/UnsupportedOperationException"; -namespace tflite { -namespace jni { - void ThrowException(JNIEnv* env, const char* clazz, const char* fmt, ...) { va_list args; va_start(args, fmt); diff --git a/tensorflow/lite/java/src/main/native/jni_utils.h b/tensorflow/lite/java/src/main/native/jni_utils.h index cb0cdf5b49f..62d1e43a961 100644 --- a/tensorflow/lite/java/src/main/native/jni_utils.h +++ b/tensorflow/lite/java/src/main/native/jni_utils.h @@ -20,23 +20,23 @@ limitations under the License. #include "tensorflow/lite/error_reporter.h" +namespace tflite { +namespace jni { + extern const char kIllegalArgumentException[]; extern const char kIllegalStateException[]; extern const char kNullPointerException[]; -extern const char kIndexOutOfBoundsException[]; extern const char kUnsupportedOperationException[]; -namespace tflite { -namespace jni { - void ThrowException(JNIEnv* env, const char* clazz, const char* fmt, ...); class BufferErrorReporter : public ErrorReporter { public: BufferErrorReporter(JNIEnv* env, int limit); - virtual ~BufferErrorReporter(); + ~BufferErrorReporter() override; int Report(const char* format, va_list args) override; const char* CachedErrorMessage(); + using ErrorReporter::Report; private: char* buffer_; diff --git a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index 959acfb205e..3551286b966 100644 --- a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -23,8 +23,10 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" -#include "tensorflow/lite/experimental/tflite_api_dispatcher/tflite_api_dispatcher.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/interpreter_builder.h" #include "tensorflow/lite/java/src/main/native/jni_utils.h" +#include "tensorflow/lite/model_builder.h" #include "tensorflow/lite/util.h" namespace tflite { @@ -37,29 +39,27 @@ using tflite::jni::ThrowException; namespace { -tflite_api_dispatcher::Interpreter* convertLongToInterpreter(JNIEnv* env, - jlong handle) { +tflite::Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) { if (handle == 0) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Internal error: Invalid handle to Interpreter."); return nullptr; } - return reinterpret_cast<tflite_api_dispatcher::Interpreter*>(handle); + return reinterpret_cast<tflite::Interpreter*>(handle); } -tflite_api_dispatcher::TfLiteModel* convertLongToModel(JNIEnv* env, - jlong handle) { +tflite::FlatBufferModel* convertLongToModel(JNIEnv* env, jlong handle) { if (handle == 0) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Internal error: Invalid handle to model."); return nullptr; } - return reinterpret_cast<tflite_api_dispatcher::TfLiteModel*>(handle); + return reinterpret_cast<tflite::FlatBufferModel*>(handle); } BufferErrorReporter* convertLongToErrorReporter(JNIEnv* env, jlong handle) { if (handle == 0) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Internal error: Invalid handle to ErrorReporter."); return nullptr; } @@ -68,7 +68,7 @@ BufferErrorReporter* convertLongToErrorReporter(JNIEnv* env, jlong handle) { TfLiteDelegate* convertLongToDelegate(JNIEnv* env, jlong handle) { if (handle == 0) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Internal error: Invalid handle to delegate."); return nullptr; } @@ -80,7 +80,7 @@ std::vector<int> convertJIntArrayToVector(JNIEnv* env, jintArray inputs) { std::vector<int> outputs(size, 0); jint* ptr = env->GetIntArrayElements(inputs, nullptr); if (ptr == nullptr) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Array has empty dimensions."); return {}; } @@ -130,7 +130,7 @@ bool AreDimsDifferent(JNIEnv* env, TfLiteTensor* tensor, jintArray dims) { int num_dims = static_cast<int>(env->GetArrayLength(dims)); jint* ptr = env->GetIntArrayElements(dims, nullptr); if (ptr == nullptr) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Empty dimensions of input array."); return true; } @@ -165,12 +165,11 @@ JNIEXPORT jobjectArray JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env, jclass clazz, jlong handle) { - tflite_api_dispatcher::Interpreter* interpreter = - convertLongToInterpreter(env, handle); + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return nullptr; jclass string_class = env->FindClass("java/lang/String"); if (string_class == nullptr) { - ThrowException(env, kUnsupportedOperationException, + ThrowException(env, tflite::jni::kUnsupportedOperationException, "Internal error: Can not find java/lang/String class to get " "input names."); return nullptr; @@ -188,8 +187,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env, JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors( JNIEnv* env, jclass clazz, jlong handle, jlong error_handle) { - tflite_api_dispatcher::Interpreter* interpreter = - convertLongToInterpreter(env, handle); + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return; BufferErrorReporter* error_reporter = convertLongToErrorReporter(env, error_handle); @@ -197,7 +195,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors( if (interpreter->AllocateTensors() != kTfLiteOk) { ThrowException( - env, kIllegalStateException, + env, tflite::jni::kIllegalStateException, "Internal error: Unexpected failure when preparing tensor allocations:" " %s", error_reporter->CachedErrorMessage()); @@ -207,8 +205,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors( JNIEXPORT jboolean JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_hasUnresolvedFlexOp( JNIEnv* env, jclass clazz, jlong handle) { - tflite_api_dispatcher::Interpreter* interpreter = - convertLongToInterpreter(env, handle); + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return JNI_FALSE; // TODO(b/132995737): Remove this logic by caching whether an unresolved @@ -231,8 +228,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_hasUnresolvedFlexOp( JNIEXPORT jint JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndex( JNIEnv* env, jclass clazz, jlong handle, jint input_index) { - tflite_api_dispatcher::Interpreter* interpreter = - convertLongToInterpreter(env, handle); + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return 0; return interpreter->inputs()[input_index]; } @@ -240,8 +236,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndex( JNIEXPORT jint JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndex( JNIEnv* env, jclass clazz, jlong handle, jint output_index) { - tflite_api_dispatcher::Interpreter* interpreter = - convertLongToInterpreter(env, handle); + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return 0; return interpreter->outputs()[output_index]; } @@ -249,8 +244,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndex( JNIEXPORT jint JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_getExecutionPlanLength( JNIEnv* env, jclass clazz, jlong handle) { - tflite_api_dispatcher::Interpreter* interpreter = - convertLongToInterpreter(env, handle); + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return 0; return static_cast<jint>(interpreter->execution_plan().size()); } @@ -259,8 +253,7 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env, jclass clazz, jlong handle) { - tflite_api_dispatcher::Interpreter* interpreter = - convertLongToInterpreter(env, handle); + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return 0; return static_cast<jint>(interpreter->inputs().size()); } @@ -269,8 +262,7 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputCount(JNIEnv* env, jclass clazz, jlong handle) { - tflite_api_dispatcher::Interpreter* interpreter = - convertLongToInterpreter(env, handle); + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return 0; return static_cast<jint>(interpreter->outputs().size()); } @@ -279,12 +271,11 @@ JNIEXPORT jobjectArray JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env, jclass clazz, jlong handle) { - tflite_api_dispatcher::Interpreter* interpreter = - convertLongToInterpreter(env, handle); + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return nullptr; jclass string_class = env->FindClass("java/lang/String"); if (string_class == nullptr) { - ThrowException(env, kUnsupportedOperationException, + ThrowException(env, tflite::jni::kUnsupportedOperationException, "Internal error: Can not find java/lang/String class to get " "output names."); return nullptr; @@ -302,8 +293,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env, JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32( JNIEnv* env, jclass clazz, jlong handle, jboolean allow) { - tflite_api_dispatcher::Interpreter* interpreter = - convertLongToInterpreter(env, handle); + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return; interpreter->SetAllowFp16PrecisionForFp32(static_cast<bool>(allow)); } @@ -311,23 +301,21 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32( JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_allowBufferHandleOutput( JNIEnv* env, jclass clazz, jlong handle, jboolean allow) { - tflite_api_dispatcher::Interpreter* interpreter = - convertLongToInterpreter(env, handle); + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return; interpreter->SetAllowBufferHandleOutput(allow); } JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK( - JNIEnv* env, jclass clazz, jlong handle, jlong error_handle, jboolean state, + JNIEnv* env, jclass clazz, jlong handle, jlong error_handle, jint state, jint num_threads) { // If not using xnnpack, simply don't apply the delegate. - if (!state) { + if (state == 0) { return; } - tflite_api_dispatcher::Interpreter* interpreter = - convertLongToInterpreter(env, handle); + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) { return; } @@ -355,8 +343,8 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK( if (num_threads > 0) { options.num_threads = num_threads; } - tflite_api_dispatcher::Interpreter::TfLiteDelegatePtr delegate( - xnnpack_create(&options), xnnpack_delete); + tflite::Interpreter::TfLiteDelegatePtr delegate(xnnpack_create(&options), + xnnpack_delete); auto delegation_status = interpreter->ModifyGraphWithDelegate(std::move(delegate)); // kTfLiteApplicationError occurs in cases where delegation fails but @@ -365,12 +353,19 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK( // TODO(b/166483905): Add support for multiple delegates when model allows. if (delegation_status != kTfLiteOk && delegation_status != kTfLiteApplicationError) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Internal error: Failed to apply XNNPACK delegate: %s", error_reporter->CachedErrorMessage()); } + } else if (state == -1) { + // Instead of throwing an exception, we tolerate the missing of such + // dependencies because we try to apply XNNPACK delegate by default. + TF_LITE_REPORT_ERROR( + error_reporter, + "WARNING: Missing necessary XNNPACK delegate dependencies to apply it " + "by default.\n"); } else { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Failed to load XNNPACK delegate from current runtime. " "Have you added the necessary dependencies?"); } @@ -381,8 +376,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env, jclass clazz, jlong handle, jint num_threads) { - tflite_api_dispatcher::Interpreter* interpreter = - convertLongToInterpreter(env, handle); + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return; interpreter->SetNumThreads(static_cast<int>(num_threads)); } @@ -396,7 +390,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter( } // Verifies whether the model is a flatbuffer file. -class JNIFlatBufferVerifier : public tflite_api_dispatcher::TfLiteVerifier { +class JNIFlatBufferVerifier : public tflite::TfLiteVerifier { public: bool Verify(const char* data, int length, tflite::ErrorReporter* reporter) override { @@ -416,13 +410,13 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel( if (error_reporter == nullptr) return 0; const char* path = env->GetStringUTFChars(model_file, nullptr); - std::unique_ptr<tflite_api_dispatcher::TfLiteVerifier> verifier; + std::unique_ptr<tflite::TfLiteVerifier> verifier; verifier.reset(new JNIFlatBufferVerifier()); - auto model = tflite_api_dispatcher::TfLiteModel::VerifyAndBuildFromFile( + auto model = tflite::FlatBufferModel::VerifyAndBuildFromFile( path, verifier.get(), error_reporter); if (!model) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Contents of %s does not encode a valid " "TensorFlow Lite model: %s", path, error_reporter->CachedErrorMessage()); @@ -443,15 +437,15 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( static_cast<char*>(env->GetDirectBufferAddress(model_buffer)); jlong capacity = env->GetDirectBufferCapacity(model_buffer); if (!VerifyModel(buf, capacity)) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "ByteBuffer is not a valid flatbuffer model"); return 0; } - auto model = tflite_api_dispatcher::TfLiteModel::BuildFromBuffer( + auto model = tflite::FlatBufferModel::BuildFromBuffer( buf, static_cast<size_t>(capacity), error_reporter); if (!model) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "ByteBuffer does not encode a valid model: %s", error_reporter->CachedErrorMessage()); return 0; @@ -463,18 +457,17 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle, jint num_threads) { - tflite_api_dispatcher::TfLiteModel* model = - convertLongToModel(env, model_handle); + tflite::FlatBufferModel* model = convertLongToModel(env, model_handle); if (model == nullptr) return 0; BufferErrorReporter* error_reporter = convertLongToErrorReporter(env, error_handle); if (error_reporter == nullptr) return 0; auto resolver = ::tflite::CreateOpResolver(); - std::unique_ptr<tflite_api_dispatcher::Interpreter> interpreter; - TfLiteStatus status = tflite_api_dispatcher::InterpreterBuilder( - *model, *(resolver.get()))(&interpreter, static_cast<int>(num_threads)); + std::unique_ptr<tflite::Interpreter> interpreter; + TfLiteStatus status = tflite::InterpreterBuilder(*model, *(resolver.get()))( + &interpreter, static_cast<int>(num_threads)); if (status != kTfLiteOk) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Internal error: Cannot create interpreter: %s", error_reporter->CachedErrorMessage()); return 0; @@ -487,7 +480,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( // Sets inputs, runs inference, and returns outputs as long handles. JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run( JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) { - tflite_api_dispatcher::Interpreter* interpreter = + tflite::Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle); if (interpreter == nullptr) return; BufferErrorReporter* error_reporter = @@ -496,7 +489,7 @@ JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run( if (interpreter->Invoke() != kTfLiteOk) { // TODO(b/168266570): Return InterruptedException. - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Internal error: Failed to run on the given Interpreter: %s", error_reporter->CachedErrorMessage()); return; @@ -506,12 +499,11 @@ JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run( JNIEXPORT jint JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType( JNIEnv* env, jclass clazz, jlong handle, jint output_idx) { - tflite_api_dispatcher::Interpreter* interpreter = - convertLongToInterpreter(env, handle); + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return -1; const int idx = static_cast<int>(output_idx); if (output_idx < 0 || output_idx >= interpreter->outputs().size()) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Failed to get %d-th output out of %d outputs", output_idx, interpreter->outputs().size()); return -1; @@ -528,11 +520,11 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput( BufferErrorReporter* error_reporter = convertLongToErrorReporter(env, error_handle); if (error_reporter == nullptr) return JNI_FALSE; - tflite_api_dispatcher::Interpreter* interpreter = + tflite::Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle); if (interpreter == nullptr) return JNI_FALSE; if (input_idx < 0 || input_idx >= interpreter->inputs().size()) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Input error: Can not resize %d-th input for a model having " "%d inputs.", input_idx, interpreter->inputs().size()); @@ -552,7 +544,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput( tensor_idx, convertJIntArrayToVector(env, dims)); } if (status != kTfLiteOk) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Internal error: Failed to resize %d-th input: %s", input_idx, error_reporter->CachedErrorMessage()); return JNI_FALSE; @@ -565,7 +557,7 @@ JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_applyDelegate( JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, jlong delegate_handle) { - tflite_api_dispatcher::Interpreter* interpreter = + tflite::Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle); if (interpreter == nullptr) return; @@ -578,7 +570,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_applyDelegate( TfLiteStatus status = interpreter->ModifyGraphWithDelegate(delegate); if (status != kTfLiteOk) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Internal error: Failed to apply delegate: %s", error_reporter->CachedErrorMessage()); } @@ -587,7 +579,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_applyDelegate( JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_resetVariableTensors( JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) { - tflite_api_dispatcher::Interpreter* interpreter = + tflite::Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle); if (interpreter == nullptr) return; @@ -597,7 +589,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_resetVariableTensors( TfLiteStatus status = interpreter->ResetVariableTensors(); if (status != kTfLiteOk) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Internal error: Failed to reset variable tensors: %s", error_reporter->CachedErrorMessage()); } @@ -606,10 +598,10 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_resetVariableTensors( JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_createCancellationFlag( JNIEnv* env, jclass clazz, jlong interpreter_handle) { - tflite_api_dispatcher::Interpreter* interpreter = + tflite::Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle); if (interpreter == nullptr) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Internal error: Invalid handle to interpreter."); } std::atomic_bool* cancellation_flag = new std::atomic_bool(false); diff --git a/tensorflow/lite/java/src/main/native/tensor_jni.cc b/tensorflow/lite/java/src/main/native/tensor_jni.cc index 5dfca9ebe6c..24a13bbc3f6 100644 --- a/tensorflow/lite/java/src/main/native/tensor_jni.cc +++ b/tensorflow/lite/java/src/main/native/tensor_jni.cc @@ -20,7 +20,7 @@ limitations under the License. #include <string> #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/tflite_api_dispatcher/tflite_api_dispatcher.h" +#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/java/src/main/native/jni_utils.h" #include "tensorflow/lite/string_util.h" @@ -39,21 +39,20 @@ static const char* kStringClassPath = "java/lang/String"; // invalidate all TfLiteTensor* handles during inference or allocation. class TensorHandle { public: - TensorHandle(tflite_api_dispatcher::Interpreter* interpreter, - int tensor_index) + TensorHandle(tflite::Interpreter* interpreter, int tensor_index) : interpreter_(interpreter), tensor_index_(tensor_index) {} TfLiteTensor* tensor() const { return interpreter_->tensor(tensor_index_); } int index() const { return tensor_index_; } private: - tflite_api_dispatcher::Interpreter* const interpreter_; + tflite::Interpreter* const interpreter_; const int tensor_index_; }; TfLiteTensor* GetTensorFromHandle(JNIEnv* env, jlong handle) { if (handle == 0) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Internal error: Invalid handle to TfLiteTensor."); return nullptr; } @@ -62,7 +61,7 @@ TfLiteTensor* GetTensorFromHandle(JNIEnv* env, jlong handle) { int GetTensorIndexFromHandle(JNIEnv* env, jlong handle) { if (handle == 0) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Internal error: Invalid handle to TfLiteTensor."); return -1; } @@ -110,7 +109,7 @@ size_t WriteOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type, const int num_elements = env->GetArrayLength(array); size_t to_copy = num_elements * ElementByteSize(type); if (to_copy > dst_size) { - ThrowException(env, kIllegalStateException, + ThrowException(env, tflite::jni::kIllegalStateException, "Internal error: cannot write Java array of %d bytes to " "Tensor of %d bytes", to_copy, dst_size); @@ -150,7 +149,7 @@ size_t WriteOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type, } default: { ThrowException( - env, kUnsupportedOperationException, + env, tflite::jni::kUnsupportedOperationException, "DataType error: TensorFlowLite currently supports float " "(32 bits), int (32 bits), byte (8 bits), bool (8 bits), and long " "(64 bits), support for other types (DataType %d in this " @@ -167,7 +166,7 @@ size_t ReadOneDimensionalArray(JNIEnv* env, TfLiteType data_type, const size_t size = len * ElementByteSize(data_type); if (size > src_size) { ThrowException( - env, kIllegalStateException, + env, tflite::jni::kIllegalStateException, "Internal error: cannot fill a Java array of %d bytes with a Tensor of " "%d bytes", size, src_size); @@ -205,7 +204,7 @@ size_t ReadOneDimensionalArray(JNIEnv* env, TfLiteType data_type, return size; } default: { - ThrowException(env, kIllegalStateException, + ThrowException(env, tflite::jni::kIllegalStateException, "DataType error: invalid DataType(%d)", data_type); } } @@ -345,7 +344,7 @@ void WriteScalar(JNIEnv* env, jobject src, TfLiteType type, void* dst, size_t src_size = ElementByteSize(type); if (src_size != dst_size) { ThrowException( - env, kIllegalStateException, + env, tflite::jni::kIllegalStateException, "Scalar (%d bytes) not compatible with allocated tensor (%d bytes)", src_size, dst_size); return; @@ -377,7 +376,8 @@ void WriteScalar(JNIEnv* env, jobject src, TfLiteType type, void* dst, return; } default: - ThrowException(env, kIllegalStateException, "Invalid DataType(%d)", type); + ThrowException(env, tflite::jni::kIllegalStateException, + "Invalid DataType(%d)", type); return; } } @@ -398,8 +398,8 @@ extern "C" { JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_Tensor_create( JNIEnv* env, jclass clazz, jlong interpreter_handle, jint tensor_index) { - tflite_api_dispatcher::Interpreter* interpreter = - reinterpret_cast<tflite_api_dispatcher::Interpreter*>(interpreter_handle); + tflite::Interpreter* interpreter = + reinterpret_cast<tflite::Interpreter*>(interpreter_handle); return reinterpret_cast<jlong>(new TensorHandle(interpreter, tensor_index)); } @@ -415,7 +415,7 @@ JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env, TfLiteTensor* tensor = GetTensorFromHandle(env, handle); if (tensor == nullptr) return nullptr; if (tensor->data.raw == nullptr) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Internal error: Tensor hasn't been allocated."); return nullptr; } @@ -430,13 +430,13 @@ JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeDirectBuffer( void* src_data_raw = env->GetDirectBufferAddress(src); if (!src_data_raw) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Input ByteBuffer is not a direct buffer"); return; } if (!tensor->data.data) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Internal error: Tensor hasn't been allocated."); return; } @@ -459,7 +459,7 @@ Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env, if (tensor == nullptr) return; int num_dims = tensor->dims->size; if (num_dims == 0) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Internal error: Cannot copy empty/scalar Tensors."); return; } @@ -481,12 +481,12 @@ Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env, TfLiteTensor* tensor = GetTensorFromHandle(env, handle); if (tensor == nullptr) return; if (tensor->type != kTfLiteString && tensor->data.raw == nullptr) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Internal error: Target Tensor hasn't been allocated."); return; } if (tensor->dims->size == 0) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Internal error: Cannot copy empty/scalar Tensors."); return; } @@ -503,12 +503,12 @@ JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeScalar( TfLiteTensor* tensor = GetTensorFromHandle(env, handle); if (tensor == nullptr) return; if ((tensor->type != kTfLiteString) && (tensor->data.raw == nullptr)) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Internal error: Target Tensor hasn't been allocated."); return; } if ((tensor->dims->size != 0) && (tensor->dims->data[0] != 1)) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Internal error: Cannot write Java scalar to non-scalar " "Tensor."); return; @@ -533,7 +533,7 @@ JNIEXPORT jstring JNICALL Java_org_tensorflow_lite_Tensor_name(JNIEnv* env, jlong handle) { TfLiteTensor* tensor = GetTensorFromHandle(env, handle); if (tensor == nullptr) { - ThrowException(env, kIllegalArgumentException, + ThrowException(env, tflite::jni::kIllegalArgumentException, "Target Tensor doesn't exist."); return nullptr; } diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java index 7009f069bd8..f5b02173033 100644 --- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -484,6 +484,10 @@ public final class InterpreterTest { fail(); } catch (IllegalStateException e) { // Expected failure. + } catch (IllegalArgumentException e) { + // As we could apply some TfLite delegate by default, the flex ops preparation could fail if + // the flex delegate isn't applied first, in which this type of exception is thrown. + // Expected failure } } diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java index 19ea16e94eb..113805e83c5 100644 --- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java @@ -104,7 +104,15 @@ public final class NativeInterpreterWrapperTest { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(MODEL_WITH_CUSTOM_OP_PATH); fail(); } catch (IllegalStateException e) { - assertThat(e).hasMessageThat().contains("Encountered unresolved custom op: Assign"); + assertThat(e) + .hasMessageThat() + .contains("preparing tensor allocations: Encountered unresolved custom op: Assign"); + } catch (IllegalArgumentException e) { + // As we could apply TfLite delegate by default, during which the prepration of this + // unresolved custom op could fail and this type of exception is thrown. + assertThat(e) + .hasMessageThat() + .containsMatch("Failed to apply .* delegate: Encountered unresolved custom op: Assign"); } } @@ -201,8 +209,20 @@ public final class NativeInterpreterWrapperTest { outputs.put(0, parsedOutputs); wrapper.run(inputs, outputs); long[] outputOneD = parsedOutputs[0][0][0]; - long[] expected = {-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L, - -892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L}; + long[] expected = { + -892834092L, + 923423L, + 2123918239018L, + -892834092L, + 923423L, + 2123918239018L, + -892834092L, + 923423L, + 2123918239018L, + -892834092L, + 923423L, + 2123918239018L + }; assertThat(outputOneD).isEqualTo(expected); } } @@ -222,8 +242,20 @@ public final class NativeInterpreterWrapperTest { outputs.put(0, parsedOutputs); wrapper.run(inputs, outputs); byte[] outputOneD = parsedOutputs[0][0][0]; - byte[] expected = {(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0, - (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0}; + byte[] expected = { + (byte) 0xe0, + 0x4f, + (byte) 0xd0, + (byte) 0xe0, + 0x4f, + (byte) 0xd0, + (byte) 0xe0, + 0x4f, + (byte) 0xd0, + (byte) 0xe0, + 0x4f, + (byte) 0xd0 + }; assertThat(outputOneD).isEqualTo(expected); } } @@ -242,7 +274,7 @@ public final class NativeInterpreterWrapperTest { wrapper.run(inputs, outputs); String[] outputOneD = parsedOutputs[0][0][0]; String[] expected = { - "s1", "s22", "s333", "s1", "s22", "s333", "s1", "s22", "s333", "s1", "s22", "s333" + "s1", "s22", "s333", "s1", "s22", "s333", "s1", "s22", "s333", "s1", "s22", "s333" }; assertThat(outputOneD).isEqualTo(expected); } @@ -276,8 +308,8 @@ public final class NativeInterpreterWrapperTest { wrapper.run(inputs, outputs); String[] outputOneD = parsedOutputs[0][0][0]; String[] expected = { - "\uD800\uDC01", "s22", "\ud841\udf0e", "\uD800\uDC01", "s22", "\ud841\udf0e", - "\uD800\uDC01", "s22", "\ud841\udf0e", "\uD800\uDC01", "s22", "\ud841\udf0e" + "\uD800\uDC01", "s22", "\ud841\udf0e", "\uD800\uDC01", "s22", "\ud841\udf0e", + "\uD800\uDC01", "s22", "\ud841\udf0e", "\uD800\uDC01", "s22", "\ud841\udf0e" }; assertThat(outputOneD).isEqualTo(expected); } @@ -332,8 +364,8 @@ public final class NativeInterpreterWrapperTest { wrapper.run(inputs, outputs); byte[] outputOneD = parsedOutputs[0][0][0]; byte[] expected = { - (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0, - (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0 + (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0, + (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0 }; assertThat(outputOneD).isEqualTo(expected); } diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 1965914f9bc..81f474879e9 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -220,6 +220,7 @@ cc_library( ], }), alwayslink = 1, + # TODO(b/161243354): add testonly=1? ) cc_library( @@ -787,7 +788,8 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ ":builtin_op_kernels", - "//tensorflow/lite:framework_lib", + "//tensorflow/lite:cc_api", + "//tensorflow/lite:mutable_op_resolver", "//tensorflow/lite:tflite_with_xnnpack_optional", "//tensorflow/lite/c:common", "//tensorflow/lite/schema:schema_fbs", @@ -2366,4 +2368,13 @@ cc_test( ], ) +exports_files( + [ + "register.h", + "builtin_op_kernels.h", + "fully_connected.h", + ], + visibility = ["//tensorflow/lite/core/shims:__subpackages__"], +) + tflite_portable_test_suite_combined(combine_conditions = {"deps": [":test_main"]}) diff --git a/tensorflow/lite/kernels/internal/common.h b/tensorflow/lite/kernels/internal/common.h index 662a1864025..0c42f11ec05 100644 --- a/tensorflow/lite/kernels/internal/common.h +++ b/tensorflow/lite/kernels/internal/common.h @@ -186,6 +186,42 @@ inline int32_t MultiplyByQuantizedMultiplier(int64_t x, return result; } +#ifdef USE_NEON +// Round uses ARM's rounding shift right. +inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows( + int32x4x4_t input_val, int32 quantized_multiplier, int shift) { + const int left_shift = std::max(shift, 0); + const int right_shift = std::min(shift, 0); + int32x4x4_t result; + + int32x4_t multiplier_dup = vdupq_n_s32(quantized_multiplier); + int32x4_t left_shift_dup = vdupq_n_s32(left_shift); + int32x4_t right_shift_dup = vdupq_n_s32(right_shift); + + result.val[0] = + vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[0], left_shift_dup), + multiplier_dup), + right_shift_dup); + + result.val[1] = + vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[1], left_shift_dup), + multiplier_dup), + right_shift_dup); + + result.val[2] = + vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[2], left_shift_dup), + multiplier_dup), + right_shift_dup); + + result.val[3] = + vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[3], left_shift_dup), + multiplier_dup), + right_shift_dup); + + return result; +} +#endif + template <typename T> int CountLeadingZeros(T integer_input) { static_assert(std::is_unsigned<T>::value, diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/mean.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/mean.h index b2ccef7e942..2535552444a 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/mean.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/mean.h @@ -89,8 +89,8 @@ inline void MeanImpl(const tflite::MeanParams& op_params, } } - temp_sum = optimized_ops::MultiplyByQuantizedMultiplier4Rows( - temp_sum, multiplier, shift); + temp_sum = + MultiplyByQuantizedMultiplier4Rows(temp_sum, multiplier, shift); temp_sum.val[0] = vaddq_s32(temp_sum.val[0], bias_dup); temp_sum.val[1] = vaddq_s32(temp_sum.val[1], bias_dup); diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc index 07ecdd1208b..2c2edcfbb39 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -127,69 +127,6 @@ inline int32_t AccumulateNeonLane(const int32x4_t lane) { #endif } -// TODO(jaesung): Merge duplicated implementations in optimized_ops.h and -// neon_tensor_utils.cc. -inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows( - int32x4x4_t input_val, int32 quantized_multiplier, int shift) { - using gemmlowp::RoundingDivideByPOT; - using gemmlowp::SaturatingRoundingDoublingHighMul; - const int left_shift = shift > 0 ? shift : 0; - const int right_shift = shift > 0 ? 0 : -shift; - int32x4x4_t result; - // The vector type support for SaturatingRoundingDoublingHighMulth in gemmlowp - // is limited to NEON. -#ifdef GEMMLOWP_NEON - const int32x4_t left_shifted_one_dup = vdupq_n_s32(1 << left_shift); - result.val[0] = - RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( - vmulq_s32(input_val.val[0], left_shifted_one_dup), - quantized_multiplier), - right_shift); - result.val[1] = - RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( - vmulq_s32(input_val.val[1], left_shifted_one_dup), - quantized_multiplier), - right_shift); - result.val[2] = - RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( - vmulq_s32(input_val.val[2], left_shifted_one_dup), - quantized_multiplier), - right_shift); - result.val[3] = - RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( - vmulq_s32(input_val.val[3], left_shifted_one_dup), - quantized_multiplier), - right_shift); -#else - for (int i = 0; i < 4; ++i) { - int32_t vals[4]; - vals[0] = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul( - vgetq_lane_s32(input_val.val[i], 0) * (1 << left_shift), - quantized_multiplier), - right_shift); - vals[1] = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul( - vgetq_lane_s32(input_val.val[i], 1) * (1 << left_shift), - quantized_multiplier), - right_shift); - vals[2] = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul( - vgetq_lane_s32(input_val.val[i], 2) * (1 << left_shift), - quantized_multiplier), - right_shift); - vals[3] = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul( - vgetq_lane_s32(input_val.val[i], 3) * (1 << left_shift), - quantized_multiplier), - right_shift); - - result.val[i] = vld1q_s32(reinterpret_cast<int32_t*>(&vals)); - } -#endif - return result; -} - inline int32x4x2_t MultiplyByQuantizedMultiplier2Rows( int32x4x2_t input_val, int32 quantized_multiplier, int shift) { using gemmlowp::RoundingDivideByPOT; diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h index bfa5eb3075d..3638a4e5874 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -15,9 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ -// TODO(ghodrat): Remove this header file and the dependency to internal data -// structure. -#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/optimized/neon_check.h" diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index 6cabea11ac4..cbe62516a52 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -201,43 +201,6 @@ MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data, return MatrixMap<Scalar>(data, rows, cols); } -// TODO(renjieliu): Refactor this to merge with other -// MultiplyByQuantizedMultipler. -#ifdef USE_NEON -inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows( - int32x4x4_t input_val, int32 quantized_multiplier, int32 shift) { - const int left_shift = std::max(shift, 0); - const int right_shift = std::min(shift, 0); - int32x4x4_t result; - - int32x4_t multiplier_dup = vdupq_n_s32(quantized_multiplier); - int32x4_t left_shift_dup = vdupq_n_s32(left_shift); - int32x4_t right_shift_dup = vdupq_n_s32(right_shift); - - result.val[0] = - vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[0], left_shift_dup), - multiplier_dup), - right_shift_dup); - - result.val[1] = - vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[1], left_shift_dup), - multiplier_dup), - right_shift_dup); - - result.val[2] = - vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[2], left_shift_dup), - multiplier_dup), - right_shift_dup); - - result.val[3] = - vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[3], left_shift_dup), - multiplier_dup), - right_shift_dup); - - return result; -} -#endif - template <typename ElementwiseF, typename ScalarBroadcastF, typename T> inline void BinaryBroadcastFiveFold(const ArithmeticParams& unswitched_params, const RuntimeShape& unswitched_input1_shape, @@ -418,7 +381,7 @@ inline void FullyConnected( const int32 output_activation_max = params.quantized_activation_max; TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1); - // TODO(benoitjacob): This really should be: + // TODO(b/62193649): This really should be: // const int batches = ArraySize(output_dims, 1); // but the current --variable_batch hack consists in overwriting the 3rd // dimension with the runtime batch size, as we don't keep track for each @@ -484,7 +447,7 @@ inline void FullyConnected( TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1); - // TODO(benoitjacob): This really should be: + // TODO(b/62193649): This really should be: // const int batches = ArraySize(output_dims, 1); // but the current --variable_batch hack consists in overwriting the 3rd // dimension with the runtime batch size, as we don't keep track for each @@ -862,7 +825,7 @@ inline void ShuffledFullyConnected( TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1); TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2); TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1); - // TODO(benoitjacob): This really should be: + // TODO(b/62193649): This really should be: // const int batches = ArraySize(output_dims, 1); // but the current --variable_batch hack consists in overwriting the 3rd // dimension with the runtime batch size, as we don't keep track for each diff --git a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h index 77b4fab0e42..649e525cbf9 100644 --- a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h @@ -24,9 +24,6 @@ limitations under the License. // NEON_2_SSE translator library. If a native SSE version of a function is // implemented, replace the appropriate one to SSE_OR_PORTABLE. -// TODO(ghodrat): Remove this header file and the dependency to internal data -// structure. -#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/optimized/neon_check.h" #include "tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h" diff --git a/tensorflow/lite/kernels/internal/reference/add.h b/tensorflow/lite/kernels/internal/reference/add.h index 5be7ab4dc0c..3da76d88b97 100644 --- a/tensorflow/lite/kernels/internal/reference/add.h +++ b/tensorflow/lite/kernels/internal/reference/add.h @@ -202,14 +202,6 @@ inline void Add(const ArithmeticParams& params, } } -// TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary -// dimensionality if the runtime code does a single loop over one dimension -// that handles broadcasting as the base case. The code generator would then -// generate max(D1, D2) nested for loops. -// TODO(benoitjacob): BroadcastAdd is intentionally duplicated from -// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T> -// is no longer referenced in this file, move NdArrayDesc<T> from types.h to -// reference_ops.h. inline void BroadcastAdd4DSlow(const ArithmeticParams& params, const RuntimeShape& input1_shape, const float* input1_data, diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc index 338adf8c2ee..8f0f1e8543e 100644 --- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -428,7 +428,7 @@ void PortableApplyLayerNorm(const int16_t* input, } int32_t mean = static_cast<int32_t>(static_cast<int64_t>(sum) * 1024 / n_input); - // TODO(jianlijianli): Avoids overflow but only works for POT n_input. + // TODO(b/173994730): Avoids overflow but only works for POT n_input. int32_t temp = kTwoToPower20 / n_input; int64_t variance = sum_sq * temp - static_cast<int64_t>(mean) * static_cast<int64_t>(mean); diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index bbc156cc3be..fe7fde50aef 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -179,7 +179,7 @@ inline void Elu(const RuntimeShape& input_shape, const float* input_data, const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; ++i) { const float val = input_data[i]; - output_data[i] = val < 0.0 ? std::exp(val) - 1 : val; + output_data[i] = val < 0.0f ? std::expm1(val) : val; } } @@ -319,10 +319,6 @@ inline void AddN(const RuntimeShape& input_shape, const size_t num_inputs, // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then // generate max(D1, D2) nested for loops. -// TODO(benoitjacob): BroadcastMul is intentionally duplicated from -// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T> -// is no longer referenced in this file, move NdArrayDesc<T> from types.h to -// reference_ops.h. inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params, const RuntimeShape& unswitched_input1_shape, const uint8* unswitched_input1_data, diff --git a/tensorflow/lite/kernels/internal/reference/sub.h b/tensorflow/lite/kernels/internal/reference/sub.h index b27f251de6c..5f4fd19e37d 100644 --- a/tensorflow/lite/kernels/internal/reference/sub.h +++ b/tensorflow/lite/kernels/internal/reference/sub.h @@ -65,10 +65,6 @@ inline void SubNonBroadcast(const ArithmeticParams& params, // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then // generate max(D1, D2) nested for loops. -// TODO(b/151345101): BroadcastSub is intentionally duplicated from -// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T> -// is no longer referenced in this file, move NdArrayDesc<T> from types.h to -// reference_ops.h. template <int N = 5> inline void BroadcastSubSlow(const ArithmeticParams& params, const RuntimeShape& input1_shape, diff --git a/tensorflow/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/lite/kernels/internal/tensor_utils_test.cc index 1d0a4d50eb2..3e10548a106 100644 --- a/tensorflow/lite/kernels/internal/tensor_utils_test.cc +++ b/tensorflow/lite/kernels/internal/tensor_utils_test.cc @@ -30,6 +30,29 @@ limitations under the License. namespace tflite { namespace tensor_utils { +// Normally we should require bit-for-bit exact results. Unfortunately a bug +// in the Intel arm_neon_sse.h translation header that we use for x86 tests +// causes 1-bit inaccuracy in the vqrdmulh_n_s32 intrinsic, which causes +// off-by-1 errors. So we have to live with a +// few off-by-one errors for now, yet still ensure that no more than a small +// minority of values are wrong. +// This util is to compare the rounding results for integer-output. +template <typename T> +void CompareRoundingResults(int flat_size, const T* expected_result, + const T* real_result, int max_element_tolerance = 1, + int max_total_tolerance = 5) { + int max_diff = 0; + int64_t total_diff = 0; + for (int i = 0; i < flat_size; i++) { + int diff = static_cast<int>(std::abs(expected_result[i] - real_result[i])); + total_diff += diff; + max_diff = std::max(max_diff, diff); + } + + EXPECT_LE(max_diff, max_element_tolerance); + EXPECT_LE(total_diff, max_total_tolerance); +} + TEST(uKernels, FloorLog2Test) { for (int i = 1; i < 257; ++i) { EXPECT_EQ(::tflite::FloorLog2(i), @@ -1758,7 +1781,7 @@ TEST(uKernels, VectorBatchVectorCwiseProductAccumulateInteger) { const std::vector<int16_t> expected_output = { /* batch 0 */ - -35, 34, 32, 30, 27, 24, 20, 16, 11, -2, 10, 13, 16, 18, 19, 20, 21, 21, + -35, 34, 32, 30, 27, 24, 20, 16, 11, -1, 10, 13, 16, 18, 19, 20, 21, 21, 20, 0, 4, 8, 12, 17, 23, 29, 35, 42, 50, /* batch 1 */ 27, 24, 20, 18, 15, 14, 12, 12, 1, 2, 2, 6, 10, 15, 20, 26, 32, 39, 26, 9, @@ -1769,7 +1792,9 @@ TEST(uKernels, VectorBatchVectorCwiseProductAccumulateInteger) { /* batch 3 */ 17, 21, 14, 17, 18, 20, 20, 21, 20, 20, 18, -7, 13, 14, 13, 13, 11, 10, 7, 5, 26, 31, 37, 56, 63, 72, 80, 90, 99}; - EXPECT_THAT(batch_output, testing::ElementsAreArray(expected_output)); + // Only allow 1 element difference for the rounding result. + CompareRoundingResults<int16_t>(4 * 29, expected_output.data(), + batch_output.data(), 1, 1); } TEST(uKernels, VectorBatchVectorCwiseProductAccumulateFloat) { diff --git a/tensorflow/lite/kernels/reshape.cc b/tensorflow/lite/kernels/reshape.cc index 2a21fa730bc..d764e1f81b2 100644 --- a/tensorflow/lite/kernels/reshape.cc +++ b/tensorflow/lite/kernels/reshape.cc @@ -47,7 +47,7 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) { // Tensorflow's Reshape allows one of the shape components to have the // special -1 value, meaning it will be calculated automatically based on the // input. Here we calculate what that dimension should be so that the number - // of output elements in the same as the number of input elements. + // of output elements is the same as the number of input elements. int num_input_elements = NumElements(input); int num_output_elements = 1; diff --git a/tensorflow/lite/micro/README.md b/tensorflow/lite/micro/README.md index ebde7dbc81f..613636cf29b 100644 --- a/tensorflow/lite/micro/README.md +++ b/tensorflow/lite/micro/README.md @@ -24,7 +24,7 @@ kilobytes of memory. To learn how to use the framework, visit the developer documentation at [tensorflow.org/lite/microcontrollers](https://www.tensorflow.org/lite/microcontrollers). -# Continuous Buils Status +# Continuous Build Status Build Type | Status | Artifacts ---------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index 22a0d035471..ee6d55a801c 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -14,11 +14,6 @@ config_setting( define_values = {"tflm_build": "xtensa_hifimini"}, ) -config_setting( - name = "xtensa_hifimini_staging", - define_values = {"tflm_build": "xtensa_hifimini_staging"}, -) - package_group( name = "micro", packages = ["//tensorflow/lite/micro/..."], @@ -35,10 +30,7 @@ cc_library( "//conditions:default": [ ], ":xtensa_hifimini": [ - "xtensa_hifimini/fixedpoint_utils.h", - ], - ":xtensa_hifimini_staging": [ - "xtensa_hifimini/fixedpoint_utils.h", + "xtensa/fixedpoint_utils.h", ], }), copts = micro_copts(), @@ -58,10 +50,7 @@ cc_library( "fully_connected.cc", ], ":xtensa_hifimini": [ - "xtensa_hifimini/fully_connected.cc", - ], - ":xtensa_hifimini_staging": [ - "xtensa_hifimini_staging/fully_connected.cc", + "xtensa/fully_connected.cc", ], }), hdrs = ["fully_connected.h"], @@ -71,22 +60,14 @@ cc_library( ":micro", ], deps = [ - ":activation_utils", ":fixedpoint_utils", ":kernel_util", - ":micro_utils", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:kernel_util", - "//tensorflow/lite/kernels:op_macros", - "//tensorflow/lite/kernels:padding", "//tensorflow/lite/kernels/internal:common", - "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/kernels/internal:quantization_util", "//tensorflow/lite/kernels/internal:reference_base", "//tensorflow/lite/kernels/internal:tensor", - "//tensorflow/lite/kernels/internal:types", - "//tensorflow/lite/micro:memory_helpers", - "//tensorflow/lite/micro:micro_utils", ] + select({ "//conditions:default": [], ":xtensa_hifimini": [ @@ -141,23 +122,11 @@ cc_library( "svdf.cc", ], ":xtensa_hifimini": [ - "xtensa_hifimini/conv.cc", - "xtensa_hifimini/depthwise_conv.cc", - "xtensa_hifimini/quantize.cc", - "xtensa_hifimini/softmax.cc", - "xtensa_hifimini/svdf.cc", - ], - ":xtensa_hifimini_staging": [ - # TODO(b/144176795): finer granularity would help reduce the - # duplication of srcs in the BUILD rules (in this case conv.cc and - # depthwise_conv.cc). We are falling back to reference kernels in - # case the optimized kernels are not implemented to match the - # behavior that we get with the Makefiles. - "conv.cc", - "depthwise_conv.cc", - "xtensa_hifimini_staging/quantize.cc", - "xtensa_hifimini_staging/softmax.cc", - "xtensa_hifimini_staging/svdf.cc", + "xtensa/conv.cc", + "xtensa/depthwise_conv.cc", + "xtensa/quantize.cc", + "xtensa/softmax.cc", + "xtensa/svdf.cc", ], }), hdrs = ["micro_ops.h"], diff --git a/tensorflow/lite/micro/kernels/circular_buffer.cc b/tensorflow/lite/micro/kernels/circular_buffer.cc index f70203062a4..5ce8dbe14c8 100644 --- a/tensorflow/lite/micro/kernels/circular_buffer.cc +++ b/tensorflow/lite/micro/kernels/circular_buffer.cc @@ -65,45 +65,49 @@ struct OpData { int cycles_max; }; -// These constants represent constants specific to the music detect model. -// They exist until (b/132070898) is fixed. -constexpr int kMaxOpDataSize = 7; -int op_data_counter = 0; -OpData op_data_array[kMaxOpDataSize]; - } // namespace -void Free(TfLiteContext* context, void* buffer) { op_data_counter = 0; } +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + return context->AllocatePersistentBuffer(context, sizeof(OpData)); +} TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input = GetInput(context, node, kInputTensor); - TF_LITE_ENSURE(context, input != nullptr); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TF_LITE_ENSURE(context, output != nullptr); + + TFLITE_DCHECK(node->user_data != nullptr); + OpData* op_data = static_cast<OpData*>(node->user_data); TF_LITE_ENSURE(context, input != nullptr); TF_LITE_ENSURE(context, output != nullptr); - TF_LITE_ENSURE_EQ(context, 1, output->dims->data[0]); - TF_LITE_ENSURE_EQ(context, 1, input->dims->data[0]); + TF_LITE_ENSURE_EQ(context, input->dims->data[0], output->dims->data[0]); TF_LITE_ENSURE_EQ(context, 1, input->dims->data[1]); - TF_LITE_ENSURE_EQ(context, 1, output->dims->data[2]); - TF_LITE_ENSURE_EQ(context, 1, input->dims->data[2]); + TF_LITE_ENSURE_EQ(context, input->dims->data[2], output->dims->data[2]); TF_LITE_ENSURE_EQ(context, output->dims->data[3], input->dims->data[3]); TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); - // The circular buffer custom operator currently only supports int8_t. + // The circular buffer custom operator currently only supports int8. TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8); - // TODO(b/132070898): Use statically slotted OpData structures until a - // scratch memory API is ready. - TFLITE_DCHECK_LE(op_data_counter, kMaxOpDataSize); - OpData* op_data = &op_data_array[op_data_counter++]; - // The last circular buffer layer (length 5) simply accumulates outputs, and - // does not run periodically. + // The last circular buffer layer simply accumulates outputs, and does not run + // periodically. // TODO(b/150001379): Move this special case logic to the tflite flatbuffer. - if (output->dims->data[1] == 5) { + static int cb_prepare_count = 0; + cb_prepare_count++; + // These checks specifically work for the only two streaming models supported + // on TFLM. They use the shape of the output tensor along with the layer + // number to determine if the circular buffer period should be 1 or 2. + + // These models are outlined int the following documents: + // https://docs.google.com/document/d/1lc_G2ZFhjiKFo02UHjBaljye1xsL0EkfybkaVELEE3Q/edit?usp=sharing + // https://docs.google.com/document/d/1pGc42PuWyrk-Jy1-9qeqtggvsmHr1ifz8Lmqfpr2rKA/edit?usp=sharing + if (output->dims->data[1] == 5 || output->dims->data[1] == 13 || + (cb_prepare_count == 5 && output->dims->data[2] == 2 && + output->dims->data[3] == 96)) { op_data->cycles_max = 1; + cb_prepare_count = 0; } else { op_data->cycles_max = 2; } @@ -127,10 +131,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, kOutputTensor); + TFLITE_DCHECK(node->user_data != nullptr); OpData* data = reinterpret_cast<OpData*>(node->user_data); int num_slots = output->dims->data[1]; - int depth = output->dims->data[3]; + int depth = output->dims->data[2] * output->dims->data[3]; if (input->type == kTfLiteInt8) { EvalInt8(tflite::micro::GetTensorData<int8_t>(input), num_slots, depth, @@ -148,12 +153,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return static_cast<TfLiteStatus>(kTfLiteAbort); } - // If prepare is ever called more than one time (for example, when testing the - // ambient model, the interpreter is created a few times), this op data - // counter needs to be reset so that future instances do not overrun this op - // data array. - op_data_counter = 0; - data->cycles_until_run = data->cycles_max; return kTfLiteOk; @@ -162,8 +161,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace circular_buffer TfLiteRegistration* Register_CIRCULAR_BUFFER() { - static TfLiteRegistration r = {/*init=*/nullptr, - /*free=*/circular_buffer::Free, + static TfLiteRegistration r = {/*init=*/circular_buffer::Init, + /*free=*/nullptr, /*prepare=*/circular_buffer::Prepare, /*invoke=*/circular_buffer::Eval, /*profiling_string=*/nullptr, diff --git a/tensorflow/lite/micro/kernels/xtensa/conv.cc b/tensorflow/lite/micro/kernels/xtensa/conv.cc index de9820b82d9..0af54c13bf6 100644 --- a/tensorflow/lite/micro/kernels/xtensa/conv.cc +++ b/tensorflow/lite/micro/kernels/xtensa/conv.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/padding.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" -#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h" +#include "tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h" namespace tflite { namespace { diff --git a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc index 12410a94456..b0ecedcb8ea 100644 --- a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc +++ b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/padding.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" -#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h" +#include "tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h" namespace tflite { namespace { diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc index 30a5b6a602a..165e243a6d7 100644 --- a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" -#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h" +#include "tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h" namespace tflite { namespace { diff --git a/tensorflow/lite/micro/kernels/xtensa/quantize.cc b/tensorflow/lite/micro/kernels/xtensa/quantize.cc index b867e70d98b..05646a32ad7 100644 --- a/tensorflow/lite/micro/kernels/xtensa/quantize.cc +++ b/tensorflow/lite/micro/kernels/xtensa/quantize.cc @@ -19,10 +19,12 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/requantize.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" -#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h" +#include "tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h" +#include "tensorflow/lite/micro/micro_utils.h" namespace tflite { namespace { @@ -30,6 +32,12 @@ namespace { struct OpData { int32_t zero_point = 0; int scale_multiplier = 0; + + // Use 32-bit multiplier and scale for requantize version of this operator + // to preserve compatibility with reference op. + int32_t requantize_output_multiplier; + int requantize_output_shift; + int32_t input_zero_point = 0; }; void AffineQuantize(int scale_multiplier, const int32_t zero_point, @@ -116,6 +124,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { CreateQConstantForInt24(0, input->params.scale / output->params.scale); op_data->zero_point = output->params.zero_point; + op_data->input_zero_point = input->params.zero_point; + + double effective_scale = static_cast<double>(input->params.scale) / + static_cast<double>(output->params.scale); + QuantizeMultiplier(effective_scale, &op_data->requantize_output_multiplier, + &op_data->requantize_output_shift); return kTfLiteOk; } @@ -127,21 +141,25 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0); TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0); - tflite::QuantizationParams op_params; - op_params.zero_point = op_data->zero_point; - - if (input->type != kTfLiteInt16 && output->type != kTfLiteInt8) { + if (output->type == kTfLiteInt8 && input->type == kTfLiteInt16) { + AffineQuantize(op_data->scale_multiplier, op_data->zero_point, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData<int16_t>(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData<int8_t>(output)); + } else if (output->type == kTfLiteInt32 && input->type == kTfLiteInt16) { + int size = ElementCount(*input->dims); + reference_ops::Requantize(tflite::micro::GetTensorData<int16_t>(input), + size, op_data->requantize_output_multiplier, + op_data->requantize_output_shift, + op_data->input_zero_point, op_data->zero_point, + tflite::micro::GetTensorData<int32_t>(output)); + } else { TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", TfLiteTypeGetName(input->type), TfLiteTypeGetName(output->type)); return kTfLiteError; } - - AffineQuantize(op_data->scale_multiplier, op_data->zero_point, - tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData<int16_t>(input), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData<int8_t>(output)); return kTfLiteOk; } diff --git a/tensorflow/lite/micro/kernels/xtensa/svdf.cc b/tensorflow/lite/micro/kernels/xtensa/svdf.cc index 28f8f1e1af0..5392e50245e 100644 --- a/tensorflow/lite/micro/kernels/xtensa/svdf.cc +++ b/tensorflow/lite/micro/kernels/xtensa/svdf.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/lite/kernels/op_macros.h" #include "tensorflow/lite/micro/kernels/activation_utils.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" -#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h" +#include "tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h" namespace tflite { namespace { diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/conv.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/conv.cc deleted file mode 100644 index de9820b82d9..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/conv.cc +++ /dev/null @@ -1,456 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/kernels/internal/reference/conv.h" - -#include <xtensa/tie/xt_hifi2.h> - -#include "tensorflow/lite/c/builtin_op_data.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/kernels/internal/common.h" -#include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" -#include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/kernels/padding.h" -#include "tensorflow/lite/micro/kernels/kernel_util.h" -#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h" - -namespace tflite { -namespace { - -constexpr int kInputTensor = 0; -constexpr int kFilterTensor = 1; -constexpr int kBiasTensor = 2; -constexpr int kOutputTensor = 0; - -// Conv is quantized along dimension 0: -// https://www.tensorflow.org/lite/performance/quantization_spec -constexpr int kConvQuantizedDimension = 0; - -struct OpData { - TfLitePaddingValues padding; - // The scaling factor from input to output (aka the 'real multiplier') can - // be represented as a fixed point multiplier plus a left shift. - int32_t output_multiplier; - int output_shift; - - // Cached tensor zero point values for quantized operations. - int32_t input_zero_point; - int32_t output_zero_point; - - // Per channel output multiplier and shift. - int32_t* per_channel_output_multiplier; - int32_t* per_channel_output_shift; - - // The range of the fused activation layer. For example for kNone and - // uint8_t these would be 0 and 255. - int32_t output_activation_min; - int32_t output_activation_max; -}; - -void ConvPerChannel(const ConvParams& params, const int32_t* output_multiplier, - const int32_t* output_shift, - const RuntimeShape& input_shape, const int8_t* input_data, - const RuntimeShape& filter_shape, const int8_t* filter_data, - const RuntimeShape& bias_shape, const int32_t* bias_data, - const RuntimeShape& output_shape, int8_t* output_data) { - const int stride_width = params.stride_width; - const int stride_height = params.stride_height; - const int dilation_width_factor = params.dilation_width_factor; - const int dilation_height_factor = params.dilation_height_factor; - const int pad_width = params.padding_values.width; - const int pad_height = params.padding_values.height; - const int32_t input_offset = params.input_offset; - const int32_t output_offset = params.output_offset; - const int32_t output_activation_min = params.quantized_activation_min; - const int32_t output_activation_max = params.quantized_activation_max; - - const int batches = input_shape.Dims(0); - - const int input_height = input_shape.Dims(1); - const int input_width = input_shape.Dims(2); - const int input_depth = input_shape.Dims(3); - - const int filter_height = filter_shape.Dims(1); - const int filter_width = filter_shape.Dims(2); - const int filter_depth = filter_shape.Dims(3); - - const int output_height = output_shape.Dims(1); - const int output_width = output_shape.Dims(2); - const int output_depth = output_shape.Dims(3); - - ae_p24x2s input_offset_24x2 = AE_MOVPA24(input_offset); - ae_q56s output_offset_56 = AE_CVTQ48A32S(output_offset); - ae_q56s output_activation_min_56 = AE_CVTQ48A32S(output_activation_min); - ae_q56s output_activation_max_56 = AE_CVTQ48A32S(output_activation_max); - - for (int batch = 0; batch < batches; ++batch) { - for (int out_y = 0; out_y < output_height; ++out_y) { - const int in_y_origin = (out_y * stride_height) - pad_height; - for (int out_x = 0; out_x < output_width; ++out_x) { - const int in_x_origin = (out_x * stride_width) - pad_width; - for (int out_channel = 0; out_channel < output_depth; ++out_channel) { - ae_q56s acc_56 = AE_ZEROQ56(); - - for (int filter_y = 0; filter_y < filter_height; ++filter_y) { - for (int filter_x = 0; filter_x < filter_width; filter_x += 2) { - const int in_x = in_x_origin + dilation_width_factor * filter_x; - const int in_y = in_y_origin + dilation_height_factor * filter_y; - const bool is_point_inside_image = - (in_x >= 0) && (in_x < input_width) && (in_y >= 0) && - (in_y < input_height); - if (is_point_inside_image) { - // Find current input index, minus 2 for Xtensa load - // alignments: - // TODO(b/147322595): Consider doing these offset calculations - // with intrinsics: - int input_idx = - ((batch * input_height + in_y) * input_width + in_x) * - input_depth * 2 - - 2; - const int8_t* input_vals_offset_ptr = input_data + input_idx; - for (int i = 0; i < input_depth; i += 2) { - // Load signed 2x 8bit values and right shift into 24bit - // alignment: - ae_p24x2s input_vals_24x2; - AE_LP8X2F_IU(input_vals_24x2, input_vals_offset_ptr, 2); - input_vals_24x2 = AE_P24X2S_SRAI(input_vals_24x2, 16); - - // Add input offset (24bit aligned): - input_vals_24x2 = - AE_P24S_ADDS_P24X2S(input_vals_24x2, input_offset_24x2); - - // Find current filter index, minus 2 for Xtensa load - // alignments: - int filter_idx = - ((out_channel * filter_height + filter_y) * filter_width + - filter_x) * - filter_depth + - i - 2; - const int8_t* filter_vals_offset_ptr = - filter_data + filter_idx; - - // Load signed 2x 8bit values and right shift into 24bit - // alignment: - ae_p24x2s filter_vals_24x2; - AE_LP8X2F_IU(filter_vals_24x2, filter_vals_offset_ptr, 2); - filter_vals_24x2 = AE_P24X2S_SRAI(filter_vals_24x2, 16); - - // Multiply and accumulate into 48bit bit space: - AE_MULAAP24S_HH_LL(acc_56, filter_vals_24x2, input_vals_24x2); - } - } - } - } - - // Left shift from 48bit alignment to 32bit: - acc_56 = AE_Q56S_SLAI(acc_56, 16); - - if (bias_data) { - // Load and add bias at 32bit alignment: - ae_q56s bias_56 = AE_CVTQ48A32S(bias_data[out_channel]); - acc_56 = AE_ADDQ56(acc_56, bias_56); - } - - // Shift from 32bit alignment to 24bit alignment and place back on - // the PR register: - acc_56 = AE_Q56S_SLAI(acc_56, 8); - ae_p24x2s acc_24x2 = AE_TRUNCP24Q48(acc_56); - - // Apply quantized multiplier and accumulate result at 48bit - // alignment. Convert the (unsigned) 32-bit multiplier down to a - // 24-bit multiplier. - acc_56 = MultiplyByQuantizedMultiplier( - acc_24x2, output_multiplier[out_channel] >> 8, - output_shift[out_channel]); - - // Add output offset, cap activation, and assign to the output: - acc_56 = AE_ADDQ56(acc_56, output_offset_56); - acc_56 = AE_MINQ56S(acc_56, output_activation_max_56); - acc_56 = AE_MAXQ56S(acc_56, output_activation_min_56); - - int output_idx = - ((batch * output_height + out_y) * output_width + out_x) * - output_depth + - out_channel; - output_data[output_idx] = static_cast<int8_t>(AE_TRUNCA32Q48(acc_56)); - } - } - } - } -} - -// TODO(b/154240772): Move shared code into common methods. -inline void Conv1x32Input32x32Filter( - const int input_offset, const int output_offset, - const int quantized_activation_min, const int quantized_activation_max, - const int32_t* output_multiplier, const int32_t* output_shift, - const RuntimeShape& input_shape, const int8_t* input_data, - const RuntimeShape& filter_shape, const int8_t* filter_data, - const RuntimeShape& bias_shape, const int32_t* bias_data, - const RuntimeShape& output_shape, int8_t* output_data) { - ae_p24x2s input_offset_24x2 = AE_MOVPA24(input_offset); - ae_q56s output_offset_56 = AE_CVTQ48A32S(output_offset); - ae_q56s output_activation_max_56 = AE_CVTQ48A32S(quantized_activation_max); - ae_q56s output_activation_min_56 = AE_CVTQ48A32S(quantized_activation_min); - - constexpr int kChannels = 32; - constexpr int kFilterDepth = 32; - for (int ch = 0; ch < kChannels; ch++) { - ae_q56s acc_56 = AE_ZEROQ56(); - const int8_t* input_vals_ptr = input_data - 2; - for (int i = 0; i < kFilterDepth; i += 2) { - // Load signed 2x 8bit values and right shift into 24bit - // alignment: - ae_p24x2s input_vals_24x2; - AE_LP8X2F_IU(input_vals_24x2, input_vals_ptr, 2); - input_vals_24x2 = AE_P24X2S_SRAI(input_vals_24x2, 16); - - // Add input offset (24bit aligned): - input_vals_24x2 = AE_P24S_ADDS_P24X2S(input_vals_24x2, input_offset_24x2); - // Find current filter index, minus 2 for Xtensa load - // alignments: - const int filter_idx = ch * kFilterDepth + i - 2; - const int8_t* filter_vals_offset_ptr = filter_data + filter_idx; - - // Load signed 2x 8bit values and right shift into 24bit - // alignment: - ae_p24x2s filter_vals_24x2; - AE_LP8X2F_IU(filter_vals_24x2, filter_vals_offset_ptr, 2); - filter_vals_24x2 = AE_P24X2S_SRAI(filter_vals_24x2, 16); - - // Multiply and accumulate into 48bit bit space: - AE_MULAAP24S_HH_LL(acc_56, filter_vals_24x2, input_vals_24x2); - } - // Left shift from 48bit alignment to 32bit: - acc_56 = AE_Q56S_SLAI(acc_56, 16); - if (bias_data) { - // Load and add bias at 32bit alignment: - ae_q56s bias_56 = AE_CVTQ48A32S(bias_data[ch]); - acc_56 = AE_ADDQ56(acc_56, bias_56); - } - - // Shift from 32bit alignment to 24bit alignment and place back on - // the PR register: - acc_56 = AE_Q56S_SLAI(acc_56, 8); - ae_p24x2s acc_24x2 = AE_TRUNCP24Q48(acc_56); - - // Apply quantized multiplier and accumulate result at 48bit alignment. - // Convert the (unsigned) 32-bit multiplier down to a 24-bit multiplier. - acc_56 = MultiplyByQuantizedMultiplier(acc_24x2, output_multiplier[ch] >> 8, - output_shift[ch]); - - // Add output offset, cap activation, and assign to the output: - acc_56 = AE_ADDQ56(acc_56, output_offset_56); - acc_56 = AE_MINQ56S(acc_56, output_activation_max_56); - acc_56 = AE_MAXQ56S(acc_56, output_activation_min_56); - - output_data[ch] = static_cast<int8_t>(AE_TRUNCA32Q48(acc_56)); - } -} - -TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, - TfLiteConvParams* params, int width, int height, - int filter_width, int filter_height, int out_width, - int out_height, const TfLiteType data_type, - OpData* data) { - bool has_bias = node->inputs->size == 3; - // Check number of inputs/outputs - TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); - - // Matching GetWindowedOutputSize in TensorFlow. - auto padding = params->padding; - data->padding = ComputePaddingHeightWidth( - params->stride_height, params->stride_width, - params->dilation_height_factor, params->dilation_width_factor, height, - width, filter_height, filter_width, padding, &out_height, &out_width); - - // Note that quantized inference requires that all tensors have their - // parameters set. This is usually done during quantized training. - if (data_type != kTfLiteFloat32) { - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); - const TfLiteTensor* bias = - GetOptionalInputTensor(context, node, kBiasTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - int output_channels = filter->dims->data[kConvQuantizedDimension]; - - return tflite::PopulateConvolutionQuantizationParams( - context, input, filter, bias, output, params->activation, - &data->output_multiplier, &data->output_shift, - &data->output_activation_min, &data->output_activation_max, - data->per_channel_output_multiplier, - reinterpret_cast<int*>(data->per_channel_output_shift), - output_channels); - } - return kTfLiteOk; -} - -void* Init(TfLiteContext* context, const char* buffer, size_t length) { - TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); - return context->AllocatePersistentBuffer(context, sizeof(OpData)); -} - -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TFLITE_DCHECK(node->user_data != nullptr); - TFLITE_DCHECK(node->builtin_data != nullptr); - auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data); - - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); - - auto* op_data = reinterpret_cast<OpData*>(node->user_data); - - int input_width = input->dims->data[2]; - int input_height = input->dims->data[1]; - int filter_width = filter->dims->data[2]; - int filter_height = filter->dims->data[1]; - int output_width = output->dims->data[2]; - int output_height = output->dims->data[1]; - - // Per channel quantization is only needed for int8_t inference. For other - // quantized types, only a single scale and zero point is needed. - const int num_channels = filter->dims->data[kConvQuantizedDimension]; - // Dynamically allocate per-channel quantization parameters. - op_data->per_channel_output_multiplier = - reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer( - context, num_channels * sizeof(int32_t))); - op_data->per_channel_output_shift = - reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer( - context, num_channels * sizeof(int32_t))); - op_data->input_zero_point = input->params.zero_point; - op_data->output_zero_point = output->params.zero_point; - // All per-channel quantized tensors need valid zero point and scale arrays. - if (input->type == kTfLiteInt8) { - TF_LITE_ENSURE_EQ(context, filter->quantization.type, - kTfLiteAffineQuantization); - - const auto* affine_quantization = - reinterpret_cast<TfLiteAffineQuantization*>( - filter->quantization.params); - TF_LITE_ENSURE(context, affine_quantization); - TF_LITE_ENSURE(context, affine_quantization->scale); - TF_LITE_ENSURE(context, affine_quantization->zero_point); - - TF_LITE_ENSURE(context, - affine_quantization->scale->size == 1 || - affine_quantization->scale->size == - filter->dims->data[kConvQuantizedDimension]); - TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size, - affine_quantization->zero_point->size); - } - - return CalculateOpData(context, node, params, input_width, input_height, - filter_width, filter_height, output_width, - output_height, input->type, op_data); -} - -void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, - TfLiteConvParams* params, OpData* data, - const TfLiteEvalTensor* input, - const TfLiteEvalTensor* filter, - const TfLiteEvalTensor* bias, - TfLiteEvalTensor* output, - TfLiteEvalTensor* im2col) { - // TODO(b/154032858): Investigate removing extra copies. - ConvParams op_params; - op_params.input_offset = -data->input_zero_point; - op_params.output_offset = data->output_zero_point; - op_params.stride_height = params->stride_height; - op_params.stride_width = params->stride_width; - op_params.dilation_height_factor = params->dilation_height_factor; - op_params.dilation_width_factor = params->dilation_width_factor; - op_params.padding_values.height = data->padding.height; - op_params.padding_values.width = data->padding.width; - op_params.quantized_activation_min = data->output_activation_min; - op_params.quantized_activation_max = data->output_activation_max; - - ConvPerChannel(op_params, data->per_channel_output_multiplier, - data->per_channel_output_shift, - tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData<int8_t>(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData<int8_t>(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData<int32_t>(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData<int8_t>(output)); -} - -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - TFLITE_DCHECK(node->user_data != nullptr); - TFLITE_DCHECK(node->builtin_data != nullptr); - auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data); - auto* op_data = reinterpret_cast<OpData*>(node->user_data); - - TfLiteEvalTensor* output = - tflite::micro::GetEvalOutput(context, node, kOutputTensor); - const TfLiteEvalTensor* input = - tflite::micro::GetEvalInput(context, node, kInputTensor); - const TfLiteEvalTensor* filter = - tflite::micro::GetEvalInput(context, node, kFilterTensor); - const TfLiteEvalTensor* bias = - (NumInputs(node) == 3) - ? tflite::micro::GetEvalInput(context, node, kBiasTensor) - : nullptr; - - int* input_dims = input->dims->data; - int* filter_dims = filter->dims->data; - if (input_dims[0] == 1 && input_dims[1] == 1 && input_dims[2] == 1 && - input_dims[3] == 32 && filter_dims[0] == 32 && filter_dims[1] == 1 && - filter_dims[2] == 1 && filter_dims[3] == 32) { - Conv1x32Input32x32Filter( - -op_data->input_zero_point, op_data->output_zero_point, - op_data->output_activation_min, op_data->output_activation_max, - op_data->per_channel_output_multiplier, - op_data->per_channel_output_shift, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData<int8_t>(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData<int8_t>(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData<int32_t>(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData<int8_t>(output)); - return kTfLiteOk; - } - - switch (input->type) { - case kTfLiteInt8: - EvalQuantizedPerChannel(context, node, params, op_data, input, filter, - bias, output, nullptr); - break; - default: - TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", - TfLiteTypeGetName(input->type), input->type); - return kTfLiteError; - } - return kTfLiteOk; -} -} // namespace - -TfLiteRegistration Register_CONV_2D() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; -} - -} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/depthwise_conv.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/depthwise_conv.cc deleted file mode 100644 index 12410a94456..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/depthwise_conv.cc +++ /dev/null @@ -1,503 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include <xtensa/tie/xt_hifi2.h> - -#include "tensorflow/lite/c/builtin_op_data.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/kernels/internal/common.h" -#include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h" -#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" -#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" -#include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/kernels/padding.h" -#include "tensorflow/lite/micro/kernels/kernel_util.h" -#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h" - -namespace tflite { -namespace { - -constexpr int kInputTensor = 0; -constexpr int kFilterTensor = 1; -constexpr int kBiasTensor = 2; -constexpr int kOutputTensor = 0; - -// Depthwise conv is quantized along dimension 3: -// https://www.tensorflow.org/lite/performance/quantization_spec -constexpr int kDepthwiseConvQuantizedDimension = 3; - -struct OpData { - TfLitePaddingValues padding; - // The scaling factor from input to output (aka the 'real multiplier') can - // be represented as a fixed point multiplier plus a left shift. - int32_t output_multiplier; - int output_shift; - - // Cached tensor zero point values for quantized operations. - int32_t input_zero_point; - int32_t output_zero_point; - - // Per channel output multiplier and shift. - // TODO(b/141139247): Allocate these dynamically when possible. - int32_t* per_channel_output_multiplier; - int32_t* per_channel_output_shift; - - // The range of the fused activation layer. For example for kNone and - // uint8_t these would be 0 and 255. - int32_t output_activation_min; - int32_t output_activation_max; -}; - -inline void DepthwiseConvPerChannel( - const DepthwiseParams& params, const int32_t* output_multiplier, - const int32_t* output_shift, const RuntimeShape& input_shape, - const int8_t* input_data, const RuntimeShape& filter_shape, - const int8_t* filter_data, const RuntimeShape& bias_shape, - const int32_t* bias_data, const RuntimeShape& output_shape, - int8_t* output_data) { - // TODO(b/154032858): Investigate removing extra copies. - const int stride_width = params.stride_width; - const int stride_height = params.stride_height; - const int dilation_width_factor = params.dilation_width_factor; - const int dilation_height_factor = params.dilation_height_factor; - const int pad_width = params.padding_values.width; - const int pad_height = params.padding_values.height; - const int depth_multiplier = params.depth_multiplier; - const int32_t input_offset = params.input_offset; - const int32_t output_offset = params.output_offset; - const int32_t output_activation_min = params.quantized_activation_min; - const int32_t output_activation_max = params.quantized_activation_max; - - const int batches = input_shape.Dims(0); - - const int input_height = input_shape.Dims(1); - const int input_width = input_shape.Dims(2); - const int input_depth = input_shape.Dims(3); - - const int filter_height = filter_shape.Dims(1); - const int filter_width = filter_shape.Dims(2); - const int filter_depth = filter_shape.Dims(3); - - const int output_height = output_shape.Dims(1); - const int output_width = output_shape.Dims(2); - const int output_depth = output_shape.Dims(3); - - ae_p24x2s input_offset_24x2 = AE_MOVPA24(input_offset); - ae_q56s output_offset_56 = AE_CVTQ48A32S(output_offset); - ae_q56s output_activation_min_56 = AE_CVTQ48A32S(output_activation_min); - ae_q56s output_activation_max_56 = AE_CVTQ48A32S(output_activation_max); - - for (int batch = 0; batch < batches; ++batch) { - for (int out_y = 0; out_y < output_height; ++out_y) { - const int in_y_origin = (out_y * stride_height) - pad_height; - for (int out_x = 0; out_x < output_width; ++out_x) { - const int in_x_origin = (out_x * stride_width) - pad_width; - for (int in_channel = 0; in_channel < input_depth; ++in_channel) { - for (int m = 0; m < depth_multiplier; ++m) { - const int output_channel = m + in_channel * depth_multiplier; - ae_q56s acc_56 = AE_ZEROQ56(); - for (int filter_y = 0; filter_y < filter_height; ++filter_y) { - const int in_y = in_y_origin + dilation_height_factor * filter_y; - for (int filter_x = 0; filter_x < filter_width; ++filter_x) { - const int in_x = in_x_origin + dilation_width_factor * filter_x; - // Zero padding by omitting the areas outside the image. - const bool is_point_inside_image = - (in_x >= 0) && (in_x < input_width) && (in_y >= 0) && - (in_y < input_height); - - if (is_point_inside_image) { - // Find current input index, minus 2 for Xtensa load - // alignments: - // TODO(b/147322595): Consider doing these offset calculations - // with intrinsics: - int input_idx = - ((batch * input_height + in_y) * input_width + in_x) * - input_depth + - (in_channel); - int32_t input_val = input_data[input_idx]; - - // Find current filter index, minus 2 for Xtensa load - // alignments: - int filter_idx = - ((filter_y)*filter_width + filter_x) * filter_depth + - (output_channel); - int32_t filter_val = filter_data[filter_idx]; - - // Load 8bit value as int32_t into a 24x24 register and right - // shift into 24bit space. Note: value is duplicated in the HH - // and LL register - but all calculations are done on the HH - // side. - ae_p24x2s input_val_24x2 = AE_MOVPA24(input_val); - - // Add input offset (24bit aligned): - input_val_24x2 = - AE_P24S_ADDS_P24X2S(input_val_24x2, input_offset_24x2); - - // Load filter 8bit value into 24bit alignment: - ae_p24x2s filter_val_24x2 = AE_MOVPA24(filter_val); - - // Multiply and accumulate the HH side of each 24x24 PR - // register: - AE_MULAS56P24S_HH(acc_56, filter_val_24x2, input_val_24x2); - } - } - } - - // Left shift from 48bit alignment to 32bit: - acc_56 = AE_Q56S_SLAI(acc_56, 16); - - if (bias_data) { - // Load and add bias at 32bit alignment: - ae_q56s bias_56 = AE_CVTQ48A32S(bias_data[output_channel]); - acc_56 = AE_ADDQ56(acc_56, bias_56); - } - - // Shift from 32bit alignment to 24bit alignment and place back on - // the PR register: - acc_56 = AE_Q56S_SLAI(acc_56, 8); - ae_p24x2s acc_24x2 = AE_TRUNCP24Q48(acc_56); - - // Apply quantized multiplier and accumulate result at 48bit - // alignment: - acc_56 = MultiplyByQuantizedMultiplier( - acc_24x2, output_multiplier[output_channel], - output_shift[output_channel]); - - // Add output offset, cap activation, and assign to the output: - acc_56 = AE_ADDQ56(acc_56, output_offset_56); - acc_56 = AE_MINQ56S(acc_56, output_activation_max_56); - acc_56 = AE_MAXQ56S(acc_56, output_activation_min_56); - - int output_idx = - ((batch * output_height + out_y) * output_width + out_x) * - output_depth + - output_channel; - output_data[output_idx] = - static_cast<int8_t>(AE_TRUNCA32Q48(acc_56)); - } - } - } - } - } -} - -constexpr int kConvolutionalKernelWidth = 4; -constexpr int kConvolutionalKernelDepth = 32; -inline void DepthwiseConv4x32MatchingInputAndFilter( - const int input_offset, const int output_offset, - const int quantized_activation_min, const int quantized_activation_max, - const int32_t* output_multiplier, const int32_t* output_shift, - const RuntimeShape& input_shape, const int8_t* input_data, - const RuntimeShape& filter_shape, const int8_t* filter_data, - const RuntimeShape& bias_shape, const int32_t* bias_data, - const RuntimeShape& output_shape, int8_t* output_data) { - // Convert the (unsigned) 32-bit multiplier down to a 24-bit multiplier. - const int32_t mult = output_multiplier[0] >> 8; - const int32_t shift = output_shift[0]; - ae_p24x2s input_offset_24x2 = AE_MOVPA24(input_offset); - ae_q56s output_offset_56 = AE_CVTQ48A32S(output_offset); - ae_q56s output_activation_min_56 = AE_CVTQ48A32S(quantized_activation_min); - ae_q56s output_activation_max_56 = AE_CVTQ48A32S(quantized_activation_max); - - const int num_blocks = - kConvolutionalKernelDepth / 2; // Based on the 24x2 register size. - const int stride_elements = - (kConvolutionalKernelDepth / kConvolutionalKernelWidth); - - const int8_t* input_0_ptr = (const int8_t*)(input_data - 2); - const int8_t* weight_0_ptr = (const int8_t*)(filter_data - 2); - // Apply the kernels in blocks of 4 for all the channels. - const int8_t* input_1_ptr = input_0_ptr + stride_elements * 4; - const int8_t* input_2_ptr = input_1_ptr + stride_elements * 4; - const int8_t* input_3_ptr = input_2_ptr + stride_elements * 4; - - const int8_t* weight_1_ptr = weight_0_ptr + stride_elements * 4; - const int8_t* weight_2_ptr = weight_1_ptr + stride_elements * 4; - const int8_t* weight_3_ptr = weight_2_ptr + stride_elements * 4; - - for (int i = 0; i < num_blocks; ++i) { - ae_q56s block_0_acc = AE_ZEROQ56(); - ae_q56s block_1_acc = AE_ZEROQ56(); - - // Load all the weights. - ae_p24x2s weight_0, weight_1, weight_2, weight_3; - AE_LP8X2F_IU(weight_0, weight_0_ptr, 2); - AE_LP8X2F_IU(weight_1, weight_1_ptr, 2); - AE_LP8X2F_IU(weight_2, weight_2_ptr, 2); - AE_LP8X2F_IU(weight_3, weight_3_ptr, 2); - - // Load all the inputs. - ae_p24x2s input_0, input_1, input_2, input_3; - AE_LP8X2F_IU(input_0, input_0_ptr, 2); - AE_LP8X2F_IU(input_1, input_1_ptr, 2); - AE_LP8X2F_IU(input_2, input_2_ptr, 2); - AE_LP8X2F_IU(input_3, input_3_ptr, 2); - - // Shift inputs to 8 bit alignment and add offsets. - input_0 = AE_P24X2S_SRAI(input_0, 16); - input_1 = AE_P24X2S_SRAI(input_1, 16); - input_2 = AE_P24X2S_SRAI(input_2, 16); - input_3 = AE_P24X2S_SRAI(input_3, 16); - - input_0 = AE_P24S_ADDS_P24X2S(input_0, input_offset_24x2); - input_1 = AE_P24S_ADDS_P24X2S(input_1, input_offset_24x2); - input_2 = AE_P24S_ADDS_P24X2S(input_2, input_offset_24x2); - input_3 = AE_P24S_ADDS_P24X2S(input_3, input_offset_24x2); - - // Do the multiplies across all channels. Resulting accumulators are 32bit - // aligned (24 bit aligned weights * 8 bit aligned inputs). - AE_MULAS56P24S_HH(block_0_acc, input_0, weight_0); - AE_MULAS56P24S_HH(block_0_acc, input_1, weight_1); - AE_MULAS56P24S_HH(block_0_acc, input_2, weight_2); - AE_MULAS56P24S_HH(block_0_acc, input_3, weight_3); - - AE_MULAS56P24S_LL(block_1_acc, input_0, weight_0); - AE_MULAS56P24S_LL(block_1_acc, input_1, weight_1); - AE_MULAS56P24S_LL(block_1_acc, input_2, weight_2); - AE_MULAS56P24S_LL(block_1_acc, input_3, weight_3); - - int ch_0 = i * 2; - int ch_1 = i * 2 + 1; - - // Load and add bias at 32bit alignment: - ae_q56s bias_56_0 = AE_CVTQ48A32S(bias_data[ch_0]); - ae_q56s bias_56_1 = AE_CVTQ48A32S(bias_data[ch_1]); - block_0_acc = AE_ADDQ56(block_0_acc, bias_56_0); - block_1_acc = AE_ADDQ56(block_1_acc, bias_56_1); - - // Shift from 32bit alignment to 24bit alignment and place back on - // the PR register: - block_0_acc = AE_Q56S_SLAI(block_0_acc, 8); - block_1_acc = AE_Q56S_SLAI(block_1_acc, 8); - ae_p24x2s acc_24x2_0 = AE_TRUNCP24Q48(block_0_acc); - ae_p24x2s acc_24x2_1 = AE_TRUNCP24Q48(block_1_acc); - - // Apply quantized multiplier and accumulate result at 48bit - // alignment: - block_0_acc = MultiplyByQuantizedMultiplier(acc_24x2_0, mult, shift); - // Apply quantized multiplier and accumulate result at 48bit - // alignment: - block_1_acc = MultiplyByQuantizedMultiplier(acc_24x2_1, mult, shift); - - // Add output offset, cap activation, and assign to the output: - block_0_acc = AE_ADDQ56(block_0_acc, output_offset_56); - block_1_acc = AE_ADDQ56(block_1_acc, output_offset_56); - block_0_acc = AE_MINQ56S(block_0_acc, output_activation_max_56); - block_1_acc = AE_MINQ56S(block_1_acc, output_activation_max_56); - block_0_acc = AE_MAXQ56S(block_0_acc, output_activation_min_56); - block_1_acc = AE_MAXQ56S(block_1_acc, output_activation_min_56); - - output_data[ch_0] = static_cast<int8_t>(AE_TRUNCA32Q48(block_0_acc)); - output_data[ch_1] = static_cast<int8_t>(AE_TRUNCA32Q48(block_1_acc)); - } -} - -TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, - TfLiteDepthwiseConvParams* params, int width, - int height, int filter_width, int filter_height, - const TfLiteType data_type, OpData* data) { - bool has_bias = node->inputs->size == 3; - // Check number of inputs/outputs - TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); - - int unused_output_height, unused_output_width; - data->padding = ComputePaddingHeightWidth( - params->stride_height, params->stride_width, 1, 1, height, width, - filter_height, filter_width, params->padding, &unused_output_height, - &unused_output_width); - - // Note that quantized inference requires that all tensors have their - // parameters set. This is usually done during quantized training. - if (data_type != kTfLiteFloat32) { - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); - const TfLiteTensor* bias = - GetOptionalInputTensor(context, node, kBiasTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension]; - - // TODO(b/148610881): Consider calculating quantized params at int24 - // calculations: - TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams( - context, input, filter, bias, output, params->activation, - &data->output_multiplier, &data->output_shift, - &data->output_activation_min, &data->output_activation_max, - data->per_channel_output_multiplier, - reinterpret_cast<int*>(data->per_channel_output_shift), num_channels)); - } - return kTfLiteOk; -} - -void* Init(TfLiteContext* context, const char* buffer, size_t length) { - TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); - return context->AllocatePersistentBuffer(context, sizeof(OpData)); -} - -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TFLITE_DCHECK(node->user_data != nullptr); - TFLITE_DCHECK(node->builtin_data != nullptr); - auto* params = - reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data); - - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); - const TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - - auto* op_data = reinterpret_cast<OpData*>(node->user_data); - - const TfLiteType data_type = input->type; - int width = SizeOfDimension(input, 2); - int height = SizeOfDimension(input, 1); - int filter_width = SizeOfDimension(filter, 2); - int filter_height = SizeOfDimension(filter, 1); - - // Per channel quantization is only needed for int8_t inference. For other - // quantized types, only a single scale and zero point is needed. - const int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension]; - // Dynamically allocate per-channel quantization parameters. - op_data->per_channel_output_multiplier = - reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer( - context, num_channels * sizeof(int32_t))); - op_data->per_channel_output_shift = - reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer( - context, num_channels * sizeof(int32_t))); - - op_data->input_zero_point = input->params.zero_point; - op_data->output_zero_point = output->params.zero_point; - - // All per-channel quantized tensors need valid zero point and scale arrays. - if (input->type == kTfLiteInt8) { - TF_LITE_ENSURE_EQ(context, filter->quantization.type, - kTfLiteAffineQuantization); - - const auto* affine_quantization = - reinterpret_cast<TfLiteAffineQuantization*>( - filter->quantization.params); - TF_LITE_ENSURE(context, affine_quantization); - TF_LITE_ENSURE(context, affine_quantization->scale); - TF_LITE_ENSURE(context, affine_quantization->zero_point); - TF_LITE_ENSURE( - context, affine_quantization->scale->size == 1 || - affine_quantization->scale->size == - filter->dims->data[kDepthwiseConvQuantizedDimension]); - TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size, - affine_quantization->zero_point->size); - } - - return CalculateOpData(context, node, params, width, height, filter_width, - filter_height, data_type, op_data); -} - -void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, - TfLiteDepthwiseConvParams* params, OpData* data, - const TfLiteEvalTensor* input, - const TfLiteEvalTensor* filter, - const TfLiteEvalTensor* bias, - TfLiteEvalTensor* output) { - DepthwiseParams op_params; - op_params.padding_type = PaddingType::kSame; - op_params.padding_values.width = data->padding.width; - op_params.padding_values.height = data->padding.height; - op_params.stride_width = params->stride_width; - op_params.stride_height = params->stride_height; - op_params.dilation_width_factor = params->dilation_width_factor; - op_params.dilation_height_factor = params->dilation_height_factor; - op_params.depth_multiplier = params->depth_multiplier; - op_params.input_offset = -data->input_zero_point; - op_params.weights_offset = 0; - op_params.output_offset = data->output_zero_point; - // TODO(b/130439627): Use calculated value for clamping. - op_params.quantized_activation_min = std::numeric_limits<int8_t>::min(); - op_params.quantized_activation_max = std::numeric_limits<int8_t>::max(); - - DepthwiseConvPerChannel(op_params, data->per_channel_output_multiplier, - data->per_channel_output_shift, - tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData<int8_t>(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData<int8_t>(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData<int32_t>(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData<int8_t>(output)); -} - -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - TFLITE_DCHECK(node->user_data != nullptr); - TFLITE_DCHECK(node->builtin_data != nullptr); - auto* params = - reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data); - auto* op_data = reinterpret_cast<OpData*>(node->user_data); - - TfLiteEvalTensor* output = - tflite::micro::GetEvalOutput(context, node, kOutputTensor); - const TfLiteEvalTensor* input = - tflite::micro::GetEvalInput(context, node, kInputTensor); - const TfLiteEvalTensor* filter = - tflite::micro::GetEvalInput(context, node, kFilterTensor); - const TfLiteEvalTensor* bias = - (NumInputs(node) == 3) - ? tflite::micro::GetEvalInput(context, node, kBiasTensor) - : nullptr; - - // Handle special case for streaming model. - int* input_dims = input->dims->data; - int* filter_dims = filter->dims->data; - if (input_dims[0] == 1 && input_dims[1] == 4 && input_dims[2] == 1 && - input_dims[3] == 32 && filter_dims[0] == 1 && filter_dims[1] == 4 && - filter_dims[2] == 1 && filter_dims[3] == 32) { - DepthwiseConv4x32MatchingInputAndFilter( - -op_data->input_zero_point, op_data->output_zero_point, - std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max(), - op_data->per_channel_output_multiplier, - op_data->per_channel_output_shift, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData<int8_t>(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData<int8_t>(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData<int32_t>(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData<int8_t>(output)); - return kTfLiteOk; - } - switch (input->type) { // Already know in/out types are same. - case kTfLiteInt8: - EvalQuantizedPerChannel(context, node, params, op_data, input, filter, - bias, output); - break; - default: - TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", - TfLiteTypeGetName(input->type), input->type); - return kTfLiteError; - } - return kTfLiteOk; -} - -} // namespace - -TfLiteRegistration Register_DEPTHWISE_CONV_2D() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; -} - -} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h b/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h deleted file mode 100644 index a1d14df1352..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h +++ /dev/null @@ -1,137 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_MICRO_KERNELS_XTENSA_HIFIMINI_FIXEDPOINT_UTILS_H_ -#define TENSORFLOW_LITE_MICRO_KERNELS_XTENSA_HIFIMINI_FIXEDPOINT_UTILS_H_ - -#include <xtensa/tie/xt_hifi2.h> - -#include <algorithm> -#include <cmath> -#include <cstdint> - -#include "tensorflow/lite/kernels/internal/compatibility.h" - -namespace tflite { - -// INT24 MIN/MAX -#define INT24_MIN -8388608 -#define INT24_MAX 8388607 - -// Multiply 24bit value by a quantized multiplier (w/ shift) and returns a 48bit -// aligned value in the QR register. -inline ae_q56s MultiplyByQuantizedMultiplier(ae_p24x2s x_24x2, - int32_t quantized_multiplier, - int shift) { - // A value with 1 sign bit, N integer bits and M fractional bits is - // represented as QN+1.M since the sign bit is included in the integer bits. - // - // The Q notation in this method explains the values represented in each - // variable, along with an implicit division since the quantized_multiplier - // represents a value between 0.5 and 1.0 (Q1.X-1 where X is the bit precision - // of the type). - // - // Load the quantized multiplier into the PR register. - // NOTE: This method assumes that this param has been calculated for 24bit - // space - not 32bits. - // Q32.0 / 2^23 -> Q24.0 / 2^23 representing a Q1.23 multiplier. - ae_p24x2s quantized_multiplier_24x2 = AE_MOVPA24(quantized_multiplier); - // Shift right by 23 - 16 bits minus the specified shift. This is because we - // keep 16 fractional bits until the end to perform rounding. Subtract shift - // since shift is a left shift, and the 23-16 is a right shift. - int shift_amount = 7 - shift; - - // Find the product of x and the quantized_multiplier. - // Q24.0 / 2^23 * Q24.0 = Q48.0 / 2^23 - // Q48.0 / 2^23 >> 7 = Q48.0 / 2^16 - ae_q56s result_56 = AE_MULP24S_HH(x_24x2, quantized_multiplier_24x2); - - // Shift right if shift amount is positive, left if shift amount is negative. - if (shift_amount >= 0) { - result_56 = AE_Q56S_SRA(result_56, shift_amount); - } else { - result_56 = AE_Q56S_SLA(result_56, -shift_amount); - } - - // Round off the bottom 16 bits. - // Q48.0 / 2^16 -> Q32.0 aligned to 48 bits. - result_56 = AE_ROUNDSQ32SYM(result_56); - return result_56; -} - -// Multiply 32bit value by a quantized multiplier (w/ shift) and returns a 48bit -// aligned value in the QR register. -inline ae_q56s MultiplyByQuantizedMultiplierResult48Bit( - int32_t x, int32_t quantized_multiplier, int shift) { - // Convert x into a 2x24bit PR register file. If x is outside the numerical - // limits of a 24bit integer, the "fractional" or lower 8bits are discarded. - // If x is within the range of a 24 bit integer, the "signed" or upper 8bits - // are discarded. - ae_p24x2s x_24x2; - if (x > INT24_MIN && x < INT24_MAX) { - x_24x2 = AE_MOVPA24(x); - } else { - x_24x2 = static_cast<ae_p24s>(*reinterpret_cast<ae_p24f*>(&x)); - shift += 8; - } - - return MultiplyByQuantizedMultiplier(x_24x2, quantized_multiplier, shift); -} - -// Calculate quantization params for 24bit runtimes. -inline void QuantizeMultiplierForInt24(float multiplier, - int32_t* quantized_multiplier, - int* shift) { - if (multiplier == 0.0f) { - *quantized_multiplier = 0; - *shift = 0; - return; - } - - // Special cased to 24bit: - const float q = std::frexp(multiplier, shift); - auto q_fixed = static_cast<int64_t>(std::round(q * (1 << 23))); - - TFLITE_CHECK(q_fixed <= (1 << 23)); - if (q_fixed == (1 << 23)) { - q_fixed /= 2; - ++*shift; - } - TFLITE_CHECK_LE(q_fixed, INT24_MAX); - - // Ensure shift does not exceed 24-bit range. - TFLITE_CHECK_LE(*shift, 23); - if (*shift < -23) { - *shift = 0; - q_fixed = 0; - } - *quantized_multiplier = static_cast<int32_t>(q_fixed); -} - -// Convert a floating point number to a Q representation for 24 bit integers. -inline int CreateQConstantForInt24(int integer_bits, float f) { - const float min_bounds = static_cast<float>(INT24_MIN); - const float max_bounds = static_cast<float>(INT24_MAX); - - int fractional_bits = 23 - integer_bits; - float raw = std::round(f * static_cast<float>(1 << fractional_bits)); - raw = std::max(raw, min_bounds); - raw = std::min(raw, max_bounds); - return static_cast<int>(raw); -} - -} // namespace tflite - -#endif // TENSORFLOW_LITE_MICRO_KERNELS_XTENSA_HIFIMINI_FIXEDPOINT_UTILS_H_ diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc deleted file mode 100644 index 30a5b6a602a..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc +++ /dev/null @@ -1,252 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" - -#include <xtensa/tie/xt_hifi2.h> - -#include "tensorflow/lite/c/builtin_op_data.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/kernels/internal/common.h" -#include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" -#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" -#include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/micro/kernels/kernel_util.h" -#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h" - -namespace tflite { -namespace { - -struct OpData { - // The scaling factor from input to output (aka the 'real multiplier') can - // be represented as a fixed point multiplier plus a left shift. - int32_t output_multiplier; - int output_shift; - - // Cached tensor zero point values for quantized operations. - int32_t input_zero_point; - int32_t filter_zero_point; - int32_t output_zero_point; - - // The range of the fused activation layer. For example for kNone and - // uint8_t these would be 0 and 255. - int32_t output_activation_min; - int32_t output_activation_max; - // The index of the temporary tensor where the quantized inputs are cached. - int input_quantized_index; -}; - -constexpr int kInputTensor = 0; -constexpr int kWeightsTensor = 1; -constexpr int kBiasTensor = 2; -constexpr int kOutputTensor = 0; - -void FullyConnected(const FullyConnectedParams& params, - const RuntimeShape& input_shape, const int8_t* input_data, - const RuntimeShape& filter_shape, const int8_t* filter_data, - const RuntimeShape& bias_shape, const int32_t* bias_data, - const RuntimeShape& output_shape, int8_t* output_data) { - // TODO(b/154032858): Investigate removing extra copies. - const int32_t input_offset = params.input_offset; - const int32_t filter_offset = params.weights_offset; - const int32_t output_offset = params.output_offset; - const int32_t output_multiplier = params.output_multiplier; - const int output_shift = params.output_shift; - const int32_t output_activation_min = params.quantized_activation_min; - const int32_t output_activation_max = params.quantized_activation_max; - - const int filter_dim_count = filter_shape.DimensionsCount(); - const int batches = output_shape.Dims(0); - const int output_depth = output_shape.Dims(1); - const int accum_depth = filter_shape.Dims(filter_dim_count - 1); - const int accum_depth_iters = accum_depth / 2; - - ae_p24x2s offsets_input_24x2 = AE_MOVPA24(input_offset); - ae_p24x2s offsets_filter_24x2 = AE_MOVPA24(filter_offset); - ae_q56s output_offset_56 = AE_CVTQ48A32S(output_offset); - ae_q56s output_activation_max_56 = AE_CVTQ48A32S(output_activation_max); - ae_q56s output_activation_min_56 = AE_CVTQ48A32S(output_activation_min); - - for (int b = 0; b < batches; ++b) { - for (int out_c = 0; out_c < output_depth; ++out_c) { - // Load intrinsics advance pointer before loading so backoff data pointers - // by two before loading: - const int8_t* input_ptr = (input_data + b * accum_depth) - 2; - const int8_t* filter_ptr = (filter_data + out_c * accum_depth) - 2; - - // Main accumulator register entry for loop: - ae_q56s sum_56 = AE_ZEROQ56(); - - for (int d = 0; d < accum_depth_iters; d++) { - // Load the signed 8bit values into the PR register: - ae_p24x2s input_24x2; - ae_p24x2s filter_24x2; - AE_LP8X2F_IU(input_24x2, input_ptr, 2); - AE_LP8X2F_IU(filter_24x2, filter_ptr, 2); - - // Right shift the signed 8bit values to expand to signed 24bit values: - input_24x2 = AE_P24X2S_SRAI(input_24x2, 16); - filter_24x2 = AE_P24X2S_SRAI(filter_24x2, 16); - - // Add offsets to data values (24 bit aligned): - input_24x2 = AE_P24S_ADDS_P24X2S(offsets_input_24x2, input_24x2); - filter_24x2 = AE_P24S_ADDS_P24X2S(offsets_filter_24x2, filter_24x2); - - // 24x2 signed integer dual MAC w/ addition into 56bit accumulator (48 - // bit aligned): - AE_MULAAP24S_HH_LL(sum_56, input_24x2, filter_24x2); - } - - // Left shift to get back into 32bit space (right padded to 48bit): - sum_56 = AE_Q56S_SLAI(sum_56, 16); - - // Add bias data if needed: - if (bias_data) { - ae_q56s bias_56 = AE_CVTQ48A32S(bias_data[out_c]); - sum_56 = AE_ADDQ56(sum_56, bias_56); - } - - // Shift left into 24bit space and place back on PR register: - sum_56 = AE_Q56S_SLAI(sum_56, 8); - ae_p24x2s sum_24x2 = AE_TRUNCP24Q48(sum_56); - - // MultiplyByQuantizedMultiplier returns a 48bit aligned value - sum_56 = MultiplyByQuantizedMultiplier(sum_24x2, output_multiplier, - output_shift); - - // Add output_offset and cap min/max values: - sum_56 = AE_ADDQ56(sum_56, output_offset_56); - sum_56 = AE_MINQ56S(sum_56, output_activation_max_56); - sum_56 = AE_MAXQ56S(sum_56, output_activation_min_56); - - output_data[out_c + output_depth * b] = - static_cast<int8_t>(AE_TRUNCA32Q48(sum_56)); - } - } -} - -TfLiteStatus CalculateOpData(TfLiteContext* context, - TfLiteFusedActivation activation, - TfLiteType data_type, const TfLiteTensor* input, - const TfLiteTensor* filter, - const TfLiteTensor* bias, TfLiteTensor* output, - OpData* data) { - double real_multiplier = 0.0; - TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( - context, input, filter, bias, output, &real_multiplier)); - QuantizeMultiplierForInt24(real_multiplier, &data->output_multiplier, - &data->output_shift); - return CalculateActivationRangeQuantized(context, activation, output, - &data->output_activation_min, - &data->output_activation_max); -} - -void* Init(TfLiteContext* context, const char* buffer, size_t length) { - TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); - return context->AllocatePersistentBuffer(context, sizeof(OpData)); -} - -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TFLITE_DCHECK(node->user_data != nullptr); - TFLITE_DCHECK(node->builtin_data != nullptr); - - OpData* data = static_cast<OpData*>(node->user_data); - const auto* params = - reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data); - - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); - const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - - if (input->type != kTfLiteInt8) { - TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", - TfLiteTypeGetName(input->type), input->type); - return kTfLiteError; - } - - data->input_zero_point = input->params.zero_point; - data->filter_zero_point = filter->params.zero_point; - data->output_zero_point = output->params.zero_point; - - return CalculateOpData(context, params->activation, input->type, input, - filter, bias, output, data); -} - -TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, - const OpData& data, - const TfLiteEvalTensor* input, - const TfLiteEvalTensor* filter, - const TfLiteEvalTensor* bias, - TfLiteEvalTensor* output) { - // TODO(b/154032858): Investigate removing extra copies, and also passing by - // value. TODO(b/155656675): Consider passing OpData by value once it is also - // passed to the FullyConnected function. Until it is copied to a local - // op_param variable, we do not get any latency improvements from passing by - // value. - FullyConnectedParams op_params; - op_params.input_offset = -data.input_zero_point; - op_params.weights_offset = -data.filter_zero_point; - op_params.output_offset = data.output_zero_point; - op_params.output_multiplier = data.output_multiplier; - op_params.output_shift = data.output_shift; - op_params.quantized_activation_min = data.output_activation_min; - op_params.quantized_activation_max = data.output_activation_max; - - FullyConnected(op_params, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData<int8_t>(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData<int8_t>(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData<int32_t>(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData<int8_t>(output)); - return kTfLiteOk; -} - -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - TFLITE_DCHECK(node->user_data != nullptr); - const OpData& data = *(static_cast<const OpData*>(node->user_data)); - - const TfLiteEvalTensor* input = - tflite::micro::GetEvalInput(context, node, kInputTensor); - const TfLiteEvalTensor* filter = - tflite::micro::GetEvalInput(context, node, kWeightsTensor); - const TfLiteEvalTensor* bias = - (NumInputs(node) == 3) - ? tflite::micro::GetEvalInput(context, node, kBiasTensor) - : nullptr; - - TfLiteEvalTensor* output = - tflite::micro::GetEvalOutput(context, node, kOutputTensor); - - return EvalQuantizedInt8(context, node, data, input, filter, bias, output); -} - -} // namespace - -TfLiteRegistration Register_FULLY_CONNECTED() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; -} - -} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc deleted file mode 100644 index b867e70d98b..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc +++ /dev/null @@ -1,161 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/kernels/internal/reference/quantize.h" - -#include <xtensa/tie/xt_hifi2.h> - -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" -#include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/micro/kernels/kernel_util.h" -#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h" - -namespace tflite { -namespace { - -struct OpData { - int32_t zero_point = 0; - int scale_multiplier = 0; -}; - -void AffineQuantize(int scale_multiplier, const int32_t zero_point, - const RuntimeShape& input_shape, const int16_t* input_data, - const RuntimeShape& output_shape, int8_t* output_data) { - const int flat_size = MatchingFlatSize(input_shape, output_shape); - ae_q56s min_val_56 = AE_CVTQ48A32S(INT16_MIN); - ae_q56s max_val_56 = AE_CVTQ48A32S(INT16_MAX); - ae_q56s zero_point_56 = AE_CVTQ48A32S(zero_point); - - const ae_p16x2s* input_data_ptr = (const ae_p16x2s*)(input_data - 2); - - ae_p24x2s scale_multiplier_24x2 = AE_MOVPA24(scale_multiplier); - - int iters = flat_size / 2; - for (int i = 0; i < iters; i++) { - // Load two 16bit pairs into the 2x24bit register PR: - // Values need to be right shifted 8 bits to align from upper 16bits to a - // 24bit value: - ae_p24x2s inputs_24x2; - AE_LP16X2F_IU(inputs_24x2, input_data_ptr, 4); - inputs_24x2 = AE_P24X2S_SRAI(inputs_24x2, 8); - - // Q0.23 * Q16.0 == Q16.23 - { - ae_q56s sum_56 = AE_MULP24S_HH(scale_multiplier_24x2, inputs_24x2); - - // Q16.23 -> Q16.0 - // Shift right only 7 bits (23 - 16). This truncated shift aligns the - // 16bit value at the truncation line for 32bit in the QR register. The - // lower 16 bits will be used for rounding in AE_ROUNDSQ32SYM. - sum_56 = AE_Q56S_SRAI(sum_56, 7); - - // Round and truncate 32 bits - sum_56 = AE_ROUNDSQ32SYM(sum_56); - - // Add offset (zero_point_56 is already aligned at 32bits. - sum_56 = AE_ADDQ56(sum_56, zero_point_56); - - // Saturate: - sum_56 = AE_MINQ56S(sum_56, max_val_56); - sum_56 = AE_MAXQ56S(sum_56, min_val_56); - - output_data[i * 2] = static_cast<int16_t>(AE_TRUNCA32Q48(sum_56)); - } - { - ae_q56s sum_56 = AE_MULP24S_LL(scale_multiplier_24x2, inputs_24x2); - - // Q16.23 -> Q16.0 - // Shift right only 7 bits (23 - 16). This truncated shift aligns the - // 16bit value at the truncation line for 32bit in the QR register. The - // lower 16 bits will be used for rounding in AE_ROUNDSQ32SYM. - sum_56 = AE_Q56S_SRAI(sum_56, 23 - 16); - - // Round and truncate 32 bits - sum_56 = AE_ROUNDSQ32SYM(sum_56); - - // Add offset (zero_point_56 is already aligned at 32bits. - sum_56 = AE_ADDQ56(sum_56, zero_point_56); - - // Saturate: - sum_56 = AE_MINQ56S(sum_56, max_val_56); - sum_56 = AE_MAXQ56S(sum_56, min_val_56); - - output_data[i * 2 + 1] = static_cast<int16_t>(AE_TRUNCA32Q48(sum_56)); - } - } -} - -void* Init(TfLiteContext* context, const char* buffer, size_t length) { - TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); - return context->AllocatePersistentBuffer(context, sizeof(OpData)); -} - -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TFLITE_DCHECK(node->user_data != nullptr); - auto* op_data = static_cast<OpData*>(node->user_data); - - TfLiteTensor* output = GetOutput(context, node, 0); - const TfLiteTensor* input = GetInput(context, node, 0); - - // TODO(b/155682734): Fix dangerous input/output scale ratio assumptions. - op_data->scale_multiplier = - CreateQConstantForInt24(0, input->params.scale / output->params.scale); - - op_data->zero_point = output->params.zero_point; - - return kTfLiteOk; -} - -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - TFLITE_DCHECK(node->user_data != nullptr); - auto* op_data = static_cast<OpData*>(node->user_data); - - const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0); - TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0); - - tflite::QuantizationParams op_params; - op_params.zero_point = op_data->zero_point; - - if (input->type != kTfLiteInt16 && output->type != kTfLiteInt8) { - TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", - TfLiteTypeGetName(input->type), - TfLiteTypeGetName(output->type)); - return kTfLiteError; - } - - AffineQuantize(op_data->scale_multiplier, op_data->zero_point, - tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData<int16_t>(input), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData<int8_t>(output)); - return kTfLiteOk; -} - -} // namespace - -TfLiteRegistration Register_QUANTIZE() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; -} - -} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc deleted file mode 100644 index 75eb2838034..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc +++ /dev/null @@ -1,208 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/kernels/internal/reference/softmax.h" - -#include "tensorflow/lite/c/builtin_op_data.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/kernels/internal/common.h" -#include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" -#include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/kernels/op_macros.h" -#include "tensorflow/lite/micro/kernels/kernel_util.h" - -namespace tflite { -namespace { - -struct OpData { - uint16_t* exp_lut; -}; - -// Number of unique int8_t and int16_t values. Used in exponent lookup table -// computation. -constexpr int kInt8Range = - std::numeric_limits<int8_t>::max() - std::numeric_limits<int8_t>::min() + 1; -constexpr int kInt16Range = std::numeric_limits<int16_t>::max() - - std::numeric_limits<int16_t>::min() + 1; -// Each 16-bit precalculated exponent is expressed as a Q0.16 fixedpoint -// value. We special-case e^0 since 1.0 requires 1 integer bit to -// express. -constexpr int kExpFractionalBits = 16; -// e^0 expressed as Q1.15 exceeds the int16_t range, so it must be handled -// specially. -constexpr int kMaxExponentValue = (1 << kExpFractionalBits); - -// Quantized softmax with int8_t input and int16_t output. -// Passing OpData by value does not have much savings in this op, but following -// that as a best practice, at least for the xtensa kernels. See b/155656675 for -// more details. -TfLiteStatus Softmax(OpData op_data, const RuntimeShape& input_shape, - const int8_t* input_data, const RuntimeShape& output_shape, - int16_t* output_data) { - // The last dimension is depth. Outer size is the total input size - // divided by depth. - const int trailing_dim = input_shape.DimensionsCount() - 1; - const int outer_size = - MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); - const int depth = - MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); - - for (int i = 0; i < outer_size; ++i) { - int8_t max_in_row = std::numeric_limits<int8_t>::min(); - for (int c = 0; c < depth; ++c) { - max_in_row = std::max(max_in_row, input_data[i * depth + c]); - } - - uint32_t sum_of_exps = 0; - for (int c = 0; c < depth; ++c) { - TFLITE_DCHECK(max_in_row >= input_data[i * depth + c]); - uint8_t input_diff = max_in_row - input_data[i * depth + c]; - - sum_of_exps += - input_diff == 0 ? kMaxExponentValue : op_data.exp_lut[input_diff]; - } - - // Ensure we cannot overflow the full_range_output value. We need to - // guarantee that kInt16Range * max(input_data) / sum_of_exps < kInt16Range. - TFLITE_DCHECK(sum_of_exps >= kMaxExponentValue); - - for (int c = 0; c < depth; ++c) { - uint8_t input_diff = max_in_row - input_data[i * depth + c]; - // Special case for diff == 0 - uint32_t unscaled_output = - input_diff == 0 ? kMaxExponentValue : op_data.exp_lut[input_diff]; - int64_t scaled_output = static_cast<int64_t>(unscaled_output) * - static_cast<int64_t>(kInt16Range); - int32_t full_range_output = - scaled_output / sum_of_exps + std::numeric_limits<int16_t>::min(); - // Round up if remainder exceeds half of the divider value. - uint32_t remainder = scaled_output % sum_of_exps; - if (remainder * 2 >= sum_of_exps) { - full_range_output++; - } - output_data[i * depth + c] = static_cast<int16_t>(std::max( - std::min(full_range_output, - static_cast<int32_t>(std::numeric_limits<int16_t>::max())), - static_cast<int32_t>(std::numeric_limits<int16_t>::min()))); - } - } - return kTfLiteOk; -} - -TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context, - const TfLiteTensor* input, - TfLiteTensor* output, - const TfLiteSoftmaxParams* params, - OpData* op_data) { - if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { - if (input->type == kTfLiteUInt8) { - TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); - } else { - if (output->type == kTfLiteInt16) { - TF_LITE_ENSURE_EQ(context, output->params.zero_point, - std::numeric_limits<int16_t>::min()); - // NOTE: Current int16_t softmax output does not require symmetric - // scaling - // - so no need to verify scale here. - } else { - TF_LITE_ENSURE_EQ(context, output->params.zero_point, - std::numeric_limits<int8_t>::min()); - TF_LITE_ENSURE(context, output->params.scale == 1.f / 256); - } - } - - // Precompute e^(-x * input_scale * beta) for every possible int8_t input. - // This computation is used for every iteration of Softmax. We must compute - // using pre-scaled inputs to avoid introducing additional error, while - // restricting our input range to the int8_t range. This is valid since beta - // and input scale are constant for a given op in the graph. Skip index 0 - // since that is a special case which requires 1 integer bit instead of 0. - for (int i = 1; i <= kInt8Range; i++) { - float scaled_input = i * input->params.scale; - float exp_value = - std::exp((-scaled_input) * static_cast<float>(params->beta)); - - float exponent_scaled = - std::round(exp_value * static_cast<float>(1 << kExpFractionalBits)); - op_data->exp_lut[i] = static_cast<uint16_t>(exponent_scaled); - } - } - return kTfLiteOk; -} - -void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) { - TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); - return context->AllocatePersistentBuffer(context, sizeof(OpData)); -} - -TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { - auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data); - - TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); - TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - const TfLiteTensor* input = GetInput(context, node, 0); - TfLiteTensor* output = GetOutput(context, node, 0); - TF_LITE_ENSURE(context, NumDimensions(input) >= 1); - - TFLITE_DCHECK(node->user_data != nullptr); - OpData* op_data = static_cast<OpData*>(node->user_data); - - // Allocate an array to precompute exponents over all int8_t inputs, applying - // the scale and beta before calculating exp. It is mandatory to apply beta - // and scale here, since each softmax op may have different beta and scale - // values. Beta and scale will remain constant for a given softmax op. - op_data->exp_lut = static_cast<uint16_t*>(context->AllocatePersistentBuffer( - context, kInt8Range * sizeof(uint16_t))); - TF_LITE_ENSURE(context, op_data->exp_lut != nullptr); - - TF_LITE_ENSURE_STATUS( - CalculateSoftmaxOpData(context, input, output, params, op_data)); - - return kTfLiteOk; -} - -TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { - auto* op_data = static_cast<OpData*>(node->user_data); - - const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0); - TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0); - - if (input->type == kTfLiteInt8 && output->type == kTfLiteInt16) { - return Softmax(*op_data, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData<int8_t>(input), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData<int16_t>(output)); - } else { - TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", - TfLiteTypeGetName(input->type), input->type); - return kTfLiteError; - } -} - -} // namespace - -TfLiteRegistration Register_SOFTMAX() { - return {/*init=*/SoftmaxInit, - /*free=*/nullptr, - /*prepare=*/SoftmaxPrepare, - /*invoke=*/SoftmaxEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; -} - -} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/svdf.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/svdf.cc deleted file mode 100644 index 28f8f1e1af0..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/svdf.cc +++ /dev/null @@ -1,420 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include <math.h> -#include <xtensa/tie/xt_hifi2.h> - -#include "tensorflow/lite/c/builtin_op_data.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/kernels/internal/common.h" -#include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" -#include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/kernels/op_macros.h" -#include "tensorflow/lite/micro/kernels/activation_utils.h" -#include "tensorflow/lite/micro/kernels/kernel_util.h" -#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h" - -namespace tflite { -namespace { - -struct OpData { - int32_t effective_scale_1_a; - int32_t effective_scale_2_a; - // b versions of each scale are kept at int since the numbers are just the - // shift value - typically between [-32, 32]. - int effective_scale_1_b; - int effective_scale_2_b; - int scratch_tensor_index; - int scratch_output_tensor_index; - - // Cached tensor zero point values for quantized operations. - int input_zero_point; - int output_zero_point; -}; - -// Input tensors. -constexpr int kInputTensor = 0; -constexpr int kWeightsFeatureTensor = 1; -constexpr int kWeightsTimeTensor = 2; -constexpr int kBiasTensor = 3; -// This is a variable tensor, and will be modified by this op. -constexpr int kInputActivationStateTensor = 4; - -// Output tensor. -constexpr int kOutputTensor = 0; - -/** - * This version of SVDF is specific to TFLite Micro. It contains only a full - * integer receipe with optimizations for the Xtensa HiFiMini platform. - * - * Note: passing OpData by value might seem like an oversight but it helps - * reduce the latency. See b/155656675 for more details. - */ -void EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node, - const TfLiteEvalTensor* input_tensor, - const TfLiteEvalTensor* weights_feature_tensor, - const TfLiteEvalTensor* weights_time_tensor, - const TfLiteEvalTensor* bias_tensor, - const TfLiteSVDFParams* params, - TfLiteEvalTensor* activation_state_tensor, - TfLiteEvalTensor* output_tensor, OpData data) { - const int n_rank = params->rank; - const int n_batch = input_tensor->dims->data[0]; - const int n_input = input_tensor->dims->data[1]; - const int n_filter = weights_feature_tensor->dims->data[0]; - const int n_unit = n_filter / n_rank; - const int n_memory = weights_time_tensor->dims->data[1]; - - TFLITE_DCHECK(context != nullptr); - TFLITE_DCHECK(context->GetScratchBuffer != nullptr); - - int32_t* scratch_tensor = static_cast<int32_t*>( - context->GetScratchBuffer(context, data.scratch_tensor_index)); - TFLITE_DCHECK(scratch_tensor != nullptr); - int32_t* scratch_output_tensor = static_cast<int32_t*>( - context->GetScratchBuffer(context, data.scratch_output_tensor_index)); - TFLITE_DCHECK(scratch_output_tensor != nullptr); - - // Shift states. - int16_t* const state_ptr = - tflite::micro::GetTensorData<int16_t>(activation_state_tensor); - - // Left shift the activation_state. - { - int16_t* new_state_start = state_ptr; - const int16_t* old_state_start = state_ptr + 1; - const int16_t* old_state_end = state_ptr + n_batch * n_filter * n_memory; - while (old_state_start != old_state_end) { - *new_state_start++ = *old_state_start++; - } - } - - // Note: no need to clear the latest activation, matmul is not accumulative. - - // Feature matmul. - { - const int8_t* input = tflite::micro::GetTensorData<int8_t>(input_tensor); - const int8_t* weight_feature = - tflite::micro::GetTensorData<int8_t>(weights_feature_tensor); - int16_t* result_in_batch = state_ptr + (n_memory - 1); - - ae_q56s output_int16_max_56 = AE_CVTQ48A32S(INT16_MAX); - ae_q56s output_int16_min_56 = AE_CVTQ48A32S(INT16_MIN); - ae_p24x2s input_zp_24x2 = AE_MOVPA24(data.input_zero_point); - - for (int b = 0; b < n_batch; b++) { - const int8_t* weight_feature_ptr = weight_feature - 2; - - for (int r = 0; r < n_filter; r++) { - ae_q56s dot_prod_56 = AE_ZEROQ56(); - - const int8_t* input_batch_ptr = input + b * n_input; - const int8_t* offset_input_batch_ptr = input_batch_ptr - 2; - - int num_iters = n_input / 2; - for (int c = 0; c < num_iters; c++) { - // Load 2 sets of values: - ae_p24x2s weight_feature_ptr_24x2; - ae_p24x2s input_batch_ptr_24x2; - AE_LP8X2F_IU(weight_feature_ptr_24x2, weight_feature_ptr, 2); - AE_LP8X2F_IU(input_batch_ptr_24x2, offset_input_batch_ptr, 2); - - // Right shift the signed 8bit values to expand to signed 24bit - // values: - weight_feature_ptr_24x2 = AE_P24X2S_SRAI(weight_feature_ptr_24x2, 16); - input_batch_ptr_24x2 = AE_P24X2S_SRAI(input_batch_ptr_24x2, 16); - - // First subtract input_zp from input_batch_ptr_24x2: - input_batch_ptr_24x2 = - AE_SUBSP24S(input_batch_ptr_24x2, input_zp_24x2); - - // Multiply accum: - AE_MULAAP24S_HH_LL(dot_prod_56, weight_feature_ptr_24x2, - input_batch_ptr_24x2); - } - - // Left shift 48bit value into 24bit space and place on the PR register: - dot_prod_56 = AE_Q56S_SLAI(dot_prod_56, 24); - ae_p24x2s dot_prod_24x2 = AE_TRUNCP24Q48(dot_prod_56); - - dot_prod_56 = MultiplyByQuantizedMultiplier( - dot_prod_24x2, data.effective_scale_1_a, data.effective_scale_1_b); - - // Cap min/max and convert to int32_t: - dot_prod_56 = AE_MAXQ56S(dot_prod_56, output_int16_min_56); - dot_prod_56 = AE_MINQ56S(dot_prod_56, output_int16_max_56); - // Truncate immediately since the QR register is already 32 bit aligned: - // This assumes state is symmetrically quantized. Otherwise last bit of - // state should be initialized to its zero point and accumulate the - // dot_prod. - // Equivalent as the following: - // result_in_batch = zero point, which happens to be zero. - // result_in_batch += dot_prod_56. - *result_in_batch = AE_TRUNCA32Q48(dot_prod_56); - result_in_batch += n_memory; - } - } - } - - // Time. - { - for (int b = 0; b < n_batch; ++b) { - int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter; - - // Perform batched vector dot product: - const int16_t* vector1_ptr = - tflite::micro::GetTensorData<int16_t>(weights_time_tensor); - const int16_t* vector2_ptr = state_ptr + b * n_memory * n_filter; - - const ae_p16x2s* offset_vector1 = - reinterpret_cast<const ae_p16x2s*>(vector1_ptr - 2); - const ae_p16x2s* offset_vector2 = - reinterpret_cast<const ae_p16x2s*>(vector2_ptr - 2); - - for (int i = 0; i < n_filter; i++) { - *scratch_ptr_batch = 0; - - ae_q56s sum_56 = AE_ZEROQ56(); - int num_iters = n_memory / 2; - for (int j = 0; j < num_iters; j++) { - ae_p24x2s vector1_24x2; - ae_p24x2s vector2_24x2; - AE_LP16X2F_IU(vector1_24x2, offset_vector1, 4); - AE_LP16X2F_IU(vector2_24x2, offset_vector2, 4); - AE_MULAAP24S_HH_LL(sum_56, vector1_24x2, vector2_24x2); - } - // Truncate directly since values are already 32bit aligned: - *scratch_ptr_batch = AE_TRUNCA32Q48(sum_56); - scratch_ptr_batch++; - } - } - } - - // Reduce, add bias, rescale, activation. - { - // Add bias. - if (bias_tensor) { - // Vector batch assign: - const int32_t* bias_data = - tflite::micro::GetTensorData<int32_t>(bias_tensor); - for (int i = 0; i < n_batch; ++i) { - int32_t* output_ptr = scratch_output_tensor + i * n_unit; - const int32_t* bias_ptr = bias_data; - for (int j = 0; j < n_unit; ++j) { - *output_ptr++ = *bias_ptr++; - } - } - } else { - int32_t* output_ptr = scratch_output_tensor; - for (int i = 0; i < n_batch * n_unit; ++i) { - *output_ptr++ = 0; - } - } - - // Reduce. - for (int b = 0; b < n_batch; ++b) { - int32_t* output_temp_ptr = scratch_output_tensor + b * n_unit; - int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter; - - // Reduction sum vector - for (int i = 0; i < n_unit; ++i) { - for (int j = 0; j < n_rank; ++j) { - output_temp_ptr[i] += *scratch_ptr_batch++; - } - } - } - - // Rescale. - ae_q56s output_int8_max_56 = AE_CVTQ48A32S(INT8_MAX); - ae_q56s output_int8_min_56 = AE_CVTQ48A32S(INT8_MIN); - ae_q56s output_zp_56 = AE_CVTQ48A32S(data.output_zero_point); - for (int i = 0; i < n_batch * n_unit; ++i) { - ae_q56s x_56 = MultiplyByQuantizedMultiplierResult48Bit( - scratch_output_tensor[i], data.effective_scale_2_a, - data.effective_scale_2_b); - // Add output adjustment: - x_56 = AE_ADDQ56(x_56, output_zp_56); - // Cap min/max and convert to int32_t (already aligned to 32bit): - x_56 = AE_MAXQ56S(x_56, output_int8_min_56); - x_56 = AE_MINQ56S(x_56, output_int8_max_56); - tflite::micro::GetTensorData<int8_t>(output_tensor)[i] = - static_cast<int8_t>(AE_TRUNCA32Q48(x_56)); - } - } -} - -void* Init(TfLiteContext* context, const char* buffer, size_t length) { - TFLITE_DCHECK(context != nullptr); - return context->AllocatePersistentBuffer(context, sizeof(OpData)); -} - -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TFLITE_DCHECK(node->builtin_data != nullptr); - const auto* params = static_cast<const TfLiteSVDFParams*>(node->builtin_data); - - // Validate Tensor Inputs (dtype depends on quantization): - // [0] = Input, {2, batch_size, input_size} - // [1] = Weights Feature, {2, num_filters, input_size} - // [2] = Weights Time, {2, num_filters, memory_size} - // [3] = Bias (optional), {1, num_units} - // [4] = Activation State (variable), - // {2, batch_size, memory_size * num_filters} - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* weights_feature = - GetInput(context, node, kWeightsFeatureTensor); - const TfLiteTensor* weights_time = - GetInput(context, node, kWeightsTimeTensor); - const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); - const TfLiteTensor* activation_state = - GetInput(context, node, kInputActivationStateTensor); - - // Define input constants based on input tensor definition above: - const int rank = params->rank; - const int input_size = input->dims->data[1]; - const int batch_size = input->dims->data[0]; - // Ensure the input size is a multiple of two. This is necessary since - // optimized kernels access the memory in chunks of two, and all accesses - // must be aligned to 16 bits. - // TODO(b/153202598): Remove when padding is allowed in TFLite tensors. - TF_LITE_ENSURE_EQ(context, input_size % 2, 0); - - const int num_filters = weights_feature->dims->data[0]; - TF_LITE_ENSURE_EQ(context, num_filters % rank, 0); - const int num_units = num_filters / rank; - const int memory_size = weights_time->dims->data[1]; - - if (input->type != kTfLiteInt8) { - TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", - TfLiteTypeGetName(input->type), input->type); - return kTfLiteError; - } - - // Validate Input Tensor: - TF_LITE_ENSURE(context, input->type == kTfLiteInt8); - TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2); - - // Validate Tensor Output: - // [0] = float/int8_t, {2, batch_size, num_units} - TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TF_LITE_ENSURE_EQ(context, NumDimensions(output), 2); - TF_LITE_ENSURE_EQ(context, output->dims->data[0], batch_size); - TF_LITE_ENSURE_EQ(context, output->dims->data[1], num_units); - - // Validate Weights Feature Input Tensor: - TF_LITE_ENSURE_EQ(context, NumDimensions(weights_feature), 2); - TF_LITE_ENSURE_EQ(context, weights_feature->dims->data[1], input_size); - - // Validate Weights Time Input Tensor: - TF_LITE_ENSURE_EQ(context, NumDimensions(weights_time), 2); - TF_LITE_ENSURE_EQ(context, weights_time->dims->data[0], num_filters); - TF_LITE_ENSURE_EQ(context, weights_time->dims->data[1], memory_size); - - // Validate Optional Bias Input Tensor: - if (bias != nullptr) { - TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units); - TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32); - } - - // Validate Activation State Input Tensor: - TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2); - TF_LITE_ENSURE_EQ(context, activation_state->dims->data[0], batch_size); - TF_LITE_ENSURE_EQ(context, activation_state->dims->data[1], - memory_size * num_filters); - - TF_LITE_ENSURE_EQ(context, node->inputs->size, 5); - TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8); - TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteInt16); - TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16); - - // Validate output tensor: - TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8); - - const double effective_scale_1 = - static_cast<double>(input->params.scale * weights_feature->params.scale / - activation_state->params.scale); - const double effective_scale_2 = - static_cast<double>(activation_state->params.scale * - weights_time->params.scale / output->params.scale); - - TF_LITE_ENSURE_EQ(context, static_cast<double>(bias->params.scale), - static_cast<double>(activation_state->params.scale * - weights_time->params.scale)); - - TFLITE_DCHECK(node->user_data != nullptr); - OpData* data = static_cast<OpData*>(node->user_data); - - QuantizeMultiplierForInt24(effective_scale_1, &data->effective_scale_1_a, - &data->effective_scale_1_b); - QuantizeMultiplierForInt24(effective_scale_2, &data->effective_scale_2_a, - &data->effective_scale_2_b); - - data->input_zero_point = input->params.zero_point; - data->output_zero_point = output->params.zero_point; - - const TfLiteStatus scratch_status = context->RequestScratchBufferInArena( - context, batch_size * num_filters * sizeof(int32_t), - &(data->scratch_tensor_index)); - TF_LITE_ENSURE_OK(context, scratch_status); - const TfLiteStatus scratch_output_status = - context->RequestScratchBufferInArena( - context, batch_size * num_units * sizeof(int32_t), - &(data->scratch_output_tensor_index)); - TF_LITE_ENSURE_OK(context, scratch_output_status); - - return kTfLiteOk; -} - -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = static_cast<TfLiteSVDFParams*>(node->builtin_data); - - const TfLiteEvalTensor* input = - tflite::micro::GetEvalInput(context, node, kInputTensor); - const TfLiteEvalTensor* weights_feature = - tflite::micro::GetEvalInput(context, node, kWeightsFeatureTensor); - const TfLiteEvalTensor* weights_time = - tflite::micro::GetEvalInput(context, node, kWeightsTimeTensor); - const TfLiteEvalTensor* bias = - (NumInputs(node) == 5) - ? tflite::micro::GetEvalInput(context, node, kBiasTensor) - : nullptr; - TfLiteEvalTensor* activation_state = tflite::micro::GetMutableEvalInput( - context, node, kInputActivationStateTensor); - TfLiteEvalTensor* output = - tflite::micro::GetEvalOutput(context, node, kOutputTensor); - - TFLITE_DCHECK(node->user_data != nullptr); - const OpData& data = *(static_cast<const OpData*>(node->user_data)); - - EvalIntegerSVDF(context, node, input, weights_feature, weights_time, bias, - params, activation_state, output, data); - return kTfLiteOk; -} - -} // namespace - -TfLiteRegistration Register_SVDF() { - return {/*init=*/Init, - /*free=*/nullptr, - /*prepare=*/Prepare, - /*invoke=*/Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; -} - -} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/fully_connected.cc deleted file mode 100644 index f9b49a2f1ae..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/fully_connected.cc +++ /dev/null @@ -1,197 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2020 Cadence Design Systems, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining - * a copy of this software and associated documentation files (the - * "Software"), to use this Software with Cadence processor cores only and - * not with any other processors and platforms, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - ******************************************************************************/ - -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" - -#include <xtensa/tie/xt_hifi2.h> - -#include "tensorflow/lite/c/builtin_op_data.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/kernels/internal/common.h" -#include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" -#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" -#include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h" -#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xtensa_tf_micro_common.h" -namespace tflite { -namespace ops { -namespace micro { - -namespace fully_connected { -namespace { - -struct OpData { - // The scaling factor from input to output (aka the 'real multiplier') can - // be represented as a fixed point multiplier plus a left shift. - int32_t output_multiplier; - int output_shift; - // The range of the fused activation layer. For example for kNone and - // uint8_t these would be 0 and 255. - int32_t output_activation_min; - int32_t output_activation_max; - // The index of the temporary tensor where the quantized inputs are cached. - int input_quantized_index; -}; - -constexpr int kInputTensor = 0; -constexpr int kWeightsTensor = 1; -constexpr int kBiasTensor = 2; -constexpr int kOutputTensor = 0; - -TfLiteStatus CalculateOpData(TfLiteContext* context, - TfLiteFusedActivation activation, - TfLiteType data_type, const TfLiteTensor* input, - const TfLiteTensor* filter, - const TfLiteTensor* bias, TfLiteTensor* output, - OpData* data) { - if (data_type != kTfLiteInt8) { - TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", - TfLiteTypeGetName(data_type), data_type); - return kTfLiteError; - } - - double real_multiplier = 0.0; - TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( - context, input, filter, bias, output, &real_multiplier)); - xtensa::hifimini::QuantizeMultiplier( - real_multiplier, &data->output_multiplier, &data->output_shift); - return CalculateActivationRangeQuantized(context, activation, output, - &data->output_activation_min, - &data->output_activation_max); -} - -} // namespace - -void* Init(TfLiteContext* context, const char* buffer, size_t length) { - TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); - return context->AllocatePersistentBuffer(context, sizeof(OpData)); -} - -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TFLITE_DCHECK(node->user_data != nullptr); - TFLITE_DCHECK(node->builtin_data != nullptr); - - OpData* data = static_cast<OpData*>(node->user_data); - const auto* params = - reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data); - - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); - const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - - return CalculateOpData(context, params->activation, input->type, input, - filter, bias, output, data); -} - -TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, - const OpData& data, const TfLiteTensor* input, - const TfLiteTensor* filter, - const TfLiteTensor* bias, TfLiteTensor* output) { - // TODO(b/154032858): Investigate removing extra copies. - FullyConnectedParams op_params; - op_params.input_offset = -input->params.zero_point; - op_params.weights_offset = -filter->params.zero_point; - op_params.output_offset = output->params.zero_point; - op_params.output_multiplier = data.output_multiplier; - op_params.output_shift = data.output_shift; - op_params.quantized_activation_min = data.output_activation_min; - op_params.quantized_activation_max = data.output_activation_max; - - { - int ret, b, weight_depth, out_depth, batches; - int8_t* p_out = GetTensorData<int8_t>(output); - weight_depth = GetTensorShape(filter).Dims( - GetTensorShape(filter).DimensionsCount() - 1); - out_depth = GetTensorShape(output).Dims( - GetTensorShape(output).DimensionsCount() - 1); - batches = FlatSizeSkipDim(GetTensorShape(output), - GetTensorShape(output).DimensionsCount() - 1); - - // TODO: Use xa_nn_fully_connected_sym8xasym8s_asym8s? the kernel tests fail - // with it. - for (b = 0; b < batches; b++) { - ret = xa_nn_fully_connected_asym8sxasym8s_asym8s( - (GetTensorData<int8_t>(output) + b * out_depth), - GetTensorData<int8_t>(filter), - (GetTensorData<int8_t>(input) + b * weight_depth), - GetTensorData<int32_t>(bias), weight_depth, out_depth, - op_params.weights_offset, op_params.input_offset, - (op_params.output_multiplier << 8), op_params.output_shift, - op_params.output_offset); - CHECK_ERR_HIFI_NNLIB_KER( - ret, "xa_nn_fully_connected_sym8xasym8s_asym8s failed"); - } - ret = xa_nn_vec_activation_min_max_asym8s_asym8s( - p_out, p_out, data.output_activation_min, data.output_activation_max, - batches * out_depth); - CHECK_ERR_HIFI_NNLIB_KER( - ret, - "fully_connected: xa_nn_vec_activation_min_max_asym8s_asym8s failed"); - } - return kTfLiteOk; -} - -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - TFLITE_DCHECK(node->user_data != nullptr); - const OpData& data = *(static_cast<const OpData*>(node->user_data)); - - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); - const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - - TFLITE_DCHECK(filter->type == kTfLiteInt8); - return EvalQuantizedInt8(context, node, data, input, filter, bias, output); -} - -} // namespace fully_connected - -TfLiteRegistration Register_FULLY_CONNECTED() { - return {/*init=*/fully_connected::Init, - /*free=*/nullptr, - /*prepare=*/fully_connected::Prepare, - /*invoke=*/fully_connected::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; -} - -} // namespace micro -} // namespace ops -} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/quantize.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/quantize.cc deleted file mode 100644 index 13c19cc6f34..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/quantize.cc +++ /dev/null @@ -1,172 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/kernels/internal/reference/quantize.h" - -#include <xtensa/tie/xt_hifi2.h> - -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" -#include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h" - -namespace tflite { -namespace ops { -namespace micro { - -namespace xtensa { -namespace hifimini { - -void AffineQuantize(int scale_multiplier, - const tflite::QuantizationParams& op_params, - const RuntimeShape& input_shape, const int16_t* input_data, - const RuntimeShape& output_shape, int8_t* output_data) { - const int32_t zero_point = op_params.zero_point; - const int flat_size = MatchingFlatSize(input_shape, output_shape); - ae_q56s min_val_56 = AE_CVTQ48A32S(INT16_MIN); - ae_q56s max_val_56 = AE_CVTQ48A32S(INT16_MAX); - ae_q56s zero_point_56 = AE_CVTQ48A32S(zero_point); - - const ae_p16x2s* input_data_ptr = (const ae_p16x2s*)(input_data - 2); - - ae_p24x2s scale_multiplier_24x2 = AE_MOVPA24(scale_multiplier); - - int iters = flat_size / 2; - for (int i = 0; i < iters; i++) { - // Load two 16bit pairs into the 2x24bit register PR: - // Values need to be right shifted 8 bits to align from upper 16bits to a - // 24bit value: - ae_p24x2s inputs_24x2; - AE_LP16X2F_IU(inputs_24x2, input_data_ptr, 4); - inputs_24x2 = AE_P24X2S_SRAI(inputs_24x2, 8); - - // Q0.23 * Q16.0 == Q16.23 - { - ae_q56s sum_56 = AE_MULP24S_HH(scale_multiplier_24x2, inputs_24x2); - - // Q16.23 -> Q16.0 - // Shift right only 7 bits (23 - 16). This truncated shift aligns the - // 16bit value at the truncation line for 32bit in the QR register. The - // lower 16 bits will be used for rounding in AE_ROUNDSQ32SYM. - sum_56 = AE_Q56S_SRAI(sum_56, 7); - - // Round and truncate 32 bits - sum_56 = AE_ROUNDSQ32SYM(sum_56); - - // Add offset (zero_point_56 is already aligned at 32bits. - sum_56 = AE_ADDQ56(sum_56, zero_point_56); - - // Saturate: - sum_56 = AE_MINQ56S(sum_56, max_val_56); - sum_56 = AE_MAXQ56S(sum_56, min_val_56); - - output_data[i * 2] = static_cast<int16_t>(AE_TRUNCA32Q48(sum_56)); - } - { - ae_q56s sum_56 = AE_MULP24S_LL(scale_multiplier_24x2, inputs_24x2); - - // Q16.23 -> Q16.0 - // Shift right only 7 bits (23 - 16). This truncated shift aligns the - // 16bit value at the truncation line for 32bit in the QR register. The - // lower 16 bits will be used for rounding in AE_ROUNDSQ32SYM. - sum_56 = AE_Q56S_SRAI(sum_56, 23 - 16); - - // Round and truncate 32 bits - sum_56 = AE_ROUNDSQ32SYM(sum_56); - - // Add offset (zero_point_56 is already aligned at 32bits. - sum_56 = AE_ADDQ56(sum_56, zero_point_56); - - // Saturate: - sum_56 = AE_MINQ56S(sum_56, max_val_56); - sum_56 = AE_MAXQ56S(sum_56, min_val_56); - - output_data[i * 2 + 1] = static_cast<int16_t>(AE_TRUNCA32Q48(sum_56)); - } - } -} - -} // namespace hifimini -} // namespace xtensa - -namespace quantize { - -struct OpData { - int scale_multiplier = 0; -}; - -void* Init(TfLiteContext* context, const char* buffer, size_t length) { - TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); - return context->AllocatePersistentBuffer(context, sizeof(OpData)); -} - -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TFLITE_DCHECK(node->user_data != nullptr); - auto* op_data = static_cast<OpData*>(node->user_data); - - TfLiteTensor* output = GetOutput(context, node, 0); - const TfLiteTensor* input = GetInput(context, node, 0); - - // TODO(b/155682734): Fix dangerous input/output scale ratio assumptions. - op_data->scale_multiplier = xtensa::hifimini::CreateQConstantForInt24( - 0, input->params.scale / output->params.scale); - - return kTfLiteOk; -} - -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - TFLITE_DCHECK(node->user_data != nullptr); - auto* op_data = static_cast<OpData*>(node->user_data); - - const TfLiteTensor* input = GetInput(context, node, 0); - TfLiteTensor* output = GetOutput(context, node, 0); - - tflite::QuantizationParams op_params; - op_params.zero_point = output->params.zero_point; - - if (input->type != kTfLiteInt16 && output->type != kTfLiteInt8) { - TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", - TfLiteTypeGetName(input->type), - TfLiteTypeGetName(output->type)); - return kTfLiteError; - } - - xtensa::hifimini::AffineQuantize( - op_data->scale_multiplier, op_params, GetTensorShape(input), - GetTensorData<int16_t>(input), GetTensorShape(output), - GetTensorData<int8_t>(output)); - return kTfLiteOk; -} - -} // namespace quantize - -// This Op (QUANTIZE) quantizes the input and produces quantized output. -// AffineQuantize takes scale and zero point and quantizes the float value to -// quantized output, in int8_t or uint8_t format. -TfLiteRegistration Register_QUANTIZE() { - return {/*init=*/quantize::Init, - /*free=*/nullptr, - /*prepare=*/quantize::Prepare, - /*invoke=*/quantize::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; -} - -} // namespace micro -} // namespace ops -} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/softmax.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/softmax.cc deleted file mode 100644 index 3e5ef198928..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/softmax.cc +++ /dev/null @@ -1,189 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2020 Cadence Design Systems, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining - * a copy of this software and associated documentation files (the - * "Software"), to use this Software with Cadence processor cores only and - * not with any other processors and platforms, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - ******************************************************************************/ - -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/kernels/internal/reference/softmax.h" - -#include "tensorflow/lite/c/builtin_op_data.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/kernels/internal/common.h" -#include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" -#include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/kernels/op_macros.h" -#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xtensa_tf_micro_common.h" - -namespace tflite { -namespace ops { -namespace micro { -namespace activations { -namespace { - -struct OpData { - int32_t input_multiplier; - int32_t input_left_shift; - int32_t diff_min; - int scratch_tensor_index; -}; - -} // namespace - -TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context, - const TfLiteTensor* input, - TfLiteTensor* output, - const TfLiteSoftmaxParams* params, - OpData* op_data) { - if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { - if (input->type == kTfLiteUInt8) { - TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); - } else { - if (output->type == kTfLiteInt16) { - TF_LITE_ENSURE_EQ(context, output->params.zero_point, - std::numeric_limits<int16_t>::min()); - // NOTE: Current int16_t softmax output does not require symmetric - // scaling - // - so no need to verify scale here. - } else { - TF_LITE_ENSURE_EQ(context, output->params.zero_point, - std::numeric_limits<int8_t>::min()); - TF_LITE_ENSURE(context, output->params.scale == 1.f / 256); - } - } - - static const int kScaledDiffIntegerBits = 5; - - int input_left_shift; - tflite::PreprocessSoftmaxScaling( - static_cast<double>(params->beta), - static_cast<double>(input->params.scale), kScaledDiffIntegerBits, - &op_data->input_multiplier, &input_left_shift); - op_data->input_left_shift = input_left_shift; - op_data->diff_min = - -1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits, - op_data->input_left_shift); - } - return kTfLiteOk; -} - -void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) { - TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); - return context->AllocatePersistentBuffer(context, sizeof(OpData)); -} - -TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { - auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data); - - TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); - TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - const TfLiteTensor* input = GetInput(context, node, 0); - TfLiteTensor* output = GetOutput(context, node, 0); - TF_LITE_ENSURE(context, NumDimensions(input) >= 1); - - TFLITE_DCHECK(node->user_data != nullptr); - OpData* op_data = static_cast<OpData*>(node->user_data); - - const RuntimeShape& input_shape = GetTensorShape(input); - const RuntimeShape& output_shape = GetTensorShape(output); - const int trailing_dim = input_shape.DimensionsCount() - 1; - const int depth = - MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); - int scratch_size = - xa_nn_get_softmax_scratch_size(PREC_SYM8S, PREC_SYM8S, depth); - - const TfLiteStatus scratch_status = context->RequestScratchBufferInArena( - context, scratch_size, &(op_data->scratch_tensor_index)); - TF_LITE_ENSURE_OK(context, scratch_status); - // Allocate an array to precompute exponents over all int8_t inputs, applying - // the scale and beta before calculating exp. It is mandatory to apply beta - // and scale here, since each softmax op may have different beta and scale - // values. Beta and scale will remain constant for a given softmax op. - - TF_LITE_ENSURE_STATUS( - CalculateSoftmaxOpData(context, input, output, params, op_data)); - - return kTfLiteOk; -} - -TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { - auto* op_data = static_cast<OpData*>(node->user_data); - - const TfLiteTensor* input = GetInput(context, node, 0); - TfLiteTensor* output = GetOutput(context, node, 0); - - if (input->type == kTfLiteInt8 && output->type == kTfLiteInt16) { - const RuntimeShape& input_shape = GetTensorShape(input); - const int8_t* input_data = GetTensorData<int8_t>(input); - const RuntimeShape& output_shape = GetTensorShape(output); - int16_t* output_data = GetTensorData<int16_t>(output); - const int trailing_dim = input_shape.DimensionsCount() - 1; - const int outer_size = - MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); - const int depth = - MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); - - void* p_scratch = static_cast<void*>( - context->GetScratchBuffer(context, op_data->scratch_tensor_index)); - TFLITE_DCHECK(p_scratch != nullptr); - - for (int i = 0; i < outer_size; ++i) { - int err = xa_nn_vec_softmax_asym8s_16( - &output_data[i * depth], &input_data[i * depth], op_data->diff_min, - op_data->input_left_shift, op_data->input_multiplier, depth, - p_scratch); - CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_softmax_asym8s_16 failed"); - } - return kTfLiteOk; - } else { - TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", - TfLiteTypeGetName(input->type), input->type); - return kTfLiteError; - } -} -} // namespace activations - -TfLiteRegistration Register_SOFTMAX() { - return {/*init=*/activations::SoftmaxInit, - /*free=*/nullptr, - /*prepare=*/activations::SoftmaxPrepare, - /*invoke=*/activations::SoftmaxEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; -} - -} // namespace micro -} // namespace ops -} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/svdf.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/svdf.cc deleted file mode 100644 index 05256f33306..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/svdf.cc +++ /dev/null @@ -1,356 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2020 Cadence Design Systems, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining - * a copy of this software and associated documentation files (the - * "Software"), to use this Software with Cadence processor cores only and - * not with any other processors and platforms, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - ******************************************************************************/ - -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include <math.h> -#include <xtensa/tie/xt_hifi2.h> - -#include "tensorflow/lite/c/builtin_op_data.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/kernels/internal/common.h" -#include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" -#include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/kernels/op_macros.h" -#include "tensorflow/lite/micro/kernels/activation_utils.h" -#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h" -#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xtensa_tf_micro_common.h" - -namespace tflite { -namespace ops { -namespace micro { -namespace svdf { -namespace { - -struct OpData { - int32_t effective_scale_1_a; - int32_t effective_scale_2_a; - // b versions of each scale are kept at int since the numbers are just the - // shift value - typically between [-32, 32]. - int effective_scale_1_b; - int effective_scale_2_b; - int scratch_tensor_index; - int scratch_output_tensor_index; -}; - -// Input tensors. -constexpr int kInputTensor = 0; -constexpr int kWeightsFeatureTensor = 1; -constexpr int kWeightsTimeTensor = 2; -constexpr int kBiasTensor = 3; -// This is a variable tensor, and will be modified by this op. -constexpr int kInputActivationStateTensor = 4; - -// Output tensor. -constexpr int kOutputTensor = 0; - -/** - * This version of SVDF is specific to TFLite Micro. It contains only a full - * integer receipe with optimizations for the Xtensa HiFiMini platform. - * - * Note: passing OpData by value might seem like an oversight but it helps - * reduce the latency. See b/155656675 for more details. - */ -TfLiteStatus EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node, - const TfLiteTensor* input_tensor, - const TfLiteTensor* weights_feature_tensor, - const TfLiteTensor* weights_time_tensor, - const TfLiteTensor* bias_tensor, - const TfLiteSVDFParams* params, - TfLiteTensor* activation_state_tensor, - TfLiteTensor* output_tensor, OpData data, - int32_t input_zp, int32_t output_zp) { - const int n_rank = params->rank; - const int n_batch = input_tensor->dims->data[0]; - const int n_input = input_tensor->dims->data[1]; - const int n_filter = weights_feature_tensor->dims->data[0]; - const int n_unit = n_filter / n_rank; - const int n_memory = weights_time_tensor->dims->data[1]; - - TFLITE_DCHECK(context != nullptr); - TFLITE_DCHECK(context->GetScratchBuffer != nullptr); - - int32_t* scratch_tensor = static_cast<int32_t*>( - context->GetScratchBuffer(context, data.scratch_tensor_index)); - TFLITE_DCHECK(scratch_tensor != nullptr); - int32_t* scratch_output_tensor = static_cast<int32_t*>( - context->GetScratchBuffer(context, data.scratch_output_tensor_index)); - TFLITE_DCHECK(scratch_output_tensor != nullptr); - - // Shift states. - int16_t* const state_ptr = GetTensorData<int16_t>(activation_state_tensor); - - // Left shift the activation_state. - - // 4-byte alignment check for state_ptr - if (((reinterpret_cast<int>(state_ptr)) & 0x3) == 0) { - // 4-bytes aligned processing - ae_p16x2s* new_state_start = (ae_p16x2s*)(state_ptr - 2); - const ae_p16x2s* old_state_start = (ae_p16x2s*)(state_ptr - 2); - int loopcnt = (n_batch * n_filter * n_memory) - 1; - ae_p24x2s dstate, dtmp, dout; - - AE_LP16X2F_IU(dtmp, old_state_start, 4); - AE_LP16X2F_IU(dstate, old_state_start, 4); - for (int i = 0; i < (loopcnt >> 1); i++) { - dout = AE_SELP24_LH(dtmp, dstate); - dtmp = dstate; - AE_LP16X2F_IU(dstate, old_state_start, 4); - AE_SP16X2F_IU(dout, new_state_start, 4); - } - if (loopcnt & 0x1) { - AE_SP16F_L_I(dtmp, (ae_p16s*)new_state_start, 4); - } - } else { - // 2-bytes aligned processing - ae_p16s* new_state_start = (ae_p16s*)(state_ptr - 1); - const ae_p16s* old_state_start = (ae_p16s*)(state_ptr); - int loopcnt = (n_batch * n_filter * n_memory) - 1; - ae_p24x2s dstate; - for (int i = 0; i < loopcnt; i++) { - AE_LP16F_IU(dstate, old_state_start, 2); - AE_SP16F_L_IU(dstate, new_state_start, 2); - } - } - // Note: no need to clear the latest activation, matmul is not accumulative. - - // Feature matmul. - { - int16_t* state = GetTensorData<int16_t>(activation_state_tensor); - const int8_t* input = GetTensorData<int8_t>(input_tensor); - const int8_t* weight_feature = - GetTensorData<int8_t>(weights_feature_tensor); - int16_t* result_in_batch = state + (n_memory - 1); - int err = 0; - - for (int b = 0; b < n_batch; b++) { - err = xa_nn_matXvec_out_stride_sym8sxasym8s_16( - &result_in_batch[b * n_filter * n_memory], weight_feature, - &input[b * n_input], NULL, n_filter, n_input, n_input, n_memory, - -input_zp, (data.effective_scale_1_a << 8), data.effective_scale_1_b); - CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_matXvec_sym8sxasym8s_16 failed"); - } - } - - // Time. - { - for (int b = 0; b < n_batch; ++b) { - int8_t* output_ptr = GetTensorData<int8_t>(output_tensor) + b * n_unit; - - const int16_t* vector1_ptr = GetTensorData<int16_t>(weights_time_tensor); - const int16_t* vector2_ptr = - GetTensorData<int16_t>(activation_state_tensor) + - b * n_memory * n_filter; - int err = 0; - const int32_t* bias_ptr = GetTensorData<int32_t>(bias_tensor); - err = xa_nn_dot_prod_16x16_asym8s( - output_ptr, vector1_ptr, vector2_ptr, bias_ptr, n_memory * n_rank, - (data.effective_scale_2_a << 8), data.effective_scale_2_b, output_zp, - n_unit); - CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_dot_prod_16x16_asym8s failed"); - } - } - return kTfLiteOk; -} - -} // namespace - -void* Init(TfLiteContext* context, const char* buffer, size_t length) { - TFLITE_DCHECK(context != nullptr); - TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); - return context->AllocatePersistentBuffer(context, sizeof(OpData)); -} - -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TFLITE_DCHECK(node->builtin_data != nullptr); - const auto* params = static_cast<const TfLiteSVDFParams*>(node->builtin_data); - - // Validate Tensor Inputs (dtype depends on quantization): - // [0] = Input, {2, batch_size, input_size} - // [1] = Weights Feature, {2, num_filters, input_size} - // [2] = Weights Time, {2, num_filters, memory_size} - // [3] = Bias (optional), {1, num_units} - // [4] = Activation State (variable), - // {2, batch_size, memory_size * num_filters} - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* weights_feature = - GetInput(context, node, kWeightsFeatureTensor); - const TfLiteTensor* weights_time = - GetInput(context, node, kWeightsTimeTensor); - const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); - const TfLiteTensor* activation_state = - GetInput(context, node, kInputActivationStateTensor); - - // Define input constants based on input tensor definition above: - const int rank = params->rank; - const int input_size = input->dims->data[1]; - const int batch_size = input->dims->data[0]; - // Ensure the input size is a multiple of two. This is necessary since - // optimized kernels access the memory in chunks of two, and all accesses - // must be aligned to 16 bits. - // TODO(b/153202598): Remove when padding is allowed in TFLite tensors. - TF_LITE_ENSURE_EQ(context, input_size % 2, 0); - - const int num_filters = weights_feature->dims->data[0]; - TF_LITE_ENSURE_EQ(context, num_filters % rank, 0); - const int num_units = num_filters / rank; - const int memory_size = weights_time->dims->data[1]; - - if (input->type != kTfLiteInt8) { - TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", - TfLiteTypeGetName(input->type), input->type); - return kTfLiteError; - } - - // Validate Input Tensor: - TF_LITE_ENSURE(context, input->type == kTfLiteInt8); - TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2); - - // Validate Tensor Output: - // [0] = float/int8_t, {2, batch_size, num_units} - TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TF_LITE_ENSURE_EQ(context, NumDimensions(output), 2); - TF_LITE_ENSURE_EQ(context, output->dims->data[0], batch_size); - TF_LITE_ENSURE_EQ(context, output->dims->data[1], num_units); - - // Validate Weights Feature Input Tensor: - TF_LITE_ENSURE_EQ(context, NumDimensions(weights_feature), 2); - TF_LITE_ENSURE_EQ(context, weights_feature->dims->data[1], input_size); - - // Validate Weights Time Input Tensor: - TF_LITE_ENSURE_EQ(context, NumDimensions(weights_time), 2); - TF_LITE_ENSURE_EQ(context, weights_time->dims->data[0], num_filters); - TF_LITE_ENSURE_EQ(context, weights_time->dims->data[1], memory_size); - - // Validate Optional Bias Input Tensor: - if (bias != nullptr) { - TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units); - TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32); - } - - // Validate Activation State Input Tensor: - TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2); - TF_LITE_ENSURE_EQ(context, activation_state->dims->data[0], batch_size); - TF_LITE_ENSURE_EQ(context, activation_state->dims->data[1], - memory_size * num_filters); - - TF_LITE_ENSURE_EQ(context, node->inputs->size, 5); - TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8); - TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteInt16); - TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16); - - // Validate output tensor: - TF_LITE_ENSURE_EQ(context, output->type, kTfLiteInt8); - - // Calculate effective scales. - auto* input_params = - static_cast<TfLiteAffineQuantization*>(input->quantization.params); - auto* weights_feature_params = static_cast<TfLiteAffineQuantization*>( - weights_feature->quantization.params); - auto* state_params = static_cast<TfLiteAffineQuantization*>( - activation_state->quantization.params); - auto* weight_time_params = - static_cast<TfLiteAffineQuantization*>(weights_time->quantization.params); - auto* output_params = - static_cast<TfLiteAffineQuantization*>(output->quantization.params); - const float effective_scale_1 = input_params->scale->data[0] * - weights_feature_params->scale->data[0] / - state_params->scale->data[0]; - const float effective_scale_2 = state_params->scale->data[0] * - weight_time_params->scale->data[0] / - output_params->scale->data[0]; - - TFLITE_DCHECK(node->user_data != nullptr); - OpData* data = static_cast<OpData*>(node->user_data); - - xtensa::hifimini::QuantizeMultiplier(effective_scale_1, - &data->effective_scale_1_a, - &data->effective_scale_1_b); - xtensa::hifimini::QuantizeMultiplier(effective_scale_2, - &data->effective_scale_2_a, - &data->effective_scale_2_b); - - const TfLiteStatus scratch_status = context->RequestScratchBufferInArena( - context, batch_size * num_filters * sizeof(int32_t), - &(data->scratch_tensor_index)); - TF_LITE_ENSURE_OK(context, scratch_status); - const TfLiteStatus scratch_output_status = - context->RequestScratchBufferInArena( - context, batch_size * num_units * sizeof(int32_t), - &(data->scratch_output_tensor_index)); - TF_LITE_ENSURE_OK(context, scratch_output_status); - - return kTfLiteOk; -} - -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = static_cast<TfLiteSVDFParams*>(node->builtin_data); - - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* weights_feature = - GetInput(context, node, kWeightsFeatureTensor); - const TfLiteTensor* weights_time = - GetInput(context, node, kWeightsTimeTensor); - const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); - TfLiteTensor* activation_state = - GetVariableInput(context, node, kInputActivationStateTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActRelu); - - TFLITE_DCHECK(node->user_data != nullptr); - const OpData& data = *(static_cast<const OpData*>(node->user_data)); - - return EvalIntegerSVDF(context, node, input, weights_feature, weights_time, - bias, params, activation_state, output, data, - input->params.zero_point, output->params.zero_point); -} - -} // namespace svdf - -TfLiteRegistration Register_SVDF() { - return {/*init=*/svdf::Init, - /*free=*/nullptr, - /*prepare=*/svdf::Prepare, - /*invoke=*/svdf::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; -} - -} // namespace micro -} // namespace ops -} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_api_defs.h b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_api_defs.h deleted file mode 100644 index a3eac676bbe..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_api_defs.h +++ /dev/null @@ -1,65 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Cadence Design Systems, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining - * a copy of this software and associated documentation files (the - * "Software"), to use this Software with Cadence processor cores only and - * not with any other processors and platforms, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - ******************************************************************************/ - -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef __XA_API_DEFS_H__ -#define __XA_API_DEFS_H__ - -/*****************************************************************************/ -/* Constant hash defines */ -/*****************************************************************************/ -/* A constant to let API copy small strings to buffers outside */ -#define XA_API_STR_LEN 30 -#define XA_APIVERSION_MAJOR 1 -#define XA_APIVERSION_MINOR 0 - -/* last compatible version */ -/* sometimes a new API version is just for a bugfix, or a added feature in */ -/* this case it is better to use a newer version even though a library was */ -/* made for an older version, library API can then be upgraded to newer API */ -/* version after checking for compatibility or by adding features */ -#define XA_LASTCOMP_APIVERSION_MAJOR 1 -#define XA_LASTCOMP_APIVERSION_MINOR 0 - -#define XA_STR(str) #str -#define XA_MAKE_VERSION_STR(maj, min) XA_STR(maj) "." XA_STR(min) -#define XA_APIVERSION \ - XA_MAKE_VERSION_STR(XA_APIVERSION_MAJOR, XA_APIVERSION_MINOR) - -#define XA_LAST_COMP_APIVERSION \ - XA_MAKE_VERSION_STR(XA_LASTCOMP_APIVERSION_MAJOR, \ - XA_LASTCOMP_APIVERSION_MINOR) - -#endif /* __XA_API_DEFS_H__ */ diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_common.h b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_common.h deleted file mode 100644 index 71e668299e6..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_common.h +++ /dev/null @@ -1,55 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Cadence Design Systems, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining - * a copy of this software and associated documentation files (the - * "Software"), to use this Software with Cadence processor cores only and - * not with any other processors and platforms, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - ******************************************************************************/ - -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef __XA_NNLIB_COMMON_H__ -#define __XA_NNLIB_COMMON_H__ - -#include <inttypes.h> -#include <stddef.h> -#include <xtensa/config/core-isa.h> -#include <xtensa/tie/xt_core.h> -#include <xtensa/tie/xt_hifi2.h> -#include <xtensa/tie/xt_misc.h> -#if XCHAL_HAVE_HIFI4_VFPU -#include <xtensa/tie/xt_FP.h> -#endif - -#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_err_chk.h" -#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_kernels_api.h" -#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_standards.h" -#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/xa_type_def.h" - -#endif /* __XA_NNLIB_COMMON_H__ */ diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_common_macros.h b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_common_macros.h deleted file mode 100644 index d04752b3a12..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_common_macros.h +++ /dev/null @@ -1,921 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Cadence Design Systems, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining - * a copy of this software and associated documentation files (the - * "Software"), to use this Software with Cadence processor cores only and - * not with any other processors and platforms, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - ******************************************************************************/ - -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef __XA_NNLIB_COMMON_MACROS_H__ -#define __XA_NNLIB_COMMON_MACROS_H__ - -#ifndef NULL -#define NULL (void *)0 -#endif /* NULL */ - -#define ALIGNMENT 8 - -/* Macro for zero value */ -#define ZERO64 AE_MOVINT64_FROMINT32X2(AE_MOVDA32(0)) -#define ZERO16X4 AE_MOVDA16(0) -#define ZERO16 (0) -#define ZERO32 (0) - -/* Macro for 1 */ -#define ONE16X4 AE_MOVDA16(1) - -/* Value of ROW_UNROLL currently supported are 1,2,4,8 only */ -#ifndef ROW_UNROLL -#define ROW_UNROLL 8 -#endif -#define VEC_UNROLL 2 - -#define ACC_LSH_AFTER_FIRST_MATXVEC 0 - -/* Increment in bytes required for particular load - * instructions. */ -#define INCREMENT_IN_BYTES_FOR_WORD8 1 -#define INCREMENT_IN_BYTES_FOR_INT16 2 -#define INCREMENT_IN_BYTES_FOR_INT32 (INCREMENT_IN_BYTES_FOR_INT16 * 2) -#define INCREMENT_IN_BYTES_FOR_WORD8X4 (INCREMENT_IN_BYTES_FOR_WORD8 * 4) -#define INCREMENT_IN_BYTES_FOR_INT16X4 (INCREMENT_IN_BYTES_FOR_INT16 * 4) -#define INCREMENT_IN_BYTES_FOR_INT64 INCREMENT_IN_BYTES_FOR_INT16X4 -#define INCREMENT_IN_BYTES_FOR_FLOAT32 4 -#define INCREMENT_IN_BYTES_FOR_FLOAT32x2 (INCREMENT_IN_BYTES_FOR_FLOAT32 * 2) - -#define HF2_AE_ADDCIRC16X4_XC(ptr, offset) \ - ptr = ptr + offset; \ - if (ptr >= p_end) ptr = ptr - size; - -#define MULTIPLY_BY_QUANTIZED_MULTIPLIER(q_out, inp, out_multiplier, \ - left_shift, right_shift) \ - { \ - ae_q56s d1; \ - ae_p24x2s d_mul; \ - d_mul = AE_CVTP24A16X2_HL(out_multiplier, out_multiplier); \ - d1 = AE_CVTQ48A32S(inp); \ - d1 = AE_SLLAQ56(d1, left_shift); \ - q_out = AE_MULFQ32SP16U_L(d1, d_mul); \ - q_out = AE_SRAIQ56(q_out, 16); \ - AE_MULAFQ32SP16S_H(q_out, d1, d_mul); \ - q_out = AE_SRAAQ56(q_out, right_shift); \ - q_out = AE_ROUNDSQ32SYM(q_out); \ - } - -/* Limit effective bias_shift and acc_shift to [-63 ... 63] */ -#define LIMIT_VARIABLE(_var, _left_limit, _right_limit) \ - _var = _var > _right_limit ? _right_limit \ - : _var < _left_limit ? _left_limit : _var; - -#define LIMIT_ACC_LSH LIMIT_VARIABLE(acc_shift, -63, 63); - -#define LIMIT_BIAS_LSH LIMIT_VARIABLE(bias_shift, -63, 63); - -#define BW(_datatype) sizeof(_datatype) - -#define ADJUST_VAR_AxB(A, B) (((8 * (4 - (BW(A) + BW(B)))))) - -#define ADJUST_VAR_C(C) (((64 - (8 * BW(C))))) - -#define ADJUST_ACC_LSH_AxB_C(A, B, C) \ - acc_shift = acc_shift + 32; \ - LIMIT_ACC_LSH; - -#define ADJUST_BIAS_LSH_AxB(A, B) LIMIT_BIAS_LSH; - -#define ADJUST_ACC_LSH_AND_BIAS_LSH_AxB_C(A, B, C) \ - ADJUST_ACC_LSH_AxB_C(A, B, C); \ - ADJUST_BIAS_LSH_AxB(A, B); - -/* ==================================================================================================== - */ -#define SETUP_BIAS_f32 \ - xtfloat _xtfloat_bias = (xtfloat)0.0f; \ - xtfloat *_xtfloat_p_bias = (xtfloat *)p_bias; - -#define SETUP_BIAS_ASYM8b \ - WORD32 _WORD32_bias; \ - ae_int64 _ae_int64_sat_bias = ZERO64; \ - WORD32 *_WORD32_p_bias = (WORD32 *)p_bias; - -#define SETUP_BIAS_8b \ - WORD8 _WORD8_bias; \ - UWORD32 _UWORD32_bias; \ - ae_int64 _ae_int64_bias = ZERO64; \ - ae_int64 _ae_int64_sat_bias = ZERO64; \ - WORD8 *_WORD8_p_bias = (WORD8 *)p_bias; - -#define SETUP_BIAS_8b_BATCH \ - WORD8 _WORD8_bias; \ - WORD16 _WORD16_bias; \ - ae_int16 _ae_int16_bias = ZERO16; \ - ae_int16 *_ae_int16_p_bias = &_ae_int16_bias; \ - ae_int64 _ae_int64_sat_bias = ZERO64; \ - WORD8 *_WORD8_p_bias = (WORD8 *)p_bias; - -#define SETUP_BIAS_32b \ - ae_int32 _ae_int32_bias = ZERO32; \ - ae_int64 _ae_int64_sat_bias = ZERO64; \ - ae_int32 *_ae_int32_p_bias = (ae_int32 *)p_bias; - -#define SETUP_BIAS_16b \ - ae_int16 _ae_int16_bias = ZERO16; \ - ae_int64 _ae_int64_sat_bias = ZERO64; \ - ae_int16 *_ae_int16_p_bias = (ae_int16 *)p_bias; - -#define SETUP_BIAS_64b \ - ae_int64 _ae_int64_bias = ZERO64; \ - ae_int64 _ae_int64_sat_bias = ZERO64; \ - ae_int64 *_ae_int64_p_bias = (ae_int64 *)p_bias; - -#define SETUP_ACC_FOR_8bx8b(idx) SETUP_ACC_64b(idx) -#define SETUP_ACC_FOR_8bx16b(idx) SETUP_ACC_64b(idx) -#define SETUP_ACC_FOR_16bx8b(idx) SETUP_ACC_64b(idx) -#define SETUP_ACC_FOR_16bx16b(idx) SETUP_ACC_64b(idx) -#define SETUP_ACC_FOR_ASYM8bxASYM8b(idx) SETUP_ACC_64b(idx) - -/*------------------ time batching macros ----------------- */ - -#define SETUP_ACC_BATCH_ROW_FOR_16bx8b SETUP_ACC_BATCH_ROW_FOR_16bx16b -#define SETUP_ACC_BATCH_ROW_FOR_8bx16b SETUP_ACC_BATCH_ROW_FOR_16bx16b -#define SETUP_ACC_BATCH_ROW_FOR_8bx8b SETUP_ACC_BATCH_ROW_FOR_16bx16b -#define SETUP_ACC_BATCH_ROW_FOR_ASYM8bxASYM8b SETUP_ACC_BATCH_ROW_FOR_16bx16b - -#define SETUP_ACC_BATCH_FOR_16bx8b SETUP_ACC_BATCH_FOR_16bx16b -#define SETUP_ACC_BATCH_FOR_8bx16b SETUP_ACC_BATCH_FOR_16bx16b -#define SETUP_ACC_BATCH_FOR_8bx8b SETUP_ACC_BATCH_FOR_16bx16b -#define SETUP_ACC_BATCH_FOR_ASYM8bxASYM8b SETUP_ACC_BATCH_FOR_16bx16b - -#define SETUP_ACC_BATCH_ROW_FOR_16bx16b(idx_row) \ - SETUP_ACC_BATCH_VEC_UNROLL(idx_row); - -#define SETUP_ACC_BATCH_FOR_16bx16b(idx_row, idx_vec) \ - ae_int64 _ae_int64_acc_##idx_row##_##idx_vec = ZERO64; - -#define SETUP_ACC_BATCH_ROW_FOR_f32(idx_row) \ - SETUP_ACC_BATCH_VEC_UNROLL(idx_row); - -#define SETUP_ACC_BATCH_FOR_f32(idx_row, idx_vec) \ - xtfloatx2 _xtfloatx2_acc_##idx_row##_##idx_vec = (xtfloatx2)0.0f; \ - xtfloat _xtfloat_acc_##idx_row##_##idx_vec = (xtfloat)0.0f; \ - /*---------------------------------------------------------*/ - -#define SETUP_ACC_64b(idx) ae_int64 _ae_int64_acc_##idx = ZERO64; - -#define SETUP_VEC1_8b \ - ae_int16x4 _ae_int16x4_vec1 = ZERO16X4; \ - WORD8 *_WORD8_p_vec1 = (WORD8 *)p_vec1; - -#define SETUP_VEC2_8b \ - ae_int16x4 _ae_int16x4_vec2 = ZERO16X4; \ - WORD8 *_WORD8_p_vec2 = (WORD8 *)p_vec2; - -#define SETUP_VEC1_16b \ - ae_int16x4 _ae_int16x4_vec1 = ZERO16X4; \ - ae_int16x4 *_ae_int16x4_p_vec1 = (ae_int16x4 *)p_vec1; - -#define SETUP_VEC2_16b \ - ae_int16x4 _ae_int16x4_vec2 = ZERO16X4; \ - ae_int16x4 *_ae_int16x4_p_vec2 = (ae_int16x4 *)p_vec2; - -#define SETUP_VEC1_ASYM8b SETUP_VEC1_8b -#define SETUP_VEC2_ASYM8b SETUP_VEC2_8b -/*------------------ time batching macros ----------------- */ - -#define SETUP_VEC_BATCH_8b(idx_vec) \ - ae_int16x4 _ae_int16x4_vec_batch_##idx_vec = ZERO16X4; \ - WORD8 *_WORD8_p_vec_batch_##idx_vec = (WORD8 *)(p_vec1[vec_itr + idx_vec]); - -#define SETUP_VEC_BATCH_16b(idx_vec) \ - ae_int16x4 _ae_int16x4_vec_batch_##idx_vec = ZERO16X4; \ - ae_int16x4 *_ae_int16x4_p_vec_batch_##idx_vec = \ - (ae_int16x4 *)(p_vec1[vec_itr + idx_vec]); - -#define SETUP_VEC_OFFSET_BATCH_16b(idx_vec) \ - ae_int16x4 _ae_int16x4_vec_batch_##idx_vec = ZERO16X4; \ - ae_int16x4 *_ae_int16x4_p_vec_batch_##idx_vec = \ - (ae_int16x4 *)(p_vec1 + (vec_itr + idx_vec) * vec_offset); - -#define SETUP_VEC_BATCH_f32(idx_vec) \ - xtfloatx2 _xtfloatx2_vec_batch_##idx_vec = (xtfloatx2)0.0f; \ - xtfloatx2 *_xtfloatx2_p_vec_batch_##idx_vec = \ - (xtfloatx2 *)(p_vec1[vec_itr + idx_vec]); - -#define SETUP_VEC_BATCH_ASYM8b SETUP_VEC_BATCH_8b -/*---------------------------------------------------------*/ - -#define SETUP_MAT1_8b(idx) \ - ae_int16x4 _ae_int16x4_mat1_##idx = ZERO16X4; \ - WORD8 *_WORD8_p_mat1_##idx = (WORD8 *)&p_mat1[(m_itr + idx) * row_stride1]; - -#define SETUP_MAT2_8b(idx) \ - ae_int16x4 _ae_int16x4_mat2_##idx = ZERO16X4; \ - WORD8 *_WORD8_p_mat2_##idx = (WORD8 *)&p_mat2[(m_itr + idx) * row_stride2]; - -#define SETUP_MAT1_16b(idx) \ - ae_int16x4 _ae_int16x4_mat1_##idx = ZERO16X4; \ - ae_int16x4 *_ae_int16x4_p_mat1_##idx = \ - (ae_int16x4 *)&p_mat1[(m_itr + idx) * row_stride1]; - -#define SETUP_MAT2_16b(idx) \ - ae_int16x4 _ae_int16x4_mat2_##idx = ZERO16X4; \ - ae_int16x4 *_ae_int16x4_p_mat2_##idx = \ - (ae_int16x4 *)&p_mat2[(m_itr + idx) * row_stride2]; - -#define SETUP_MAT1_f32(idx) \ - xtfloatx2 _xtfloatx2_mat1_##idx = (xtfloatx2)0.0f; \ - xtfloatx2 *_xtfloatx2_p_mat1_##idx = \ - (xtfloatx2 *)&p_mat1[(m_itr + idx) * row_stride1]; - -#define SETUP_MAT1_ASYM8b SETUP_MAT1_8b -#define SETUP_MAT2_ASYM8b SETUP_MAT2_8b -/* ====================================================================== */ - -#define LOAD_VEC1_8b \ - AE_L8X4F_IP(_ae_int16x4_vec1, _WORD8_p_vec1, INCREMENT_IN_BYTES_FOR_WORD8X4); - -#define LOAD_VEC2_8b \ - AE_L8X4F_IP(_ae_int16x4_vec2, _WORD8_p_vec2, INCREMENT_IN_BYTES_FOR_WORD8X4); - -#define LOAD_VEC1_16b \ - AE_L16X4_IP(_ae_int16x4_vec1, _ae_int16x4_p_vec1, \ - INCREMENT_IN_BYTES_FOR_INT16X4); - -#define LOAD_VEC2_16b \ - AE_L16X4_IP(_ae_int16x4_vec2, _ae_int16x4_p_vec2, \ - INCREMENT_IN_BYTES_FOR_INT16X4); - -#define LOAD_VEC1_ASYM8b \ - AE_L8X4F_IP(_ae_int16x4_vec1, _WORD8_p_vec1, \ - INCREMENT_IN_BYTES_FOR_WORD8X4); \ - _ae_int16x4_vec1 = AE_MOVF16X4_FROMF64( \ - AE_SRLI64(AE_MOVF64_FROMF16X4(_ae_int16x4_vec1), 8)); \ - _ae_int16x4_vec1 = AE_ADD16(_ae_int16x4_vec1, AE_MOVDA16(vec1_zero_bias)); - -#define LOAD_VEC2_ASYM8b \ - AE_L8X4F_IP(_ae_int16x4_vec2, _WORD8_p_vec2, \ - INCREMENT_IN_BYTES_FOR_WORD8X4); \ - _ae_int16x4_vec2 = AE_MOVF16X4_FROMF64( \ - AE_SRLI64(AE_MOVF64_FROMF16X4(_ae_int16x4_vec2), 8)); \ - _ae_int16x4_vec2 = AE_ADD16(_ae_int16x4_vec2, AE_MOVDA16(vec2_zero_bias)); \ -/*------------------ time batching macros ----------------- */ -#define LOAD_VEC_BATCH_f32(idx_vec) \ - XT_LSX2IP(_xtfloatx2_vec_batch_##idx_vec, _xtfloatx2_p_vec_batch_##idx_vec, \ - INCREMENT_IN_BYTES_FOR_FLOAT32x2); - -#define LOAD_VEC_BATCH_8b(idx_vec) \ - AE_L8X4F_IP(_ae_int16x4_vec_batch_##idx_vec, _WORD8_p_vec_batch_##idx_vec, \ - INCREMENT_IN_BYTES_FOR_WORD8X4); - -#define LOAD_VEC_BATCH_16b(idx_vec) \ - AE_L16X4_IP(_ae_int16x4_vec_batch_##idx_vec, \ - _ae_int16x4_p_vec_batch_##idx_vec, \ - INCREMENT_IN_BYTES_FOR_INT16X4); - -#define LOAD_VEC_BATCH_ASYM8b(idx_vec) \ - AE_L8X4F_IP(_ae_int16x4_vec_batch_##idx_vec, _WORD8_p_vec_batch_##idx_vec, \ - INCREMENT_IN_BYTES_FOR_WORD8X4); \ - _ae_int16x4_vec_batch_##idx_vec = AE_MOVF16X4_FROMF64( \ - AE_SRLI64(AE_MOVF64_FROMF16X4(_ae_int16x4_vec_batch_##idx_vec), 8)); \ - _ae_int16x4_vec_batch_##idx_vec = \ - AE_ADD16(_ae_int16x4_vec_batch_##idx_vec, AE_MOVDA16(vec1_zero_bias)); - -#define LOAD_BIAS_8b_FOR_8bx8b \ - _WORD8_bias = *_WORD8_p_bias++; \ - _WORD16_bias = _WORD8_bias; \ - *((WORD16 *)_ae_int16_p_bias) = _WORD16_bias; \ - _ae_int64_sat_bias = AE_SLAA64S(((ae_int64)_ae_int16_bias), bias_shift); - -#define LOAD_BIAS_16b_FOR_8bx16b \ - ae_int16_loadip(_ae_int16_bias, _ae_int16_p_bias, \ - INCREMENT_IN_BYTES_FOR_INT16); \ - _ae_int64_sat_bias = AE_SLAA64S(((ae_int64)_ae_int16_bias), bias_shift); - -#define LOAD_BIAS_16b_FOR_16bx8b LOAD_BIAS_16b_FOR_8bx16b - -#define LOAD_BIAS_16b_FOR_16bx16b \ - ae_int16_loadip(_ae_int16_bias, _ae_int16_p_bias, \ - INCREMENT_IN_BYTES_FOR_INT16); \ - _ae_int64_sat_bias = AE_SLAA64S(((ae_int64)_ae_int16_bias), bias_shift); - -#define LOAD_BIAS_f32 \ - XT_LSIP(_xtfloat_bias, _xtfloat_p_bias, INCREMENT_IN_BYTES_FOR_FLOAT32); - -#define LOAD_BIAS_ASYM8b \ - _WORD32_bias = *_WORD32_p_bias++; \ - _ae_int64_sat_bias = \ - AE_SRAI64(AE_MOVINT64_FROMINT32X2(AE_MOVDA32(_WORD32_bias)), 32); \ -/*---------------------------------------------------------*/ -#define LOAD_ROW_MAT1_8b(idx) \ - AE_L8X4F_IP(_ae_int16x4_mat1_##idx, _WORD8_p_mat1_##idx, \ - INCREMENT_IN_BYTES_FOR_WORD8X4); - -#define LOAD_ROW_MAT2_8b(idx) \ - AE_L8X4F_IP(_ae_int16x4_mat2_##idx, _WORD8_p_mat2_##idx, \ - INCREMENT_IN_BYTES_FOR_WORD8X4); - -#define LOAD_ROW_MAT1_16b(idx) \ - AE_L16X4_IP(_ae_int16x4_mat1_##idx, _ae_int16x4_p_mat1_##idx, \ - INCREMENT_IN_BYTES_FOR_INT16X4); - -#define LOAD_ROW_MAT2_16b(idx) \ - AE_L16X4_IP(_ae_int16x4_mat2_##idx, _ae_int16x4_p_mat2_##idx, \ - INCREMENT_IN_BYTES_FOR_INT16X4); - -#define LOAD_ROW_MAT1_f32(idx) \ - XT_LSX2IP(_xtfloatx2_mat1_##idx, _xtfloatx2_p_mat1_##idx, \ - INCREMENT_IN_BYTES_FOR_FLOAT32x2); - -#define LOAD_ROW_MAT1_ASYM8b(idx) \ - AE_L8X4F_IP(_ae_int16x4_mat1_##idx, _WORD8_p_mat1_##idx, \ - INCREMENT_IN_BYTES_FOR_WORD8X4); \ - _ae_int16x4_mat1_##idx = AE_MOVF16X4_FROMF64( \ - AE_SRLI64(AE_MOVF64_FROMF16X4(_ae_int16x4_mat1_##idx), 8)); \ - _ae_int16x4_mat1_##idx = \ - AE_ADD16(_ae_int16x4_mat1_##idx, AE_MOVDA16(mat1_zero_bias)); - -#define LOAD_ROW_MAT2_ASYM8b(idx) \ - AE_L8X4F_IP(_ae_int16x4_mat2_##idx, _WORD8_p_mat2_##idx, \ - INCREMENT_IN_BYTES_FOR_WORD8X4); \ - _ae_int16x4_mat2_##idx = AE_MOVF16X4_FROMF64( \ - AE_SRLI64(AE_MOVF64_FROMF16X4(_ae_int16x4_mat2_##idx), 8)); \ - _ae_int16x4_mat2_##idx = \ - AE_ADD16(_ae_int16x4_mat2_##idx, AE_MOVDA16(mat2_zero_bias)); - -#define KERNEL_MAT1_VEC1_8b_8b(idx) \ - LOAD_ROW_MAT1_8b(idx); \ - AE_MULAAAAQ16(_ae_int64_acc_##idx, _ae_int16x4_vec1, _ae_int16x4_mat1_##idx); - -#define KERNEL_MAT2_VEC2_8b_8b(idx) \ - LOAD_ROW_MAT2_8b(idx); \ - AE_MULAAAAQ16(_ae_int64_acc_##idx, _ae_int16x4_vec2, _ae_int16x4_mat2_##idx); - -#define KERNEL_MAT1_VEC1_16b_8b(idx) \ - LOAD_ROW_MAT1_16b(idx); \ - AE_MULAAAAQ16(_ae_int64_acc_##idx, _ae_int16x4_vec1, _ae_int16x4_mat1_##idx); - -#define KERNEL_MAT2_VEC2_16b_8b(idx) \ - LOAD_ROW_MAT2_16b(idx); \ - AE_MULAAAAQ16(_ae_int64_acc_##idx, _ae_int16x4_vec2, _ae_int16x4_mat2_##idx); - -#define KERNEL_MAT1_VEC1_8b_16b(idx) \ - LOAD_ROW_MAT1_8b(idx); \ - AE_MULAAAAQ16(_ae_int64_acc_##idx, _ae_int16x4_vec1, _ae_int16x4_mat1_##idx); - -#define KERNEL_MAT2_VEC2_8b_16b(idx) \ - LOAD_ROW_MAT2_8b(idx); \ - AE_MULAAAAQ16(_ae_int64_acc_##idx, _ae_int16x4_vec2, _ae_int16x4_mat2_##idx); - -#define KERNEL_MAT1_VEC1_16b_16b(idx) \ - LOAD_ROW_MAT1_16b(idx); \ - AE_MULAAAAQ16(_ae_int64_acc_##idx, _ae_int16x4_vec1, _ae_int16x4_mat1_##idx); - -#define KERNEL_MAT2_VEC2_16b_16b(idx) \ - LOAD_ROW_MAT2_16b(idx); \ - AE_MULAAAAQ16(_ae_int64_acc_##idx, _ae_int16x4_vec2, _ae_int16x4_mat2_##idx); - -#define KERNEL_MAT1_VEC1_ASYM8b_ASYM8b(idx) \ - LOAD_ROW_MAT1_ASYM8b(idx); \ - AE_MULAAAAQ16(_ae_int64_acc_##idx, _ae_int16x4_vec1, _ae_int16x4_mat1_##idx); - -#define KERNEL_MAT2_VEC2_ASYM8b_ASYM8b(idx) \ - LOAD_ROW_MAT2_ASYM8b(idx); \ - AE_MULAAAAQ16(_ae_int64_acc_##idx, _ae_int16x4_vec2, _ae_int16x4_mat2_##idx); - -/*------------------ time batching macros ----------------- */ - -#define KERNEL_MAT1_VEC_BATCH_ROW_8b_8b KERNEL_MAT1_VEC_BATCH_ROW_16b_16b -#define KERNEL_MAT1_VEC_BATCH_ROW_16b_8b KERNEL_MAT1_VEC_BATCH_ROW_16b_16b -#define KERNEL_MAT1_VEC_BATCH_ROW_8b_16b KERNEL_MAT1_VEC_BATCH_ROW_16b_16b -#define KERNEL_MAT1_VEC_BATCH_ROW_ASYM8b_ASYM8b \ - KERNEL_MAT1_VEC_BATCH_ROW_16b_16b -#define KERNEL_MAT1_VEC_BATCH_8b_8b KERNEL_MAT1_VEC_BATCH_16b_16b -#define KERNEL_MAT1_VEC_BATCH_16b_8b KERNEL_MAT1_VEC_BATCH_16b_16b -#define KERNEL_MAT1_VEC_BATCH_8b_16b KERNEL_MAT1_VEC_BATCH_16b_16b -#define KERNEL_MAT1_VEC_BATCH_ASYM8b_ASYM8b KERNEL_MAT1_VEC_BATCH_16b_16b - -#define KERNEL_MAT1_VEC_BATCH_ROW_16b_16b(idx_row) \ - KERNEL_MAT1_VEC_BATCH_VEC_UNROLL(idx_row); - -#define KERNEL_MAT1_VEC_BATCH_16b_16b(idx_row, idx_vec) \ - AE_MULAAAAQ16(_ae_int64_acc_##idx_row##_##idx_vec, \ - _ae_int16x4_vec_batch_##idx_vec, _ae_int16x4_mat1_##idx_row); - -#define KERNEL_MAT1_VEC_BATCH_ROW_f32(idx_row) \ - KERNEL_MAT1_VEC_BATCH_VEC_UNROLL(idx_row); - -#define KERNEL_MAT1_VEC_BATCH_f32(idx_row, idx_vec) \ - XT_MADD_SX2(_xtfloatx2_acc_##idx_row##_##idx_vec, \ - _xtfloatx2_vec_batch_##idx_vec, _xtfloatx2_mat1_##idx_row); - -/*---------------------------------------------------------*/ -#define ADD_BIAS_8b_ACC_FOR_8bx8b(idx) \ - /* Load 8b bias */ \ - _WORD8_bias = *_WORD8_p_bias++; \ - /* Copy 8-bits to unsigned 32-bits */ \ - _UWORD32_bias = _WORD8_bias; \ - /*Move unsigned 32 bit value to DR register*/ \ - _ae_int64_bias = AE_MOVINT64_FROMINT32X2((AE_MOVDA32X2(_UWORD32_bias, 0))); \ - _ae_int64_bias = AE_SRAA64(_ae_int64_bias, 32); \ - _ae_int64_sat_bias = AE_SLAA64S(_ae_int64_bias, bias_shift); \ - _ae_int64_acc_##idx = AE_SRAA64(_ae_int64_acc_##idx, 16); \ - _ae_int64_acc_##idx = AE_ADD64S(_ae_int64_acc_##idx, _ae_int64_sat_bias); - -#define ADD_BIAS_32b_ACC_FOR_8bx8b(idx) \ - ae_int32_loadip(_ae_int32_bias, _ae_int32_p_bias, \ - INCREMENT_IN_BYTES_FOR_INT32); \ - _ae_int64_sat_bias = AE_SLAA64S(((ae_int64)_ae_int32_bias), bias_shift); \ - _ae_int64_acc_##idx = AE_SRAA64(_ae_int64_acc_##idx, 16); \ - _ae_int64_acc_##idx = AE_ADD64S(_ae_int64_acc_##idx, _ae_int64_sat_bias); - -#define ADD_BIAS_16b_ACC_FOR_8bx16b(idx) \ - ae_int16_loadip(_ae_int16_bias, _ae_int16_p_bias, \ - INCREMENT_IN_BYTES_FOR_INT16); \ - /* Saturate 16b bias after shift to 64b */ \ - _ae_int64_sat_bias = AE_SLAA64S(((ae_int64)_ae_int16_bias), bias_shift); \ - _ae_int64_acc_##idx = AE_SRAA64(_ae_int64_acc_##idx, 8); \ - _ae_int64_acc_##idx = AE_ADD64S(_ae_int64_acc_##idx, _ae_int64_sat_bias); - -#define ADD_BIAS_16b_ACC_FOR_16bx8b ADD_BIAS_16b_ACC_FOR_8bx16b - -#define ADD_BIAS_64b_ACC_FOR_8bx16b(idx) \ - ae_int64_loadip(_ae_int64_bias, _ae_int64_p_bias, \ - INCREMENT_IN_BYTES_FOR_INT64); \ - /* Saturate 64b bias after shift to 64b */ \ - _ae_int64_sat_bias = AE_SLAA64S(((ae_int64)_ae_int64_bias), bias_shift); \ - _ae_int64_acc_##idx = AE_SRAA64(_ae_int64_acc_##idx, 8); \ - _ae_int64_acc_##idx = AE_ADD64S(_ae_int64_acc_##idx, _ae_int64_sat_bias); - -#define ADD_BIAS_16b_ACC_FOR_16bx16b(idx) \ - ae_int16_loadip(_ae_int16_bias, _ae_int16_p_bias, \ - INCREMENT_IN_BYTES_FOR_INT16); \ - /* Saturate 16b bias after shift to 64b */ \ - _ae_int64_sat_bias = AE_SLAA64S(((ae_int64)_ae_int16_bias), bias_shift); \ - _ae_int64_acc_##idx = AE_ADD64S(_ae_int64_acc_##idx, _ae_int64_sat_bias); - -#define ADD_BIAS_64b_ACC_FOR_16bx16b(idx) \ - ae_int64_loadip(_ae_int64_bias, _ae_int64_p_bias, \ - INCREMENT_IN_BYTES_FOR_INT64); \ - /* Saturate 64b bias after shift to 64b */ \ - _ae_int64_sat_bias = AE_SLAA64S(((ae_int64)_ae_int64_bias), bias_shift); \ - _ae_int64_acc_##idx = AE_ADD64S(_ae_int64_acc_##idx, _ae_int64_sat_bias); - -#define ADD_BIAS_ASYM8b_ACC_FOR_ASYM8bxASYM8b(idx) \ - /* Load 32b bias */ \ - _WORD32_bias = *_WORD32_p_bias++; \ - _ae_int64_sat_bias = \ - AE_SRAI64(AE_MOVINT64_FROMINT32X2(AE_MOVDA32(_WORD32_bias)), 32); \ - _ae_int64_acc_##idx = AE_ADD64S(_ae_int64_acc_##idx, _ae_int64_sat_bias); - -/*------------------ time batching macros ----------------- */ -#define ADD_BIAS_BATCH_ROW_8b_ACC_FOR_8bx8b(idx_row) \ - LOAD_BIAS_8b_FOR_8bx8b; \ - ADD_BIAS_BATCH_ACC_VEC_UNROLL(idx_row); - -#define ADD_BIAS_BATCH_ROW_16b_ACC_FOR_8bx16b(idx_row) \ - LOAD_BIAS_16b_FOR_8bx16b; \ - ADD_BIAS_BATCH_ACC_VEC_UNROLL(idx_row); - -#define ADD_BIAS_BATCH_ROW_16b_ACC_FOR_16bx8b(idx_row) \ - LOAD_BIAS_16b_FOR_16bx8b; \ - ADD_BIAS_BATCH_ACC_VEC_UNROLL(idx_row); - -#define ADD_BIAS_BATCH_ROW_16b_ACC_FOR_16bx16b(idx_row) \ - LOAD_BIAS_16b_FOR_16bx16b; \ - ADD_BIAS_BATCH_ACC_VEC_UNROLL(idx_row); - -#define ADD_BIAS_BATCH_ROW_ASYM8b_ACC_FOR_ASYM8bxASYM8b(idx_row) \ - LOAD_BIAS_ASYM8b ADD_BIAS_BATCH_ACC_VEC_UNROLL(idx_row); - -#define ADD_BIAS_BATCH_8b_ACC_FOR_8bx8b(idx_row, idx_vec) \ - _ae_int64_acc_##idx_row##_##idx_vec = \ - AE_SRAA64(_ae_int64_acc_##idx_row##_##idx_vec, 16); \ - _ae_int64_acc_##idx_row##_##idx_vec = \ - AE_ADD64S(_ae_int64_acc_##idx_row##_##idx_vec, _ae_int64_sat_bias); - -#define ADD_BIAS_BATCH_16b_ACC_FOR_8bx16b(idx_row, idx_vec) \ - _ae_int64_acc_##idx_row##_##idx_vec = \ - AE_SRAA64(_ae_int64_acc_##idx_row##_##idx_vec, 8); \ - _ae_int64_acc_##idx_row##_##idx_vec = \ - AE_ADD64S(_ae_int64_acc_##idx_row##_##idx_vec, _ae_int64_sat_bias); - -#define ADD_BIAS_BATCH_16b_ACC_FOR_16bx16b(idx_row, idx_vec) \ - _ae_int64_acc_##idx_row##_##idx_vec = \ - AE_ADD64S(_ae_int64_acc_##idx_row##_##idx_vec, _ae_int64_sat_bias); - -#define ADD_BIAS_BATCH_16b_ACC_FOR_16bx8b ADD_BIAS_BATCH_16b_ACC_FOR_8bx16b -#define ADD_BIAS_BATCH_ASYM8b_ACC_FOR_ASYM8bxASYM8b \ - ADD_BIAS_BATCH_16b_ACC_FOR_16bx16b - -#define ADD_BIAS_BATCH_ROW_ACC_FOR_f32(idx_row) \ - LOAD_BIAS_f32; \ - ADD_BIAS_BATCH_ACC_VEC_UNROLL(idx_row); - -#define ADD_BIAS_BATCH_ACC_FOR_f32(idx_row, idx_vec) \ - _xtfloat_acc_##idx_row##_##idx_vec = \ - XT_RADD_SX2(_xtfloatx2_acc_##idx_row##_##idx_vec); \ - _xtfloat_acc_##idx_row##_##idx_vec = \ - XT_ADD_S(_xtfloat_acc_##idx_row##_##idx_vec, _xtfloat_bias); - -#define STORE_ACC_8bx8b_AT_SCRATCH_32b(idx) \ - (*((ae_int32 *)p_scratch + m_itr + idx)) = \ - AE_ROUND32F64SSYM(AE_SLAA64S(_ae_int64_acc_##idx, acc_shift)); - -#define STORE_ACC_8bx8b_AT_OUT_8b(idx) \ - ae_int32 _ae_int32_tmp_var_##idx; \ - ae_f32x2 _ae_f32x2_tmp_var_##idx = AE_SLAA32S( \ - AE_ROUND32F64SSYM(AE_SLAA64S(_ae_int64_acc_##idx, acc_shift)), 24); \ - _ae_int32_tmp_var_##idx = AE_SLAA32S(_ae_f32x2_tmp_var_##idx, -24); \ - (*((WORD8 *)p_out + m_itr + idx)) = (*((UWORD32 *)&_ae_int32_tmp_var_##idx)); - -#define STORE_ACC_8bx8b_AT_OUT_16b(idx) \ - ae_int32 _ae_int32_tmp_var_##idx; \ - ae_f32x2 _ae_f32x2_tmp_var_##idx = AE_SLAA32S( \ - AE_ROUND32F64SSYM(AE_SLAA64S(_ae_int64_acc_##idx, acc_shift)), 16); \ - _ae_int32_tmp_var_##idx = AE_SLAA32S(_ae_f32x2_tmp_var_##idx, -16); \ - (*((WORD16 *)p_out + m_itr + idx)) = (*((UWORD32 *)&_ae_int32_tmp_var_##idx)); - -#define STORE_ACC_8bx8b_AT_OUT_32b(idx) \ - (*((ae_int32 *)p_out + m_itr + idx)) = \ - AE_ROUND32F64SSYM(AE_SLAA64S(_ae_int64_acc_##idx, acc_shift)); - -#define STORE_ACC_ASYM8bxASYM8b_AT_OUT_ASYM8b(idx) \ - _ae_int32x2_acc_##idx = AE_MIN32( \ - AE_MAX32(_ae_int32x2_acc_##idx, AE_MOVDA32(0)), AE_MOVDA32(255)); \ - (*((UWORD8 *)p_out + m_itr + idx)) = \ - (UWORD8)AE_MOVAD32_L(_ae_int32x2_acc_##idx); - -/* ==================================================================================================== - */ -#define STORE_ACC_8bx16b_AT_SCRATCH_32b(idx) \ - (*((ae_int32 *)p_scratch + m_itr + idx)) = \ - AE_ROUND32F64SSYM(AE_SLAA64S(_ae_int64_acc_##idx, acc_shift)); - -#define STORE_ACC_8bx16b_AT_OUT_16b(idx) \ - ae_int32 _ae_int32_tmp_var_##idx; \ - ae_f32x2 _ae_f32x2_tmp_var_##idx = AE_SLAA32S( \ - AE_ROUND32F64SSYM(AE_SLAA64S(_ae_int64_acc_##idx, acc_shift)), 16); \ - _ae_int32_tmp_var_##idx = AE_SLAA32S(_ae_f32x2_tmp_var_##idx, -16); \ - (*((WORD16 *)p_out + m_itr + idx)) = (*((UWORD32 *)&_ae_int32_tmp_var_##idx)); - -#define STORE_ACC_16bx8b_AT_OUT_16b STORE_ACC_8bx16b_AT_OUT_16b - -#define STORE_ACC_8bx16b_AT_OUT_32b(idx) \ - (*((ae_int32 *)p_out + m_itr + idx)) = \ - AE_ROUND32F64SSYM(AE_SLAA64S(_ae_int64_acc_##idx, acc_shift)); - -#define STORE_ACC_8bx16b_AT_OUT_64b(idx) \ - (*((ae_int64 *)p_out + m_itr + idx)) = \ - AE_SLAA64S(_ae_int64_acc_##idx, acc_shift); - -/* ==================================================================================================== - */ -#define STORE_ACC_16bx16b_AT_SCRATCH_32b(idx) \ - (*((ae_int32 *)p_scratch + m_itr + idx)) = \ - AE_ROUND32F64SSYM(AE_SLAA64S(_ae_int64_acc_##idx, acc_shift)); - -#define STORE_ACC_16bx16b_AT_OUT_16b(idx) \ - ae_int32 _ae_int32_tmp_var_##idx; \ - ae_f32x2 _ae_f32x2_tmp_var_##idx = AE_SLAA32S( \ - AE_ROUND32F64SSYM(AE_SLAA64S(_ae_int64_acc_##idx, acc_shift)), 16); \ - _ae_int32_tmp_var_##idx = AE_SLAA32S(_ae_f32x2_tmp_var_##idx, -16); \ - (*((WORD16 *)p_out + m_itr + idx)) = (*((UWORD32 *)&_ae_int32_tmp_var_##idx)); - -#define STORE_ACC_16bx16b_AT_OUT_32b(idx) \ - (*((ae_int32 *)p_out + m_itr + idx)) = \ - AE_ROUND32F64SSYM(AE_SLAA64S(_ae_int64_acc_##idx, acc_shift)); - -#define STORE_ACC_16bx16b_AT_OUT_64b(idx) \ - (*((ae_int64 *)p_out + m_itr + idx)) = \ - AE_SLAA64S(_ae_int64_acc_##idx, acc_shift); - -/*------------------ time batching macros ----------------- */ -#define STORE_ACC_BATCH_ROW_8bx8b_AT_OUT_32b(idx_row) \ - STORE_ACC_BATCH_VEC_UNROLL(idx_row); - -#define STORE_ACC_BATCH_ROW_8bx8b_AT_OUT_8b(idx_row) \ - STORE_ACC_BATCH_VEC_UNROLL(idx_row); - -#define STORE_ACC_BATCH_8bx8b_AT_OUT_32b(idx_row, idx_vec) \ - (*((ae_int32 *)p_out[vec_itr + idx_vec] + m_itr + idx_row)) = \ - AE_ROUND32F64SSYM( \ - AE_SLAA64S(_ae_int64_acc_##idx_row##_##idx_vec, acc_shift)); - -#define STORE_ACC_BATCH_8bx8b_AT_OUT_8b(idx_row, idx_vec) \ - ae_int32 _ae_int32_tmp_var_##idx_row##_##idx_vec; \ - ae_f32x2 _ae_f32x2_tmp_var_##idx_row##_##idx_vec = \ - AE_SLAA32S(AE_ROUND32F64SSYM(AE_SLAA64S( \ - _ae_int64_acc_##idx_row##_##idx_vec, acc_shift)), \ - 24); \ - _ae_int32_tmp_var_##idx_row##_##idx_vec = \ - AE_SLAA32S(_ae_f32x2_tmp_var_##idx_row##_##idx_vec, -24); \ - (*((WORD8 *)p_out[vec_itr + idx_vec] + m_itr + idx_row)) = \ - (*((UWORD32 *)&_ae_int32_tmp_var_##idx_row##_##idx_vec)); - -#define STORE_ACC_BATCH_ROW_8bx16b_AT_OUT_64b(idx_row) \ - STORE_ACC_BATCH_VEC_UNROLL(idx_row); - -#define STORE_ACC_BATCH_ROW_16bx8b_AT_OUT_16b \ - STORE_ACC_BATCH_ROW_8bx16b_AT_OUT_64b - -#define STORE_ACC_BATCH_ROW_8bx16b_AT_OUT_16b \ - STORE_ACC_BATCH_ROW_8bx16b_AT_OUT_64b - -#define STORE_ACC_BATCH_8bx16b_AT_OUT_64b(idx_row, idx_vec) \ - (*((ae_int64 *)p_out[vec_itr + idx_vec] + m_itr + idx_row)) = \ - AE_SLAA64S(_ae_int64_acc_##idx_row##_##idx_vec, acc_shift); - -#define STORE_ACC_BATCH_8bx16b_AT_OUT_16b(idx_row, idx_vec) \ - STORE_ACC_BATCH_16bx16b_AT_OUT_16b(idx_row, idx_vec); - -#define STORE_ACC_BATCH_ROW_16bx16b_AT_OUT_64b(idx_row) \ - STORE_ACC_BATCH_VEC_UNROLL(idx_row); - -#define STORE_ACC_BATCH_ROW_16bx16b_AT_OUT_16b \ - STORE_ACC_BATCH_ROW_16bx16b_AT_OUT_64b - -#define STORE_ACC_BATCH_16bx16b_AT_OUT_64b(idx_row, idx_vec) \ - (*((ae_int64 *)p_out[vec_itr + idx_vec] + m_itr + idx_row)) = \ - AE_SLAA64S(_ae_int64_acc_##idx_row##_##idx_vec, acc_shift); - -#define STORE_STRIDE_ACC_BATCH_16bx16b_AT_OUT_16b(idx_row, idx_vec) \ - ae_int32 _ae_int32_tmp_var_##idx_row##_##idx_vec; \ - ae_f32x2 _ae_f32x2_tmp_var_##idx_row##_##idx_vec = \ - AE_SLAA32S(AE_ROUND32F64SSYM(AE_SLAA64S( \ - _ae_int64_acc_##idx_row##_##idx_vec, acc_shift)), \ - 16); \ - _ae_int32_tmp_var_##idx_row##_##idx_vec = \ - AE_SLAA32S(_ae_f32x2_tmp_var_##idx_row##_##idx_vec, -16); \ - (*((WORD16 *)p_out + (vec_itr + idx_vec) * out_offset + \ - (m_itr + idx_row) * out_stride)) = \ - (*((UWORD32 *)&_ae_int32_tmp_var_##idx_row##_##idx_vec)); - -#define STORE_ACC_BATCH_ROW_AT_OUT_f32(idx_row) \ - STORE_ACC_BATCH_VEC_UNROLL(idx_row); - -#define STORE_ACC_BATCH_AT_OUT_f32(idx_row, idx_vec) \ - /*p_out value stored in a tmp pointer to make it inout for ISA */ \ - p_out_tmp = (p_out[vec_itr + idx_vec] + m_itr + idx_row); \ - XT_SSIP(_xtfloat_acc_##idx_row##_##idx_vec, p_out_tmp, 0); - -#define STORE_ACC_BATCH_ROW_ASYM8bxASYM8b_AT_OUT_ASYM8b(idx_row) \ - STORE_ACC_BATCH_VEC_UNROLL(idx_row); - -#define STORE_ACC_BATCH_ASYM8bxASYM8b_AT_OUT_ASYM8b(idx_row, idx_vec) \ - _ae_int32x2_acc_##idx_row##_##idx_vec = \ - AE_MIN32(AE_MAX32(_ae_int32x2_acc_##idx_row##_##idx_vec, AE_MOVDA32(0)), \ - AE_MOVDA32(255)); \ - (*((UWORD8 *)(p_out[vec_itr + idx_vec] + m_itr + idx_row))) = \ - (UWORD8)AE_MOVAD32_L(_ae_int32x2_acc_##idx_row##_##idx_vec); - -/*---------------------------------------------------------*/ -/* Specific macros needed for extra calculations involved - for ASYM8b */ - -/* This is written to match with Tensorflow */ -#define ADJUST_ACC_ASYM8b(idx) \ - /* Multiply accumulator with 'out_multiplier', same as Tensorflow */ \ - ae_int32x2 _ae_int32x2_acc_##idx = \ - AE_SLAA32(AE_MOVINT32X2_FROMINT64(_ae_int64_acc_##idx), left_shift); \ - _ae_int32x2_acc_##idx = \ - AE_MULFP32X2RAS(_ae_int32x2_acc_##idx, AE_MOVDA32(out_multiplier)); \ - /* Shift by out_shift, same as Tensorflow */ \ - _ae_int64_acc_##idx = \ - AE_SLAI64(AE_MOVINT64_FROMINT32X2(_ae_int32x2_acc_##idx), 32); \ - _ae_int64_acc_##idx = AE_SRAA64(_ae_int64_acc_##idx, right_shift); \ - _ae_int32x2_acc_##idx = AE_ROUND32F64SSYM(_ae_int64_acc_##idx); \ - /* Add output zero point */ \ - (_ae_int32x2_acc_##idx) = \ - AE_ADD32S(_ae_int32x2_acc_##idx, AE_MOVDA32(out_zero_bias)); - -/* For time batching */ -#define ADJUST_ACC_BATCH_ROW_ASYM8b(idx_row) \ - ADJUST_ACC_BATCH_VEC_UNROLL(idx_row); - -/* For time batching */ -#define ADJUST_ACC_BATCH_ASYM8b(idx_row, idx_vec) \ - /* Multiply accumulator with 'out_multiplier', same as Tensorflow */ \ - ae_int32x2 _ae_int32x2_acc_##idx_row##_##idx_vec = \ - AE_SLAA32(AE_MOVINT32X2_FROMINT64(_ae_int64_acc_##idx_row##_##idx_vec), \ - left_shift); \ - _ae_int32x2_acc_##idx_row##_##idx_vec = AE_MULFP32X2RAS( \ - _ae_int32x2_acc_##idx_row##_##idx_vec, AE_MOVDA32(out_multiplier)); \ - /* Shift by out_shift, same as Tensorflow */ \ - _ae_int64_acc_##idx_row##_##idx_vec = AE_SLAI64( \ - AE_MOVINT64_FROMINT32X2(_ae_int32x2_acc_##idx_row##_##idx_vec), 32); \ - _ae_int64_acc_##idx_row##_##idx_vec = \ - AE_SRAA64(_ae_int64_acc_##idx_row##_##idx_vec, right_shift); \ - _ae_int32x2_acc_##idx_row##_##idx_vec = \ - AE_ROUND32F64SSYM(_ae_int64_acc_##idx_row##_##idx_vec); \ - /* Add output zero point */ \ - (_ae_int32x2_acc_##idx_row##_##idx_vec) = AE_ADD32S( \ - _ae_int32x2_acc_##idx_row##_##idx_vec, AE_MOVDA32(out_zero_bias)); - -/*---------------------------------------------------------*/ -/* ==================================================================================================== - */ -#if (ROW_UNROLL == 1) -#define SETUP_ACC UNROLL_SETUP_ACC(0) -#define SETUP_MAT1 UNROLL_SETUP_MAT1(0) -#define SETUP_MAT2 UNROLL_SETUP_MAT2(0) -#define KERNEL_MAT1_VEC1 UNROLL_KERNEL_MAT1_VEC1(0) -#define KERNEL_MAT2_VEC2 UNROLL_KERNEL_MAT2_VEC2(0) -#define ADD_BIAS_ACC UNROLL_ADD_BIAS_ACC(0) -#define ADJUST_ACC UNROLL_ADJUST_ACC(0) -#define STORE_ACC UNROLL_STORE_ACC(0) - -#elif (ROW_UNROLL == 2) -#define SETUP_ACC UNROLL_SETUP_ACC(0) UNROLL_SETUP_ACC(1) -#define SETUP_MAT1 UNROLL_SETUP_MAT1(0) UNROLL_SETUP_MAT1(1) -#define SETUP_MAT2 UNROLL_SETUP_MAT2(0) UNROLL_SETUP_MAT2(1) -#define KERNEL_MAT1_VEC1 UNROLL_KERNEL_MAT1_VEC1(0) UNROLL_KERNEL_MAT1_VEC1(1) -#define KERNEL_MAT2_VEC2 UNROLL_KERNEL_MAT2_VEC2(0) UNROLL_KERNEL_MAT2_VEC2(1) -#define ADD_BIAS_ACC UNROLL_ADD_BIAS_ACC(0) UNROLL_ADD_BIAS_ACC(1) -#define ADJUST_ACC UNROLL_ADJUST_ACC(0) UNROLL_ADJUST_ACC(1) -#define STORE_ACC UNROLL_STORE_ACC(0) UNROLL_STORE_ACC(1) - -#elif (ROW_UNROLL == 4) -#define SETUP_ACC \ - UNROLL_SETUP_ACC(0) \ - UNROLL_SETUP_ACC(1) UNROLL_SETUP_ACC(2) UNROLL_SETUP_ACC(3) -#define SETUP_MAT1 \ - UNROLL_SETUP_MAT1(0) \ - UNROLL_SETUP_MAT1(1) UNROLL_SETUP_MAT1(2) UNROLL_SETUP_MAT1(3) -#define SETUP_MAT2 \ - UNROLL_SETUP_MAT2(0) \ - UNROLL_SETUP_MAT2(1) UNROLL_SETUP_MAT2(2) UNROLL_SETUP_MAT2(3) -#define KERNEL_MAT1_VEC1 \ - UNROLL_KERNEL_MAT1_VEC1(0) \ - UNROLL_KERNEL_MAT1_VEC1(1) \ - UNROLL_KERNEL_MAT1_VEC1(2) UNROLL_KERNEL_MAT1_VEC1(3) -#define KERNEL_MAT2_VEC2 \ - UNROLL_KERNEL_MAT2_VEC2(0) \ - UNROLL_KERNEL_MAT2_VEC2(1) \ - UNROLL_KERNEL_MAT2_VEC2(2) UNROLL_KERNEL_MAT2_VEC2(3) -#define ADD_BIAS_ACC \ - UNROLL_ADD_BIAS_ACC(0) \ - UNROLL_ADD_BIAS_ACC(1) UNROLL_ADD_BIAS_ACC(2) UNROLL_ADD_BIAS_ACC(3) -#define ADJUST_ACC \ - UNROLL_ADJUST_ACC(0) \ - UNROLL_ADJUST_ACC(1) UNROLL_ADJUST_ACC(2) UNROLL_ADJUST_ACC(3) -#define STORE_ACC \ - UNROLL_STORE_ACC(0) \ - UNROLL_STORE_ACC(1) UNROLL_STORE_ACC(2) UNROLL_STORE_ACC(3) - -#elif (ROW_UNROLL == 8) -#define SETUP_ACC \ - UNROLL_SETUP_ACC(0) \ - UNROLL_SETUP_ACC(1) \ - UNROLL_SETUP_ACC(2) \ - UNROLL_SETUP_ACC(3) \ - UNROLL_SETUP_ACC(4) \ - UNROLL_SETUP_ACC(5) UNROLL_SETUP_ACC(6) UNROLL_SETUP_ACC(7) -#define SETUP_MAT1 \ - UNROLL_SETUP_MAT1(0) \ - UNROLL_SETUP_MAT1(1) \ - UNROLL_SETUP_MAT1(2) \ - UNROLL_SETUP_MAT1(3) \ - UNROLL_SETUP_MAT1(4) \ - UNROLL_SETUP_MAT1(5) UNROLL_SETUP_MAT1(6) UNROLL_SETUP_MAT1(7) -#define SETUP_MAT2 \ - UNROLL_SETUP_MAT2(0) \ - UNROLL_SETUP_MAT2(1) \ - UNROLL_SETUP_MAT2(2) \ - UNROLL_SETUP_MAT2(3) \ - UNROLL_SETUP_MAT2(4) \ - UNROLL_SETUP_MAT2(5) UNROLL_SETUP_MAT2(6) UNROLL_SETUP_MAT2(7) -#define KERNEL_MAT1_VEC1 \ - UNROLL_KERNEL_MAT1_VEC1(0) \ - UNROLL_KERNEL_MAT1_VEC1(1) \ - UNROLL_KERNEL_MAT1_VEC1(2) \ - UNROLL_KERNEL_MAT1_VEC1(3) \ - UNROLL_KERNEL_MAT1_VEC1(4) \ - UNROLL_KERNEL_MAT1_VEC1(5) \ - UNROLL_KERNEL_MAT1_VEC1(6) UNROLL_KERNEL_MAT1_VEC1(7) -#define KERNEL_MAT2_VEC2 \ - UNROLL_KERNEL_MAT2_VEC2(0) \ - UNROLL_KERNEL_MAT2_VEC2(1) \ - UNROLL_KERNEL_MAT2_VEC2(2) \ - UNROLL_KERNEL_MAT2_VEC2(3) \ - UNROLL_KERNEL_MAT2_VEC2(4) \ - UNROLL_KERNEL_MAT2_VEC2(5) \ - UNROLL_KERNEL_MAT2_VEC2(6) UNROLL_KERNEL_MAT2_VEC2(7) -#define ADD_BIAS_ACC \ - UNROLL_ADD_BIAS_ACC(0) \ - UNROLL_ADD_BIAS_ACC(1) \ - UNROLL_ADD_BIAS_ACC(2) \ - UNROLL_ADD_BIAS_ACC(3) \ - UNROLL_ADD_BIAS_ACC(4) \ - UNROLL_ADD_BIAS_ACC(5) UNROLL_ADD_BIAS_ACC(6) UNROLL_ADD_BIAS_ACC(7) -#define ADJUST_ACC \ - UNROLL_ADJUST_ACC(0) \ - UNROLL_ADJUST_ACC(1) \ - UNROLL_ADJUST_ACC(2) \ - UNROLL_ADJUST_ACC(3) \ - UNROLL_ADJUST_ACC(4) \ - UNROLL_ADJUST_ACC(5) UNROLL_ADJUST_ACC(6) UNROLL_ADJUST_ACC(7) -#define STORE_ACC \ - UNROLL_STORE_ACC(0) \ - UNROLL_STORE_ACC(1) \ - UNROLL_STORE_ACC(2) \ - UNROLL_STORE_ACC(3) \ - UNROLL_STORE_ACC(4) \ - UNROLL_STORE_ACC(5) UNROLL_STORE_ACC(6) UNROLL_STORE_ACC(7) - -#endif /* (ROW_UNROLL == 1) */ - -#if (ROW_UNROLL == 4 && VEC_UNROLL == 2) - -#define SETUP_VEC_BATCH UNROLL_SETUP_VEC_BATCH(0) UNROLL_SETUP_VEC_BATCH(1) - -#define SETUP_ACC_BATCH \ - UNROLL_ROW_SETUP_ACC_BATCH(0) \ - UNROLL_ROW_SETUP_ACC_BATCH(1) \ - UNROLL_ROW_SETUP_ACC_BATCH(2) UNROLL_ROW_SETUP_ACC_BATCH(3) -#define SETUP_ACC_BATCH_VEC_UNROLL(idx_row) \ - UNROLL_SETUP_ACC_BATCH(idx_row, 0) UNROLL_SETUP_ACC_BATCH(idx_row, 1) -#define SETUP_ACC_BATCH_TAIL \ - UNROLL_SETUP_ACC_BATCH(0, 0) \ - UNROLL_SETUP_ACC_BATCH(1, 0) \ - UNROLL_SETUP_ACC_BATCH(2, 0) UNROLL_SETUP_ACC_BATCH(3, 0) - -#define LOAD_VEC_BATCH UNROLL_LOAD_VEC_BATCH(0) UNROLL_LOAD_VEC_BATCH(1) -#define LOAD_MAT1 \ - UNROLL_LOAD_ROW_MAT1(0) \ - UNROLL_LOAD_ROW_MAT1(1) UNROLL_LOAD_ROW_MAT1(2) UNROLL_LOAD_ROW_MAT1(3) - -#define KERNEL_MAT1_VEC_BATCH \ - UNROLL_ROW_KERNEL_MAT1_VEC_BATCH(0) \ - UNROLL_ROW_KERNEL_MAT1_VEC_BATCH(1) \ - UNROLL_ROW_KERNEL_MAT1_VEC_BATCH(2) UNROLL_ROW_KERNEL_MAT1_VEC_BATCH(3) -#define KERNEL_MAT1_VEC_BATCH_VEC_UNROLL(idx_row) \ - UNROLL_KERNEL_MAT1_VEC_BATCH(idx_row, 0) \ - UNROLL_KERNEL_MAT1_VEC_BATCH(idx_row, 1) -#define KERNEL_MAT1_VEC_BATCH_TAIL \ - UNROLL_KERNEL_MAT1_VEC_BATCH(0, 0) \ - UNROLL_KERNEL_MAT1_VEC_BATCH(1, 0) \ - UNROLL_KERNEL_MAT1_VEC_BATCH(2, 0) UNROLL_KERNEL_MAT1_VEC_BATCH(3, 0) - -#define ADD_BIAS_ACC_BATCH \ - UNROLL_ROW_ADD_BIAS_ACC(0) \ - UNROLL_ROW_ADD_BIAS_ACC(1) \ - UNROLL_ROW_ADD_BIAS_ACC(2) UNROLL_ROW_ADD_BIAS_ACC(3) -#define ADD_BIAS_BATCH_ACC_VEC_UNROLL(idx_row) \ - UNROLL_ADD_BIAS_ACC_BATCH(idx_row, 0) UNROLL_ADD_BIAS_ACC_BATCH(idx_row, 1) -#define ADD_BIAS_ACC_BATCH_TAIL \ - LOAD_BIAS UNROLL_ADD_BIAS_ACC_BATCH(0, 0) \ - LOAD_BIAS UNROLL_ADD_BIAS_ACC_BATCH(1, 0) \ - LOAD_BIAS UNROLL_ADD_BIAS_ACC_BATCH(2, 0) \ - LOAD_BIAS UNROLL_ADD_BIAS_ACC_BATCH(3, 0) - -#define STORE_ACC_BATCH \ - UNROLL_ROW_STORE_ACC(0) \ - UNROLL_ROW_STORE_ACC(1) UNROLL_ROW_STORE_ACC(2) UNROLL_ROW_STORE_ACC(3) -#define STORE_ACC_BATCH_VEC_UNROLL(idx_row) \ - UNROLL_STORE_ACC_BATCH(idx_row, 0) UNROLL_STORE_ACC_BATCH(idx_row, 1) -#define STORE_ACC_BATCH_TAIL \ - UNROLL_STORE_ACC_BATCH(0, 0) \ - UNROLL_STORE_ACC_BATCH(1, 0) \ - UNROLL_STORE_ACC_BATCH(2, 0) UNROLL_STORE_ACC_BATCH(3, 0) - -#define ADJUST_ACC_BATCH_TAIL \ - UNROLL_ADJUST_ACC_BATCH(0, 0) \ - UNROLL_ADJUST_ACC_BATCH(1, 0) \ - UNROLL_ADJUST_ACC_BATCH(2, 0) UNROLL_ADJUST_ACC_BATCH(3, 0) -#define ADJUST_ACC_BATCH \ - UNROLL_ROW_ADJUST_ACC(0) \ - UNROLL_ROW_ADJUST_ACC(1) UNROLL_ROW_ADJUST_ACC(2) UNROLL_ROW_ADJUST_ACC(3) -#define ADJUST_ACC_BATCH_VEC_UNROLL(idx_row) \ - UNROLL_ADJUST_ACC_BATCH(idx_row, 0) UNROLL_ADJUST_ACC_BATCH(idx_row, 1) - -#endif /* (ROW_UNROLL == 4 && VEC_UNROLL == 2)*/ - -#endif /* __XA_NNLIB_COMMON_MACROS_H__ */ diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_definitions.h b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_definitions.h deleted file mode 100644 index 7199887f501..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_definitions.h +++ /dev/null @@ -1,57 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Cadence Design Systems, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining - * a copy of this software and associated documentation files (the - * "Software"), to use this Software with Cadence processor cores only and - * not with any other processors and platforms, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - ******************************************************************************/ - -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef __XA_OPUS_CODEC_DEFINITIONS_H__ -#define __XA_OPUS_CODEC_DEFINITIONS_H__ - -#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_api_defs.h" - -/* Identification Strings */ -#define LIBNAME "HiFi Mini Neural Network Library" -#define LIBVERSION "0.6.0" - -#define LIB_APIVERSION_MAJOR 1 -#define LIB_APIVERSION_MINOR 0 - -#if LIB_APIVERSION_MAJOR != XA_APIVERSION_MAJOR || \ - LIB_APIVERSION_MINOR != XA_APIVERSION_MINOR -// #error "Version Mismatch" -#endif - -#define LIB_APIVERSION \ - XA_MAKE_VERSION_STR(LIB_APIVERSION_MAJOR, LIB_APIVERSION_MINOR) - -#endif /* __XA_OPUS_CODEC_DEFINITIONS_H__ */ diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_err_chk.h b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_err_chk.h deleted file mode 100644 index 8508e54e515..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/common/include/xa_nnlib_err_chk.h +++ /dev/null @@ -1,84 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Cadence Design Systems, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining - * a copy of this software and associated documentation files (the - * "Software"), to use this Software with Cadence processor cores only and - * not with any other processors and platforms, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - ******************************************************************************/ - -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef __XA_NNLIB_ERR_CHK_H__ -#define __XA_NNLIB_ERR_CHK_H__ - -#ifndef NULL -#define NULL (void *)0 -#endif /* NULL */ - -#ifndef DISABLE_ARG_CHK - -#define XA_NNLIB_ARG_CHK_PTR(_ptr, _err) \ - do { \ - if ((_ptr) == NULL) return (_err); \ - } while (0) - -#define XA_NNLIB_ARG_CHK_ALIGN(_ptr, _align, _err) \ - do { \ - if (((unsigned int)(_ptr) & ((_align)-1)) != 0) return (_err); \ - } while (0) - -#define XA_NNLIB_ARG_CHK_COND(_cond, _err) \ - do { \ - if ((_cond)) return (_err); \ - } while (0) - -#else /* DISABLE_ARG_CHK */ - -#define XA_NNLIB_ARG_CHK_PTR(_ptr, _err) -#define XA_NNLIB_ARG_CHK_ALIGN(_ptr, _align, _err) -#define XA_NNLIB_ARG_CHK_COND(_cond, _err) - -#endif /* DISABLE_ARG_CHK */ - -#define XA_NNLIB_CHK_PTR(_ptr, _err) \ - do { \ - if ((_ptr) == NULL) return (_err); \ - } while (0) - -#define XA_NNLIB_CHK_ALIGN(_ptr, _align, _err) \ - do { \ - if (((unsigned int)(_ptr) & ((_align)-1)) != 0) return (_err); \ - } while (0) - -#define XA_NNLIB_CHK_COND(_cond, _err) \ - do { \ - if ((_cond)) return (_err); \ - } while (0) - -#endif /* __XA_NNLIB_ERR_CHK_H__ */ diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/activations/hifi_mini/xa_nn_activations_asym8s_asym8s.c b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/activations/hifi_mini/xa_nn_activations_asym8s_asym8s.c deleted file mode 100644 index 060b70696e0..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/activations/hifi_mini/xa_nn_activations_asym8s_asym8s.c +++ /dev/null @@ -1,176 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Cadence Design Systems, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining - * a copy of this software and associated documentation files (the - * "Software"), to use this Software with Cadence processor cores only and - * not with any other processors and platforms, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - ******************************************************************************/ - -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xa_nnlib_common.h" - -#define ALIGNMENT 8 /* 8 bytes alignment */ - -#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) - -#define LIMIT(out, inp, min, max) \ - { \ - out = min; \ - out = AE_MAXP24S(inp, min); \ - out = AE_MINP24S(out, max); \ - } - -#define STORE_8X2_FROM_24X2(out_ptr, val) \ - { \ - int o1, o2; \ - o1 = AE_MOVAP24S_H(val); \ - o2 = AE_MOVAP24S_L(val); \ - *out_ptr++ = (WORD8)o1; \ - *out_ptr++ = (WORD8)o2; \ - } - -/* - * inp: p_vec: 4 byte aligned input pointer - * out: p_out: no alignment needed for output pointer*/ -WORD32 xa_nn_vec_activation_min_max_asym8s_asym8s( - WORD8 *__restrict__ p_out, const WORD8 *__restrict__ p_vec, - int activation_min, int activation_max, WORD32 vec_length) { - int i; - ae_p24x2s x, y, min, max; - - /* NULL pointer checks */ - XA_NNLIB_ARG_CHK_PTR(p_out, -1); - XA_NNLIB_ARG_CHK_PTR(p_vec, -1); - - /* Basic Parameter checks */ - XA_NNLIB_ARG_CHK_COND((vec_length <= 0), -1); - - /* Basic Parameter checks */ - XA_NNLIB_ARG_CHK_COND((activation_max < activation_min), -1); - - WORD8 *p_o = p_out; - WORD8 *p_v = (WORD8 *)p_vec; - - min = AE_SRAIP24(AE_CVTP24A16(activation_min), 8); - max = AE_SRAIP24(AE_CVTP24A16(activation_max), 8); - - int pre_loop_count = 0; - // pre loop, active when input ptr is not 4 byte aligned - pre_loop_count = (int)((unsigned)ALIGN_PTR(p_v, 4) - (unsigned)p_v); - pre_loop_count = (pre_loop_count < vec_length) ? pre_loop_count : vec_length; - - vec_length = vec_length - pre_loop_count; - vec_length = (vec_length < 0) ? 0 : vec_length; - - for (i = 0; i < pre_loop_count; i++) { - int i1; - i1 = ((WORD8)*p_v++); - x = AE_MOVPA24(i1); - LIMIT(y, x, min, max) - i1 = AE_MOVAP24S_H(y); - *p_o++ = (WORD8)i1; - } - - if ((activation_max >= (int)127) && (activation_min <= (int)-128)) { - p_v = p_v - 2; - for (i = 0; i < (vec_length >> 1); i++) { - AE_LP8X2F_IU(x, (WORD8 *)p_v, 2 * sizeof(WORD8)); - y = AE_SRAIP24(x, 16); - - STORE_8X2_FROM_24X2(p_o, y) - } - if (vec_length & 1) { - p_v = p_v + 2; - int i1; - i1 = (WORD8)p_v[0]; - *p_o++ = (WORD8)i1; - } - } else if ((activation_max < (int)127) && (activation_min <= (int)-128)) { - p_v = p_v - 2; - for (i = 0; i < (vec_length >> 1); i++) { - AE_LP8X2F_IU(x, (WORD8 *)p_v, 2 * sizeof(WORD8)); - y = AE_SRAIP24(x, 16); - - y = AE_MINP24S(y, max); - - STORE_8X2_FROM_24X2(p_o, y) - } - if (vec_length & 1) { - p_v = p_v + 2; - int i1; - i1 = (WORD8)p_v[0]; - y = AE_MOVPA24(i1); - - y = AE_MINP24S(y, max); - - i1 = AE_MOVAP24S_H(y); - *p_o++ = (WORD8)i1; - } - } else if ((activation_max >= (int)127) && (activation_min > (int)-128)) { - p_v = p_v - 2; - for (i = 0; i < (vec_length >> 1); i++) { - AE_LP8X2F_IU(x, (WORD8 *)p_v, 2 * sizeof(WORD8)); - y = AE_SRAIP24(x, 16); - - y = AE_MAXP24S(y, min); - - STORE_8X2_FROM_24X2(p_o, y) - } - if (vec_length & 1) { - p_v = p_v + 2; - int i1; - i1 = (WORD8)p_v[0]; - y = AE_MOVPA24(i1); - - y = AE_MAXP24S(y, min); - - i1 = AE_MOVAP24S_H(y); - *p_o++ = (WORD8)i1; - } - } else { - p_v = p_v - 2; - for (i = 0; i < (vec_length >> 1); i++) { - AE_LP8X2F_IU(x, (WORD8 *)p_v, 2 * sizeof(WORD8)); - x = AE_SRAIP24(x, 16); - LIMIT(y, x, min, max) - STORE_8X2_FROM_24X2(p_o, y) - } - if (vec_length & 1) { - p_v = p_v + 2; - int i1; - i1 = (WORD8)p_v[0]; - x = AE_MOVPA24(i1); - LIMIT(y, x, min, max) - i1 = AE_MOVAP24S_H(y); - *p_o++ = (WORD8)i1; - } - } - return 0; -} diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/activations/hifi_mini/xa_nn_softmax_asym8_asym8.c b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/activations/hifi_mini/xa_nn_softmax_asym8_asym8.c deleted file mode 100644 index 4f7dce839d3..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/activations/hifi_mini/xa_nn_softmax_asym8_asym8.c +++ /dev/null @@ -1,1005 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Cadence Design Systems, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining - * a copy of this software and associated documentation files (the - * "Software"), to use this Software with Cadence processor cores only and - * not with any other processors and platforms, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - ******************************************************************************/ - -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xa_nnlib_common.h" - -#define ALIGNMENT 8 /* 8 bytes alignment */ -#define ALIGNED_SIZE(x, bytes) (((x) + (bytes - 1)) & (~(bytes - 1))) -#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) - -#ifndef AE_LP8X2F_IU -#define AE_LP8X2F_IU(p_x, p_in, x) \ - AE_LP16F_IU(p_x, (ae_p16s *)p_in, x); \ - ae_p24x2s p_tmp1 = AE_SLLIP24(p_x, 8); \ - ae_p24x2s p_tmp2 = AE_ANDP48(p_x, AE_MOVPA24(0xFFFF0000)); \ - p_x = AE_SELP24_LL(p_tmp2, p_tmp1); - -#endif - -#define NSA64_T(y, x) \ - { \ - ae_q56s q_tmp = *(ae_q56s *)&x; \ - y = AE_NSAQ56S(q_tmp) + 8; \ - } - -#define MULFP32X2RAS_T(result, a, b) \ - { \ - ae_q56s q_a = AE_CVTQ48A32S(a); \ - ae_p24x2s p_b = AE_CVTP24A16X2_HL(b, b); \ - ae_q56s q_out = AE_MULFQ32SP16U_L(q_a, p_b); \ - q_out = AE_SRAIQ56(q_out, 16); \ - AE_MULAFQ32SP16S_H(q_out, q_a, p_b); \ - q_out = AE_ROUNDSQ32ASYM(q_out); \ - *(ae_q32s *)&result = q_out; \ - } - -#define MULFP32X2RS_T(result, a, b) \ - { \ - ae_q56s q_a = AE_CVTQ48A32S(a); \ - ae_p24x2s p_b = AE_CVTP24A16X2_HL(b, b); \ - ae_q56s q_out = AE_MULFQ32SP16U_L(q_a, p_b); \ - q_out = AE_SRAIQ56(q_out, 16); \ - AE_MULAFQ32SP16S_H(q_out, q_a, p_b); \ - q_out = AE_ROUNDSQ32SYM(q_out); \ - *(ae_q32s *)&result = q_out; \ - } -#define ADD32S_T(result, a, b) \ - { \ - ae_q56s q_a = AE_CVTQ48A32S(a); \ - ae_q56s q_b = AE_CVTQ48A32S(b); \ - ae_q56s q_out = AE_ADDSQ56S(q_a, q_b); \ - q_out = AE_SATQ48S(q_out); \ - *(ae_q32s *)&result = q_out; \ - } - -#define SUB32S_T(result, a, b) \ - { \ - ae_q56s q_a = AE_CVTQ48A32S(a); \ - ae_q56s q_b = AE_CVTQ48A32S(b); \ - ae_q56s q_out = AE_SUBSQ56S(q_a, q_b); \ - q_out = AE_SATQ48S(q_out); \ - *(ae_q32s *)&result = q_out; \ - } - -#define SLAI32S_T(result, a, b) \ - { \ - ae_q56s q_a = AE_CVTQ48A32S(a); \ - ae_q56s q_out = AE_SLLIQ56(q_a, b); \ - q_out = AE_SATQ48S(q_out); \ - *(ae_q32s *)&result = q_out; \ - } - -#define SRAA32RS_T(result, a, b) \ - { \ - ae_q56s q_a = AE_CVTQ48A32S(a); \ - ae_q56s q_out = AE_SLAASQ56S(q_a, (-b)); \ - q_out = AE_ROUNDSQ32ASYM(q_out); \ - *(ae_q32s *)&result = q_out; \ - } - -#define SRAI32R_T(result, a, b) \ - { \ - ae_q56s q_a = AE_CVTQ48A32S(a); \ - ae_q56s q_out = AE_SRAIQ56(q_a, b); \ - q_out = AE_ROUNDSQ32ASYM(q_out); \ - *(ae_q32s *)&result = q_out; \ - } - -static const int CONSTANT_TERM = (0x70f5a894); -static const int CONSTANT_1_OVER_3 = (0x2aaaaaab); -static const int CONSTANT_1_OVER_8 = (0x10000000); -static const int ONE_QUATER_Q26 = (0x1000000); // Q6.26 -static const int MASK = (0xffffff); -static const int Q31 = 0x7fffffff; -static const int constant_48_over_17 = 1515870810; -static const int constant_neg_32_over_17 = -1010580540; // Q29 -static const int F2_ONE = 0x20000000; - -static const int constant_neg_32_over_17_Q21 = -3947580; // Q21 -static const int constant_48_over_17_Q21 = 5921370; // Q21 - -static ae_p24x2s GetReciprocal(ae_q56s q_x, int x_integerbits, int *lsh) { - int headroom_plus_one; - ae_p24x2s p_x; - ae_q56s q_tmp; - ae_p24x2s p_half_den; - int i; - - headroom_plus_one = AE_NSAQ56S(q_x) + 8; - headroom_plus_one = headroom_plus_one - 31; - *lsh = x_integerbits - headroom_plus_one; - - q_x = (q_x << (headroom_plus_one + 15)); - p_half_den = AE_ROUNDSP24Q48SYM(q_x); - - q_tmp = AE_CVTQ48A32S(constant_48_over_17); - AE_MULAFP24S_LL(q_tmp, p_half_den, AE_MOVPA24(constant_neg_32_over_17_Q21)); - p_x = AE_ROUNDSP24Q48SYM(q_tmp); - - for (i = 0; i < 3; i++) { - q_tmp = AE_CVTQ48A32S(F2_ONE); - AE_MULSFP24S_LL(q_tmp, p_x, p_half_den); - ae_p24x2s p_one_minus_half_denominator_times_x = AE_ROUNDSP24Q48SYM(q_tmp); - - q_tmp = AE_MULFP24S_LL(p_x, p_one_minus_half_denominator_times_x); - ae_p24x2s p_m = AE_ROUNDSP24Q48SYM(q_tmp); - p_m = AE_SLLISP24S(p_m, 2); - p_x = AE_ADDSP24S(p_x, p_m); - } - - p_x = AE_SLLISP24S(p_x, 1); - - return p_x; -} - -static const int MASK_16BITS = (0xffff); -static const int ONE_QUATER_Q18 = (0x10000); // Q18 -static const int CONSTANT_1_OVER_8_Q23 = (0x100000); // Q23 -static const int CONSTANT_1_OVER_3_Q23 = (0x2aaaaa); // Q23 -static const int CONSTANT_TERM_Q23 = (0x70f5a8); // Q23 -static const int Q23 = 0x7fffff; - -#define GEMMLOWP_EXP_BARREL_SHIFTER_OPT_II(p_in_out, exponent, \ - FixedPointMultiplier, p_remainder) \ - { \ - ae_p24x2s p_out; \ - \ - ae_p24x2s p_zero = AE_ZEROP48(); \ - \ - ae_p24x2s p_scale = AE_MOVPA24(1 << (18 + exponent)); \ - ae_p24x2s p_mask = p_remainder & p_scale; \ - \ - ae_p24x2s p_FixedPointMultiplier = AE_MOVPA24(FixedPointMultiplier >> 8); \ - \ - ae_q56s q_tmp1 = AE_MULFP24S_HH(p_in_out, p_FixedPointMultiplier); \ - ae_q56s q_tmp2 = AE_MULFP24S_LL(p_in_out, p_FixedPointMultiplier); \ - ae_p24x2s p_t1 = AE_ROUNDSP24Q48SYM(q_tmp1); \ - ae_p24x2s p_t2 = AE_ROUNDSP24Q48SYM(q_tmp2); \ - p_out = AE_SELP24_LL(p_t1, p_t2); \ - \ - xtbool2 flag_le = AE_LTP24S(p_zero, p_mask); \ - AE_MOVTP24X2(p_in_out, p_out, flag_le); \ - } - -#define EXP_Q26_II(p_exp_y, p_inp_t) \ - { \ - ae_p24x2s p_x1_in, p_x2, p_x3, p_x4, p_x4_by_4, p_y1, p_y2, p_y3, p_y4, \ - p_y5, p_y6, p_y; \ - \ - p_x2 = p_inp_t & AE_MOVPA24(MASK_16BITS); \ - ae_p24x2s p_a_mod_quater_minus_q_1_by_4 = \ - p_x2 - AE_MOVPA24(ONE_QUATER_Q18); \ - ae_p24x2s p_x_in = p_a_mod_quater_minus_q_1_by_4 << 5; \ - ae_p24x2s p_remainder = p_a_mod_quater_minus_q_1_by_4 - p_inp_t; \ - \ - p_x1_in = AE_ADDSP24S(p_x_in, AE_MOVPA24(CONSTANT_1_OVER_8_Q23)); \ - \ - ae_q56s q_tmp1 = AE_MULFP24S_HH(p_x1_in, p_x1_in); \ - ae_q56s q_tmp2 = AE_MULFP24S_LL(p_x1_in, p_x1_in); \ - ae_p24x2s p_t1 = AE_ROUNDSP24Q48SYM(q_tmp1); \ - ae_p24x2s p_t2 = AE_ROUNDSP24Q48SYM(q_tmp2); \ - p_x2 = AE_SELP24_LL(p_t1, p_t2); \ - \ - q_tmp1 = AE_MULFP24S_HH(p_t1, p_x1_in); \ - q_tmp2 = AE_MULFP24S_LL(p_t2, p_x1_in); \ - p_t1 = AE_ROUNDSP24Q48SYM(q_tmp1); \ - p_t2 = AE_ROUNDSP24Q48SYM(q_tmp2); \ - p_x3 = AE_SELP24_LL(p_t1, p_t2); \ - \ - q_tmp1 = AE_MULFP24S_HH(p_x2, p_x2); \ - q_tmp2 = AE_MULFP24S_LL(p_x2, p_x2); \ - p_t1 = AE_ROUNDSP24Q48SYM(q_tmp1); \ - p_t2 = AE_ROUNDSP24Q48SYM(q_tmp2); \ - p_x4 = AE_SELP24_LL(p_t1, p_t2); \ - p_x4_by_4 = p_x4 >> 2; \ - \ - p_y1 = AE_ADDSP24S(p_x4_by_4, p_x3); \ - \ - ae_p24x2s p_const = AE_MOVPA24(CONSTANT_1_OVER_3_Q23); \ - q_tmp1 = AE_MULFP24S_HH(p_y1, p_const); \ - q_tmp2 = AE_MULFP24S_LL(p_y1, p_const); \ - p_t1 = AE_ROUNDSP24Q48SYM(q_tmp1); \ - p_t2 = AE_ROUNDSP24Q48SYM(q_tmp2); \ - p_y2 = AE_SELP24_LL(p_t1, p_t2); \ - \ - p_y3 = AE_ADDSP24S(p_y2, p_x2); \ - p_y4 = p_y3 >> 1; \ - \ - p_y5 = AE_ADDSP24S(p_x1_in, p_y4); /* ADD32S_T(y5, x1_in, y4); */ \ - \ - p_const = AE_MOVPA24(CONSTANT_TERM_Q23); \ - q_tmp1 = AE_MULFP24S_HH(p_y5, p_const); \ - q_tmp2 = AE_MULFP24S_LL(p_y5, p_const); \ - p_t1 = AE_ROUNDSP24Q48SYM(q_tmp1); \ - p_t2 = AE_ROUNDSP24Q48SYM(q_tmp2); \ - p_y6 = AE_SELP24_LL(p_t1, p_t2); \ - p_y = AE_ADDSP24S(p_y6, p_const); \ - \ - { \ - GEMMLOWP_EXP_BARREL_SHIFTER_OPT_II(p_y, -2, 1672461947, p_remainder); \ - GEMMLOWP_EXP_BARREL_SHIFTER_OPT_II(p_y, -1, 1302514674, p_remainder); \ - GEMMLOWP_EXP_BARREL_SHIFTER_OPT_II(p_y, 0, 790015084, p_remainder); \ - GEMMLOWP_EXP_BARREL_SHIFTER_OPT_II(p_y, 1, 290630308, p_remainder); \ - GEMMLOWP_EXP_BARREL_SHIFTER_OPT_II(p_y, 2, 39332535, p_remainder); \ - GEMMLOWP_EXP_BARREL_SHIFTER_OPT_II(p_y, 3, 720401, p_remainder); \ - GEMMLOWP_EXP_BARREL_SHIFTER_OPT_II(p_y, 4, 242, p_remainder); \ - } \ - p_exp_y = p_y; \ - p_const = AE_MOVPA24(Q23); \ - xtbool2 flag_eq = AE_EQP24(p_inp_t, AE_ZEROP48()); \ - AE_MOVTP24X2(p_exp_y, p_const, flag_eq); \ - } - -#define GEMMLOWP_EXP_BARREL_SHIFTER_OPT_I(p_in_out, exponent, \ - FixedPointMultiplier, p_remainder) \ - { \ - ae_p24x2s p_out; \ - \ - ae_p24x2s p_zero = AE_ZEROP48(); \ - \ - ae_p24x2s p_scale = AE_MOVPA24(1 << (18 + exponent)); \ - ae_p24x2s p_mask = p_remainder & p_scale; \ - \ - ae_p24x2s p_FixedPointMultiplier = AE_MOVPA24(FixedPointMultiplier >> 8); \ - \ - ae_q56s q_tmp1 = AE_MULFP24S_HH(p_in_out, p_FixedPointMultiplier); \ - p_out = AE_ROUNDSP24Q48SYM(q_tmp1); \ - \ - xtbool2 flag_le = AE_LTP24S(p_zero, p_mask); \ - AE_MOVTP24X2(p_in_out, p_out, flag_le); \ - } - -#define EXP_Q26_I(p_exp_y, p_inp_t) \ - { \ - ae_p24x2s p_x1_in, p_x2, p_x3, p_x4, p_x4_by_4, p_y1, p_y2, p_y3, p_y4, \ - p_y5, p_y6, p_y; \ - \ - p_x2 = p_inp_t & AE_MOVPA24(MASK_16BITS); \ - ae_p24x2s p_a_mod_quater_minus_q_1_by_4 = \ - p_x2 - AE_MOVPA24(ONE_QUATER_Q18); \ - ae_p24x2s p_x_in = p_a_mod_quater_minus_q_1_by_4 << 5; \ - ae_p24x2s p_remainder = p_a_mod_quater_minus_q_1_by_4 - p_inp_t; \ - \ - p_x1_in = AE_ADDSP24S(p_x_in, AE_MOVPA24(CONSTANT_1_OVER_8_Q23)); \ - \ - ae_q56s q_tmp1 = AE_MULFP24S_HH(p_x1_in, p_x1_in); \ - p_x2 = AE_ROUNDSP24Q48SYM(q_tmp1); \ - \ - q_tmp1 = AE_MULFP24S_HH(p_x2, p_x1_in); \ - p_x3 = AE_ROUNDSP24Q48SYM(q_tmp1); \ - \ - q_tmp1 = AE_MULFP24S_HH(p_x2, p_x2); \ - p_x4 = AE_ROUNDSP24Q48SYM(q_tmp1); \ - p_x4_by_4 = p_x4 >> 2; \ - \ - p_y1 = AE_ADDSP24S(p_x4_by_4, p_x3); \ - \ - ae_p24x2s p_const = AE_MOVPA24(CONSTANT_1_OVER_3_Q23); \ - q_tmp1 = AE_MULFP24S_HH(p_y1, p_const); \ - p_y2 = AE_ROUNDSP24Q48SYM(q_tmp1); \ - \ - p_y3 = AE_ADDSP24S(p_y2, p_x2); \ - p_y4 = p_y3 >> 1; \ - \ - p_y5 = AE_ADDSP24S(p_x1_in, p_y4); /* ADD32S_T(y5, x1_in, y4); */ \ - \ - p_const = AE_MOVPA24(CONSTANT_TERM_Q23); \ - q_tmp1 = AE_MULFP24S_HH(p_y5, p_const); \ - p_y6 = AE_ROUNDSP24Q48SYM(q_tmp1); \ - p_y = AE_ADDSP24S(p_y6, p_const); \ - \ - { \ - GEMMLOWP_EXP_BARREL_SHIFTER_OPT_I(p_y, -2, 1672461947, p_remainder); \ - GEMMLOWP_EXP_BARREL_SHIFTER_OPT_I(p_y, -1, 1302514674, p_remainder); \ - GEMMLOWP_EXP_BARREL_SHIFTER_OPT_I(p_y, 0, 790015084, p_remainder); \ - GEMMLOWP_EXP_BARREL_SHIFTER_OPT_I(p_y, 1, 290630308, p_remainder); \ - GEMMLOWP_EXP_BARREL_SHIFTER_OPT_I(p_y, 2, 39332535, p_remainder); \ - GEMMLOWP_EXP_BARREL_SHIFTER_OPT_I(p_y, 3, 720401, p_remainder); \ - GEMMLOWP_EXP_BARREL_SHIFTER_OPT_I(p_y, 4, 242, p_remainder); \ - } \ - p_exp_y = p_y; \ - p_const = AE_MOVPA24(Q23); \ - xtbool2 flag_eq = AE_EQP24(p_inp_t, AE_ZEROP48()); \ - AE_MOVTP24X2(p_exp_y, p_const, flag_eq); \ - } - -WORD32 xa_nn_vec_softmax_asym8u_8(UWORD8 *__restrict__ pOut, - const UWORD8 *__restrict__ pVec, - WORD32 diffmin, WORD32 input_beta_left_shift, - WORD32 input_beta_multiplier, - WORD32 vec_length, pVOID pScratch) { - /* NULL pointer checks */ - XA_NNLIB_ARG_CHK_PTR(pOut, -1); - XA_NNLIB_ARG_CHK_PTR(pVec, -1); - XA_NNLIB_ARG_CHK_PTR(pScratch, -1); - /* Pointer alignment checks */ - /* No alignment (1-byte) needed for any pointer */ - /* Basic Parameter checks */ - XA_NNLIB_ARG_CHK_COND((vec_length <= 0), -1); - XA_NNLIB_ARG_CHK_COND( - ((input_beta_left_shift < -31) || (input_beta_left_shift > 31)), -1); - XA_NNLIB_ARG_CHK_COND((input_beta_multiplier < 0), -1); - - int i; - int shift_bits_reciprocal; - UWORD8 *p_in; - WORD32 *__restrict pExp = (WORD32 *)ALIGN_PTR(pScratch, ALIGNMENT); - ae_p24f *__restrict pTmpScratch = (ae_p24f *)pExp; - int max; - ae_p24x2s p_x; - ae_p24x2s p_max = AE_MOVPA24(0xFF800000); - ae_p24x2s p_recip_sum_exp; - int pre_loop_count; - int main_loop_count; - int post_loop_count; - - if (vec_length > 1) { - pre_loop_count = (int)pVec & 0x1; - main_loop_count = vec_length - pre_loop_count; - post_loop_count = (main_loop_count & 1); - main_loop_count = main_loop_count >> 1; - } else { - pre_loop_count = 0; - main_loop_count = 0; - post_loop_count = vec_length; - } - - /* Calculating Max */ - { - p_in = (UWORD8 *)pVec; - - if (pre_loop_count) { - p_x = AE_MOVPA24(*p_in++); - p_max = AE_MAXP24S(p_max, p_x); - } - - p_in -= 2; - for (i = 0; i < main_loop_count; i++) { - AE_LP8X2F_IU(p_x, p_in, 2 * sizeof(WORD8)); - p_x = AE_SRLIP24(p_x, 16); - p_max = AE_MAXP24S(p_max, p_x); - } - - if (post_loop_count) { - p_in += 2; - p_x = AE_MOVPA24(*p_in); - p_max = AE_MAXP24S(p_max, p_x); - } - p_max = AE_MAXP24S(p_max, AE_SELP24_LH(p_max, p_max)); - max = AE_MOVAP24S_L(p_max); - } - - /* Calculate exponents */ - { - ae_q56s q_sum_exp = AE_ZEROQ56(); - ae_p24x2s p_rem_x, p_y, p_exp_y; - ae_p24x2s p_zero = AE_ZEROP48(); - ae_p24x2s p_input_beta_multiplier = - AE_MOVPA24((input_beta_multiplier >> 8)); - ae_p24x2s p_diffmin = AE_MOVPA24(diffmin); - int input_beta_left_shift_for_24bit = input_beta_left_shift - 8; - - p_in = (UWORD8 *)pVec; - WUR_AE_SAR(input_beta_left_shift_for_24bit); - - if (pre_loop_count) { - p_x = AE_MOVPA24(*p_in++); - p_rem_x = p_x - p_max; - p_y = AE_SLLSSP24S(p_rem_x); - - ae_q56s q_dequantized_y1 = AE_MULFP24S_LL(p_y, p_input_beta_multiplier); - - ae_p24x2s p_dequantized_y1 = AE_ROUNDSP24Q48ASYM(q_dequantized_y1); - - EXP_Q26_I(p_exp_y, p_dequantized_y1) - - xtbool2 flag_cmp = AE_LTP24S(p_rem_x, p_diffmin); - AE_MOVTP24X2(p_exp_y, p_zero, flag_cmp); - - *pTmpScratch++ = p_exp_y; - - p_exp_y = p_exp_y >> 4; - - AE_MULAP24S_LL(q_sum_exp, p_exp_y, AE_MOVPA24(1)); - } - - p_in -= 2; - for (i = 0; i < main_loop_count; i++) { - AE_LP8X2F_IU(p_x, p_in, 2 * sizeof(WORD8)); - p_x = AE_SRLIP24(p_x, 16); - p_rem_x = p_x - p_max; - p_y = AE_SLLSSP24S(p_rem_x); - - ae_q56s q_dequantized_y1 = AE_MULFP24S_HH(p_y, p_input_beta_multiplier); - ae_q56s q_dequantized_y2 = AE_MULFP24S_LL(p_y, p_input_beta_multiplier); - - ae_p24x2s p_dequantized_y1 = AE_ROUNDSP24Q48ASYM(q_dequantized_y1); - ae_p24x2s p_dequantized_y2 = AE_ROUNDSP24Q48ASYM(q_dequantized_y2); - - ae_p24x2s p_dequantized = - AE_SELP24_LL(p_dequantized_y1, p_dequantized_y2); - - EXP_Q26_II(p_exp_y, p_dequantized) - - xtbool2 flag_cmp = AE_LTP24S(p_rem_x, p_diffmin); - AE_MOVTP24X2(p_exp_y, p_zero, flag_cmp); - - *pTmpScratch++ = AE_SELP24_HH(p_exp_y, p_exp_y); - *pTmpScratch++ = p_exp_y; /* store lower element */ - - p_exp_y = p_exp_y >> 4; - - AE_MULAAP24S_HH_LL(q_sum_exp, p_exp_y, AE_MOVPA24(1)); - } - if (post_loop_count) { - p_in += 2; - - p_x = AE_MOVPA24(*p_in); - p_rem_x = p_x - p_max; - p_y = AE_SLLSSP24S(p_rem_x); - - ae_q56s q_dequantized_y1 = AE_MULFP24S_LL(p_y, p_input_beta_multiplier); - - ae_p24x2s p_dequantized_y1 = AE_ROUNDSP24Q48ASYM(q_dequantized_y1); - - EXP_Q26_I(p_exp_y, p_dequantized_y1) - - xtbool2 flag_cmp = AE_LTP24S(p_rem_x, p_diffmin); - AE_MOVTP24X2(p_exp_y, p_zero, flag_cmp); - - *pTmpScratch = p_exp_y; - - p_exp_y = p_exp_y >> 4; - - AE_MULAP24S_LL(q_sum_exp, p_exp_y, AE_MOVPA24(1)); - } - p_recip_sum_exp = GetReciprocal(q_sum_exp, 12, &shift_bits_reciprocal); - } - - /* Calculate output */ - { - ae_p24x2s p_exp; - - int shift_val = -(shift_bits_reciprocal + 31 - 8 - 8); - - ae_p24x2s p_min = AE_ZEROP48(); - ae_p24x2s p_max = AE_MOVPA24(255); - - for (i = 0; i<vec_length >> 1; i++) { - int out; - - p_exp = *(ae_p24x2f *)&pExp[2 * i]; - - ae_q56s q_tmp1 = AE_MULFP24S_HH(p_exp, p_recip_sum_exp); - ae_q56s q_tmp2 = AE_MULFP24S_LL(p_exp, p_recip_sum_exp); - - q_tmp1 = AE_SLAASQ56S(q_tmp1, shift_val); - q_tmp2 = AE_SLAASQ56S(q_tmp2, shift_val); - - ae_p24x2s p_out1 = AE_ROUNDSP24Q48ASYM(q_tmp1); - ae_p24x2s p_out2 = AE_ROUNDSP24Q48ASYM(q_tmp2); - - ae_p24x2s p_out = AE_SELP24_LL(p_out1, p_out2); - - p_out = AE_MAXP24S(p_out, p_min); - p_out = AE_MINP24S(p_out, p_max); - - out = AE_MOVAP24S_H(p_out); - *pOut++ = (UWORD8)out; - - out = AE_MOVAP24S_L(p_out); - *pOut++ = (UWORD8)out; - } - - if (vec_length & 0x1) { - int out; - - p_exp = *(ae_p24f *)&pExp[vec_length - 1]; - - ae_q56s q_tmp1 = AE_MULFP24S_LL(p_exp, p_recip_sum_exp); - - q_tmp1 = AE_SLAASQ56S(q_tmp1, shift_val); - - ae_p24x2s p_out = AE_ROUNDSP24Q48ASYM(q_tmp1); - - p_out = AE_MAXP24S(p_out, p_min); - p_out = AE_MINP24S(p_out, p_max); - - out = AE_MOVAP24S_L(p_out); - *pOut++ = (UWORD8)out; - } - } - - return 0; -} - -WORD32 xa_nn_vec_softmax_asym8s_8(WORD8 *__restrict__ pOut, - const WORD8 *__restrict__ pVec, - WORD32 diffmin, WORD32 input_beta_left_shift, - WORD32 input_beta_multiplier, - WORD32 vec_length, pVOID pScratch) { - /* NULL pointer checks */ - XA_NNLIB_ARG_CHK_PTR(pOut, -1); - XA_NNLIB_ARG_CHK_PTR(pVec, -1); - XA_NNLIB_ARG_CHK_PTR(pScratch, -1); - /* Pointer alignment checks */ - /* No alignment (1-byte) needed for any pointer */ - /* Basic Parameter checks */ - XA_NNLIB_ARG_CHK_COND((vec_length <= 0), -1); - XA_NNLIB_ARG_CHK_COND( - ((input_beta_left_shift < -31) || (input_beta_left_shift > 31)), -1); - XA_NNLIB_ARG_CHK_COND((input_beta_multiplier < 0), -1); - - int i; - int shift_bits_reciprocal; - WORD8 *p_in; - WORD32 *__restrict pExp = (WORD32 *)ALIGN_PTR(pScratch, ALIGNMENT); - ae_p24x2s p_recip_sum_exp; - ae_p24x2s p_x; - ae_p24x2s p_max = AE_MOVPA24(0xFF800000); - - int pre_loop_count; - int main_loop_count; - int post_loop_count; - - if (vec_length > 1) { - pre_loop_count = (int)pVec & 0x1; - main_loop_count = vec_length - pre_loop_count; - post_loop_count = (main_loop_count & 1); - main_loop_count = main_loop_count >> 1; - } else { - pre_loop_count = 0; - main_loop_count = 0; - post_loop_count = vec_length; - } - - /* Calculating Max */ - { - p_in = (WORD8 *)pVec; - - if (pre_loop_count) { - p_x = AE_MOVPA24(*p_in++); - p_max = AE_MAXP24S(p_max, p_x); - } - - p_in -= 2; - for (i = 0; i < main_loop_count; i++) { - AE_LP8X2F_IU(p_x, p_in, 2 * sizeof(WORD8)); - p_max = AE_MAXP24S(p_max, p_x); - } - p_max = AE_SRAIP24(p_max, 16); - - if (post_loop_count) { - p_in += 2; - p_x = AE_MOVPA24(*p_in); - p_max = AE_MAXP24S(p_max, p_x); - } - p_max = AE_MAXP24S(p_max, AE_SELP24_LH(p_max, p_max)); - } - - /* Calculate exponents */ - { - ae_q56s q_sum_exp = AE_ZEROQ56(); - ae_p24x2s p_rem_x, p_y, p_exp_y; - ae_p24x2s p_zero = AE_ZEROP48(); - ae_p24x2s p_input_beta_multiplier = - AE_MOVPA24((input_beta_multiplier >> 8)); - ae_p24x2s p_diffmin = AE_MOVPA24(diffmin); - int input_beta_left_shift_for_24bit = input_beta_left_shift - 8; - - p_in = (WORD8 *)pVec; - WUR_AE_SAR(input_beta_left_shift_for_24bit); - - if (pre_loop_count) { - p_x = AE_MOVPA24(*p_in++); - p_rem_x = p_x - p_max; - p_y = AE_SLLSSP24S(p_rem_x); - - ae_q56s q_dequantized_y1 = AE_MULFP24S_LL(p_y, p_input_beta_multiplier); - - ae_p24x2s p_dequantized_y1 = AE_ROUNDSP24Q48ASYM(q_dequantized_y1); - - EXP_Q26_I(p_exp_y, p_dequantized_y1) - - xtbool2 flag_cmp = AE_LTP24S(p_rem_x, p_diffmin); - AE_MOVTP24X2(p_exp_y, p_zero, flag_cmp); - - *(ae_p24f *)&pExp[0] = p_exp_y; - - p_exp_y = p_exp_y >> 4; - - AE_MULAP24S_LL(q_sum_exp, p_exp_y, AE_MOVPA24(1)); - } - - p_in -= 2; - for (i = 0; i < main_loop_count; i++) { - AE_LP8X2F_IU(p_x, p_in, 2 * sizeof(WORD8)); - p_x = AE_SRAIP24(p_x, 16); - p_rem_x = p_x - p_max; - p_y = AE_SLLSSP24S(p_rem_x); - - ae_q56s q_dequantized_y1 = AE_MULFP24S_HH(p_y, p_input_beta_multiplier); - ae_q56s q_dequantized_y2 = AE_MULFP24S_LL(p_y, p_input_beta_multiplier); - - ae_p24x2s p_dequantized_y1 = AE_ROUNDSP24Q48ASYM(q_dequantized_y1); - ae_p24x2s p_dequantized_y2 = AE_ROUNDSP24Q48ASYM(q_dequantized_y2); - - ae_p24x2s p_dequantized = - AE_SELP24_LL(p_dequantized_y1, p_dequantized_y2); - - EXP_Q26_II(p_exp_y, p_dequantized) - - xtbool2 flag_cmp = AE_LTP24S(p_rem_x, p_diffmin); - AE_MOVTP24X2(p_exp_y, p_zero, flag_cmp); - - //*(ae_p24x2f *)&pExp[pre_loop_count + 2*i] = p_exp_y; - *(ae_p24f *)&pExp[pre_loop_count + 2 * i] = - AE_SELP24_HH(p_exp_y, p_exp_y); - *(ae_p24f *)&pExp[pre_loop_count + 2 * i + 1] = - AE_SELP24_LL(p_exp_y, p_exp_y); - //*(ae_p24f *)&pExp[0] = p_exp_y; - - p_exp_y = p_exp_y >> 4; - - AE_MULAAP24S_HH_LL(q_sum_exp, p_exp_y, AE_MOVPA24(1)); - } - - if (post_loop_count) { - p_in += 2; - - p_x = AE_MOVPA24(*p_in); - p_rem_x = p_x - p_max; - p_y = AE_SLLSSP24S(p_rem_x); - - ae_q56s q_dequantized_y1 = AE_MULFP24S_LL(p_y, p_input_beta_multiplier); - - ae_p24x2s p_dequantized_y1 = AE_ROUNDSP24Q48ASYM(q_dequantized_y1); - - EXP_Q26_I(p_exp_y, p_dequantized_y1) - - xtbool2 flag_cmp = AE_LTP24S(p_rem_x, p_diffmin); - AE_MOVTP24X2(p_exp_y, p_zero, flag_cmp); - - *(ae_p24f *)&pExp[vec_length - 1] = p_exp_y; - - p_exp_y = p_exp_y >> 4; - - AE_MULAP24S_LL(q_sum_exp, p_exp_y, AE_MOVPA24(1)); - } - - p_recip_sum_exp = GetReciprocal(q_sum_exp, 12, &shift_bits_reciprocal); - } - - /* Calculate output */ - pExp = (WORD32 *)ALIGN_PTR(pScratch, ALIGNMENT); - { - ae_p24x2s p_exp; - - int shift_val = -(shift_bits_reciprocal + 31 - 8 - 8); - - ae_p24x2s p_min = AE_MOVPA24(-128); - ae_p24x2s p_max = AE_MOVPA24(127); - - for (i = 0; i<vec_length >> 1; i++) { - int out; - - p_exp = *(ae_p24x2f *)&pExp[2 * i]; - - ae_q56s q_tmp1 = AE_MULFP24S_HH(p_exp, p_recip_sum_exp); - ae_q56s q_tmp2 = AE_MULFP24S_LL(p_exp, p_recip_sum_exp); - - q_tmp1 = AE_SLAASQ56S(q_tmp1, shift_val); - q_tmp2 = AE_SLAASQ56S(q_tmp2, shift_val); - - ae_p24x2s p_out1 = AE_ROUNDSP24Q48ASYM(q_tmp1); - ae_p24x2s p_out2 = AE_ROUNDSP24Q48ASYM(q_tmp2); - - ae_p24x2s p_out = AE_SELP24_LL(p_out1, p_out2); - - p_out = AE_SUBSP24S(p_out, AE_MOVPA24(128)); - p_out = AE_MAXP24S(p_out, p_min); - p_out = AE_MINP24S(p_out, p_max); - - out = AE_MOVAP24S_H(p_out); - *pOut++ = (WORD8)out; - - out = AE_MOVAP24S_L(p_out); - *pOut++ = (WORD8)out; - } - - if (vec_length & 0x1) { - int out; - - p_exp = *(ae_p24f *)&pExp[vec_length - 1]; - - ae_q56s q_tmp1 = AE_MULFP24S_LL(p_exp, p_recip_sum_exp); - - q_tmp1 = AE_SLAASQ56S(q_tmp1, shift_val); - - ae_p24x2s p_out = AE_ROUNDSP24Q48ASYM(q_tmp1); - - p_out = AE_SUBSP24S(p_out, AE_MOVPA24(128)); - p_out = AE_MAXP24S(p_out, p_min); - p_out = AE_MINP24S(p_out, p_max); - - out = AE_MOVAP24S_L(p_out); - *pOut++ = (WORD8)out; - } - } - - return 0; -} - -WORD32 xa_nn_vec_softmax_asym8s_16(WORD16 *__restrict__ pOut, - const WORD8 *__restrict__ pVec, - WORD32 diffmin, WORD32 input_beta_left_shift, - WORD32 input_beta_multiplier, - WORD32 vec_length, pVOID pScratch) { - /* NULL pointer checks */ - XA_NNLIB_ARG_CHK_PTR(pOut, -1); - XA_NNLIB_ARG_CHK_PTR(pVec, -1); - XA_NNLIB_ARG_CHK_PTR(pScratch, -1); - /* Pointer alignment checks */ - /* No alignment (1-byte) needed for any pointer */ - /* Basic Parameter checks */ - XA_NNLIB_ARG_CHK_COND((vec_length <= 0), -1); - XA_NNLIB_ARG_CHK_COND( - ((input_beta_left_shift < -31) || (input_beta_left_shift > 31)), -1); - XA_NNLIB_ARG_CHK_COND((input_beta_multiplier < 0), -1); - - int i; - int shift_bits_reciprocal; - WORD8 *p_in; - WORD32 *__restrict pExp = (WORD32 *)ALIGN_PTR(pScratch, ALIGNMENT); - ae_p24x2s p_recip_sum_exp; - ae_p24x2s p_x; - ae_p24x2s p_max = AE_MOVPA24(0xFF800000); - - int pre_loop_count; - int main_loop_count; - int post_loop_count; - - if (vec_length > 1) { - pre_loop_count = (int)pVec & 0x1; - main_loop_count = vec_length - pre_loop_count; - post_loop_count = (main_loop_count & 1); - main_loop_count = main_loop_count >> 1; - } else { - pre_loop_count = 0; - main_loop_count = 0; - post_loop_count = vec_length; - } - - /* Calculating Max */ - { - p_in = (WORD8 *)pVec; - - if (pre_loop_count) { - p_x = AE_MOVPA24(*p_in++); - p_max = AE_MAXP24S(p_max, p_x); - } - - p_in -= 2; - for (i = 0; i < main_loop_count; i++) { - AE_LP8X2F_IU(p_x, p_in, 2 * sizeof(WORD8)); - p_max = AE_MAXP24S(p_max, p_x); - } - p_max = AE_SRAIP24(p_max, 16); - - if (post_loop_count) { - p_in += 2; - p_x = AE_MOVPA24(*p_in); - p_max = AE_MAXP24S(p_max, p_x); - } - p_max = AE_MAXP24S(p_max, AE_SELP24_LH(p_max, p_max)); - } - - /* Calculate exponents */ - { - ae_q56s q_sum_exp = AE_ZEROQ56(); - ae_p24x2s p_rem_x, p_y, p_exp_y; - ae_p24x2s p_zero = AE_ZEROP48(); - ae_p24x2s p_input_beta_multiplier = - AE_MOVPA24((input_beta_multiplier >> 8)); - ae_p24x2s p_diffmin = AE_MOVPA24(diffmin); - int input_beta_left_shift_for_24bit = input_beta_left_shift - 8; - - p_in = (WORD8 *)pVec; - WUR_AE_SAR(input_beta_left_shift_for_24bit); - - if (pre_loop_count) { - p_x = AE_MOVPA24(*p_in++); - p_rem_x = p_x - p_max; - p_y = AE_SLLSSP24S(p_rem_x); - - ae_q56s q_dequantized_y1 = AE_MULFP24S_LL(p_y, p_input_beta_multiplier); - - ae_p24x2s p_dequantized_y1 = AE_ROUNDSP24Q48ASYM(q_dequantized_y1); - - EXP_Q26_I(p_exp_y, p_dequantized_y1) - - xtbool2 flag_cmp = AE_LTP24S(p_rem_x, p_diffmin); - AE_MOVTP24X2(p_exp_y, p_zero, flag_cmp); - - *(ae_p24f *)&pExp[0] = p_exp_y; - - p_exp_y = p_exp_y >> 4; - - AE_MULAP24S_LL(q_sum_exp, p_exp_y, AE_MOVPA24(1)); - } - - p_in -= 2; - for (i = 0; i < main_loop_count; i++) { - AE_LP8X2F_IU(p_x, p_in, 2 * sizeof(WORD8)); - p_x = AE_SRAIP24(p_x, 16); - p_rem_x = p_x - p_max; - p_y = AE_SLLSSP24S(p_rem_x); - - ae_q56s q_dequantized_y1 = AE_MULFP24S_HH(p_y, p_input_beta_multiplier); - ae_q56s q_dequantized_y2 = AE_MULFP24S_LL(p_y, p_input_beta_multiplier); - - ae_p24x2s p_dequantized_y1 = AE_ROUNDSP24Q48ASYM(q_dequantized_y1); - ae_p24x2s p_dequantized_y2 = AE_ROUNDSP24Q48ASYM(q_dequantized_y2); - - ae_p24x2s p_dequantized = - AE_SELP24_LL(p_dequantized_y1, p_dequantized_y2); - - EXP_Q26_II(p_exp_y, p_dequantized) - - xtbool2 flag_cmp = AE_LTP24S(p_rem_x, p_diffmin); - AE_MOVTP24X2(p_exp_y, p_zero, flag_cmp); - - *(ae_p24f *)&pExp[pre_loop_count + 2 * i] = - AE_SELP24_HH(p_exp_y, p_exp_y); - *(ae_p24f *)&pExp[pre_loop_count + 2 * i + 1] = - AE_SELP24_LL(p_exp_y, p_exp_y); - - p_exp_y = p_exp_y >> 4; - - AE_MULAAP24S_HH_LL(q_sum_exp, p_exp_y, AE_MOVPA24(1)); - } - - if (post_loop_count) { - p_in += 2; - - p_x = AE_MOVPA24(*p_in); - p_rem_x = p_x - p_max; - p_y = AE_SLLSSP24S(p_rem_x); - - ae_q56s q_dequantized_y1 = AE_MULFP24S_LL(p_y, p_input_beta_multiplier); - - ae_p24x2s p_dequantized_y1 = AE_ROUNDSP24Q48ASYM(q_dequantized_y1); - - EXP_Q26_I(p_exp_y, p_dequantized_y1) - - xtbool2 flag_cmp = AE_LTP24S(p_rem_x, p_diffmin); - AE_MOVTP24X2(p_exp_y, p_zero, flag_cmp); - - *(ae_p24f *)&pExp[vec_length - 1] = p_exp_y; - - p_exp_y = p_exp_y >> 4; - - AE_MULAP24S_LL(q_sum_exp, p_exp_y, AE_MOVPA24(1)); - } - - p_recip_sum_exp = GetReciprocal(q_sum_exp, 12, &shift_bits_reciprocal); - } - - /* Calculate output */ - pExp = (WORD32 *)ALIGN_PTR(pScratch, ALIGNMENT); - { - ae_p24x2s p_exp; - - int shift_val = -(shift_bits_reciprocal + 31 - 8 - 16); - - ae_p24x2s p_min = AE_MOVPA24(-32768); - ae_p24x2s p_max = AE_MOVPA24(32767); - - for (i = 0; i<vec_length >> 1; i++) { - int out; - - p_exp = *(ae_p24x2f *)&pExp[2 * i]; - - ae_q56s q_tmp1 = AE_MULFP24S_HH(p_exp, p_recip_sum_exp); - ae_q56s q_tmp2 = AE_MULFP24S_LL(p_exp, p_recip_sum_exp); - - q_tmp1 = AE_SLAASQ56S(q_tmp1, shift_val); - q_tmp2 = AE_SLAASQ56S(q_tmp2, shift_val); - - ae_p24x2s p_out1 = AE_ROUNDSP24Q48ASYM(q_tmp1); - ae_p24x2s p_out2 = AE_ROUNDSP24Q48ASYM(q_tmp2); - - ae_p24x2s p_out = AE_SELP24_LL(p_out1, p_out2); - - p_out = AE_SUBSP24S(p_out, AE_MOVPA24(32768)); - p_out = AE_MAXP24S(p_out, p_min); - p_out = AE_MINP24S(p_out, p_max); - - out = AE_MOVAP24S_H(p_out); - *pOut++ = (WORD16)out; - - out = AE_MOVAP24S_L(p_out); - *pOut++ = (WORD16)out; - } - - if (vec_length & 0x1) { - int out; - - p_exp = *(ae_p24f *)&pExp[vec_length - 1]; - - ae_q56s q_tmp1 = AE_MULFP24S_LL(p_exp, p_recip_sum_exp); - - q_tmp1 = AE_SLAASQ56S(q_tmp1, shift_val); - - ae_p24x2s p_out = AE_ROUNDSP24Q48ASYM(q_tmp1); - - p_out = AE_SUBSP24S(p_out, AE_MOVPA24(32768)); - p_out = AE_MAXP24S(p_out, p_min); - p_out = AE_MINP24S(p_out, p_max); - - out = AE_MOVAP24S_L(p_out); - *pOut++ = (WORD16)out; - } - } - - return 0; -} - -int xa_nn_get_softmax_scratch_size(int inp_precision, int out_precision, - int length) { - int size_of_one_elm_in_bytes, total_bytes; - (void)out_precision; - - /* This function returns scratch size required by softmax implementation in - bytes scratch memory is needed to save exponents of inputs computed in the - function, every exponent is computed as 32 bit (4 bytes) number currently*/ - switch (inp_precision) { - case PREC_ASYM8U: - size_of_one_elm_in_bytes = 4; - break; - case PREC_SYM8S: - size_of_one_elm_in_bytes = 4; - break; - default: - size_of_one_elm_in_bytes = 4; - break; - } - - total_bytes = size_of_one_elm_in_bytes * length; - total_bytes = ALIGNED_SIZE(total_bytes, ALIGNMENT); - - return total_bytes; -} diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/basic/hifi_mini/xa_nn_dot_prod_16x16.c b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/basic/hifi_mini/xa_nn_dot_prod_16x16.c deleted file mode 100644 index 80697ca7068..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/basic/hifi_mini/xa_nn_dot_prod_16x16.c +++ /dev/null @@ -1,175 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Cadence Design Systems, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining - * a copy of this software and associated documentation files (the - * "Software"), to use this Software with Cadence processor cores only and - * not with any other processors and platforms, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - ******************************************************************************/ - -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xa_nnlib_common.h" -#include "xa_nnlib_common_macros.h" - -/*----------------------------Main function---------------------------------*/ -WORD32 xa_nn_dot_prod_16x16_asym8s( - WORD8 *__restrict__ p_out, /* pointer to output */ - const WORD16 *__restrict__ p_inp1_start, /* pointer to input1 */ - const WORD16 *__restrict__ p_inp2_start, /* pointer to input2 */ - const WORD32 *bias_ptr, WORD32 vec_length, WORD32 out_multiplier, - WORD32 out_shift, WORD32 out_zero_bias, WORD32 vec_count) { - /* NULL pointer checks */ - XA_NNLIB_ARG_CHK_PTR(p_out, -1); - XA_NNLIB_ARG_CHK_PTR(p_inp1_start, -1); - XA_NNLIB_ARG_CHK_PTR(p_inp2_start, -1); - /* Pointer alignment checks */ - XA_NNLIB_ARG_CHK_ALIGN(p_inp1_start, sizeof(WORD16), -1); - XA_NNLIB_ARG_CHK_ALIGN(p_inp2_start, sizeof(WORD16), -1); - /* Basic Parameter checks */ - XA_NNLIB_ARG_CHK_COND((vec_length <= 0), -1); - XA_NNLIB_ARG_CHK_COND((out_shift < -31 || out_shift > 31), -1); - XA_NNLIB_ARG_CHK_COND((out_zero_bias < -128 || out_zero_bias > 127), -1); - int left_shift, right_shift; - int loopcnt; - const WORD32 bias_buffer[2] = {0, 0}; - const WORD32 *p_bias_load; - WORD32 bias_address_increment = sizeof(WORD32); - - if (bias_ptr == NULL) { - p_bias_load = bias_buffer - 1; - bias_address_increment = 0; - } else { - p_bias_load = bias_ptr - 1; - } - - left_shift = out_shift < 0 ? 0 : out_shift; - right_shift = out_shift > 0 ? 0 : -out_shift; - /* inp1 4-bytes aligned, inp2 4-bytes aligned and vec_length is multple of 2 - */ - if (((((unsigned)p_inp1_start) & 0x3) == 0) && - ((((unsigned)p_inp2_start) & 0x3) == 0) && ((vec_length & 0x1) == 0)) { - const ae_p16x2s *pt_inp1, *pt_inp2; - pt_inp1 = (const ae_p16x2s *)&p_inp1_start[-2]; - pt_inp2 = (const ae_p16x2s *)&p_inp2_start[-2]; - - ae_q56s output_int8_max_56 = AE_CVTQ48A32S(127); - ae_q56s output_int8_min_56 = AE_CVTQ48A32S(-128); - for (loopcnt = 0; loopcnt < vec_count; loopcnt++) { - ae_p24x2s dp_inp1, dp_inp2; - ae_q32s dq_out32; - ae_q56s dq_out; - int i; - - AE_LQ32F_XU(dq_out, (ae_q32s *)p_bias_load, bias_address_increment); - - for (i = 0; i < (vec_length >> 1); i++) { - AE_LP16X2F_IU(dp_inp1, pt_inp1, 4); - AE_LP16X2F_IU(dp_inp2, pt_inp2, 4); - AE_MULAAP24S_HH_LL(dq_out, dp_inp1, dp_inp2); - } - - dq_out32 = AE_SATQ48S(dq_out); - MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32), - out_multiplier, left_shift, right_shift); - dq_out = AE_ADDSQ56S(dq_out, AE_CVTQ48A32S(out_zero_bias)); - - dq_out = AE_MAXQ56S(dq_out, output_int8_min_56); - dq_out = AE_MINQ56S(dq_out, output_int8_max_56); - *p_out++ = (WORD8)AE_TRUNCA32Q48(dq_out); - } - } else { -#ifndef DISABLE_NNLIB_UNALIGNED_SUPPORT - for (loopcnt = 0; loopcnt < vec_count; loopcnt++) { - ae_p24x2s dp_inp1, dp_inp2; - ae_q32s dq_out32; - ae_q56s dq_out; - int i; - const WORD16 *p_inp1 = (WORD16 *)&p_inp1_start[loopcnt * vec_length]; - const WORD16 *p_inp2 = (WORD16 *)&p_inp2_start[loopcnt * vec_length]; - - AE_LQ32F_XU(dq_out, (ae_q32s *)p_bias_load, bias_address_increment); - - if (((((unsigned)p_inp1) & 3) != 0 && (((unsigned)p_inp2) & 3) != 0) || - ((((unsigned)p_inp1) & 3) == 0 && (((unsigned)p_inp2) & 3) == 0)) { - int pre_loop_count = ((int)(((unsigned)p_inp1) & 3)) >> 1; - if (pre_loop_count != 0) { - dp_inp1 = AE_CVTP24A16X2_LL(*p_inp1++, *p_inp2++); - AE_MULAP24S_HL(dq_out, dp_inp1, dp_inp1); - } - const ae_p16x2s *pt_inp1, *pt_inp2; - pt_inp1 = (const ae_p16x2s *)(p_inp1 - 2); - pt_inp2 = (const ae_p16x2s *)(p_inp2 - 2); - for (i = 0; i < (vec_length - pre_loop_count - 1); i += 2) { - AE_LP16X2F_IU(dp_inp1, pt_inp1, 4); - AE_LP16X2F_IU(dp_inp2, pt_inp2, 4); - AE_MULAAP24S_HH_LL(dq_out, dp_inp1, dp_inp2); - } - if ((vec_length - pre_loop_count) & 1) { - dp_inp1 = AE_CVTP24A16X2_LL(p_inp1[i], p_inp2[i]); - AE_MULAP24S_HL(dq_out, dp_inp1, dp_inp1); - } - } else { - /* One of the pointers in not aligned to 4 bytes, if it is p_inp1, swap - * them */ - if ((((unsigned)p_inp1) & 3) != 0) { - const WORD16 *p_tmp; - p_tmp = p_inp1; - p_inp1 = p_inp2; - p_inp2 = p_tmp; - } - const ae_p16x2s *pt_inp1 = (const ae_p16x2s *)(p_inp1 - 2); - const ae_p16s *pt_inp2 = (const ae_p16s *)(p_inp2 - 1); - for (i = 0; i < (vec_length - 1); i += 2) { - ae_p24x2s dp_t0, dp_t1; - AE_LP16X2F_IU(dp_inp1, pt_inp1, 4); - AE_LP16F_IU(dp_t0, pt_inp2, 2); - AE_LP16F_IU(dp_t1, pt_inp2, 2); - dp_inp2 = AE_SELP24_LL(dp_t0, dp_t1); - AE_MULAAP24S_HH_LL(dq_out, dp_inp1, dp_inp2); - } - if (vec_length & 1) { - dp_inp1 = AE_CVTP24A16X2_LL(p_inp1[i], p_inp2[i]); - AE_MULAP24S_HL(dq_out, dp_inp1, dp_inp1); - } - } - dq_out32 = AE_SATQ48S(dq_out); - MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32), - out_multiplier, left_shift, right_shift); - dq_out = AE_ADDSQ56S(dq_out, AE_CVTQ48A32S(out_zero_bias)); - WORD32 out_i32 = AE_TRUNCA32Q48(AE_SATQ48S(dq_out)); - out_i32 = out_i32 < -128 ? -128 : out_i32; - out_i32 = out_i32 > 127 ? 127 : out_i32; - *p_out++ = (WORD8)out_i32; - } -#else - return 1; -#endif - } - return 0; -} diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/fc/hifi_mini/xa_nn_fully_connected.c b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/fc/hifi_mini/xa_nn_fully_connected.c deleted file mode 100644 index 0a9325e81bf..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/fc/hifi_mini/xa_nn_fully_connected.c +++ /dev/null @@ -1,142 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Cadence Design Systems, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining - * a copy of this software and associated documentation files (the - * "Software"), to use this Software with Cadence processor cores only and - * not with any other processors and platforms, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - ******************************************************************************/ - -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xa_nnlib_err_chk.h" -#include "xa_nnlib_kernels_api.h" -#include "xa_type_def.h" - -WORD32 xa_nn_fully_connected_asym8uxasym8u_asym8u( - UWORD8 *__restrict__ p_out, const UWORD8 *__restrict__ p_weight, - const UWORD8 *__restrict__ p_inp, const WORD32 *__restrict__ p_bias, - WORD32 weight_depth, WORD32 out_depth, WORD32 input_zero_bias, - WORD32 weight_zero_bias, WORD32 out_multiplier, WORD32 out_shift, - WORD32 out_zero_bias) { - /* NULL pointer checks */ - XA_NNLIB_ARG_CHK_PTR(p_out, -1); - XA_NNLIB_ARG_CHK_PTR(p_weight, -1); - XA_NNLIB_ARG_CHK_PTR(p_inp, -1); - XA_NNLIB_ARG_CHK_PTR(p_bias, -1); - /* Pointer alignment checks */ - XA_NNLIB_ARG_CHK_ALIGN(p_bias, sizeof(WORD32), -1); - /* Basic Parameter checks */ - XA_NNLIB_ARG_CHK_COND((out_depth <= 0), -1); - XA_NNLIB_ARG_CHK_COND((input_zero_bias < -255 || input_zero_bias > 0), -1); - XA_NNLIB_ARG_CHK_COND((weight_zero_bias < -255 || weight_zero_bias > 0), -1); - XA_NNLIB_ARG_CHK_COND((out_shift < -31 || out_shift > 31), -1); - XA_NNLIB_ARG_CHK_COND((out_zero_bias < 0 || out_zero_bias > 255), -1); - - WORD32 ret = 0; - ret = xa_nn_matXvec_out_stride_asym8uxasym8u_asym8u( - p_out, p_weight, p_inp, p_bias, out_depth /* rows */ - , - weight_depth /* cols */ - , - weight_depth /* row_stride */ - , - 1 /* out_stride */ - , - weight_zero_bias, input_zero_bias, out_multiplier, out_shift, - out_zero_bias); - return ret; -} - -WORD32 xa_nn_fully_connected_sym8sxasym8s_asym8s( - WORD8 *__restrict__ p_out, const WORD8 *__restrict__ p_weight, - const WORD8 *__restrict__ p_inp, const WORD32 *__restrict__ p_bias, - WORD32 weight_depth, WORD32 out_depth, WORD32 input_zero_bias, - WORD32 out_multiplier, WORD32 out_shift, WORD32 out_zero_bias) { - /* NULL pointer checks */ - XA_NNLIB_ARG_CHK_PTR(p_out, -1); - XA_NNLIB_ARG_CHK_PTR(p_weight, -1); - XA_NNLIB_ARG_CHK_PTR(p_inp, -1); - XA_NNLIB_ARG_CHK_PTR(p_bias, -1); - /* Pointer alignment checks */ - XA_NNLIB_ARG_CHK_ALIGN(p_bias, sizeof(WORD32), -1); - /* Basic Parameter checks */ - XA_NNLIB_ARG_CHK_COND((out_depth <= 0), -1); - XA_NNLIB_ARG_CHK_COND((input_zero_bias < -127 || input_zero_bias > 128), -1); - XA_NNLIB_ARG_CHK_COND((out_shift < -31 || out_shift > 31), -1); - XA_NNLIB_ARG_CHK_COND((out_zero_bias < -128 || out_zero_bias > 127), -1); - - WORD32 ret = 0; - ret = xa_nn_matXvec_out_stride_sym8sxasym8s_asym8s( - p_out, p_weight, p_inp, p_bias, out_depth /* rows */ - , - weight_depth /* cols */ - , - weight_depth /* row_stride */ - , - 1 /* out_stride */ - , - input_zero_bias, out_multiplier, out_shift, out_zero_bias); - return ret; -} - -WORD32 xa_nn_fully_connected_asym8sxasym8s_asym8s( - WORD8 *__restrict__ p_out, const WORD8 *__restrict__ p_weight, - const WORD8 *__restrict__ p_inp, const WORD32 *__restrict__ p_bias, - WORD32 weight_depth, WORD32 out_depth, WORD32 weight_zero_bias, - WORD32 input_zero_bias, WORD32 out_multiplier, WORD32 out_shift, - WORD32 out_zero_bias) { - /* NULL pointer checks */ - XA_NNLIB_ARG_CHK_PTR(p_out, -1); - XA_NNLIB_ARG_CHK_PTR(p_weight, -1); - XA_NNLIB_ARG_CHK_PTR(p_inp, -1); - XA_NNLIB_ARG_CHK_PTR(p_bias, -1); - /* Pointer alignment checks */ - XA_NNLIB_ARG_CHK_ALIGN(p_bias, sizeof(WORD32), -1); - /* Basic Parameter checks */ - XA_NNLIB_ARG_CHK_COND((out_depth <= 0), -1); - XA_NNLIB_ARG_CHK_COND((weight_zero_bias < -127 || weight_zero_bias > 128), - -1); - XA_NNLIB_ARG_CHK_COND((input_zero_bias < -127 || input_zero_bias > 128), -1); - XA_NNLIB_ARG_CHK_COND((out_shift < -31 || out_shift > 31), -1); - XA_NNLIB_ARG_CHK_COND((out_zero_bias < -128 || out_zero_bias > 127), -1); - - WORD32 ret = 0; - ret = xa_nn_matXvec_out_stride_asym8sxasym8s_asym8s( - p_out, p_weight, p_inp, p_bias, out_depth /* rows */ - , - weight_depth /* cols */ - , - weight_depth /* row_stride */ - , - 1 /* out_stride */ - , - weight_zero_bias, input_zero_bias, out_multiplier, out_shift, - out_zero_bias); - return ret; -} diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/matXvec/hifi_mini/xa_nn_matXvec_sym8sxasym8s.c b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/matXvec/hifi_mini/xa_nn_matXvec_sym8sxasym8s.c deleted file mode 100644 index 71af822e68b..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/algo/kernels/matXvec/hifi_mini/xa_nn_matXvec_sym8sxasym8s.c +++ /dev/null @@ -1,1053 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Cadence Design Systems, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining - * a copy of this software and associated documentation files (the - * "Software"), to use this Software with Cadence processor cores only and - * not with any other processors and platforms, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - ******************************************************************************/ - -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xa_nnlib_common.h" -#include "xa_nnlib_common_macros.h" - -#define ADD_OUT_OFFSET_STORE_INT8(ptr, data, out_offset) \ - { \ - data = AE_ADDSQ56S(data, AE_CVTQ48A32S(out_offset)); \ - int out_i32 = AE_TRUNCA32Q48(AE_SATQ48S(data)); \ - out_i32 = out_i32 < -128 ? -128 : out_i32; \ - out_i32 = out_i32 > 127 ? 127 : out_i32; \ - *(ptr) = (WORD8)out_i32; \ - } - -WORD32 xa_nn_matXvec_out_stride_sym8sxasym8s_asym8s( - WORD8 *__restrict__ p_out, const WORD8 *__restrict__ p_mat1, - const WORD8 *__restrict__ p_vec1, const WORD32 *__restrict__ p_bias, - WORD32 rows, WORD32 cols1, WORD32 row_stride1, WORD32 out_stride, - WORD32 vec1_zero_bias, WORD32 out_multiplier, WORD32 out_shift, - WORD32 out_zero_bias) { - /* NULL pointer checks */ - XA_NNLIB_ARG_CHK_PTR(p_out, -1); - XA_NNLIB_ARG_CHK_PTR(p_mat1, -1); - XA_NNLIB_ARG_CHK_PTR(p_vec1, -1); - /* Pointer alignment checks */ - XA_NNLIB_ARG_CHK_ALIGN(p_bias, sizeof(WORD32), -1); - /* Basic Parameter checks */ - XA_NNLIB_ARG_CHK_COND((rows <= 0), -1); - XA_NNLIB_ARG_CHK_COND((cols1 <= 0), -1); - XA_NNLIB_ARG_CHK_COND((row_stride1 < cols1), -1); - XA_NNLIB_ARG_CHK_COND((vec1_zero_bias < -127 || vec1_zero_bias > 128), -1); - XA_NNLIB_ARG_CHK_COND((out_shift < -31 || out_shift > 31), -1); - XA_NNLIB_ARG_CHK_COND((out_zero_bias < -128 || out_zero_bias > 127), -1); - - /* Iterators used in for loops */ - int m_itr, c_itr, i; - /* Assign initial value so this value will be used in trailing loop */ - m_itr = 0; - /* Shifts to match with Tensorflow */ - int left_shift, right_shift; - - left_shift = out_shift < 0 ? 0 : out_shift; - right_shift = out_shift > 0 ? 0 : -out_shift; - - const WORD8 *p_mat1_0, *p_mat1_1, *p_mat1_2, *p_mat1_3; - const WORD8 *p_vec1_0; - ae_p24x2s dp_mat1_0, dp_mat1_1, dp_mat1_2, dp_mat1_3, dp_vec1_0; - ae_p24x2s dp_vec1_zb; - ae_q56s dq_acc[4]; - ae_q56s dq_out32, dq_out; - - dp_vec1_zb = AE_MOVPA24(vec1_zero_bias); - if (((((unsigned)p_mat1) & 1) == 0) && ((((unsigned)p_vec1) & 1) == 0) && - ((row_stride1 & 1) == 0)) { - for (m_itr = 0; m_itr < (rows - 3); m_itr += 4) { - p_mat1_0 = &p_mat1[(m_itr + 0) * row_stride1 - 2]; - p_mat1_1 = &p_mat1[(m_itr + 1) * row_stride1 - 2]; - p_mat1_2 = &p_mat1[(m_itr + 2) * row_stride1 - 2]; - p_mat1_3 = &p_mat1[(m_itr + 3) * row_stride1 - 2]; - p_vec1_0 = p_vec1 - 2; - - dq_acc[0] = dq_acc[1] = dq_acc[2] = dq_acc[3] = AE_ZEROQ56(); - - /* AE_LP8X2F* instruction loads in upper 8 bits of P register, so shifting - vector right by 16 to get multiplication result in middle 32 bits of Q - register (lower 16 bits 0) */ - for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) { - AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2); - AE_LP8X2F_IU(dp_mat1_1, p_mat1_1, 2); - AE_LP8X2F_IU(dp_mat1_2, p_mat1_2, 2); - AE_LP8X2F_IU(dp_mat1_3, p_mat1_3, 2); - AE_LP8X2F_IU(dp_vec1_0, p_vec1_0, 2); - dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc[1], dp_mat1_1, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc[2], dp_mat1_2, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc[3], dp_mat1_3, dp_vec1_0); - } - /* Pointers are aligned so can do 8X2 loads and ignore L parts of - * registers */ - if (cols1 & 1) { - AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2); - AE_LP8X2F_IU(dp_mat1_1, p_mat1_1, 2); - AE_LP8X2F_IU(dp_mat1_2, p_mat1_2, 2); - AE_LP8X2F_IU(dp_mat1_3, p_mat1_3, 2); - AE_LP8X2F_IU(dp_vec1_0, p_vec1_0, 2); - dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - AE_MULAP24S_HH(dq_acc[0], dp_mat1_0, dp_vec1_0); - AE_MULAP24S_HH(dq_acc[1], dp_mat1_1, dp_vec1_0); - AE_MULAP24S_HH(dq_acc[2], dp_mat1_2, dp_vec1_0); - AE_MULAP24S_HH(dq_acc[3], dp_mat1_3, dp_vec1_0); - } - - if (p_bias != NULL) { - for (i = 0; i < 4; i++) - dq_acc[i] = AE_ADDSQ56S(dq_acc[i], *(ae_q32s *)(&p_bias[m_itr + i])); - } - - for (i = 0; i < 4; i++) { - dq_out32 = AE_SATQ48S(dq_acc[i]); - MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32), - out_multiplier, left_shift, - right_shift); - ADD_OUT_OFFSET_STORE_INT8(&p_out[(m_itr + i) * out_stride], dq_out, - out_zero_bias); - } - } - for (; m_itr < rows; m_itr++) { - p_mat1_0 = &p_mat1[m_itr * row_stride1 - 2]; - p_vec1_0 = p_vec1 - 2; - - dq_acc[0] = AE_ZEROQ56(); - - for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) { - AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2); - AE_LP8X2F_IU(dp_vec1_0, p_vec1_0, 2); - dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0); - } - /* Pointers are aligned so can do 8X2 loads and ignore L parts of - * registers */ - if (cols1 & 1) { - AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2); - AE_LP8X2F_IU(dp_vec1_0, p_vec1_0, 2); - dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - AE_MULAP24S_HH(dq_acc[0], dp_mat1_0, dp_vec1_0); - } - - if (p_bias != NULL) - dq_acc[0] = AE_ADDSQ56S(dq_acc[0], *(ae_q32s *)(&p_bias[m_itr])); - - dq_out32 = AE_SATQ48S(dq_acc[0]); - MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32), - out_multiplier, left_shift, right_shift); - ADD_OUT_OFFSET_STORE_INT8(&p_out[m_itr * out_stride], dq_out, - out_zero_bias); - } - } else { - if ((((unsigned)p_mat1) & 1) == 0) { - for (m_itr = 0; m_itr < (rows - 3); m_itr += 4) { - p_mat1_0 = &p_mat1[(m_itr + 0) * row_stride1 - 2]; - p_mat1_1 = &p_mat1[(m_itr + 1) * row_stride1]; - p_mat1_2 = &p_mat1[(m_itr + 2) * row_stride1 - 2]; - p_mat1_3 = &p_mat1[(m_itr + 3) * row_stride1]; - p_vec1_0 = p_vec1; - - dq_acc[0] = dq_acc[1] = dq_acc[2] = dq_acc[3] = AE_ZEROQ56(); - - /* Matrix elements are kept in upper 8 bits of P registers, vector - elements are kept in lower 8 bits of P registers, typecasting to UWORD8 - is to avoid extra extui instructions since signed 8-bit load in not - there in HiFiMini */ - for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) { - AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2); - dp_mat1_1 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_1[c_itr], - (UWORD8)p_mat1_1[c_itr + 1]); - AE_LP8X2F_IU(dp_mat1_2, p_mat1_2, 2); - dp_mat1_3 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_3[c_itr], - (UWORD8)p_mat1_3[c_itr + 1]); - dp_vec1_0 = AE_CVTP24A16X2_LL((UWORD8)p_vec1_0[c_itr], - (UWORD8)p_vec1_0[c_itr + 1]); - dp_mat1_1 = AE_SLLIP24(dp_mat1_1, 8); - dp_mat1_3 = AE_SLLIP24(dp_mat1_3, 8); - dp_vec1_0 = AE_SLLIP24(dp_vec1_0, 8); - dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc[1], dp_mat1_1, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc[2], dp_mat1_2, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc[3], dp_mat1_3, dp_vec1_0); - } - if (cols1 & 1) { - ae_p24x2s dp_mat1_01, dp_mat1_23; - dp_mat1_01 = - AE_CVTP24A16X2_LL((UWORD8)p_mat1_0[2], (UWORD8)p_mat1_1[c_itr]); - dp_mat1_23 = - AE_CVTP24A16X2_LL((UWORD8)p_mat1_2[2], (UWORD8)p_mat1_3[c_itr]); - dp_vec1_0 = AE_MOVPA24(p_vec1_0[c_itr]); - dp_mat1_01 = AE_SLLIP24(dp_mat1_01, 8); - dp_mat1_23 = AE_SLLIP24(dp_mat1_23, 8); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - AE_MULAP24S_HH(dq_acc[0], dp_mat1_01, dp_vec1_0); - AE_MULAP24S_LL(dq_acc[1], dp_mat1_01, dp_vec1_0); - AE_MULAP24S_HH(dq_acc[2], dp_mat1_23, dp_vec1_0); - AE_MULAP24S_LL(dq_acc[3], dp_mat1_23, dp_vec1_0); - } - - if (p_bias != NULL) { - for (i = 0; i < 4; i++) - dq_acc[i] = - AE_ADDSQ56S(dq_acc[i], *(ae_q32s *)(&p_bias[m_itr + i])); - } - - for (i = 0; i < 4; i++) { - dq_out32 = AE_SATQ48S(dq_acc[i]); - MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32), - out_multiplier, left_shift, - right_shift); - ADD_OUT_OFFSET_STORE_INT8(&p_out[(m_itr + i) * out_stride], dq_out, - out_zero_bias); - } - } - } else { - for (m_itr = 0; m_itr < (rows - 3); m_itr += 4) { - p_mat1_0 = &p_mat1[(m_itr + 0) * row_stride1]; - p_mat1_1 = &p_mat1[(m_itr + 1) * row_stride1]; - p_mat1_2 = &p_mat1[(m_itr + 2) * row_stride1]; - p_mat1_3 = &p_mat1[(m_itr + 3) * row_stride1]; - p_vec1_0 = p_vec1; - - dq_acc[0] = dq_acc[1] = dq_acc[2] = dq_acc[3] = AE_ZEROQ56(); - - /* Matrix elements are kept in upper 8 bits of P registers, vector - elements are kept in lower 8 bits of P registers, typecasting to UWORD8 - is to avoid extra extui instructions since signed 8-bit load in not - there in HiFiMini */ - for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) { - dp_mat1_0 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_0[c_itr], - (UWORD8)p_mat1_0[c_itr + 1]); - dp_mat1_1 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_1[c_itr], - (UWORD8)p_mat1_1[c_itr + 1]); - dp_mat1_2 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_2[c_itr], - (UWORD8)p_mat1_2[c_itr + 1]); - dp_mat1_3 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_3[c_itr], - (UWORD8)p_mat1_3[c_itr + 1]); - dp_vec1_0 = AE_CVTP24A16X2_LL((UWORD8)p_vec1_0[c_itr], - (UWORD8)p_vec1_0[c_itr + 1]); - dp_mat1_0 = AE_SLLIP24(dp_mat1_0, 8); - dp_mat1_1 = AE_SLLIP24(dp_mat1_1, 8); - dp_mat1_2 = AE_SLLIP24(dp_mat1_2, 8); - dp_mat1_3 = AE_SLLIP24(dp_mat1_3, 8); - dp_vec1_0 = AE_SLLIP24(dp_vec1_0, 8); - dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc[1], dp_mat1_1, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc[2], dp_mat1_2, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc[3], dp_mat1_3, dp_vec1_0); - } - if (cols1 & 1) { - ae_p24x2s dp_mat1_01, dp_mat1_23; - dp_mat1_01 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_0[c_itr], - (UWORD8)p_mat1_1[c_itr]); - dp_mat1_23 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_2[c_itr], - (UWORD8)p_mat1_3[c_itr]); - dp_vec1_0 = AE_MOVPA24(p_vec1_0[c_itr]); - dp_mat1_01 = AE_SLLIP24(dp_mat1_01, 8); - dp_mat1_23 = AE_SLLIP24(dp_mat1_23, 8); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - AE_MULAP24S_HH(dq_acc[0], dp_mat1_01, dp_vec1_0); - AE_MULAP24S_LL(dq_acc[1], dp_mat1_01, dp_vec1_0); - AE_MULAP24S_HH(dq_acc[2], dp_mat1_23, dp_vec1_0); - AE_MULAP24S_LL(dq_acc[3], dp_mat1_23, dp_vec1_0); - } - - if (p_bias != NULL) { - for (i = 0; i < 4; i++) - dq_acc[i] = - AE_ADDSQ56S(dq_acc[i], *(ae_q32s *)(&p_bias[m_itr + i])); - } - - for (i = 0; i < 4; i++) { - dq_out32 = AE_SATQ48S(dq_acc[i]); - MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32), - out_multiplier, left_shift, - right_shift); - ADD_OUT_OFFSET_STORE_INT8(&p_out[(m_itr + i) * out_stride], dq_out, - out_zero_bias); - } - } - } - for (; m_itr < rows; m_itr++) { - p_mat1_0 = &p_mat1[m_itr * row_stride1]; - p_vec1_0 = p_vec1; - - dq_acc[0] = AE_ZEROQ56(); - - for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) { - dp_mat1_0 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_0[c_itr], - (UWORD8)p_mat1_0[c_itr + 1]); - dp_vec1_0 = AE_CVTP24A16X2_LL((UWORD8)p_vec1_0[c_itr], - (UWORD8)p_vec1_0[c_itr + 1]); - dp_mat1_0 = AE_SLLIP24(dp_mat1_0, 8); - dp_vec1_0 = AE_SLLIP24(dp_vec1_0, 8); - dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0); - } - if (cols1 & 1) { - dp_mat1_0 = AE_CVTP24A16(p_mat1_0[c_itr]); - dp_vec1_0 = AE_CVTP24A16(p_vec1_0[c_itr]); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, AE_CVTP24A16(vec1_zero_bias)); - AE_MULAP24S_LL(dq_acc[0], dp_mat1_0, dp_vec1_0); - } - - if (p_bias != NULL) - dq_acc[0] = AE_ADDSQ56S(dq_acc[0], *(ae_q32s *)(&p_bias[m_itr])); - - dq_out32 = AE_SATQ48S(dq_acc[0]); - MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32), - out_multiplier, left_shift, right_shift); - ADD_OUT_OFFSET_STORE_INT8(&p_out[m_itr * out_stride], dq_out, - out_zero_bias); - } - } - - return 0; -} - -WORD32 xa_nn_matXvec_out_stride_asym8sxasym8s_asym8s( - WORD8 *__restrict__ p_out, const WORD8 *__restrict__ p_mat1, - const WORD8 *__restrict__ p_vec1, const WORD32 *__restrict__ p_bias, - WORD32 rows, WORD32 cols1, WORD32 row_stride1, WORD32 out_stride, - WORD32 mat1_zero_bias, WORD32 vec1_zero_bias, WORD32 out_multiplier, - WORD32 out_shift, WORD32 out_zero_bias) { - /* NULL pointer checks */ - XA_NNLIB_ARG_CHK_PTR(p_out, -1); - XA_NNLIB_ARG_CHK_PTR(p_mat1, -1); - XA_NNLIB_ARG_CHK_PTR(p_vec1, -1); - /* Pointer alignment checks */ - XA_NNLIB_ARG_CHK_ALIGN(p_bias, sizeof(WORD32), -1); - /* Basic Parameter checks */ - XA_NNLIB_ARG_CHK_COND((rows <= 0), -1); - XA_NNLIB_ARG_CHK_COND((cols1 <= 0), -1); - XA_NNLIB_ARG_CHK_COND((row_stride1 < cols1), -1); - XA_NNLIB_ARG_CHK_COND((mat1_zero_bias < -127 || mat1_zero_bias > 128), -1); - XA_NNLIB_ARG_CHK_COND((vec1_zero_bias < -127 || vec1_zero_bias > 128), -1); - XA_NNLIB_ARG_CHK_COND((out_shift < -31 || out_shift > 31), -1); - XA_NNLIB_ARG_CHK_COND((out_zero_bias < -128 || out_zero_bias > 127), -1); - - /* Iterators used in for loops */ - int m_itr, c_itr, i; - /* Assign initial value so this value will be used in trailing loop */ - m_itr = 0; - /* Shifts to match with Tensorflow */ - int left_shift, right_shift; - - left_shift = out_shift < 0 ? 0 : out_shift; - right_shift = out_shift > 0 ? 0 : -out_shift; - - const WORD8 *p_mat1_0, *p_mat1_1, *p_mat1_2, *p_mat1_3; - const WORD8 *p_vec1_0; - ae_p24x2s dp_mat1_0, dp_mat1_1, dp_mat1_2, dp_mat1_3, dp_vec1_0; - ae_p24x2s dp_vec1_zb, dp_mat1_zb; - ae_q56s dq_acc_0, dq_acc_1, dq_acc_2, dq_acc_3; - ae_q56s dq_out32, dq_out; - - const WORD32 bias_buffer[1] = {0}; - const WORD32 *p_bias_load; - WORD32 bias_address_increment = sizeof(WORD32); - - dp_mat1_zb = AE_MOVPA24(mat1_zero_bias); - dp_vec1_zb = AE_MOVPA24(vec1_zero_bias); - - /* Check for alignment conditions */ - if (((((unsigned)p_mat1) & 1) == 0) && ((((unsigned)p_vec1) & 1) == 0) && - ((row_stride1 & 1) == 0) && ((cols1 & 1) == 0)) { - /* Calculate partial zero offset adjustment outside the loop */ - WORD32 zero_offset_adjustment; - - // Constant part of total zero bias - ae_q56s dq_zero_bias_sum = - AE_CVTQ48A32S(vec1_zero_bias * cols1 * mat1_zero_bias); - - WORD8 *p_inp = (WORD8 *)p_vec1 - 2; - for (i = 0; i < (cols1 >> 1); i++) { - /* Input vector is in MSB 8 bits, matrix zero bias in LSB 8 bits */ - AE_LP8X2F_IU(dp_vec1_0, p_inp, 2); - AE_MULAAP24S_HH_LL(dq_zero_bias_sum, dp_vec1_0, dp_mat1_zb); - } - /* Product is already aligned to bits 16 to 47 in QR register. */ - zero_offset_adjustment = AE_TRUNCA32Q48(dq_zero_bias_sum); - - /* If bias is not provided, use a dummy zero value from bias_buffer. */ - if (p_bias == NULL) { - p_bias_load = bias_buffer - 1; - bias_address_increment = 0; - } else { - p_bias_load = p_bias - 1; - } - - for (m_itr = 0; m_itr < (rows - 3); m_itr += 4) { - p_mat1_0 = &p_mat1[(m_itr + 0) * row_stride1 - 2]; - p_mat1_1 = &p_mat1[(m_itr + 1) * row_stride1 - 2]; - p_mat1_2 = &p_mat1[(m_itr + 2) * row_stride1 - 2]; - p_mat1_3 = &p_mat1[(m_itr + 3) * row_stride1 - 2]; - p_vec1_0 = p_vec1 - 2; - - AE_LQ32F_XU(dq_acc_0, (ae_q32s *)p_bias_load, bias_address_increment); - AE_LQ32F_XU(dq_acc_1, (ae_q32s *)p_bias_load, bias_address_increment); - AE_LQ32F_XU(dq_acc_2, (ae_q32s *)p_bias_load, bias_address_increment); - AE_LQ32F_XU(dq_acc_3, (ae_q32s *)p_bias_load, bias_address_increment); - - dq_acc_0 = AE_ADDQ56(dq_acc_0, AE_CVTQ48A32S(zero_offset_adjustment)); - dq_acc_1 = AE_ADDQ56(dq_acc_1, AE_CVTQ48A32S(zero_offset_adjustment)); - dq_acc_2 = AE_ADDQ56(dq_acc_2, AE_CVTQ48A32S(zero_offset_adjustment)); - dq_acc_3 = AE_ADDQ56(dq_acc_3, AE_CVTQ48A32S(zero_offset_adjustment)); - - /* AE_LP8X2F* instruction loads in upper 8 bits of P register, so shifting - vector right by 16 to get multiplication result in middle 32 bits of Q - register (lower 16 bits 0) */ - for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) { - AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2); - AE_LP8X2F_IU(dp_mat1_1, p_mat1_1, 2); - AE_LP8X2F_IU(dp_mat1_2, p_mat1_2, 2); - AE_LP8X2F_IU(dp_mat1_3, p_mat1_3, 2); - AE_LP8X2F_IU(dp_vec1_0, p_vec1_0, 2); - - dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - - AE_MULAAP24S_HH_LL(dq_acc_0, dp_mat1_0, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc_1, dp_mat1_1, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc_2, dp_mat1_2, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc_3, dp_mat1_3, dp_vec1_0); - } - - /* Pointers are aligned so can do 8X2 loads and ignore L parts of - * registers */ - if (cols1 & 1) { - AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2); - AE_LP8X2F_IU(dp_mat1_1, p_mat1_1, 2); - AE_LP8X2F_IU(dp_mat1_2, p_mat1_2, 2); - AE_LP8X2F_IU(dp_mat1_3, p_mat1_3, 2); - AE_LP8X2F_IU(dp_vec1_0, p_vec1_0, 2); - - dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - - AE_MULAP24S_HH(dq_acc_0, dp_mat1_0, dp_vec1_0); - AE_MULAP24S_HH(dq_acc_1, dp_mat1_1, dp_vec1_0); - AE_MULAP24S_HH(dq_acc_2, dp_mat1_2, dp_vec1_0); - AE_MULAP24S_HH(dq_acc_3, dp_mat1_3, dp_vec1_0); - } - - dq_out32 = AE_SATQ48S(dq_acc_0); - MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32), - out_multiplier, left_shift, right_shift); - ADD_OUT_OFFSET_STORE_INT8(&p_out[(m_itr + i) * out_stride], dq_out, - out_zero_bias); - - dq_out32 = AE_SATQ48S(dq_acc_1); - MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32), - out_multiplier, left_shift, right_shift); - ADD_OUT_OFFSET_STORE_INT8(&p_out[(m_itr + i) * out_stride], dq_out, - out_zero_bias); - - dq_out32 = AE_SATQ48S(dq_acc_2); - MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32), - out_multiplier, left_shift, right_shift); - ADD_OUT_OFFSET_STORE_INT8(&p_out[(m_itr + i) * out_stride], dq_out, - out_zero_bias); - - dq_out32 = AE_SATQ48S(dq_acc_3); - MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32), - out_multiplier, left_shift, right_shift); - ADD_OUT_OFFSET_STORE_INT8(&p_out[(m_itr + i) * out_stride], dq_out, - out_zero_bias); - } - for (; m_itr < rows; m_itr++) { - p_mat1_0 = &p_mat1[m_itr * row_stride1 - 2]; - p_vec1_0 = p_vec1 - 2; - - AE_LQ32F_XU(dq_acc_0, (ae_q32s *)p_bias_load, bias_address_increment); - dq_acc_0 = AE_ADDQ56(dq_acc_0, AE_CVTQ48A32S(zero_offset_adjustment)); - - for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) { - AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2); - AE_LP8X2F_IU(dp_vec1_0, p_vec1_0, 2); - dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - - AE_MULAAP24S_HH_LL(dq_acc_0, dp_mat1_0, dp_vec1_0); - } - - dq_out32 = AE_SATQ48S(dq_acc_0); - MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32), - out_multiplier, left_shift, right_shift); - ADD_OUT_OFFSET_STORE_INT8(&p_out[m_itr * out_stride], dq_out, - out_zero_bias); - } - } else { -#ifndef DISABLE_NNLIB_UNALIGNED_SUPPORT - ae_q56s dq_acc[4]; - - if ((((unsigned)p_mat1) & 1) == 0) { - for (m_itr = 0; m_itr < (rows - 3); m_itr += 4) { - p_mat1_0 = &p_mat1[(m_itr + 0) * row_stride1 - 2]; - p_mat1_1 = &p_mat1[(m_itr + 1) * row_stride1]; - p_mat1_2 = &p_mat1[(m_itr + 2) * row_stride1 - 2]; - p_mat1_3 = &p_mat1[(m_itr + 3) * row_stride1]; - p_vec1_0 = p_vec1; - - dq_acc[0] = dq_acc[1] = dq_acc[2] = dq_acc[3] = AE_ZEROQ56(); - - /* Matrix elements are kept in upper 8 bits of P registers, vector - elements are kept in lower 8 bits of P registers, typecasting to UWORD8 - is to avoid extra extui instructions since signed 8-bit load in not - there in HiFiMini */ - for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) { - AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2); - dp_mat1_1 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_1[c_itr], - (UWORD8)p_mat1_1[c_itr + 1]); - AE_LP8X2F_IU(dp_mat1_2, p_mat1_2, 2); - dp_mat1_3 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_3[c_itr], - (UWORD8)p_mat1_3[c_itr + 1]); - dp_vec1_0 = AE_CVTP24A16X2_LL((UWORD8)p_vec1_0[c_itr], - (UWORD8)p_vec1_0[c_itr + 1]); - dp_mat1_1 = AE_SLLIP24(dp_mat1_1, 8); - dp_mat1_3 = AE_SLLIP24(dp_mat1_3, 8); - dp_vec1_0 = AE_SLLIP24(dp_vec1_0, 8); - dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - - dp_mat1_0 = AE_SRAIP24(dp_mat1_0, 16); - dp_mat1_0 = AE_ADDSP24S(dp_mat1_0, dp_mat1_zb); - dp_mat1_1 = AE_SRAIP24(dp_mat1_1, 16); - dp_mat1_1 = AE_ADDSP24S(dp_mat1_1, dp_mat1_zb); - dp_mat1_2 = AE_SRAIP24(dp_mat1_2, 16); - dp_mat1_2 = AE_ADDSP24S(dp_mat1_2, dp_mat1_zb); - dp_mat1_3 = AE_SRAIP24(dp_mat1_3, 16); - dp_mat1_3 = AE_ADDSP24S(dp_mat1_3, dp_mat1_zb); - - AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc[1], dp_mat1_1, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc[2], dp_mat1_2, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc[3], dp_mat1_3, dp_vec1_0); - } - if (cols1 & 1) { - ae_p24x2s dp_mat1_01, dp_mat1_23; - dp_mat1_01 = - AE_CVTP24A16X2_LL((UWORD8)p_mat1_0[2], (UWORD8)p_mat1_1[c_itr]); - dp_mat1_23 = - AE_CVTP24A16X2_LL((UWORD8)p_mat1_2[2], (UWORD8)p_mat1_3[c_itr]); - dp_vec1_0 = AE_MOVPA24(p_vec1_0[c_itr]); - dp_mat1_01 = AE_SLLIP24(dp_mat1_01, 8); - dp_mat1_23 = AE_SLLIP24(dp_mat1_23, 8); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - - dp_mat1_01 = AE_SRAIP24(dp_mat1_01, 16); - dp_mat1_01 = AE_ADDSP24S(dp_mat1_01, dp_mat1_zb); - dp_mat1_23 = AE_SRAIP24(dp_mat1_23, 16); - dp_mat1_23 = AE_ADDSP24S(dp_mat1_23, dp_mat1_zb); - - AE_MULAP24S_HH(dq_acc[0], dp_mat1_01, dp_vec1_0); - AE_MULAP24S_LL(dq_acc[1], dp_mat1_01, dp_vec1_0); - AE_MULAP24S_HH(dq_acc[2], dp_mat1_23, dp_vec1_0); - AE_MULAP24S_LL(dq_acc[3], dp_mat1_23, dp_vec1_0); - } - - dq_acc[0] = AE_SLLISQ56S(dq_acc[0], 16); - dq_acc[1] = AE_SLLISQ56S(dq_acc[1], 16); - dq_acc[2] = AE_SLLISQ56S(dq_acc[2], 16); - dq_acc[3] = AE_SLLISQ56S(dq_acc[3], 16); - - if (p_bias != NULL) { - for (i = 0; i < 4; i++) - dq_acc[i] = - AE_ADDSQ56S(dq_acc[i], *(ae_q32s *)(&p_bias[m_itr + i])); - } - - for (i = 0; i < 4; i++) { - dq_out32 = AE_SATQ48S(dq_acc[i]); - MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32), - out_multiplier, left_shift, - right_shift); - ADD_OUT_OFFSET_STORE_INT8(&p_out[(m_itr + i) * out_stride], dq_out, - out_zero_bias); - } - } - } else { - for (m_itr = 0; m_itr < (rows - 3); m_itr += 4) { - p_mat1_0 = &p_mat1[(m_itr + 0) * row_stride1]; - p_mat1_1 = &p_mat1[(m_itr + 1) * row_stride1]; - p_mat1_2 = &p_mat1[(m_itr + 2) * row_stride1]; - p_mat1_3 = &p_mat1[(m_itr + 3) * row_stride1]; - p_vec1_0 = p_vec1; - - dq_acc[0] = dq_acc[1] = dq_acc[2] = dq_acc[3] = AE_ZEROQ56(); - - /* Matrix elements are kept in upper 8 bits of P registers, vector - elements are kept in lower 8 bits of P registers, typecasting to UWORD8 - is to avoid extra extui instructions since signed 8-bit load in not - there in HiFiMini */ - for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) { - dp_mat1_0 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_0[c_itr], - (UWORD8)p_mat1_0[c_itr + 1]); - dp_mat1_1 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_1[c_itr], - (UWORD8)p_mat1_1[c_itr + 1]); - dp_mat1_2 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_2[c_itr], - (UWORD8)p_mat1_2[c_itr + 1]); - dp_mat1_3 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_3[c_itr], - (UWORD8)p_mat1_3[c_itr + 1]); - dp_vec1_0 = AE_CVTP24A16X2_LL((UWORD8)p_vec1_0[c_itr], - (UWORD8)p_vec1_0[c_itr + 1]); - dp_mat1_0 = AE_SLLIP24(dp_mat1_0, 8); - dp_mat1_1 = AE_SLLIP24(dp_mat1_1, 8); - dp_mat1_2 = AE_SLLIP24(dp_mat1_2, 8); - dp_mat1_3 = AE_SLLIP24(dp_mat1_3, 8); - dp_vec1_0 = AE_SLLIP24(dp_vec1_0, 8); - dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - - dp_mat1_0 = AE_SRAIP24(dp_mat1_0, 16); - dp_mat1_0 = AE_ADDSP24S(dp_mat1_0, dp_mat1_zb); - dp_mat1_1 = AE_SRAIP24(dp_mat1_1, 16); - dp_mat1_1 = AE_ADDSP24S(dp_mat1_1, dp_mat1_zb); - dp_mat1_2 = AE_SRAIP24(dp_mat1_2, 16); - dp_mat1_2 = AE_ADDSP24S(dp_mat1_2, dp_mat1_zb); - dp_mat1_3 = AE_SRAIP24(dp_mat1_3, 16); - dp_mat1_3 = AE_ADDSP24S(dp_mat1_3, dp_mat1_zb); - - AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc[1], dp_mat1_1, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc[2], dp_mat1_2, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc[3], dp_mat1_3, dp_vec1_0); - } - if (cols1 & 1) { - ae_p24x2s dp_mat1_01, dp_mat1_23; - dp_mat1_01 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_0[c_itr], - (UWORD8)p_mat1_1[c_itr]); - dp_mat1_23 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_2[c_itr], - (UWORD8)p_mat1_3[c_itr]); - dp_vec1_0 = AE_MOVPA24(p_vec1_0[c_itr]); - dp_mat1_01 = AE_SLLIP24(dp_mat1_01, 8); - dp_mat1_23 = AE_SLLIP24(dp_mat1_23, 8); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - - dp_mat1_01 = AE_SRAIP24(dp_mat1_01, 16); - dp_mat1_01 = AE_ADDSP24S(dp_mat1_01, dp_mat1_zb); - dp_mat1_23 = AE_SRAIP24(dp_mat1_23, 16); - dp_mat1_23 = AE_ADDSP24S(dp_mat1_23, dp_mat1_zb); - - AE_MULAP24S_HH(dq_acc[0], dp_mat1_01, dp_vec1_0); - AE_MULAP24S_LL(dq_acc[1], dp_mat1_01, dp_vec1_0); - AE_MULAP24S_HH(dq_acc[2], dp_mat1_23, dp_vec1_0); - AE_MULAP24S_LL(dq_acc[3], dp_mat1_23, dp_vec1_0); - } - - dq_acc[0] = AE_SLLISQ56S(dq_acc[0], 16); - dq_acc[1] = AE_SLLISQ56S(dq_acc[1], 16); - dq_acc[2] = AE_SLLISQ56S(dq_acc[2], 16); - dq_acc[3] = AE_SLLISQ56S(dq_acc[3], 16); - - if (p_bias != NULL) { - for (i = 0; i < 4; i++) - dq_acc[i] = - AE_ADDSQ56S(dq_acc[i], *(ae_q32s *)(&p_bias[m_itr + i])); - } - - for (i = 0; i < 4; i++) { - dq_out32 = AE_SATQ48S(dq_acc[i]); - MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32), - out_multiplier, left_shift, - right_shift); - ADD_OUT_OFFSET_STORE_INT8(&p_out[(m_itr + i) * out_stride], dq_out, - out_zero_bias); - } - } - } - for (; m_itr < rows; m_itr++) { - p_mat1_0 = &p_mat1[m_itr * row_stride1]; - p_vec1_0 = p_vec1; - - dq_acc[0] = AE_ZEROQ56(); - - for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) { - dp_mat1_0 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_0[c_itr], - (UWORD8)p_mat1_0[c_itr + 1]); - dp_vec1_0 = AE_CVTP24A16X2_LL((UWORD8)p_vec1_0[c_itr], - (UWORD8)p_vec1_0[c_itr + 1]); - dp_mat1_0 = AE_SLLIP24(dp_mat1_0, 8); - dp_vec1_0 = AE_SLLIP24(dp_vec1_0, 8); - dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - - dp_mat1_0 = AE_SRAIP24(dp_mat1_0, 16); - dp_mat1_0 = AE_ADDSP24S(dp_mat1_0, dp_mat1_zb); - - AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0); - } - if (cols1 & 1) { - dp_mat1_0 = AE_CVTP24A16(p_mat1_0[c_itr]); - dp_vec1_0 = AE_CVTP24A16(p_vec1_0[c_itr]); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, AE_CVTP24A16(vec1_zero_bias)); - - dp_mat1_0 = AE_SRAIP24(dp_mat1_0, 16); - dp_mat1_0 = AE_ADDSP24S(dp_mat1_0, dp_mat1_zb); - - AE_MULAP24S_LL(dq_acc[0], dp_mat1_0, dp_vec1_0); - } - - dq_acc[0] = AE_SLLISQ56S(dq_acc[0], 16); - - if (p_bias != NULL) - dq_acc[0] = AE_ADDSQ56S(dq_acc[0], *(ae_q32s *)(&p_bias[m_itr])); - - dq_out32 = AE_SATQ48S(dq_acc[0]); - MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32), - out_multiplier, left_shift, right_shift); - ADD_OUT_OFFSET_STORE_INT8(&p_out[m_itr * out_stride], dq_out, - out_zero_bias); - } -#else - return 1; -#endif - } - - return 0; -} - -#define STORE_INT16(ptr, data) \ - { \ - int out_i32 = AE_TRUNCA32Q48(AE_SATQ48S(data)); \ - out_i32 = out_i32 < (int)0xffff8000L ? (int)0xffff8000L : out_i32; \ - out_i32 = out_i32 > (int)0x7fff ? (int)0x7fff : out_i32; \ - *(ptr) = (WORD16)out_i32; \ - } - -WORD32 xa_nn_matXvec_out_stride_sym8sxasym8s_16( - WORD16 *__restrict__ p_out, const WORD8 *__restrict__ p_mat1, - const WORD8 *__restrict__ p_vec1, const WORD32 *__restrict__ p_bias, - WORD32 rows, WORD32 cols1, WORD32 row_stride1, WORD32 out_stride, - WORD32 vec1_zero_bias, WORD32 out_multiplier, WORD32 out_shift) { - /* NULL pointer checks */ - XA_NNLIB_ARG_CHK_PTR(p_out, -1); - XA_NNLIB_ARG_CHK_PTR(p_mat1, -1); - XA_NNLIB_ARG_CHK_PTR(p_vec1, -1); - /* Pointer alignment checks */ - XA_NNLIB_ARG_CHK_ALIGN(p_out, sizeof(WORD16), -1); - XA_NNLIB_ARG_CHK_ALIGN(p_bias, sizeof(WORD32), -1); - /* Basic Parameter checks */ - XA_NNLIB_ARG_CHK_COND((rows <= 0), -1); - XA_NNLIB_ARG_CHK_COND((cols1 <= 0), -1); - XA_NNLIB_ARG_CHK_COND((row_stride1 < cols1), -1); - XA_NNLIB_ARG_CHK_COND((vec1_zero_bias < -127 || vec1_zero_bias > 128), -1); - XA_NNLIB_ARG_CHK_COND((out_shift < -31 || out_shift > 31), -1); - - /* Iterators used in for loops */ - int m_itr, c_itr, i; - /* Assign initial value so this value will be used in trailing loop */ - m_itr = 0; - /* Shifts to match with Tensorflow */ - int left_shift, right_shift; - - left_shift = out_shift < 0 ? 0 : out_shift; - right_shift = out_shift > 0 ? 0 : -out_shift; - - const WORD8 *p_mat1_0, *p_mat1_1, *p_mat1_2, *p_mat1_3; - const WORD8 *p_vec1_0; - ae_p24x2s dp_mat1_0, dp_mat1_1, dp_mat1_2, dp_mat1_3, dp_vec1_0; - ae_p24x2s dp_vec1_zb; - ae_q56s dq_acc[4]; - ae_q56s dq_out32, dq_out; - - dp_vec1_zb = AE_MOVPA24(vec1_zero_bias); - if (((((unsigned)p_mat1) & 1) == 0) && ((((unsigned)p_vec1) & 1) == 0) && - ((row_stride1 & 1) == 0)) { - for (m_itr = 0; m_itr < (rows - 3); m_itr += 4) { - p_mat1_0 = &p_mat1[(m_itr + 0) * row_stride1 - 2]; - p_mat1_1 = &p_mat1[(m_itr + 1) * row_stride1 - 2]; - p_mat1_2 = &p_mat1[(m_itr + 2) * row_stride1 - 2]; - p_mat1_3 = &p_mat1[(m_itr + 3) * row_stride1 - 2]; - p_vec1_0 = p_vec1 - 2; - - dq_acc[0] = dq_acc[1] = dq_acc[2] = dq_acc[3] = AE_ZEROQ56(); - - /* AE_LP8X2F* instruction loads in upper 8 bits of P register, so shifting - vector right by 16 to get multiplication result in middle 32 bits of Q - register (lower 16 bits 0) */ - for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) { - AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2); - AE_LP8X2F_IU(dp_mat1_1, p_mat1_1, 2); - AE_LP8X2F_IU(dp_mat1_2, p_mat1_2, 2); - AE_LP8X2F_IU(dp_mat1_3, p_mat1_3, 2); - AE_LP8X2F_IU(dp_vec1_0, p_vec1_0, 2); - dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc[1], dp_mat1_1, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc[2], dp_mat1_2, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc[3], dp_mat1_3, dp_vec1_0); - } - /* Pointers are aligned so can do 8X2 loads and ignore L parts of - * registers */ - if (cols1 & 1) { - AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2); - AE_LP8X2F_IU(dp_mat1_1, p_mat1_1, 2); - AE_LP8X2F_IU(dp_mat1_2, p_mat1_2, 2); - AE_LP8X2F_IU(dp_mat1_3, p_mat1_3, 2); - AE_LP8X2F_IU(dp_vec1_0, p_vec1_0, 2); - dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - AE_MULAP24S_HH(dq_acc[0], dp_mat1_0, dp_vec1_0); - AE_MULAP24S_HH(dq_acc[1], dp_mat1_1, dp_vec1_0); - AE_MULAP24S_HH(dq_acc[2], dp_mat1_2, dp_vec1_0); - AE_MULAP24S_HH(dq_acc[3], dp_mat1_3, dp_vec1_0); - } - - if (p_bias != NULL) { - for (i = 0; i < 4; i++) - dq_acc[i] = AE_ADDSQ56S(dq_acc[i], *(ae_q32s *)(&p_bias[m_itr + i])); - } - - for (i = 0; i < 4; i++) { - dq_out32 = AE_SATQ48S(dq_acc[i]); - MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32), - out_multiplier, left_shift, - right_shift); - STORE_INT16(&p_out[(m_itr + i) * out_stride], dq_out); - } - } - for (; m_itr < rows; m_itr++) { - p_mat1_0 = &p_mat1[m_itr * row_stride1 - 2]; - p_vec1_0 = p_vec1 - 2; - - dq_acc[0] = AE_ZEROQ56(); - - for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) { - AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2); - AE_LP8X2F_IU(dp_vec1_0, p_vec1_0, 2); - dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0); - } - /* Pointers are aligned so can do 8X2 loads and ignore L parts of - * registers */ - if (cols1 & 1) { - AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2); - AE_LP8X2F_IU(dp_vec1_0, p_vec1_0, 2); - dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - AE_MULAP24S_HH(dq_acc[0], dp_mat1_0, dp_vec1_0); - } - - if (p_bias != NULL) - dq_acc[0] = AE_ADDSQ56S(dq_acc[0], *(ae_q32s *)(&p_bias[m_itr])); - - dq_out32 = AE_SATQ48S(dq_acc[0]); - MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32), - out_multiplier, left_shift, right_shift); - STORE_INT16(&p_out[m_itr * out_stride], dq_out); - } - } else { -#ifndef DISABLE_NNLIB_UNALIGNED_SUPPORT - if ((((unsigned)p_mat1) & 1) == 0) { - for (m_itr = 0; m_itr < (rows - 3); m_itr += 4) { - p_mat1_0 = &p_mat1[(m_itr + 0) * row_stride1 - 2]; - p_mat1_1 = &p_mat1[(m_itr + 1) * row_stride1]; - p_mat1_2 = &p_mat1[(m_itr + 2) * row_stride1 - 2]; - p_mat1_3 = &p_mat1[(m_itr + 3) * row_stride1]; - p_vec1_0 = p_vec1; - - dq_acc[0] = dq_acc[1] = dq_acc[2] = dq_acc[3] = AE_ZEROQ56(); - - /* Matrix elements are kept in upper 8 bits of P registers, vector - elements are kept in lower 8 bits of P registers, typecasting to UWORD8 - is to avoid extra extui instructions since signed 8-bit load in not - there in HiFiMini */ - for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) { - AE_LP8X2F_IU(dp_mat1_0, p_mat1_0, 2); - dp_mat1_1 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_1[c_itr], - (UWORD8)p_mat1_1[c_itr + 1]); - AE_LP8X2F_IU(dp_mat1_2, p_mat1_2, 2); - dp_mat1_3 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_3[c_itr], - (UWORD8)p_mat1_3[c_itr + 1]); - dp_vec1_0 = AE_CVTP24A16X2_LL((UWORD8)p_vec1_0[c_itr], - (UWORD8)p_vec1_0[c_itr + 1]); - dp_mat1_1 = AE_SLLIP24(dp_mat1_1, 8); - dp_mat1_3 = AE_SLLIP24(dp_mat1_3, 8); - dp_vec1_0 = AE_SLLIP24(dp_vec1_0, 8); - dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc[1], dp_mat1_1, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc[2], dp_mat1_2, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc[3], dp_mat1_3, dp_vec1_0); - } - if (cols1 & 1) { - ae_p24x2s dp_mat1_01, dp_mat1_23; - dp_mat1_01 = - AE_CVTP24A16X2_LL((UWORD8)p_mat1_0[2], (UWORD8)p_mat1_1[c_itr]); - dp_mat1_23 = - AE_CVTP24A16X2_LL((UWORD8)p_mat1_2[2], (UWORD8)p_mat1_3[c_itr]); - dp_vec1_0 = AE_MOVPA24(p_vec1_0[c_itr]); - dp_mat1_01 = AE_SLLIP24(dp_mat1_01, 8); - dp_mat1_23 = AE_SLLIP24(dp_mat1_23, 8); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - AE_MULAP24S_HH(dq_acc[0], dp_mat1_01, dp_vec1_0); - AE_MULAP24S_LL(dq_acc[1], dp_mat1_01, dp_vec1_0); - AE_MULAP24S_HH(dq_acc[2], dp_mat1_23, dp_vec1_0); - AE_MULAP24S_LL(dq_acc[3], dp_mat1_23, dp_vec1_0); - } - - if (p_bias != NULL) { - for (i = 0; i < 4; i++) - dq_acc[i] = - AE_ADDSQ56S(dq_acc[i], *(ae_q32s *)(&p_bias[m_itr + i])); - } - - for (i = 0; i < 4; i++) { - dq_out32 = AE_SATQ48S(dq_acc[i]); - MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32), - out_multiplier, left_shift, - right_shift); - STORE_INT16(&p_out[(m_itr + i) * out_stride], dq_out); - } - } - } else { - for (m_itr = 0; m_itr < (rows - 3); m_itr += 4) { - p_mat1_0 = &p_mat1[(m_itr + 0) * row_stride1]; - p_mat1_1 = &p_mat1[(m_itr + 1) * row_stride1]; - p_mat1_2 = &p_mat1[(m_itr + 2) * row_stride1]; - p_mat1_3 = &p_mat1[(m_itr + 3) * row_stride1]; - p_vec1_0 = p_vec1; - - dq_acc[0] = dq_acc[1] = dq_acc[2] = dq_acc[3] = AE_ZEROQ56(); - - /* Matrix elements are kept in upper 8 bits of P registers, vector - elements are kept in lower 8 bits of P registers, typecasting to UWORD8 - is to avoid extra extui instructions since signed 8-bit load in not - there in HiFiMini */ - for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) { - dp_mat1_0 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_0[c_itr], - (UWORD8)p_mat1_0[c_itr + 1]); - dp_mat1_1 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_1[c_itr], - (UWORD8)p_mat1_1[c_itr + 1]); - dp_mat1_2 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_2[c_itr], - (UWORD8)p_mat1_2[c_itr + 1]); - dp_mat1_3 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_3[c_itr], - (UWORD8)p_mat1_3[c_itr + 1]); - dp_vec1_0 = AE_CVTP24A16X2_LL((UWORD8)p_vec1_0[c_itr], - (UWORD8)p_vec1_0[c_itr + 1]); - dp_mat1_0 = AE_SLLIP24(dp_mat1_0, 8); - dp_mat1_1 = AE_SLLIP24(dp_mat1_1, 8); - dp_mat1_2 = AE_SLLIP24(dp_mat1_2, 8); - dp_mat1_3 = AE_SLLIP24(dp_mat1_3, 8); - dp_vec1_0 = AE_SLLIP24(dp_vec1_0, 8); - dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc[1], dp_mat1_1, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc[2], dp_mat1_2, dp_vec1_0); - AE_MULAAP24S_HH_LL(dq_acc[3], dp_mat1_3, dp_vec1_0); - } - if (cols1 & 1) { - ae_p24x2s dp_mat1_01, dp_mat1_23; - dp_mat1_01 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_0[c_itr], - (UWORD8)p_mat1_1[c_itr]); - dp_mat1_23 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_2[c_itr], - (UWORD8)p_mat1_3[c_itr]); - dp_vec1_0 = AE_MOVPA24(p_vec1_0[c_itr]); - dp_mat1_01 = AE_SLLIP24(dp_mat1_01, 8); - dp_mat1_23 = AE_SLLIP24(dp_mat1_23, 8); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - AE_MULAP24S_HH(dq_acc[0], dp_mat1_01, dp_vec1_0); - AE_MULAP24S_LL(dq_acc[1], dp_mat1_01, dp_vec1_0); - AE_MULAP24S_HH(dq_acc[2], dp_mat1_23, dp_vec1_0); - AE_MULAP24S_LL(dq_acc[3], dp_mat1_23, dp_vec1_0); - } - - if (p_bias != NULL) { - for (i = 0; i < 4; i++) - dq_acc[i] = - AE_ADDSQ56S(dq_acc[i], *(ae_q32s *)(&p_bias[m_itr + i])); - } - - for (i = 0; i < 4; i++) { - dq_out32 = AE_SATQ48S(dq_acc[i]); - MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32), - out_multiplier, left_shift, - right_shift); - STORE_INT16(&p_out[(m_itr + i) * out_stride], dq_out); - } - } - } - for (; m_itr < rows; m_itr++) { - p_mat1_0 = &p_mat1[m_itr * row_stride1]; - p_vec1_0 = p_vec1; - - dq_acc[0] = AE_ZEROQ56(); - - for (c_itr = 0; c_itr < (cols1 - 1); c_itr += 2) { - dp_mat1_0 = AE_CVTP24A16X2_LL((UWORD8)p_mat1_0[c_itr], - (UWORD8)p_mat1_0[c_itr + 1]); - dp_vec1_0 = AE_CVTP24A16X2_LL((UWORD8)p_vec1_0[c_itr], - (UWORD8)p_vec1_0[c_itr + 1]); - dp_mat1_0 = AE_SLLIP24(dp_mat1_0, 8); - dp_vec1_0 = AE_SLLIP24(dp_vec1_0, 8); - dp_vec1_0 = AE_SRAIP24(dp_vec1_0, 16); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, dp_vec1_zb); - AE_MULAAP24S_HH_LL(dq_acc[0], dp_mat1_0, dp_vec1_0); - } - if (cols1 & 1) { - dp_mat1_0 = AE_CVTP24A16(p_mat1_0[c_itr]); - dp_vec1_0 = AE_CVTP24A16(p_vec1_0[c_itr]); - dp_vec1_0 = AE_ADDSP24S(dp_vec1_0, AE_CVTP24A16(vec1_zero_bias)); - AE_MULAP24S_LL(dq_acc[0], dp_mat1_0, dp_vec1_0); - } - - if (p_bias != NULL) - dq_acc[0] = AE_ADDSQ56S(dq_acc[0], *(ae_q32s *)(&p_bias[m_itr])); - - dq_out32 = AE_SATQ48S(dq_acc[0]); - MULTIPLY_BY_QUANTIZED_MULTIPLIER(dq_out, AE_TRUNCA32Q48(dq_out32), - out_multiplier, left_shift, right_shift); - STORE_INT16(&p_out[m_itr * out_stride], dq_out); - } -#else - return 1; -#endif - } - - return 0; -} diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_api.h b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_api.h deleted file mode 100644 index e499e1eb980..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_api.h +++ /dev/null @@ -1,43 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Cadence Design Systems, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining - * a copy of this software and associated documentation files (the - * "Software"), to use this Software with Cadence processor cores only and - * not with any other processors and platforms, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - ******************************************************************************/ - -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef __XA_NNLIB_API_H__ -#define __XA_NNLIB_API_H__ - -#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_kernels_api.h" -#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/xa_type_def.h" - -#endif /* __XA_NNLIB_API_H__ */ diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_kernels_api.h b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_kernels_api.h deleted file mode 100644 index d3a5e2990c0..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_kernels_api.h +++ /dev/null @@ -1,300 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Cadence Design Systems, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining - * a copy of this software and associated documentation files (the - * "Software"), to use this Software with Cadence processor cores only and - * not with any other processors and platforms, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - ******************************************************************************/ - -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef __XA_NNLIB_KERNELS_API_H__ -#define __XA_NNLIB_KERNELS_API_H__ - -/** - * @file xa_nnlib_kernels_api.h - * @brief This file gives the API definition for the HiFi NNLIB - * - * matXvec KERNELS API NAMING CONVENTION <br> - * <br> - * xa_nn_matXvec_<batch>_[m]x[n]_[p]_<activation>, where - * - <batch>: Optional 'batch' tag to indicate time batching routine - * - [m]: Matrix precision in bits - * - [n]: Vector (and bias for non-activation routines) precision in bits - * - [p]: Output precision in bits - * - <activation>: optional activation tag 'sigmoid' / 'tanh' - * - * These set of kernels perform dual matXvec followed by optional - * activation function. There are several variants based on the input, - * output precision and use of activation functions. - * - * Restriction, - * - All pointers (p_out, p_mat1, p_mat2, p_vec1, p_vec2, p_bias, p_scratch) - * must be SIMD (64-bit) aligned and should not overlap. - * - p_mat2, p_vec2 can be 'NULL', but other pointers cannot be 'NULL' - * - Variables cols1, cols2, row_stride1, row_stride2 must be multiple of 4 - * - * Usage of few critical variables, - * - acc_shift: - * -# In case of valid activation tag i.e. <activation>: shift to be - * applied on accumulator to match accumulator's Q format with activation - * function's input's Q format - * -# In case of bypass i.e. no activation tag: shift to be applied on - * accumulator. - * -# Positive value denotes left shift, and negative value denotes right - * shift. - * - bias_shift: shift which is to be applied on bias to match bias's - * Q format with accumulator's Q format. Positive value denotes left shift, - * and negative value denotes right shift. - * - bias_precision: This represents bias precision - * -# For 16x16, and 8x16 apis, valid values are '16' and '64' - * -# For 8x8 apis, valid values are '8' and '32' - * - * Output 8b, 16b, 32b of fixed point apis (only for bypass variants) is - * extracted from 64b accumulator with symmetric rounding. Output 64b of fixed - * point apis (only for bypass variants) is extracted from 64b accumulator. - * Output 8b, 16b of fixed point apis (only for activation variants) is - * symmetrically rounded. - * - * matXvec 16x16 Kernels, - * - Bypass kernels with 16, 32, 64 bit output: 3 - * - Fused kernel with 2 activation variants: 2 - * - Time batching kernel: 1 (Not implemented) - * - Total: 6 - * - * matXvec 8x16 Kernels, - * - Bypass kernels with 16, 32, 64 bit output: 3 - * - Fused kernel with 2 activation variants: 2 - * - Time batching kernel: 1 (Not implemented) - * - Total: 6 - * - * matXvec 8x8 Kernels, - * - Bypass kernels with 8, 16, 32 bit output: 3 - * - Fused kernel with 2 activation variants: 2 - * - Time batching kernel: 1 (Not implemented) - * - Total: 6 - * - * matXvec float32 x float32 Kernels, - * - Bypass kernels 32 bit output: 1 - * - Fused kernel with 2 activation variants: 2 - * - Time batching kernel: 1 (Not implemented) - * - Total: 4 - * - * ACTIVATION KERNELS API NAMING CONVENTION <br> - * <br> - * xa_nn_vec_[activation]_[n]_[p] for fixed point <br> - * xa_nn_vec_[activation]_f32_f32 for floating point, where - * - [activation]: One of activations - sigmoid/tanh/relu/relu1/relu6/softmax - * - [n]: Input precision in bits - * - [p]: Output precision in bits - * - * Possible values, - * - 'n' takes value '32', and expects input in Q6.25 format. - * - 'p' takes values '32' and '16', gives output in Q16.15 and Q0.15 formats - * respectively. - * - * There is WORD32 datatype variable 'threshold' for 'relu' related apis, which - * expects value in Q16.15 format. - * - * Restriction, - * - All pointers (p_out, p_vec) must be 32-bit aligned and should not overlap. - * - * activation 32_32 kernels, - * - Vector activation kernels: 6 - * - Total: 6 - * - * activation f32_f32 kernels, - * - Vector activation kernels: 6 - * - Total: 6 - * - * activation 32_16 kernels, - * - Vector activation kernels: 2 - * - Total: 2 - */ - -#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/xa_type_def.h" - -#if defined(__cplusplus) -extern "C" { -#endif - -WORD32 xa_nn_conv2d_depthwise_getsize( - WORD32 input_height, WORD32 input_width, WORD32 input_channels, - WORD32 kernel_height, WORD32 kernel_width, WORD32 channels_multiplier, - WORD32 x_stride, WORD32 y_stride, WORD32 x_padding, WORD32 y_padding, - WORD32 output_height, WORD32 output_width, WORD32 circ_buf_precision, - WORD32 inp_data_format); - -WORD32 xa_nn_vec_activation_min_max_asym8u_asym8u( - UWORD8 *__restrict__ p_out, const UWORD8 *__restrict__ p_vec, - int activation_min, int activation_max, WORD32 vec_length); - -WORD32 xa_nn_vec_activation_min_max_asym8s_asym8s( - WORD8 *__restrict__ p_out, const WORD8 *__restrict__ p_vec, - int activation_min, int activation_max, WORD32 vec_length); - -WORD32 xa_nn_conv2d_std_getsize(WORD32 input_height, WORD32 input_channels, - WORD32 kernel_height, WORD32 kernel_width, - WORD32 y_stride, WORD32 y_padding, - WORD32 out_height, WORD32 input_precision); - -WORD32 xa_nn_conv2d_std_asym8uxasym8u( - UWORD8 *__restrict__ p_out, const UWORD8 *__restrict__ p_inp, - const UWORD8 *__restrict__ p_kernel, const WORD32 *__restrict__ p_bias, - WORD32 input_height, WORD32 input_width, WORD32 input_channels, - WORD32 kernel_height, WORD32 kernel_width, WORD32 out_channels, - WORD32 x_stride, WORD32 y_stride, WORD32 x_padding, WORD32 y_padding, - WORD32 out_height, WORD32 out_width, WORD32 input_zero_bias, - WORD32 kernel_zero_bias, WORD32 out_multiplier, WORD32 out_shift, - WORD32 out_zero_bias, WORD32 out_data_format, VOID *p_scratch); - -WORD32 xa_nn_conv2d_std_per_chan_sym8sxasym8s( - WORD8 *__restrict__ p_out, const WORD8 *__restrict__ p_inp, - const WORD8 *__restrict__ p_kernel, const WORD32 *__restrict__ p_bias, - WORD32 input_height, WORD32 input_width, WORD32 input_channels, - WORD32 kernel_height, WORD32 kernel_width, WORD32 out_channels, - WORD32 x_stride, WORD32 y_stride, WORD32 x_padding, WORD32 y_padding, - WORD32 out_height, WORD32 out_width, WORD32 input_zero_bias, - WORD32 *p_out_multiplier, WORD32 *p_out_shift, WORD32 out_zero_bias, - WORD32 out_data_format, VOID *p_scratch); - -WORD32 xa_nn_conv2d_depthwise_asym8uxasym8u( - pUWORD8 __restrict__ p_out, const UWORD8 *__restrict__ p_kernel, - const UWORD8 *__restrict__ p_inp, const WORD32 *__restrict__ p_bias, - WORD32 input_height, WORD32 input_width, WORD32 input_channels, - WORD32 kernel_height, WORD32 kernel_width, WORD32 channels_multiplier, - WORD32 x_stride, WORD32 y_stride, WORD32 x_padding, WORD32 y_padding, - WORD32 out_height, WORD32 out_width, WORD32 input_zero_bias, - WORD32 kernel_zero_bias, WORD32 out_multiplier, WORD32 out_shift, - WORD32 out_zero_bias, WORD32 inp_data_format, WORD32 out_data_format, - pVOID p_scratch); - -WORD32 xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s( - WORD8 *__restrict__ p_out, const WORD8 *__restrict__ p_kernel, - const WORD8 *__restrict__ p_inp, const WORD32 *__restrict__ p_bias, - WORD32 input_height, WORD32 input_width, WORD32 input_channels, - WORD32 kernel_height, WORD32 kernel_width, WORD32 channels_multiplier, - WORD32 x_stride, WORD32 y_stride, WORD32 x_padding, WORD32 y_padding, - WORD32 out_height, WORD32 out_width, WORD32 input_zero_bias, - const WORD32 *p_out_multiplier, const WORD32 *p_out_shift, - WORD32 out_zero_bias, WORD32 inp_data_format, WORD32 out_data_format, - pVOID p_scratch); - -WORD32 xa_nn_fully_connected_asym8uxasym8u_asym8u( - pUWORD8 __restrict__ p_out, const UWORD8 *__restrict__ p_weight, - const UWORD8 *__restrict__ p_inp, const WORD32 *__restrict__ p_bias, - WORD32 weight_depth, WORD32 out_depth, WORD32 input_zero_bias, - WORD32 weight_zero_bias, WORD32 out_multiplier, WORD32 out_shift, - WORD32 out_zero_bias); - -WORD32 xa_nn_fully_connected_sym8sxasym8s_asym8s( - pWORD8 __restrict__ p_out, const WORD8 *__restrict__ p_weight, - const WORD8 *__restrict__ p_inp, const WORD32 *__restrict__ p_bias, - WORD32 weight_depth, WORD32 out_depth, WORD32 input_zero_bias, - WORD32 out_multiplier, WORD32 out_shift, WORD32 out_zero_bias); - -WORD32 xa_nn_fully_connected_asym8sxasym8s_asym8s( - WORD8 *__restrict__ p_out, const WORD8 *__restrict__ p_weight, - const WORD8 *__restrict__ p_inp, const WORD32 *__restrict__ p_bias, - WORD32 weight_depth, WORD32 out_depth, WORD32 weight_zero_bias, - WORD32 input_zero_bias, WORD32 out_multiplier, WORD32 out_shift, - WORD32 out_zero_bias); - -WORD32 xa_nn_vec_softmax_asym8u_8(UWORD8 *__restrict__ p_out, - const UWORD8 *__restrict__ p_vec, - WORD32 diffmin, WORD32 input_left_shift, - WORD32 input_multiplier, WORD32 vec_length, - pVOID p_scratch); - -WORD32 xa_nn_vec_softmax_asym8s_16(WORD16 *__restrict__ p_out, - const WORD8 *__restrict__ p_vec, - WORD32 diffmin, WORD32 input_left_shift, - WORD32 input_multiplier, WORD32 vec_length, - pVOID p_scratch); - -WORD32 xa_nn_vec_softmax_asym8s_8(WORD8 *__restrict__ p_out, - const WORD8 *__restrict__ p_vec, - WORD32 diffmin, WORD32 input_left_shift, - WORD32 input_multiplier, WORD32 vec_length, - pVOID p_scratch); - -int xa_nn_get_softmax_scratch_size(int inp_precision, int out_precision, - int length); - -WORD32 xa_nn_matXvec_out_stride_asym8uxasym8u_asym8u( - UWORD8 *__restrict__ p_out, const UWORD8 *__restrict__ p_mat1, - const UWORD8 *__restrict__ p_vec1, const WORD32 *__restrict__ p_bias, - WORD32 rows, WORD32 cols1, WORD32 row_stride1, WORD32 out_stride, - WORD32 mat1_zero_bias, WORD32 vec1_zero_bias, WORD32 out_multiplier, - WORD32 out_shift, WORD32 out_zero_bias); - -WORD32 xa_nn_matXvec_out_stride_sym8sxasym8s_asym8s( - WORD8 *__restrict__ p_out, const WORD8 *__restrict__ p_mat1, - const WORD8 *__restrict__ p_vec1, const WORD32 *__restrict__ p_bias, - WORD32 rows, WORD32 cols1, WORD32 row_stride1, WORD32 out_stride, - WORD32 vec1_zero_bias, WORD32 out_multiplier, WORD32 out_shift, - WORD32 out_zero_bias); - -WORD32 xa_nn_matXvec_out_stride_asym8sxasym8s_asym8s( - WORD8 *__restrict__ p_out, const WORD8 *__restrict__ p_mat1, - const WORD8 *__restrict__ p_vec1, const WORD32 *__restrict__ p_bias, - WORD32 rows, WORD32 cols1, WORD32 row_stride1, WORD32 out_stride, - WORD32 mat1_zero_bias, WORD32 vec1_zero_bias, WORD32 out_multiplier, - WORD32 out_shift, WORD32 out_zero_bias); - -WORD32 xa_nn_matXvec_out_stride_sym8sxasym8s_16( - WORD16 *__restrict__ p_out, const WORD8 *__restrict__ p_mat1, - const WORD8 *__restrict__ p_vec1, const WORD32 *__restrict__ p_bias, - WORD32 rows, WORD32 cols1, WORD32 row_stride1, WORD32 out_stride, - WORD32 vec1_zero_bias, WORD32 out_multiplier, WORD32 out_shift); - -WORD32 xa_nn_dot_prod_16x16_asym8s( - WORD8 *__restrict__ p_out, /* pointer to output */ - const WORD16 *__restrict__ p_inp1_start, /* pointer to input1 */ - const WORD16 *__restrict__ p_inp2_start, /* pointer to input2 */ - const WORD32 *bias_ptr, WORD32 vec_length, WORD32 out_multiplier, - WORD32 out_shift, WORD32 out_zero_bias, WORD32 vec_count); - -/* Mapping the functions names from previous naming convension for backward - * compatibility */ -#define xa_nn_vec_activation_min_max_asym8_asym8 \ - xa_nn_vec_activation_min_max_asym8u_asym8u -#define xa_nn_conv2d_std_asym8xasym8 xa_nn_conv2d_std_asym8uxasym8u -#define xa_nn_conv2d_depthwise_asym8xasym8 xa_nn_conv2d_depthwise_asym8uxasym8u -#define xa_nn_fully_connected_asym8xasym8_asym8 \ - xa_nn_fully_connected_asym8uxasym8u_asym8u -#define xa_nn_vec_softmax_asym8_asym8 xa_nn_vec_softmax_asym8u_asym8u -#define xa_nn_dot_prod_asym8xasym8_asym8 xa_nn_dot_prod_asym8uxasym8u_asym8u -#define xa_nn_matXvec_out_stride_asym8xasym8_asym8 \ - xa_nn_matXvec_out_stride_asym8uxasym8u_asym8u - -#if defined(__cplusplus) -} -#endif -#endif /* __XA_NNLIB_KERNELS_API_H__ */ diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_standards.h b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_standards.h deleted file mode 100644 index 36ea75d1e25..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_standards.h +++ /dev/null @@ -1,170 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Cadence Design Systems, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining - * a copy of this software and associated documentation files (the - * "Software"), to use this Software with Cadence processor cores only and - * not with any other processors and platforms, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - ******************************************************************************/ - -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef __STANDARDS_H__ -#define __STANDARDS_H__ - -#if defined(__cplusplus) -extern "C" { -#endif - -typedef double flt64; -typedef char Int4; -typedef char Int8; -typedef int16_t Int16; -typedef int Int32; -typedef int Int24; -typedef int64_t Int64; -typedef int Bool; -typedef float Flt32; - -#ifdef MODEL_FLT64 -typedef double vect_t; -typedef double coeff_t; -typedef double accu_t; - -#elif MODEL_INT16 -typedef int16_t vect_t; -typedef int16_t coeff_t; -typedef signed char coeff8_t; -typedef int64_t accu_t; -typedef float coefff32_t; -#endif - -typedef struct xa_nnlib_opaque { - Int32 _; -} * xa_nnlib_handle_t; - -typedef enum _xa_nnlib_prec_t { - PREC_8 = 8, - PREC_16 = 16, - PREC_32 = 32, - PREC_F32 = -1, - PREC_F16 = -2, - PREC_ASYM8U = -3, - PREC_ASYM8S = -4, - PREC_SYM8S = -5 -} xa_nnlib_prec_t; - -typedef enum _xa_nnlib_shape_type_t { - SHAPE_UNKNOWN_T = 0, - SHAPE_VECTOR_T = 1, - SHAPE_MATRIX_T = 2, - SHAPE_CUBE_DWH_T = 3, - SHAPE_CUBE_WHD_T = 4 -} xa_nnlib_shape_type_t; - -typedef struct _xa_nnlib_shape_t { - xa_nnlib_shape_type_t shape_type; - Int32 n_shapes; - Int32 shape_offset; // Offest between current shape and next shape - union { - struct { - Int32 height; - Int32 height_offset; - Int32 width; - Int32 width_offset; - Int32 depth; - Int32 depth_offset; - } cube; - - struct { - Int32 length; - } vector; - struct { - Int32 rows; - Int32 row_offset; // Offset between current row and next row - Int32 cols; - } matrix; - } dim; -} xa_nnlib_shape_t; - -/*****************************************************************************/ -/* Constant hash defines */ -/*****************************************************************************/ -#define XA_NNLIB_NO_ERROR 0 -/* error handling 'AND' definition */ -#define XA_FATAL_ERROR 0x80000000 - -enum xa_error_severity { - xa_severity_nonfatal = 0, - xa_severity_fatal = (int)0xffffffff -}; - -enum xa_error_class { - xa_class_nnlib = 0, - xa_class_config = 1, - xa_class_execute = 2 -}; - -#define XA_NNLIB_GENERIC 0 - -#define XA_ERROR_CODE(severity, class, codec, index) \ - ((severity << 31) | (class << 12) | (codec << 7) | index) -#define XA_ERROR_SEVERITY(code) (((code)&XA_FATAL_ERROR) != 0) -#define XA_ERROR_CLASS(code) (((code) >> 12) & 0x0f) -#define XA_ERROR_CODEC(code) (((code) >> 7) & 0x1f) -#define XA_ERROR_SUBCODE(code) (((code) >> 0) & 0x3f) - -/* Our convention is that only nnlib-class errors can be generic ones. */ - -/*****************************************************************************/ -/* Class 0: NNLib Errors */ -/*****************************************************************************/ -/* Non Fatal Errors */ -/* (none) */ -/* Fatal Errors */ -enum xa_error_fatal_nnlib_generic { - XA_NNLIB_FATAL_MEM_ALLOC = - XA_ERROR_CODE(xa_severity_fatal, xa_class_nnlib, XA_NNLIB_GENERIC, 0), - XA_NNLIB_FATAL_MEM_ALIGN = - XA_ERROR_CODE(xa_severity_fatal, xa_class_nnlib, XA_NNLIB_GENERIC, 1), - XA_NNLIB_FATAL_INVALID_SHAPE = - XA_ERROR_CODE(xa_severity_fatal, xa_class_nnlib, XA_NNLIB_GENERIC, 3) -}; - -/*****************************************************************************/ -/* NNLib Startup Functions */ -/*****************************************************************************/ -const Int8* xa_nnlib_get_lib_name_string(void); -const Int8* xa_nnlib_get_lib_version_string(void); -const Int8* xa_nnlib_get_lib_api_version_string(void); - -#if defined(__cplusplus) -} -#endif - -#endif /* __STANDARDS_H__ */ diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/xa_type_def.h b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/xa_type_def.h deleted file mode 100644 index 13a7469bbf7..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/xa_type_def.h +++ /dev/null @@ -1,108 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Cadence Design Systems, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining - * a copy of this software and associated documentation files (the - * "Software"), to use this Software with Cadence processor cores only and - * not with any other processors and platforms, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - ******************************************************************************/ - -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef __XA_TYPE_DEF_H__ -#define __XA_TYPE_DEF_H__ - -#include <stdint.h> - -/****************************************************************************/ -/* types type define prefix examples bytes */ -/************************ *********** ****** **************** ***** */ -typedef signed char WORD8; /* b WORD8 b_name 1 */ -typedef signed char* pWORD8; /* pb pWORD8 pb_nmae 1 */ -typedef unsigned char UWORD8; /* ub UWORD8 ub_count 1 */ -typedef unsigned char* pUWORD8; /* pub pUWORD8 pub_count 1 */ - -typedef int16_t WORD16; /* s WORD16 s_count 2 */ -typedef int16_t* pWORD16; /* ps pWORD16 ps_count 2 */ -typedef uint16_t UWORD16; /* us UWORD16 us_count 2 */ -typedef uint16_t* pUWORD16; /* pus pUWORD16 pus_count 2 */ - -typedef signed int WORD24; /* k WORD24 k_count 3 */ -typedef signed int* pWORD24; /* pk pWORD24 pk_count 3 */ -typedef unsigned int UWORD24; /* uk UWORD24 uk_count 3 */ -typedef unsigned int* pUWORD24; /* puk pUWORD24 puk_count 3 */ - -typedef signed int WORD32; /* i WORD32 i_count 4 */ -typedef signed int* pWORD32; /* pi pWORD32 pi_count 4 */ -typedef unsigned int UWORD32; /* ui UWORD32 ui_count 4 */ -typedef unsigned int* pUWORD32; /* pui pUWORD32 pui_count 4 */ - -typedef int64_t WORD40; /* m WORD40 m_count 5 */ -typedef int64_t* pWORD40; /* pm pWORD40 pm_count 5 */ -typedef uint64_t UWORD40; /* um UWORD40 um_count 5 */ -typedef uint64_t* pUWORD40; /* pum pUWORD40 pum_count 5 */ - -typedef int64_t WORD64; /* h WORD64 h_count 8 */ -typedef int64_t* pWORD64; /* ph pWORD64 ph_count 8 */ -typedef uint64_t UWORD64; /* uh UWORD64 uh_count 8 */ -typedef uint64_t* pUWORD64; /* puh pUWORD64 puh_count 8 */ - -typedef float FLOAT32; /* f FLOAT32 f_count 4 */ -typedef float* pFLOAT32; /* pf pFLOAT32 pf_count 4 */ -typedef double FLOAT64; /* d UFLOAT64 d_count 8 */ -typedef double* pFlOAT64; /* pd pFLOAT64 pd_count 8 */ - -typedef void VOID; /* v VOID v_flag 4 */ -typedef void* pVOID; /* pv pVOID pv_flag 4 */ - -/* variable size types: platform optimized implementation */ -typedef signed int BOOL; /* bool BOOL bool_true */ -typedef unsigned int UBOOL; /* ubool BOOL ubool_true */ -typedef signed int FLAG; /* flag FLAG flag_false */ -typedef unsigned int UFLAG; /* uflag FLAG uflag_false */ -typedef signed int LOOPIDX; /* lp LOOPIDX lp_index */ -typedef unsigned int ULOOPIDX; /* ulp SLOOPIDX ulp_index */ -typedef signed int WORD; /* lp LOOPIDX lp_index */ -typedef unsigned int UWORD; /* ulp SLOOPIDX ulp_index */ - -typedef LOOPIDX LOOPINDEX; /* lp LOOPIDX lp_index */ -typedef ULOOPIDX ULOOPINDEX; /* ulp SLOOPIDX ulp_index */ - -#define PLATFORM_INLINE __inline - -typedef struct xa_codec_opaque { - WORD32 _; -} * xa_codec_handle_t; - -typedef int XA_ERRORCODE; - -typedef XA_ERRORCODE xa_codec_func_t(xa_codec_handle_t p_xa_module_obj, - WORD32 i_cmd, WORD32 i_idx, - pVOID pv_value); - -#endif /* __XA_TYPE_DEF_H__ */ diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xtensa_tf_micro_common.h b/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xtensa_tf_micro_common.h deleted file mode 100644 index 81847b60444..00000000000 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xtensa_tf_micro_common.h +++ /dev/null @@ -1,88 +0,0 @@ -/****************************************************************************** - * Copyright (C) 2019 Cadence Design Systems, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining - * a copy of this software and associated documentation files (the - * "Software"), to use this Software with Cadence processor cores only and - * not with any other processors and platforms, subject to - * the following conditions: - * - * The above copyright notice and this permission notice shall be included - * in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - ******************************************************************************/ - -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef __XTENSA_TF_MICRO_COMMON__ -#define __XTENSA_TF_MICRO_COMMON__ - -#if defined HIFI_NNLIB_OPT || defined HIFI_MINI_NNLIB_OPT -#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_api.h" -#include "tensorflow/lite/micro/kernels/xtensa_hifimini_staging/xa_nnlib/include/nnlib/xa_nnlib_standards.h" - -#define CHECK_ERR_HIFI_NNLIB_KER(ret, err_msg) \ - if (ret != 0) { \ - TF_LITE_KERNEL_LOG(context, err_msg); \ - return kTfLiteError; \ - } - -#ifndef XTENSA_NNLIB_MAX_SCRATCH_SIZE -#define XTENSA_NNLIB_MAX_SCRATCH_SIZE (70 * 1024) -#endif - -#define ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM \ - uint8_t xtensa_nnlib_scratch_buf[XTENSA_NNLIB_MAX_SCRATCH_SIZE]; - -#define MIN(a, b) (a) < (b) ? (a) : (b); -#define MAX(a, b) (a) > (b) ? (a) : (b); - -#define ACTIVATION_MIN_MAX(data_type, out, inp, min, max) \ - { \ - data_type temp = MAX(inp, min); \ - out = MIN(temp, max); \ - } - -#define ACTIVATION_MIN_MAX_F32(out, inp, min, max) \ - { \ - float temp = MAX(inp, min); \ - out = MIN(temp, max); \ - } - -#define ACTIVATION_MIN_MAX_ASYM8(out, inp, min, max) \ - { \ - int32_t temp = MAX((int32_t)inp, min); \ - out = (uint8_t)MIN(temp, max); \ - } - -#define ALIGNED_SIZE(x, bytes) (((x) + (bytes - 1)) & (~(bytes - 1))) -#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) - -#define PRINT_VAR(var) \ - printf("%s = %d\n", #var, var); \ - fflush(stdout); \ - fflush(stderr); - -#endif /* HIFI_NNLIB_OPT */ - -#endif /* __XTENSA_TF_MICRO_COMMON__ */ diff --git a/tensorflow/lite/micro/micro_interpreter_test.cc b/tensorflow/lite/micro/micro_interpreter_test.cc index 2dbfa8b56ac..3f4bb813d2d 100644 --- a/tensorflow/lite/micro/micro_interpreter_test.cc +++ b/tensorflow/lite/micro/micro_interpreter_test.cc @@ -494,4 +494,68 @@ TF_LITE_MICRO_TEST(TestInterpreterDoesNotAllocateUntilInvoke) { static_cast<size_t>(0)); } +TF_LITE_MICRO_TEST(TestInterpreterMultipleInputs) { + const tflite::Model* model = tflite::testing::GetSimpleMultipleInputsModel(); + TF_LITE_MICRO_EXPECT_NE(nullptr, model); + + tflite::AllOpsResolver op_resolver = tflite::testing::GetOpResolver(); + + constexpr size_t allocator_buffer_size = 2000; + uint8_t allocator_buffer[allocator_buffer_size]; + + // Create a new scope so that we can test the destructor. + { + tflite::MicroInterpreter interpreter(model, op_resolver, allocator_buffer, + allocator_buffer_size, + micro_test::reporter); + + TF_LITE_MICRO_EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + TF_LITE_MICRO_EXPECT_LE(interpreter.arena_used_bytes(), 928 + 100); + + TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(3), interpreter.inputs_size()); + TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(1), interpreter.outputs_size()); + TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(4), interpreter.tensors_size()); + + TfLiteTensor* input = interpreter.input(0); + TF_LITE_MICRO_EXPECT_NE(nullptr, input); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt32, input->type); + TF_LITE_MICRO_EXPECT_EQ(1, input->dims->size); + TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]); + TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(4), input->bytes); + TF_LITE_MICRO_EXPECT_NE(nullptr, input->data.i32); + input->data.i32[0] = 21; + + TfLiteTensor* input1 = interpreter.input(1); + TF_LITE_MICRO_EXPECT_NE(nullptr, input1); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt8, input1->type); + TF_LITE_MICRO_EXPECT_EQ(1, input1->dims->size); + TF_LITE_MICRO_EXPECT_EQ(1, input1->dims->data[0]); + TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(1), input1->bytes); + TF_LITE_MICRO_EXPECT_NE(nullptr, input1->data.i32); + input1->data.i32[0] = 21; + + TfLiteTensor* input2 = interpreter.input(2); + TF_LITE_MICRO_EXPECT_NE(nullptr, input2); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt32, input2->type); + TF_LITE_MICRO_EXPECT_EQ(1, input2->dims->size); + TF_LITE_MICRO_EXPECT_EQ(1, input2->dims->data[0]); + TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(4), input2->bytes); + TF_LITE_MICRO_EXPECT_NE(nullptr, input2->data.i32); + input2->data.i32[0] = 24; + + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.Invoke()); + + TfLiteTensor* output = interpreter.output(0); + TF_LITE_MICRO_EXPECT_NE(nullptr, output); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt32, output->type); + TF_LITE_MICRO_EXPECT_EQ(1, output->dims->size); + TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]); + TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(4), output->bytes); + TF_LITE_MICRO_EXPECT_NE(nullptr, output->data.i32); + TF_LITE_MICRO_EXPECT_EQ(66, output->data.i32[0]); + } + + TF_LITE_MICRO_EXPECT_EQ(tflite::testing::MultipleInputs::freed_, true); +} + TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/test_helpers.cc b/tensorflow/lite/micro/test_helpers.cc index 897f3110036..f73073f63da 100644 --- a/tensorflow/lite/micro/test_helpers.cc +++ b/tensorflow/lite/micro/test_helpers.cc @@ -570,6 +570,74 @@ const Model* BuildComplexMockModel() { return model; } +const Model* BuildSimpleMultipleInputsModel() { + using flatbuffers::Offset; + flatbuffers::FlatBufferBuilder* builder = BuilderInstance(); + + constexpr size_t buffers_size = 1; + const Offset<Buffer> buffers[buffers_size] = { + CreateBuffer(*builder), + }; + constexpr size_t tensor_shape_size = 1; + const int32_t tensor_shape[tensor_shape_size] = {1}; + constexpr size_t tensors_size = 4; + const Offset<Tensor> tensors[tensors_size] = { + CreateTensor(*builder, + builder->CreateVector(tensor_shape, tensor_shape_size), + TensorType_INT32, 0, + builder->CreateString("test_input_tensor1"), 0, false), + CreateTensor(*builder, + builder->CreateVector(tensor_shape, tensor_shape_size), + TensorType_INT8, 0, + builder->CreateString("test_input_tensor2"), 0, false), + CreateTensor(*builder, + builder->CreateVector(tensor_shape, tensor_shape_size), + TensorType_INT32, 0, + builder->CreateString("test_input_tensor3"), 0, false), + CreateTensor(*builder, + builder->CreateVector(tensor_shape, tensor_shape_size), + TensorType_INT32, 0, + builder->CreateString("test_output_tensor"), 0, false), + }; + constexpr size_t inputs_size = 3; + const int32_t inputs[inputs_size] = {0, 1, 2}; + constexpr size_t outputs_size = 1; + const int32_t outputs[outputs_size] = {3}; + constexpr size_t operator_inputs_size = 3; + const int32_t operator_inputs[operator_inputs_size] = {0, 1, 2}; + constexpr size_t operator_outputs_size = 1; + const int32_t operator_outputs[operator_outputs_size] = {3}; + constexpr size_t operators_size = 1; + const Offset<Operator> operators[operators_size] = { + CreateOperator( + *builder, 0, + builder->CreateVector(operator_inputs, operator_inputs_size), + builder->CreateVector(operator_outputs, operator_outputs_size), + BuiltinOptions_NONE), + }; + constexpr size_t subgraphs_size = 1; + const Offset<SubGraph> subgraphs[subgraphs_size] = { + CreateSubGraph(*builder, builder->CreateVector(tensors, tensors_size), + builder->CreateVector(inputs, inputs_size), + builder->CreateVector(outputs, outputs_size), + builder->CreateVector(operators, operators_size), + builder->CreateString("test_subgraph"))}; + constexpr size_t operator_codes_size = 1; + const Offset<OperatorCode> operator_codes[operator_codes_size] = { + CreateOperatorCodeDirect(*builder, /*deprecated_builtin_code=*/0, + "multiple_inputs_op", + /*version=*/0, BuiltinOperator_CUSTOM)}; + const Offset<Model> model_offset = CreateModel( + *builder, 0, builder->CreateVector(operator_codes, operator_codes_size), + builder->CreateVector(subgraphs, subgraphs_size), + builder->CreateString("test_model"), + builder->CreateVector(buffers, buffers_size)); + FinishModelBuffer(*builder, model_offset); + void* model_pointer = builder->GetBufferPointer(); + const Model* model = flatbuffers::GetRoot<Model>(model_pointer); + return model; +} + } // namespace const TfLiteRegistration* SimpleStatefulOp::getRegistration() { @@ -704,12 +772,66 @@ TfLiteStatus MockCustom::Invoke(TfLiteContext* context, TfLiteNode* node) { bool MockCustom::freed_ = false; +const TfLiteRegistration* MultipleInputs::getRegistration() { + return GetMutableRegistration(); +} + +TfLiteRegistration* MultipleInputs::GetMutableRegistration() { + static TfLiteRegistration r; + r.init = Init; + r.prepare = Prepare; + r.invoke = Invoke; + r.free = Free; + return &r; +} + +void* MultipleInputs::Init(TfLiteContext* context, const char* buffer, + size_t length) { + // We don't support delegate in TFL micro. This is a weak check to test if + // context struct being zero-initialized. + TFLITE_DCHECK(context->ReplaceNodeSubsetsWithDelegateKernels == nullptr); + freed_ = false; + // Do nothing. + return nullptr; +} + +void MultipleInputs::Free(TfLiteContext* context, void* buffer) { + freed_ = true; +} + +TfLiteStatus MultipleInputs::Prepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteStatus MultipleInputs::Invoke(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input; + TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); + const int32_t* input_data = input->data.i32; + const TfLiteTensor* input1; + TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &input1)); + const int32_t* input_data1 = input1->data.i32; + const TfLiteTensor* input2; + TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &input2)); + const int32_t* input_data2 = input2->data.i32; + + TfLiteTensor* output; + TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); + int32_t* output_data = output->data.i32; + output_data[0] = + 0; // Catch output tensor sharing memory with an input tensor + output_data[0] = input_data[0] + input_data1[0] + input_data2[0]; + return kTfLiteOk; +} + +bool MultipleInputs::freed_ = false; + AllOpsResolver GetOpResolver() { AllOpsResolver op_resolver; op_resolver.AddCustom("mock_custom", MockCustom::GetMutableRegistration()); op_resolver.AddCustom("simple_stateful_op", SimpleStatefulOp::GetMutableRegistration()); - + op_resolver.AddCustom("multiple_inputs_op", + MultipleInputs::GetMutableRegistration()); return op_resolver; } @@ -721,6 +843,14 @@ const Model* GetSimpleMockModel() { return model; } +const Model* GetSimpleMultipleInputsModel() { + static Model* model = nullptr; + if (!model) { + model = const_cast<Model*>(BuildSimpleMultipleInputsModel()); + } + return model; +} + const Model* GetComplexMockModel() { static Model* model = nullptr; if (!model) { diff --git a/tensorflow/lite/micro/test_helpers.h b/tensorflow/lite/micro/test_helpers.h index 1db0d81facc..4c8b7c20aa0 100644 --- a/tensorflow/lite/micro/test_helpers.h +++ b/tensorflow/lite/micro/test_helpers.h @@ -76,6 +76,20 @@ class MockCustom { static bool freed_; }; +// A simple operator with the purpose of testing multiple inputs. It returns +// the sum of the inputs. +class MultipleInputs { + public: + static const TfLiteRegistration* getRegistration(); + static TfLiteRegistration* GetMutableRegistration(); + static void* Init(TfLiteContext* context, const char* buffer, size_t length); + static void Free(TfLiteContext* context, void* buffer); + static TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node); + static TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node); + + static bool freed_; +}; + // Returns an Op Resolver that can be used in the testing code. AllOpsResolver GetOpResolver(); @@ -90,6 +104,10 @@ const Model* GetComplexMockModel(); // Returns a simple flatbuffer model with two branches. const Model* GetSimpleModelWithBranch(); +// Returns a simple example flatbuffer TensorFlow Lite model. Contains 3 inputs, +// 1 output Tensor, and 1 operator. +const Model* GetSimpleMultipleInputsModel(); + // Returns a simple flatbuffer model with offline planned tensors // @param[in] num_tensors Number of tensors in the model. // @param[in] metadata_buffer Metadata for offline planner. diff --git a/tensorflow/lite/micro/tools/ci_build/test_esp32.sh b/tensorflow/lite/micro/tools/ci_build/test_esp32.sh new file mode 100755 index 00000000000..8341e90924e --- /dev/null +++ b/tensorflow/lite/micro/tools/ci_build/test_esp32.sh @@ -0,0 +1,58 @@ +#!/usr/bin/env bash +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Tests the microcontroller code for esp32 platform + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR=${SCRIPT_DIR}/../../../../.. +cd "${ROOT_DIR}" +pwd + +source tensorflow/lite/micro/tools/ci_build/helper_functions.sh + +TARGET=esp + +# setup esp-idf and toolchains +readable_run git clone --recursive --single-branch --branch release/v4.2 https://github.com/espressif/esp-idf.git +readable_run export IDF_PATH="${ROOT_DIR}"/esp-idf +cd $IDF_PATH +readable_run ./install.sh +readable_run . ./export.sh +cd "${ROOT_DIR}" + +# clean all +readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean + +# generate examples +readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} generate_hello_world_esp_project +readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} generate_person_detection_esp_project +readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} generate_micro_speech_esp_project + +# build examples +cd "${ROOT_DIR}"/tensorflow/lite/micro/tools/make/gen/esp_xtensa-esp32/prj/hello_world/esp-idf +readable_run idf.py build + +cd "${ROOT_DIR}"/tensorflow/lite/micro/tools/make/gen/esp_xtensa-esp32/prj/person_detection/esp-idf +readable_run git clone https://github.com/espressif/esp32-camera.git components/esp32-camera +cd components/esp32-camera/ +readable_run git checkout eacd640b8d379883bff1251a1005ebf3cf1ed95c +cd ../../ +readable_run idf.py build + +cd "${ROOT_DIR}"/tensorflow/lite/micro/tools/make/gen/esp_xtensa-esp32/prj/micro_speech/esp-idf +readable_run idf.py build diff --git a/tensorflow/lite/micro/tools/ci_build/test_x86.sh b/tensorflow/lite/micro/tools/ci_build/test_x86.sh index 844dccbafb7..05d79802dcd 100755 --- a/tensorflow/lite/micro/tools/ci_build/test_x86.sh +++ b/tensorflow/lite/micro/tools/ci_build/test_x86.sh @@ -29,9 +29,12 @@ readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean # TODO(b/143715361): downloading first to allow for parallel builds. readable_run make -f tensorflow/lite/micro/tools/make/Makefile third_party_downloads -# Next, build w/o TF_LITE_STATIC_MEMORY to catch additional build errors. +# Next, build w/o TF_LITE_STATIC_MEMORY to catch additional errors. +# TODO(b/160955687): We run the tests w/o TF_LITE_STATIC_MEMORY to make the +# internal and open source CI consistent. See b/160955687#comment7 for more +# details. readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean -readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile BUILD_TYPE=no_tf_lite_static_memory build +readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile BUILD_TYPE=no_tf_lite_static_memory test # Next, make sure that the release build succeeds. readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean diff --git a/tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifimini.inc b/tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifimini.inc deleted file mode 100644 index c972870cf02..00000000000 --- a/tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifimini.inc +++ /dev/null @@ -1,4 +0,0 @@ -# Every optimized kernel implementation directory (i.e. -# micro/kernels/<optimized_kernel_dir>/ must have a corresponding -# micro/tools/make/ext_libs/<optimized_kernel_dir>.inc - diff --git a/tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifimini_staging_nn_library.inc b/tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifimini_staging_nn_library.inc deleted file mode 100644 index df7d3089c30..00000000000 --- a/tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifimini_staging_nn_library.inc +++ /dev/null @@ -1,30 +0,0 @@ -ifneq ($(filter xtensa_hifimini_staging, $(ALL_TAGS)),) - - XTENSA_PATH = $(MAKEFILE_DIR)/../../kernels/xtensa_hifimini_staging - - ifneq (,$(filter xtensa_hifimini%, $(ALL_TAGS))) - - CCFLAGS += -DHIFI_MINI_NNLIB_OPT \ - -DDISABLE_NNLIB_UNALIGNED_SUPPORT \ - -DXTENSA_NNLIB_MAX_SCRATCH_SIZE=1024 - - CXXFLAGS += -DHIFI_MINI_NNLIB_OPT \ - -DDISABLE_NNLIB_UNALIGNED_SUPPORT \ - -DXTENSA_NNLIB_MAX_SCRATCH_SIZE=1024 - - MICROLITE_CC_SRCS += \ - $(XTENSA_PATH)/xa_nnlib/algo/kernels/activations/hifi_mini/xa_nn_activations_asym8s_asym8s.c \ - $(XTENSA_PATH)/xa_nnlib/algo/kernels/activations/hifi_mini/xa_nn_softmax_asym8_asym8.c \ - $(XTENSA_PATH)/xa_nnlib/algo/kernels/basic/hifi_mini/xa_nn_dot_prod_16x16.c \ - $(XTENSA_PATH)/xa_nnlib/algo/kernels/fc/hifi_mini/xa_nn_fully_connected.c \ - $(XTENSA_PATH)/xa_nnlib/algo/kernels/matXvec/hifi_mini/xa_nn_matXvec_sym8sxasym8s.c \ - - - INCLUDES += -I$(XTENSA_PATH)/xa_nnlib/algo/kernels/ \ - -I$(XTENSA_PATH)/xa_nnlib/include/nnlib/ \ - -I$(XTENSA_PATH)/xa_nnlib/include/ \ - -I$(XTENSA_PATH)/xa_nnlib/algo/common/include/ \ - - endif - -endif diff --git a/tensorflow/lite/micro/tools/make/targets/xtensa_hifimini_makefile.inc b/tensorflow/lite/micro/tools/make/targets/xtensa_hifimini_makefile.inc deleted file mode 100644 index 1587ebcd034..00000000000 --- a/tensorflow/lite/micro/tools/make/targets/xtensa_hifimini_makefile.inc +++ /dev/null @@ -1,58 +0,0 @@ -# Settings for Xtensa toolchain for the hifimini kernels. -# REQUIRED: -# Environment variables: -# - XTENSA_BASE must be set to location of -# the Xtensa developer tools installation directory. -# Command line arguments: -# - XTENSA_TOOLS_VERSION: For example: RI-2019.2-linux -# - XTENSA_CORE: The name of the Xtensa core to use -# For example: hifimini - -TARGET_ARCH := xtensa_hifimini - -ifndef XTENSA_BASE - $(error XTENSA_BASE is undefined) -endif - -ifndef XTENSA_TOOLS_VERSION - $(error XTENSA_TOOLS_VERSION is undefined) -endif - -ifndef XTENSA_CORE - $(error XTENSA_CORE is undefined) -endif - -PLATFORM_FLAGS = \ - -DTF_LITE_MCU_DEBUG_LOG \ - -DTF_LITE_USE_CTIME \ - --xtensa-core=$(XTENSA_CORE) \ - -mcoproc \ - -DXTENSA \ - -DMAX_RFFT_PWR=9 \ - -DMIN_RFFT_PWR=MAX_RFFT_PWR - - -export PATH := $(XTENSA_BASE)/tools/$(XTENSA_TOOLS_VERSION)/XtensaTools/bin:$(PATH) -TARGET_TOOLCHAIN_PREFIX := xt- -CXX_TOOL := clang++ -CC_TOOL := clang - -CXXFLAGS += $(PLATFORM_FLAGS) -CCFLAGS += $(PLATFORM_FLAGS) - -# TODO(b/150240249): Do not remove -fno-rtti once that works for the Xtensa toolchain. -CXXFLAGS := $(filter-out -fno-rtti, $(CXXFLAGS)) - -TEST_SCRIPT := tensorflow/lite/micro/testing/test_xtensa_hifimini_binary.sh - -# TODO(b/156962140): This manually maintained list of excluded examples is -# quite error prone. -EXCLUDED_EXAMPLE_TESTS := \ - tensorflow/lite/micro/examples/image_recognition_experimental/Makefile.inc \ - tensorflow/lite/micro/examples/magic_wand/Makefile.inc \ - tensorflow/lite/micro/examples/micro_speech/Makefile.inc \ - tensorflow/lite/micro/examples/network_tester/Makefile.inc \ - tensorflow/lite/micro/examples/person_detection/Makefile.inc \ - tensorflow/lite/micro/examples/person_detection_experimental/Makefile.inc -MICRO_LITE_EXAMPLE_TESTS := $(filter-out $(EXCLUDED_EXAMPLE_TESTS), $(MICRO_LITE_EXAMPLE_TESTS)) - diff --git a/tensorflow/lite/micro/tools/make/targets/xtensa_hifimini_staging_makefile.inc b/tensorflow/lite/micro/tools/make/targets/xtensa_hifimini_staging_makefile.inc deleted file mode 100644 index 557b8f6e9e6..00000000000 --- a/tensorflow/lite/micro/tools/make/targets/xtensa_hifimini_staging_makefile.inc +++ /dev/null @@ -1,62 +0,0 @@ -# Settings for Xtensa toolchain for the hifimini kernels. -# REQUIRED: -# Environment variables: -# - XTENSA_BASE must be set to location of -# the Xtensa developer tools installation directory. -# Command line arguments: -# - XTENSA_TOOLS_VERSION: For example: RI-2019.2-linux -# - XTENSA_CORE: The name of the Xtensa core to use -# For example: hifimini - -ifeq ($(TARGET), xtensa_hifimini_staging) - TARGET_ARCH := xtensa_hifimini_staging - - ifndef XTENSA_BASE - $(error XTENSA_BASE is undefined) - endif - - ifndef XTENSA_TOOLS_VERSION - $(error XTENSA_TOOLS_VERSION is undefined) - endif - - ifndef XTENSA_CORE - $(error XTENSA_CORE is undefined) - endif - - PLATFORM_ARGS = \ - -DTF_LITE_MCU_DEBUG_LOG \ - --xtensa-core=$(XTENSA_CORE) \ - -mcoproc \ - -DXTENSA -DMAX_RFFT_PWR=9 -DMIN_RFFT_PWR=MAX_RFFT_PWR \ - -fdata-sections \ - -ffunction-sections \ - -fno-exceptions \ - -fno-unwind-tables \ - -fno-use-cxa-atexit \ - -fmessage-length=0 \ - -fno-threadsafe-statics - - export PATH := $(XTENSA_BASE)/tools/$(XTENSA_TOOLS_VERSION)/XtensaTools/bin:$(PATH) - TARGET_TOOLCHAIN_PREFIX := xt- - CXX_TOOL := clang++ - CC_TOOL := clang - - CXXFLAGS += $(PLATFORM_ARGS) - CCFLAGS += $(PLATFORM_ARGS) - - LDFLAGS += -Wl,-gc-sections - - TEST_SCRIPT := tensorflow/lite/micro/testing/test_xtensa_hifimini_staging_binary.sh - - # TODO(b/156962140): This manually maintained list of excluded examples is - # quite error prone. - EXCLUDED_EXAMPLE_TESTS := \ - tensorflow/lite/micro/examples/image_recognition_experimental/Makefile.inc \ - tensorflow/lite/micro/examples/magic_wand/Makefile.inc \ - tensorflow/lite/micro/examples/micro_speech/Makefile.inc \ - tensorflow/lite/micro/examples/network_tester/Makefile.inc \ - tensorflow/lite/micro/examples/person_detection/Makefile.inc \ - tensorflow/lite/micro/examples/person_detection_experimental/Makefile.inc - MICRO_LITE_EXAMPLE_TESTS := $(filter-out $(EXCLUDED_EXAMPLE_TESTS), $(MICRO_LITE_EXAMPLE_TESTS)) - -endif diff --git a/tensorflow/lite/micro/xtensa_hifimini_staging/debug_log.cc b/tensorflow/lite/micro/xtensa_hifimini_staging/debug_log.cc deleted file mode 100644 index 45d9317478a..00000000000 --- a/tensorflow/lite/micro/xtensa_hifimini_staging/debug_log.cc +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Reference implementation of the DebugLog() function that's required for a -// platform to support the TensorFlow Lite for Microcontrollers library. This is -// the only function that's absolutely required to be available on a target -// device, since it's used for communicating test results back to the host so -// that we can verify the implementation is working correctly. -// It's designed to be as easy as possible to supply an implementation though. -// On platforms that have a POSIX stack or C library, it can be written as a -// single call to `fprintf(stderr, "%s", s)` to output a string to the error -// stream of the console, but if there's no OS or C library available, there's -// almost always an equivalent way to write out a string to some serial -// interface that can be used instead. For example on Arm M-series MCUs, calling -// the `bkpt #0xAB` assembler instruction will output the string in r1 to -// whatever debug serial connection is available. If you're running mbed, you -// can do the same by creating `Serial pc(USBTX, USBRX)` and then calling -// `pc.printf("%s", s)`. -// To add an equivalent function for your own platform, create your own -// implementation file, and place it in a subfolder with named after the OS -// you're targeting. For example, see the Cortex M bare metal version in -// tensorflow/lite/micro/bluepill/debug_log.cc or the mbed one on -// tensorflow/lite/micro/mbed/debug_log.cc. - -#include "tensorflow/lite/micro/debug_log.h" - -#ifndef TF_LITE_STRIP_ERROR_STRINGS -#include <cstdio> -#endif - -extern "C" void DebugLog(const char* s) { -#ifndef TF_LITE_STRIP_ERROR_STRINGS - // Reusing TF_LITE_STRIP_ERROR_STRINGS to disable DebugLog completely to get - // maximum reduction in binary size. This is because we have DebugLog calls - // via TF_LITE_CHECK that are not stubbed out by TF_LITE_REPORT_ERROR. - fprintf(stderr, "%s", s); -#endif -} diff --git a/tensorflow/lite/python/interpreter.py b/tensorflow/lite/python/interpreter.py index cd5a237b0ef..18ba58b27a5 100644 --- a/tensorflow/lite/python/interpreter.py +++ b/tensorflow/lite/python/interpreter.py @@ -460,7 +460,7 @@ class Interpreter(object): ] def get_tensor(self, tensor_index): - """Gets the value of the input tensor (get a copy). + """Gets the value of the output tensor (get a copy). If you wish to avoid the copy, use `tensor()`. This function cannot be used to read intermediate results. diff --git a/tensorflow/lite/python/interpreter_wrapper/BUILD b/tensorflow/lite/python/interpreter_wrapper/BUILD index 3055d37fa07..5e7ecd7f5fd 100644 --- a/tensorflow/lite/python/interpreter_wrapper/BUILD +++ b/tensorflow/lite/python/interpreter_wrapper/BUILD @@ -38,7 +38,6 @@ cc_library( "//tensorflow/lite:util", "//tensorflow/lite/c:common", "//tensorflow/lite/core/api", - "//tensorflow/lite/experimental/tflite_api_dispatcher", "//tensorflow/lite/kernels:builtin_ops", "//third_party/python_runtime:headers", # buildcleaner: keep "@com_google_absl//absl/memory", @@ -90,7 +89,6 @@ pybind_extension( "@pybind11", "//third_party/python_runtime:headers", "//tensorflow/lite:framework_lib", - "//tensorflow/lite/experimental/tflite_api_dispatcher", "//tensorflow/python:pybind11_lib", ] + select({ ":tflite_pip_with_flex": ["//tensorflow/lite/delegates/flex:delegate"], diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc index adfa760f147..9f97c79dd43 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -66,8 +66,8 @@ namespace { using python_utils::PyDecrefDeleter; -std::unique_ptr<tflite_api_dispatcher::Interpreter> CreateInterpreter( - const tflite_api_dispatcher::TfLiteModel* model, +std::unique_ptr<Interpreter> CreateInterpreter( + const InterpreterWrapper::Model* model, const tflite::ops::builtin::BuiltinOpResolver& resolver) { if (!model) { return nullptr; @@ -75,9 +75,8 @@ std::unique_ptr<tflite_api_dispatcher::Interpreter> CreateInterpreter( ::tflite::python::ImportNumpy(); - std::unique_ptr<tflite_api_dispatcher::Interpreter> interpreter; - if (tflite_api_dispatcher::InterpreterBuilder( - *model, resolver)(&interpreter) != kTfLiteOk) { + std::unique_ptr<Interpreter> interpreter; + if (InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) { return nullptr; } return interpreter; @@ -167,7 +166,7 @@ bool RegisterCustomOpByName(const char* registerer_name, } // namespace InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper( - std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model, + std::unique_ptr<InterpreterWrapper::Model> model, std::unique_ptr<PythonErrorReporter> error_reporter, const std::vector<std::string>& registerers_by_name, const std::vector<std::function<void(uintptr_t)>>& registerers_by_func, @@ -198,10 +197,10 @@ InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper( } InterpreterWrapper::InterpreterWrapper( - std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model, + std::unique_ptr<InterpreterWrapper::Model> model, std::unique_ptr<PythonErrorReporter> error_reporter, std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver, - std::unique_ptr<tflite_api_dispatcher::Interpreter> interpreter) + std::unique_ptr<Interpreter> interpreter) : model_(std::move(model)), error_reporter_(std::move(error_reporter)), resolver_(std::move(resolver)), @@ -537,9 +536,8 @@ namespace { // Checks to see if a tensor access can succeed (returns nullptr on error). // Otherwise returns Py_None. -PyObject* CheckGetTensorArgs(tflite_api_dispatcher::Interpreter* interpreter_, - int tensor_index, TfLiteTensor** tensor, - int* type_num) { +PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index, + TfLiteTensor** tensor, int* type_num) { TFLITE_PY_ENSURE_VALID_INTERPRETER(); TFLITE_PY_TENSOR_BOUNDS_CHECK(tensor_index); @@ -665,9 +663,8 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile( const std::vector<std::function<void(uintptr_t)>>& registerers_by_func, std::string* error_msg) { std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter); - std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model = - tflite_api_dispatcher::TfLiteModel::BuildFromFile(model_path, - error_reporter.get()); + std::unique_ptr<InterpreterWrapper::Model> model = + Model::BuildFromFile(model_path, error_reporter.get()); return CreateInterpreterWrapper(std::move(model), std::move(error_reporter), registerers_by_name, registerers_by_func, error_msg); @@ -690,9 +687,8 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) { return nullptr; } - std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model = - tflite_api_dispatcher::TfLiteModel::BuildFromBuffer(buf, length, - error_reporter.get()); + std::unique_ptr<InterpreterWrapper::Model> model = + Model::BuildFromBuffer(buf, length, error_reporter.get()); return CreateInterpreterWrapper(std::move(model), std::move(error_reporter), registerers_by_name, registerers_by_func, error_msg); diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h index 6b83d2d06db..8f03d0915b1 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -27,7 +27,6 @@ limitations under the License. // automatically move <Python.h> before <locale>. #include <Python.h> -#include "tensorflow/lite/experimental/tflite_api_dispatcher/tflite_api_dispatcher.h" #include "tensorflow/lite/interpreter.h" struct TfLiteDelegate; @@ -48,6 +47,8 @@ class PythonErrorReporter; class InterpreterWrapper { public: + using Model = FlatBufferModel; + // SWIG caller takes ownership of pointer. static InterpreterWrapper* CreateWrapperCPPFromFile( const char* model_path, const std::vector<std::string>& registerers, @@ -105,26 +106,24 @@ class InterpreterWrapper { // Experimental and subject to change. // // Returns a pointer to the underlying interpreter. - tflite_api_dispatcher::Interpreter* interpreter() { - return interpreter_.get(); - } + Interpreter* interpreter() { return interpreter_.get(); } private: // Helper function to construct an `InterpreterWrapper` object. // It only returns InterpreterWrapper if it can construct an `Interpreter`. // Otherwise it returns `nullptr`. static InterpreterWrapper* CreateInterpreterWrapper( - std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model, + std::unique_ptr<Model> model, std::unique_ptr<PythonErrorReporter> error_reporter, const std::vector<std::string>& registerers_by_name, const std::vector<std::function<void(uintptr_t)>>& registerers_by_func, std::string* error_msg); InterpreterWrapper( - std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model, + std::unique_ptr<Model> model, std::unique_ptr<PythonErrorReporter> error_reporter, std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver, - std::unique_ptr<tflite_api_dispatcher::Interpreter> interpreter); + std::unique_ptr<Interpreter> interpreter); // InterpreterWrapper is not copyable or assignable. We avoid the use of // InterpreterWrapper() = delete here for SWIG compatibility. @@ -137,10 +136,10 @@ class InterpreterWrapper { // The public functions which creates `InterpreterWrapper` should ensure all // these member variables are initialized successfully. Otherwise it should // report the error and return `nullptr`. - const std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model_; + const std::unique_ptr<Model> model_; const std::unique_ptr<PythonErrorReporter> error_reporter_; const std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver_; - const std::unique_ptr<tflite_api_dispatcher::Interpreter> interpreter_; + const std::unique_ptr<Interpreter> interpreter_; }; } // namespace interpreter_wrapper diff --git a/tensorflow/lite/python/op_hint.py b/tensorflow/lite/python/op_hint.py index 9d62c1b8a97..e8aff7b82da 100644 --- a/tensorflow/lite/python/op_hint.py +++ b/tensorflow/lite/python/op_hint.py @@ -89,11 +89,18 @@ from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes from tensorflow.python.framework.graph_util_impl import _extract_graph_summary from tensorflow.python.ops import array_ops as _array_ops from tensorflow.python.util import compat as _compat +from tensorflow.python.util import deprecation as _deprecation from tensorflow.python.util.all_util import remove_undocumented from tensorflow.python.util.tf_export import tf_export as _tf_export @_tf_export(v1=["lite.OpHint"]) +@_deprecation.deprecated( + None, + "Please follow instructions under " + "https://www.tensorflow.org/lite/convert/operation_fusion for operation" + "fusion in tflite." +) class OpHint(object): """A class that helps build tflite function invocations. @@ -1302,6 +1309,12 @@ def is_ophint_converted(graph_def): @_tf_export(v1=["lite.experimental.convert_op_hints_to_stubs"]) +@_deprecation.deprecated( + None, + "Please follow instructions under " + "https://www.tensorflow.org/lite/convert/operation_fusion for operation" + "fusion in tflite." +) def convert_op_hints_to_stubs(session=None, graph_def=None, write_callback=lambda graph_def, comments: None): diff --git a/tensorflow/lite/testing/BUILD b/tensorflow/lite/testing/BUILD index 74380bd9d4c..dbbb69b49f4 100644 --- a/tensorflow/lite/testing/BUILD +++ b/tensorflow/lite/testing/BUILD @@ -312,7 +312,7 @@ cc_library( hdrs = ["util.h"], deps = [ "//tensorflow/core/platform:logging", - "//tensorflow/lite:framework", + "//tensorflow/lite:error_reporter", "//tensorflow/lite:string", "//tensorflow/lite/core/api", ], diff --git a/tensorflow/lite/tools/BUILD b/tensorflow/lite/tools/BUILD index e2146910d52..aa852093794 100644 --- a/tensorflow/lite/tools/BUILD +++ b/tensorflow/lite/tools/BUILD @@ -258,7 +258,7 @@ cc_library( srcs = ["command_line_flags.cc"], hdrs = ["command_line_flags.h"], copts = tflite_copts(), - deps = [":logging"], + deps = ["//tensorflow/lite/tools:logging"], ) cc_test( @@ -296,8 +296,8 @@ tf_cc_binary( srcs = ["list_flex_ops_main.cc"], visibility = ["//visibility:public"], deps = [ - ":command_line_flags", ":list_flex_ops", + "//tensorflow/lite/tools:command_line_flags", "@com_google_absl//absl/strings", ], ) @@ -307,8 +307,8 @@ cc_library( srcs = ["list_flex_ops_main.cc"], visibility = ["//visibility:public"], deps = [ - ":command_line_flags", ":list_flex_ops", + "//tensorflow/lite/tools:command_line_flags", "@com_google_absl//absl/strings", ], ) @@ -329,8 +329,8 @@ tf_cc_binary( srcs = ["list_flex_ops_main.cc"], visibility = ["//visibility:public"], deps = [ - ":command_line_flags", ":list_flex_ops_no_kernel", + "//tensorflow/lite/tools:command_line_flags", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/lite/tools/benchmark/BUILD b/tensorflow/lite/tools/benchmark/BUILD index 430bcf32766..eb3f37aef58 100644 --- a/tensorflow/lite/tools/benchmark/BUILD +++ b/tensorflow/lite/tools/benchmark/BUILD @@ -213,6 +213,7 @@ cc_library( ":benchmark_params", ":benchmark_utils", "//tensorflow/core/util:stats_calculator_portable", + "//tensorflow/lite:framework", "//tensorflow/lite/c:common", "//tensorflow/lite/profiling:memory_info", "//tensorflow/lite/profiling:time", diff --git a/tensorflow/lite/tools/benchmark/android/BUILD b/tensorflow/lite/tools/benchmark/android/BUILD index 09cf57e5baf..be9c379af50 100644 --- a/tensorflow/lite/tools/benchmark/android/BUILD +++ b/tensorflow/lite/tools/benchmark/android/BUILD @@ -22,6 +22,10 @@ android_binary( # can't be built. We need to prevent the build system from trying to # use the target in that case. tags = ["manual"], + deps = [ + ":hexagon_libs", + ":tensorflowlite_benchmark_native", + ], ) tflite_jni_binary( diff --git a/tensorflow/lite/tools/benchmark/experimental/firebase/android/BUILD b/tensorflow/lite/tools/benchmark/experimental/firebase/android/BUILD index 5b25cba0487..98591063962 100644 --- a/tensorflow/lite/tools/benchmark/experimental/firebase/android/BUILD +++ b/tensorflow/lite/tools/benchmark/experimental/firebase/android/BUILD @@ -23,6 +23,10 @@ android_binary( # can't be built. We need to prevent the build system from trying to # use the target in that case. tags = ["manual"], + deps = [ + ":hexagon_libs", + ":tensorflowlite_benchmark_firebase_native", + ], ) tflite_jni_binary( diff --git a/tensorflow/lite/tools/cmake/README.md b/tensorflow/lite/tools/cmake/README.md index c48685a8c1e..b49162f0eba 100644 --- a/tensorflow/lite/tools/cmake/README.md +++ b/tensorflow/lite/tools/cmake/README.md @@ -41,6 +41,13 @@ cd tflite_build cmake ../tensorflow_src/tensorflow/lite ``` +It generates release binary by default. If you need to produce debug builds, you +need to provide '-DCMAKE_BUILD_TYPE=Debug' option. + +```sh +cmake ../tensorflow_src/tensorflow/lite -DCMAKE_BUILD_TYPE=Debug +``` + If you want to configure Android build with GPU delegate support, ```sh diff --git a/tensorflow/lite/tools/cmake/modules/Findegl_headers.cmake b/tensorflow/lite/tools/cmake/modules/Findegl_headers.cmake new file mode 100644 index 00000000000..02cf736faea --- /dev/null +++ b/tensorflow/lite/tools/cmake/modules/Findegl_headers.cmake @@ -0,0 +1,16 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +include(egl_headers) diff --git a/tensorflow/lite/tools/cmake/modules/Findopengl_headers.cmake b/tensorflow/lite/tools/cmake/modules/Findopengl_headers.cmake new file mode 100644 index 00000000000..7651549eb1c --- /dev/null +++ b/tensorflow/lite/tools/cmake/modules/Findopengl_headers.cmake @@ -0,0 +1,16 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +include(opengl_headers) diff --git a/tensorflow/lite/tools/cmake/modules/egl_headers.cmake b/tensorflow/lite/tools/cmake/modules/egl_headers.cmake new file mode 100644 index 00000000000..f6a23dbee06 --- /dev/null +++ b/tensorflow/lite/tools/cmake/modules/egl_headers.cmake @@ -0,0 +1,39 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(TARGET egl_headers OR egl_headers_POPULATED) + return() +endif() + +include(FetchContent) + +OverridableFetchContent_Declare( + egl_headers + GIT_REPOSITORY https://github.com/KhronosGroup/EGL-Registry.git + GIT_TAG 649981109e263b737e7735933c90626c29a306f2 + GIT_PROGRESS TRUE + PREFIX "${CMAKE_BINARY_DIR}" + SOURCE_DIR "${CMAKE_BINARY_DIR}/egl_headers" +) + +OverridableFetchContent_GetProperties(egl_headers) +if(NOT egl_headers) + OverridableFetchContent_Populate(egl_headers) +endif() + +include_directories( + AFTER + "${egl_headers_SOURCE_DIR}/api" +) diff --git a/tensorflow/lite/tools/cmake/modules/opengl_headers.cmake b/tensorflow/lite/tools/cmake/modules/opengl_headers.cmake new file mode 100644 index 00000000000..c9db6b48306 --- /dev/null +++ b/tensorflow/lite/tools/cmake/modules/opengl_headers.cmake @@ -0,0 +1,39 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(TARGET opengl_headers OR opengl_headers_POPULATED) + return() +endif() + +include(FetchContent) + +OverridableFetchContent_Declare( + opengl_headers + GIT_REPOSITORY https://github.com/KhronosGroup/OpenGL-Registry.git + GIT_TAG 0cb0880d91581d34f96899c86fc1bf35627b4b81 + GIT_PROGRESS TRUE + PREFIX "${CMAKE_BINARY_DIR}" + SOURCE_DIR "${CMAKE_BINARY_DIR}/opengl_headers" +) + +OverridableFetchContent_GetProperties(opengl_headers) +if(NOT opengl_headers) + OverridableFetchContent_Populate(opengl_headers) +endif() + +include_directories( + AFTER + "${opengl_headers_SOURCE_DIR}/api" +) diff --git a/tensorflow/lite/tools/evaluation/stages/BUILD b/tensorflow/lite/tools/evaluation/stages/BUILD index 1e621956ea8..4363aff01a9 100644 --- a/tensorflow/lite/tools/evaluation/stages/BUILD +++ b/tensorflow/lite/tools/evaluation/stages/BUILD @@ -180,6 +180,7 @@ cc_test( deps = [ ":inference_profiler_stage", "//tensorflow/lite/c:common", + "//tensorflow/lite/delegates/nnapi:nnapi_delegate", "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", "@com_google_googletest//:gtest_main", diff --git a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/BUILD b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/BUILD index 34a2f9ce68c..cfb49a20a78 100644 --- a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/BUILD +++ b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/BUILD @@ -29,6 +29,7 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/tools:command_line_flags", "//tensorflow/lite/tools:logging", + "//tensorflow/lite/tools/evaluation:evaluation_stage", "//tensorflow/lite/tools/evaluation:utils", "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", diff --git a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD index c315e5ca4a6..941bbc0ff69 100644 --- a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD +++ b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD @@ -17,6 +17,7 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/tools:command_line_flags", "//tensorflow/lite/tools:logging", + "//tensorflow/lite/tools/evaluation:evaluation_stage", "//tensorflow/lite/tools/evaluation:utils", "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", diff --git a/tensorflow/lite/tools/evaluation/tasks/inference_diff/BUILD b/tensorflow/lite/tools/evaluation/tasks/inference_diff/BUILD index 8befc356d41..36606722caf 100644 --- a/tensorflow/lite/tools/evaluation/tasks/inference_diff/BUILD +++ b/tensorflow/lite/tools/evaluation/tasks/inference_diff/BUILD @@ -17,6 +17,7 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/tools:command_line_flags", "//tensorflow/lite/tools:logging", + "//tensorflow/lite/tools/evaluation:evaluation_stage", "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", "//tensorflow/lite/tools/evaluation/stages:inference_profiler_stage", diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD index 715afa642a3..c77c08eeb6b 100644 --- a/tensorflow/lite/tools/optimize/BUILD +++ b/tensorflow/lite/tools/optimize/BUILD @@ -43,13 +43,17 @@ tf_cc_test( "//tensorflow/lite/schema:schema_utils", "@com_google_absl//absl/memory", "@com_google_googletest//:gtest", + "@flatbuffers", ], ) cc_binary( name = "modify_model_interface_main", srcs = ["modify_model_interface_main.cc"], - deps = [":modify_model_interface"], + deps = [ + ":modify_model_interface", + ":quantize_model", + ], ) cc_library( @@ -61,6 +65,7 @@ cc_library( "//tensorflow/lite:framework", "//tensorflow/lite/core/api", "//tensorflow/lite/schema:schema_fbs", + "@flatbuffers", ], ) @@ -77,7 +82,9 @@ tf_cc_test( "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_utils", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", + "@flatbuffers", ], ) @@ -102,7 +109,9 @@ tf_cc_test( "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_utils", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", + "@flatbuffers", ], ) @@ -112,7 +121,11 @@ cc_library( hdrs = ["quantization_wrapper.h"], deps = [ ":quantization_wrapper_utils", - ":quantize_model", + "//tensorflow/lite:framework", + "//tensorflow/lite/core/api", + "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/tools/optimize:quantize_model", + "@flatbuffers", ], ) @@ -134,6 +147,7 @@ cc_library( "//tensorflow/lite/schema:schema_fbs", "//third_party/eigen3", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -151,6 +165,7 @@ cc_library( "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_utils", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -169,6 +184,7 @@ tf_cc_test( "//tensorflow/lite:framework", "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", + "@flatbuffers", ], ) @@ -179,6 +195,7 @@ cc_library( compatible_with = get_compatible_with_cloud(), deps = [ "//tensorflow/lite:framework", + "//tensorflow/lite/kernels/internal:types", "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_utils", ], @@ -206,8 +223,11 @@ tf_cc_test( "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_utils", "//tensorflow/lite/testing:util", + "//third_party/eigen3", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", + "@flatbuffers", ], ) @@ -220,6 +240,7 @@ cc_library( ":quantization_utils", ":model_utils", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/container:flat_hash_map", "@flatbuffers", "//tensorflow/lite:framework", @@ -238,10 +259,10 @@ tf_cc_test( "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - ":testdata/custom_op.bin", - ":testdata/quantized_with_gather.bin", - ":testdata/single_conv_weights_min_0_max_plus_10.bin", - ":testdata/weight_shared_between_convs.bin", + "//tensorflow/lite/tools/optimize:testdata/custom_op.bin", + "//tensorflow/lite/tools/optimize:testdata/quantized_with_gather.bin", + "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/lite/tools/optimize:testdata/weight_shared_between_convs.bin", ], tags = [ "tflite_not_portable_android", @@ -266,8 +287,10 @@ cc_library( srcs = ["test_util.cc"], hdrs = ["test_util.h"], deps = [ + "//tensorflow/lite:framework", "//tensorflow/lite/core/api", "@com_google_googletest//:gtest", + "@flatbuffers", ], ) @@ -295,32 +318,32 @@ tf_cc_test( "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - ":testdata/add_with_const_input.bin", - ":testdata/argmax.bin", - ":testdata/concat.bin", - ":testdata/fc.bin", - ":testdata/lstm_calibrated.bin", - ":testdata/lstm_calibrated2.bin", - ":testdata/lstm_quantized.bin", - ":testdata/lstm_quantized2.bin", - ":testdata/maximum.bin", - ":testdata/minimum.bin", - ":testdata/mixed.bin", - ":testdata/mixed16x8.bin", - ":testdata/multi_input_add_reshape.bin", - ":testdata/pack.bin", - ":testdata/single_avg_pool_min_minus_5_max_plus_5.bin", - ":testdata/single_conv_no_bias.bin", - ":testdata/single_conv_weights_min_0_max_plus_10.bin", - ":testdata/single_conv_weights_min_minus_127_max_plus_127.bin", - ":testdata/single_softmax_min_minus_5_max_plus_5.bin", - ":testdata/split.bin", - ":testdata/svdf_calibrated.bin", - ":testdata/svdf_quantized.bin", - ":testdata/transpose.bin", - ":testdata/unidirectional_sequence_lstm_calibrated.bin", - ":testdata/unidirectional_sequence_lstm_quantized.bin", - ":testdata/unpack.bin", + "//tensorflow/lite/tools/optimize:testdata/add_with_const_input.bin", + "//tensorflow/lite/tools/optimize:testdata/argmax.bin", + "//tensorflow/lite/tools/optimize:testdata/concat.bin", + "//tensorflow/lite/tools/optimize:testdata/fc.bin", + "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated.bin", + "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated2.bin", + "//tensorflow/lite/tools/optimize:testdata/lstm_quantized.bin", + "//tensorflow/lite/tools/optimize:testdata/lstm_quantized2.bin", + "//tensorflow/lite/tools/optimize:testdata/maximum.bin", + "//tensorflow/lite/tools/optimize:testdata/minimum.bin", + "//tensorflow/lite/tools/optimize:testdata/mixed.bin", + "//tensorflow/lite/tools/optimize:testdata/mixed16x8.bin", + "//tensorflow/lite/tools/optimize:testdata/multi_input_add_reshape.bin", + "//tensorflow/lite/tools/optimize:testdata/pack.bin", + "//tensorflow/lite/tools/optimize:testdata/single_avg_pool_min_minus_5_max_plus_5.bin", + "//tensorflow/lite/tools/optimize:testdata/single_conv_no_bias.bin", + "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_minus_127_max_plus_127.bin", + "//tensorflow/lite/tools/optimize:testdata/single_softmax_min_minus_5_max_plus_5.bin", + "//tensorflow/lite/tools/optimize:testdata/split.bin", + "//tensorflow/lite/tools/optimize:testdata/svdf_calibrated.bin", + "//tensorflow/lite/tools/optimize:testdata/svdf_quantized.bin", + "//tensorflow/lite/tools/optimize:testdata/transpose.bin", + "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_calibrated.bin", + "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_quantized.bin", + "//tensorflow/lite/tools/optimize:testdata/unpack.bin", ], tags = [ "tflite_not_portable_android", diff --git a/tensorflow/lite/tools/optimize/calibration/BUILD b/tensorflow/lite/tools/optimize/calibration/BUILD index 64f15857bdb..bf4a9b86233 100644 --- a/tensorflow/lite/tools/optimize/calibration/BUILD +++ b/tensorflow/lite/tools/optimize/calibration/BUILD @@ -70,7 +70,9 @@ cc_library( "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_utils", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@flatbuffers", ], ) @@ -99,6 +101,7 @@ tf_cc_test( "//tensorflow/lite/kernels:builtin_ops", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", ], ) @@ -125,6 +128,7 @@ cc_test( deps = [ ":logging_op_resolver", "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:builtin_ops", "@com_google_googletest//:gtest", ], ) @@ -139,6 +143,7 @@ cc_library( "//tensorflow/lite:framework", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -148,6 +153,7 @@ cc_library( hdrs = ["calibration_logger.h"], copts = tflite_copts(), deps = [ + "//tensorflow/lite:framework", "//tensorflow/lite:minimal_logging", "//tensorflow/lite/c:common", "//tensorflow/lite/core/api", diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc index b46393b964f..2c993945ea1 100644 --- a/tensorflow/lite/tools/optimize/operator_property.cc +++ b/tensorflow/lite/tools/optimize/operator_property.cc @@ -67,6 +67,9 @@ const OpVariant GetOperatorVariant(const ModelT* model, int subgraph_index, } } // namespace +// Update operation defintions in TensorFlow Lite dialect accordingly when there +// are any needs on updating the kernel support level. +// LINT.IfChange OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, int op_index) { OpVariant op_variant = GetOperatorVariant(model, subgraph_index, op_index); @@ -933,7 +936,6 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, property.inputs = {{0, {}}}; property.outputs = {{0, {}}}; property.version = 2; - property.quantizable_int16 = false; break; case BuiltinOperator_TANH: { property.inputs = {{0, {}}}; @@ -1002,7 +1004,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, property.quantizable_int16 = false; } return property; -} +} // NOLINT(readability/fn_size) +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_ops.td) } // namespace operator_property } // namespace optimize diff --git a/tensorflow/lite/tools/optimize/python/BUILD b/tensorflow/lite/tools/optimize/python/BUILD index b25f3c710fe..34f57cbecaf 100644 --- a/tensorflow/lite/tools/optimize/python/BUILD +++ b/tensorflow/lite/tools/optimize/python/BUILD @@ -26,6 +26,7 @@ py_library( deps = [ ":_pywrap_modify_model_interface", ":modify_model_interface_constants", + "//tensorflow:tensorflow_py", "//tensorflow/lite/python:schema_py", ], ) diff --git a/tensorflow/lite/experimental/writer/BUILD b/tensorflow/lite/tools/serialization/BUILD similarity index 100% rename from tensorflow/lite/experimental/writer/BUILD rename to tensorflow/lite/tools/serialization/BUILD diff --git a/tensorflow/lite/experimental/writer/enum_mapping.h b/tensorflow/lite/tools/serialization/enum_mapping.h similarity index 96% rename from tensorflow/lite/experimental/writer/enum_mapping.h rename to tensorflow/lite/tools/serialization/enum_mapping.h index 688ee406125..a79e25d844e 100644 --- a/tensorflow/lite/experimental/writer/enum_mapping.h +++ b/tensorflow/lite/tools/serialization/enum_mapping.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_ +#ifndef TENSORFLOW_LITE_TOOLS_SERIALIZATION_ENUM_MAPPING_H_ +#define TENSORFLOW_LITE_TOOLS_SERIALIZATION_ENUM_MAPPING_H_ #include "tensorflow/lite/builtin_op_data.h" #include "tensorflow/lite/schema/reflection/schema_generated.h" @@ -147,4 +147,4 @@ inline CombinerType CombinerTypeToSchema(TfLiteCombinerType type) { // int } // namespace tflite -#endif // TENSORFLOW_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_ +#endif // TENSORFLOW_LITE_TOOLS_SERIALIZATION_ENUM_MAPPING_H_ diff --git a/tensorflow/lite/experimental/writer/option_writer_generator.cc b/tensorflow/lite/tools/serialization/option_writer_generator.cc similarity index 100% rename from tensorflow/lite/experimental/writer/option_writer_generator.cc rename to tensorflow/lite/tools/serialization/option_writer_generator.cc diff --git a/tensorflow/lite/experimental/writer/writer.cc b/tensorflow/lite/tools/serialization/writer.cc similarity index 96% rename from tensorflow/lite/experimental/writer/writer.cc rename to tensorflow/lite/tools/serialization/writer.cc index 3977c8e1003..fb816792b6a 100644 --- a/tensorflow/lite/experimental/writer/writer.cc +++ b/tensorflow/lite/tools/serialization/writer.cc @@ -20,9 +20,9 @@ limitations under the License. #include <iostream> -#include "tensorflow/lite/experimental/writer/writer_lib.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" +#include "tensorflow/lite/tools/serialization/writer_lib.h" int main(int argc, char* argv[]) { if (argc != 3) { diff --git a/tensorflow/lite/experimental/writer/writer_lib.cc b/tensorflow/lite/tools/serialization/writer_lib.cc similarity index 98% rename from tensorflow/lite/experimental/writer/writer_lib.cc rename to tensorflow/lite/tools/serialization/writer_lib.cc index 9f18fff76d5..0d831f5f9a0 100644 --- a/tensorflow/lite/experimental/writer/writer_lib.cc +++ b/tensorflow/lite/tools/serialization/writer_lib.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/writer/writer_lib.h" +#include "tensorflow/lite/tools/serialization/writer_lib.h" #include <cstdlib> #include <cstring> @@ -23,9 +23,9 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/context_util.h" #include "tensorflow/lite/core/subgraph.h" -#include "tensorflow/lite/experimental/writer/enum_mapping.h" #include "tensorflow/lite/schema/reflection/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" +#include "tensorflow/lite/tools/serialization/enum_mapping.h" #include "tensorflow/lite/version.h" namespace tflite { @@ -34,7 +34,7 @@ std::pair<BuiltinOptions, flatbuffers::Offset<void>> CreateBuiltinUnion( flatbuffers::FlatBufferBuilder* fbb, enum BuiltinOperator op, void* builtin_op_data, const TfLiteNode& node) { switch (op) { -#include "tensorflow/lite/experimental/writer/option_writer_generated.h" +#include "tensorflow/lite/tools/serialization/option_writer_generated.h" } return std::make_pair(BuiltinOptions_NONE, flatbuffers::Offset<void>()); } diff --git a/tensorflow/lite/experimental/writer/writer_lib.h b/tensorflow/lite/tools/serialization/writer_lib.h similarity index 96% rename from tensorflow/lite/experimental/writer/writer_lib.h rename to tensorflow/lite/tools/serialization/writer_lib.h index f7816dcc33e..a18a3dd0958 100644 --- a/tensorflow/lite/experimental/writer/writer_lib.h +++ b/tensorflow/lite/tools/serialization/writer_lib.h @@ -24,8 +24,8 @@ limitations under the License. // // Build Interpreter however // // ... <omitted> // SubgraphWriter(&interpreter->primary_subgraph()).Write("output.tflite"); -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_ +#ifndef TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_ +#define TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_ #include <iostream> #include <unordered_map> @@ -33,8 +33,8 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/context_util.h" #include "tensorflow/lite/core/subgraph.h" -#include "tensorflow/lite/experimental/writer/enum_mapping.h" #include "tensorflow/lite/schema/reflection/schema_generated.h" +#include "tensorflow/lite/tools/serialization/enum_mapping.h" #include "tensorflow/lite/version.h" namespace tflite { @@ -149,4 +149,4 @@ class SubgraphWriter { } // namespace tflite -#endif // TENSORFLOW_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_ +#endif // TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_ diff --git a/tensorflow/lite/experimental/writer/writer_lib_test.cc b/tensorflow/lite/tools/serialization/writer_lib_test.cc similarity index 99% rename from tensorflow/lite/experimental/writer/writer_lib_test.cc rename to tensorflow/lite/tools/serialization/writer_lib_test.cc index bf50d4944f1..189b4bc106f 100644 --- a/tensorflow/lite/experimental/writer/writer_lib_test.cc +++ b/tensorflow/lite/tools/serialization/writer_lib_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/writer/writer_lib.h" +#include "tensorflow/lite/tools/serialization/writer_lib.h" #include <numeric> #include <sstream> diff --git a/tensorflow/lite/experimental/writer/writer_test.cc b/tensorflow/lite/tools/serialization/writer_test.cc similarity index 97% rename from tensorflow/lite/experimental/writer/writer_test.cc rename to tensorflow/lite/tools/serialization/writer_test.cc index ac89b74291f..ccaab76776b 100644 --- a/tensorflow/lite/experimental/writer/writer_test.cc +++ b/tensorflow/lite/tools/serialization/writer_test.cc @@ -21,9 +21,9 @@ limitations under the License. #include <iostream> -#include "tensorflow/lite/experimental/writer/writer_lib.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" +#include "tensorflow/lite/tools/serialization/writer_lib.h" int main(int argc, char* argv[]) { if (argc != 2) { diff --git a/tensorflow/lite/tools/signature/BUILD b/tensorflow/lite/tools/signature/BUILD index 7fd83562258..05fc106d759 100644 --- a/tensorflow/lite/tools/signature/BUILD +++ b/tensorflow/lite/tools/signature/BUILD @@ -81,7 +81,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":_pywrap_signature_def_util_wrapper", - "//tensorflow/core/protobuf:meta_graph_pyclif_pb2", + "//tensorflow/core:protos_all_py", ], ) @@ -98,7 +98,7 @@ py_test( deps = [ ":signature_def_utils", "//tensorflow:tensorflow_py", - "//tensorflow/core/protobuf:meta_graph_pyclif_pb2", + "//tensorflow/core:protos_all_py", ], ) diff --git a/tensorflow/lite/tools/versioning/BUILD b/tensorflow/lite/tools/versioning/BUILD index 31fb903ce8d..06ac1968f52 100644 --- a/tensorflow/lite/tools/versioning/BUILD +++ b/tensorflow/lite/tools/versioning/BUILD @@ -43,6 +43,7 @@ tf_cc_test( ":versioning", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_fbs_with_mutable", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index 97ab4635cec..bf8bd7ab6bb 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -5,6 +5,16 @@ tensorflow/compat_template.__init__.py tensorflow/compat_template_v1.__init__.py tensorflow/compiler/mlir/glob_lit_test.bzl tensorflow/go/op/wrappers.go +tensorflow/lite/core/shims/BUILD +tensorflow/lite/core/shims/c/builtin_op_data.h +tensorflow/lite/core/shims/c/c_api.h +tensorflow/lite/core/shims/c/c_api_experimental.h +tensorflow/lite/core/shims/c/common.h +tensorflow/lite/core/shims/cc/interpreter.h +tensorflow/lite/core/shims/cc/interpreter_builder.h +tensorflow/lite/core/shims/cc/kernels/register.h +tensorflow/lite/core/shims/cc/model.h +tensorflow/lite/core/shims/cc/model_builder.h tensorflow/lite/delegates/gpu/cl/compiled_program_cache_generated.h tensorflow/lite/delegates/gpu/cl/serialization_generated.h tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index e3b016e0201..31c205c638e 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3246,7 +3246,7 @@ py_library( ], ) -py_test( +cuda_py_test( name = "batch_ops_test", size = "small", srcs = ["ops/batch_ops_test.py"], diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index fa894acf9bd..86252637df1 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -33,7 +33,7 @@ from tensorflow.python.util.tf_export import tf_export # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 11, 18) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 11, 23) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None diff --git a/tensorflow/python/compiler/tensorrt/test/annotate_max_batch_sizes_test.py b/tensorflow/python/compiler/tensorrt/test/annotate_max_batch_sizes_test.py index 7eadb001708..8ef97107e33 100644 --- a/tensorflow/python/compiler/tensorrt/test/annotate_max_batch_sizes_test.py +++ b/tensorflow/python/compiler/tensorrt/test/annotate_max_batch_sizes_test.py @@ -59,19 +59,11 @@ class MaxBatchSizesTestBase(trt_test.TfTrtIntegrationTestBase): # engines. return (not run_params.dynamic_engine, 'test static engine only.') - def GetConversionParams(self, run_params): - """Returns a ConversionParams for test.""" - conversion_params = super(MaxBatchSizesTestBase, - self).GetConversionParams(run_params) - conversion_params._replace( - max_batch_size=min(self.max_batch_sizes), maximum_cached_engines=1) - rewrite_config_with_trt = self.GetTrtRewriterConfig( - run_params=run_params, - conversion_params=conversion_params, - use_implicit_batch=True, - disable_non_trt_optimizers=True) - return conversion_params._replace( - rewriter_config_template=rewrite_config_with_trt) + def GetMaxBatchSize(self, run_params): + """Returns the max_batch_size that the converter should use for tests.""" + if run_params.dynamic_engine: + return None + return min(self.max_batch_sizes) def ExpectedEnginesToBuild(self, run_params): """Checks that the expected engine is built. diff --git a/tensorflow/python/compiler/tensorrt/test/base_test.py b/tensorflow/python/compiler/tensorrt/test/base_test.py index 9d2d3abd4fb..b43749fd305 100644 --- a/tensorflow/python/compiler/tensorrt/test/base_test.py +++ b/tensorflow/python/compiler/tensorrt/test/base_test.py @@ -117,18 +117,11 @@ class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase): "TRTEngineOp_1": ["c2", "conv", "div", "weights"] } - def GetConversionParams(self, run_params): - """Return a ConversionParams for test.""" - conversion_params = super(SimpleMultiEnginesTest, - self).GetConversionParams(run_params) - rewrite_config_with_trt = self.GetTrtRewriterConfig( - run_params=run_params, - conversion_params=conversion_params, - # Disable layout optimizer, since it will convert BiasAdd with NHWC - # format to NCHW format under four dimentional input. - disable_non_trt_optimizers=True) - return conversion_params._replace( - rewriter_config_template=rewrite_config_with_trt) + def setUp(self): + super(trt_test.TfTrtIntegrationTestBase, self).setUp() + # Disable layout optimizer, since it will convert BiasAdd with NHWC + # format to NCHW format under four dimentional input. + self.DisableNonTrtOptimizers() class SimpleMultiEnginesTest2(trt_test.TfTrtIntegrationTestBase): diff --git a/tensorflow/python/compiler/tensorrt/test/biasadd_matmul_test.py b/tensorflow/python/compiler/tensorrt/test/biasadd_matmul_test.py index 04445cc99aa..b229cff47dc 100644 --- a/tensorflow/python/compiler/tensorrt/test/biasadd_matmul_test.py +++ b/tensorflow/python/compiler/tensorrt/test/biasadd_matmul_test.py @@ -106,19 +106,18 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase): return self.BuildParams(self.GraphFn, dtypes.float32, [[4, 144]], [[4, 6680]]) - def GetConversionParams(self, run_params): - """Return a ConversionParams for test.""" - conversion_params = super(BiasaddMatMulTest, - self).GetConversionParams(run_params) - conversion_params._replace(max_batch_size=4, maximum_cached_engines=1) - rewrite_config_with_trt = self.GetTrtRewriterConfig( - run_params=run_params, - conversion_params=conversion_params, - # Disable layout optimizer, since it will convert BiasAdd with NHWC - # format to NCHW format under four dimensional input. - disable_non_trt_optimizers=True) - return conversion_params._replace( - rewriter_config_template=rewrite_config_with_trt) + def setUp(self): + super(trt_test.TfTrtIntegrationTestBase, self).setUp() + # Disable layout optimizer, since it will convert BiasAdd with NHWC + # format to NCHW format under four dimentional input. + self.DisableNonTrtOptimizers() + + def GetMaxBatchSize(self, run_params): + """Returns the max_batch_size that the converter should use for tests.""" + if run_params.dynamic_engine: + return None + + return 4 def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" diff --git a/tensorflow/python/compiler/tensorrt/test/combined_nms_test.py b/tensorflow/python/compiler/tensorrt/test/combined_nms_test.py index 23397d76cc3..6d45d358b82 100644 --- a/tensorflow/python/compiler/tensorrt/test/combined_nms_test.py +++ b/tensorflow/python/compiler/tensorrt/test/combined_nms_test.py @@ -91,13 +91,10 @@ class CombinedNmsTest(trt_test.TfTrtIntegrationTestBase): } def ShouldRunTest(self, run_params): - # There is no CombinedNonMaxSuppression op for GPU at the moment, so - # calibration will fail. - # TODO(laigd): fix this. + should_run, reason = super().ShouldRunTest(run_params) # Only run for TRT 5.1 and above. - return trt_test.IsTensorRTVersionGreaterEqual( - 5, 1) and not trt_test.IsQuantizationMode( - run_params.precision_mode), 'test >=TRT5.1 and non-INT8' + return should_run and trt_test.IsTensorRTVersionGreaterEqual( + 5, 1), reason + ' and >=TRT5.1' class CombinedNmsExecuteNativeSegmentTest(CombinedNmsTest): @@ -110,15 +107,17 @@ class CombinedNmsExecuteNativeSegmentTest(CombinedNmsTest): super().tearDown() os.environ['TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION'] = 'False' - def GetConversionParams(self, run_params): - conversion_param = super().GetConversionParams(run_params) + def GetMaxBatchSize(self, run_params): + """Returns the max_batch_size that the converter should use for tests.""" + if run_params.dynamic_engine: + return None + # Build the engine with the allowed max_batch_size less than the actual # max_batch_size, to fore the runtime to execute the native segment. This # is to test that combined_non_max_suppression, which doesn't have a TF GPU # implementation, can be executed natively even though the it is in the # the graph for the TRTEngineOp with a GPU as a default device. - return conversion_param._replace( - max_batch_size=conversion_param.max_batch_size - 1) + return super().GetMaxBatchSize(run_params) - 1 def ShouldRunTest(self, run_params): should_run, reason = super().ShouldRunTest(run_params) diff --git a/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py b/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py index 95dbe727ac3..2cf9abe8455 100644 --- a/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py +++ b/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py @@ -80,19 +80,11 @@ class DynamicInputShapesTest(trt_test.TfTrtIntegrationTestBase): input_dims=input_dims, expected_output_dims=expected_output_dims) - def GetConversionParams(self, run_params): - """Return a ConversionParams for test.""" - conversion_params = super(DynamicInputShapesTest, - self).GetConversionParams(run_params) - conversion_params._replace(maximum_cached_engines=10) - rewrite_config_with_trt = self.GetTrtRewriterConfig( - run_params=run_params, - conversion_params=conversion_params, - # Disable layout optimizer, since it will convert BiasAdd with NHWC - # format to NCHW format under four dimensional input. - disable_non_trt_optimizers=True) - return conversion_params._replace( - rewriter_config_template=rewrite_config_with_trt) + def setUp(self): + super(trt_test.TfTrtIntegrationTestBase, self).setUp() + # Disable layout optimizer, since it will convert BiasAdd with NHWC + # format to NCHW format under four dimentional input. + self.DisableNonTrtOptimizers() def ExpectedEnginesToBuild(self, run_params): return ["TRTEngineOp_0"] diff --git a/tensorflow/python/compiler/tensorrt/test/int32_test.py b/tensorflow/python/compiler/tensorrt/test/int32_test.py index ecc68656a60..638dbb5727a 100644 --- a/tensorflow/python/compiler/tensorrt/test/int32_test.py +++ b/tensorflow/python/compiler/tensorrt/test/int32_test.py @@ -46,19 +46,17 @@ class ExcludeUnsupportedInt32Test(trt_test.TfTrtIntegrationTestBase): def GetParams(self): return self.BuildParams(self.GraphFn, dtypes.int32, [[100, 4]], [[100, 10]]) - def GetConversionParams(self, run_params): - """Return a ConversionParams for test.""" - conversion_params = super(ExcludeUnsupportedInt32Test, - self).GetConversionParams(run_params) - conversion_params._replace(max_batch_size=100, maximum_cached_engines=1) - rewrite_config_with_trt = self.GetTrtRewriterConfig( - run_params=run_params, - conversion_params=conversion_params, - # Disable layout optimizer, since it will convert BiasAdd with NHWC - # format to NCHW format under four dimensional input. - disable_non_trt_optimizers=True) - return conversion_params._replace( - rewriter_config_template=rewrite_config_with_trt) + def setUp(self): + super(trt_test.TfTrtIntegrationTestBase, self).setUp() + # Disable layout optimizer, since it will convert BiasAdd with NHWC + # format to NCHW format under four dimentional input. + self.DisableNonTrtOptimizers() + + def GetMaxBatchSize(self, run_params): + """Returns the max_batch_size that the converter should use for tests.""" + if run_params.dynamic_engine: + return None + return 100 def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" diff --git a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py index 2265a19cf62..fc32e7e3b3d 100644 --- a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py @@ -159,6 +159,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): def __init__(self, methodName="runTest"): # pylint: disable=invalid-name super(TfTrtIntegrationTestBase, self).__init__(methodName) self._trt_test_params = None + self._disable_non_trt_optimizers = False + self._use_implicit_batch = True def setUp(self): """Setup method.""" @@ -257,12 +259,33 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): input_dims=[input_shapes] + extra_inputs, expected_output_dims=[output_shapes] + extra_outputs) + def DisableNonTrtOptimizers(self): + self._disable_non_trt_optimizers = True + + def DisableImplicitBatchMode(self): + self._use_implicit_batch = False + def GetParams(self): - """Return a TfTrtIntegrationTestParams for test, implemented by subclass.""" + """Returns a TfTrtIntegrationTestParams for the test.""" raise NotImplementedError() def GetConversionParams(self, run_params): - """Return a TrtConversionParams for test.""" + """Returns a TrtConversionParams for test.""" + conversion_params = trt_convert.TrtConversionParams( + # We use the minimum of all the batch sizes, so when multiple different + # input shapes are provided it'll always create new engines in the + # cache, and we can therefore test the cache behavior. + max_workspace_size_bytes=1 << 25, + precision_mode=run_params.precision_mode, + minimum_segment_size=2, + maximum_cached_engines=1, + use_calibration=run_params.use_calibration) + return conversion_params + + def GetMaxBatchSize(self, run_params): + """Returns the max_batch_size that the converter should use for tests.""" + if run_params.dynamic_engine: + return None batch_list = [] for dims_list in self._GetParamsCached().input_dims: assert dims_list @@ -270,33 +293,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): input_batches = [dims[0] for dims in dims_list] assert max(input_batches) == min(input_batches) batch_list.append(input_batches[0]) - conversion_params = trt_convert.TrtConversionParams( - # We use the minimum of all the batch sizes, so when multiple different - # input shapes are provided it'll always create new engines in the - # cache, and we can therefore test the cache behavior. - rewriter_config_template=None, - max_workspace_size_bytes=1 << 25, - precision_mode=run_params.precision_mode, - minimum_segment_size=2, - is_dynamic_op=run_params.dynamic_engine, - maximum_cached_engines=1, - use_calibration=run_params.use_calibration, - max_batch_size=max(batch_list)) - return conversion_params - - def GetTrtRewriterConfig(self, - run_params, - conversion_params, - disable_non_trt_optimizers=False, - use_implicit_batch=True): - rewriter_config = trt_convert.get_tensorrt_rewriter_config( - conversion_params=conversion_params, - is_v2=run_params.is_v2, - disable_non_trt_optimizers=disable_non_trt_optimizers) - for optimizer in rewriter_config.custom_optimizers: - if optimizer.name == "TensorRTOptimizer": - optimizer.parameter_map["use_implicit_batch"].b = use_implicit_batch - return rewriter_config + return max(batch_list) def ShouldRunTest(self, run_params): """Whether to run the test.""" @@ -333,14 +330,15 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): def _GetConfigProto(self, run_params, graph_state): """Get config proto based on specific settings.""" conversion_params = self.GetConversionParams(run_params) + max_batch_size = self.GetMaxBatchSize(run_params) if graph_state == GraphState.INFERENCE and run_params.convert_online: - rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(conversion_params) + rewriter_cfg = trt_convert.get_tensorrt_rewriter_config( + conversion_params, + is_dynamic_op=run_params.dynamic_engine, + max_batch_size=max_batch_size) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg) else: graph_options = config_pb2.GraphOptions() - if conversion_params.rewriter_config_template is not None: - graph_options.rewrite_options.CopyFrom( - conversion_params.rewriter_config_template) config = config_pb2.ConfigProto( gpu_options=self._GetGPUOptions(), graph_options=graph_options) @@ -444,30 +442,37 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): config = self._GetConfigProto(run_params, GraphState.INFERENCE) return self._RunGraphV1(saved_model_dir, inputs_data, config, num_runs) - def _CreateConverter(self, run_params, saved_model_dir, session_config, - conversion_params): - """Return a TrtGraphConverter.""" + def _CreateConverter(self, run_params, saved_model_dir, conversion_params): + """Returns a TrtGraphConverter.""" if run_params.is_v2: - return trt_convert.TrtGraphConverterV2( + converter_v2 = trt_convert.TrtGraphConverterV2( input_saved_model_dir=saved_model_dir, conversion_params=conversion_params) - return trt_convert.TrtGraphConverter( + if self._disable_non_trt_optimizers: + converter_v2._test_only_disable_non_trt_optimizers = True # pylint: disable=protected-access + if not self._use_implicit_batch: + converter_v2._test_only_use_implicit_batch = False # pylint: disable=protected-access + return converter_v2 + + converter_v1 = trt_convert.TrtGraphConverter( input_saved_model_dir=saved_model_dir, - session_config=session_config, - max_batch_size=conversion_params.max_batch_size, + max_batch_size=self.GetMaxBatchSize(run_params), max_workspace_size_bytes=conversion_params.max_workspace_size_bytes, precision_mode=conversion_params.precision_mode, minimum_segment_size=conversion_params.minimum_segment_size, - is_dynamic_op=conversion_params.is_dynamic_op, + is_dynamic_op=run_params.dynamic_engine, maximum_cached_engines=conversion_params.maximum_cached_engines, use_calibration=conversion_params.use_calibration) + if self._disable_non_trt_optimizers: + converter_v1._test_only_disable_non_trt_optimizers = True # pylint: disable=protected-access + return converter_v1 def _GetCalibratedInferGraph(self, run_params, saved_model_dir, inputs_data): """Return trt converted graphdef in INT8 mode.""" conversion_params = self.GetConversionParams(run_params) logging.info(conversion_params) assert conversion_params.precision_mode == "INT8" - assert conversion_params.is_dynamic_op + assert run_params.dynamic_engine assert conversion_params.maximum_cached_engines == 1 assert conversion_params.use_calibration @@ -475,11 +480,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): # TODO(aaroey): fix this. assert len(inputs_data) == 1 - session_config = self._GetConfigProto(run_params, GraphState.CALIBRATE) - logging.info("Running calibration graph, config:\n%s", str(session_config)) - converter = self._CreateConverter(run_params, saved_model_dir, - session_config, conversion_params) + conversion_params) int8_gdef = converter.convert() self._VerifyGraphDef(run_params, saved_model_dir, int8_gdef, GraphState.CALIBRATE) @@ -498,15 +500,11 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): conversion_params = self.GetConversionParams(run_params) logging.info(conversion_params) - session_config = self._GetConfigProto(run_params, GraphState.INFERENCE) - logging.info("Creating TRT graph for inference, config\n%s", - str(session_config)) converter = self._CreateConverter(run_params, saved_model_dir, - session_config, conversion_params) + conversion_params) converter.convert() - if trt_convert.is_explicit_batch_mode_enabled( - conversion_params.rewriter_config_template): + if not self._use_implicit_batch: logging.info("Using build mode") def _BuildInputFn(): @@ -684,7 +682,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): original_gdef: GraphDef. The graph def before TensorRT conversion. converted_gdef: GraphDef. The graph def after TensorRT conversion. default_max_batch_size: The default maximum batch size to use if no node - inside a segment is annoted with a customized max batch size. + inside a segment is annoted with a customized max batch size. This value + is None when the graph is converted to TF-TRT with dynamic engines. expected_max_batch_sizes: Optional. A sequence of max batch sizes for all the engines. `None` if does not check enforce max batch sizes. """ @@ -780,7 +779,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): node_max_batch_size is None) logging.info("'{%s}'s max batch size is %d.", engine_name, engine_max_batch_size) - self.assertTrue(engine_max_batch_size == default_max_batch_size or + self.assertTrue(default_max_batch_size is None or + engine_max_batch_size == default_max_batch_size or not node_max_batch_size_all_none) self.assertCountEqual(expected_engines, tuple(name_to_engines_map.keys())) @@ -851,9 +851,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): original_gdef=original_gdef, converted_gdef=gdef_to_verify, expected_max_batch_sizes=self.ExpectedMaxBatchSizes(run_params), - default_max_batch_size=self.GetConversionParams( - run_params).max_batch_size, - ) + default_max_batch_size=self.GetMaxBatchSize(run_params)) def _VerifyGraphDefV2(self, run_params, original_gdef, gdef_to_verify, graph_state): @@ -886,14 +884,6 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): expected_engines = set(expected_engines.keys()) self.assertEqual(set(expected_engines), trt_op_names) - self._VerifyMaxBatchSizeAnnotations( - expected_engines=expected_engines, - original_gdef=original_gdef, - converted_gdef=gdef_to_verify, - expected_max_batch_sizes=self.ExpectedMaxBatchSizes(run_params), - default_max_batch_size=self.GetConversionParams( - run_params).max_batch_size, - ) def _VerifyGraphDef(self, run_params, original_gdef_or_saved_model_dir, gdef_or_saved_model_dir_to_verify, graph_state): diff --git a/tensorflow/python/compiler/tensorrt/test/trt_mode_test.py b/tensorflow/python/compiler/tensorrt/test/trt_mode_test.py index 5c7f261fa98..f5eb3c75653 100644 --- a/tensorflow/python/compiler/tensorrt/test/trt_mode_test.py +++ b/tensorflow/python/compiler/tensorrt/test/trt_mode_test.py @@ -59,25 +59,6 @@ class TrtModeTestBase(trt_test.TfTrtIntegrationTestBase): return self.BuildParams(self.GraphFn, dtypes.float32, [[1, 12, 5]], [[12, 5]]) - def GetConversionParams(self, - run_params, - max_batch_size=0, - implicit_batch=False): - """Return a TrtConversionParams for test.""" - - conversion_params = super(TrtModeTestBase, - self).GetConversionParams(run_params) - # If max_batch_size!=0, use the value for conversion_params. - if max_batch_size and implicit_batch: - conversion_params = conversion_params._replace( - max_batch_size=max_batch_size) - - rewriter_config = self.GetTrtRewriterConfig( - run_params=run_params, - conversion_params=conversion_params, - use_implicit_batch=implicit_batch) - return conversion_params._replace(rewriter_config_template=rewriter_config) - @classmethod def setUpClass(cls): if cls is TrtModeTestBase: @@ -87,12 +68,13 @@ class TrtModeTestBase(trt_test.TfTrtIntegrationTestBase): class ImplicitBatchTest(TrtModeTestBase): - def GetConversionParams(self, run_params): - """Return a TrtConversionParams for test using implicit batch mdoe.""" + def GetMaxBatchSize(self, run_params): + if run_params.dynamic_engine: + return None + # The first dimension of the input is squeezed and the batch size for the # rest OPs is 12. - return super(ImplicitBatchTest, - self).GetConversionParams(run_params, 12, True) + return 12 def ExpectedEnginesToBuild(self, run_params): """Check that the expected engine is built. @@ -123,11 +105,6 @@ class ExplicitBatchTest(TrtModeTestBase): extra_inputs=[], extra_outputs=[]) - def GetConversionParams(self, run_params): - """Return a TrtConversionParams for test that enables explicit batch.""" - return super(ExplicitBatchTest, self).GetConversionParams( - run_params, implicit_batch=False) - def ExpectedEnginesToBuild(self, run_params): """Check that the expected engine is built. @@ -146,6 +123,11 @@ class ExplicitBatchTest(TrtModeTestBase): return run_params.is_v2 and trt_test.IsTensorRTVersionGreaterEqual(6) and ( not run_params.use_calibration), "test v2, >=TRT6 and non-calibration" + def setUp(self): + super().setUp() + # Diable implicit batch mode for testing explicit batch mode. + self.DisableImplicitBatchMode() + class DynamicShapesTest(TrtModeTestBase): """Test with dynamic input shapes. @@ -169,10 +151,6 @@ class DynamicShapesTest(TrtModeTestBase): input_mask=[[False, False, False]], output_mask=[[False, False]]) - def GetConversionParams(self, run_params): - """Return a TrtConversionParams for test that enables explicit batch.""" - return super(DynamicShapesTest, self).GetConversionParams(run_params, False) - def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" return ["TRTEngineOp_0"] @@ -182,6 +160,9 @@ class DynamicShapesTest(TrtModeTestBase): return run_params.is_v2 and trt_test.IsTensorRTVersionGreaterEqual(6) and ( not run_params.use_calibration), "test v2 >=TRT6 and non-calibration" + def setUp(self): + super().setUp() + self.DisableImplicitBatchMode() if __name__ == "__main__": test.main() diff --git a/tensorflow/python/compiler/tensorrt/trt_convert.py b/tensorflow/python/compiler/tensorrt/trt_convert.py index 315da90598d..734b13aad1e 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert.py @@ -117,16 +117,12 @@ DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30 @tf_export("experimental.tensorrt.ConversionParams", v1=[]) class TrtConversionParams( collections.namedtuple("TrtConversionParams", [ - "rewriter_config_template", "max_workspace_size_bytes", - "precision_mode", "minimum_segment_size", "is_dynamic_op", - "maximum_cached_engines", "use_calibration", "max_batch_size", - "allow_build_at_runtime" + "max_workspace_size_bytes", "precision_mode", "minimum_segment_size", + "maximum_cached_engines", "use_calibration", "allow_build_at_runtime" ])): """Parameters that are used for TF-TRT conversion. Fields: - rewriter_config_template: a template RewriterConfig proto used to create a - TRT-enabled RewriterConfig. If None, it will use a default one. max_workspace_size_bytes: the maximum GPU temporary memory which the TRT engine can use at execution time. This corresponds to the 'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize(). @@ -134,11 +130,6 @@ class TrtConversionParams( TrtPrecisionMode.supported_precision_modes(). minimum_segment_size: the minimum number of nodes required for a subgraph to be replaced by TRTEngineOp. - is_dynamic_op: whether to generate dynamic TRT ops which will build the - TRT network and engine at run time. i.e. Since TensorRT version < 6.0 - does not support dynamic dimensions other than the batch dimension, when - the TensorFlow graph has a non-batch dimension of dynamic size, we would - need to enable this option. This option should be set to True in TF 2.0. maximum_cached_engines: max number of cached TRT engines for dynamic TRT ops. Created TRT engines for a dynamic dimension are cached. This is the maximum number of engines that can be cached. If the number of cached @@ -154,31 +145,23 @@ class TrtConversionParams( will occur. Please note that accuracy may be negatively affected if there is a mismatch between which tensors TRT quantizes and which tensors were trained with fake quantization. - max_batch_size: max size for the input batch. This parameter is only - effective when use_implicit_batch is true. allow_build_at_runtime: whether to build TensorRT engines during runtime. If no TensorRT engine can be found in cache that can handle the given inputs during runtime, then a new TensorRT engine is built at runtime if - allow_build_at_runtime=True, and otherwise native TF is used. This - argument is only effective if is_dynamic_op=True. + allow_build_at_runtime=True, and otherwise native TF is used. """ def __new__(cls, - rewriter_config_template=None, max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES, precision_mode=TrtPrecisionMode.FP32, minimum_segment_size=3, - is_dynamic_op=True, maximum_cached_engines=1, use_calibration=True, - max_batch_size=1, allow_build_at_runtime=True): return super(TrtConversionParams, - cls).__new__(cls, rewriter_config_template, - max_workspace_size_bytes, precision_mode, - minimum_segment_size, is_dynamic_op, - maximum_cached_engines, use_calibration, - max_batch_size, allow_build_at_runtime) + cls).__new__(cls, max_workspace_size_bytes, precision_mode, + minimum_segment_size, maximum_cached_engines, + use_calibration, allow_build_at_runtime) DEFAULT_TRT_CONVERSION_PARAMS = TrtConversionParams() @@ -203,51 +186,6 @@ def _check_conversion_params(conversion_params, is_v2=False): ("precision mode '{}' is not supported." "It should be one of {}").format(conversion_params.precision_mode, supported_precision_modes)) - if is_v2: - # Static mode (building TRT engine without executing the op) is deprecated - # in TF 2.0. See TrtGraphConverterV2 for more details. - if not conversion_params.is_dynamic_op: - raise ValueError("Option is_dynamic_op=False is not supported in TF 2.0, " - "please set it to True instead.") - - if conversion_params.rewriter_config_template: - rewriter_cfg = conversion_params.rewriter_config_template - trt_optimizer = None - for optimizer in rewriter_cfg.custom_optimizers: - if optimizer.name == "TensorRTOptimizer": - if trt_optimizer: - raise ValueError( - "Found more than one TensorRTOptimizer in " - "rewriter_config_template while only one is allowed.") - trt_optimizer = optimizer - # If rewriter_config_template is set, it should include TensorRTOptimizer. - # It is possible to remove this requirement if needed. - if not trt_optimizer: - raise ValueError( - "Found no TensorRTOptimizer in rewriter_config_template.") - if not trt_optimizer.parameter_map: - raise ValueError("Found no parameter_map in TensorRTOptimizer.") - if ("precision_mode" in trt_optimizer.parameter_map.keys() and - trt_optimizer.parameter_map["precision_mode"].s not in map( - _to_bytes, supported_precision_modes)): - raise ValueError(("precision_mode '{}' is not supported. " - "It should be one of {}").format( - trt_optimizer.parameter_map["precision_mode"], - supported_precision_modes)) - if is_v2: - # Static mode (building TRT engine without executing the op) is not - # supported in TF 2.0. See TrtGraphConverterV2 for more details. - if ("is_dynamic_op" in trt_optimizer.parameter_map.keys() and - not trt_optimizer.parameter_map["is_dynamic_op"]): - raise ValueError("Option is_dynamic_op=False is not supported " - "in TF 2.0, please set it to True instead.") - if (conversion_params.allow_build_at_runtime and - not conversion_params.is_dynamic_op): - tf_logging.warn( - ("Building TensorRT engines at runtime is not supported " - "if is_dynamic_op=False, therefore assuming " - "allow_build_at_runtime=False. If building TensorRT engines " - "at runtime is desired, set is_dynamic_op=True.")) def _check_trt_version_compatibility(): @@ -290,15 +228,21 @@ def _check_trt_version_compatibility(): " minor/patch upgrades are backward compatible") -def get_tensorrt_rewriter_config(conversion_params, - is_v2=False, - disable_non_trt_optimizers=False): +def _get_tensorrt_rewriter_config(conversion_params, + is_dynamic_op=None, + max_batch_size=None, + is_v2=False, + disable_non_trt_optimizers=False, + use_implicit_batch=True): """Returns a RewriterConfig proto for TRT transformation. Args: conversion_params: a TrtConversionParams instance. + is_dynamic_op: whether to use dynamic engines. + max_batch_size: maximum batch size for static engines. is_v2: whether we're getting a RewriterConfig for TF 2.0. disable_non_trt_optimizers: Turn off all default Grappler optimizers. + use_implicit_batch: Whether to use implicit batch or explicit batch. Returns: A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler. @@ -307,46 +251,49 @@ def get_tensorrt_rewriter_config(conversion_params, TypeError: if any of the parameters are of unexpected type. ValueError: if any of the parameters are of unexpected value. """ - if conversion_params.rewriter_config_template is not None and not isinstance( - conversion_params.rewriter_config_template, - rewriter_config_pb2.RewriterConfig): - raise TypeError( - "rewriter_config_template should be a RewriterConfig proto.") _check_conversion_params(conversion_params, is_v2=is_v2) + if is_v2 and is_dynamic_op is not None and not is_dynamic_op: + raise ValueError("is_dynamic_op is either None or True for TF2") + if not is_v2 and is_dynamic_op is None: + raise ValueError("is_dynamic_op can't be None for TF1") + if (is_dynamic_op is None or is_dynamic_op) and max_batch_size is not None: + raise ValueError("max_batch_size has to be None for TF2" + " or when is_dynamic_op == True in TF1") + if is_dynamic_op is not None and not is_dynamic_op and not isinstance( + max_batch_size, int): + raise ValueError( + "max_batch_size has to be an integer for is_dynamic_op==False in TF1") rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig() - if conversion_params.rewriter_config_template is None: - if not disable_non_trt_optimizers: - # Layout optimizer may add Const nodes followed by Reshape nodes, thus we - # need to run constant folding again. - rewriter_config_with_trt.optimizers.extend( - ["constfold", "layout", "constfold"]) - rewriter_config_with_trt.meta_optimizer_iterations = ( - rewriter_config_pb2.RewriterConfig.ONE) - optimizer = rewriter_config_with_trt.custom_optimizers.add() - # Add a constfold optimizer to cleanup the unused Const nodes. - rewriter_config_with_trt.custom_optimizers.add().name = "constfold" + if not disable_non_trt_optimizers: + # Layout optimizer may add Const nodes followed by Reshape nodes, thus we + # need to run constant folding again. + rewriter_config_with_trt.optimizers.extend( + ["constfold", "layout", "constfold"]) + rewriter_config_with_trt.meta_optimizer_iterations = ( + rewriter_config_pb2.RewriterConfig.ONE) + optimizer = rewriter_config_with_trt.custom_optimizers.add() + # Add a constfold optimizer to cleanup the unused Const nodes. + rewriter_config_with_trt.custom_optimizers.add().name = "constfold" - optimizer.name = "TensorRTOptimizer" - optimizer.parameter_map[ - "minimum_segment_size"].i = conversion_params.minimum_segment_size - optimizer.parameter_map["max_workspace_size_bytes"].i = ( - conversion_params.max_workspace_size_bytes) - optimizer.parameter_map["precision_mode"].s = _to_bytes( - conversion_params.precision_mode) - optimizer.parameter_map[ - "maximum_cached_engines"].i = conversion_params.maximum_cached_engines - optimizer.parameter_map[ - "use_calibration"].b = conversion_params.use_calibration - optimizer.parameter_map["is_dynamic_op"].b = conversion_params.is_dynamic_op - optimizer.parameter_map[ - "allow_build_at_runtime"].b = conversion_params.allow_build_at_runtime - optimizer.parameter_map[ - "max_batch_size"].i = conversion_params.max_batch_size - else: - rewriter_config_with_trt.CopyFrom( - conversion_params.rewriter_config_template) + optimizer.name = "TensorRTOptimizer" + optimizer.parameter_map[ + "minimum_segment_size"].i = conversion_params.minimum_segment_size + optimizer.parameter_map["max_workspace_size_bytes"].i = ( + conversion_params.max_workspace_size_bytes) + optimizer.parameter_map["precision_mode"].s = _to_bytes( + conversion_params.precision_mode) + optimizer.parameter_map[ + "maximum_cached_engines"].i = conversion_params.maximum_cached_engines + optimizer.parameter_map[ + "use_calibration"].b = conversion_params.use_calibration + optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op + optimizer.parameter_map[ + "allow_build_at_runtime"].b = conversion_params.allow_build_at_runtime + if max_batch_size is not None: + optimizer.parameter_map["max_batch_size"].i = max_batch_size + optimizer.parameter_map["use_implicit_batch"].b = use_implicit_batch # Disabling optimizers should happen after CopyFrom the template # otherwise the template can overwrite the disablement. @@ -371,6 +318,18 @@ def get_tensorrt_rewriter_config(conversion_params, return rewriter_config_with_trt +@deprecation.deprecated( + None, "You shouldn't need a rewriter_config with the current TF-TRT APIs.") +def get_tensorrt_rewriter_config(conversion_params, + is_dynamic_op=None, + max_batch_size=None, + is_v2=False, + disable_non_trt_optimizers=False): + return _get_tensorrt_rewriter_config(conversion_params, is_dynamic_op, + max_batch_size, is_v2, + disable_non_trt_optimizers) + + # Remove all scope prefixes in the node name. In TF 2.0, the same concrete # function can be initialized multiple times with different prefixes, and # this will result in the same TRTEngineOp being initialized multiple times @@ -384,17 +343,6 @@ def _get_canonical_engine_name(name): return name.split("/")[-1] -def is_explicit_batch_mode_enabled(rewriter_config): - """Checks whether explicit batch is enabled by the rewriter config.""" - if rewriter_config is None: - return False - for optimizer in rewriter_config.custom_optimizers: - if optimizer.name == "TensorRTOptimizer": - if "use_implicit_batch" in optimizer.parameter_map: - return not optimizer.parameter_map["use_implicit_batch"].b - return False - - class TrtGraphConverter(object): """A converter for TF-TRT transformation for TF 1.x GraphDef/SavedModels. @@ -427,15 +375,12 @@ class TrtGraphConverter(object): ``` """ - @deprecation.deprecated_args(None, "Remove the use of this argument", - "session_config") def __init__(self, input_saved_model_dir=None, input_saved_model_tags=None, input_saved_model_signature_key=None, input_graph_def=None, nodes_denylist=None, - session_config=None, max_batch_size=1, max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES, precision_mode=TrtPrecisionMode.FP32, @@ -443,7 +388,7 @@ class TrtGraphConverter(object): is_dynamic_op=False, maximum_cached_engines=1, use_calibration=True): - """Initialize the converter. + """Initializes the converter. Args: input_saved_model_dir: the directory to load the SavedModel which contains @@ -454,11 +399,7 @@ class TrtGraphConverter(object): input_graph_def: a GraphDef object containing a model to be transformed. If set to None, the graph will be read from the SavedModel loaded from input_saved_model_dir. - nodes_denylist: list of node names to prevent the converter from - touching. - session_config: the ConfigProto used to create a Session. It's also used - as a template to create a TRT-enabled ConfigProto for conversion. If not - specified, a default ConfigProto will be used. + nodes_denylist: list of node names to prevent the converter from touching. max_batch_size: max size for the input batch. max_workspace_size_bytes: the maximum GPU temporary memory which the TRT engine can use at execution time. This corresponds to the @@ -510,7 +451,6 @@ class TrtGraphConverter(object): self._input_saved_model_signature_key = ( input_saved_model_signature_key or signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY) - self._session_config = session_config or config_pb2.ConfigProto() # For calibration usage. self._calibration_graph = None @@ -523,33 +463,39 @@ class TrtGraphConverter(object): "dynamic TRT ops only. Disregarding is_dynamic_op parameter.") is_dynamic_op = True - # TODO(laigd): - # - Verify in int8 mode that maximum_cached_engines is set properly. - # - If it fails to build the int8 engine it should return error. - rewriter_config_template = None - if (session_config and session_config.HasField("graph_options") and - session_config.graph_options.HasField("rewrite_options")): - rewriter_config_template = session_config.graph_options.rewrite_options + self._is_dynamic_op = is_dynamic_op + if is_dynamic_op: + self._max_batch_size = None + if max_batch_size is not None: + tf_logging.warn("When is_dynamic_op==True max_batch_size should be " + "None") + else: + if not isinstance(max_batch_size, int): + raise ValueError("When is_dynamic_op==False max_batch_size should be " + "an integer") + self._max_batch_size = max_batch_size self._conversion_params = TrtConversionParams( - rewriter_config_template=rewriter_config_template, max_workspace_size_bytes=max_workspace_size_bytes, precision_mode=precision_mode, minimum_segment_size=minimum_segment_size, - is_dynamic_op=is_dynamic_op, maximum_cached_engines=maximum_cached_engines, use_calibration=use_calibration, - max_batch_size=max_batch_size, allow_build_at_runtime=True) _check_conversion_params(self._conversion_params) + self._test_only_disable_non_trt_optimizers = False + def _run_conversion(self): """Run Grappler's OptimizeGraph() tool to convert the graph.""" # Create custom ConfigProto for Grappler. grappler_session_config = config_pb2.ConfigProto() - grappler_session_config.CopyFrom(self._session_config) - custom_rewriter_config = get_tensorrt_rewriter_config( - conversion_params=self._conversion_params) + custom_rewriter_config = _get_tensorrt_rewriter_config( + conversion_params=self._conversion_params, + is_dynamic_op=self._is_dynamic_op, + max_batch_size=self._max_batch_size, + disable_non_trt_optimizers=self._test_only_disable_non_trt_optimizers, + use_implicit_batch=True) grappler_session_config.graph_options.rewrite_options.CopyFrom( custom_rewriter_config) @@ -596,7 +542,7 @@ class TrtGraphConverter(object): def _convert_saved_model(self): """Convert the input SavedModel.""" graph = ops.Graph() - with session.Session(graph=graph, config=self._session_config) as sess: + with session.Session(graph=graph) as sess: input_meta_graph_def = loader.load(sess, self._input_saved_model_tags, self._input_saved_model_dir) input_signature_def = input_meta_graph_def.signature_def[ @@ -706,9 +652,13 @@ class TrtGraphConverter(object): return_elements=fetch_names, name="") + # Set allow_soft_placement=True to run the graph for calibration so that + # OPs supported by TensorRT but don't have a GPU implementation are allowed + # to execute on CPU. + calibrate_config = config_pb2.ConfigProto(allow_soft_placement=True) with session.Session( graph=self._calibration_graph, - config=self._session_config) as calibration_sess: + config=calibrate_config) as calibration_sess: for _ in range(num_runs): calibration_sess.run( fetches, feed_dict=feed_dict_fn() if feed_dict_fn else None) @@ -823,7 +773,7 @@ class TrtGraphConverter(object): self._collections_to_keep( self._grappler_meta_graph_def.collection_def)) # We don't use any specific converter here. - with session.Session(config=self._session_config) as sess: + with session.Session() as sess: saved_model_builder.add_meta_graph_and_variables( sess, self._input_saved_model_tags, @@ -884,11 +834,6 @@ class TrtGraphConverterV2(object): Currently this is not available on Windows platform. - Note that in V2, is_dynamic_op=False is not supported, meaning TRT engines - will be built only when the corresponding TRTEngineOp is executed. But we - still provide a way to avoid the cost of building TRT engines during inference - (see more below). - There are several ways to run the conversion: 1. FP32/FP16 precision @@ -997,8 +942,6 @@ class TrtGraphConverterV2(object): assert context.executing_eagerly() if conversion_params is None: conversion_params = TrtConversionParams() - elif conversion_params.rewriter_config_template is not None: - tf_logging.warn("the rewrite_config_template field will be deprecated.") _check_trt_version_compatibility() _check_conversion_params(conversion_params, is_v2=True) @@ -1010,22 +953,21 @@ class TrtGraphConverterV2(object): self._input_saved_model_signature_key = ( input_saved_model_signature_key or signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY) - self._rewriter_config = get_tensorrt_rewriter_config( - conversion_params=self._conversion_params, is_v2=True) self._need_calibration = ( conversion_params.precision_mode == TrtPrecisionMode.INT8 and conversion_params.use_calibration) - if (self._need_calibration and not conversion_params.is_dynamic_op): - raise ValueError("INT8 precision mode with calibration is not supported " - "with static TensorRT ops. Set is_dynamic_op to True.") - # rewriter_config is already validated - self._need_trt_profiles = is_explicit_batch_mode_enabled( - self._rewriter_config) self._converted = False self._build_called_once = False + # Fields to support TF-TRT testing and shouldn't be used for other purpose. + self._test_only_disable_non_trt_optimizers = False + self._test_only_use_implicit_batch = True + + def _need_trt_profiles(self): + return not self._test_only_use_implicit_batch + def _run_conversion(self, meta_graph_def): """Run Grappler's OptimizeGraph() tool to convert the graph. @@ -1036,8 +978,14 @@ class TrtGraphConverterV2(object): The optimized GraphDef. """ grappler_session_config = config_pb2.ConfigProto() + custom_rewriter_config = _get_tensorrt_rewriter_config( + conversion_params=self._conversion_params, + is_dynamic_op=True, + max_batch_size=None, + disable_non_trt_optimizers=self._test_only_disable_non_trt_optimizers, + use_implicit_batch=self._test_only_use_implicit_batch) grappler_session_config.graph_options.rewrite_options.CopyFrom( - self._rewriter_config) + custom_rewriter_config) return tf_optimizer.OptimizeGraph( grappler_session_config, meta_graph_def, graph_id=b"tf_graph") @@ -1167,7 +1115,7 @@ class TrtGraphConverterV2(object): def _set_profile_generation_mode(value, node): node.attr["_profile_generation_mode"].b = value - if self._need_trt_profiles: + if self._need_trt_profiles(): # Enable profile generation. self._for_each_trt_node(self._converted_graph_def, partial(_set_profile_generation_mode, True)) @@ -1187,7 +1135,7 @@ class TrtGraphConverterV2(object): first_input = inp func(*map(ops.convert_to_tensor, inp)) - if self._need_trt_profiles: + if self._need_trt_profiles(): # Disable profile generation. self._for_each_trt_node(self._converted_graph_def, partial(_set_profile_generation_mode, False)) @@ -1208,7 +1156,7 @@ class TrtGraphConverterV2(object): """ assert self._converted - if self._need_trt_profiles and not self._build_called_once: + if self._need_trt_profiles() and not self._build_called_once: raise NotImplementedError( "build() is not called . Explicit batch mode " "(use_implicit_batch=False) requires generating TensorRT optimization" @@ -1295,8 +1243,7 @@ def create_inference_graph( input_saved_model_dir=None, input_saved_model_tags=None, input_saved_model_signature_key=None, - output_saved_model_dir=None, - session_config=None): + output_saved_model_dir=None): """Python wrapper for the TRT transformation. Args: @@ -1327,9 +1274,6 @@ def create_inference_graph( returned GraphDef and save it to the specified directory. This option only works when the input graph is loaded from a SavedModel, i.e. when input_saved_model_dir is specified and input_graph_def is None. - session_config: the ConfigProto used to create a Session. It's also used as - a template to create a TRT-enabled ConfigProto for conversion. If not - specified, a default ConfigProto will be used. Returns: A GraphDef transformed from input_graph_def (or the SavedModel graph def @@ -1359,7 +1303,6 @@ def create_inference_graph( input_saved_model_signature_key=input_saved_model_signature_key, input_graph_def=input_graph_def, nodes_denylist=outputs, - session_config=session_config, max_batch_size=max_batch_size, max_workspace_size_bytes=max_workspace_size_bytes, precision_mode=precision_mode, diff --git a/tensorflow/python/compiler/tensorrt/trt_convert_test.py b/tensorflow/python/compiler/tensorrt/trt_convert_test.py index 0baae0bd3bf..d19f2d03e30 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert_test.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert_test.py @@ -30,7 +30,6 @@ from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled from tensorflow.compiler.tf2tensorrt.utils.trt_engine_instance_pb2 import TRTEngineInstance # pylint: disable=g-importing-member from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.compiler.tensorrt import trt_convert from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes @@ -79,98 +78,6 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): # test if we can access the TRTEngineInstance protobuf assert hasattr(TRTEngineInstance(), "serialized_engine") - def testGetTensorrtRewriterConfig(self): - """Test case for TrtGraphConverter.get_tensorrt_rewriter_config().""" - if not is_tensorrt_enabled(): - return - conversion_params = trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace( - max_batch_size=128, - max_workspace_size_bytes=1234, - precision_mode="INT8", - minimum_segment_size=10, - is_dynamic_op=True, - maximum_cached_engines=2) - rewriter_cfg = trt_convert.get_tensorrt_rewriter_config( - conversion_params=conversion_params, is_v2=True) - self.assertEqual(["constfold", "layout", "constfold"], - rewriter_cfg.optimizers) - self.assertEqual(rewriter_config_pb2.RewriterConfig.ONE, - rewriter_cfg.meta_optimizer_iterations) - trt_optimizer = None - for optimizer in rewriter_cfg.custom_optimizers: - if optimizer.name == "TensorRTOptimizer": - self.assertTrue(trt_optimizer is None) - trt_optimizer = optimizer - self.assertTrue(trt_optimizer is not None) - for key in [ - "minimum_segment_size", "max_batch_size", "is_dynamic_op", - "max_workspace_size_bytes", "precision_mode", "maximum_cached_engines" - ]: - self.assertTrue(key in trt_optimizer.parameter_map) - self.assertEqual(10, trt_optimizer.parameter_map["minimum_segment_size"].i) - self.assertEqual(128, trt_optimizer.parameter_map["max_batch_size"].i) - self.assertEqual(True, trt_optimizer.parameter_map["is_dynamic_op"].b) - self.assertEqual(1234, - trt_optimizer.parameter_map["max_workspace_size_bytes"].i) - self.assertEqual( - trt_convert._to_bytes("INT8"), - trt_optimizer.parameter_map["precision_mode"].s) - self.assertEqual(2, trt_optimizer.parameter_map["maximum_cached_engines"].i) - - def testGetTensorrtRewriterConfigTemplate(self): - """Test case for TrtGraphConverter.get_tensorrt_rewriter_config().""" - if not is_tensorrt_enabled(): - return - - rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig() - rewriter_config_with_trt.optimizers.extend( - ["constfold", "layout", "constfold"]) - rewriter_config_with_trt.meta_optimizer_iterations = ( - rewriter_config_pb2.RewriterConfig.ONE) - optimizer = rewriter_config_with_trt.custom_optimizers.add() - rewriter_config_with_trt.custom_optimizers.add().name = "constfold" - optimizer.name = "TensorRTOptimizer" - optimizer.parameter_map["minimum_segment_size"].i = 10 - optimizer.parameter_map["max_batch_size"].i = 128 - optimizer.parameter_map["is_dynamic_op"].b = True - optimizer.parameter_map["max_workspace_size_bytes"].i = 1234 - optimizer.parameter_map["precision_mode"].s = trt_convert._to_bytes( - trt_convert.TrtPrecisionMode.INT8) - optimizer.parameter_map["maximum_cached_engines"].i = 2 - optimizer.parameter_map["use_calibration"].b = False - optimizer.parameter_map["use_implicit_batch"].b = True - - conversion_params = trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace( - rewriter_config_template=rewriter_config_with_trt) - rewriter_cfg = trt_convert.get_tensorrt_rewriter_config( - conversion_params=conversion_params) - self.assertEqual(["constfold", "layout", "constfold"], - rewriter_cfg.optimizers) - self.assertEqual(rewriter_config_pb2.RewriterConfig.ONE, - rewriter_cfg.meta_optimizer_iterations) - trt_optimizer = None - for optimizer in rewriter_cfg.custom_optimizers: - if optimizer.name == "TensorRTOptimizer": - self.assertIsNone(trt_optimizer) - trt_optimizer = optimizer - self.assertIsNotNone(trt_optimizer) - for key in [ - "minimum_segment_size", "max_batch_size", "is_dynamic_op", - "max_workspace_size_bytes", "precision_mode", "maximum_cached_engines" - ]: - self.assertIn(key, trt_optimizer.parameter_map) - self.assertEqual(10, trt_optimizer.parameter_map["minimum_segment_size"].i) - self.assertEqual(128, trt_optimizer.parameter_map["max_batch_size"].i) - self.assertEqual(True, trt_optimizer.parameter_map["is_dynamic_op"].b) - self.assertEqual(1234, - trt_optimizer.parameter_map["max_workspace_size_bytes"].i) - self.assertEqual( - trt_convert._to_bytes("INT8"), - trt_optimizer.parameter_map["precision_mode"].s) - self.assertEqual(2, trt_optimizer.parameter_map["maximum_cached_engines"].i) - self.assertEqual(False, trt_optimizer.parameter_map["use_calibration"].b) - self.assertEqual(True, trt_optimizer.parameter_map["use_implicit_batch"].b) - def _GetConfigProto(self, rewriter_config=None): """Get ConfigProto for session creation.""" config = config_pb2.ConfigProto( @@ -280,13 +187,20 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): input_saved_model_dir = self.mkdtemp() self._WriteInputSavedModelForV1(input_saved_model_dir, device) + # Calibration requires dynamic_op. + if need_calibration: + is_dynamic_op = True + + # For dynamic_op, the converter requires the unused max_batch_size=None. + if is_dynamic_op: + max_batch_size = None + converter = trt_convert.TrtGraphConverter( input_saved_model_dir=input_saved_model_dir, input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY, input_graph_def=None if input_saved_model_dir else self._GetGraphDefForV1(device), nodes_denylist=None if input_saved_model_dir else ["output"], - session_config=self._GetConfigProto(), max_batch_size=max_batch_size, max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES, precision_mode=(trt_convert.TrtPrecisionMode.INT8 if need_calibration @@ -437,10 +351,13 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): return conversion_params = trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace( - precision_mode=trt_convert.TrtPrecisionMode.FP32, is_dynamic_op=True) + precision_mode=trt_convert.TrtPrecisionMode.FP32) config = self._GetConfigProto( rewriter_config=trt_convert.get_tensorrt_rewriter_config( - conversion_params, is_v2=False)) + conversion_params, + is_dynamic_op=False, + max_batch_size=1, + is_v2=False)) with ops.Graph().as_default(): # Online conversion requires a frozen graph, so we reuse inp1 as the var @@ -463,7 +380,6 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY, max_workspace_size_bytes=10 << 20, # Use a smaller workspace. precision_mode=trt_convert.TrtPrecisionMode.FP32, - is_dynamic_op=True, maximum_cached_engines=2): return trt_convert.TrtGraphConverterV2( input_saved_model_dir=input_saved_model_dir, @@ -471,7 +387,6 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): conversion_params=trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace( max_workspace_size_bytes=max_workspace_size_bytes, precision_mode=precision_mode, - is_dynamic_op=is_dynamic_op, maximum_cached_engines=maximum_cached_engines)) def _CheckTrtOps(self, concrete_func, check_fn=None): @@ -572,24 +487,6 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): del root_with_trt gc.collect() # Force GC to destroy the TRT engine cache. - @test_util.run_v2_only - def testTrtGraphConverter_StaticConversionNotSupportedInV2(self): - """Test case for trt_convert.TrtGraphConverter() using static mode.""" - if not is_tensorrt_enabled(): - return - - # Create a model and save it. - input_saved_model_dir = self.mkdtemp() - root = self._GetModelForV2() - save.save(root, input_saved_model_dir, - {_SAVED_MODEL_SIGNATURE_KEY: root.run}) - - # Run TRT conversion. - with self.assertRaisesRegex( - ValueError, r"Option is_dynamic_op=False is not supported in TF 2.0, " - "please set it to True instead."): - self._CreateConverterV2(input_saved_model_dir, is_dynamic_op=False) - @test_util.run_v2_only def testTrtGraphConverter_Int8Conversion_v2(self): if not is_tensorrt_enabled(): diff --git a/tensorflow/python/compiler/tensorrt/trt_convert_windows.py b/tensorflow/python/compiler/tensorrt/trt_convert_windows.py index 782d22d3721..0180e389f3b 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert_windows.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert_windows.py @@ -40,10 +40,15 @@ DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30 @tf_export("experimental.tensorrt.ConversionParams", v1=[]) -class TrtConversionParams(collections.namedtuple("TrtConversionParams", [ - "rewriter_config_template", "max_workspace_size_bytes", "precision_mode", - "minimum_segment_size", "is_dynamic_op", "maximum_cached_engines", - "use_calibration", "max_batch_size"])): +class TrtConversionParams( + collections.namedtuple("TrtConversionParams", [ + "rewriter_config_template", + "max_workspace_size_bytes", + "precision_mode", + "minimum_segment_size", + "maximum_cached_engines", + "use_calibration", + ])): """Parameters that are used for TF-TRT conversion. Fields: @@ -56,11 +61,6 @@ class TrtConversionParams(collections.namedtuple("TrtConversionParams", [ TrtPrecisionMode.supported_precision_modes(). minimum_segment_size: the minimum number of nodes required for a subgraph to be replaced by TRTEngineOp. - is_dynamic_op: whether to generate dynamic TRT ops which will build the - TRT network and engine at run time. i.e. Since TensorRT version < 6.0 - does not support dynamic dimensions other than the batch dimension, when - the TensorFlow graph has a non-batch dimension of dynamic size, we would - need to enable this option. This option should be set to True in TF 2.0. maximum_cached_engines: max number of cached TRT engines for dynamic TRT ops. Created TRT engines for a dynamic dimension are cached. This is the maximum number of engines that can be cached. If the number of cached @@ -76,8 +76,6 @@ class TrtConversionParams(collections.namedtuple("TrtConversionParams", [ will occur. Please note that accuracy may be negatively affected if there is a mismatch between which tensors TRT quantizes and which tensors were trained with fake quantization. - max_batch_size: max size for the input batch. This parameter is only - effective when is_dynamic_op=False which is not supported in TF 2.0. """ def __new__(cls, @@ -87,8 +85,7 @@ class TrtConversionParams(collections.namedtuple("TrtConversionParams", [ minimum_segment_size=3, is_dynamic_op=True, maximum_cached_engines=1, - use_calibration=True, - max_batch_size=1): + use_calibration=True): raise NotImplementedError( "TensorRT integration is not available on Windows.") diff --git a/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py index d428baca9c0..86088488e05 100644 --- a/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py @@ -493,6 +493,36 @@ class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, ] self.assertDatasetProduces(dataset, list(chunk(expected, 5))) + @combinations.generate(test_base.default_test_combinations()) + def testMaxIntraOpParallelism(self): + dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) + dataset = dataset.flat_map(core_readers.TFRecordDataset) + dataset = dataset.batch(5) + dataset = dataset_ops._MaxIntraOpParallelismDataset(dataset, 1) + dataset = distribute._AutoShardDataset(dataset, 5, 0) + + expected = [ + b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension + for f in (0, 5) + for r in range(0, 10) + ] + self.assertDatasetProduces(dataset, list(chunk(expected, 5))) + + @combinations.generate(test_base.default_test_combinations()) + def testPrivateThreadpool(self): + dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) + dataset = dataset.flat_map(core_readers.TFRecordDataset) + dataset = dataset.batch(5) + dataset = dataset_ops._PrivateThreadPoolDataset(dataset, 1) + dataset = distribute._AutoShardDataset(dataset, 5, 0) + + expected = [ + b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension + for f in (0, 5) + for r in range(0, 10) + ] + self.assertDatasetProduces(dataset, list(chunk(expected, 5))) + @combinations.generate(test_base.default_test_combinations()) def testMakeBatchedFeaturesDataset(self): files = 2 diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 20d77bd40ed..b2918076879 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -246,9 +246,6 @@ py_test( srcs = ["distribute_coordinator_test.py"], python_version = "PY3", srcs_version = "PY2AND3", - tags = [ - "no_oss_py2", # b/138443278 - ], deps = [ ":distribute_coordinator", "//tensorflow/core:protos_all_py", diff --git a/tensorflow/python/distribute/coordinator/cluster_coordinator.py b/tensorflow/python/distribute/coordinator/cluster_coordinator.py index ca330b88ed0..e26b829fa85 100644 --- a/tensorflow/python/distribute/coordinator/cluster_coordinator.py +++ b/tensorflow/python/distribute/coordinator/cluster_coordinator.py @@ -176,14 +176,15 @@ class RemoteValueImpl(RemoteValue): """ self._closure = closure self._type_spec = type_spec - self._value = None + self._tensors = None + self._fetched_numpys = None self._error = None self._status_available_event = threading.Event() self._status = _RemoteValueStatus.NOT_READY def _set_aborted(self): self._status = _RemoteValueStatus.ABORTED - self._value = None + self._tensors = None self._error = None # Wake up any waiting thread and clear the event. @@ -194,21 +195,21 @@ class RemoteValueImpl(RemoteValue): # TODO(yuefengz): we may need to rebuild its inputs as well. self._closure.execute_on(worker) - def _set_value(self, value): + def _set_tensors(self, tensors): self._status = _RemoteValueStatus.READY - self._value = value + self._tensors = tensors self._error = None self._status_available_event.set() def _set_error(self, exception): self._status = _RemoteValueStatus.READY - self._value = None + self._tensors = None self._error = exception self._status_available_event.set() - def _get_value(self): + def _get_tensors(self): self._status_available_event.wait() - return self._value + return self._tensors def _get_error(self): self._status_available_event.wait() @@ -222,10 +223,11 @@ class RemoteValueImpl(RemoteValue): "The corresponding function is aborted. Please reschedule the " "function.") if self._error is not None: - raise self._error # pylint: disable=raising-bad-type - else: - return nest.map_structure( - lambda x: x.numpy() if hasattr(x, "numpy") else x, self._value) + raise self._error + if self._fetched_numpys is None: + self._fetched_numpys = nest.map_structure( + lambda x: x.numpy() if hasattr(x, "numpy") else x, self._tensors) + return self._fetched_numpys class InputError(Exception): @@ -271,7 +273,7 @@ def _maybe_get_remote_value(val): raise AssertionError( "RemoteValue doesn't have a value because it has errors.") else: - return val._get_value() # pylint: disable=protected-access + return val._get_tensors() # pylint: disable=protected-access else: return val @@ -406,10 +408,10 @@ class Closure(object): with ops.device(worker.device_name): with context.executor_scope(worker.executor): with metric_utils.monitored_timer("closure_execution"): - output_value = self._function( + output_tensors = self._function( *nest.map_structure(_maybe_get_remote_value, replica_args), **nest.map_structure(_maybe_get_remote_value, replica_kwargs)) - self.output_remote_value._set_value(output_value) # pylint: disable=protected-access + self.output_remote_value._set_tensors(output_tensors) # pylint: disable=protected-access class _CoordinatedClosureQueue(object): diff --git a/tensorflow/python/distribute/coordinator/cluster_coordinator_mpr_test.py b/tensorflow/python/distribute/coordinator/cluster_coordinator_mpr_test.py index 8b3e95f1fea..4c5a8e20ca1 100644 --- a/tensorflow/python/distribute/coordinator/cluster_coordinator_mpr_test.py +++ b/tensorflow/python/distribute/coordinator/cluster_coordinator_mpr_test.py @@ -209,6 +209,50 @@ class ClusterCoordinatorMprTest(test.TestCase): any("_test_translate_ps_failure_error ends properly" in msg for msg in mpr.join().stdout)) + def test_numpy_fetched_after_worker_failure(self): + + def fn(first_fetch_occurred_event, worker_terminated_event): + os.environ["GRPC_FAIL_FAST"] = "use_caller" + + cluster_resolver = TFConfigClusterResolver() + if cluster_resolver.task_type != "chief": + utils.start_server(cluster_resolver, "grpc") + strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( + cluster_resolver) + ps_coordinator = coordinator_lib.ClusterCoordinator(strategy) + + with strategy.scope(): + v = variables.Variable(initial_value=0, dtype=dtypes.int32) + + @def_function.function + def worker_fn(): + return v + 1, v - 1 + + remote_value = ps_coordinator.schedule(worker_fn) + logging.info("result (1st fetch): %r", remote_value.fetch()) + first_fetch_occurred_event.set() + worker_terminated_event.wait() + logging.info("result (2nd fetch): %r", remote_value.fetch()) + + manager = multi_process_runner.manager() + first_fetch_occurred_event = manager.Event() + worker_terminated_event = manager.Event() + mpr = multi_process_runner.MultiProcessRunner( + fn, + multi_worker_test_base.create_cluster_spec( + has_chief=True, num_workers=1, num_ps=1, has_eval=False), + args=(first_fetch_occurred_event, worker_terminated_event), + rpc_layer="grpc", + return_output=True, + use_dill_for_args=False) + + mpr.start() + first_fetch_occurred_event.wait() + mpr.terminate("worker", 0) + worker_terminated_event.set() + self.assertTrue( + any("result (2nd fetch)" in msg for msg in mpr.join().stdout)) + if __name__ == "__main__": v2_compat.enable_v2_behavior() diff --git a/tensorflow/python/distribute/coordinator/metric_utils_test.py b/tensorflow/python/distribute/coordinator/metric_utils_test.py index abd4221df4d..db223e3aeb8 100644 --- a/tensorflow/python/distribute/coordinator/metric_utils_test.py +++ b/tensorflow/python/distribute/coordinator/metric_utils_test.py @@ -58,7 +58,7 @@ class MetricUtilsTest(test.TestCase): result = cluster.schedule(func, args=None, kwargs=None) result = cluster.schedule(func, args=None, kwargs=None) cluster.join() - self.assertEqual(result._get_value().numpy(), 3) + self.assertEqual(result.fetch(), 3) # Tracing, closure execution, and remote_value fetching should be executed # exactly once for running this function. diff --git a/tensorflow/python/distribute/cross_device_ops_test.py b/tensorflow/python/distribute/cross_device_ops_test.py index 47af4e426df..ca17550626a 100644 --- a/tensorflow/python/distribute/cross_device_ops_test.py +++ b/tensorflow/python/distribute/cross_device_ops_test.py @@ -109,7 +109,7 @@ def enable_collective_ops(): # Recover default flag values. cross_device_ops_lib.CollectiveAllReduce._limited_nccl = True cross_device_utils.CollectiveReplicaLauncher._prefer_scoped_allocator = True - cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = False + cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = True cross_device_utils.CollectiveReplicaLauncher._prefer_ordering_token = False diff --git a/tensorflow/python/distribute/cross_device_utils.py b/tensorflow/python/distribute/cross_device_utils.py index 0993b054963..15bc9d0da7d 100644 --- a/tensorflow/python/distribute/cross_device_utils.py +++ b/tensorflow/python/distribute/cross_device_utils.py @@ -258,7 +258,7 @@ class CollectiveReplicaLauncher(object): """Launch collectives on one replica.""" _prefer_scoped_allocator = True - _prefer_collective_v2 = False + _prefer_collective_v2 = True _prefer_ordering_token = False def __init__(self, diff --git a/tensorflow/python/distribute/multi_process_lib.py b/tensorflow/python/distribute/multi_process_lib.py index 0cb88214cf2..14fe8a43d69 100644 --- a/tensorflow/python/distribute/multi_process_lib.py +++ b/tensorflow/python/distribute/multi_process_lib.py @@ -108,9 +108,10 @@ def _set_spawn_exe_path(): # /.../tensorflow/python/distribute/input_lib_test.py # and the binary is # /.../tensorflow/python/distribute/input_lib_test_multiworker_gpu - org_tensorflow_path = sys.argv[0][:sys.argv[0].rfind('/tensorflow')] + org_tensorflow_base = sys.argv[0][:sys.argv[0].rfind('/org_tensorflow')] binary = os.environ['TEST_TARGET'][2:].replace(':', '/', 1) - possible_path = os.path.join(org_tensorflow_path, binary) + possible_path = os.path.join(org_tensorflow_base, 'org_tensorflow', + binary) logging.info('Guessed test binary path: %s', possible_path) if os.access(possible_path, os.X_OK): path = possible_path diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index cb732d37089..89fdc0c82ad 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -375,6 +375,7 @@ cuda_py_test( srcs = ["backprop_test.py"], python_version = "PY3", tags = [ + "no_cuda_asan", # b/173825938 "no_windows", #TODO(b/139745667) "notsan", #TODO(b/139745667) ], diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index edb02b072b1..0063b7f155e 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -1844,6 +1844,24 @@ class JacobianTest(test.TestCase): self.assertAllClose(compute_jacobian(use_pfor=True), compute_jacobian(use_pfor=False)) + def test_cond_func_grad_jacobian(self): + + @def_function.function + def f(x): + y = control_flow_ops.cond(x > 0., lambda: x**3., lambda: x**2.) + return y + + with backprop.GradientTape(persistent=True) as tape: + x = constant_op.constant(1.) + tape.watch(x) + y = f(x) + grad = tape.gradient(y, x) + self.assertAllClose(3., grad) + jacobian = tape.jacobian(grad, x, experimental_use_pfor=False) + self.assertAllClose(6., jacobian) + jacobian_pfor = tape.jacobian(grad, x, experimental_use_pfor=True) + self.assertAllClose(6., jacobian_pfor) + @test_util.run_all_in_graph_and_eager_modes class BatchJacobianTest(test.TestCase, parameterized.TestCase): diff --git a/tensorflow/python/eager/backprop_util.py b/tensorflow/python/eager/backprop_util.py index e1c719d4a9d..96d709fe9c5 100644 --- a/tensorflow/python/eager/backprop_util.py +++ b/tensorflow/python/eager/backprop_util.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.core.framework import types_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util @@ -38,8 +39,12 @@ def _DTypeFromTensor(tensor): and handle_data.is_set and handle_data.shape_and_type): first_type = handle_data.shape_and_type[0].dtype - if all(shape_and_type.dtype == first_type - for shape_and_type in handle_data.shape_and_type): + # Some variants have statically unknown dtypes; we can't make inferences + # about trainability, so we conservatively assume they're trainable + # (which may waste memory passing zeros around, but will be correct). + if (first_type != types_pb2.DT_INVALID + and all(shape_and_type.dtype == first_type + for shape_and_type in handle_data.shape_and_type)): return first_type return dtype diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 8c226f6c681..d41f5d736c8 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -243,7 +243,7 @@ struct FastPathOpExecInfo { #if PY_MAJOR_VERSION >= 3 PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong) -PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLong) +PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLongLong) #else PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong) #endif diff --git a/tensorflow/python/eager/pywrap_tfe_test.py b/tensorflow/python/eager/pywrap_tfe_test.py index 324a314a540..af0f27cef09 100644 --- a/tensorflow/python/eager/pywrap_tfe_test.py +++ b/tensorflow/python/eager/pywrap_tfe_test.py @@ -368,5 +368,19 @@ class Tests(test.TestCase): self.assertNotRegex(full_exception_text, "_FallbackException") + def testIntAttrThatDoesNotFitIn32Bits(self): + # Tests bug where int attributes >= 2**31 raised an exception on platforms + # where sizeof(long) = 32 bits. + ctx = context.context() + ctx.ensure_initialized() + shape = constant_op.constant([10]) + minval = constant_op.constant(0) + maxval = constant_op.constant(10) + seed = 2**50 + pywrap_tfe.TFE_Py_FastPathExecute(ctx, "RandomUniformInt", None, + shape, minval, maxval, + "seed", seed) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/framework/tensor_spec.py b/tensorflow/python/framework/tensor_spec.py index b5275372447..19947121cd5 100644 --- a/tensorflow/python/framework/tensor_spec.py +++ b/tensorflow/python/framework/tensor_spec.py @@ -56,10 +56,6 @@ class DenseSpec(type_spec.TypeSpec): self._dtype = dtypes.as_dtype(dtype) self._name = name - @classmethod - def from_spec(cls, spec, name=None): - return cls(spec.shape, spec.dtype, name or spec.name) - @property def shape(self): """Returns the `TensorShape` that represents the shape of the tensor.""" @@ -141,8 +137,34 @@ class TensorSpec(DenseSpec, type_spec.BatchableTypeSpec): """ return super(TensorSpec, self).is_compatible_with(spec_or_tensor) + @classmethod + def from_spec(cls, spec, name=None): + """Returns a `TensorSpec` with the same shape and dtype as `spec`. + + >>> spec = tf.TensorSpec(shape=[8, 3], dtype=tf.int32, name="OriginalName") + >>> tf.TensorSpec.from_spec(spec, "NewName") + TensorSpec(shape=(8, 3), dtype=tf.int32, name='NewName') + + Args: + spec: The `TypeSpec` used to create the new `TensorSpec`. + name: The name for the new `TensorSpec`. Defaults to `spec.name`. + """ + return cls(spec.shape, spec.dtype, name or spec.name) + @classmethod def from_tensor(cls, tensor, name=None): + """Returns a `TensorSpec` that describes `tensor`. + + >>> tf.TensorSpec.from_tensor(tf.constant([1, 2, 3])) + TensorSpec(shape=(3,), dtype=tf.int32, name=None) + + Args: + tensor: The `tf.Tensor` that should be described. + name: A name for the `TensorSpec`. Defaults to `tensor.op.name`. + + Returns: + A `TensorSpec` that describes `tensor`. + """ if isinstance(tensor, ops.EagerTensor): return TensorSpec(tensor.shape, tensor.dtype, name) elif isinstance(tensor, ops.Tensor): @@ -150,7 +172,10 @@ class TensorSpec(DenseSpec, type_spec.BatchableTypeSpec): else: raise ValueError("`tensor` should be a tf.Tensor") - value_type = property(lambda self: ops.Tensor) + @property + def value_type(self): + """The Python type for values that are compatible with this TypeSpec.""" + return ops.Tensor def _to_components(self, value): try: @@ -263,6 +288,21 @@ class BoundedTensorSpec(TensorSpec): @classmethod def from_spec(cls, spec): + """Returns a `TensorSpec` with the same shape and dtype as `spec`. + + If `spec` is a `BoundedTensorSpec`, then the new spec's bounds are set to + `spec.minimum` and `spec.maximum`; otherwise, the bounds are set to + `spec.dtype.min` and `spec.dtype.max`. + + >>> spec = tf.TensorSpec(shape=[8, 3], dtype=tf.int32, name="x") + >>> BoundedTensorSpec.from_spec(spec) + BoundedTensorSpec(shape=(8, 3), dtype=tf.int32, name='x', + minimum=array(-2147483648, dtype=int32), + maximum=array(2147483647, dtype=int32)) + + Args: + spec: The `TypeSpec` used to create the new `BoundedTensorSpec`. + """ dtype = dtypes.as_dtype(spec.dtype) minimum = getattr(spec, "minimum", dtype.min) maximum = getattr(spec, "maximum", dtype.max) diff --git a/tensorflow/python/framework/type_spec.py b/tensorflow/python/framework/type_spec.py index ebfce25d6db..fa48ca8f952 100644 --- a/tensorflow/python/framework/type_spec.py +++ b/tensorflow/python/framework/type_spec.py @@ -21,6 +21,7 @@ from __future__ import print_function import abc import collections +import re import numpy as np import six @@ -56,9 +57,18 @@ class TypeSpec(object): For example, `tf.function`'s `input_signature` argument accepts a list (or nested structure) of `TypeSpec`s. - Creating new subclasses of TypeSpec (outside of TensorFlow core) is not + Creating new subclasses of `TypeSpec` (outside of TensorFlow core) is not currently supported. In particular, we may make breaking changes to the private methods and properties defined by this base class. + + Example: + + >>> spec = tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int32) + >>> @tf.function(input_signature=[spec]) + ... def double(x): + ... return x * 2 + >>> print(double(tf.ragged.constant([[1, 2], [3]]))) + <tf.RaggedTensor [[2, 4], [6]]> """ # === Subclassing === # @@ -611,3 +621,73 @@ def register_type_spec_from_value_converter(type_object, converter_fn, _pywrap_utils.RegisterType("TypeSpec", TypeSpec) + + +_TYPE_SPEC_TO_NAME = {} +_NAME_TO_TYPE_SPEC = {} + + +# Regular expression for valid TypeSpec names. +_REGISTERED_NAME_RE = re.compile(r"^(\w+\.)+\w+$") + + +# TODO(b/173744905) tf_export this as "tf.register_type_spec". (And add a +# usage example to the docstring, once the API is public.) +# +# TODO(b/173744905) Update this decorator to apply to ExtensionType rather than +# TypeSpec (once we do refactoring to move to_components/from_components from +# TypeSpec to ExtensionType). +def register(name): + """Decorator used to register a globally unique name for a TypeSpec subclass. + + Args: + name: The name of the type spec. Must be globally unique. Must have + the form `"{project_name}.{type_name}"`. E.g. `"my_project.MyTypeSpec"`. + + Returns: + A class decorator that registers the decorated class with the given name. + """ + if not isinstance(name, str): + raise TypeError("Expected `name` to be a string; got %r" % (name,)) + if not _REGISTERED_NAME_RE.match(name): + raise ValueError( + "Registered name must have the form '{project_name}.{type_name}' " + "(e.g. 'my_project.MyTypeSpec'); got %r." % name) + + def decorator_fn(cls): + if not (isinstance(cls, type) and issubclass(cls, TypeSpec)): + raise TypeError("Expected `cls` to be a TypeSpec; got %r" % (cls,)) + if cls in _TYPE_SPEC_TO_NAME: + raise ValueError("Class %s.%s has already been registered with name %s." + % (cls.__module__, cls.__name__, + _TYPE_SPEC_TO_NAME[cls])) + if name in _NAME_TO_TYPE_SPEC: + raise ValueError("Name %s has already been registered for class %s.%s." + % (name, _NAME_TO_TYPE_SPEC[name].__module__, + _NAME_TO_TYPE_SPEC[name].__name__)) + _TYPE_SPEC_TO_NAME[cls] = name + _NAME_TO_TYPE_SPEC[name] = cls + return cls + + return decorator_fn + + +# TODO(edloper) tf_export this as "tf.get_type_spec_name" (or some similar name) +def get_name(cls): + """Returns the registered name for TypeSpec `cls`.""" + if not (isinstance(cls, type) and issubclass(cls, TypeSpec)): + raise TypeError("Expected `cls` to be a TypeSpec; got %r" % (cls,)) + if cls not in _TYPE_SPEC_TO_NAME: + raise ValueError("TypeSpec %s.%s has not been registered." % + (cls.__module__, cls.__name__)) + return _TYPE_SPEC_TO_NAME[cls] + + +# TODO(edloper) tf_export this as "tf.lookup_type_spec" (or some similar name) +def lookup(name): + """Returns the TypeSpec that has been registered with name `name`.""" + if not isinstance(name, str): + raise TypeError("Expected `name` to be a string; got %r" % (name,)) + if name not in _NAME_TO_TYPE_SPEC: + raise ValueError("No TypeSpec has been registered with name %r" % (name,)) + return _NAME_TO_TYPE_SPEC[name] diff --git a/tensorflow/python/framework/type_spec_test.py b/tensorflow/python/framework/type_spec_test.py index bcffd43ee6a..8007bf62dd2 100644 --- a/tensorflow/python/framework/type_spec_test.py +++ b/tensorflow/python/framework/type_spec_test.py @@ -47,6 +47,7 @@ class TwoTensors(object): self.color = color +@type_spec.register("tf.TwoTensorsSpec") class TwoTensorsSpec(type_spec.TypeSpec): """A TypeSpec for the TwoTensors value type.""" @@ -97,6 +98,7 @@ class TwoComposites(object): self.color = color +@type_spec.register("tf.TwoCompositesSpec") class TwoCompositesSpec(type_spec.TypeSpec): """A TypeSpec for the TwoTensors value type.""" @@ -349,5 +351,64 @@ class TypeSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase): self.assertTrue(spec1.is_compatible_with(spec2)) self.assertFalse(spec1.is_compatible_with(spec3)) + def testRegistry(self): + self.assertEqual("tf.TwoCompositesSpec", + type_spec.get_name(TwoCompositesSpec)) + self.assertEqual("tf.TwoTensorsSpec", type_spec.get_name(TwoTensorsSpec)) + self.assertEqual(TwoCompositesSpec, + type_spec.lookup("tf.TwoCompositesSpec")) + self.assertEqual(TwoTensorsSpec, type_spec.lookup("tf.TwoTensorsSpec")) + + def testRegistryTypeErrors(self): + with self.assertRaisesRegex(TypeError, "Expected `name` to be a string"): + type_spec.register(None) + + with self.assertRaisesRegex(TypeError, "Expected `name` to be a string"): + type_spec.register(TwoTensorsSpec) + + with self.assertRaisesRegex(TypeError, "Expected `cls` to be a TypeSpec"): + type_spec.register("tf.foo")(None) + + with self.assertRaisesRegex(TypeError, "Expected `cls` to be a TypeSpec"): + type_spec.register("tf.foo")(ragged_tensor.RaggedTensor) + + def testRegistryDuplicateErrors(self): + with self.assertRaisesRegex( + ValueError, "Name tf.TwoCompositesSpec has already been registered " + "for class __main__.TwoCompositesSpec."): + + @type_spec.register("tf.TwoCompositesSpec") # pylint: disable=unused-variable + class NewTypeSpec(TwoCompositesSpec): + pass + + with self.assertRaisesRegex( + ValueError, "Class __main__.TwoCompositesSpec has already been " + "registered with name tf.TwoCompositesSpec"): + type_spec.register("tf.NewName")(TwoCompositesSpec) + + def testRegistryNameErrors(self): + for bad_name in ["foo", "", "hello world"]: + with self.assertRaises(ValueError): + type_spec.register(bad_name) + + def testRegistryLookupErrors(self): + with self.assertRaises(TypeError): + type_spec.lookup(None) + with self.assertRaisesRegex( + ValueError, "No TypeSpec has been registered with name 'foo.bar'"): + type_spec.lookup("foo.bar") + + def testRegistryGetNameErrors(self): + with self.assertRaises(TypeError): + type_spec.get_name(None) + + class Foo(TwoCompositesSpec): + pass + + with self.assertRaisesRegex( + ValueError, "TypeSpec __main__.Foo has not been registered."): + type_spec.get_name(Foo) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index c1da720c8ff..62dff46c44a 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -1678,14 +1678,13 @@ def zeros_like(x, dtype=None, name=None): Example: - + ```python from tensorflow.keras import backend as K kvar = K.variable(np.random.random((2,3))) kvar_zeros = K.zeros_like(kvar) K.eval(kvar_zeros) # array([[ 0., 0., 0.], [ 0., 0., 0.]], dtype=float32) - - + ``` """ return array_ops.zeros_like(x, dtype=dtype, name=name) @@ -4944,6 +4943,11 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1): # activations cache logits on the `output` Tensor. if hasattr(output, '_keras_logits'): output = output._keras_logits # pylint: disable=protected-access + if from_logits: + warnings.warn( + '"`categorical_crossentropy` received `from_logits=True`, but ' + 'the `output` argument was produced by a sigmoid or softmax ' + 'activation and thus does not represent logits. Was this intended?"') from_logits = True if from_logits: @@ -4999,6 +5003,11 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): # activations cache logits on the `output` Tensor. if hasattr(output, '_keras_logits'): output = output._keras_logits # pylint: disable=protected-access + if from_logits: + warnings.warn( + '"`sparse_categorical_crossentropy` received `from_logits=True`, but ' + 'the `output` argument was produced by a sigmoid or softmax ' + 'activation and thus does not represent logits. Was this intended?"') from_logits = True elif (not from_logits and not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and @@ -5081,6 +5090,11 @@ def binary_crossentropy(target, output, from_logits=False): # activations cache logits on the `output` Tensor. if hasattr(output, '_keras_logits'): output = output._keras_logits # pylint: disable=protected-access + if from_logits: + warnings.warn( + '"`binary_crossentropy` received `from_logits=True`, but the `output`' + ' argument was produced by a sigmoid or softmax activation and thus ' + 'does not represent logits. Was this intended?"') from_logits = True if from_logits: diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py index 8943afaf993..85131056de5 100644 --- a/tensorflow/python/keras/backend_test.py +++ b/tensorflow/python/keras/backend_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import gc +import warnings from absl.testing import parameterized import numpy as np @@ -32,6 +33,7 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util +from tensorflow.python.keras import activations from tensorflow.python.keras import backend from tensorflow.python.keras import combinations from tensorflow.python.keras.engine import input_layer @@ -1735,6 +1737,45 @@ class BackendCrossEntropyLossesTest(test.TestCase, parameterized.TestCase): result = self.evaluate(backend.sparse_categorical_crossentropy(t, p)) self.assertArrayNear(result, [0.002, 0.0005, 0.17], 1e-3) + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) + def test_binary_crossentropy_from_logits_no_warnings(self): + t = backend.constant([[0, 1, 0]]) + logits = backend.constant([[8., 1., 1.]]) + with warnings.catch_warnings(record=True) as w: + self.evaluate(backend.binary_crossentropy(t, logits, from_logits=True)) + self.assertEmpty(w) + + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) + def test_binary_crossentropy_from_logits_with_sigmoid(self): + t = backend.constant([[0, 1, 0]]) + logits = backend.constant([[8., 1., 1.]]) + p = activations.sigmoid(logits) + with warnings.catch_warnings(record=True) as w: + self.evaluate(backend.binary_crossentropy(t, p, from_logits=True)) + self.assertLen(w, 1) + self.assertIn('received `from_logits=True`', str(w[0].message)) + + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) + def test_categorical_crossentropy_from_logits_with_softmax(self): + t = backend.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + logits = backend.constant([[8., 1., 1.], [0., 9., 1.], [2., 3., 5.]]) + p = activations.softmax(logits) + with warnings.catch_warnings(record=True) as w: + self.evaluate(backend.categorical_crossentropy(t, p, from_logits=True)) + self.assertLen(w, 1) + self.assertIn('received `from_logits=True`', str(w[0].message)) + + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) + def test_sparse_categorical_crossentropy_from_logits_with_softmax(self): + t = backend.constant([0, 1, 2]) + logits = backend.constant([[8., 1., 1.], [0., 9., 1.], [2., 3., 5.]]) + p = activations.softmax(logits) + with warnings.catch_warnings(record=True) as w: + self.evaluate( + backend.sparse_categorical_crossentropy(t, p, from_logits=True)) + self.assertLen(w, 1) + self.assertIn('received `from_logits=True`', str(w[0].message)) + @test_util.with_control_flow_v2 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) diff --git a/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py b/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py index ebc343158c0..0fc90150127 100644 --- a/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py +++ b/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py @@ -30,6 +30,23 @@ def _get_benchmark_name(name): return name.split("__")[-1].split("_") +def _get_metadata(name): + return { + "model_name": "ideal_layers", + "parameters": name[1] + "_shape", + } + + +def _generate_benchmark_params(*params_list): + benchmark_params = [] + for params in params_list: + benchmark_params.extend( + [((param[0] + "_CPU",) + param[1:]) for param in params]) + benchmark_params.extend( + [((param[0] + "_GPU",) + param[1:]) for param in params]) + return benchmark_params + + def _layer_call_backward(layer, x): with tf.GradientTape() as tape: y = layer(x) @@ -46,7 +63,7 @@ class KerasLayerBenchmarks(six.with_metaclass( # the benchmark name. It must follow the convention of # "{layer_name}_{small|normal|large}_shape" to make it compatible with # `self.report_benchmark()` method. - _benchmark_parameters = [ + _benchmark_parameters = _generate_benchmark_params([ ("Conv2D_small_shape", tf.keras.layers.Conv2D, {"filters": 1, "kernel_size": 1, "activation": "relu"}, (1, 1, 1, 1), 10000), @@ -57,7 +74,7 @@ class KerasLayerBenchmarks(six.with_metaclass( {"units": 1}, (1, 1, 1), 10000), ("LSTM_normal_shape", tf.keras.layers.LSTM, {"units": 4}, (32, 10, 8), 10000), - ] + ]) def benchmark_layer_call(self, layer_cls, layer_args, input_shape, num_iters): layer = layer_cls(**layer_args) @@ -65,11 +82,8 @@ class KerasLayerBenchmarks(six.with_metaclass( fn = functools.partial(layer, x) name = _get_benchmark_name(self._get_name()) - metadata = { - "model_name": "ideal_layers", - "implementation": name[0] + ".layer.call", - "parameters": name[1] + "_shape" - } + metadata = {"implementation": name[0] + ".layer.call"} + metadata.update(_get_metadata(name)) self.run_report(fn, num_iters, metadata) def benchmark_layer_call_with_function( @@ -80,11 +94,8 @@ class KerasLayerBenchmarks(six.with_metaclass( fn = functools.partial(layer, x) name = _get_benchmark_name(self._get_name()) - metadata = { - "model_name": "ideal_layers", - "implementation": name[0] + ".layer.call.function", - "parameters": name[1] + "_shape" - } + metadata = {"implementation": name[0] + ".layer.call.function"} + metadata.update(_get_metadata(name)) self.run_report(fn, num_iters, metadata) def benchmark_layer_call_with_xla( @@ -96,11 +107,8 @@ class KerasLayerBenchmarks(six.with_metaclass( fn = functools.partial(layer, x) name = _get_benchmark_name(self._get_name()) - metadata = { - "model_name": "ideal_layers", - "implementation": name[0] + ".layer.call.xla", - "parameters": name[1] + "_shape" - } + metadata = {"implementation": name[0] + ".layer.call.xla"} + metadata.update(_get_metadata(name)) self.run_report(fn, num_iters, metadata) def benchmark_layer_call_backward( @@ -110,11 +118,8 @@ class KerasLayerBenchmarks(six.with_metaclass( fn = functools.partial(_layer_call_backward, layer, x) name = _get_benchmark_name(self._get_name()) - metadata = { - "model_name": "ideal_layers", - "implementation": name[0] + ".layer.call.backward", - "parameters": name[1] + "_shape" - } + metadata = {"implementation": name[0] + ".layer.call.backward"} + metadata.update(_get_metadata(name)) self.run_report(fn, num_iters, metadata) def benchmark_layer_call_backward_with_function( @@ -125,11 +130,8 @@ class KerasLayerBenchmarks(six.with_metaclass( fn = functools.partial(_layer_call_backward, layer, x) name = _get_benchmark_name(self._get_name()) - metadata = { - "model_name": "ideal_layers", - "implementation": name[0] + ".layer.call.backward.function", - "parameters": name[1] + "_shape" - } + metadata = {"implementation": name[0] + ".layer.call.backward.function"} + metadata.update(_get_metadata(name)) self.run_report(fn, num_iters, metadata) @@ -137,7 +139,7 @@ class KerasLayerBenchmarksBackwardXLA(six.with_metaclass( benchmark.ParameterizedBenchmark, layer_benchmarks_test_base.LayerBenchmarksBase)): - _benchmark_parameters = [ + _benchmark_parameters = _generate_benchmark_params([ ("Conv2D_small_shape", tf.keras.layers.Conv2D, {"filters": 1, "kernel_size": 1, "activation": "relu"}, (1, 1, 1, 1), 10000), @@ -149,7 +151,7 @@ class KerasLayerBenchmarksBackwardXLA(six.with_metaclass( # {"units": 1}, (1, 1, 1), 10000), # ("LSTM_normal_shape", tf.keras.layers.LSTM, # {"units": 4}, (32, 10, 8), 10000), - ] + ]) def benchmark_layer_call_backward_with_xla( self, layer_cls, layer_args, input_shape, num_iters): @@ -160,11 +162,8 @@ class KerasLayerBenchmarksBackwardXLA(six.with_metaclass( fn = functools.partial(_layer_call_backward, layer, x) name = _get_benchmark_name(self._get_name()) - metadata = { - "model_name": "ideal_layers", - "implementation": name[0] + ".layer.call.backward.xla", - "parameters": name[1] + "_shape" - } + metadata = {"implementation": name[0] + ".layer.call.backward.xla"} + metadata.update(_get_metadata(name)) self.run_report(fn, num_iters, metadata) diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD index 43b5cf63677..34c0fc202db 100644 --- a/tensorflow/python/keras/distribute/BUILD +++ b/tensorflow/python/keras/distribute/BUILD @@ -328,6 +328,7 @@ py_library( "//tensorflow/python/distribute:parameter_server_strategy_v2", "//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/distribute:tpu_strategy", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", "//tensorflow/python/estimator:estimator_py", diff --git a/tensorflow/python/keras/distribute/ctl_correctness_test.py b/tensorflow/python/keras/distribute/ctl_correctness_test.py index 629bc2ac632..b51022f860e 100644 --- a/tensorflow/python/keras/distribute/ctl_correctness_test.py +++ b/tensorflow/python/keras/distribute/ctl_correctness_test.py @@ -33,7 +33,6 @@ from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.framework import test_combinations as combinations -from tensorflow.python.framework import test_util from tensorflow.python.keras.distribute import optimizer_combinations from tensorflow.python.keras.distribute import strategy_combinations from tensorflow.python.ops import math_ops @@ -251,9 +250,6 @@ class TestDistributionStrategyDnnCorrectness(test.TestCase, # TODO(anjs): Identify why this particular V1 optimizer needs a higher tol. if 'FtrlV1' in optimizer_fn._name and 'TPU' in type(distribution).__name__: self.skipTest('Reduced tolerance of the order of 1e-1 required.') - if ('CollectiveAllReduce' in type(distribution).__name__ and - test_util.is_xla_enabled()): - self.skipTest('XLA tests fail with MWMS.') self.dnn_correctness(distribution, optimizer_fn, iteration_type, inside_func, sync_batchnorm) diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py index 69aa1c7e9eb..972eab4b999 100644 --- a/tensorflow/python/keras/distribute/distribute_strategy_test.py +++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py @@ -40,6 +40,7 @@ from tensorflow.python.distribute import parameter_server_strategy_v2 from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import tpu_strategy +from tensorflow.python.distribute import values as ds_values_lib from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.eager import backprop from tensorflow.python.eager import context @@ -2677,6 +2678,43 @@ class TestModelCapturesStrategy(test.TestCase, parameterized.TestCase): loss=keras.losses.MeanSquaredError(), metrics=[keras.metrics.BinaryAccuracy()]) + @ds_combinations.generate( + combinations.combine( + distribution=strategy_combinations.mirrored_strategy_with_one_cpu, + mode=['eager'])) + def test_optimizer(self, distribution): + temp_dir = os.path.join(self.get_temp_dir(), 'ckpt') + + def create_model(): + model = keras.models.Sequential([ + keras.layers.Dense(1), + ]) + model.compile(optimizer='adam', loss='mse') + model.build([None, 1]) # create weights. + self.assertEmpty(model.optimizer.weights) + return model + + model = create_model() + x = y = array_ops.ones(shape=(1, 1)) + model.fit(x=x, y=y, batch_size=1) + model.save_weights(temp_dir) + + with distribution.scope(): + model = create_model() + model.load_weights(temp_dir) + self.assertNotEmpty(model.optimizer.weights) + self.assertIsInstance(model.optimizer.weights[0], + ds_values_lib.DistributedVariable) + + with distribution.scope(): + model = create_model() + # create/restore slot variables outside of scope is fine. + model.load_weights(temp_dir) + self.assertNotEmpty(model.optimizer.weights) + self.assertIsInstance(model.optimizer.weights[0], + ds_values_lib.DistributedVariable) + + if __name__ == '__main__': base_layer_utils.enable_v2_dtype_behavior() multi_process_runner.test_main() diff --git a/tensorflow/python/keras/engine/functional_test.py b/tensorflow/python/keras/engine/functional_test.py index 8427517f235..fea3ee16da7 100644 --- a/tensorflow/python/keras/engine/functional_test.py +++ b/tensorflow/python/keras/engine/functional_test.py @@ -2473,5 +2473,75 @@ class InputsOutputsErrorTest(keras_parameterized.TestCase): model({'1': np.zeros((3, 10)), '2': np.zeros((3, 6))}) +class FunctionalSubclassModel(training_lib.Model): + + def __init__(self, *args, **kwargs): + my_input = input_layer_lib.Input(shape=(16,)) + dense = layers.Dense(32, activation='relu') + output = dense(my_input) + outputs = {'output': output} + super().__init__(inputs=[my_input], outputs=outputs, *args, **kwargs) + + +class MixinClass(object): + + def __init__(self, foo, **kwargs): + self._foo = foo + super().__init__(**kwargs) + + def get_foo(self): + return self._foo + + +class SubclassedModel(training_lib.Model): + + def __init__(self, bar, **kwargs): + self._bar = bar + super().__init__(**kwargs) + + def get_bar(self): + return self._bar + + +class MultipleInheritanceModelTest(keras_parameterized.TestCase): + + def testFunctionalSubclass(self): + m = FunctionalSubclassModel() + # Some smoke test for the weights and output shape of the model + self.assertLen(m.weights, 2) + self.assertEqual(m.outputs[0].shape.as_list(), [None, 32]) + + def testFunctionalSubclassPreMixin(self): + class MixedFunctionalSubclassModel(MixinClass, FunctionalSubclassModel): + pass + + m = MixedFunctionalSubclassModel(foo='123') + self.assertTrue(m._is_graph_network) + self.assertLen(m.weights, 2) + self.assertEqual(m.outputs[0].shape.as_list(), [None, 32]) + self.assertEqual(m.get_foo(), '123') + + def testFunctionalSubclassPostMixin(self): + # Make sure the the mixin class is also init correct when the order changed. + + class MixedFunctionalSubclassModel(FunctionalSubclassModel, MixinClass): + pass + + m = MixedFunctionalSubclassModel(foo='123') + self.assertTrue(m._is_graph_network) + self.assertLen(m.weights, 2) + self.assertEqual(m.outputs[0].shape.as_list(), [None, 32]) + self.assertEqual(m.get_foo(), '123') + + def testSubclassModelPreMixin(self): + class MixedSubclassModel(MixinClass, SubclassedModel): + pass + + m = MixedSubclassModel(foo='123', bar='456') + self.assertFalse(m._is_graph_network) + self.assertEqual(m.get_foo(), '123') + self.assertEqual(m.get_bar(), '456') + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 55f71e3a94c..3feb39172f4 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -115,6 +115,10 @@ def inject_functional_model_class(cls): from tensorflow.python.keras.engine import training_v1 # pylint: disable=g-import-not-at-top if cls == Model or cls == training_v1.Model: return functional.Functional + # In case there is any multiple inheritance, we stop injecting the + # class if keras model is not in its class hierarchy. + if cls == object: + return object cls.__bases__ = tuple(inject_functional_model_class(base) for base in cls.__bases__) @@ -230,8 +234,33 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top if (is_functional_model_init_params(args, kwargs) and not isinstance(self, functional.Functional)): + # Filter the kwargs for multiple inheritance. + supported_kwargs = ['inputs', 'outputs', 'name', 'trainable', 'skip_init'] + model_kwargs = {k: kwargs[k] for k in kwargs if k in supported_kwargs} + other_kwargs = {k: kwargs[k] for k in kwargs if k not in supported_kwargs} inject_functional_model_class(self.__class__) - functional.Functional.__init__(self, *args, **kwargs) + functional.Functional.__init__(self, *args, **model_kwargs) + + # In case there is any multiple inheritance here, we need to call the + # __init__ for any class that appears after the Functional class. + clz_to_init = [] + found_functional_class = False + for clz in self.__class__.__bases__: + if issubclass(clz, functional.Functional): + found_functional_class = True + continue + if found_functional_class: + clz_to_init.append(clz) + + if clz_to_init: + for clz in clz_to_init: + clz.__init__(self, *args, **other_kwargs) + elif other_kwargs: + # In case there are unused kwargs, we should raise an error to user, in + # case they have a typo in the param name. + raise TypeError( + 'The following keyword arguments aren\'t supported: {}'.format( + other_kwargs)) return base_layer.keras_api_gauge.get_cell('Model subclass').set(True) diff --git a/tensorflow/python/keras/layers/advanced_activations.py b/tensorflow/python/keras/layers/advanced_activations.py index 456b6758dc6..06e369dab21 100644 --- a/tensorflow/python/keras/layers/advanced_activations.py +++ b/tensorflow/python/keras/layers/advanced_activations.py @@ -408,7 +408,7 @@ class ReLU(Layer): raise ValueError('threshold of Relu layer ' 'cannot be None. Required a float') - self.support_masking = True + self.supports_masking = True if max_value is not None: max_value = K.cast_to_floatx(max_value) self.max_value = max_value diff --git a/tensorflow/python/keras/layers/advanced_activations_test.py b/tensorflow/python/keras/layers/advanced_activations_test.py index 80eae8e72de..96a0c217cac 100644 --- a/tensorflow/python/keras/layers/advanced_activations_test.py +++ b/tensorflow/python/keras/layers/advanced_activations_test.py @@ -34,37 +34,44 @@ class AdvancedActivationsTest(keras_parameterized.TestCase): for alpha in [0., .5, -1.]: testing_utils.layer_test(keras.layers.LeakyReLU, kwargs={'alpha': alpha}, - input_shape=(2, 3, 4)) + input_shape=(2, 3, 4), + supports_masking=True) def test_prelu(self): testing_utils.layer_test(keras.layers.PReLU, kwargs={}, - input_shape=(2, 3, 4)) + input_shape=(2, 3, 4), + supports_masking=True) def test_prelu_share(self): testing_utils.layer_test(keras.layers.PReLU, kwargs={'shared_axes': 1}, - input_shape=(2, 3, 4)) + input_shape=(2, 3, 4), + supports_masking=True) def test_elu(self): for alpha in [0., .5, -1.]: testing_utils.layer_test(keras.layers.ELU, kwargs={'alpha': alpha}, - input_shape=(2, 3, 4)) + input_shape=(2, 3, 4), + supports_masking=True) def test_thresholded_relu(self): testing_utils.layer_test(keras.layers.ThresholdedReLU, kwargs={'theta': 0.5}, - input_shape=(2, 3, 4)) + input_shape=(2, 3, 4), + supports_masking=True) def test_softmax(self): testing_utils.layer_test(keras.layers.Softmax, kwargs={'axis': 1}, - input_shape=(2, 3, 4)) + input_shape=(2, 3, 4), + supports_masking=True) def test_relu(self): testing_utils.layer_test(keras.layers.ReLU, kwargs={'max_value': 10}, - input_shape=(2, 3, 4)) + input_shape=(2, 3, 4), + supports_masking=True) x = keras.backend.ones((3, 4)) if not context.executing_eagerly(): # Test that we use `leaky_relu` when appropriate in graph mode. @@ -80,7 +87,8 @@ class AdvancedActivationsTest(keras_parameterized.TestCase): ValueError, 'max_value of Relu layer cannot be negative value: -10'): testing_utils.layer_test(keras.layers.ReLU, kwargs={'max_value': -10}, - input_shape=(2, 3, 4)) + input_shape=(2, 3, 4), + supports_masking=True) with self.assertRaisesRegex( ValueError, 'negative_slope of Relu layer cannot be negative value: -2'): diff --git a/tensorflow/python/keras/layers/pooling.py b/tensorflow/python/keras/layers/pooling.py index 51dc5131a8a..6eb3e8e97ac 100644 --- a/tensorflow/python/keras/layers/pooling.py +++ b/tensorflow/python/keras/layers/pooling.py @@ -338,10 +338,11 @@ class MaxPooling2D(Pooling2D): window defined by `pool_size` for each dimension along the features axis. The window is shifted by `strides` in each dimension. The resulting output when using "valid" padding option has a shape(number of rows or columns) of: - `output_shape = (input_shape - pool_size + 1) / strides)` + `output_shape = math.floor((input_shape - pool_size) / strides) + 1` + (when input_shape >= pool_size) The resulting output shape when using the "same" padding option is: - `output_shape = input_shape / strides` + `output_shape = math.floor((input_shape - 1) / strides) + 1` For example, for stride=(1,1) and padding="valid": diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup.py b/tensorflow/python/keras/layers/preprocessing/index_lookup.py index f3f689fda3a..8a1eb3dfbad 100644 --- a/tensorflow/python/keras/layers/preprocessing/index_lookup.py +++ b/tensorflow/python/keras/layers/preprocessing/index_lookup.py @@ -223,7 +223,7 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer): if not reset_state: raise ValueError("IndexLookup does not support streaming adapts.") super(IndexLookup, self).adapt(data, reset_state) - self.max_tokens = self._table_handler.vocab_size() + self.max_tokens = int(self._table_handler.vocab_size()) def get_vocabulary(self): if self._table_handler.vocab_size() == 0: @@ -239,7 +239,7 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer): return [x for _, x in sorted(zip(values, keys))] def vocab_size(self): - return self._table_handler.vocab_size() + return int(self._table_handler.vocab_size()) def get_config(self): config = { @@ -402,7 +402,7 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer): self._set_inverse_vocabulary(vocab) else: self._set_forward_vocabulary(vocab) - self.max_tokens = self._table_handler.vocab_size() + self.max_tokens = int(self._table_handler.vocab_size()) def _set_state_variables(self, updates): if not self.built: diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py index c2accf24e58..4aaf159e11b 100644 --- a/tensorflow/python/keras/layers/wrappers.py +++ b/tensorflow/python/keras/layers/wrappers.py @@ -86,8 +86,8 @@ class Wrapper(Layer): class TimeDistributed(Wrapper): """This wrapper allows to apply a layer to every temporal slice of an input. - The input should be at least 3D, and the dimension of index one - will be considered to be the temporal dimension. + Every input should be at least 3D, and the dimension of index one of the + first input will be considered to be the temporal dimension. Consider a batch of 32 video samples, where each sample is a 128x128 RGB image with `channels_last` data format, across 10 timesteps. @@ -109,7 +109,8 @@ class TimeDistributed(Wrapper): layer: a `tf.keras.layers.Layer` instance. Call arguments: - inputs: Input tensor. + inputs: Input tensor of shape (batch, time, ...) or nested tensors, + and each of which has shape (batch, time, ...). training: Python boolean indicating whether the layer should behave in training mode or in inference mode. This argument is passed to the wrapped layer (only if the layer supports this argument). @@ -141,7 +142,6 @@ class TimeDistributed(Wrapper): The static shapes are replaced with the corresponding dynamic shapes of the tensor. - Arguments: init_tuple: a tuple, the first part of the output shape tensor: the tensor from which to get the (static and dynamic) shapes @@ -150,7 +150,6 @@ class TimeDistributed(Wrapper): the static shape of the tensor int_shape: an alternative static shape to take as the last part of the output shape - Returns: The new int_shape with the first part from init_tuple and the last part from either `int_shape` (if provided) @@ -160,6 +159,8 @@ class TimeDistributed(Wrapper): # replace all None in int_shape by K.shape if int_shape is None: int_shape = K.int_shape(tensor)[start_idx:] + if isinstance(int_shape, tensor_shape.TensorShape): + int_shape = int_shape.as_list() if not any(not s for s in int_shape): return init_tuple + tuple(int_shape) shape = K.shape(tensor) @@ -169,39 +170,56 @@ class TimeDistributed(Wrapper): int_shape[i] = shape[start_idx + i] return init_tuple + tuple(int_shape) + def _remove_timesteps(self, dims): + dims = dims.as_list() + return tensor_shape.TensorShape([dims[0]] + dims[2:]) + def build(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape).as_list() - if len(input_shape) < 3: + input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) + input_dims = nest.flatten( + nest.map_structure(lambda x: x.ndims, input_shape)) + if any(dim < 3 for dim in input_dims): raise ValueError( '`TimeDistributed` Layer should be passed an `input_shape ` ' 'with at least 3 dimensions, received: ' + str(input_shape)) # Don't enforce the batch or time dimension. - self.input_spec = InputSpec(shape=[None, None] + input_shape[2:]) - child_input_shape = [input_shape[0]] + input_shape[2:] + self.input_spec = nest.map_structure( + lambda x: InputSpec(shape=[None, None] + x.as_list()[2:]), input_shape) + child_input_shape = nest.map_structure(self._remove_timesteps, input_shape) + child_input_shape = tf_utils.convert_shapes(child_input_shape) super(TimeDistributed, self).build(tuple(child_input_shape)) self.built = True def compute_output_shape(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape).as_list() - child_input_shape = tensor_shape.TensorShape([input_shape[0]] + - input_shape[2:]) + input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) + + child_input_shape = nest.map_structure(self._remove_timesteps, input_shape) child_output_shape = self.layer.compute_output_shape(child_input_shape) - if not isinstance(child_output_shape, tensor_shape.TensorShape): - child_output_shape = tensor_shape.TensorShape(child_output_shape) - child_output_shape = child_output_shape.as_list() - timesteps = input_shape[1] - return tensor_shape.TensorShape([child_output_shape[0], timesteps] + - child_output_shape[1:]) + child_output_shape = tf_utils.convert_shapes( + child_output_shape, to_tuples=False) + timesteps = tf_utils.convert_shapes(input_shape) + timesteps = nest.flatten(timesteps)[1] + + def insert_timesteps(dims): + dims = dims.as_list() + return tensor_shape.TensorShape([dims[0], timesteps] + dims[1:]) + + return nest.map_structure(insert_timesteps, child_output_shape) def call(self, inputs, training=None, mask=None): kwargs = {} if generic_utils.has_arg(self.layer.call, 'training'): kwargs['training'] = training - input_shape = K.int_shape(inputs) - if input_shape[0] and not self._always_use_reshape: + input_shape = nest.map_structure( + lambda x: tensor_shape.TensorShape(K.int_shape(x)), inputs) + batch_size = tf_utils.convert_shapes(input_shape) + batch_size = nest.flatten(batch_size)[0] + if batch_size and not self._always_use_reshape: inputs, row_lengths = K.convert_inputs_if_ragged(inputs) is_ragged_input = row_lengths is not None + input_length = tf_utils.convert_shapes(input_shape) + input_length = nest.flatten(input_length)[1] # batch size matters, use rnn-based implementation def step(x, _): @@ -212,27 +230,44 @@ class TimeDistributed(Wrapper): step, inputs, initial_states=[], - input_length=row_lengths[0] if is_ragged_input else input_shape[1], + input_length=row_lengths[0] if is_ragged_input else input_length, mask=mask, unroll=False) - y = K.maybe_convert_to_ragged(is_ragged_input, outputs, row_lengths) + # pylint: disable=g-long-lambda + y = nest.map_structure( + lambda output: K.maybe_convert_to_ragged(is_ragged_input, output, + row_lengths), outputs) else: # No batch size specified, therefore the layer will be able # to process batches of any size. # We can go with reshape-based implementation for performance. - if isinstance(inputs, ragged_tensor.RaggedTensor): - y = self.layer(inputs.values, **kwargs) - y = ragged_tensor.RaggedTensor.from_row_lengths( - y, - inputs.nested_row_lengths()[0]) + is_ragged_input = nest.map_structure( + lambda x: isinstance(x, ragged_tensor.RaggedTensor), inputs) + is_ragged_input = nest.flatten(is_ragged_input) + if all(is_ragged_input): + input_values = nest.map_structure(lambda x: x.values, inputs) + input_row_lenghts = nest.map_structure( + lambda x: x.nested_row_lengths()[0], inputs) + y = self.layer(input_values, **kwargs) + y = nest.map_structure(ragged_tensor.RaggedTensor.from_row_lengths, y, + input_row_lenghts) + elif any(is_ragged_input): + raise ValueError('All inputs has to be either ragged or not, ' + 'but not mixed. You passed: {}'.format(inputs)) else: - input_length = input_shape[1] + input_length = tf_utils.convert_shapes(input_shape) + input_length = nest.flatten(input_length)[1] if not input_length: - input_length = array_ops.shape(inputs)[1] - inner_input_shape = self._get_shape_tuple((-1,), inputs, 2) + input_length = nest.map_structure(lambda x: array_ops.shape(x)[1], + inputs) + input_length = generic_utils.to_list(nest.flatten(input_length))[0] + + inner_input_shape = nest.map_structure( + lambda x: self._get_shape_tuple((-1,), x, 2), inputs) # Shape: (num_samples * timesteps, ...). And track the # transformation in self._input_map. - inputs = array_ops.reshape(inputs, inner_input_shape) + inputs = nest.map_structure_up_to(inputs, array_ops.reshape, inputs, + inner_input_shape) # (num_samples * timesteps, ...) if generic_utils.has_arg(self.layer.call, 'mask') and mask is not None: inner_mask_shape = self._get_shape_tuple((-1,), mask, 2) @@ -241,15 +276,19 @@ class TimeDistributed(Wrapper): y = self.layer(inputs, **kwargs) # Shape: (num_samples, timesteps, ...) - output_shape = self.compute_output_shape(input_shape).as_list() - output_shape = self._get_shape_tuple((-1, input_length), y, 1, - output_shape[2:]) - y = array_ops.reshape(y, output_shape) + output_shape = self.compute_output_shape(input_shape) + # pylint: disable=g-long-lambda + output_shape = nest.map_structure( + lambda tensor, int_shape: self._get_shape_tuple( + (-1, input_length), tensor, 1, int_shape[2:]), y, output_shape) + y = nest.map_structure_up_to(y, array_ops.reshape, y, output_shape) if not context.executing_eagerly(): # Set the static shape for the result since it might be lost during # array_ops reshape, eg, some `None` dim in the result could be # inferred. - y.set_shape(self.compute_output_shape(input_shape)) + nest.map_structure_up_to( + y, lambda tensor, shape: tensor.set_shape(shape), y, + self.compute_output_shape(input_shape)) return y @@ -290,9 +329,15 @@ class TimeDistributed(Wrapper): """ # cases need to call the layer.compute_mask when input_mask is None: # Masking layer and Embedding layer with mask_zero - input_shape = K.int_shape(inputs) - if input_shape[0] and not self._always_use_reshape or isinstance( - inputs, ragged_tensor.RaggedTensor): + input_shape = nest.map_structure( + lambda x: tensor_shape.TensorShape(K.int_shape(x)), inputs) + input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) + batch_size = tf_utils.convert_shapes(input_shape) + batch_size = nest.flatten(batch_size)[0] + is_ragged_input = nest.map_structure( + lambda x: isinstance(x, ragged_tensor.RaggedTensor), inputs) + is_ragged_input = generic_utils.to_list(nest.flatten(is_ragged_input)) + if batch_size and not self._always_use_reshape or any(is_ragged_input): # batch size matters, we currently do not handle mask explicitly, or if # the layer always uses reshape approach, or the input is a ragged tensor. return mask @@ -300,8 +345,10 @@ class TimeDistributed(Wrapper): if inner_mask is not None: inner_mask_shape = self._get_shape_tuple((-1,), mask, 2) inner_mask = K.reshape(inner_mask, inner_mask_shape) - inner_input_shape = self._get_shape_tuple((-1,), inputs, 2) - inner_inputs = array_ops.reshape(inputs, inner_input_shape) + inner_input_shape = nest.map_structure( + lambda tensor: self._get_shape_tuple((-1,), tensor, 2), inputs) + inner_inputs = nest.map_structure_up_to(inputs, array_ops.reshape, inputs, + inner_input_shape) output_mask = self.layer.compute_mask(inner_inputs, inner_mask) if output_mask is None: if mask is None: @@ -313,9 +360,11 @@ class TimeDistributed(Wrapper): output_mask = K.any(output_mask, axis=-1) else: # output_mask is not None. We need to reshape it - input_length = input_shape[1] + input_length = tf_utils.convert_shapes(input_shape) + input_length = nest.flatten(input_length)[1] if not input_length: - input_length = K.shape(inputs)[1] + input_length = nest.map_structure(lambda x: K.shape(x)[1], inputs) + input_length = nest.flatten(input_length)[0] output_mask_int_shape = K.int_shape(output_mask) if output_mask_int_shape is None: # if the output_mask does not have a static shape, @@ -323,6 +372,7 @@ class TimeDistributed(Wrapper): if mask is not None: output_mask_int_shape = K.int_shape(mask) else: + input_shape = generic_utils.to_list(nest.flatten(input_shape))[0] output_mask_int_shape = K.compute_output_shape(input_shape)[:-1] output_mask_shape = self._get_shape_tuple( (-1, input_length), output_mask, 1, output_mask_int_shape[1:]) diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py index c60e950794f..d8c6d4ff220 100644 --- a/tensorflow/python/keras/layers/wrappers_test.py +++ b/tensorflow/python/keras/layers/wrappers_test.py @@ -471,6 +471,62 @@ class TimeDistributedTest(keras_parameterized.TestCase): # Make sure the batch dim is not lost after array_ops.reshape. self.assertListEqual(outputs.shape.as_list(), [1, None, 30, 30, 16]) + @keras_parameterized.run_all_keras_modes + def test_TimeDistributed_with_mimo(self): + dense_1 = keras.layers.Dense(8) + dense_2 = keras.layers.Dense(16) + + class TestLayer(keras.layers.Layer): + + def __init__(self): + super(TestLayer, self).__init__() + self.dense_1 = dense_1 + self.dense_2 = dense_2 + + def call(self, inputs): + return self.dense_1(inputs[0]), self.dense_2(inputs[1]) + + def compute_output_shape(self, input_shape): + output_shape_1 = self.dense_1.compute_output_shape(input_shape[0]) + output_shape_2 = self.dense_2.compute_output_shape(input_shape[1]) + return output_shape_1, output_shape_2 + + np.random.seed(100) + layer = TestLayer() + + data_1 = array_ops.constant([[[[1.0], [1.0]], [[2.0], [2.0]]], + [[[4.0], [4.0]], [[5.0], [5.0]]], + [[[7.0], [7.0]], [[8.0], [8.0]]]]) + + data_2 = array_ops.constant([[[[1.0], [1.0]], [[2.0], [2.0]]], + [[[4.0], [4.0]], [[5.0], [5.0]]], + [[[7.0], [7.0]], [[8.0], [8.0]]]]) + + x1 = keras.Input(shape=(None, 2, 1), dtype='float32') + x2 = keras.Input(shape=(None, 2, 1), dtype='float32') + y1, y2 = keras.layers.TimeDistributed(layer)([x1, x2]) + model_1 = keras.models.Model([x1, x2], [y1, y2]) + model_1.compile( + optimizer='rmsprop', + loss='mse', + run_eagerly=testing_utils.should_run_eagerly()) + output_1 = model_1.predict((data_1, data_2), steps=1) + + y1 = dense_1(x1) + y2 = dense_2(x2) + model_2 = keras.models.Model([x1, x2], [y1, y2]) + output_2 = model_2.predict((data_1, data_2), steps=1) + + self.assertAllClose(output_1, output_2) + + model_1.fit( + x=[np.random.random((10, 2, 2, 1)), + np.random.random((10, 2, 2, 1))], + y=[np.random.random((10, 2, 2, 8)), + np.random.random((10, 2, 2, 16))], + epochs=1, + batch_size=3) + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) class BidirectionalTest(test.TestCase, parameterized.TestCase): diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index 1bd13d82834..a7d937344f4 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -855,8 +855,22 @@ class OptimizerV2(trackable.Trackable): """A list of names for this optimizer's slots.""" return self._slot_names - def add_slot(self, var, slot_name, initializer="zeros"): - """Add a new slot variable for `var`.""" + def add_slot(self, var, slot_name, initializer="zeros", shape=None): + """Add a new slot variable for `var`. + + A slot variable is an additional variable associated with `var` to train. + It is allocated and managed by optimizers, e.g. `Adam`. + + Args: + var: a `Variable` object. + slot_name: name of the slot variable. + initializer: initializer of the slot variable + shape: (Optional) shape of the slot variable. If not set, it will default + to the shape of `var`. + + Returns: + A slot variable. + """ if slot_name not in self._slot_names: self._slot_names.append(slot_name) var_key = _var_key(var) @@ -865,26 +879,29 @@ class OptimizerV2(trackable.Trackable): if weight is None: if isinstance(initializer, six.string_types) or callable(initializer): initializer = initializers.get(initializer) + slot_shape = var.shape if shape is None else shape initial_value = functools.partial( - initializer, shape=var.shape, dtype=var.dtype) + initializer, shape=slot_shape, dtype=var.dtype) else: initial_value = initializer - strategy = distribute_ctx.get_strategy() - if not strategy.extended.variable_created_in_scope(var): - raise ValueError( - "Trying to create optimizer slot variable under the scope for " - "tf.distribute.Strategy ({}), which is different from the scope " - "used for the original variable ({}). Make sure the slot " - "variables are created under the same strategy scope. This may " - "happen if you're restoring from a checkpoint outside the scope" - .format(strategy, var)) - with strategy.extended.colocate_vars_with(var): - weight = tf_variables.Variable( - name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access - dtype=var.dtype, - trainable=False, - initial_value=initial_value) + with self._distribution_strategy_scope(): + strategy = distribute_ctx.get_strategy() + if not strategy.extended.variable_created_in_scope(var): + raise ValueError( + "Trying to create optimizer slot variable under the scope for " + "tf.distribute.Strategy ({}), which is different from the scope " + "used for the original variable ({}). Make sure the slot " + "variables are created under the same strategy scope. This may " + "happen if you're restoring from a checkpoint outside the scope" + .format(strategy, var)) + + with strategy.extended.colocate_vars_with(var): + weight = tf_variables.Variable( + name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access + dtype=var.dtype, + trainable=False, + initial_value=initial_value) backend.track_variable(weight) slot_dict[slot_name] = weight self._restore_slot_variable( @@ -1344,7 +1361,13 @@ class OptimizerV2(trackable.Trackable): # a slot variable if not for this case). Deferring is mostly harmless # (aside from double initialization), and makes variable creator scopes # behave the same way they do when graph building. - and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access + # + # One notable case is with distribution strategy, which uses variable + # creator scope but always desires the `variable` and the slot to use + # the same scope, thus we can safely eagerly create/restore slot + # variables. + and (not ops.get_default_graph()._variable_creator_stack or # pylint: disable=protected-access + self._distribution_strategy)): initializer = trackable.CheckpointInitialValueCallable( checkpoint_position=slot_variable_position) slot_variable = self.add_slot( diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py index fecf52e71b3..9045f0ec3de 100644 --- a/tensorflow/python/keras/testing_utils.py +++ b/tensorflow/python/keras/testing_utils.py @@ -100,7 +100,8 @@ def layer_test(layer_cls, validate_training=True, adapt_data=None, custom_objects=None, - test_harness=None): + test_harness=None, + supports_masking=None): """Test routine for a layer with a single input and single output. Arguments: @@ -122,6 +123,8 @@ def layer_test(layer_cls, in the layer class. This is helpful for testing custom layers. test_harness: The Tensorflow test, if any, that this function is being called in. + supports_masking: Optional boolean to check the `supports_masking` property + of the layer. If None, the check will not be performed. Returns: The output data (Numpy array) returned by the layer, for additional @@ -165,6 +168,13 @@ def layer_test(layer_cls, kwargs = kwargs or {} layer = layer_cls(**kwargs) + if (supports_masking is not None + and layer.supports_masking != supports_masking): + raise AssertionError( + 'When testing layer %s, the `supports_masking` property is %r' + 'but expected to be %r.\nFull kwargs: %s' % + (layer_cls.__name__, layer.supports_masking, supports_masking, kwargs)) + # Test adapt, if data was passed. if adapt_data is not None: layer.adapt(adapt_data) diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index b6f3b6b8c82..8fa5e2b33f0 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -987,11 +987,13 @@ tf_py_test( ], ) -tf_py_test( +cuda_py_test( name = "segment_reduction_ops_test", size = "medium", srcs = ["segment_reduction_ops_test.py"], shard_count = 10, + # TODO (b/173835746): the test fails with XLA. + xla_enable_strict_auto_jit = False, deps = [ "//tensorflow/python:client", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/kernel_tests/array_ops/BUILD b/tensorflow/python/kernel_tests/array_ops/BUILD index c10c44e1ef4..21e11d3ad62 100644 --- a/tensorflow/python/kernel_tests/array_ops/BUILD +++ b/tensorflow/python/kernel_tests/array_ops/BUILD @@ -23,6 +23,9 @@ cuda_py_test( name = "unstack_op_test", size = "small", srcs = ["unstack_op_test.py"], + tags = [ + "no_cuda_asan", # b/173806679 + ], xla_tags = [ "no_cuda_asan", # times out ], @@ -54,6 +57,7 @@ cuda_py_test( name = "gather_op_test", size = "medium", srcs = ["gather_op_test.py"], + tags = ["no_cuda_asan"], # b/173806733 deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/kernel_tests/collective_ops_test.py b/tensorflow/python/kernel_tests/collective_ops_test.py index fdaf3213759..bb6c04019e9 100644 --- a/tensorflow/python/kernel_tests/collective_ops_test.py +++ b/tensorflow/python/kernel_tests/collective_ops_test.py @@ -595,6 +595,7 @@ class OpCancellationTest(test.TestCase, parameterized.TestCase): mode='eager'), device_combination)) def testOpErrorNotAbortWithCollective(self, collective_op, device, communication): + self.skipTest('b/173733368: currently it may timeout on guitar.') # Do not abort v2 collective ops even if there're active collective ops at # the time of an op error. We rely cancellation to terminate active # collective ops. diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py index 2011b3b4b45..7ccf4b89e67 100644 --- a/tensorflow/python/kernel_tests/cond_v2_test.py +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -33,6 +33,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import cond_v2 from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops @@ -1296,6 +1297,24 @@ class CondV2Test(test.TestCase): i = f(constant_op.constant(False)) self.assertEqual(self.evaluate(i), 2.0) + def testGradientOfMixedOptionals(self): + + @def_function.function + def f(c): + x = constant_op.constant(1., name="x") + + def then_branch(): + return x ** 2., gen_dataset_ops.optional_from_value( + [constant_op.constant(1)]) + + def else_branch(): + return x ** 3., gen_dataset_ops.optional_from_value( + [constant_op.constant(1.)]) + + y, _ = cond_v2.cond_v2(c, then_branch, else_branch) + return gradients_impl.gradients(y, x) + self.assertAllClose([2.], f(constant_op.constant(True))) + class CondV2CollectionTest(test.TestCase): diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py index 95483881629..81f105899f3 100644 --- a/tensorflow/python/kernel_tests/relu_op_test.py +++ b/tensorflow/python/kernel_tests/relu_op_test.py @@ -19,9 +19,7 @@ from __future__ import division from __future__ import print_function import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.python import tf2 from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -29,7 +27,6 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import gradient_checker_v2 -from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables @@ -117,45 +114,19 @@ class ReluTest(test.TestCase): order="F") err = gradient_checker_v2.max_error(*gradient_checker_v2.compute_gradient( nn_ops.relu, [x], delta=1.0 / 1024)) - self.assertLess(err, 1e-4) + self.assertLess(err, 1e-6) - # The gradient for fp16 is inaccurate due to the low-precision. - # We compare the fp16 analytical gradient against their fp32 counterpart. + # The gradient test for ReLU is a bit tricky as the derivative is not well + # defined at around zero and we want to avoid that in terms of input values. def testGradientFloat16(self): - - def grad(x): - with backprop.GradientTape() as tape: - tape.watch(x) - y = nn_ops.l2_loss(nn_ops.relu(x)) - return tape.gradient(y, x) - - def f(): - with test_util.use_gpu(): - # Randomly construct a 1D shape from [1, 40) - shape = random_ops.random_uniform([1], - minval=1, - maxval=40, - dtype=dtypes.int32) - x32 = random_ops.random_uniform(shape, minval=-1, maxval=1) - x16 = math_ops.cast(x32, dtype=dtypes.float16) - return grad(x32), grad(x16) - - # We're going to ensure that the fp16 and fp32 gradients - # are "close" to each other for ~100 random values. - # - # In TensorFlow 1.x, invoking f() (without eager execution enabled) - # would construct a graph. Instead of construct a graph with O(100) nodes, - # we construct a single graph to be executed ~100 times in a Session. - if not tf2.enabled(): - d32_tensor, d16_tensor = f() - with self.cached_session() as sess: - f = lambda: sess.run([d32_tensor, d16_tensor]) - - # Repeat the experiment for 100 times. All tensor shapes and its tensor - # values are randomly generated for each run. - for _ in xrange(100): - d32, d16 = f() - self.assertAllClose(d32, d16, atol=3e-4) + with self.cached_session(): + x = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float16, + order="F") + err = gradient_checker_v2.max_error( + *gradient_checker_v2.compute_gradient(nn_ops.relu, [x])) + self.assertLess(err, 1e-6) def testGradientFloat64(self): with self.cached_session(): @@ -165,7 +136,7 @@ class ReluTest(test.TestCase): order="F") err = gradient_checker_v2.max_error(*gradient_checker_v2.compute_gradient( nn_ops.relu, [x], delta=1.0 / 1024)) - self.assertLess(err, 1e-10) + self.assertLess(err, 1e-15) def testGradGradFloat32(self): with self.cached_session(): diff --git a/tensorflow/python/kernel_tests/summary_ops_test.py b/tensorflow/python/kernel_tests/summary_ops_test.py index e64e75bde0d..d5cd9d7bf43 100644 --- a/tensorflow/python/kernel_tests/summary_ops_test.py +++ b/tensorflow/python/kernel_tests/summary_ops_test.py @@ -1216,6 +1216,26 @@ class SummaryOpsTest(test_util.TensorFlowTestCase): # Reset to default state for other tests. summary_ops.set_step(None) + @test_util.run_v2_only + def testTrace_withProfiler(self): + + @def_function.function + def f(): + x = constant_op.constant(2) + y = constant_op.constant(3) + return x**y + + assert context.executing_eagerly() + logdir = self.get_temp_dir() + writer = summary_ops.create_file_writer(logdir) + summary_ops.trace_on(graph=True, profiler=True) + profiler_outdir = self.get_temp_dir() + with writer.as_default(): + f() + summary_ops.trace_export( + name='foo', step=1, profiler_outdir=profiler_outdir) + writer.close() + @test_util.run_v2_only def testGraph_graph(self): diff --git a/tensorflow/python/kernel_tests/template_mirrored_strategy_test.py b/tensorflow/python/kernel_tests/template_mirrored_strategy_test.py index df397d449c3..d18f66f74e3 100644 --- a/tensorflow/python/kernel_tests/template_mirrored_strategy_test.py +++ b/tensorflow/python/kernel_tests/template_mirrored_strategy_test.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import init_ops from tensorflow.python.ops import template @@ -29,28 +30,29 @@ from tensorflow.python.platform import test class TemplateMirroredStrategyTest(test.TestCase): - @test_util.run_deprecated_v1 @test_util.disable_tfrt("Strategy not supported yet.") def test_merge_call(self): - if not test.is_gpu_available(): - self.skipTest("No GPU available") + with ops.Graph().as_default(): + # The test is testing a v1 only function. + if not test.is_gpu_available(): + self.skipTest("No GPU available") - def fn(): - var1 = variable_scope.get_variable( - "var1", shape=[], initializer=init_ops.constant_initializer(21.)) - ds_context.get_replica_context().merge_call(lambda _: ()) - var2 = variable_scope.get_variable( - "var2", shape=[], initializer=init_ops.constant_initializer(2.)) - return var1 * var2 + def fn(): + var1 = variable_scope.get_variable( + "var1", shape=[], initializer=init_ops.constant_initializer(21.)) + ds_context.get_replica_context().merge_call(lambda _: ()) + var2 = variable_scope.get_variable( + "var2", shape=[], initializer=init_ops.constant_initializer(2.)) + return var1 * var2 - temp = template.make_template("my_template", fn) + temp = template.make_template("my_template", fn) - strategy = mirrored_strategy.MirroredStrategy(["/cpu:0", "/gpu:0"]) - out = strategy.experimental_local_results( - strategy.run(temp)) + strategy = mirrored_strategy.MirroredStrategy(["/cpu:0", "/gpu:0"]) + out = strategy.experimental_local_results( + strategy.run(temp)) - self.evaluate(variables.global_variables_initializer()) - self.assertAllEqual([42., 42.], self.evaluate(out)) + self.evaluate(variables.global_variables_initializer()) + self.assertAllEqual([42., 42.], self.evaluate(out)) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py index f3902fb28f3..a161b4b1720 100644 --- a/tensorflow/python/kernel_tests/while_v2_test.py +++ b/tensorflow/python/kernel_tests/while_v2_test.py @@ -288,9 +288,8 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): dx = _TapeFromGraphMode(x) theoretical, numerical = gradient_checker_v2.compute_gradient( target_function, [x]) - self.assertAllClose(numerical, theoretical, rtol=1e-3) - self.assertAllClose(array_ops.reshape(numerical, []), - dx, rtol=1e-3) + self.assertAllClose(numerical, theoretical, rtol=3e-3) + self.assertAllClose(array_ops.reshape(numerical, []), dx, rtol=3e-3) def testDeviceLabelsInherited(self): def _LoopBody(i, y): diff --git a/tensorflow/python/ops/batch_ops_test.py b/tensorflow/python/ops/batch_ops_test.py index 5749be96033..15c670e439d 100644 --- a/tensorflow/python/ops/batch_ops_test.py +++ b/tensorflow/python/ops/batch_ops_test.py @@ -25,12 +25,17 @@ import numpy as np from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import function +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.framework.errors import InvalidArgumentError from tensorflow.python.ops import array_ops from tensorflow.python.ops import batch_ops from tensorflow.python.ops import gen_batch_ops +from tensorflow.python.ops import gen_functional_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import script_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -50,7 +55,7 @@ class BatchOpsTest(test.TestCase): """Tests that a single batched tensor executes together and only once.""" if context.executing_eagerly(): return - with self.cached_session() as sess: + with self.cached_session(use_gpu=True) as sess: inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) batched, index, _ = batch_ops.batch( [inp], num_batch_threads=1, max_batch_size=2, @@ -92,7 +97,7 @@ class BatchOpsTest(test.TestCase): """Test that batching with padding up to an allowed batch size works.""" if context.executing_eagerly(): return - with self.cached_session() as sess: + with self.cached_session(use_gpu=True) as sess: inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2]) batched, index, _ = batch_ops.batch( [inp], num_batch_threads=1, max_batch_size=10, @@ -124,7 +129,7 @@ class BatchOpsTest(test.TestCase): """Tests that multiple batched tensors execute together.""" if context.executing_eagerly(): return - with self.cached_session() as sess: + with self.cached_session(use_gpu=True) as sess: inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) batched, _, _ = batch_ops.batch( @@ -165,7 +170,7 @@ class BatchOpsTest(test.TestCase): """Tests illegally feeding tensors with different dim0 sizes.""" if context.executing_eagerly(): return - with self.cached_session() as sess: + with self.cached_session(use_gpu=True) as sess: inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2]) batched, index, _ = batch_ops.batch( @@ -181,7 +186,7 @@ class BatchOpsTest(test.TestCase): """Tests that batch and unbatch work together.""" if context.executing_eagerly(): return - with self.cached_session() as sess: + with self.cached_session(use_gpu=True) as sess: inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) batched, index, id_t = batch_ops.batch( [inp], num_batch_threads=1, max_batch_size=10, @@ -207,7 +212,7 @@ class BatchOpsTest(test.TestCase): """Tests that the batch_function decorator works.""" if context.executing_eagerly(): return - with self.cached_session() as sess: + with self.cached_session(use_gpu=True) as sess: # TODO(apassos): Removing this line causes test flakiness! Ideally should # be investigated. default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable @@ -235,33 +240,62 @@ class BatchOpsTest(test.TestCase): """Tests that the batch_function decorator works.""" if context.executing_eagerly(): return - with self.cached_session() as sess: - captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) - captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) + with self.cached_session(use_gpu=True) as sess: + captured_inp0 = array_ops.placeholder_with_default(2., shape=[]) + captured_inp1 = resource_variable_ops.ResourceVariable(3.) + with ops.device("/cpu:0"): + captured_inp2 = resource_variable_ops.ResourceVariable(4.) @batch_ops.batch_function(1, 10, 100000) def computation(in_t): - return in_t + captured_inp0 - captured_inp1 + return in_t + captured_inp0 + captured_inp1 + captured_inp2 - inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1]) result = computation(inp) thread_results = [] def worker(): thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + sess.run(variables.global_variables_initializer()) worker_thread = threading.Thread(target=worker) worker_thread.start() main_results = sess.run([result], feed_dict={inp: [2]}) worker_thread.join() - self.assertEqual(thread_results[0], [2]) - self.assertEqual(main_results[0], [3]) + self.assertEqual(thread_results[0], [10]) + self.assertEqual(main_results[0], [11]) + + @test_util.disable_xla("DeviceIndex returns sentinel value with XLA") + def testBatchDecoratedGpu(self): + if context.executing_eagerly(): + return + with self.cached_session(use_gpu=True) as sess: + + @batch_ops.batch_function(1, 10, 100000) + def computation(in_t): + # index is 0 on CPU and 1 on GPU + index = gen_functional_ops.DeviceIndex(device_names=["CPU", "GPU"]) + return in_t + math_ops.cast(index, dtypes.float32) + + inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1]) + result = computation(inp) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [10.]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [20.]}) + worker_thread.join() + self.assertEqual(thread_results[0], [10 + test_util.is_gpu_available()]) + self.assertEqual(main_results[0], [20 + test_util.is_gpu_available()]) def testBatchFunctionOp(self): """Tests that the batch_function op works.""" if context.executing_eagerly(): return - with self.cached_session() as sess: + with self.cached_session(use_gpu=True) as sess: @function.Defun(dtypes.int32) def computation(in_t): @@ -292,7 +326,7 @@ class BatchOpsTest(test.TestCase): """Tests that batch_function op works with captured input.""" if context.executing_eagerly(): return - with self.cached_session() as sess: + with self.cached_session(use_gpu=True) as sess: captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) @@ -328,7 +362,7 @@ class BatchOpsTest(test.TestCase): """Tests that batch_function op works with error in the inputs.""" if context.executing_eagerly(): return - with self.cached_session() as sess: + with self.cached_session(use_gpu=True) as sess: inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) @function.Defun(dtypes.int32, dtypes.int32) @@ -345,8 +379,9 @@ class BatchOpsTest(test.TestCase): captured_tensors=computation.captured_inputs, Tout=[o.type for o in computation.definition.signature.output_arg]) - with self.assertRaisesRegex(InvalidArgumentError, - ".*2 arguments.*but 1.*"): + with self.assertRaisesRegex( + InvalidArgumentError, + r"Function takes 2 argument\(s\) but 1 argument\(s\) were passed"): sess.run([result], feed_dict={inp: [2]}) def testBatchFunctionOpWithLargeBatchSplitted(self): @@ -354,7 +389,7 @@ class BatchOpsTest(test.TestCase): if context.executing_eagerly(): return - with self.cached_session() as sess: + with self.cached_session(use_gpu=True) as sess: @function.Defun(dtypes.int32) def computation(in_t): @@ -408,7 +443,7 @@ class BatchOpsTest(test.TestCase): """Tests that the batch_function decorator works.""" if context.executing_eagerly(): return - with self.cached_session() as sess: + with self.cached_session(use_gpu=True) as sess: @batch_ops.batch_function(1, 10, 100000) def computation(in_t): @@ -432,7 +467,7 @@ class BatchOpsTest(test.TestCase): """Tests that the unbatch timeout works.""" if context.executing_eagerly(): return - with self.cached_session() as sess: + with self.cached_session(use_gpu=True) as sess: inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) batched, index, id_t = batch_ops.batch( [inp], num_batch_threads=1, max_batch_size=2, diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py index 282eda948db..92c09f7dcb7 100644 --- a/tensorflow/python/ops/cond_v2.py +++ b/tensorflow/python/ops/cond_v2.py @@ -25,6 +25,7 @@ from __future__ import print_function import collections +from tensorflow.core.framework import types_pb2 from tensorflow.python.eager import backprop_util from tensorflow.python.framework import auto_control_deps from tensorflow.python.framework import auto_control_deps_utils as acd @@ -829,14 +830,22 @@ def _copy_handle_data(external_tensors, *branch_graph_outputs): internal_handle_data.append(handle_data) else: # There is handle data, so we need to combine it. combined_shape = tensor_shape.TensorShape(None) + combined_dtype = None for handle_data in internal_handle_data: handle_shape = tensor_shape.TensorShape( handle_data.shape_and_type[0].shape) combined_shape = combined_shape.most_specific_compatible_shape( handle_shape) + if combined_dtype is None: + combined_dtype = handle_data.shape_and_type[0].dtype + elif handle_data.shape_and_type[0].dtype != combined_dtype: + # Variants from different branches have different dtypes. The + # combined variant has no static dtype. + combined_dtype = types_pb2.DT_INVALID combined_handle_data = internal_handle_data[0] combined_handle_data.shape_and_type[0].shape.CopyFrom( combined_shape.as_proto()) + combined_handle_data.shape_and_type[0].dtype = combined_dtype handle_data_util.set_handle_data(external, combined_handle_data) diff --git a/tensorflow/python/ops/gradient_checker_v2.py b/tensorflow/python/ops/gradient_checker_v2.py index 3ca0903c80c..ce5a4f76678 100644 --- a/tensorflow/python/ops/gradient_checker_v2.py +++ b/tensorflow/python/ops/gradient_checker_v2.py @@ -292,7 +292,7 @@ def _compute_gradient_list(f, xs, delta): @tf_export("test.compute_gradient", v1=[]) -def compute_gradient(f, x, delta=1e-3): +def compute_gradient(f, x, delta=None): """Computes the theoretical and numeric Jacobian of `f`. With y = f(x), computes the theoretical and numeric Jacobian dy/dx. @@ -329,6 +329,12 @@ def compute_gradient(f, x, delta=1e-3): raise ValueError( "`x` must be a list or tuple of values convertible to a Tensor " "(arguments to `f`), not a %s" % type(x)) + if delta is None: + # By default, we use a step size for the central finite difference + # approximation that is exactly representable as a binary floating + # point number, since this reduces the amount of noise due to rounding + # in the approximation of some functions. + delta = 1.0 / 1024 return _compute_gradient_list(f, x, delta) diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 28a5a9f8509..ac4ba701713 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -5137,6 +5137,7 @@ def non_max_suppression_padded(boxes, selected_indices = tf.slice( selected_indices_padded, tf.constant([0]), num_valid) selected_boxes = tf.gather(boxes, selected_indices) + ``` Args: boxes: a tensor of rank 2 or higher with a shape of [..., num_boxes, 4]. diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py index 145c2b0195c..63773ee0f95 100644 --- a/tensorflow/python/ops/lookup_ops.py +++ b/tensorflow/python/ops/lookup_ops.py @@ -458,7 +458,7 @@ class DatasetInitializer(TableInitializerBase): """ def __init__(self, dataset): - """Creates a table initializser from a `tf.data.Dataset`. + """Creates a table initializer from a `tf.data.Dataset`. Args: dataset: A `tf.data.Dataset` object that produces tuples of scalars. The @@ -633,8 +633,7 @@ class TextFileInitializer(TableInitializerBase): >>> init = tf.lookup.TextFileInitializer( ... filename=f.name, ... key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE, - ... value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER, - ... delimiter=" ") + ... value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER) >>> table = tf.lookup.StaticHashTable(init, -1) >>> table.lookup(tf.constant('palmer 30')).numpy() 2 diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py index f641687e990..b7650846d33 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py @@ -45,6 +45,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_v2_toggles from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gen_list_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gradient_checker_v2 @@ -1159,6 +1160,21 @@ class TensorListTest(PForTestCase): self._test_loop_fn(loop_fn, 2) +class OptionalTest(PForTestCase): + + def test_optional_from_value(self): + + def loop_fn(i): + o = gen_dataset_ops.optional_from_value( + [i, i + 1, constant_op.constant(3)]) + gen_dataset_ops.optional_none() + return gen_dataset_ops.optional_get_value( + o, [dtypes.int32, dtypes.int32, dtypes.int32], + [[], [], []]) + + self._test_loop_fn(loop_fn, 2) + + class StackTest(PForTestCase): @test_util.run_v1_only("b/122612051") diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py index 2b02d5e30d3..c9431fa8fa7 100644 --- a/tensorflow/python/ops/parallel_for/pfor.py +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -46,6 +46,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gen_image_ops from tensorflow.python.ops import gen_linalg_ops from tensorflow.python.ops import gen_list_ops @@ -83,22 +84,45 @@ def _variant_handle_data(t): handle_data = resource_variable_ops.get_eager_safe_handle_data(t) if not handle_data.is_set: return None - if len(handle_data.shape_and_type) != 1: - raise ValueError("Expected handle data of length 1, got {!r} of length {}" - .format(handle_data, len(handle_data.shape_and_type))) - return handle_data.shape_and_type[0] + return handle_data.shape_and_type -def _is_tensor_list(t): - """True if `t` is a TensorList, False if it isn't, None if unknown.""" +def _is_variant_with_internal_stacking(t): + """Identifies variant tensors which pfor always maintains as scalars. + + For these, the pfor tensor is recorded as "stacked" if the content of the + variant tensor (e.g. the elements of a TensorList) are all stacked. + + Args: + t: A tensor to identify. + Returns: + True if `t` is a TensorList/Optional, False not, None if unknown. + """ if t.dtype != dtypes.variant: return False - shape_and_type = _variant_handle_data(t) - if shape_and_type is None: - # TODO(b/169968286): Identify all variant tensors (e.g. optionals) and we - # can make this an error instead of assuming TensorLists have handle data. - return None # Presumed not a TensorList - return shape_and_type.specialized_type == types_pb2.ST_TENSOR_LIST + shapes_and_types = _variant_handle_data(t) + if shapes_and_types is None or not shapes_and_types: + # TODO(b/169968286): Identify all variant tensors (e.g. maps) and we can + # make this an error instead of assuming TensorLists have handle data. + return None # Presumed not a TensorList/Optional + return (shapes_and_types[0].specialized_type == types_pb2.ST_TENSOR_LIST or + shapes_and_types[0].specialized_type == types_pb2.ST_OPTIONAL) + + +def _parse_variant_shapes_and_types(t): + """Extracts shape and dtype information from a variant tensor `t`.""" + shapes_and_types = _variant_handle_data(t) + if shapes_and_types is None or not shapes_and_types: + raise ValueError("Required handle data not set for {!r}".format(t)) + if shapes_and_types[0].specialized_type == types_pb2.ST_TENSOR_LIST: + return shapes_and_types + else: + if shapes_and_types[0].specialized_type != types_pb2.ST_INVALID: + return shapes_and_types + else: + raise ValueError( + "Attempted to stack a variant-dtype tensor with no type set ({!r})" + .format(t)) def _stack(t, length): @@ -109,23 +133,19 @@ def _stack(t, length): # suitable since operations on stacked handles may expect a vectorized version # of the variant. if t.dtype == dtypes.variant: - shape_and_type = _variant_handle_data(t) - if shape_and_type is None: - raise ValueError("Required handle data not set for {!r}".format(t)) - if shape_and_type.specialized_type == types_pb2.ST_TENSOR_LIST: + shapes_and_types = _parse_variant_shapes_and_types(t) + if shapes_and_types[0].specialized_type == types_pb2.ST_TENSOR_LIST: + if len(shapes_and_types) != 1: + raise ValueError( + "Expected handle data of length 1, got {!r} of length {}" + .format(shapes_and_types, len(shapes_and_types))) return wrap( - _stack_tensor_list(t, shape_and_type.dtype, length), + _stack_tensor_list(t, shapes_and_types[0].dtype, length), True) else: - if shape_and_type.specialized_type != types_pb2.ST_INVALID: - raise ValueError( - ("Attempted to stack an unhandled variant-dtype tensor of " - "type {!r} ({!r})").format( - shape_and_type.specialized_type, t)) - else: - raise ValueError( - "Attempted to stack a variant-dtype tensor with no type set ({!r})" - .format(t)) + raise ValueError( + ("Attempted to stack an unhandled variant-dtype tensor of " + "type {!r} ({!r})").format(shapes_and_types[0].specialized_type, t)) ones = array_ops.ones_like(array_ops.shape(t)) ones = array_ops.reshape(ones, [-1]) length = array_ops.reshape(length, [-1]) @@ -1629,7 +1649,7 @@ class PFor(object): else: batch_dim = tensor_shape.TensorShape(loop_len) output_shape = batch_dim.concatenate(output_shape) - if _is_tensor_list(new_output.t): + if _is_variant_with_internal_stacking(new_output.t): new_output.t.set_shape([]) else: new_output.t.set_shape(output_shape) @@ -3602,7 +3622,7 @@ def _stack_tensor_list_shape(shape, first_dim): def _tile_variant_with_length(t, length): """stacks `t` `length` times.""" - if _is_tensor_list(t): + if _is_variant_with_internal_stacking(t): # The content of TensorLists is vectorized, not the variant itself. return t original_tensor = t @@ -3622,16 +3642,41 @@ def _tile_variant(t, pfor_input): def _untile_variant(t): - if _is_tensor_list(t): + if _is_variant_with_internal_stacking(t): # The content of TensorLists is vectorized, not the variant itself. if not t.shape.is_compatible_with([]): raise AssertionError( - "Unexpectedly saw a TensorList with non-scalar shape: {!r}" - .format(t)) + ("Unexpectedly saw a vectorized variant (e.g. TensorList) with " + "non-scalar shape: {!r}").format(t)) return t return array_ops.gather(t, 0) +@RegisterPFor("OptionalFromValue") +def _convert_optional_from_value(pfor_input): + pfor_input.stack_inputs() + return wrap( + gen_dataset_ops.optional_from_value([x.t for x in pfor_input.inputs]), + True) + + +@RegisterPFor("OptionalGetValue") +def _convert_optional_get_value(pfor_input): + handle = pfor_input.stacked_input(0) + output_types = pfor_input.get_attr("output_types") + original_output_shapes = pfor_input.get_attr("output_shapes") + output_shapes = [] + for shape in original_output_shapes: + shape = tensor_shape.TensorShape(shape) + loop_len_shape = tensor_shape.TensorShape( + [tensor_util.constant_value(pfor_input.pfor.loop_len_vector)]) + shape = loop_len_shape.concatenate(shape) + output_shapes.append(shape.as_proto()) + results = gen_dataset_ops.optional_get_value(handle, output_types, + output_shapes) + return [wrap(t, True) for t in results] + + @RegisterPFor("TensorListReserve") def _convert_tensor_list_reserve(pfor_input): element_shape = pfor_input.unstacked_input(0) @@ -4275,7 +4320,7 @@ class WhileV2(object): shape = shape.merge_with(output_shapes[i]) pfor_input = self._pfor_input.input(i) if pfor_input.is_stacked: - if _is_tensor_list(pfor_input.t): + if _is_variant_with_internal_stacking(pfor_input.t): shape = tensor_shape.TensorShape([]).concatenate(shape) else: shape = tensor_shape.TensorShape([None]).concatenate(shape) diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index 53f6a2b0492..8575cdf3da5 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -545,7 +545,31 @@ def py_func_common(func, inp, Tout, stateful=True, name=None): `tf.compat.v1.py_func()` and you must pin the created operation to a device in that server (e.g. using `with tf.device():`). + + Note: It produces tensors of unknown shape and rank as shape inference + does not work on arbitrary Python code. + If you need the shape, you need to set it based on statically + available information. + + E.g. + ```python + import tensorflow as tf + import numpy as np + def make_synthetic_data(i): + return np.cast[np.uint8](i) * np.ones([20,256,256,3], + dtype=np.float32) / 10. + + def preprocess_fn(i): + ones = tf.py_function(make_synthetic_data,[i],tf.float32) + ones.set_shape(tf.TensorShape([None, None, None, None])) + ones = tf.image.resize(ones, [224,224]) + return ones + + ds = tf.data.Dataset.range(10) + ds = ds.map(preprocess_fn) + ``` + Args: func: A Python function, which accepts `ndarray` objects as arguments and returns a list of `ndarray` objects (or a single `ndarray`). This function diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index 3e3751e3ca6..5f6b6453e15 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -111,6 +111,15 @@ def from_dense(tensor, name=None): Only elements not equal to zero will be present in the result. The resulting `SparseTensor` has the same dtype and shape as the input. + >>> sp = tf.sparse.from_dense([0, 0, 3, 0, 1]) + >>> sp.shape.as_list() + [5] + >>> sp.values.numpy() + array([3, 1], dtype=int32) + >>> sp.indices.numpy() + array([[2], + [4]]) + Args: tensor: A dense `Tensor` to be converted to a `SparseTensor`. name: Optional name for the op. @@ -1602,23 +1611,24 @@ def sparse_tensor_to_dense(sp_input, name=None): """Converts a `SparseTensor` into a dense tensor. - This op is a convenience wrapper around `sparse_to_dense` for `SparseTensor`s. + For this sparse tensor with three non-empty values: - For example, if `sp_input` has shape `[3, 5]` and non-empty string values: + >>> sp_input = tf.SparseTensor( + ... dense_shape=[3, 5], + ... values=[7, 8, 9], + ... indices =[[0, 1], + ... [0, 3], + ... [2, 0]]) - [0, 1]: a - [0, 3]: b - [2, 0]: c + The output will be a dense `[3, 5]` tensor with values: - and `default_value` is `x`, then the output will be a dense `[3, 5]` - string tensor with values: + >>> tf.sparse.to_dense(sp_input).numpy() + array([[0, 7, 0, 8, 0], + [0, 0, 0, 0, 0], + [9, 0, 0, 0, 0]], dtype=int32) - [[x a x b x] - [x x x x x] - [c x x x x]] - - Indices must be without repeats. This is only - tested if `validate_indices` is `True`. + Note: Indices must be without repeats. This is only tested if + `validate_indices` is `True`. Args: sp_input: The input `SparseTensor`. diff --git a/tensorflow/python/ops/structured/BUILD b/tensorflow/python/ops/structured/BUILD index 33834f0e914..81e8b37dc7d 100644 --- a/tensorflow/python/ops/structured/BUILD +++ b/tensorflow/python/ops/structured/BUILD @@ -18,17 +18,60 @@ py_library( srcs_version = "PY2AND3", tags = ["nofixdeps"], deps = [ + ":structured_array_ops", ":structured_tensor", ], ) py_library( name = "structured_tensor", - srcs = ["structured_tensor.py"], + srcs = [ + "structured_array_ops.py", + "structured_tensor.py", + ], deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:composite_tensor", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_spec", + "//tensorflow/python:type_spec", + "//tensorflow/python:util", + "//tensorflow/python/ops/ragged:ragged_factory_ops", "//tensorflow/python/ops/ragged:ragged_tensor", + "//tensorflow/python/ops/ragged:row_partition", + "//third_party/py/numpy", + ], +) + +py_library( + name = "structured_array_ops", + srcs = [ + "structured_array_ops.py", + ], + deps = [ + ":structured_tensor", + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:composite_tensor", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_spec", + "//tensorflow/python:type_spec", + "//tensorflow/python:util", + "//tensorflow/python/ops/ragged:ragged_factory_ops", + "//tensorflow/python/ops/ragged:ragged_tensor", + "//tensorflow/python/ops/ragged:row_partition", + "//third_party/py/numpy", ], ) @@ -37,13 +80,23 @@ py_test( srcs = ["structured_tensor_test.py"], python_version = "PY3", deps = [ + ":structured_array_ops", ":structured_tensor", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", "//tensorflow/python:sparse_tensor", "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_spec", + "//tensorflow/python/eager:context", "//tensorflow/python/ops/ragged:ragged_factory_ops", "//tensorflow/python/ops/ragged:ragged_tensor", + "//tensorflow/python/ops/ragged:row_partition", + "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/python/ops/structured/structured_array_ops.py b/tensorflow/python/ops/structured/structured_array_ops.py new file mode 100644 index 00000000000..dca8084575e --- /dev/null +++ b/tensorflow/python/ops/structured/structured_array_ops.py @@ -0,0 +1,157 @@ +# Lint as python3 +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""StructuredTensor array ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.ragged.row_partition import RowPartition +from tensorflow.python.ops.structured.structured_tensor import StructuredTensor +from tensorflow.python.util import deprecation +from tensorflow.python.util import dispatch + + +@dispatch.dispatch_for_types(array_ops.expand_dims, StructuredTensor) +@deprecation.deprecated_args(None, 'Use the `axis` argument instead', 'dim') +def expand_dims(input, axis=None, name=None, dim=None): # pylint: disable=redefined-builtin + """Creates a StructuredTensor with a length 1 axis inserted at index `axis`. + + This is an implementation of tf.expand_dims for StructuredTensor. Note + that the `axis` must be less than or equal to rank. + + >>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]]) + >>> tf.expand_dims(st, 0).to_pyval() + [[[{'x': 1}, {'x': 2}], [{'x': 3}]]] + >>> tf.expand_dims(st, 1).to_pyval() + [[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]] + >>> tf.expand_dims(st, 2).to_pyval() + [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]] + >>> tf.expand_dims(st, -1).to_pyval() # -1 is the same as 2 + [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]] + + Args: + input: the original StructuredTensor. + axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank` + name: the name of the op. + dim: deprecated: use axis. + + Returns: + a new structured tensor with larger rank. + + Raises: + an error if `axis < -(rank + 1)` or `rank < axis`. + """ + axis = deprecation.deprecated_argument_lookup('axis', axis, 'dim', dim) + return _expand_dims_impl(input, axis, name=name) + + +@dispatch.dispatch_for_types(array_ops.expand_dims_v2, StructuredTensor) +def expand_dims_v2(input, axis, name=None): # pylint: disable=redefined-builtin + """Creates a StructuredTensor with a length 1 axis inserted at index `axis`. + + This is an implementation of tf.expand_dims for StructuredTensor. Note + that the `axis` must be less than or equal to rank. + + >>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]]) + >>> tf.expand_dims(st, 0).to_pyval() + [[[{'x': 1}, {'x': 2}], [{'x': 3}]]] + >>> tf.expand_dims(st, 1).to_pyval() + [[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]] + >>> tf.expand_dims(st, 2).to_pyval() + [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]] + >>> tf.expand_dims(st, -1).to_pyval() # -1 is the same as 2 + [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]] + + Args: + input: the original StructuredTensor. + axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank` + name: the name of the op. + + Returns: + a new structured tensor with larger rank. + + Raises: + an error if `axis < -(rank + 1)` or `rank < axis`. + """ + return _expand_dims_impl(input, axis, name=name) + + +def _expand_dims_impl(st, axis, name=None): # pylint: disable=redefined-builtin + """Creates a StructuredTensor with a length 1 axis inserted at index `axis`. + + This is an implementation of tf.expand_dims for StructuredTensor. Note + that the `axis` must be less than or equal to rank. + + >>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]]) + >>> tf.expand_dims(st, 0).to_pyval() + [[[{'x': 1}, {'x': 2}], [{'x': 3}]]] + >>> tf.expand_dims(st, 1).to_pyval() + [[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]] + >>> tf.expand_dims(st, 2).to_pyval() + [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]] + >>> tf.expand_dims(st, -1).to_pyval() # -1 is the same as 2 + [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]] + + Args: + st: the original StructuredTensor. + axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank` + name: the name of the op. + + Returns: + a new structured tensor with larger rank. + + Raises: + an error if `axis < -(rank + 1)` or `rank < axis`. + """ + axis = array_ops.get_positive_axis( + axis, st.rank + 1, axis_name='axis', ndims_name='rank(st)') + with ops.name_scope(name, 'ExpandDims', [st, axis]): + new_fields = { + k: array_ops.expand_dims(v, axis) + for (k, v) in st._fields.items() + } + new_shape = st.shape[:axis] + (1,) + st.shape[axis:] + new_row_partitions = _expand_st_row_partitions(st, axis) + new_nrows = st.nrows() if (axis > 0) else 1 + return StructuredTensor.from_fields( + new_fields, + shape=new_shape, + row_partitions=new_row_partitions, + nrows=new_nrows) + + +def _expand_st_row_partitions(st, axis): + """Create the row_partitions for expand_dims.""" + if axis == 0: + if st.shape.rank == 0: + return () + nvals = st.nrows() + new_partition = RowPartition.from_uniform_row_length( + nvals, nvals, nrows=1, validate=False) + return (new_partition,) + st.row_partitions + elif axis == st.rank: + nvals = ( + st.row_partitions[axis - 2].nvals() if (axis - 2 >= 0) else st.nrows()) + return st.row_partitions + (RowPartition.from_uniform_row_length( + 1, nvals, nrows=nvals, validate=False),) + else: + nvals = ( + st.row_partitions[axis - 1].nrows() if (axis - 1 >= 0) else st.nrows()) + return st.row_partitions[:axis - 1] + (RowPartition.from_uniform_row_length( + 1, nvals, nrows=nvals, validate=False),) + st.row_partitions[axis - 1:] diff --git a/tensorflow/python/ops/structured/structured_tensor_test.py b/tensorflow/python/ops/structured/structured_tensor_test.py index 28acfbb3304..9e064bb9dcd 100644 --- a/tensorflow/python/ops/structured/structured_tensor_test.py +++ b/tensorflow/python/ops/structured/structured_tensor_test.py @@ -36,6 +36,10 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import row_partition + +# TODO(b/173144447): remove when structured_array_ops is included in init. +from tensorflow.python.ops.structured import structured_array_ops # pylint: disable=unused-import + from tensorflow.python.ops.structured import structured_tensor from tensorflow.python.ops.structured.structured_tensor import StructuredTensor from tensorflow.python.platform import googletest @@ -500,8 +504,6 @@ class StructuredTensorTest(test_util.TensorFlowTestCase, self.assertAllEqual(struct4.field_value("s"), struct2) def testPartitionOuterDims(self): - if not context.executing_eagerly(): - return # TESTING a = dict(x=1, y=[1, 2]) b = dict(x=2, y=[3, 4]) c = dict(x=3, y=[5, 6]) @@ -977,6 +979,114 @@ class StructuredTensorTest(test_util.TensorFlowTestCase, r"or equal to inner_axis \(1\)"): st.merge_dims(2, 1) + @parameterized.named_parameters([ + dict( + testcase_name="0D_0", + st={"x": 1}, + axis=0, + expected=[{"x": 1}]), + dict( + testcase_name="0D_minus_1", + st={"x": 1}, + axis=-1, + expected=[{"x": 1}]), + dict( + testcase_name="1D_0", + st=[{"x": [1, 3]}, {"x": [2, 7, 9]}], + axis=0, + expected=[[{"x": [1, 3]}, {"x": [2, 7, 9]}]]), + dict( + testcase_name="1D_1", + st=[{"x": [1]}, {"x": [2, 10]}], + axis=1, + expected=[[{"x": [1]}], [{"x": [2, 10]}]]), + dict( + testcase_name="2D_0", + st=[[{"x": [1]}, {"x": [2]}], [{"x": [3, 4]}]], + axis=0, + expected=[[[{"x": [1]}, {"x": [2]}], [{"x": [3, 4]}]]]), + dict( + testcase_name="2D_1", + st=[[{"x": 1}, {"x": 2}], [{"x": 3}]], + axis=1, + expected=[[[{"x": 1}, {"x": 2}]], [[{"x": 3}]]]), + dict( + testcase_name="2D_2", + st=[[{"x": [1]}, {"x": [2]}], [{"x": [3, 4]}]], + axis=2, + expected=[[[{"x": [1]}], [{"x": [2]}]], [[{"x": [3, 4]}]]]), + dict( + testcase_name="3D_0", + st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]], + axis=0, + expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], + [[{"x": [4, 5]}]]]]), + dict( + testcase_name="3D_minus_4", + st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]], + axis=-4, # same as zero + expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], + [[{"x": [4, 5]}]]]]), + dict( + testcase_name="3D_1", + st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]], + axis=1, + expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]]], + [[[{"x": [4, 5]}]]]]), + dict( + testcase_name="3D_minus_3", + st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]], + axis=-3, # same as 1 + expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]]], + [[[{"x": [4, 5]}]]]]), + dict( + testcase_name="3D_2", + st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]], + axis=2, + expected=[[[[{"x": [1]}, {"x": [2]}]], [[{"x": [3]}]]], + [[[{"x": [4, 5]}]]]]), + dict( + testcase_name="3D_minus_2", + st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]], + axis=-2, # same as 2 + expected=[[[[{"x": [1]}, {"x": [2]}]], [[{"x": [3]}]]], + [[[{"x": [4, 5]}]]]]), + dict( + testcase_name="3D_3", + st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]], + axis=3, + expected=[[[[{"x": [1]}], [{"x": [2]}]], [[{"x": [3]}]]], + [[[{"x": [4, 5]}]]]]), + ]) # pyformat: disable + def testExpandDims(self, st, axis, expected): + st = StructuredTensor.from_pyval(st) + result = array_ops.expand_dims(st, axis) + self.assertAllEqual(result, expected) + + def testExpandDimsAxisTooBig(self): + st = [[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]] + st = StructuredTensor.from_pyval(st) + with self.assertRaisesRegex(ValueError, + "axis=4 out of bounds: expected -4<=axis<4"): + array_ops.expand_dims(st, 4) + + def testExpandDimsAxisTooSmall(self): + st = [[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]] + st = StructuredTensor.from_pyval(st) + with self.assertRaisesRegex(ValueError, + "axis=-5 out of bounds: expected -4<=axis<4"): + array_ops.expand_dims(st, -5) + + def testExpandDimsScalar(self): + # Note that if we expand_dims for the final dimension and there are scalar + # fields, then the shape is (2, None, None, 1), whereas if it is constructed + # from pyval it is (2, None, None, None). + st = [[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]] + st = StructuredTensor.from_pyval(st) + result = array_ops.expand_dims(st, 3) + expected_shape = tensor_shape.TensorShape([2, None, None, 1]) + self.assertEqual(repr(expected_shape), repr(result.shape)) + def testTupleFieldValue(self): st = StructuredTensor.from_pyval({"a": 5, "b": {"c": [1, 2, 3]}}) self.assertAllEqual(st.field_value(("a",)), 5) diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py index 6be1ea04968..bd344fccca9 100644 --- a/tensorflow/python/ops/summary_ops_v2.py +++ b/tensorflow/python/ops/summary_ops_v2.py @@ -1385,4 +1385,7 @@ def trace_off(): context.context().disable_run_metadata() if profiler: - _profiler.stop() + try: + _profiler.stop() + except _profiler.ProfilerNotRunningError: + pass diff --git a/tensorflow/python/saved_model/README.md b/tensorflow/python/saved_model/README.md index fe69f3beb01..a5c0aa894f6 100644 --- a/tensorflow/python/saved_model/README.md +++ b/tensorflow/python/saved_model/README.md @@ -3,49 +3,37 @@ [TOC] ## Overview -This document describes SavedModel, the universal serialization format for + +SavedModel is the universal serialization format for [TensorFlow](https://www.tensorflow.org/) models. -SavedModel provides a language-neutral format to save machine-learned models +SavedModel provides a language-neutral format to save machine-learning models that is recoverable and hermetic. It enables higher-level systems and tools to produce, consume and transform TensorFlow models. -## Features +## Guides +* [Using the SavedModel Format](https://www.tensorflow.org/guide/saved_model) +* [Save and load Keras models](https://www.tensorflow.org/guide/keras/save_and_serialize) +* [Save and load with checkpointing in Keras](https://www.tensorflow.org/tutorials/keras/save_and_load) +* [Training checkpoints](https://www.tensorflow.org/guide/checkpoint) +* [Save and load a model using a distribution strategy](https://www.tensorflow.org/tutorials/distribute/save_and_load) -The following is a summary of the features in SavedModel: -* Multiple graphs sharing a single set of variables and assets can be added to a - single SavedModel. Each graph is associated with a specific set of tags to - allow identification during a load or restore operation. -* Support for `SignatureDefs` - * Graphs that are used for inference tasks typically have a set of inputs - and outputs. This is called a `Signature`. - * SavedModel uses [SignatureDefs](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/meta_graph.proto) - to allow generic support for signatures that may need to be saved with the graphs. - * For commonly used SignatureDefs in the context of TensorFlow Serving, - please see documentation [here](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/g3doc/signature_defs.md). -* Support for `Assets`. - * For cases where ops depend on external files for initialization, such as - vocabularies, SavedModel supports this via `assets`. - * Assets are copied to the SavedModel location and can be read when loading - a specific meta graph def. -* Support to clear devices before generating the SavedModel. +## [Public API](https://www.tensorflow.org/api_docs/python/tf/saved_model) +* [`tf.saved_model.save`](https://www.tensorflow.org/api_docs/python/tf/saved_model/save) +* [`tf.saved_model.load`](https://www.tensorflow.org/api_docs/python/tf/saved_model/load) +* [`tf.saved_model.SaveOptions`](https://www.tensorflow.org/api_docs/python/tf/saved_model/SaveOptions) +* [`tf.saved_model.LoadOptions`](https://www.tensorflow.org/api_docs/python/tf/saved_model/LoadOptions) +* [`tf.saved_model.Asset`](https://www.tensorflow.org/api_docs/python/tf/saved_model/Asset) +* [`tf.saved_model.contains_saved_model`](https://www.tensorflow.org/api_docs/python/tf/saved_model/contains_saved_model) -The following is a summary of features that are NOT supported in SavedModel. -Higher-level frameworks and tools that use SavedModel may provide these. +### Related Modules and Functions +* [`tf.keras.models.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model) +* [`tf.keras.models.load_model`](https://www.tensorflow.org/api_docs/python/tf/keras/models/load_model) +* [`tf.train.Checkpoint`](https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint) -* Implicit versioning. -* Garbage collection. -* Atomic writes to the SavedModel location. -## Background -SavedModel manages and builds upon existing TensorFlow primitives such as -`TensorFlow Saver` and `MetaGraphDef`. Specifically, SavedModel wraps a [TensorFlow Saver](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/training/saver.py). -The Saver is primarily used to generate the variable checkpoints. SavedModel -will replace the existing [TensorFlow Inference Model Format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/session_bundle/README.md) -as the canonical way to export TensorFlow graphs for serving. - -## Components +## The SavedModel Format A SavedModel directory has the following structure: ``` @@ -57,72 +45,23 @@ variables/ saved_model.pb ``` -* SavedModel protocol buffer - * `saved_model.pb` or `saved_model.pbtxt` - * Includes the graph definitions as `MetaGraphDef` protocol buffers. -* Assets - * Subfolder called `assets`. - * Contains auxiliary files such as vocabularies, etc. -* Extra assets - * Subfolder where higher-level libraries and users can add their own assets - that co-exist with the model, but are not loaded by the graph. - * This subfolder is not managed by the SavedModel libraries. -* Variables - * Subfolder called `variables`. - * Includes output from the [TensorFlow Saver](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/training/saver.py). - * `variables.data-?????-of-?????` - * `variables.index` +* SavedModel protocol buffer + * [`saved_model.pb`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/saved_model.proto) + or `saved_model.pbtxt` + * Includes the graph definitions as `MetaGraphDef` protocol buffers. +* Assets + * Subfolder called `assets`. + * Contains auxiliary files such as vocabularies, etc. +* Extra assets + * Subfolder where higher-level libraries and users can add their own + assets that co-exist with the model, but are not loaded by the graph. + * This subfolder is not managed by the SavedModel libraries. +* Variables + * Subfolder called `variables`. + * `variables.data-?????-of-?????` + * `variables.index` -## APIs -The APIs for building and loading a SavedModel are described in this section. - -### Builder -The SavedModel [builder](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/builder.py) -is implemented in Python. - -The `SavedModelBuilder` class provides functionality to save multiple meta graph -defs, associated variables and assets. - -To build a SavedModel, the first meta graph must be saved with variables. -Subsequent meta graphs will simply be saved with their graph definitions. If -assets need to be saved and written or copied to disk, they can be provided -when the meta graph def is added. If multiple meta graph defs are associated -with an asset of the same name, only the first version is retained. - -#### Tags -Each meta graph added to the SavedModel must be annotated with user specified -tags, which reflect the meta graph capabilities or use-cases. -More specifically, these tags typically annotate a meta graph with its -functionality (e.g. serving or training), and possibly hardware specific aspects -such as GPU. -In the SavedModel, the meta graph def whose tag-set exactly matches those -specified in the loader API, will be the one loaded by the loader. -If no meta graph def is found matching the specified tags, an error is returned. -For example, a loader with a requirement to serve on GPU hardware would be able -to load only meta graph annotated with tags='serve,gpu' by specifying this set -of tags in tensorflow::LoadSavedModel(...). - - -#### Usage -The typical usage of `builder` is as follows: - -~~~python -export_dir = ... -... -builder = tf.saved_model.builder.SavedModelBuilder(export_dir) -with tf.Session(graph=tf.Graph()) as sess: - ... - builder.add_meta_graph_and_variables(sess, - [tf.saved_model.tag_constants.TRAINING], - signature_def_map=foo_signatures, - assets_collection=foo_assets) -... -with tf.Session(graph=tf.Graph()) as sess: - ... - builder.add_meta_graph(["bar-tag", "baz-tag"]) -... -builder.save() -~~~ +--- #### Stripping Default valued attributes The SavedModelBuilder class allows users to control whether default-valued @@ -152,60 +91,3 @@ models regenerated with newer training binaries. TIP: If you care about forward compatibility, then set `strip_default_attrs` to `True` while using `SavedModelBuilder.add_meta_graph_and_variables` and `SavedModelBuilder.add_meta_graph`. - -### Loader -The SavedModel loader is implemented in C++ and Python. - -#### Python -The Python version of the SavedModel [loader](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/loader.py) -provides load and restore capability for a SavedModel. The `load` operation -requires the session in which to restore the graph definition and variables, the -tags used to identify the meta graph def to load and the location of the -SavedModel. Upon a load, the subset of variables and assets supplied as part of -the specific meta graph def, will be restored into the supplied session. - -~~~python -export_dir = ... -... -with tf.Session(graph=tf.Graph()) as sess: - tf.saved_model.loader.load(sess, [tag_constants.TRAINING], export_dir) - ... -~~~ - -#### C++ -The C++ version of the SavedModel [loader](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/saved_model/loader.h) -provides an API to load a SavedModel from a path, while allowing -`SessionOptions` and `RunOptions`. Similar to the Python version, the C++ -version requires the tags associated with the graph to be loaded, to be -specified. The loaded version of SavedModel is referred to as `SavedModelBundle` -and contains the meta graph def and the session within which it is loaded. - -~~~c++ -const string export_dir = ... -SavedModelBundle bundle; -... -LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagTrain}, - &bundle); -~~~ - -### Constants -SavedModel offers the flexibility to build and load TensorFlow graphs for a -variety of use-cases. For the set of most common expected use-cases, -SavedModel's APIs provide a set of constants in Python and C++ that are easy to -reuse and share across tools consistently. - -#### Tag constants -Sets of tags can be used to uniquely identify a `MetaGraphDef` saved in a -SavedModel. A subset of commonly used tags is specified in: - -* [Python](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/tag_constants.py) -* [C++](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/saved_model/tag_constants.h). - -#### Signature constants -SignatureDefs are used to define the signature of a computation supported in a -TensorFlow graph. Commonly used input keys, output keys and method names are -defined in: - -* [Python](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/signature_constants.py) -* [C++](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/saved_model/signature_constants.h). - diff --git a/tensorflow/python/tpu/session_support.py b/tensorflow/python/tpu/session_support.py index 6e85c417f54..24da6efb4a4 100644 --- a/tensorflow/python/tpu/session_support.py +++ b/tensorflow/python/tpu/session_support.py @@ -106,7 +106,7 @@ class WorkerHeartbeatManager(object): self._session.run(self._ops, {self._request_placeholder: message.SerializeToString()}) - def ping(self, request=None, timeout_in_ms=5000): + def ping(self, request=None, timeout_in_ms=60000): """Ping all workers, returning the parsed status results.""" if request is None: request = event_pb2.WorkerHeartbeatRequest() diff --git a/tensorflow/security/fuzzing/BUILD b/tensorflow/security/fuzzing/BUILD index a05b287e7ba..994191eb5d4 100644 --- a/tensorflow/security/fuzzing/BUILD +++ b/tensorflow/security/fuzzing/BUILD @@ -27,6 +27,25 @@ tf_fuzz_target( ], ) +tf_fuzz_target( + name = "parseURI_fuzz", + srcs = ["parseURI_fuzz.cc"], + deps = [ + "//tensorflow/core/platform:path", + "//tensorflow/core/platform:stringpiece", + "@com_google_absl//absl/strings", + ], +) + +tf_fuzz_target( + name = "cleanpath_fuzz", + srcs = ["cleanpath_fuzz.cc"], + deps = [ + "//tensorflow/core/platform:path", + "@com_google_absl//absl/strings", + ], +) + tf_fuzz_target( name = "consume_leading_digits_fuzz", srcs = ["consume_leading_digits_fuzz.cc"], diff --git a/tensorflow/security/fuzzing/cleanpath_fuzz.cc b/tensorflow/security/fuzzing/cleanpath_fuzz.cc new file mode 100644 index 00000000000..b535bb31fbf --- /dev/null +++ b/tensorflow/security/fuzzing/cleanpath_fuzz.cc @@ -0,0 +1,43 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <cstdint> +#include <cstdlib> +#include <iostream> +#include <regex> // NOLINT + +#include "absl/strings/match.h" +#include "tensorflow/core/platform/path.h" + +// This is a fuzzer for tensorflow::io::CleanPath. + +namespace { + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + std::string input_path(reinterpret_cast<const char *>(data), size); + std::string clean_path = tensorflow::io::CleanPath(input_path); + + // Assert there are no '/./' no directory changes. + assert(!absl::StrContains(clean_path, "/./")); + // Assert there are no duplicate '/'. + assert(!absl::StrContains(clean_path, "//")); + // Assert there are no higher up directories after entering a directory. + std::regex higher_up_directory("[^.]{1}/[.]{2}"); + assert(!std::regex_match(clean_path, higher_up_directory)); + + return 0; +} + +} // namespace diff --git a/tensorflow/security/fuzzing/parseURI_fuzz.cc b/tensorflow/security/fuzzing/parseURI_fuzz.cc new file mode 100644 index 00000000000..f2230c0c444 --- /dev/null +++ b/tensorflow/security/fuzzing/parseURI_fuzz.cc @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <cstdint> +#include <cstdlib> + +#include "absl/strings/match.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/stringpiece.h" + +// This is a fuzzer for tensorflow::io::CleanPath. + +namespace { + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + std::string uri(reinterpret_cast<const char *>(data), size); + tensorflow::StringPiece scheme, host, path; + tensorflow::io::ParseURI(uri, &scheme, &host, &path); + + // If a path is invalid. + if (path == uri) { + assert(host == ""); + assert(scheme == ""); + } else { + assert(absl::StrContains(uri, host)); + assert(absl::StrContains(uri, scheme)); + assert(absl::StrContains(uri, path)); + assert(absl::StrContains(uri, "://")); + } + + return 0; +} + +} // namespace diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index c9419753437..fd7ce8cd515 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -119,6 +119,12 @@ def tf_portable_full_lite_protos(full, lite): "//conditions:default": full, }) +def if_no_default_logger(a): + return select({ + clean_dep("//tensorflow:no_default_logger"): a, + "//conditions:default": [], + }) + def if_android_x86(a): return select({ clean_dep("//tensorflow:android_x86"): a, @@ -332,6 +338,7 @@ def tf_copts( if_android_arm(["-mfpu=neon"]) + if_linux_x86_64(["-msse3"]) + if_ios_x86_64(["-msse4.1"]) + + if_no_default_logger(["-DNO_DEFAULT_LOGGER"]) + select({ clean_dep("//tensorflow:framework_shared_object"): [], "//conditions:default": ["-DTENSORFLOW_MONOLITHIC_BUILD"], diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adadelta.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adadelta.pbtxt index 88e4ecfbb62..0a297df71a9 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adadelta.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adadelta.pbtxt @@ -30,7 +30,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adagrad.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adagrad.pbtxt index 89e0718d5b6..1ec4e30e3f2 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adagrad.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adagrad.pbtxt @@ -30,7 +30,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adam.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adam.pbtxt index 29b1fba5aae..a6deef2da30 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adam.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adam.pbtxt @@ -30,7 +30,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adamax.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adamax.pbtxt index c481aa07ace..ccb9bebb195 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adamax.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adamax.pbtxt @@ -30,7 +30,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-ftrl.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-ftrl.pbtxt index a2b9d310eb9..eab98ac6d06 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-ftrl.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-ftrl.pbtxt @@ -30,7 +30,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-nadam.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-nadam.pbtxt index 650ac77d6df..81cce2fbbec 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-nadam.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-nadam.pbtxt @@ -30,7 +30,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-optimizer.pbtxt index 50e3da3eda5..badb430ca1f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-optimizer.pbtxt @@ -29,7 +29,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-r-m-sprop.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-r-m-sprop.pbtxt index ab8391e0465..bcc8c6019db 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-r-m-sprop.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-r-m-sprop.pbtxt @@ -30,7 +30,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-s-g-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-s-g-d.pbtxt index 2bad07d9998..db673eedc11 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-s-g-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-s-g-d.pbtxt @@ -30,7 +30,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.experimental.tensorrt.-conversion-params.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.experimental.tensorrt.-conversion-params.pbtxt index 5eed1aa7d0a..e3f7e3639da 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.experimental.tensorrt.-conversion-params.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.experimental.tensorrt.-conversion-params.pbtxt @@ -7,14 +7,6 @@ tf_class { name: "allow_build_at_runtime" mtype: "<type \'property\'>" } - member { - name: "is_dynamic_op" - mtype: "<type \'property\'>" - } - member { - name: "max_batch_size" - mtype: "<type \'property\'>" - } member { name: "max_workspace_size_bytes" mtype: "<type \'property\'>" @@ -31,10 +23,6 @@ tf_class { name: "precision_mode" mtype: "<type \'property\'>" } - member { - name: "rewriter_config_template" - mtype: "<type \'property\'>" - } member { name: "use_calibration" mtype: "<type \'property\'>" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adadelta.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adadelta.pbtxt index 88e4ecfbb62..0a297df71a9 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adadelta.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adadelta.pbtxt @@ -30,7 +30,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adagrad.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adagrad.pbtxt index 89e0718d5b6..1ec4e30e3f2 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adagrad.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adagrad.pbtxt @@ -30,7 +30,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adam.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adam.pbtxt index 29b1fba5aae..a6deef2da30 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adam.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adam.pbtxt @@ -30,7 +30,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adamax.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adamax.pbtxt index c481aa07ace..ccb9bebb195 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adamax.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adamax.pbtxt @@ -30,7 +30,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-ftrl.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-ftrl.pbtxt index a2b9d310eb9..eab98ac6d06 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-ftrl.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-ftrl.pbtxt @@ -30,7 +30,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-nadam.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-nadam.pbtxt index 650ac77d6df..81cce2fbbec 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-nadam.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-nadam.pbtxt @@ -30,7 +30,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-optimizer.pbtxt index 50e3da3eda5..badb430ca1f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-optimizer.pbtxt @@ -29,7 +29,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-r-m-sprop.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-r-m-sprop.pbtxt index ab8391e0465..bcc8c6019db 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-r-m-sprop.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-r-m-sprop.pbtxt @@ -30,7 +30,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-s-g-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-s-g-d.pbtxt index 2bad07d9998..db673eedc11 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-s-g-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-s-g-d.pbtxt @@ -30,7 +30,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adadelta.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adadelta.pbtxt index 605cb27c36b..e4552623c02 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adadelta.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adadelta.pbtxt @@ -30,7 +30,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adagrad.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adagrad.pbtxt index a436583fdd6..8b97923bcbf 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adagrad.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adagrad.pbtxt @@ -30,7 +30,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adam.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adam.pbtxt index f874658cc25..d92c792df1b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adam.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adam.pbtxt @@ -30,7 +30,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adamax.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adamax.pbtxt index 6798187be77..e015f67cdcb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adamax.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adamax.pbtxt @@ -30,7 +30,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-ftrl.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-ftrl.pbtxt index 15efc6ada39..43252585d22 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-ftrl.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-ftrl.pbtxt @@ -30,7 +30,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-nadam.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-nadam.pbtxt index 00cf3e0e24e..a1866674ec5 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-nadam.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-nadam.pbtxt @@ -30,7 +30,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-optimizer.pbtxt index 881d15c5306..77ae5d5ccff 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-optimizer.pbtxt @@ -29,7 +29,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-r-m-sprop.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-r-m-sprop.pbtxt index 661e9cb5a58..068c4b7cf40 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-r-m-sprop.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-r-m-sprop.pbtxt @@ -30,7 +30,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-s-g-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-s-g-d.pbtxt index a14c9a4ce57..e2defb8a290 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-s-g-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-s-g-d.pbtxt @@ -30,7 +30,7 @@ tf_class { } member_method { name: "add_slot" - argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\', \'shape\'], varargs=None, keywords=None, defaults=[\'zeros\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.test.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.test.pbtxt index b23d3b9f01b..2d4729f1867 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.test.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.test.pbtxt @@ -18,7 +18,7 @@ tf_module { } member_method { name: "compute_gradient" - argspec: "args=[\'f\', \'x\', \'delta\'], varargs=None, keywords=None, defaults=[\'0.001\'], " + argspec: "args=[\'f\', \'x\', \'delta\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "create_local_cluster" diff --git a/tensorflow/tools/ci_build/a100/nightly.sh b/tensorflow/tools/ci_build/a100/nightly.sh index e7756da8089..f96429cbbba 100644 --- a/tensorflow/tools/ci_build/a100/nightly.sh +++ b/tensorflow/tools/ci_build/a100/nightly.sh @@ -20,8 +20,10 @@ cd tensorflow/tools/ci_build docker build -t gpu_test_container:latest -f \ Dockerfile.rbe.cuda11.0-cudnn8-ubuntu18.04-manylinux2010-multipython . +DEFAULT_BAZEL_TARGETS="//tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/mlir/tosa/... -//tensorflow/compiler/xrt/... //tensorflow/compiler/mlir/lite/... -//tensorflow/lite/micro/examples/... -//tensorflow/core/tpu/..." + docker run --rm \ --gpus all \ --shm-size=2g --ulimit memlock=-1 --ulimit stack=67108864 \ - gpu_test_container:latest bash -c "python3.7 -m pytest" + gpu_test_container:latest bash -c "cd tensorflow_src; bazel test --config=rbe_linux_cuda11.0_nvcc_py3.6 --config=tensorflow_testing_rbe_linux --test_tag_filters=gpu,-no_gpu,-nogpu,-benchmark-test,-no_oss,-oss_serial,-v1only,-no_gpu_presubmit,-no_cuda11 -- ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/..." diff --git a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl index 1049939c94b..2d1c5138866 100644 --- a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl +++ b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl @@ -168,9 +168,17 @@ def get_pybind_export_symbols(symbols_file, lib_paths_file): lib_paths_file: String that is the path to txt file that lists cc_library target execpaths for exporting symbols. """ - # cc_library target name is always in [target_name] format in - # `symbols_pybind.txt`. - section_header_filter = r"\[(\S+)\]" # e.g. `[cpp_python_util]` + # A cc_library target name must begin its own line, and it must begin with + # `//tensorflow`. It can then optionally have some number of directories, and + # it must end with a target name directly preceded by either a slash or a + # colon. A directory or target name is any combination of letters, numbers, + # underscores, and dashes. + # Examples of possible headers: + # `[//tensorflow/core/util/tensor_bundle]` + # `[//tensorflow/python:safe_ptr]` + # `[//tensorflow:target_name_v2_25]` + # `[//tensorflow/-/24/util_:port-5]` + section_header_filter = r"^\[\/\/(tensorflow(\/[\w-]+)*(:|\/)[\w-]+)\]" # Create a dict of target libs and their symbols to be exported and populate # it. (key = cc_library target, value = list of symbols) that we need to @@ -199,20 +207,31 @@ def get_pybind_export_symbols(symbols_file, lib_paths_file): symbols_all = [] for lib in lib_paths: if lib: - for cc_lib in symbols: # keys in symbols = cc_library target name - path_to_lib = cc_lib.split("/") - cc_target = path_to_lib[-1] - # if `len(path_to_lib)` is larger than 1, that means, we are given one - # or more parent directory of the target. e.g. `[foo/bar]` instead of - # just the target name `[bar]`. - if len(path_to_lib) > 1: - parent_dir = path_to_lib[0] + for cc_lib in symbols: # keys in symbols = cc_library target name + if cc_lib.count(":") == 1: + formatted_cc_lib = cc_lib.replace(":", "/") + elif cc_lib.count(":") == 0: + formatted_cc_lib = cc_lib else: - parent_dir = "" - if cc_target in lib and parent_dir in lib: - symbols_all.extend( - get_symbols(lib, "|".join(symbols[cc_lib]))) - + raise ValueError(f"Detected wrong format for symbols header in" + "`symbols_pybind.txt`. Header must have 0 or 1 " + "colon (e.g. `[//third_party/tensorflow/python:safe_ptr]`" + "or `[tensorflow/core/util/tensor_bundle]`) but " + "detected: {cc_lib}") + path_to_lib = formatted_cc_lib.split("/") + # `path_to_lib` is a bazel out path, which means the actual path string + # we get here differs from the package path listed in + # `win_lib_files_for_exported_symbols` and `symbols_pybind.txt`. + # For example, the target `tensorflow/core:op_gen_lib` in + # `win_lib_files_for_exported_symbols` generates the bazel library path + # `bazel-out/x64_windows-opt/bin/tensorflow/core/framework/op_gen_lib.lib` + lib_and_cc_lib_match = True + for p in path_to_lib: + if p not in lib: + lib_and_cc_lib_match = False + break + if lib_and_cc_lib_match: + symbols_all.extend(get_symbols(lib, "|".join(symbols[cc_lib]))) return symbols_all def main(): diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt index a6082788413..361283c5c90 100644 --- a/tensorflow/tools/def_file_filter/symbols_pybind.txt +++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt @@ -1,4 +1,4 @@ -[cpp_python_util] # util tfe +[//tensorflow/python:cpp_python_util] # util tfe tensorflow::swig::IsSequence tensorflow::swig::IsSequenceOrComposite tensorflow::swig::IsCompositeTensor @@ -22,18 +22,18 @@ tensorflow::swig::RegisterType tensorflow::swig::IsEagerTensorSlow tensorflow::swig::GetRegisteredPyObject -[util_port] # util_port +[//tensorflow/core/util:port] # util_port tensorflow::IsGoogleCudaEnabled tensorflow::IsBuiltWithROCm tensorflow::IsBuiltWithNvcc tensorflow::GpuSupportsHalfMatMulAndConv tensorflow::IsMklEnabled -[stream_executor_pimpl] # stat_summarizer +[//tensorflow/stream_executor:stream_executor_pimpl] # stat_summarizer stream_executor::StreamExecutor::EnablePeerAccessTo stream_executor::StreamExecutor::CanEnablePeerAccessTo -[print_model_analysis] # tfprof +[//tensorflow/core/profiler/internal:print_model_analysis] # tfprof tensorflow::tfprof::NewProfiler tensorflow::tfprof::DeleteProfiler tensorflow::tfprof::AddStep @@ -43,24 +43,17 @@ tensorflow::tfprof::Profile tensorflow::tfprof::PrintModelAnalysis tensorflow::tfprof::SerializeToString -[graph_analyzer_tool] # graph_analyzer +[//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool] # graph_analyze tensorflow::grappler::graph_analyzer::GraphAnalyzerTool -[bfloat16_lib] # bfloat16 +[//tensorflow/python:bfloat16_lib] # bfloat16 tensorflow::RegisterNumpyBfloat16 tensorflow::Bfloat16PyType -[events_writer] # events_writer -tensorflow::EventsWriter::Init -tensorflow::EventsWriter::InitWithSuffix -tensorflow::EventsWriter::WriteSerializedEvent -tensorflow::EventsWriter::Flush -tensorflow::EventsWriter::Close - -[py_func_lib] # py_func +[//tensorflow/python:py_func_lib] # py_func tensorflow::InitializePyTrampoline -[framework_internal_impl] # op_def_registry, dtypes +[//tensorflow/core:framework_internal_impl] # op_def_registry, dtypes tensorflow::BaseType tensorflow::DataTypeString tensorflow::DataTypeIsComplex @@ -73,29 +66,26 @@ tensorflow::OpRegistry::Global tensorflow::OpRegistry::LookUpOpDef tensorflow::RemoveNonDeprecationDescriptionsFromOpDef -[lib_internal_impl] # device_lib +[//tensorflow/core:lib_internal_impl] # device_lib tensorflow::Status::code tensorflow::Status::error_message tensorflow::Status::ok() -[device] # device_lib, tfe, tf_session -tensorflow::Device::attributes - -[device_factory] # device_lib, tfe, tf_session +[//tensorflow/core/common_runtime:device_factory] # device_lib, tfe, tf_session tensorflow::DeviceFactory::AddDevices tensorflow::DeviceFactory::ListAllPhysicalDevices tensorflow::DeviceFactory::GetAnyDeviceDetails -[session_options] # device_lib, tfe, tf_session +[//tensorflow/core/common_runtime:session_options] # device_lib, tfe, tf_session tensorflow::SessionOptions::SessionOptions -[quantize_training] # quantize_training +[//tensorflow/core/common_runtime:quantize_training] # quantize_training tensorflow::DoQuantizeTrainingOnSerializedGraphDef -[session_state] # tf_session +[//tensorflow/core/common_runtime:session_state] # tf_session tensorflow::SessionState::kTensorHandleResourceTypeName -[server_lib] # server_lib +[//tensorflow/core/data/service:server_lib] # server_lib tensorflow::data::GrpcDataServerBase::Join tensorflow::data::GrpcDataServerBase::Start tensorflow::data::GrpcDataServerBase::Stop @@ -105,31 +95,25 @@ tensorflow::data::WorkerGrpcDataServer::NumTasks tensorflow::data::NewDispatchServer tensorflow::data::NewWorkerServer -[protos_all] # device_lib, dtypes -tensorflow::DataType_IsValid -tensorflow::ConfigProto::ConfigProto -tensorflow::ConfigProto::ParseFromString -tensorflow::DeviceAttributes::SerializeToString - -[py_exception_registry] # py_exception_registry +[//tensorflow/python:py_exception_registry] # py_exception_registry tensorflow::PyExceptionRegistry::Init tensorflow::PyExceptionRegistry::Lookup -[kernel_registry] # kernel_registry +[//tensorflow/python:kernel_registry] # kernel_registry tensorflow::swig::TryFindKernelClass -[toco_python_api] # toco_python_api +[//tensorflow/lite/toco/python:toco_python_api] # toco_python_api toco::TocoConvert toco::TocoGetPotentiallySupportedOps toco::MlirQuantizeModel toco::MlirSparsifyModel toco::RegisterCustomOpdefs -[transform_graph_lib] # transform_graph +[//tensorflow/tools/graph_transforms:transform_graph_lib] # transform_graph tensorflow::graph_transforms::TransformGraph tensorflow::graph_transforms::ParseTransformParameters -[checkpoint_reader] # py_checkpoint_reader +[//tensorflow/c:checkpoint_reader] # py_checkpoint_reader tensorflow::checkpoint::CheckpointReader tensorflow::checkpoint::CheckpointReader::Init tensorflow::checkpoint::CheckpointReader::DebugString @@ -138,21 +122,21 @@ tensorflow::checkpoint::CheckpointReader::GetVariableToDataTypeMap tensorflow::checkpoint::CheckpointReader::GetTensor tensorflow::checkpoint::CheckpointReader::HasTensor -[tensor_bundle] # py_checkpoint_reader +[//tensorflow/core/util/tensor_bundle] # py_checkpoint_reader tensorflow::BundleReader::BundleReader tensorflow::BundleReader::~BundleReader -[ndarray_tensor] # py_checkpoint_reader +[//tensorflow/python:ndarray_tensor] # py_checkpoint_reader tensorflow::TensorToNdarray -[safe_ptr] # py_checkpoint_reader +[//tensorflow/python:safe_ptr] # py_checkpoint_reader tensorflow::detail::PyDecrefDeleter tensorflow::make_safe -[python_op_gen] # python_op_gen +[//tensorflow/python:python_op_gen] # python_op_gen tensorflow::GetPythonWrappers -[pywrap_tfe_lib] # tfe +[//tensorflow/python/eager:pywrap_tfe_lib] # tfe tensorflow::TFE_TensorHandleCache tensorflow::TFE_TensorHandleCache::Clear EagerTensor_CheckExact @@ -210,17 +194,17 @@ tensorflow::MakeEagerContextThreadLocalData tensorflow::GetEagerContextThreadLocalData tensorflow::DestroyEagerContextThreadLocalData -[eager_executor] # tfe +[//tensorflow/core/common_runtime/eager:eager_executor] # tfe tensorflow::EagerExecutor::~EagerExecutor tensorflow::EagerContext::WaitForAndCloseRemoteContexts -[tf_status_helper] # tfe +[//tensorflow/c:tf_status_helper] # tfe tensorflow::Set_TF_Status_from_Status -[context] # tfe +[//tensorflow/core/common_runtime/eager:context] # tfe tensorflow::EagerContext::WaitForAndCloseRemoteContexts -[mlir] # mlir +[//tensorflow/compiler/mlir/python:mlir] # mlir tensorflow::ExperimentalRunPassPipeline tensorflow::ExperimentalConvertSavedModelV1ToMlirLite tensorflow::ExperimentalConvertSavedModelV1ToMlir @@ -228,13 +212,13 @@ tensorflow::ExperimentalConvertSavedModelToMlir tensorflow::ImportGraphDef tensorflow::ImportFunction -[op_gen_lib] # tf_session +[//tensorflow/core:op_gen_lib] # tf_session tensorflow::ApiDefMap::~ApiDefMap -[graph_constructor] # tf_session +[//tensorflow/core/common_runtime:graph_constructor] # tf_session tensorflow::ShapeRefiner::~ShapeRefiner -[python_api] # tf_session +[//tensorflow/c:python_api] # tf_session tensorflow::AddControlInput tensorflow::SetAttr tensorflow::ClearAttr @@ -247,11 +231,11 @@ tensorflow::GetHandleShapeAndType tensorflow::SetHandleShapeAndType tensorflow::AddWhileInputHack -[numpy_lib] # tf_session +[//tensorflow/python:numpy_lib] # tf_session tensorflow::ImportNumpy _tensorflow_numpy_api -[tf_session_helper] # tf_session +[//tensorflow/python:tf_session_helper] # tf_session tensorflow::TF_NewSessionRef tensorflow::TF_SessionMakeCallable tensorflow::TF_SessionRunCallable @@ -274,76 +258,76 @@ tensorflow::TF_GraphSetTensorShape_wrapper tensorflow::TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper tensorflow::TF_TryEvaluateConstant_wrapper -[grappler_item] # tf_item +[//tensorflow/core/grappler:grappler_item] # tf_item tensorflow::grappler::GrapplerItem::MainOpsFanin tensorflow::grappler::GrapplerItem::EnqueueOpsFanin -[graph_properties] # tf_item +[//tensorflow/core/grappler/costs:graph_properties] # tf_item tensorflow::grappler::GraphProperties::InferStatically tensorflow::grappler::GraphProperties::GetOutputProperties -[grappler_item_builder] # tf_item +[//tensorflow/core/grappler:grappler_item_builder] # tf_item tensorflow::grappler::GrapplerItemFromMetaGraphDef -[topological_sort] # tf_item +[//tensorflow/core/grappler/utils:topological_sort] # tf_item tensorflow::grappler::TopologicalSort -[clusters/utils] # tf_cluster tf_optimizer +[//tensorflow/core/grappler/clusters:utils] # tf_cluster tf_optimizer tensorflow::grappler::GetDeviceInfo -[costs/utils] # tf_optimizer tf_cluster +[//tensorflow/core/grappler/costs:utils] # tf_optimizer tf_cluster tensorflow::grappler::CostGraphToOpPerformanceData tensorflow::grappler::GetDeviceInfo -[meta_optimizer] # tf_optimizer +[//tensorflow/core/grappler/optimizers:meta_optimizer] # tf_optimizer tensorflow::grappler::MetaOptimizer::MetaOptimizer tensorflow::grappler::MetaOptimizer::Optimize tensorflow::grappler::MetaOptimizer::PrintResult -[clusters/cluster] # tf_cluster +[//tensorflow/core/grappler/clusters:cluster] # tf_cluster tensorflow::grappler::Cluster::AllowSoftPlacement tensorflow::grappler::Cluster::SetNumWarmupSteps tensorflow::grappler::Cluster::DisableDetailedStats tensorflow::grappler::Cluster::DetailedStatsEnabled -[single_machine] # tf_cluster +[//tensorflow/core/grappler/clusters:single_machine] # tf_cluster tensorflow::grappler::SingleMachine::SingleMachine -[op_level_cost_estimator] # tf_cluster +[//tensorflow/core/grappler/costs:op_level_cost_estimator] # tf_cluster tensorflow::grappler::OpLevelCostEstimator::OpLevelCostEstimator tensorflow::grappler::OpLevelCostEstimator::PredictCosts tensorflow::grappler::OpLevelCostEstimator::GetDeviceInfo -[virtual_cluster] # tf_cluster +[//tensorflow/core/grappler/clusters:virtual_cluster] # tf_cluster tensorflow::grappler::VirtualCluster::VirtualCluster -[graph_memory] # tf_cluster +[//tensorflow/core/grappler/costs:graph_memory] # tf_cluster tensorflow::grappler::GraphMemory::InferStatically tensorflow::grappler::GraphMemory::InferDynamically -[measuring_cost_estimator] # tf_cluster +[//tensorflow/core/grappler/costs:measuring_cost_estimator] # tf_cluster tensorflow::grappler::MeasuringCostEstimator::MeasuringCostEstimator tensorflow::grappler::MeasuringCostEstimator::Initialize tensorflow::grappler::MeasuringCostEstimator::PredictCosts -[devices] # tf_cluster +[//tensorflow/core/grappler:devices] # tf_cluster tensorflow::grappler::GetNumAvailableGPUs tensorflow::grappler::GetNumAvailableLogicalCPUCores -[traceme_recorder_impl] # profiler +[//tensorflow/core/profiler/internal:traceme_recorder_impl] # profiler tensorflow::profiler::TraceMeRecorder::Record -[profiler_session_impl] # profiler +[//tensorflow/core/profiler/lib:profiler_session_impl] # profiler tensorflow::ProfilerSession::Create tensorflow::ProfilerSession::CollectData tensorflow::ProfilerSession::Status tensorflow::ProfilerSession::~ProfilerSession -[profiler_server_impl] # profiler +[//tensorflow/core/profiler/rpc:profiler_server_impl] # profiler tensorflow::profiler::ProfilerServer::StartProfilerServer tensorflow::profiler::ProfilerServer::~ProfilerServer -[profiler_client_impl] # profiler +[//tensorflow/core/profiler/rpc/client:profiler_client_impl] # profiler tensorflow::profiler::ProfileGrpc tensorflow::profiler::NewSessionGrpc tensorflow::profiler::MonitorGrpc @@ -352,13 +336,13 @@ tensorflow::profiler::RemoteProfilerSession::GetServiceAddress tensorflow::profiler::RemoteProfilerSession::WaitForCompletion tensorflow::profiler::RemoteProfilerSession::~RemoteProfilerSession -[status_macros] # tfcompile +[//tensorflow/compiler/xla:status_macros] # tfcompile xla::status_macros::MakeErrorStream::Impl::Impl xla::status_macros::MakeErrorStream::Impl::~Impl xla::status_macros::MakeErrorStream::Impl::GetStatus xla::status_macros::MakeErrorStream::CheckNotDone -[hlo] # tfcompile +[//tensorflow/compiler/xla/service:hlo] # tfcompile xla::DfsHloVisitorBase::SetVisited xla::DfsHloVisitorBase<class xla::HloInstruction.*>::SetVisited xla::HloComputation::Accept @@ -368,42 +352,42 @@ xla::HloInstruction::ToString xla::HloInstruction::Accept xla::HloInstruction::Visit -[tfcompile_lib] # tfcompile +[//tensorflow/compiler/aot:tfcompile_lib] # tfcompile tensorflow::tfcompile::Main -[model_analyzer_lib] # model_analyzer +[//tensorflow/python:model_analyzer_lib] # model_analyzer tensorflow::grappler::ModelAnalyzer::GenerateReport tensorflow::grappler::ModelAnalyzer::ModelAnalyzer -[analytical_cost_estimator] # cost_analyzer +[//tensorflow/core/grappler/costs:analytical_cost_estimator] # cost_analyzer tensorflow::grappler::AnalyticalCostEstimator::Initialize tensorflow::grappler::AnalyticalCostEstimator::PredictCosts -[cost_analyzer_lib] # cost_analyzer +[//tensorflow/python:cost_analyzer_lib] # cost_analyzer tensorflow::grappler::CostAnalyzer::CostAnalyzer tensorflow::grappler::CostAnalyzer::GenerateReport -[flags] # tfe +[//tensorflow/compiler/jit:flags] # tfe tensorflow::IsXlaEnabled tensorflow::GetMlirCommonFlags tensorflow::GetXlaDeviceFlags -[tensor_float_32_utils] # tensor_float_32 +[//tensorflow/core/platform:tensor_float_32_utils] # tensor_float_32 tensorflow::enable_tensor_float_32_execution tensorflow::tensor_float_32_execution_enabled -[get_compiler_ir] # tfe +[//tensorflow/compiler/jit:get_compiler_ir] # tfe tensorflow::GetCompilerIr stream_executor::port::internal_statusor::Helper::Crash -[tensor_handle] # tfe +[//tensorflow/core/common_runtime/eager:tensor_handle] # tfe tensorflow::TensorHandle::Tensor -[python_api_dispatcher] # python_api_dispatcher +[//tensorflow/python:python_api_dispatcher] # python_api_dispatcher tensorflow::PythonAPIDispatcher -[python_tensor_converter] # python_tensor_converter +[//tensorflow/python:python_tensor_converter] # python_tensor_converter tensorflow::PythonTensorConverter -[python_api_info] # python_api_info +[//tensorflow/python:python_api_info] # python_api_info tensorflow::PythonAPIInfo diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 0d8ae1cc67b..f2029dc1c5c 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -203,11 +203,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): name = "eigen_archive", build_file = clean_dep("//third_party:eigen.BUILD"), patch_file = clean_dep("//third_party/eigen3:gpu_packet_math.patch"), - sha256 = "cce3143f6ed22dadff4ef0b43ce31c9632ace8e75bf41e401b6ec4668968d4c9", # SHARED_EIGEN_SHA - strip_prefix = "eigen-41d5d5334b8a4e364dfd88dcd91f6cd38834b8ed", + sha256 = "306f15c04fbd514b4adc3a327a2c6f63521ea6805cab75691fa30c30fea55193", # SHARED_EIGEN_SHA + strip_prefix = "eigen-fd1dcb6b45a2c797ad4c4d6cc7678ee70763b4ed", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/41d5d5334b8a4e364dfd88dcd91f6cd38834b8ed/eigen-41d5d5334b8a4e364dfd88dcd91f6cd38834b8ed.tar.gz", - "https://gitlab.com/libeigen/eigen/-/archive/41d5d5334b8a4e364dfd88dcd91f6cd38834b8ed/eigen-41d5d5334b8a4e364dfd88dcd91f6cd38834b8ed.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/fd1dcb6b45a2c797ad4c4d6cc7678ee70763b4ed/eigen-fd1dcb6b45a2c797ad4c4d6cc7678ee70763b4ed.tar.gz", + "https://gitlab.com/libeigen/eigen/-/archive/fd1dcb6b45a2c797ad4c4d6cc7678ee70763b4ed/eigen-fd1dcb6b45a2c797ad4c4d6cc7678ee70763b4ed.tar.gz", ], ) @@ -686,8 +686,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): ) # Check out LLVM and MLIR from llvm-project. - LLVM_COMMIT = "8c1e3cbebfe9ff14829a279b9d229d4fc3f190b1" - LLVM_SHA256 = "ac1e5de135b12b8008c7998e9e07a657524c2bfee79406273da5f18c356e3892" + LLVM_COMMIT = "76bd4444e36197465f1c72f4b6f1d59721012a59" + LLVM_SHA256 = "805b737a8ff996ea7216818b0969e7f60c3c9d7a9be82198902f53e2a36c44eb" LLVM_URLS = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD index a7eff42d00f..fa4ccb71ece 100644 --- a/third_party/mlir/BUILD +++ b/third_party/mlir/BUILD @@ -675,6 +675,7 @@ cc_library( hdrs = ["include/mlir/Dialect/Async/Passes.h"], includes = ["include"], deps = [ + ":Analysis", ":Async", ":AsyncPassIncGen", ":IR", @@ -3451,6 +3452,7 @@ cc_library( "include/mlir/ExecutionEngine/AsyncRuntime.h", ], includes = ["include"], + deps = ["@llvm-project//llvm:Support"], ) cc_library( diff --git a/third_party/nccl/archive.patch b/third_party/nccl/archive.patch index 9dfe432d60b..695fd718cc5 100644 --- a/third_party/nccl/archive.patch +++ b/third_party/nccl/archive.patch @@ -36,20 +36,20 @@ index 985274e..7ebb1e1 100644 @@ -10,12 +10,12 @@ #include <cuda_runtime.h> #include <cuda_fp16.h> - + -#define NCCL_MAJOR ${nccl:Major} -#define NCCL_MINOR ${nccl:Minor} -#define NCCL_PATCH ${nccl:Patch} -#define NCCL_SUFFIX "${nccl:Suffix}" +#define NCCL_MAJOR 2 +#define NCCL_MINOR 7 -+#define NCCL_PATCH 3 ++#define NCCL_PATCH 6 +#define NCCL_SUFFIX "" - + -#define NCCL_VERSION_CODE ${nccl:Version} +#define NCCL_VERSION_CODE 2703 #define NCCL_VERSION(X,Y,Z) ((X) * 1000 + (Y) * 100 + (Z)) - + #ifdef __cplusplus See https://github.com/NVIDIA/nccl/pull/322.patch From 410d341bd4569f60282576daa5c991717dbd560e Mon Sep 17 00:00:00 2001 @@ -127,7 +127,7 @@ index 550cfcd0c..8fea91950 100644 if (parent == NULL) { - if (path == NULL) NCCLCHECK(getPciPath(busId, &path)); + NCCLCHECK(getPciPath(busId, path)); - + // Save that for later in case next step is a CPU char numaIdStr[MAX_STR_LEN]; @@ -544,7 +546,6 @@ ncclResult_t ncclTopoGetXmlFromSys(struct ncclXmlNode* pciNode, struct ncclXml* @@ -137,7 +137,7 @@ index 550cfcd0c..8fea91950 100644 - free(path); return ncclSuccess; } - + @@ -644,8 +644,8 @@ ncclResult_t ncclTopoGetXmlFromGpu(struct ncclXmlNode* pciNode, nvmlDevice_t nvm // Remote NVLink device is not visible inside this VM. Assume NVSwitch. NCCLCHECK(xmlSetAttr(sub, "tclass", "0x068000")); @@ -169,7 +169,7 @@ index 8fea91950..42eb68a4b 100644 @@ -460,20 +460,21 @@ int checkBDFFormat(char* bdf) { return 1; } - + -ncclResult_t ncclTopoGetXmlFromSys(struct ncclXmlNode* pciNode, struct ncclXml* xml) { +ncclResult_t ncclTopoGetXmlNodeFromSys(struct ncclXmlNode* pciNode, + struct ncclXml* xml,