diff --git a/.gitignore b/.gitignore index 72cb418fe11..0cfe6fca30e 100644 --- a/.gitignore +++ b/.gitignore @@ -38,6 +38,7 @@ gradleBuild *.pbxproj *.xcworkspace /*.podspec +/tensorflow/lite/**/coreml/**/BUILD /tensorflow/lite/**/ios/BUILD /tensorflow/lite/**/objc/BUILD /tensorflow/lite/**/swift/BUILD diff --git a/configure.py b/configure.py index fcce0ccd061..3cc05041e18 100644 --- a/configure.py +++ b/configure.py @@ -58,6 +58,8 @@ NCCL_LIB_PATHS = [ # List of files to configure when building Bazel on Apple platforms. APPLE_BAZEL_FILES = [ + 'tensorflow/lite/experimental/delegates/coreml/BUILD', + 'tensorflow/lite/experimental/delegates/coreml/builders/BUILD', 'tensorflow/lite/experimental/ios/BUILD', 'tensorflow/lite/experimental/objc/BUILD', 'tensorflow/lite/experimental/swift/BUILD', diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 114787116df..2e18d4cf0b2 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -639,7 +639,7 @@ tf_cc_shared_object( "//tensorflow/cc/saved_model:loader_lite_impl", "//tensorflow/core:core_cpu_impl", "//tensorflow/core:framework_internal_impl", - "//tensorflow/core:gpu_runtime_impl", + "//tensorflow/core/common_runtime/gpu:gpu_runtime_impl", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl", "//tensorflow/core:lib_internal_impl", "//tensorflow/core/profiler:profiler_impl", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index ab59ba829d7..ff7d56e592d 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -995,9 +995,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { return nullptr; } - tensorflow::Tensor tensor = tensorflow::TensorFromInterface(t); - t->Release(); - return tensorflow::TF_TensorFromTensor(tensor, &status->status); + return new TF_Tensor{t}; } void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) { diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index 72ddf166cbd..6e4ac19c3ce 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -580,12 +580,6 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) { }; } -void TFE_TensorHandleEnableImplicitMirroring(TFE_TensorHandle* h, - TF_Status* status) { - h->handle->EnableImplicitMirroring(); - status->status = tensorflow::Status::OK(); -} - void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name, TF_Buffer* buf, TF_Status* status) { tensorflow::EagerContext* context = diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 5f9190af79a..45d15960a9f 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -392,12 +392,6 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx, TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status); -// If the TensorHandle is copied to another device as part of an op execution, -// the copy is destroyed after the op has executed. Enabling implicit mirroring -// causes the copy to be held as a mirror for the lifetime of the TensorHandle. -TF_CAPI_EXPORT extern void TFE_TensorHandleEnableImplicitMirroring( - TFE_TensorHandle*, TF_Status*); - // This function will block till the operation that produces `h` has // completed. This is only valid on local TFE_TensorHandles. The pointer // returned will be on the device in which the TFE_TensorHandle resides (so e.g. diff --git a/tensorflow/c/eager/c_api_remote_test.cc b/tensorflow/c/eager/c_api_remote_test.cc index a0c4830e5ef..a084795eef6 100644 --- a/tensorflow/c/eager/c_api_remote_test.cc +++ b/tensorflow/c/eager/c_api_remote_test.cc @@ -168,8 +168,6 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) { auto* h1_task2 = TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_TensorHandleEnableImplicitMirroring(h1_task2, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); // Handles are on task0 (local), and task2, but op is on task1. TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2); diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index f939a4a3035..6c4877b2ea2 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -594,7 +594,6 @@ void ExecuteAdd(bool async, bool forward_input) { TFE_TensorHandle* n_gpu = TFE_TensorHandleCopyToDevice(n, ctx, gpu_device_name.c_str(), status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_TensorHandleEnableImplicitMirroring(n_gpu, status); TFE_DeleteTensorHandle(n); n = n_gpu; } diff --git a/tensorflow/c/eager/tensor_handle_interface.h b/tensorflow/c/eager/tensor_handle_interface.h index 2b604f660b1..1ca40daec41 100644 --- a/tensorflow/c/eager/tensor_handle_interface.h +++ b/tensorflow/c/eager/tensor_handle_interface.h @@ -59,14 +59,6 @@ class AbstractTensorHandleInterface { // Return a copy of the handle. virtual AbstractTensorHandleInterface* Copy() = 0; - // Maintain mirror tensors for any implicit copies to local devices. This - // setting is offered on a per tensor handle basis to avoid potential memory - // over utilization due to holding on to mirrors as well as the original - // tensor. Note this setting overrides the context mirroring policy whereby if - // the mirroring policy is MIRRORING_NONE, we will still continue to mirror - // this tensor. - virtual void EnableImplicitMirroring() = 0; - protected: virtual ~AbstractTensorHandleInterface() {} }; diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index d0e8a2b35d2..28d922f9e3c 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -118,8 +118,8 @@ cc_library( "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:gpu_init", "//tensorflow/core:lib", + "//tensorflow/core/common_runtime/gpu:gpu_init", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 4d07b8d26c1..22f3640023d 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -72,6 +72,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", "//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo", + "//tensorflow/compiler/mlir/xla:buffer_assignment", "//tensorflow/compiler/mlir/xla:hlo", "//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo", "//tensorflow/compiler/mlir/xla:lhlo", diff --git a/tensorflow/compiler/mlir/g3doc/_index.yaml b/tensorflow/compiler/mlir/g3doc/_index.yaml index affd0926af5..10ea5ac337d 100644 --- a/tensorflow/compiler/mlir/g3doc/_index.yaml +++ b/tensorflow/compiler/mlir/g3doc/_index.yaml @@ -1,6 +1,7 @@ book_path: /mlir/_book.yaml project_path: /mlir/_project.yaml -description: +description: An intermediate representation and compiler framework, MLIR unifies the + infrastructure for high-performance ML models in TensorFlow. landing_page: custom_css_path: /site-assets/css/style.css rows: diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 32a977416ae..789d06b8ac9 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -771,6 +771,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/lite/tools/optimize:quantize_weights", "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:support", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h index 322aeadfa37..82d058964cb 100644 --- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h +++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h @@ -36,7 +36,7 @@ struct PassConfig { form_clusters(false), unfold_batch_matmul(true), legalize_tf_while(true), - shape_inference(false) {} + shape_inference(true) {} // If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be // added, which produces TF Lite ops. diff --git a/tensorflow/compiler/mlir/lite/converter_gen.cc b/tensorflow/compiler/mlir/lite/converter_gen.cc index 83c95c03c8b..bc894d36e75 100644 --- a/tensorflow/compiler/mlir/lite/converter_gen.cc +++ b/tensorflow/compiler/mlir/lite/converter_gen.cc @@ -409,10 +409,14 @@ static void GenOperandResultVerifier(raw_ostream &os, os << " (void)v;\n" << " if (!(" << tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n" + << " if (failure_on_operand_type_mismatch) {\n" << formatv( " return op->emitOpError(\"{0} #\") << index " "<< \" must be {1}, but got \" << v.getType();\n", valueKind, desc) + << " } else {\n" + << " return ::mlir::LogicalResult::Failure;\n" + << " }\n" << " }\n" // if << " ++index;\n" << " }\n"; // for @@ -437,7 +441,8 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) { mlir::tblgen::FmtContext verify_ctx; os << "::mlir::LogicalResult " << op.getCppClassName() - << "::VerifyTflRuntimeTypes(::mlir::Operation *op) {\n"; + << "::VerifyTflRuntimeTypes(::mlir::Operation *op, bool " + "failure_on_operand_type_mismatch) {\n"; os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n"; verify_ctx.withOp("top"); diff --git a/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h b/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h index 8581187be70..9de762629c2 100644 --- a/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h +++ b/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h @@ -70,6 +70,19 @@ class TFLiteCostEstimator { static bool IsSupported(mlir::Operation* op) { return true; } }; +// tfl.cos +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + // tfl.depthwise_conv_2d template <> class TFLiteCostEstimator { @@ -83,6 +96,32 @@ class TFLiteCostEstimator { static bool IsSupported(mlir::Operation* op) { return true; } }; +// tfl.div +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +// tfl.exp +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + // tfl.fully_connected template <> class TFLiteCostEstimator { @@ -97,6 +136,19 @@ class TFLiteCostEstimator { static bool IsSupported(mlir::Operation* op) { return true; } }; +// tfl.hard_swish +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + // tfl.logistic template <> class TFLiteCostEstimator { diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index c36f4af9623..f9739bfa626 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -138,6 +138,8 @@ static StatusOr GetTFLiteType(Type type, return tflite::TensorType_FLOAT32; case mlir::StandardTypes::F16: return tflite::TensorType_FLOAT16; + case mlir::StandardTypes::F64: + return tflite::TensorType_FLOAT64; case mlir::TF::TensorFlowTypes::STRING: return tflite::TensorType_STRING; case mlir::TF::TensorFlowTypes::QUINT8: diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 3ad625f6e08..89cd9f46884 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -353,6 +353,22 @@ StatusOr ConvertFloatBuffer( } return DenseElementsAttr::get(shaped_type, ArrayRef(values)); } + case 64: { + assert(bytes_len % 8 == 0); + size_t elem_count = bytes_len / 8; + std::vector values; + values.reserve(elem_count); + + const char* data = reinterpret_cast(buffer.data()); + + for (int i = 0; i < elem_count; i++) { + uint64_t bit_repr = + llvm::support::endian::readNext(data); + values.push_back(absl::bit_cast(bit_repr)); + } + return DenseElementsAttr::get(shaped_type, ArrayRef(values)); + } } return errors::InvalidArgument("unsupported bit width", elem_type.getWidth()); } diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td index db0bef39358..b20e81aefa9 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td @@ -86,7 +86,8 @@ def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> { let methods = [ StaticInterfaceMethod< [{Returns whether the op's operands/results are supported by runtime.}], - "LogicalResult", "VerifyTflRuntimeTypes", (ins "Operation*":$op) + "LogicalResult", "VerifyTflRuntimeTypes", + (ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch) >, ]; } diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 7226f68cc90..484a5ea81c4 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -706,7 +706,10 @@ def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0> { } def TFL_CosOp: TFL_Op<"cos", [ - NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> { + NoSideEffect, + SameOperandsAndResultType, + NoQuantizableResult, + TFL_GpuTargetOp]> { let summary = "Cosine operator"; let description = [{ @@ -827,12 +830,12 @@ def TFL_GatherNdOp : TFL_Op<"gather_nd", [NoSideEffect]> { }]; let arguments = (ins - TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$params, + TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8, TFL_Str]>:$params, TFL_I32OrI64Tensor:$indices ); let results = (outs - TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$output + TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8, TFL_Str]>:$output ); } @@ -1108,7 +1111,10 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [ def TFL_DivOp : TFL_Op<"div", [ // TODO(fengliuai): NoQuantizableResult is only correct for int8 // quantization. update to handle Uint8 quantization. - ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> { + ResultsBroadcastableShape, + NoSideEffect, + NoQuantizableResult, + TFL_GpuTargetOp]> { let summary = "Division operator"; let description = [{ @@ -1187,7 +1193,9 @@ def TFL_EqualOp: TFL_Op<"equal", [Commutative, ResultsBroadcastableShape, let builders = [TFL_ComparisonBinaryBuilder]; } -def TFL_ExpOp: TFL_Op<"exp", [NoSideEffect, SameOperandsAndResultType]> { +def TFL_ExpOp: TFL_Op<"exp", [NoSideEffect, + SameOperandsAndResultType, + TFL_GpuTargetOp]> { let summary = "Natural exponentiation operator"; let description = [{ @@ -1369,7 +1377,8 @@ def TFL_GreaterOp : TFL_Op<"greater", [ } def TFL_HardSwishOp: TFL_Op<"hard_swish", [NoSideEffect, - SameOperandsAndResultShape]> { + SameOperandsAndResultShape, + TFL_GpuTargetOp]> { let summary = "Hardswish activation function."; let description = [{ Computes hard-swish activation function diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index 0a3f0eb3518..1165561cb71 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -84,8 +84,14 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, TF_ASSIGN_OR_RETURN( auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context)); + mlir::TFL::PassConfig pass_config(quant_specs); + bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops(); + pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops; + pass_config.lower_tensor_list_ops = true; + pass_config.shape_inference = false; + return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module), - quant_specs, result); + pass_config, result); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index f8435d17c8d..681773a7e6b 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -43,8 +43,6 @@ namespace tensorflow { Status ConvertSavedModelToTFLiteFlatBuffer( const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, - const string& saved_model_dir, bool saved_model_v1, - const string& saved_model_tags, const string& saved_model_exported_names, string* result) { mlir::MLIRContext context; mlir::TFL::QuantizationSpecs quant_specs; @@ -66,13 +64,28 @@ Status ConvertSavedModelToTFLiteFlatBuffer( // Register all custom ops, including user-specified custom ops. TF_RETURN_IF_ERROR(internal::RegisterAllCustomOps(toco_flags)); - const bool import_saved_model = !saved_model_v1; - TF_ASSIGN_OR_RETURN( - auto module, - ImportSavedModel(import_saved_model, saved_model_v1, saved_model_dir, - saved_model_tags, saved_model_exported_names, &context)); - return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module), - quant_specs, result); + auto& saved_model_tags = model_flags.saved_model_tags(); + auto& saved_model_exported_names = model_flags.saved_model_exported_names(); + std::unordered_set tags(saved_model_tags.begin(), + saved_model_tags.end()); + auto exported_names_in_vector = std::vector( + saved_model_exported_names.begin(), saved_model_exported_names.end()); + absl::Span exported_names(exported_names_in_vector); + + TF_ASSIGN_OR_RETURN(auto module, + ImportSavedModel(model_flags.saved_model_dir(), + model_flags.saved_model_version(), tags, + exported_names, &context)); + + mlir::TFL::PassConfig pass_config(quant_specs); + bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops(); + pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops; + pass_config.lower_tensor_list_ops = true; + pass_config.shape_inference = true; + + auto status = internal::ConvertMLIRToTFLiteFlatBuffer( + toco_flags, std::move(module), pass_config, result); + return status; } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h index dea5603dad0..ed339ca64b9 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h @@ -28,8 +28,6 @@ namespace tensorflow { // status if it fails to convert the input. Status ConvertSavedModelToTFLiteFlatBuffer( const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, - const string& saved_model_dir, bool saved_model_v1, - const string& saved_model_tags, const string& saved_model_exported_names, string* result); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index ae342dd49ae..a17cdda2a39 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -105,6 +105,10 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) { switch (dtype) { case toco::IODataType::FLOAT: return DT_FLOAT; + case toco::IODataType::FLOAT16: + return DT_HALF; + case toco::IODataType::FLOAT64: + return DT_DOUBLE; case toco::IODataType::QUANTIZED_UINT8: return DT_QUINT8; case toco::IODataType::INT8: @@ -261,7 +265,7 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) { Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module, - mlir::TFL::QuantizationSpecs quant_specs, + const mlir::TFL::PassConfig& pass_config, string* result) { bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops(); bool emit_select_tf_ops = toco_flags.enable_select_tf_ops(); @@ -275,9 +279,6 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags, } mlir::PassManager pm(module->getContext()); - mlir::TFL::PassConfig pass_config(quant_specs); - pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops; - pass_config.lower_tensor_list_ops = true; tensorflow::AddTFToTFLConversionPasses(pass_config, &pm); // Convert back to outlined while format for export back to flatbuffer. @@ -288,7 +289,8 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags, auto status = ConvertTFExecutorToTFLOrFlatbuffer( module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops, - emit_select_tf_ops, emit_custom_ops, quant_specs, result, &pm); + emit_select_tf_ops, emit_custom_ops, pass_config.quant_specs, result, + &pm); if (toco_flags.has_dump_graphviz_dir()) { TF_RETURN_IF_ERROR(DumpOpGraphToFile( // rename once we enable the new converter feature flag. diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h index 96c2096e469..3ea36e5eb1d 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h @@ -47,7 +47,7 @@ Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags, // This will also run relevant passes as well. Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module, - mlir::TFL::QuantizationSpecs quant_specs, + const mlir::TFL::PassConfig& pass_config, string* result); // Give a warning for any unused flags that have been specified. diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc index 85a988a9bde..50e3771d467 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc @@ -57,8 +57,9 @@ QuantizeContext::QuantizeContext(FuncOp func, const DeviceTarget &spec) }); } -llvm::ArrayRef QuantizeContext::GetAllOps() { - llvm::SmallVector all_ops; +std::vector QuantizeContext::GetAllOps() { + std::vector all_ops; + all_ops.reserve(128); func_.walk([&](quant::QuantizeRegionOp op) { all_ops.push_back(op); }); return all_ops; } @@ -75,7 +76,7 @@ LogicalResult QuantizeContext::Handle( switch (spec->type) { case ScaleConstraintType::OutputInputFreeScale: { // no propagation. - *changed = false; + *changed |= false; break; } case ScaleConstraintType::CustomScale: { @@ -84,7 +85,20 @@ LogicalResult QuantizeContext::Handle( } break; } + case ScaleConstraintType::OutputInputSameScale: { + auto params = GetQuantParamsForSameScaleConstraint(op); + if (EmptyParams(params)) { + *changed |= false; + break; + } + // propagate this params to all the quantizable ports. + if (failed(PropagateQuantParams(op, params, new_items, changed))) { + return failure(); + } + break; + } default: { + // TODO(fengliuai): implement the other types. llvm_unreachable("no implementation."); return failure(); } @@ -154,6 +168,102 @@ void QuantizeContext::DumpStates(QuantizeRegionOp current_op) { }); } +// A heuristic to get quantization parameters satisfies the same scale +// constraints: +// - If there are immutable states, +// - use the single input, or, +// - use the single output, or, +// - use the first one in the collection, +// - use the single input if it is ready, or, +// - use the single output if it is ready, or, +// - use use the first ready one in the collection. +QuantParams QuantizeContext::GetQuantParamsForSameScaleConstraint( + Operation *op) { + // Two vector to collect Non-empty operands and results states. + std::vector mutable_states, immutable_states; + for (int i = 0, e = op->getNumOperands(); i != e; ++i) { + auto &state = states_manager_.GetOperandQuantState(op, i); + if (state.immutable) { + immutable_states.push_back(&state); + } else if (!state.IsEmpty()) { + mutable_states.push_back(&state); + } + } + + int immutable_operands_num = immutable_states.size(); + int mutable_operands_num = mutable_states.size(); + // Use the operand's state if it is immutable and it is the only one + // operand. + if (op->getNumOperands() == 1 && immutable_operands_num == 1) { + return immutable_states.front()->params; + } + + for (int i = 0, e = op->getNumResults(); i != e; ++i) { + auto &state = states_manager_.GetResultQuantState(op, i); + if (state.immutable) { + immutable_states.push_back(&state); + } else if (!state.IsEmpty()) { + mutable_states.push_back(&state); + } + } + + int immutable_results_num = immutable_states.size() - immutable_operands_num; + int mutable_results_num = mutable_states.size() - mutable_operands_num; + // Use the result's state if it is immutable and it is the only one result. + if (op->getNumResults() == 1 && immutable_results_num == 1) { + return immutable_states.back()->params; + } + + LLVM_DEBUG(llvm::dbgs() + << "Quantization parameters are not collected in an ideal place. " + "Has to fallback values which might introduce errors.\n"); + + // Use the first immutable state to quantize the rest operands and results. + if (!immutable_states.empty()) return immutable_states.front()->params; + + // If there are no immutable states, use the operand's state if it is the + // only one operand and has parameters propagated. + if (op->getNumOperands() == 1 && mutable_operands_num == 1) { + return mutable_states.front()->params; + } + + // If there are no immutable states, use the result's state if it is the + // only one result and has parameters propagated. + if (op->getNumResults() == 1 && mutable_results_num == 1) { + return mutable_states.back()->params; + } + + // Use the first propagated state to quantize the rest operands and results. + if (!mutable_states.empty()) return mutable_states.front()->params; + + // None operands/results have parameters propagated, skip this node for now. + return {}; +} + +LogicalResult QuantizeContext::PropagateQuantParams( + Operation *op, const QuantParams params, + quant::AdjacentOperations *new_items, bool *changed) { + // Use the final state to set all the operands' parameters. + for (int i = 0, e = op->getNumOperands(); i != e; ++i) { + auto ele = op->getOperand(i).getType().cast().getElementType(); + if (ele.isa() && SetOperandParams(op, i, params)) { + *changed |= true; + new_items->push_back(op->getOperand(i).getDefiningOp()); + } + } + + // Use the final state to set all the results' parameters. + for (int res = 0, e = op->getNumResults(); res != e; ++res) { + auto ele = op->getResult(res).getType().cast().getElementType(); + if (ele.isa() && SetResultParams(op, res, params)) { + auto users = op->getResult(res).getUsers(); + *changed |= !users.empty(); + new_items->append(users.begin(), users.end()); + } + } + return success(); +} + int QuantizeContext::StatesManager::InitializeState(quant::QuantizeRegionOp op, int index, bool as_result) { Attribute params_attr; diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_context.h b/tensorflow/compiler/mlir/lite/quantization/quantization_context.h index 35ed1feaaab..0d460fd9a50 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_context.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_context.h @@ -67,7 +67,7 @@ class QuantizeContext { QuantizeContext(FuncOp func, const DeviceTarget &spec); // Returns all the quant region ops. - ArrayRef GetAllOps(); + std::vector GetAllOps(); // For each quant region op, propagates its quantization parameters according // to the kernel specification and also returns the adjcent quant region ops @@ -107,6 +107,25 @@ class QuantizeContext { return states_manager_.GetOperandParams(op, index); } + // A heuristic to get quantization parameters satisfies the same scale + // constraints: + // - If there are immutable states, + // - use the single input, or, + // - use the single output, or, + // - use the first one in the collection, + // - use the single input if it is ready, or, + // - use the single output if it is ready, or, + // - use use the first ready one in the collection. + QuantParams GetQuantParamsForSameScaleConstraint(Operation *op); + + // Propagate `params` to all the quantizable port of the `op`. The adjcent + // ops, which have the parameters propagated to, are collected by `new_items`, + // so they can be added to the working queue. `changed` is set to true if + // there are any new elements being added to `new_items`. + LogicalResult PropagateQuantParams(Operation *op, const QuantParams params, + AdjacentOperations *new_items, + bool *changed); + private: class StatesManager { public: diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.cc b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.cc index e4bdafa89ff..b456af27fa5 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.cc +++ b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.cc @@ -28,6 +28,14 @@ namespace ph = std::placeholders; CpuDeviceTarget::CpuDeviceTarget(MLIRContext* ctx) : DeviceTarget(ctx) { RegisterKernel("generic.concat", {qi8_, qi8_, qi8_}, quant::ScaleConstraintType::OutputInputSameScale); + + // TODO(fengliuai): All the combinations are required to list. We need to + // improve this. + RegisterKernel("generic.reshape", {qi8_, any_}, + quant::ScaleConstraintType::OutputInputSameScale); + RegisterKernel("generic.reshape", {any_, qi8_}, + quant::ScaleConstraintType::OutputInputSameScale); + RegisterKernel("generic.mul", {qi8_, qi8_, qi8_}, quant::ScaleConstraintType::OutputInputFreeScale); RegisterKernel("generic.mul_add", {qi8_, qi8n_, any_, qi8_}, diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.cc b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.cc index 478b9d54176..7deece117a5 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.cc +++ b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_kernel_fusion.cc @@ -176,7 +176,7 @@ llvm::SmallVector fuseOps(PatternRewriter* rewriter, auto* body = new Block(); region.body().push_back(body); - OpBuilder builder(body); + OpBuilder builder = OpBuilder::atBlockEnd(body); BlockAndValueMapping mapping; // Make block arguments and add it to the block value mapping. diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc b/tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc index c4c5904209c..22dd4357416 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc +++ b/tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc @@ -69,7 +69,7 @@ void PropagateQuantPass::runOnFunction() { CpuDeviceTarget spec(&getContext()); quant::QuantizeContext ctx(func, spec); - std::vector work_list(ctx.GetAllOps()); + std::vector work_list = ctx.GetAllOps(); bool changed = false; while (!work_list.empty()) { quant::QuantizeRegionOp op = work_list.back(); diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/region_propagation.mlir b/tensorflow/compiler/mlir/lite/quantization/xla/tests/region_propagation.mlir index 05ac48c9f39..a504be01827 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/tests/region_propagation.mlir +++ b/tensorflow/compiler/mlir/lite/quantization/xla/tests/region_propagation.mlir @@ -52,3 +52,18 @@ func @mul_add_annotated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tenso // CHECK: input_specs = [!quant.uniform, !quant.uniform:f32, 1.000000e+00:-128>, !quant.uniform] // CHECK-SAME: output_specs = [!quant.uniform] } + +// ----- + +// CHECK-LABEL: @same_scale_1_1 +func @same_scale_1_1(%arg0: tensor<1x7x7x64xf32>) -> (tensor<1x3136xf32>) { + %region = "quant.region"(%arg0) ( { + ^bb0(%arg1: tensor<1x7x7x64xf32>): // no predecessors + %r = "xla_hlo.reshape"(%arg1) : (tensor<1x7x7x64xf32>) -> (tensor<1x3136xf32>) + "quant.return"(%r) : (tensor<1x3136xf32>) -> () + }) {input_specs = [!quant.uniform], logical_kernel = "generic.reshape", output_specs = [f32]} : (tensor<1x7x7x64xf32>) -> tensor<1x3136xf32> + return %region : tensor<1x3136xf32> + +// CHECK: input_specs = [!quant.uniform] +// CHECK-SAME: output_specs = [!quant.uniform] +} diff --git a/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir b/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir index 5fe5fbfb3ee..6e2c1141f19 100644 --- a/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir +++ b/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir @@ -1,37 +1,53 @@ // RUN: tf-opt %s -tfl-identify-dilated-conv | FileCheck %s --dump-input-on-failure -func @testDilatedConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> { +func @testDilatedConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> { + %cst = constant dense<[2, 2]> : tensor<2xi32> + %cst_0 = constant dense<2> : tensor<2x2xi32> + %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> + %1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> + %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> + return %2 : tensor<1x128x128x8xf32> + + // CHECK-LABEL: testDilatedConv + // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>) + // CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> + // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> +} + +func @testDilatedConvWithNonConstantPadAndCrops(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> %2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> return %2 : tensor<1x128x128x8xf32> - // CHECK-LABEL: testDilatedConv + // CHECK-LABEL: testDilatedConvWithNonConstantPadAndCrops // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>) // CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> } -func @testDilatedConvWithNonZeroSTBPadding(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> { +func @testDilatedConvWithNonZeroBasePadding(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = constant dense<2> : tensor<2x2xi32> + %cst_1 = constant dense<1> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> - %1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> - %2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> + %1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> + %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> return %2 : tensor<1x128x128x8xf32> - // CHECK-LABEL: testDilatedConvWithNonZeroSTBPadding - // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>) + // CHECK-LABEL: testDilatedConvWithNonZeroBasePadding + // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>) // CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> } -func @testDilatedConvWithNonTrivialDilations(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> { +func @testDilatedConvWithNonTrivialDilations(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> - %1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", dilations = [1, 2, 2, 1], strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> - %2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> + %cst_0 = constant dense<2> : tensor<2x2xi32> + %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> + %1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", dilations = [1, 2, 2, 1], strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> + %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> return %2 : tensor<1x128x128x8xf32> // CHECK-LABEL: testDilatedConvWithNonTrivialDilations @@ -41,25 +57,27 @@ func @testDilatedConvWithNonTrivialDilations(%arg0: tensor<1x128x128x3xf32>, %ar // CHECK-NEXT: return [[RESULT]] } -func @testDilatedDepthWiseConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> { +func @testDilatedDepthWiseConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> - %1 = "tf.DepthwiseConv2dNative"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> - %2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> + %cst_0 = constant dense<2> : tensor<2x2xi32> + %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> + %1 = "tf.DepthwiseConv2dNative"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> + %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> return %2 : tensor<1x128x128x8xf32> // CHECK-LABEL: testDilatedDepthWiseConv - // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>) + // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>) // CHECK-NEXT: [[RESULT:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> } func @testDilatedConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> + %cst_0 = constant dense<2> : tensor<2x2xi32> + %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> %2 = "tf.Pad"(%1, %arg1) : (tensor<4x64x64x8xf32>, tensor<2x2xi32>) -> tensor<4x64x64x8xf32> - %3 = "tf.BatchToSpaceND"(%2, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> + %3 = "tf.BatchToSpaceND"(%2, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> %4 = "tf.BiasAdd"(%3, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> return %4 : tensor<1x128x128x8xf32> @@ -72,10 +90,11 @@ func @testDilatedConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi func @testDilatedDepthWiseConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> + %cst_0 = constant dense<2> : tensor<2x2xi32> + %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %1 = "tf.DepthwiseConv2dNative"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> %2 = "tf.Pad"(%1, %arg1) : (tensor<4x64x64x8xf32>, tensor<2x2xi32>) -> tensor<4x64x64x8xf32> - %3 = "tf.BatchToSpaceND"(%2, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> + %3 = "tf.BatchToSpaceND"(%2, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> %4 = "tf.BiasAdd"(%3, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> return %4 : tensor<1x128x128x8xf32> @@ -86,49 +105,52 @@ func @testDilatedDepthWiseConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: ten // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> } -func @testDilatedConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> { +func @testDilatedConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>, %arg2: tensor<8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> - %1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> - %2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> - %3 = "tf.BiasAdd"(%2, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> + %cst_0 = constant dense<2> : tensor<2x2xi32> + %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> + %1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> + %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> + %3 = "tf.BiasAdd"(%2, %arg2) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> return %3 : tensor<1x128x128x8xf32> // CHECK-LABEL: testDilatedConvWithBiasAdd - // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>) + // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>) // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> } -func @testDilatedDepthWiseConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> { +func @testDilatedDepthWiseConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>, %arg2: tensor<8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> - %1 = "tf.DepthwiseConv2dNative"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> - %2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> - %3 = "tf.BiasAdd"(%2, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> + %cst_0 = constant dense<2> : tensor<2x2xi32> + %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> + %1 = "tf.DepthwiseConv2dNative"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> + %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> + %3 = "tf.BiasAdd"(%2, %arg2) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> return %3 : tensor<1x128x128x8xf32> // CHECK-LABEL: testDilatedDepthWiseConvWithBiasAdd - // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>) + // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>) // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> } -func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> { +func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor<128xf32>) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = "tf.Const"() { value = dense<3> : tensor } : () -> tensor - %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> + %cst_1 = constant dense<2> : tensor<2x2xi32> + %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor) -> tensor<4x68x68x1xf32> - %2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> + %2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32> - %4 = "tf.BatchToSpaceND"(%3, %cst, %arg1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> - %5 = "tf.BiasAdd"(%4, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> + %4 = "tf.BatchToSpaceND"(%3, %cst, %cst_1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> + %5 = "tf.BiasAdd"(%4, %arg2) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> return %5 : tensor<1x128x128xf32> // CHECK-LABEL: testDilatedConvWithExpandSqueeze1 - // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>) + // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>) // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor} : () -> tensor // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> @@ -137,19 +159,20 @@ func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: ten // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32> } -func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> { +func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor<128xf32>) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = "tf.Const"() { value = dense<3> : tensor } : () -> tensor - %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> + %cst_1 = constant dense<2> : tensor<2x2xi32> + %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor) -> tensor<4x68x68x1xf32> - %2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> + %2 = "tf.DepthwiseConv2dNative"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32> - %4 = "tf.BatchToSpaceND"(%3, %cst, %arg1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> - %5 = "tf.BiasAdd"(%4, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> + %4 = "tf.BatchToSpaceND"(%3, %cst, %cst_1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> + %5 = "tf.BiasAdd"(%4, %arg2) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> return %5 : tensor<1x128x128xf32> // CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze1 - // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>) + // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>) // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor} : () -> tensor // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> @@ -158,19 +181,20 @@ func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, % // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32> } -func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor) -> tensor<1x128x128xf32> { +func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = "tf.Const"() { value = dense<3> : tensor } : () -> tensor - %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32> + %cst_1 = constant dense<2> : tensor<2x2xi32> + %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor) -> tensor<4x?x?x1xf32> - %2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32> + %2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x?x?x1xf32>) -> tensor<4x?x?xf32> - %4 = "tf.BiasAdd"(%3, %arg3) : (tensor<4x?x?xf32>, tensor) -> tensor<4x?x?xf32> - %5 = "tf.BatchToSpaceND"(%4, %cst, %arg1) : (tensor<4x?x?xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> + %4 = "tf.BiasAdd"(%3, %arg2) : (tensor<4x?x?xf32>, tensor) -> tensor<4x?x?xf32> + %5 = "tf.BatchToSpaceND"(%4, %cst, %cst_1) : (tensor<4x?x?xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> return %5 : tensor<1x128x128xf32> // CHECK-LABEL: testDilatedConvWithExpandSqueeze2 - // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor) + // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor) // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor} : () -> tensor // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> @@ -179,19 +203,20 @@ func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: ten // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32> } -func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor) -> tensor<1x128x128xf32> { +func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = "tf.Const"() { value = dense<3> : tensor } : () -> tensor - %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32> + %cst_1 = constant dense<2> : tensor<2x2xi32> + %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor) -> tensor<4x?x?x1xf32> - %2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32> + %2 = "tf.DepthwiseConv2dNative"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x?x?x1xf32>) -> tensor<4x?x?xf32> - %4 = "tf.BiasAdd"(%3, %arg3) : (tensor<4x?x?xf32>, tensor) -> tensor<4x?x?xf32> - %5 = "tf.BatchToSpaceND"(%4, %cst, %arg1) : (tensor<4x?x?xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> + %4 = "tf.BiasAdd"(%3, %arg2) : (tensor<4x?x?xf32>, tensor) -> tensor<4x?x?xf32> + %5 = "tf.BatchToSpaceND"(%4, %cst, %cst_1) : (tensor<4x?x?xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> return %5 : tensor<1x128x128xf32> // CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze2 - // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor) + // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor) // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor} : () -> tensor // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> @@ -203,12 +228,13 @@ func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, % func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = "tf.Const"() { value = dense<3> : tensor } : () -> tensor - %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> + %cst_1 = constant dense<2> : tensor<2x2xi32> + %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor) -> tensor<4x68x68x1xf32> %2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32> %4 = "tf.Pad"(%3, %arg1) : (tensor<4x64x64xf32>, tensor<2x2xi32>) -> tensor<4x64x64xf32> - %5 = "tf.BatchToSpaceND"(%4, %cst, %arg1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> + %5 = "tf.BatchToSpaceND"(%4, %cst, %cst_1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> %6 = "tf.BiasAdd"(%5, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> return %6 : tensor<1x128x128xf32> @@ -225,12 +251,13 @@ func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: ten func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = "tf.Const"() { value = dense<3> : tensor } : () -> tensor - %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> + %cst_1 = constant dense<2> : tensor<2x2xi32> + %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor) -> tensor<4x68x68x1xf32> %2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32> %4 = "tf.Pad"(%3, %arg1) : (tensor<4x64x64xf32>, tensor<2x2xi32>) -> tensor<4x64x64xf32> - %5 = "tf.BatchToSpaceND"(%4, %cst, %arg1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> + %5 = "tf.BatchToSpaceND"(%4, %cst, %cst_1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> %6 = "tf.BiasAdd"(%5, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> return %6 : tensor<1x128x128xf32> @@ -244,14 +271,15 @@ func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, % // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32> } -func @testDilatedConvWithDifferentExpandSqueezeAxis(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128x1xf32> { +func @testDilatedConvWithDifferentExpandSqueezeAxis(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = "tf.Const"() { value = dense<3> : tensor } : () -> tensor - %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> + %cst_1 = constant dense<2> : tensor<2x2xi32> + %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor) -> tensor<4x68x68x1xf32> - %2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> + %2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [2]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64x1xf32> - %4 = "tf.BatchToSpaceND"(%3, %cst, %arg1) : (tensor<4x64x64x1xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x1xf32> + %4 = "tf.BatchToSpaceND"(%3, %cst, %cst_1) : (tensor<4x64x64x1xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x1xf32> return %4 : tensor<1x128x128x1xf32> // CHECK-LABEL: testDilatedConvWithDifferentExpandSqueezeAxis diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/graph_with_placeholder_with_default.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/graph_with_placeholder_with_default.pbtxt index 33e1347a0b9..82e843517a3 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/graph_with_placeholder_with_default.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/graph_with_placeholder_with_default.pbtxt @@ -142,7 +142,7 @@ versions { # CHECK-SAME: control_outputs = "" # CHECK-SAME: inputs = "unranked" # CHECK-SAME: outputs = "unranked,static,static_10" -# CHECK: [[VAL_1:%.*]] = constant dense<0> : tensor<10xi32> -# CHECK: [[VAL_2:%.*]] = constant dense<0> : tensor -# CHECK: return [[VAL_0]], [[VAL_2]], [[VAL_1]] : tensor<1x8x8x2xi32>, tensor, tensor<10xi32> +# CHECK: [[VAL_1:%.*]] = constant dense<0> : tensor +# CHECK: [[VAL_2:%.*]] = constant dense<0> : tensor<10xi32> +# CHECK: return [[VAL_0]], [[VAL_1]], [[VAL_2]] : tensor<1x8x8x2xi32>, tensor, tensor<10xi32> # CHECK: } diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/ophint_lstm.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/ophint_lstm.pbtxt index 05ec3b0b93e..1b42b60acf7 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/ophint_lstm.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/ophint_lstm.pbtxt @@ -7788,35 +7788,35 @@ library { # CHECK-SAME: control_outputs = "" # CHECK-SAME: inputs = "INPUT" # CHECK-SAME: outputs = "OUTPUT" -# CHECK: [[VAL_1:%.*]] = constant dense<{{\[\[}}-0.400154352, 0.739109992, 0.201825857], [0.678572893, 0.32076478, 0.949867963], [-0.807729483, -5.324750e-01, 0.148033619]]> : tensor<3x3xf32> -# CHECK: [[VAL_2:%.*]] = constant dense<{{\[\[}}0.886177539, -0.606141329, -0.451275587], [0.325554609, 0.691527605, -0.676239967], [0.219799042, 0.626042128, -0.597596407]]> : tensor<3x3xf32> -# CHECK: [[VAL_3:%.*]] = constant dense<{{\[\[}}-0.493826151, -0.391061306, -0.349843264], [-0.0213134289, 0.558384657, -0.51513052], [0.427886248, 0.618100405, -0.187585592]]> : tensor<3x3xf32> -# CHECK: [[VAL_4:%.*]] = constant dense<{{\[\[}}0.444335222, -0.133341789, 0.839591503], [0.445418358, -0.571707964, 0.569707394], [0.465010405, -0.990037918, -0.632481337]]> : tensor<3x3xf32> -# CHECK: [[VAL_5:%.*]] = constant dense<{{\[\[}}-0.138204336, -0.10879755, -0.135128736], [0.94797182, -8.713360e-01, -0.792336463], [0.0339827538, -0.539326906, 8.906350e-01]]> : tensor<3x3xf32> -# CHECK: [[VAL_6:%.*]] = constant dense<{{\[\[}}0.513064623, -0.692989588, 0.547988653], [0.0653710365, 0.576977491, 0.966733217], [0.0130724907, 0.247342348, 0.317092657]]> : tensor<3x3xf32> -# CHECK: [[VAL_7:%.*]] = constant dense<{{\[\[}}0.230039358, -0.182297707, -0.352231741], [-0.805100203, -0.220300436, -0.669503212], [0.278807402, -0.201502323, -0.627609729]]> : tensor<3x3xf32> -# CHECK: [[VAL_8:%.*]] = constant dense<{{\[\[}}-0.207589626, -0.756766081, -0.853258133], [-0.269270182, 0.0468223095, -0.353052378], [-0.0702953338, 0.0725159645, -0.817753077]]> : tensor<3x3xf32> -# CHECK: [[VAL_9:%.*]] = constant dense<[0.171322107, -0.153412342, 0.591750383]> : tensor<3xf32> -# CHECK: [[VAL_10:%.*]] = constant dense<[-0.671292543, 0.411814928, 0.560465336]> : tensor<3xf32> -# CHECK: [[VAL_11:%.*]] = constant dense<[0.403919935, -0.882057666, -0.894463062]> : tensor<3xf32> -# CHECK: [[VAL_12:%.*]] = constant dense<{{\[\[}}-0.936182261, -0.935433864, 0.288229942], [-0.243383884, -0.628288031, -0.477061749], [-0.514976501, -0.903514862, 6.728170e-01]]> : tensor<3x3xf32> -# CHECK: [[VAL_13:%.*]] = constant dense<{{\[\[}}0.18183589, 0.616135359, -0.167827845], [0.734281301, 0.958347797, -0.878054618], [0.369523764, -0.969005823, -0.881014585]]> : tensor<3x3xf32> -# CHECK: [[VAL_14:%.*]] = constant dense<{{\[\[}}-5.087240e-01, -0.588907719, 0.471896172], [-0.508019447, -0.0157074928, -0.804120779], [-0.978842973, 0.00160336494, -0.978532075]]> : tensor<3x3xf32> -# CHECK: [[VAL_15:%.*]] = constant dense<{{\[\[}}-0.616786718, 0.892614365, 0.671324968], [-0.842380046, -0.358094931, 0.821366549], [0.790347338, 0.71222949, 0.0690443515]]> : tensor<3x3xf32> -# CHECK: [[VAL_16:%.*]] = constant dense<1.000000e+00> : tensor<3xf32> -# CHECK: [[VAL_17:%.*]] = constant dense<{{\[\[}}0.782244444, -0.0446639061, 0.848498106], [-0.579102755, -0.407756329, 0.442389727], [0.00566458702, 0.5984025, 0.629857302]]> : tensor<3x3xf32> -# CHECK: [[VAL_18:%.*]] = constant dense<{{\[\[}}0.891112089, -2.786560e-01, 0.966933965], [-0.789963722, 0.057955265, 0.217499971], [-0.698129416, -0.983400583, -0.834380626]]> : tensor<3x3xf32> -# CHECK: [[VAL_19:%.*]] = constant dense<{{\[\[}}-0.125753641, 0.32271719, 0.488939524], [0.36119318, 0.982266664, -0.448646784], [0.966353893, -0.767024993, 0.446366787]]> : tensor<3x3xf32> -# CHECK: [[VAL_20:%.*]] = constant dense<{{\[\[}}-0.856678485, -0.800494194, 0.716800689], [0.536404848, 0.541643381, -0.35657692], [-0.794646739, 0.137629032, 0.690013885]]> : tensor<3x3xf32> -# CHECK: [[VAL_21:%.*]] = constant dense<0.000000e+00> : tensor<3xf32> -# CHECK: [[VAL_22:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32> +# CHECK: [[VAL_1:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32> +# CHECK: [[VAL_2:%.*]] = constant dense<0.000000e+00> : tensor<3xf32> +# CHECK: [[VAL_3:%.*]] = constant dense<{{\[\[}}-0.856678485, -0.800494194, 0.716800689], [0.536404848, 0.541643381, -0.35657692], [-0.794646739, 0.137629032, 0.690013885]]> : tensor<3x3xf32> +# CHECK: [[VAL_4:%.*]] = constant dense<{{\[\[}}-0.125753641, 0.32271719, 0.488939524], [0.36119318, 0.982266664, -0.448646784], [0.966353893, -0.767024993, 0.446366787]]> : tensor<3x3xf32> +# CHECK: [[VAL_5:%.*]] = constant dense<{{\[\[}}0.891112089, -2.786560e-01, 0.966933965], [-0.789963722, 0.057955265, 0.217499971], [-0.698129416, -0.983400583, -0.834380626]]> : tensor<3x3xf32> +# CHECK: [[VAL_6:%.*]] = constant dense<{{\[\[}}0.782244444, -0.0446639061, 0.848498106], [-0.579102755, -0.407756329, 0.442389727], [0.00566458702, 0.5984025, 0.629857302]]> : tensor<3x3xf32> +# CHECK: [[VAL_7:%.*]] = constant dense<1.000000e+00> : tensor<3xf32> +# CHECK: [[VAL_8:%.*]] = constant dense<{{\[\[}}-0.616786718, 0.892614365, 0.671324968], [-0.842380046, -0.358094931, 0.821366549], [0.790347338, 0.71222949, 0.0690443515]]> : tensor<3x3xf32> +# CHECK: [[VAL_9:%.*]] = constant dense<{{\[\[}}-5.087240e-01, -0.588907719, 0.471896172], [-0.508019447, -0.0157074928, -0.804120779], [-0.978842973, 0.00160336494, -0.978532075]]> : tensor<3x3xf32> +# CHECK: [[VAL_10:%.*]] = constant dense<{{\[\[}}0.18183589, 0.616135359, -0.167827845], [0.734281301, 0.958347797, -0.878054618], [0.369523764, -0.969005823, -0.881014585]]> : tensor<3x3xf32> +# CHECK: [[VAL_11:%.*]] = constant dense<{{\[\[}}-0.936182261, -0.935433864, 0.288229942], [-0.243383884, -0.628288031, -0.477061749], [-0.514976501, -0.903514862, 6.728170e-01]]> : tensor<3x3xf32> +# CHECK: [[VAL_12:%.*]] = constant dense<{{\[}}0.403919935, -0.882057666, -0.894463062]> : tensor<3xf32> +# CHECK: [[VAL_13:%.*]] = constant dense<{{\[}}-0.671292543, 0.411814928, 0.560465336]> : tensor<3xf32> +# CHECK: [[VAL_14:%.*]] = constant dense<{{\[}}0.171322107, -0.153412342, 0.591750383]> : tensor<3xf32> +# CHECK: [[VAL_15:%.*]] = constant dense<{{\[\[}}-0.207589626, -0.756766081, -0.853258133], [-0.269270182, 0.0468223095, -0.353052378], [-0.0702953338, 0.0725159645, -0.817753077]]> : tensor<3x3xf32> +# CHECK: [[VAL_16:%.*]] = constant dense<{{\[\[}}0.230039358, -0.182297707, -0.352231741], [-0.805100203, -0.220300436, -0.669503212], [0.278807402, -0.201502323, -0.627609729]]> : tensor<3x3xf32> +# CHECK: [[VAL_17:%.*]] = constant dense<{{\[\[}}0.513064623, -0.692989588, 0.547988653], [0.0653710365, 0.576977491, 0.966733217], [0.0130724907, 0.247342348, 0.317092657]]> : tensor<3x3xf32> +# CHECK: [[VAL_18:%.*]] = constant dense<{{\[\[}}-0.138204336, -0.10879755, -0.135128736], [0.94797182, -8.713360e-01, -0.792336463], [0.0339827538, -0.539326906, 8.906350e-01]]> : tensor<3x3xf32> +# CHECK: [[VAL_19:%.*]] = constant dense<{{\[\[}}0.444335222, -0.133341789, 0.839591503], [0.445418358, -0.571707964, 0.569707394], [0.465010405, -0.990037918, -0.632481337]]> : tensor<3x3xf32> +# CHECK: [[VAL_20:%.*]] = constant dense<{{\[\[}}-0.493826151, -0.391061306, -0.349843264], [-0.0213134289, 0.558384657, -0.51513052], [0.427886248, 0.618100405, -0.187585592]]> : tensor<3x3xf32> +# CHECK: [[VAL_21:%.*]] = constant dense<{{\[\[}}0.886177539, -0.606141329, -0.451275587], [0.325554609, 0.691527605, -0.676239967], [0.219799042, 0.626042128, -0.597596407]]> : tensor<3x3xf32> +# CHECK: [[VAL_22:%.*]] = constant dense<{{\[\[}}-0.400154352, 0.739109992, 0.201825857], [0.678572893, 0.32076478, 0.949867963], [-0.807729483, -5.324750e-01, 0.148033619]]> : tensor<3x3xf32> # CHECK: [[VAL_23:%.*]] = constant unit -# CHECK: [[VAL_24:%.*]]:3 = "tfl.unpack"(%[[ARG_0]]) {axis = 1 : i32, num = 3 : i32} : (tensor<1x3x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -# CHECK: [[VAL_25:%.*]] = "tfl.pack"([[VAL_24]]#0, [[VAL_24]]#1, [[VAL_24]]#2) {axis = 0 : i32, values_count = 3 : i32} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<3x1x3xf32> +# CHECK: [[UNPACK:%.*]]:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<1x3x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) +# CHECK: [[PACK:%.*]] = "tfl.pack"([[UNPACK]]#0, [[UNPACK]]#1, [[UNPACK]]#2) {axis = 0 : i32, values_count = 3 : i32} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<3x1x3xf32> +# CHECK: [[VAL_24:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32> +# CHECK: [[UNIDIRECTIONAL_SEQUENCE_LSTM_1:%.*]] = "tfl.unidirectional_sequence_lstm"([[PACK]], [[VAL_16]], [[VAL_17]], [[VAL_18]], [[VAL_15]], [[VAL_20]], [[VAL_21]], [[VAL_22]], [[VAL_19]], [[VAL_13]], [[VAL_14]], [[VAL_12]], [[VAL_2]], [[VAL_7]], [[VAL_2]], [[VAL_2]], [[VAL_23]], [[VAL_23]], [[VAL_1]], [[VAL_24]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {fused_activation_function = "TANH", time_major = true} : (tensor<3x1x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, none, none, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, none) -> tensor<3x1x3xf32> +# CHECK: [[VAL_25:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32> # CHECK: [[VAL_26:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32> -# CHECK: [[VAL_27:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_25]], [[VAL_7]], [[VAL_6]], [[VAL_5]], [[VAL_8]], [[VAL_3]], [[VAL_2]], [[VAL_1]], [[VAL_4]], [[VAL_10]], [[VAL_9]], [[VAL_11]], [[VAL_21]], [[VAL_16]], [[VAL_21]], [[VAL_21]], [[VAL_23]], [[VAL_23]], [[VAL_22]], [[VAL_26]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {fused_activation_function = "TANH", time_major = true} : (tensor<3x1x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, none, none, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, none) -> tensor<3x1x3xf32> -# CHECK: [[VAL_28:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32> -# CHECK: [[VAL_29:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32> -# CHECK: [[VAL_30:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_27]], [[VAL_19]], [[VAL_18]], [[VAL_17]], [[VAL_20]], [[VAL_14]], [[VAL_13]], [[VAL_12]], [[VAL_15]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_21]], [[VAL_16]], [[VAL_21]], [[VAL_21]], [[VAL_23]], [[VAL_23]], [[VAL_28]], [[VAL_29]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {fused_activation_function = "TANH", time_major = true} : (tensor<3x1x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, none, none, none, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, none, none, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, none) -> tensor<3x1x3xf32> -# CHECK: [[VAL_31:%.*]]:3 = "tfl.unpack"([[VAL_30]]) {axis = 0 : i32, num = 3 : i32} : (tensor<3x1x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -# CHECK: return [[VAL_31]]#2 : tensor<1x3xf32> +# CHECK: [[UNIDIRECTIONAL_SEQUENCE_LSTM_2:%.*]] = "tfl.unidirectional_sequence_lstm"([[UNIDIRECTIONAL_SEQUENCE_LSTM_1]], [[VAL_4]], [[VAL_5]], [[VAL_6]], [[VAL_3]], [[VAL_9]], [[VAL_10]], [[VAL_11]], [[VAL_8]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_2]], [[VAL_7]], [[VAL_2]], [[VAL_2]], [[VAL_23]], [[VAL_23]], [[VAL_25]], [[VAL_26]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {fused_activation_function = "TANH", time_major = true} : (tensor<3x1x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, none, none, none, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, none, none, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, none) -> tensor<3x1x3xf32> +# CHECK: [[RESULT:%.*]]:3 = "tfl.unpack"([[UNIDIRECTIONAL_SEQUENCE_LSTM_2]]) {axis = 0 : i32, num = 3 : i32} : (tensor<3x1x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) +# CHECK: return [[RESULT]]#2 : tensor<1x3xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir index a113c318d80..d834f99fa7e 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir @@ -28,6 +28,13 @@ func @f32() -> tensor<4xf32> { return %0 : tensor<4xf32> } +func @f64() -> tensor<4xf64> { + // CHECK-LABEL: @f64 + // CHECK: value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf64> + %0 = "tfl.pseudo_const"() { value = dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64> } : () -> tensor<4xf64> + return %0 : tensor<4xf64> +} + func @i8() -> tensor<4xi8> { // CHECK-LABEL: @i8 // CHECK: value = dense<[1, 2, 3, 4]> : tensor<4xi8> diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 7db46f778fa..7e9b1bdb711 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -829,6 +829,14 @@ func @pack3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2 : tensor<2x // CHECK: "tfl.pack"(%arg0, %arg1, %arg2) {axis = 1 : i32, values_count = 3 : i32} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32> } +func @packStringWithFlex(%arg0: tensor<2x!tf.string>, %arg1: tensor<2x!tf.string>) -> tensor<2x2x!tf.string> { + %0 = "tf.Pack"(%arg0, %arg1) : (tensor<2x!tf.string>, tensor<2x!tf.string>) -> tensor<2x2x!tf.string> + return %0 : tensor<2x2x!tf.string> + +// CHECK-LABEL: packStringWithFlex +// CHECK: "tf.Pack"(%arg0, %arg1) : (tensor<2x!tf.string>, tensor<2x!tf.string>) -> tensor<2x2x!tf.string> +} + func @packNegAxis(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2 : tensor<2xi32>) -> tensor<2x3xi32> { %0 = "tf.Pack"(%arg0, %arg1, %arg2) {axis = -1 : i64} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_f64.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_f64.mlir new file mode 100644 index 00000000000..4ba9ef75459 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_f64.mlir @@ -0,0 +1,66 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s + +func @main(tensor<4xf64>, tensor<4xf64>) -> tensor<4xf64> { +^bb0(%arg0: tensor<4xf64>, %arg1: tensor<4xf64>): +// CHECK: { +// CHECK-NEXT: version: 3, +// CHECK-NEXT: operator_codes: [ { +// CHECK-NEXT: builtin_code: CUSTOM, +// CHECK-NEXT: custom_code: "FlexAdd" +// CHECK-NEXT: } ], +// CHECK-NEXT: subgraphs: [ { +// CHECK-NEXT: tensors: [ { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: type: FLOAT64, +// CHECK-NEXT: buffer: 1, +// CHECK-NEXT: name: "arg0", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: type: FLOAT64, +// CHECK-NEXT: buffer: 2, +// CHECK-NEXT: name: "arg1", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: type: FLOAT64, +// CHECK-NEXT: buffer: 3, +// CHECK-NEXT: name: "add", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: } ], +// CHECK-NEXT: inputs: [ 0, 1 ], +// CHECK-NEXT: outputs: [ 2 ], +// CHECK-NEXT: operators: [ { +// CHECK-NEXT: inputs: [ 0, 1 ], +// CHECK-NEXT: outputs: [ 2 ], +// CHECK-NEXT: custom_options: [ 3, 65, 100, 100, 0, 20, 18, 3, 65, 100, 100, 26, 0, 26, 0, 42, 7, 10, 1, 84, 18, 2, 48, 2, 50, 0, 0, 2, 27, 23, 20, 20, 4, 40, 1 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: name: "main" +// CHECK-NEXT: } ], +// CHECK-NEXT: description: "MLIR Converted.", +// CHECK-NEXT: buffers: [ { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 4 +// CHECK-NEXT: } ] +// CHECK-NEXT:} + + %0 = "tf.Add"(%arg0, %arg1) : (tensor<4xf64>, tensor<4xf64>) -> tensor<4xf64> loc("add") + return %0 : tensor<4xf64> +} diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index 762bd8c8ed2..038adebabef 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -138,13 +138,24 @@ int main(int argc, char **argv) { // TODO(b/147435528): We need to test the e2e behavior once the graph freezing // inside mlir is done. if (import_saved_model_object_graph || import_saved_model_signature_defs) { + int saved_model_version; + if (import_saved_model_object_graph) { + saved_model_version = 2; + } else { + saved_model_version = 1; + } if (input_mlir) module = tensorflow::errors::InvalidArgument( "Importing saved model should not have input_mlir set"); - module = tensorflow::ImportSavedModel(import_saved_model_object_graph, - import_saved_model_signature_defs, - input_file_name, saved_model_tags, - saved_model_exported_names, &context); + + std::unordered_set tags = + absl::StrSplit(saved_model_tags, ','); + std::vector exported_names_vector = + absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); + absl::Span exported_names(exported_names_vector); + + module = tensorflow::ImportSavedModel(input_file_name, saved_model_version, + tags, exported_names, &context); } else { module = tensorflow::LoadFromGraphdefOrMlirSource( input_file_name, input_mlir, use_splatted_constant, custom_opdefs, @@ -197,11 +208,6 @@ int main(int argc, char **argv) { pass_config.lower_tensor_list_ops = lower_tensor_list_ops; pass_config.legalize_tf_while = convert_tf_while_to_tfl_while; - // Currently we only do shape inference for saved model import. - if (import_saved_model_object_graph || import_saved_model_signature_defs) { - pass_config.shape_inference = true; - } - tensorflow::AddTFToTFLConversionPasses(pass_config, &pm); // TODO(b/150901738): Move those into tf_tfl_translate.cc. // Convert back to outlined while format for export back to flatbuffer. diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index 7c0a91d6d4e..aacc1ad2fd6 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -160,25 +160,17 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( } StatusOr ImportSavedModel( - bool import_saved_model, bool import_saved_model_v1, - const std::string& input_filename, const std::string& saved_model_tags, - const std::string& saved_model_exported_names, mlir::MLIRContext* context) { - if (import_saved_model) { - std::unordered_set tags = - absl::StrSplit(saved_model_tags, ','); - std::vector exported_names = - absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); - + const std::string& input_filename, const int saved_model_version, + const std::unordered_set& tags, + absl::Span exported_names, mlir::MLIRContext* context) { + if (saved_model_version == 2) { auto module = tensorflow::SavedModelObjectGraphToMlirImport( - input_filename, tags, absl::Span(exported_names), context); + input_filename, tags, exported_names, context); if (!module) return tensorflow::errors::InvalidArgument("fail to open input file"); return module; - } else if (import_saved_model_v1) { - std::unordered_set tags = - absl::StrSplit(saved_model_tags, ','); - + } else if (saved_model_version == 1) { auto module = tensorflow::SavedModelSignatureDefsToMlirImport( input_filename, tags, context); diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h index c93f8a6d416..d2c31a6b972 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_ +#include + +#include "absl/types/span.h" #include "llvm/Support/SourceMgr.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project @@ -42,9 +45,9 @@ LoadFromGraphdefOrMlirSource( // Load Saved model (either v1 or v2) into MLIR. stream_executor::port::StatusOr ImportSavedModel( - bool import_saved_model, bool import_saved_model_v1, - const std::string& input_filename, const std::string& saved_model_tags, - const std::string& saved_model_exported_names, mlir::MLIRContext* context); + const std::string& input_filename, const int saved_model_version, + const std::unordered_set& tags, + absl::Span exported_names, mlir::MLIRContext* context); // Taking a MLIR module in TF executor dialect and a set of parameters, // applies a set of passes to convert the module to TF Lite dialect and diff --git a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h index 68a1c617e34..b745be7753a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h +++ b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h @@ -152,7 +152,6 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( } // BatchToSpaceND + BiasAdd. - // TODO(b/149936532): Check the `crops` input, currently ignored. TF::BatchToSpaceNDOp bts_op; TF::BiasAddOp biasadd_op; bool final_op_is_bts = true; @@ -179,16 +178,50 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( if (!dilations_attr.hasValue()) return failure(); op.setAttr("dilations", dilations_attr.getValue()); - // Padding is set to 'SAME' when `stb_op` has non-zero paddings. - // TODO(b/149936532): This assumption only holds when the input width & height - // is multiple of dilation width & height. We should fix it in order to - // support other use cases. + // TODO(b/149936532): Check that the input width & height are multiples of + // dilation rate. + // TF python library will rewrite dilated conv to + // "SpaceToBatch->Conv->BatchToSpace" pattern, and the Conv in the middle + // always has 'VALID' padding. The padding tensor in `SpaceToBatch` has two + // parts of contributions, one is to reduce padding of CONV from 'SAME' to + // 'VALID', and another is to make input shape multiples of dilation rate. The + // first part of padding, which is also called `base_padding` will be used + // here to determine if the original padding format is 'SAME' or 'VALID'. + // According to the following formula we will compute the `base_padding` if + // it's a constant. Basically, `paddings` tensor in `SpaceToBatch` and `crops` + // tensor in `BatchToSpace` must satisfy the following: + // paddings[i, 0] = base_paddings[i, 0]. + // 0 <= paddings[i, 1] - base_paddings[i, 1] < block_shape[i] + // (input_shape[i] + paddings[i, 0] + paddings[i, 1]) % block_shape[i] == 0. + // crops[i, 0] = 0. + // crops[i, 1] = paddings[i, 1] - base_paddings[i, 1]. + + // If `paddings` - `crops` != 0, this means that `base_paddings` != 0, which + // tells us the original padding is 'SAME' (with one caveat presented below). + // Here we need to reset the padding back to `SAME` if `base_padding` + // != 0. + // TODO(b/149936532): We might not simply rely on `paddings - crops != 0` to + // determine the original padding format. For example, users can build + // arbitrary valid examples of `STB->Conv->BTS` which doesn't represent a + // dilated conv, hence we shouldn't pattern match here. Instead, we need to + // check values of `paddings` and `crops` to make sure it really stands for + // a dilated conv. auto stb_paddings = stb_op.paddings(); - ElementsAttr stb_paddings_attr; - if (matchPattern(stb_paddings, m_Constant(&stb_paddings_attr))) { - if (llvm::any_of(stb_paddings_attr.getValues(), - [](IntegerAttr attr) { return attr.getInt() != 0; })) { - op.setAttr("padding", rewriter.getStringAttr("SAME")); + auto bts_crops = bts_op.crops(); + ElementsAttr stb_paddings_attr, bts_crops_attr; + if (matchPattern(stb_paddings, m_Constant(&stb_paddings_attr)) && + matchPattern(bts_crops, m_Constant(&bts_crops_attr))) { + if (stb_paddings_attr.getNumElements() != bts_crops_attr.getNumElements()) + return failure(); + // padding - crop. + auto paddings = stb_paddings_attr.getValues(); + auto crops = bts_crops_attr.getValues(); + for (auto it1 = paddings.begin(), it2 = crops.begin(); + it1 != paddings.end() && it2 != crops.end(); it1++, it2++) { + if ((*it1).getInt() != (*it2).getInt()) { + op.setAttr("padding", rewriter.getStringAttr("SAME")); + break; + } } } diff --git a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc index 40b9c54450e..51b14d2013b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc +++ b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc @@ -679,8 +679,8 @@ LogicalResult ConvertOphintToStub(StringRef stub_name, return success(); } -struct ExtractOphintPass : public ModulePass { - void runOnModule() override; +struct ExtractOphintPass : public OperationPass { + void runOnOperation() override; void Verify(); private: @@ -689,8 +689,8 @@ struct ExtractOphintPass : public ModulePass { // TODO(renjieliu): Current ophint extraction does not support inputs/outputs // cross functions, we need to do that. -void ExtractOphintPass::runOnModule() { - ModuleOp module = getModule(); +void ExtractOphintPass::runOnOperation() { + ModuleOp module = getOperation(); for (auto function : module.getOps()) { // Process block by block. for (auto& bb : function.getBody()) { @@ -710,7 +710,7 @@ void ExtractOphintPass::runOnModule() { ophint_composite_ops_count = ophint_composite_ops.size(); // Convert. - OpBuilder builder(&bb); + OpBuilder builder = OpBuilder::atBlockEnd(&bb); for (const auto& kv : ophint_composite_ops) { if (failed(ConvertOphintToStub(kv.getKey(), kv.getValue(), &builder, &module))) { @@ -724,9 +724,9 @@ void ExtractOphintPass::runOnModule() { } void ExtractOphintPass::Verify() { - ModuleOp module = getModule(); + ModuleOp module = getOperation(); int ophint_func_op_count = 0; - for (FuncOp func : getModule().getOps()) { + for (FuncOp func : getOperation().getOps()) { for (const NamedAttribute attr : func.getAttrs()) { if (attr.first == kTfLiteFunctionName) { ophint_func_op_count++; diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc index 0d9630a9793..299a8774db6 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc @@ -68,8 +68,9 @@ constexpr char kUnidirectionalSequenceLstm[] = "UnidirectionalSequenceLstm"; // | // | // OutputOp1 -struct LegalizeOphintFuncOpPass : public ModulePass { - void runOnModule() override; +struct LegalizeOphintFuncOpPass + : public OperationPass { + void runOnOperation() override; }; llvm::StringMap FindCompositeFuncOps(ModuleOp module) { @@ -256,8 +257,8 @@ LogicalResult ConvertCallOps(llvm::StringMap* composite_func_ops, return success(); } -void LegalizeOphintFuncOpPass::runOnModule() { - ModuleOp module = getModule(); +void LegalizeOphintFuncOpPass::runOnOperation() { + ModuleOp module = getOperation(); // Find all composite funcs, then for every call op inside every func op // within the module, we go ahead and replace the callop with the tflite diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 4d40eec7a1b..98501aaa803 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -745,7 +745,8 @@ void LegalizeTF::runOnFunction() { Optional([](Operation* op) { auto tfl_op = dyn_cast_or_null(op); if (!tfl_op) return false; - return succeeded(tfl_op.VerifyTflRuntimeTypes(tfl_op.getOperation())); + return succeeded(tfl_op.VerifyTflRuntimeTypes( + tfl_op.getOperation(), /*failure_on_operand_type_mismatch=*/false)); })); // Keep trying to convert. // TODO(karimnosseir): This is similar to what apply greedy patterns does. diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc index 6d7713ad505..e85a85f26cb 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc @@ -31,11 +31,11 @@ namespace { // Legalize TF While to TFL While with calls to the original functions from the // cond and body regions. -struct LegalizeWhile : public ModulePass { +struct LegalizeWhile : public OperationPass { void RunOnFunction(FuncOp func); - void runOnModule() override { - for (auto op : getModule().getOps()) RunOnFunction(op); + void runOnOperation() override { + for (auto op : getOperation().getOps()) RunOnFunction(op); } }; diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index 6ad9a6d2267..17d0f6743a1 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -82,8 +82,8 @@ class TensorListPatternRewriter : public PatternRewriter { /// Lower TensorList ops in functions for subsequent legalization. struct LowerStaticTensorListPass - : public ModulePass { - void runOnModule() override; + : public OperationPass { + void runOnOperation() override; // Apply type and op changes within a function. LogicalResult RewriteFunction(FuncOp func, @@ -878,14 +878,14 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction( return applyFullConversion(func, target, patterns); } -void LowerStaticTensorListPass::runOnModule() { +void LowerStaticTensorListPass::runOnOperation() { // TODO(haoliang): currently we process the `main` function first, and the // remaining functions may be processed in arbitrary order. However, this will // have a potential issue when one function taking a `DT_VARIANT` is processed // before the function that produces the `DT_VARIANT`. We need to carefully // order the functions to be processed. std::vector funcs_in_module; - for (auto func : getModule().getOps()) { + for (auto func : getOperation().getOps()) { // Always place the main function to be the first in the list. if (func.getName() == "main") { funcs_in_module.insert(funcs_in_module.begin(), func); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc index cf12e036360..302194e1293 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc @@ -36,8 +36,8 @@ using FuncSet = llvm::SmallSet; // Module pass to optimize TensorFlow functional ops. struct OptimizeFunctionalOpsPass - : public ModulePass { - void runOnModule() override; + : public OperationPass { + void runOnOperation() override; }; // Updates function return type of the given functions to match the terminator @@ -180,13 +180,13 @@ static void EraseDeadFuncs(const FuncSet& candidate_funcs, ModuleOp module) { } } -void OptimizeFunctionalOpsPass::runOnModule() { +void OptimizeFunctionalOpsPass::runOnOperation() { OwningRewritePatternList patterns; FuncSet inlined_funcs; patterns.insert(&getContext(), &inlined_funcs); - ModuleOp module = getModule(); + ModuleOp module = getOperation(); applyPatternsGreedily(module, patterns); // Erase inlined functions that don't have any references. diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc index e8a2a2e75d8..c29e85a0f4d 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc @@ -94,14 +94,14 @@ class ConvertEmbeddedLookupFunc { // body with the corresponding fused TFLite op. The replacement need not always // be a fused op, though that is the primary use case. class PrepareCompositeFunctionsPass - : public ModulePass { + : public OperationPass { public: explicit PrepareCompositeFunctionsPass() {} private: void ConvertTFImplements(FuncOp func, StringAttr attr); void ConvertTFAPIImplements(FuncOp func, StringAttr attr, ModuleOp module); - void runOnModule() override; + void runOnOperation() override; }; void PrepareCompositeFunctionsPass::ConvertTFImplements(FuncOp func, @@ -189,8 +189,8 @@ void PrepareCompositeFunctionsPass::ConvertTFAPIImplements(FuncOp func, } } -void PrepareCompositeFunctionsPass::runOnModule() { - auto module = getModule(); +void PrepareCompositeFunctionsPass::runOnOperation() { + auto module = getOperation(); for (auto func : module.getOps()) { // We have two kinds of implements: // 1) tf._implements. diff --git a/tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc b/tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc index 92eb7023438..d103209ffd9 100644 --- a/tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc +++ b/tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc @@ -34,7 +34,9 @@ class RuntimeTypeVerifyPass : public mlir::FunctionPass { void RuntimeTypeVerifyPass::runOnFunction() { getFunction().walk([&](TflRuntimeVerifyOpInterface op) { - if (failed(op.VerifyTflRuntimeTypes(op.getOperation()))) + if (failed(op.VerifyTflRuntimeTypes( + op.getOperation(), + /*failure_on_operand_type_mismatch=*/true))) signalPassFailure(); }); } diff --git a/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc index a81f2147059..41adc21db35 100644 --- a/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc @@ -44,21 +44,22 @@ namespace { // The pass to trim functions before we legalize to TFL // dialect using the specified whitelist. -class TrimFunctionsPass : public mlir::ModulePass { +class TrimFunctionsPass + : public mlir::OperationPass { public: explicit TrimFunctionsPass() : trim_funcs_whitelist_(trim_funcs_whitelist) {} explicit TrimFunctionsPass(llvm::ArrayRef trim_funcs_whitelist) : trim_funcs_whitelist_(trim_funcs_whitelist) {} private: - void runOnModule() override; + void runOnOperation() override; bool TrimModule(); void Verify(); llvm::ArrayRef trim_funcs_whitelist_; }; -void TrimFunctionsPass::runOnModule() { +void TrimFunctionsPass::runOnOperation() { // trim the functions in the module using the trim_funcs_whitelist_ // by removing functions not in the whitelist. if (TrimModule()) { @@ -73,7 +74,7 @@ bool TrimFunctionsPass::TrimModule() { if (trim_funcs_whitelist_.empty()) return false; llvm::SmallVector funcs_to_trim; - for (auto func : getModule().getOps()) { + for (auto func : getOperation().getOps()) { if (llvm::is_contained(trim_funcs_whitelist_, func.getName())) { // If no main is specified in the whitelist, use the 1st func // in trim_funcs_whitelist as the main. @@ -102,12 +103,12 @@ bool TrimFunctionsPass::TrimModule() { void TrimFunctionsPass::Verify() { // TODO(ashwinm): Instead, we should make sure that references to all // SymbolRefAttrs of all ops are present. - SymbolTable symbol_table = SymbolTable(getModule()); + SymbolTable symbol_table = SymbolTable(getOperation()); llvm::SetVector reachable_funcs; - for (auto func : getModule().getOps()) { + for (auto func : getOperation().getOps()) { auto walk_result = func.walk([&](CallOp op) -> WalkResult { if (!symbol_table.lookup(op.getCallee())) - return getModule().emitError() + return getOperation().emitError() << func.getName() << " is not in the funcs whitelist"; return WalkResult::advance(); }); diff --git a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc index a0675efcc6b..c2acb93fe78 100644 --- a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc +++ b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc @@ -37,12 +37,13 @@ namespace { // This pass outlines the cond/body region of the TFL WhileOp into functions and // replaces the regions with calls to these outlined functions. -class WhileOutlinePass : public mlir::ModulePass { +class WhileOutlinePass + : public mlir::OperationPass { public: explicit WhileOutlinePass() {} private: - void runOnModule() override; + void runOnOperation() override; // Outlines the regions of the WhileOp's cond and body and insert function // calls instead, @@ -130,7 +131,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) { // Create outline function from region. Optional pass extra arguments through // to yield. - SymbolTable symbol_table(getModule()); + SymbolTable symbol_table(getOperation()); auto create_outline_func = [&](StringRef name, Region& region, bool passthru_extra_args) { FunctionType type; @@ -234,8 +235,8 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) { op->erase(); } -void WhileOutlinePass::runOnModule() { - getModule().walk( +void WhileOutlinePass::runOnOperation() { + getOperation().walk( [&](mlir::TFL::WhileOp while_op) { OutlineWhile(while_op); }); } diff --git a/tensorflow/compiler/mlir/lite/utils/convert_type.cc b/tensorflow/compiler/mlir/lite/utils/convert_type.cc index fe2cdb2748d..ed6888d4874 100644 --- a/tensorflow/compiler/mlir/lite/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/lite/utils/convert_type.cc @@ -32,10 +32,12 @@ namespace errors = tensorflow::errors; mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) { switch (type) { - case tflite::TensorType_FLOAT32: - return builder.getF32Type(); case tflite::TensorType_FLOAT16: return builder.getF16Type(); + case tflite::TensorType_FLOAT32: + return builder.getF32Type(); + case tflite::TensorType_FLOAT64: + return builder.getF64Type(); case tflite::TensorType_INT32: return builder.getIntegerType(32); case tflite::TensorType_UINT8: @@ -65,6 +67,8 @@ tensorflow::DataType TflTypeToTfType(tflite::TensorType type) { return tensorflow::DT_HALF; case tflite::TensorType_FLOAT32: return tensorflow::DT_FLOAT; + case tflite::TensorType_FLOAT64: + return tensorflow::DT_DOUBLE; case tflite::TensorType_INT8: return tensorflow::DT_INT8; case tflite::TensorType_INT16: diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index 8d670d96748..0ca4364f9cd 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -545,13 +545,44 @@ LogicalResult Verify(SwitchNOp switchn) { << "expect `num_outs` (" << num_outs.getInt() << ") results but got " << (switchn.getNumResults() - 1); + // Check that operand can be broadcasted to each output type. auto operand0_type = switchn.getOperand(0).getType(); - for (Value result : switchn.outputs()) - if (operand0_type != result.getType()) - return switchn.emitOpError() - << "type mismatch between data operand and result: " - << operand0_type << " vs " << result.getType(); + TensorType operand0_tensor_type = operand0_type.dyn_cast(); + if (!operand0_tensor_type) { + return switchn.emitOpError() + << "expects data operand to have tensor type but got " + << operand0_type; + } + for (Type output_type : switchn.getResultTypes()) { + if (output_type.isa()) break; + TensorType output_tensor_type = output_type.dyn_cast(); + if (!output_tensor_type) { + return switchn.emitOpError() + << "expects outputs to have tensor type but got " << output_type; + } + + // If the output type is a ref type, then the operand type should also be of + // the same ref type. However, if the output type is a non-ref type T, then + // the operand can be tensor of type T or T_REF. + bool is_output_ref = + output_tensor_type.getElementType().isa(); + if (is_output_ref && + !operand0_tensor_type.getElementType().isa()) { + return switchn.emitOpError() + << "expects same operand and output element type but got " + << operand0_tensor_type << " vs " << output_tensor_type; + } + Type broadcasted_type = OpTrait::util::getBroadcastedType( + DropRefType(DropTypeSubTypes(operand0_tensor_type)), + DropRefType(DropTypeSubTypes(output_tensor_type))); + if (!broadcasted_type) { + return switchn.emitOpError() + << "expects data operand to be broadcastable with all output types" + << " but got " << operand0_tensor_type << " vs " + << output_tensor_type; + } + } return success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index cdeb10cf03a..7326192f418 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -5301,6 +5301,8 @@ tf.pow(x, y) ==> [[256, 65536], [9, 27]] ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let hasFolder = 1; } def TF_PreventGradientOp : TF_Op<"PreventGradient", [NoSideEffect, SameOperandsAndResultType]> { @@ -6006,6 +6008,30 @@ Resize `images` to `size` using nearest neighbor interpolation. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_ResourceApplyAdagradV2Op : TF_Op<"ResourceApplyAdagradV2", []> { + let summary = "Update '*var' according to the adagrad scheme."; + + let description = [{ +accum += grad * grad +var -= lr * grad * (1 / (sqrt(accum) + epsilon)) + }]; + + let arguments = (ins + TF_ResourceTensor:$var, + TF_ResourceTensor:$accum, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + + DefaultValuedAttr:$use_locking, + DefaultValuedAttr:$update_slots + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; +} + def TF_ResourceApplyAdamOp : TF_Op<"ResourceApplyAdam", []> { let summary = "Update '*var' according to the Adam algorithm."; @@ -7711,6 +7737,28 @@ shape of `StridedSlice`'s `input`. }]; } +def TF_StringFormatOp : TF_Op<"StringFormat", [NoSideEffect]> { + let summary = "Formats a string template using a list of tensors."; + + let description = [{ +Formats a string template using a list of tensors, pretty-printing tensor summaries. + }]; + + let arguments = (ins + Variadic:$inputs, + + DefaultValuedAttr:$strtemplate, + DefaultValuedAttr:$placeholder, + DefaultValuedAttr:$summarize + ); + + let results = (outs + TF_StrTensor:$output + ); + + TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>; +} + def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape]>, WithBroadcastableBinOpBuilder { let summary = "Returns x - y element-wise."; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 842520927e2..92590af2aea 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -2153,6 +2153,27 @@ static LogicalResult VerifyPartitionedCall(OpClass op) { return success(); } +//===----------------------------------------------------------------------===// +// PowOp +//===----------------------------------------------------------------------===// + +OpFoldResult PowOp::fold(ArrayRef operands) { + auto constant_y = operands[1].dyn_cast_or_null(); + if (constant_y && constant_y.isSplat()) { + APFloat y_value = constant_y.getSplatValue(); + auto output_type = getType().cast(); + if (y_value.isZero() && output_type.hasStaticShape()) { + return DenseElementsAttr::get( + output_type, + FloatAttr::get(output_type.getElementType(), /*value=*/1.0)); + } + if (y_value.isExactlyValue(1.0)) { + return x(); + } + } + return {}; +} + //===----------------------------------------------------------------------===// // ReciprocalOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir index 411599053e5..2a34bbfacdc 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir @@ -18,6 +18,26 @@ func @testShape(tensor, tensor<1x32x32x16xf32>, tensor<*xf32>) -> (tensor<0 return %0, %1, %2 : tensor<0xi32>, tensor, tensor } +// CHECK-LABEL: func @testPow +// CHECK-SAME:(%[[ARG_0:.*]]: tensor<4xf32>, %[[ARG_1:.*]]: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) +func @testPow(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) { + + %cst_zero = constant dense<0.0> : tensor + %cst_one = constant dense<1.0> : tensor + + // CHECK-DAG: %[[RES_NO_FOLD:.*]] = "tf.Pow"(%arg0, %arg1) + %0 = "tf.Pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + + // CHECK-DAG: %[[POW_ZERO:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32> + %1 = "tf.Pow"(%arg0, %cst_zero) : (tensor<4xf32>, tensor) -> tensor<4xf32> + + // CHECK-NOT: "tf.Pow" + %2 = "tf.Pow"(%arg0, %cst_one) : (tensor<4xf32>, tensor) -> tensor<4xf32> + + // CHECK: return %[[RES_NO_FOLD]], %[[POW_ZERO]], %[[ARG_0]] + return %0, %1, %2 : tensor<4xf32>, tensor<4xf32>, tensor<4xf32> +} + // CHECK-LABEL: func @testShapeN func @testShapeN(%arg0: tensor, %arg1: tensor<1x32x32x16xf32>, %arg2: tensor<*xf32>) -> (tensor<0xi64>, tensor<4xi64>, tensor<4xi64>, tensor) { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir index d3178be9b1e..08eb773a54b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir @@ -147,6 +147,37 @@ func @decompose_resource_apply_keras_momentum_nesterov(%arg0: tensor, %arg1 // ----- + +// Tests that composite tf.ResourceApplyAdagradV2 operation is decomposed. + +// CHECK-LABEL: func @decompose_resource_apply_adagradv2 +// CHECK-SAME: ([[LR:%.*]]: tensor, [[EPSILON:%.*]]: tensor, [[GRAD:%.*]]: tensor) +func @decompose_resource_apply_adagradv2(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> () { + +// CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp"() +// CHECK: [[ACC_HANDLE:%.*]] = "tf.VarHandleOp"() +// CHECK: [[GRAD_SQUARE:%.*]] = "tf.Mul"([[GRAD]], [[GRAD]]) : (tensor, tensor) -> tensor +// CHECK: [[OLD_ACC:%.*]] = "tf.ReadVariableOp"([[ACC_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32> +// CHECK: [[NEW_ACC:%.*]] = "tf.AddV2"([[OLD_ACC]], [[GRAD_SQUARE]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK: [[LR_MULTIPLY:%.*]] = "tf.Mul"([[LR]], [[GRAD]]) : (tensor, tensor) -> tensor +// CHECK: [[SQRT:%.*]] = "tf.Sqrt"([[NEW_ACC]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: [[DIVISOR:%.*]] = "tf.AddV2"([[SQRT]], [[EPSILON]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK: [[VAR_DELTA:%.*]] = "tf.Div"([[LR_MULTIPLY]], [[DIVISOR]]) : (tensor, tensor<*xf32>) -> tensor<*xf32> +// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32> +// CHECK: [[NEW_VAR:%.*]] = "tf.Sub"(%9, %8) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> +// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> () +// CHECK: "tf.AssignVariableOp"([[ACC_HANDLE]], [[NEW_ACC]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> () + + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> + %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> + + "tf.ResourceApplyAdagradV2"(%0, %1, %arg0, %arg1, %arg2) {update_slots = true, use_locking = true} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor, tensor, tensor) -> () + + return +} + +// ----- + // Tests that composite tf.ResourceApplyAdam (non-Nesterov) operation is // decomposed. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 99b8823f2bb..5c8041e0436 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -248,6 +248,40 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { return %0 : tensor } + // Check that supported tf_executor ops can receive data from ops on which + // shape inference has inferred the result types, without throwing any errors. + // CHECK-LABEL: func @supported_tf_executor_users + func @supported_tf_executor_users(%arg0: tensor<32x?x256x4xf32>, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor { + %0 = tf_executor.graph { + %island:3 = tf_executor.island { + %dims = "tf.Const"() {value = dense<[32, -1, 4]> : tensor<3xi32>} : () -> tensor<3xi32> + %reshape = "tf.Reshape"(%arg0, %dims) : (tensor<32x?x256x4xf32>, tensor<3xi32>) -> tensor + %cast = "tf.Cast"(%arg2) : (tensor) -> tensor<*xi1> + tf_executor.yield %reshape, %cast : tensor, tensor<*xi1> + } + // CHECK: tf_executor.Merge + // CHECK-SAME: : (tensor<32x?x4xf32>, tensor) -> + // CHECK: tf_executor.Switch + // CHECK-SAME: : (tensor<32x?x4xf32>, tensor) -> + // CHECK: tf_executor.SwitchN + // CHECK-SAME: : tensor + // CHECK: tf_executor.Enter + // CHECK-SAME: : (tensor<32x?x4xf32>) -> + // CHECK: tf_executor.Exit + // CHECK-SAME: : tensor + // CHECK: tf_executor.LoopCond + // CHECK-SAME: : tensor<*xi1> + %merge:3 = "tf_executor.Merge"(%island#0, %arg1) : (tensor, tensor) -> (tensor, tensor, !tf_executor.control) + %switch:3 = "tf_executor.Switch"(%island#0, %arg2) : (tensor, tensor) -> (tensor, tensor, !tf_executor.control) + %switchn:3 = "tf_executor.SwitchN"(%island#0, %arg3) {num_outs = 2} : (tensor, tensor) -> (tensor, tensor, !tf_executor.control) + %enter:2 = "tf_executor.Enter"(%island#0) { frame_name = "frame"} : (tensor) -> (tensor, !tf_executor.control) + %exit:2 = "tf_executor.Exit"(%island#0) : (tensor) -> (tensor, !tf_executor.control) + %loop_cond:2 = "tf_executor.LoopCond" (%island#1) : (tensor<*xi1>) -> (tensor<*xi1>, !tf_executor.control) + tf_executor.fetch %enter#0 : tensor + } + return %0 : tensor + } + // CHECK-LABEL: func @fold_cast func @fold_cast(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK-NOT: Cast diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir index b9ec020ff59..db9db1518d7 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir @@ -7,7 +7,7 @@ func @invalid_type() -> !tf_executor.foobar // Check that tf_executor.graph does not accept any operand. func @graph_with_invalid_op(%arg0: tensor<*xf32>) { - "tf_executor.graph" (%arg0) : (tensor<*xf32>) -> () + "tf_executor.graph" (%arg0) ({}) : (tensor<*xf32>) -> () // expected-error@-1 {{'tf_executor.graph' op requires zero operands}} return } @@ -405,12 +405,49 @@ func @invalid_switchN(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> // ----- -// Check that switchN result type matches the input type. -func @invalid_switchN(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { +// Check that data operands of SwitchN have tensor type +func @invalid_switchN(%arg0: i32, %arg1: tensor) -> tensor<*xi32> { + %result = tf_executor.graph { + %1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (i32, tensor) -> (tensor<*xi32>, tensor, !tf_executor.control) +// expected-error@-1 {{'tf_executor.SwitchN' op expects data operand to have tensor type but got 'i32'}} + tf_executor.fetch %1#0 : tensor<*xi32> + } + return %result : tensor<*xi32> +} + +// ----- + +// Check that result of SwitchN has tensor type +func @invalid_switchN(%arg0: tensor<*xi32>, %arg1: tensor) -> i32 { + %result = tf_executor.graph { + %1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<*xi32>, tensor) -> (i32, tensor, !tf_executor.control) +// expected-error@-1 {{'tf_executor.SwitchN' op expects outputs to have tensor type but got 'i32'}} + tf_executor.fetch %1#0 : i32 + } + return %result : i32 +} + +// ----- + +// Check that if any result is a ref type, then data operand needs to be ref too. +func @invalid_switchN(%arg0: tensor<4xf32>, %arg1: tensor) -> tensor<4x!tf.f32ref> { %fetches = tf_executor.graph { - %1:3 = "tf_executor.SwitchN"(%arg1, %arg0) {num_outs = 2} : (tensor<*xf32>, tensor) -> (tensor<*xf32>, i32, !tf_executor.control) -// expected-error@-1 {{'tf_executor.SwitchN' op type mismatch between data operand and result: 'tensor<*xf32>' vs 'i32'}} + %1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<4xf32>, tensor) -> (tensor<4x!tf.f32ref>, tensor<4xf32>, !tf_executor.control) +// expected-error@-1 {{'tf_executor.SwitchN' op expects same operand and output element type but got 'tensor<4xf32>' vs 'tensor<4x!tf.f32ref>'}} + tf_executor.fetch %1#0 : tensor<4x!tf.f32ref> + } + return %fetches : tensor<4x!tf.f32ref> +} + +// ----- + +// Check that switchN data operand is broadcastable with all output types +func @invalid_switchN(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf32> { + %fetches = tf_executor.graph { + + %1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor, !tf_executor.control) +// expected-error@-1 {{'tf_executor.SwitchN' op expects data operand to be broadcastable with all output types but got 'tensor<*xf32>' vs 'tensor'}} tf_executor.fetch %1#0 : tensor<*xf32> } @@ -472,6 +509,30 @@ func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf32> { // ----- +// Check that data operands of merge have tensor type +func @invalid_merge(%arg0: tensor<*xi32>, %arg1: i32) -> tensor<*xi32> { + %result = tf_executor.graph { + %value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<*xi32>, i32) -> (tensor<*xi32>, tensor, !tf_executor.control) +// expected-error@-1 {{'tf_executor.Merge' op expects data operands to have tensor type but got 'i32'}} + tf_executor.fetch %value : tensor<*xi32> + } + return %result : tensor<*xi32> +} + +// ----- + +// Check that result of merge has tensor type +func @invalid_merge(%arg0: tensor<*xi32>, %arg1: tensor) -> i32 { + %result = tf_executor.graph { + %value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<*xi32>, tensor) -> (i32, tensor, !tf_executor.control) +// expected-error@-1 {{'tf_executor.Merge' op result #0 must be tensor of any type values, but got 'i32'}} + tf_executor.fetch %value : i32 + } + return %result : i32 +} + +// ----- + // Check that merge data inputs are all the same type func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf32> { %result = tf_executor.graph { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/multi_arguments_results_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/multi_arguments_results_v1.py new file mode 100644 index 00000000000..8778f0048da --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/multi_arguments_results_v1.py @@ -0,0 +1,92 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# RUN: %p/multi_arguments_results_v1 | FileCheck -dump-input-on-failure %s + +# pylint: disable=missing-docstring,line-too-long +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.compat.v1 as tf +from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1 +from tensorflow.python.ops import array_ops + +# Tests multiple inputs and outputs with index paths. + +# CHECK-LABEL: func @key( +# CHECK-SAME: %[[ARG0:.*]]: tensor<3x5xf32> {tf_saved_model.index_path = ["y"]} +# CHECK-SAME: %[[ARG1:.*]]: tensor<5x3xf32> {tf_saved_model.index_path = ["x"]} +# CHECK-SAME: tensor<3x3xf32> {tf_saved_model.index_path = ["t"]} +# CHECK-SAME: tensor<5x5xf32> {tf_saved_model.index_path = ["s"]} +# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["key"] +# CHECK-DAG: %[[MUL0:.*]] = "tf.MatMul"(%[[ARG1]], %[[ARG0]]) +# CHECK-DAG: %[[MUL1:.*]] = "tf.MatMul"(%[[ARG0]], %[[ARG1]]) +# CHECK: %[[IDENTITY:.*]]:2 = "tf.IdentityN"(%[[MUL1]], %[[MUL0]]) +# CHECK: return %[[IDENTITY]]#0, %[[IDENTITY]]#1 + +# CHECK-LABEL: func @key2( +# CHECK-SAME: %[[ARG1:.*]]: tensor<5x3xf32> {tf_saved_model.index_path = ["b"]} +# CHECK-SAME: %[[ARG0:.*]]: tensor<3x5xf32> {tf_saved_model.index_path = ["a"]} +# CHECK-SAME: tensor<5x5xf32> {tf_saved_model.index_path = ["d"]} +# CHECK-SAME: tensor<3x3xf32> {tf_saved_model.index_path = ["c"]} +# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["key2"] +# CHECK-DAG: %[[MUL1:.*]] = "tf.MatMul"(%[[ARG0]], %[[ARG1]]) +# CHECK-DAG: %[[MUL2:.*]] = "tf.MatMul"(%[[ARG1]], %[[ARG0]]) +# CHECK: %[[IDENTITY:.*]]:2 = "tf.IdentityN"(%[[MUL1]], %[[MUL2]]) +# CHECK: return %[[IDENTITY]]#1, %[[IDENTITY]]#0 + + +def Test(): + + x = tf.constant(1.0, shape=(5, 3)) + y = tf.constant(1.0, shape=(3, 5)) + + s = tf.matmul(x, y) + t = tf.matmul(y, x) + [t, s] = array_ops.identity_n([t, s]) + + tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x) + tensor_info_y = tf.compat.v1.saved_model.utils.build_tensor_info(y) + tensor_info_s = tf.compat.v1.saved_model.utils.build_tensor_info(s) + tensor_info_t = tf.compat.v1.saved_model.utils.build_tensor_info(t) + + return { + 'key': (tf.compat.v1.saved_model.signature_def_utils.build_signature_def( + inputs={ + 'x': tensor_info_x, + 'y': tensor_info_y + }, + outputs={ + 's': tensor_info_s, + 't': tensor_info_t + }, + method_name='some_function')), + 'key2': (tf.compat.v1.saved_model.signature_def_utils.build_signature_def( + inputs={ + 'a': tensor_info_y, + 'b': tensor_info_x, + }, + outputs={ + 'c': tensor_info_t, + 'd': tensor_info_s, + }, + method_name='reverse_arguments')) + } + + +if __name__ == '__main__': + common_v1.set_tf_options() + common_v1.do_test(Test()) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/multi_arguments_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/multi_arguments_v1.py deleted file mode 100644 index 107c7a4aad7..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/multi_arguments_v1.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -# RUN: %p/multi_arguments_v1 | FileCheck %s - -# pylint: disable=missing-docstring,line-too-long -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow.compat.v1 as tf -from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1 - -# Tests multiple inputs with index paths. -# CHECK: func {{@[a-zA-Z_0-9]+}}( -# CHECK-SAME: [[ARG0:%.*]]: tensor<5x3xf32> {tf_saved_model.index_path = ["x"]}, -# CHECK-SAME: [[ARG1:%.*]]: tensor<3x5xf32> {tf_saved_model.index_path = ["y"]}) -# CHECK-SAME: -> (tensor<5x5xf32> {tf_saved_model.index_path = ["s"]}, -# CHECK-SAME: tensor<3x3xf32> {tf_saved_model.index_path = ["t"]}) -# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["key"] - - -def Test(): - - x = tf.constant(1.0, shape=(5, 3)) - y = tf.constant(1.0, shape=(3, 5)) - - s = tf.matmul(x, y) - t = tf.matmul(y, x) - - tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x) - tensor_info_y = tf.compat.v1.saved_model.utils.build_tensor_info(y) - tensor_info_s = tf.compat.v1.saved_model.utils.build_tensor_info(s) - tensor_info_t = tf.compat.v1.saved_model.utils.build_tensor_info(t) - - return { - 'key': (tf.compat.v1.saved_model.signature_def_utils.build_signature_def( - inputs={ - 'x': tensor_info_x, - 'y': tensor_info_y - }, - outputs={ - 's': tensor_info_s, - 't': tensor_info_t - }, - method_name='some_function')) - } - - -if __name__ == '__main__': - common_v1.set_tf_options() - common_v1.do_test(Test()) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc b/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc index d96f4c18a10..58b4901b548 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc @@ -39,8 +39,8 @@ constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices"; // Analyzes the inputs to LaunchFuncOps in the module, and annotates their // invoked functions whether each input has the same data across replicas. struct AnnotateParameterReplication - : public ModulePass { - void runOnModule() override; + : public OperationPass { + void runOnOperation() override; }; // Returns the first value in the chain of operands, which is not defined by a @@ -53,8 +53,8 @@ Value SkipIdentityAndReadVariable(Value v) { return v; } -void AnnotateParameterReplication::runOnModule() { - ModuleOp m = getModule(); +void AnnotateParameterReplication::runOnOperation() { + ModuleOp m = getOperation(); OpBuilder builder(m.getContext()); m.walk([&](tf_device::LaunchFuncOp launch_func) { auto replicate = launch_func.getParentOfType(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc index 1c9ace21efb..aee6e72e7d6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc @@ -38,8 +38,9 @@ namespace { constexpr char kDeviceAttr[] = "device"; constexpr char kFuncAttr[] = "func"; -struct ClusterOutliningPass : public ModulePass { - void runOnModule() override; +struct ClusterOutliningPass + : public OperationPass { + void runOnOperation() override; }; void ReplaceLaunchReturnWithReturn(tf_device::ReturnOp launch_return_op, @@ -120,8 +121,8 @@ void OutlineLaunch(tf_device::LaunchOp launch_op, SymbolTable* symbol_table, launch_op.erase(); } -void ClusterOutliningPass::runOnModule() { - ModuleOp m = getModule(); +void ClusterOutliningPass::runOnOperation() { + ModuleOp m = getOperation(); SymbolTable symbol_table(m); OpBuilder builder(m.getContext()); m.walk([&](tf_device::LaunchOp launch) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td index bac7b9ba01c..e26f8b6a7bb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td @@ -22,7 +22,7 @@ class GetScalarOfType : NativeCodeCall< "GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">; // Creates a tf.ReadVariable op that reads a resource `$2` that has the same -// element type as `$1`. The op created will use location of `$1`. +// element type as `$1`. The op created will use location of `$0`. def CreateTFReadVariableOp: NativeCodeCall< "$_builder.create(" " $0.getLoc()," @@ -118,6 +118,32 @@ def DecomposeResourceApplyKerasMomentumOpNesterov : ] >; +// Pattern to Decompose ResourceApplyAdagrad. +// This decomposition is only correct inside XLA as it ignores use_locking +// attribute. +// accum <- accum + grad * grad +// variable <- variable - lr * grad / (sqrt(accum) + epsilon) +def DecomposeResourceApplyAdagradV2 : + Pattern< + (TF_ResourceApplyAdagradV2Op:$src_op + $var_resource, $accum_resource, $lr, $epsilon, $grad, BoolAttr:$_, + ConstBoolAttrTrue:$update_slots), + [ + (TF_AddV2Op:$new_accum + (CreateTFReadVariableOp $src_op, $grad, $accum_resource), + (TF_MulOp $grad, $grad) + ), + (TF_AssignSubVariableOp + $var_resource, + (TF_DivOp + (TF_MulOp $lr, $grad), + (TF_AddV2Op (TF_SqrtOp $new_accum), $epsilon) + ) + ), + (TF_AssignVariableOp $accum_resource, $new_accum), + ] + >; + // Pattern to Decompose ResourceApplyAdam without Nesterov momentum. // This decomposition is only correct inside XLA as it ignores use_locking // attribute. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc index eb2aa16e25f..f7569917b41 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc @@ -303,7 +303,8 @@ void InsertDummyIslandForFetch(FetchOp fetch) { /*control=*/ControlType::get(fetch.getContext()), /*controlInputs=*/control_fetches); island.body().push_back(new Block); - OpBuilder(&island.GetBody()).create(fetch.getLoc(), data_fetches); + OpBuilder::atBlockEnd(&island.GetBody()) + .create(fetch.getLoc(), data_fetches); const int fetch_control_idx = data_fetches.size(); for (int i = 0, e = fetch.getNumOperands(); i < e; i++) { // The fetch could have multiple control operands (all at the end of its diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc index 85d9d994b30..71e5d291292 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc @@ -43,17 +43,17 @@ constexpr llvm::StringRef kNestedModule = "_tpu_v1_compat_outlined"; // Inlining the islands calling into the nested module that was outlined. // This is the end of the TPU bridge in V1 compatibility mode. struct TPUBridgeExecutorIslandInlining - : public ModulePass { - void runOnModule() override; + : public OperationPass { + void runOnOperation() override; }; -void TPUBridgeExecutorIslandInlining::runOnModule() { - SymbolTable symbol_table(getModule()); +void TPUBridgeExecutorIslandInlining::runOnOperation() { + SymbolTable symbol_table(getOperation()); Operation *nested_module = symbol_table.lookup(kNestedModule); if (!nested_module) return; InlinerInterface inliner(&getContext()); - auto walk_result = getModule().walk([&](TF::PartitionedCallOp call_op) { + auto walk_result = getOperation().walk([&](TF::PartitionedCallOp call_op) { if (!call_op.f().getRootReference().startswith(kNestedModule)) return WalkResult::advance(); // This is a call we need to inline! @@ -61,7 +61,7 @@ void TPUBridgeExecutorIslandInlining::runOnModule() { << "Found call to inline: " << *call_op.getOperation() << "\n"); FuncOp called_func = dyn_cast_or_null( - symbol_table.lookupSymbolIn(getModule(), call_op.f())); + symbol_table.lookupSymbolIn(getOperation(), call_op.f())); if (failed(inlineCall(inliner, cast(call_op.getOperation()), @@ -80,7 +80,7 @@ void TPUBridgeExecutorIslandInlining::runOnModule() { Block &nested_block = nested_module->getRegion(0).front(); for (FuncOp func_op : llvm::make_early_inc_range(nested_block.getOps())) { - if (!symbol_table.lookupSymbolIn(getModule(), func_op.getName())) { + if (!symbol_table.lookupSymbolIn(getOperation(), func_op.getName())) { nested_block.getOperations().remove(func_op.getOperation()); symbol_table.insert(func_op.getOperation()); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc index 54782116094..452ac076ac9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc @@ -59,8 +59,8 @@ constexpr llvm::StringRef kTpuStatusAttr = "_tpu_compilation_status"; // TPU-annotated operations and intended to preserve backward compatibility with // TFv1. struct TpuV1BridgeExecutorIslandCoarsening - : public ModulePass { - void runOnModule() override; + : public OperationPass { + void runOnOperation() override; }; // Sort the Operations in the provided range to enforce dominance. @@ -226,7 +226,8 @@ LogicalResult MergeIsland(llvm::function_ref yield_operands.push_back(std::get<1>(result)); } } - OpBuilder(&island_body).create(new_island.getLoc(), yield_operands); + OpBuilder::atBlockEnd(&island_body) + .create(new_island.getLoc(), yield_operands); // remap results of the new islands to the user outside of the island. int current_result = 0; @@ -257,13 +258,13 @@ LogicalResult MergeIsland(llvm::function_ref first_op_after); } -void TpuV1BridgeExecutorIslandCoarsening::runOnModule() { - SymbolTable symbol_table(getModule()); +void TpuV1BridgeExecutorIslandCoarsening::runOnOperation() { + SymbolTable symbol_table(getOperation()); // Map tpu cluster names to the functions that contain operations for this // cluster. DenseMap> tpu_funcs; - for (FuncOp func_op : getModule().getOps()) { + for (FuncOp func_op : getOperation().getOps()) { func_op.walk([&](Operation* op) { StringAttr cluster_name = op->getAttrOfType(kTpuReplicateAttr); @@ -291,7 +292,7 @@ void TpuV1BridgeExecutorIslandCoarsening::runOnModule() { return false; }; - for (FuncOp func_op : getModule().getOps()) { + for (FuncOp func_op : getOperation().getOps()) { func_op.walk([&](GraphOp graph) { Block& graph_body = graph.GetBody(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc index b25cc23aac8..db13d6b3875 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc @@ -44,20 +44,20 @@ constexpr llvm::StringRef kOutlinedFuncPrefix = "_tpu_v1_compat_outlined_func"; // This is only intended for V1 compatibility mode where the bridge runs without // feed/fetches on session create/extend. struct TPUBridgeExecutorIslandOutlining - : public ModulePass { - void runOnModule() override; + : public OperationPass { + void runOnOperation() override; }; -void TPUBridgeExecutorIslandOutlining::runOnModule() { +void TPUBridgeExecutorIslandOutlining::runOnOperation() { MLIRContext *ctx = &getContext(); - SymbolTable symbol_table(getModule()); + SymbolTable symbol_table(getOperation()); if (Operation *nested_module = symbol_table.lookup(kNestedModule)) { nested_module->emitOpError("unexpected already present outlined module."); return signalPassFailure(); } - ModuleOp outlined_module = ModuleOp::create(getModule().getLoc()); - outlined_module.setAttrs(getModule().getAttrs()); + ModuleOp outlined_module = ModuleOp::create(getOperation().getLoc()); + outlined_module.setAttrs(getOperation().getAttrs()); outlined_module.setAttr(SymbolTable::getSymbolAttrName(), StringAttr::get(kNestedModule, ctx)); symbol_table.insert(outlined_module); @@ -66,7 +66,7 @@ void TPUBridgeExecutorIslandOutlining::runOnModule() { // Find every island that contains a TPUReplicateMetadata node and extract it // in a new module to run the V1 bridge there. SmallVector islands_to_outline; - getModule().walk([&](TF::TPUReplicateMetadataOp replicate_op) { + getOperation().walk([&](TF::TPUReplicateMetadataOp replicate_op) { auto island_op = cast(replicate_op.getParentOp()); if (!island_op || island_op.WrapsSingleOp()) return; islands_to_outline.push_back(island_op); @@ -123,7 +123,7 @@ void TPUBridgeExecutorIslandOutlining::runOnModule() { // The function is in place in the nested module, create a call and yield in // the original island. - OpBuilder builder(&island_op.GetBody()); + OpBuilder builder = OpBuilder::atBlockEnd(&island_op.GetBody()); auto call_op = builder.create( island_op.getLoc(), func_result_types, operands.getArrayRef(), builder.getSymbolRefAttr( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc index 30444b88677..ad404182658 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc @@ -202,7 +202,7 @@ static void MatchSwitchFoldOps(tf_executor::SwitchOp switch_op, static LogicalResult FoldMergeNodes(FuncOp function, const DeadQueue& queue) { // Create builder for val_index of MergeOp. auto* block = &function.getBlocks().front(); - OpBuilder builder(block); + OpBuilder builder = OpBuilder::atBlockEnd(block); auto type = builder.getIntegerType(32); auto build_index = [&](Location loc, int value) { return builder.create(loc, type, diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc index 9ae3ffdaa7d..088080c603b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc @@ -41,12 +41,13 @@ namespace { // the IR is in correct form for inference backends (like lite) that do not // support resources/variables . Further, this contract also ensures that this // pass lowers from saved model to pure TF. Hence it fails, if it cannot lower. -struct FreezeGlobalTensorsPass : public ModulePass { - void runOnModule() override; +struct FreezeGlobalTensorsPass + : public OperationPass { + void runOnOperation() override; }; -void FreezeGlobalTensorsPass::runOnModule() { - auto module = getModule(); +void FreezeGlobalTensorsPass::runOnOperation() { + auto module = getOperation(); SymbolTable symbol_table(module); DenseSet frozen_global_tensors; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc index 237f08c6c41..0b03c522596 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc @@ -126,7 +126,7 @@ void LayoutAssignmentPass::runOnFunction() { mlir::Operation* op = layout_sensitive_interface.getOperation(); Location loc = op->getLoc(); - OpBuilder builder(op->getBlock()); + OpBuilder builder = OpBuilder::atBlockEnd(op->getBlock()); auto perm_attr = [&](Permutation permutation) -> DenseIntElementsAttr { auto perm_ty = RankedTensorType::get({4}, builder.getIntegerType(32)); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_function_visibility.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_function_visibility.cc index dd884fd09fd..a42e7ea8f71 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_function_visibility.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/mark_function_visibility.cc @@ -74,11 +74,11 @@ LogicalResult MarkFunctionVisibilityUsingEntryFunctionSpecification( namespace { struct MarkFunctionVisibilityUsingEntryFunctionSpecificationPass - : public ModulePass< - MarkFunctionVisibilityUsingEntryFunctionSpecificationPass> { - void runOnModule() override { + : public OperationPass< + MarkFunctionVisibilityUsingEntryFunctionSpecificationPass, ModuleOp> { + void runOnOperation() override { if (failed(MarkFunctionVisibilityUsingEntryFunctionSpecification( - getModule()))) { + getOperation()))) { signalPassFailure(); } } @@ -110,9 +110,10 @@ static LogicalResult MarkFunctionVisibilityUsingSavedModelLinkage( namespace { struct MarkFunctionVisibilityUsingSavedModelLinkagePass - : public ModulePass { - void runOnModule() override { - if (failed(MarkFunctionVisibilityUsingSavedModelLinkage(getModule()))) { + : public OperationPass { + void runOnOperation() override { + if (failed(MarkFunctionVisibilityUsingSavedModelLinkage(getOperation()))) { signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc index 713bcff1a71..74b9df3fe9f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc @@ -41,8 +41,8 @@ namespace mlir { namespace tf_saved_model { namespace { struct OptimizeGlobalTensorsPass - : public ModulePass { - void runOnModule() override; + : public OperationPass { + void runOnOperation() override; }; // A global tensor is bound to arguments of multiple funcs. @@ -276,8 +276,8 @@ void EraseUnusedBoundInputs(ModuleOp module) { } } -void OptimizeGlobalTensorsPass::runOnModule() { - auto module = getModule(); +void OptimizeGlobalTensorsPass::runOnOperation() { + auto module = getOperation(); EraseUnusedBoundInputs(module); ResourceAnalyzer resource_analyzer(module); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc index 4d83c647f40..61644866886 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc @@ -258,13 +258,13 @@ LogicalResult PromoteResourcesToArguments(FuncOp function) { } class PromoteResourcesToArgsPass - : public ModulePass { + : public OperationPass { public: - void runOnModule() override; + void runOnOperation() override; }; -void PromoteResourcesToArgsPass::runOnModule() { - ModuleOp module = getModule(); +void PromoteResourcesToArgsPass::runOnOperation() { + ModuleOp module = getOperation(); FuncOp main_func = module.lookupSymbol("main"); if (!main_func) return; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc index d0abb8d844f..2ae62bfee10 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc @@ -53,8 +53,9 @@ constexpr char kFuncDeviceAttr[] = "tf.device"; // // This pass changes the module by adding "tf.device" attribute to function // arguments and adding "device" attribute to TF ops. -struct ResourceDeviceInference : public ModulePass { - void runOnModule() override; +struct ResourceDeviceInference + : public OperationPass { + void runOnOperation() override; }; // A class that records each resource's device assignment in a function. @@ -190,8 +191,8 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op, return failure(walk_res.wasInterrupted()); } -void ResourceDeviceInference::runOnModule() { - auto module = getModule(); +void ResourceDeviceInference::runOnOperation() { + auto module = getOperation(); llvm::SmallDenseMap per_function_results; llvm::SetVector worklist; module.walk([&](FuncOp func_op) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index ed380c7b8bc..420367f72b8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -131,8 +131,9 @@ namespace { // return %arg0 // } // -struct ResourceOpLiftingPass : public ModulePass { - void runOnModule() override; +struct ResourceOpLiftingPass + : public OperationPass { + void runOnOperation() override; }; // Removes identity nodes in the block. The device computation does not need @@ -1050,13 +1051,13 @@ LogicalResult HoistForFunctionalControlFlow( // Lifts resource operation from tf_device.launch_func ops nested in `op` // outside. Returns failure if there are remaining resource-type values that can // not be lifted. -void ResourceOpLiftingPass::runOnModule() { +void ResourceOpLiftingPass::runOnOperation() { llvm::SmallDenseMap lifted_partitioned_call_callees; - auto result = getModule().walk([&](FuncOp func_op) { + auto result = getOperation().walk([&](FuncOp func_op) { return func_op.walk([&](tf_device::LaunchOp launch_op) { if (failed(HoistForFunctionalControlFlow( - &launch_op.GetBody(), getModule(), + &launch_op.GetBody(), getOperation(), &lifted_partitioned_call_callees)) || failed(HoistResourceOpsFromLaunchOp(launch_op))) { return WalkResult::interrupt(); @@ -1070,12 +1071,12 @@ void ResourceOpLiftingPass::runOnModule() { } struct ResourceOpLiftingForMainFunctionPass - : public ModulePass { - void runOnModule() override; + : public OperationPass { + void runOnOperation() override; }; -void ResourceOpLiftingForMainFunctionPass::runOnModule() { - ModuleOp module = getModule(); +void ResourceOpLiftingForMainFunctionPass::runOnOperation() { + ModuleOp module = getOperation(); FuncOp main_func = module.lookupSymbol("main"); if (!main_func) { return; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index e01055916ce..d3a6adbbce6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -111,7 +111,9 @@ bool IsSupportedNonTFOp(Operation* op) { return isa(op) || isa(op) || isa(op) || isa(op) || isa(op) || isa(op) || - isa(op); + isa(op) || isa(op) || + isa(op) || isa(op) || + isa(op) || isa(op); } // Inserts tf.Cast operation when changing the type of a result if the user is @@ -224,7 +226,8 @@ GetSubtypes(Type type) { return GetSubtypesHelper(type); } -// Makes result types match the operand types. Returns if anything is changed. +// Makes result types match the operand types (the i-th result type will +// match the i-th operand type). Returns true if anything is changed. bool PassThroughOperandTypes(OperandRange operands, ResultRange results) { bool changed = false; for (auto entry : llvm::zip(operands, results)) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc index e45504ce819..c90089ad9d5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc @@ -47,9 +47,9 @@ namespace { // This transformation pass propagate shapes on the TensorFlow graph. // It is a ModulePass in order to be able to change function types. -struct ShapeInference : public ModulePass { - void runOnModule() override { - auto module = getModule(); +struct ShapeInference : public OperationPass { + void runOnOperation() override { + auto module = getOperation(); auto producer_or = tensorflow::GetTfGraphProducerVersion(module); if (!producer_or.ok()) { LLVM_DEBUG(llvm::dbgs() << producer_or.status().ToString();); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc index ac23ef7ce2b..6abf4893327 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc @@ -85,8 +85,8 @@ namespace cutil = TF::collection_ops_util; // // The pass also works across control flow and functional calls. struct StackOpsDecompositionPass - : public ModulePass { - void runOnModule() override; + : public OperationPass { + void runOnOperation() override; }; // Returns the type of the local variable for the stack size. @@ -551,8 +551,8 @@ LogicalResult DecomposeStackOps(Block* block, ModuleOp module) { &decomposed_partitioned_call_callees); } -void StackOpsDecompositionPass::runOnModule() { - auto module = getModule(); +void StackOpsDecompositionPass::runOnOperation() { + auto module = getOperation(); auto main = module.lookupSymbol("main"); if (!main) return; if (failed(DecomposeStackOps(&main.front(), module))) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc index b7efc5aa64b..f97d9306a43 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc @@ -68,8 +68,8 @@ using std::string; // shape. // struct TensorArrayOpsDecompositionPass - : public ModulePass { - void runOnModule() override; + : public OperationPass { + void runOnOperation() override; }; // Infers the element type and count for a TensorArraySplitV3Op. Requires @@ -873,8 +873,8 @@ LogicalResult DecomposeTensorArrayOps( return success(); } -void TensorArrayOpsDecompositionPass::runOnModule() { - auto module = getModule(); +void TensorArrayOpsDecompositionPass::runOnOperation() { + auto module = getOperation(); auto main = module.lookupSymbol("main"); if (!main) return; llvm::SmallDenseMap stats; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc index e1010f3b9bd..277146d5c42 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc @@ -62,8 +62,8 @@ namespace cutil = TF::collection_ops_util; // // The pass also works across control flow and functional calls. struct TensorListOpsDecompositionPass - : public ModulePass { - void runOnModule() override; + : public OperationPass { + void runOnOperation() override; }; // Updates func's type according to its current arguments and return values. @@ -671,8 +671,8 @@ LogicalResult DecomposeTensorListOps(Block* block, ModuleOp module) { &decomposed_partitioned_call_callees); } -void TensorListOpsDecompositionPass::runOnModule() { - auto module = getModule(); +void TensorListOpsDecompositionPass::runOnOperation() { + auto module = getOperation(); auto main = module.lookupSymbol("main"); if (!main) return; if (failed(DecomposeTensorListOps(&main.front(), module))) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc index 0ba01738532..cf45c8da5e9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc @@ -40,20 +40,20 @@ namespace tensorflow { // Optimization Passes and convert back to MLIR. // Constraints: This pass expects that all operations in the MLIR module either // belong to 'tf' or '_tf' dialect. The output is in '_tf' dialect. -class GraphOptPass : public mlir::ModulePass { +class GraphOptPass : public mlir::OperationPass { public: explicit GraphOptPass(std::vector passes) : passes_(std::move(passes)) {} protected: - void runOnModule() override; + void runOnOperation() override; // The passes to run on the module. std::vector passes_; }; -void GraphOptPass::runOnModule() { - mlir::ModuleOp module_in = getModule(); +void GraphOptPass::runOnOperation() { + mlir::ModuleOp module_in = getOperation(); mlir::MLIRContext& ctx = getContext(); // Convert MLIR to Graph @@ -151,7 +151,7 @@ class GraphOptByNamePass : public GraphOptPass { : GraphOptPass(FindRegisteredPassesByName(pass_names)) {} private: - void runOnModule() override { + void runOnOperation() override { // Verify all passes requested were registered/found. for (auto pass_it : llvm::enumerate(passes_)) { if (pass_it.value() == nullptr) { @@ -160,7 +160,7 @@ class GraphOptByNamePass : public GraphOptPass { return signalPassFailure(); } } - return GraphOptPass::runOnModule(); + return GraphOptPass::runOnOperation(); } }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc index 6013dfdf4ef..a54826c8f8e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc @@ -48,8 +48,9 @@ constexpr char kPaddingMapAttr[] = "padding_map"; // (user). namespace { -struct TPUDynamicPaddingMapper : public ModulePass { - void runOnModule() override; +struct TPUDynamicPaddingMapper + : public OperationPass { + void runOnOperation() override; }; // Creates a mapping from replicated input index (in `tf_device.replicate` op) @@ -190,8 +191,8 @@ LogicalResult RemapAndAssignPaddingMaps(tf_device::LaunchFuncOp launch_func, return success(); } -void TPUDynamicPaddingMapper::runOnModule() { - ModuleOp module = getModule(); +void TPUDynamicPaddingMapper::runOnOperation() { + ModuleOp module = getOperation(); SymbolTable symbol_table(module); module.walk([&](tf_device::LaunchFuncOp launch_func) { RemapAndAssignPaddingMaps(launch_func, &symbol_table); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index 1a49350a4be..e735fa918bb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -98,8 +98,8 @@ constexpr char kBadArrayAttrLengthMsg[] = // %4 = "tf.SomeOp"(%3) namespace { -struct TPURewritePass : public ModulePass { - void runOnModule() override; +struct TPURewritePass : public OperationPass { + void runOnOperation() override; }; // Creates a missing attribute error message. @@ -747,13 +747,13 @@ LogicalResult Rewrite( return success(); } -void TPURewritePass::runOnModule() { +void TPURewritePass::runOnOperation() { mlir::TF::RuntimeDevices devices; - if (failed(tensorflow::GetDevicesFromOp(getModule(), &devices))) + if (failed(tensorflow::GetDevicesFromOp(getOperation(), &devices))) return signalPassFailure(); OpBuilder builder(&getContext()); - auto result = getModule().walk([&](tf_device::LaunchFuncOp op) { + auto result = getOperation().walk([&](tf_device::LaunchFuncOp op) { if (failed(Rewrite(op, devices.device_names(), &builder))) return WalkResult::interrupt(); @@ -763,7 +763,7 @@ void TPURewritePass::runOnModule() { if (result.wasInterrupted()) return signalPassFailure(); // Eliminate TPUCompilationResultOp now that the rewrite is complete. - getModule().walk([&](TF::TPUCompilationResultOp op) { op.erase(); }); + getOperation().walk([&](TF::TPUCompilationResultOp op) { op.erase(); }); // TODO(b/139377366): Remove functions that are no longer needed. } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc index af01cf329b0..05c8e096f38 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc @@ -40,8 +40,8 @@ namespace { constexpr char kShardingAttr[] = "xla_hlo.sharding"; struct TPUShardingIdentificationPass - : public ModulePass { - void runOnModule() override; + : public OperationPass { + void runOnOperation() override; }; // XlaSharding op may be direct user of inputs but it may also be followed by @@ -176,9 +176,9 @@ void IdentifyXlaShardingForTPUComputation(Builder* builder, builder->getStrArrayAttr(sharding_for_rets)); } -void TPUShardingIdentificationPass::runOnModule() { - Builder builder(getModule().getContext()); - getModule().walk([&](tf_device::LaunchFuncOp launch_func) { +void TPUShardingIdentificationPass::runOnOperation() { + Builder builder(getOperation().getContext()); + getOperation().walk([&](tf_device::LaunchFuncOp launch_func) { IdentifyXlaShardingForTPUComputation(&builder, launch_func); }); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc index 8a7f0c55c3e..a58c28c50d1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc @@ -116,8 +116,8 @@ std::string GetRandomStateVariableName() { // tf.TPUReshardVariablesOp(%rvar, %default_format, %rstate) // } struct TPUVariableRuntimeReformattingPass - : public ModulePass { - void runOnModule() override; + : public OperationPass { + void runOnOperation() override; }; // Returns the earlier value of which `v` is an identity. If `skipped` is @@ -318,7 +318,7 @@ TF::WhileOp AddStateVarsToWhileOp(TF::WhileOp while_op, FuncOp body, new_body_return_vals.push_back(inner_arg); new_while_operands.push_back(state_var.resource()); } - OpBuilder builder(&body.front()); + OpBuilder builder = OpBuilder::atBlockEnd(&body.front()); // Update return values. builder.create(body_return.getLoc(), new_body_return_vals); body_return.erase(); @@ -555,8 +555,8 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate, builder.create(while_op.getLoc(), ArrayRef{}); } -void TPUVariableRuntimeReformattingPass::runOnModule() { - auto module = getModule(); +void TPUVariableRuntimeReformattingPass::runOnOperation() { + auto module = getOperation(); module.walk([&](TF::WhileOp while_op) { auto body = llvm::cast(module.lookupSymbol(while_op.body())); tf_device::ReplicateOp replicate; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc b/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc index 6d3e35ac19b..9d66fc9d355 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc @@ -218,7 +218,7 @@ void ControlToExecutorDialectConversion::runOnFunction() { } // Create the operation inside the island - OpBuilder island_builder(&island.GetBody()); + OpBuilder island_builder = OpBuilder::atBlockEnd(&island.GetBody()); Operation *inner_op = island_builder.createOperation(result); inner_op->setAttrs(op.getAttrList()); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc b/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc index 7410074e300..40a359808cf 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc @@ -68,7 +68,7 @@ void ExecutorToControlDialectConversion::runOnFunction() { Block &body = getFunction().front(); auto graph = cast(body.front()); - OpBuilder builder(&body); + OpBuilder builder = OpBuilder::atBlockEnd(&body); SmallString<64> new_op_name; for (auto &op : llvm::make_early_inc_range(llvm::reverse(graph.GetBody()))) { LLVM_DEBUG(llvm::dbgs() << "Process: " << op.getName() << "\n"); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 155995a4f65..b5200f91840 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -1452,7 +1452,8 @@ mlir::Operation* ImporterBase::createOperation( result.location, types, control_operands, mlir::ArrayRef{}); island.body().push_back(new mlir::Block); - mlir::OpBuilder island_builder(&island.GetBody()); + mlir::OpBuilder island_builder = + mlir::OpBuilder::atBlockEnd(&island.GetBody()); // Create the operation inside the island now. mlir::Operation* inner_op; @@ -2928,12 +2929,11 @@ class SavedModelSignatureDefImporter { // Converts the SavedModel to the SavedModel dialect. Creates an MLIR function // for each signature. StatusOr ConvertSignatures(); - Status ConvertSignature( - const GraphDef& graphdef, const std::string& sig_def_key, - const std::map& inputs_sorted, - const std::map& outputs_sorted, - const GraphDebugInfo& debug_info, - const FunctionLibraryDefinition& flib_def); + Status ConvertSignature(const GraphDef& graphdef, + const std::string& sig_def_key, + const SignatureDef& signature_def, + const GraphDebugInfo& debug_info, + const FunctionLibraryDefinition& flib_def); // Creates GlobalTensorOp for each variable and moves each VarHandle op to // the enclosing function's arguments. @@ -2948,10 +2948,7 @@ class SavedModelSignatureDefImporter { const llvm::SmallVectorImpl& ops); GraphImportConfig::InputArrays ParseInputArrays( - const std::map& inputs); - - std::vector ParseOutputArrays( - const std::map& outputs); + const std::vector>& inputs); const SavedModelBundle& bundle_; mlir::OwningModuleRef module_; @@ -2979,14 +2976,8 @@ SavedModelSignatureDefImporter::ConvertSignatures() { continue; } - // protobuf::Map doesn't provide stable iteration order so use std::map - std::map inputs_sorted( - signature_def.inputs().begin(), signature_def.inputs().end()); - std::map outputs_sorted( - signature_def.outputs().begin(), signature_def.outputs().end()); - - TF_RETURN_IF_ERROR(ConvertSignature(graphdef, sig_def_key, inputs_sorted, - outputs_sorted, debug_info, flib_def)); + TF_RETURN_IF_ERROR(ConvertSignature(graphdef, sig_def_key, signature_def, + debug_info, flib_def)); } TF_RETURN_IF_ERROR(LiftVariables()); @@ -2999,13 +2990,26 @@ SavedModelSignatureDefImporter::ConvertSignatures() { Status SavedModelSignatureDefImporter::ConvertSignature( const GraphDef& graphdef, const std::string& sig_def_key, - const std::map& inputs_sorted, - const std::map& outputs_sorted, - const GraphDebugInfo& debug_info, + const SignatureDef& signature_def, const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def) { + // Create local vectors for the input and output and sort them to be + // deterministic. We don't want anyone to really depend on the order, client + // should lookup argument/result mapping by attribute name. + // To avoid accidentally depending on the order we use an unintuitive sorting. + std::vector> inputs( + signature_def.inputs().begin(), signature_def.inputs().end()); + llvm::sort(inputs, [](const auto& lhs, const auto& rhs) { + return lhs.first.size() < rhs.first.size() || lhs.first > rhs.first; + }); + std::vector> outputs( + signature_def.outputs().begin(), signature_def.outputs().end()); + llvm::sort(outputs, [](const auto& lhs, const auto& rhs) { + return lhs.first.size() < rhs.first.size() || lhs.first > rhs.first; + }); + GraphImportConfig specs; - specs.inputs = ParseInputArrays(inputs_sorted); - specs.outputs = ParseOutputArrays(outputs_sorted); + specs.inputs = ParseInputArrays(inputs); + for (auto& output : outputs) specs.outputs.push_back(output.second.name()); // Remove unused nodes and create sub-graphdef. GraphDef sub_graph_def; @@ -3041,11 +3045,11 @@ Status SavedModelSignatureDefImporter::ConvertSignature( builder.getStrArrayAttr({sig_def_key})); // Transfer input and output parameter names to index_path attributes. - for (auto input_and_idx : llvm::enumerate(inputs_sorted)) { + for (auto input_and_idx : llvm::enumerate(inputs)) { func_op.setArgAttr(input_and_idx.index(), "tf_saved_model.index_path", builder.getStrArrayAttr({input_and_idx.value().first})); } - for (auto output_and_idx : llvm::enumerate(outputs_sorted)) { + for (auto output_and_idx : llvm::enumerate(outputs)) { func_op.setResultAttr( output_and_idx.index(), "tf_saved_model.index_path", builder.getStrArrayAttr({output_and_idx.value().first})); @@ -3180,7 +3184,7 @@ Status SavedModelSignatureDefImporter::ReadVariablesFromSession( } GraphImportConfig::InputArrays SavedModelSignatureDefImporter::ParseInputArrays( - const std::map& inputs) { + const std::vector>& inputs) { GraphImportConfig::InputArrays results; for (const auto& iter : inputs) { const auto& tensor_info = iter.second; @@ -3192,28 +3196,12 @@ GraphImportConfig::InputArrays SavedModelSignatureDefImporter::ParseInputArrays( array_info.imported_dtype = tensor_info.dtype(); array_info.shape = tensor_info.tensor_shape(); - std::vector node_names = - absl::StrSplit(tensor_info.name(), ':'); - - results.insert(std::pair(node_names.at(0), + results.insert(std::pair(tensor_info.name(), std::move(array_info))); } return results; } -std::vector SavedModelSignatureDefImporter::ParseOutputArrays( - const std::map& outputs) { - std::vector results; - for (const auto& iter : outputs) { - const auto& tensor_info = iter.second; - - std::vector node_names = - absl::StrSplit(tensor_info.name(), ':'); - results.push_back(node_names.at(0)); - } - return results; -} - } // namespace Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def) { diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 453576ba9ee..0feb633948d 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -328,6 +328,24 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "buffer_assignment", + srcs = ["transforms/buffer_assignment.cc"], + hdrs = ["transforms/buffer_assignment.h"], + deps = [ + ":hlo", + ":lhlo", + "@com_google_absl//absl/memory", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], + alwayslink = 1, +) + gentbl( name = "xla_legalize_to_standard_inc_gen", tbl_outs = [ diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index 6238e8175c4..a49648b0b37 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -140,7 +140,7 @@ tensorflow::Status HloFunctionImporter::ImportInstructions( instruction_value_map_[hlo_parameter] = block->getArgument(i); } - mlir::OpBuilder builder(block); + mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block); for (auto instruction : computation->MakeInstructionPostOrder()) { TF_ASSIGN_OR_RETURN(auto new_operation, ImportInstruction(instruction, &builder)); @@ -523,6 +523,32 @@ StatusOr HloFunctionImporter::ImportInstruction( attributes.push_back(builder_->getNamedAttr("transpose_a", transpose_a)); MakeAndReturn(TriangularSolveOp); } + case HloOpcode::kReduceWindow: { + llvm::SmallVector sizes, strides, base_dilations, win_dilations; + llvm::SmallVector padding; + for (const auto& dim : instruction->window().dimensions()) { + sizes.push_back(dim.size()); + strides.push_back(dim.stride()); + base_dilations.push_back(dim.base_dilation()); + win_dilations.push_back(dim.window_dilation()); + padding.push_back(dim.padding_low()); + padding.push_back(dim.padding_high()); + } + attributes.push_back(builder_->getNamedAttr("window_dimensions", + ConvertDimensions(sizes))); + attributes.push_back( + builder_->getNamedAttr("window_strides", ConvertDimensions(strides))); + attributes.push_back(builder_->getNamedAttr( + "base_dilations", ConvertDimensions(base_dilations))); + attributes.push_back(builder_->getNamedAttr( + "window_dilations", ConvertDimensions(win_dilations))); + attributes.push_back(ConvertPadding(padding)); + auto reduce = func_builder->create( + loc, result_type, operands, attributes); + TF_RETURN_IF_ERROR( + ImportComputation(instruction->to_apply(), &reduce.body())); + return reduce.getOperation(); + } case HloOpcode::kMap: { auto op = func_builder->create( loc, result_type, operands, diff --git a/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir b/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir new file mode 100644 index 00000000000..2a1975384e5 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir @@ -0,0 +1,131 @@ +// RUN: tf-opt -test-buffer-assignment -split-input-file %s | FileCheck %s -dump-input-on-failure + +// CHECK-LABEL: Testing : condBranch +func @condBranch(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ + // CHECK: Alloc: cond_br + cond_br %cond, ^bb1, ^bb2 + ^bb1: + br ^exit(%arg0 : tensor<2xf32>) + ^bb2: + %1 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + br ^exit(%1 : tensor<2xf32>) + ^exit(%arg1: tensor<2xf32>): + return %arg1 : tensor<2xf32> + // CHECK-NEXT: Dealloc: return +} + +// ----- + +// CHECK-LABEL: Testing : criticalEdge +func @criticalEdge(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ + // CHECK: Alloc: cond_br + cond_br %cond, ^bb1, ^exit(%arg0 : tensor<2xf32>) + ^bb1: + %0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + br ^exit(%0 : tensor<2xf32>) + ^exit(%arg1: tensor<2xf32>): + return %arg1 : tensor<2xf32> + // CHECK-NEXT: Dealloc: return +} + +// ----- + +// CHECK-LABEL: Testing : invCriticalEdge +func @invCriticalEdge(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ + // CHECK: Alloc: %0 = "xla_hlo.exp" + %0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + cond_br %cond, ^bb1, ^exit(%arg0 : tensor<2xf32>) + ^bb1: + br ^exit(%0 : tensor<2xf32>) + ^exit(%arg1: tensor<2xf32>): + return %arg1 : tensor<2xf32> + // CHECK-NEXT: Dealloc: return +} + +// ----- + +// CHECK-LABEL: Testing : ifElse +func @ifElse(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ + // CHECK: Alloc: %0 = "xla_hlo.exp"(%arg1) + %0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>) + ^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>): + br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>) + ^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>): + br ^exit(%arg3, %arg4 : tensor<2xf32>, tensor<2xf32>) + ^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>): + // CHECK-NEXT: Dealloc: %7 = "xla_hlo.exp"(%5) + // CHECK: Alloc: %7 = "xla_hlo.exp"(%5) + // CHECK-NEXT: Dealloc: return + %1 = "xla_hlo.exp"(%arg5) : (tensor<2xf32>) -> tensor<2xf32> + return %1 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: Testing : ifElseNoUsers +func @ifElseNoUsers(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ + // CHECK: Alloc: %0 = "xla_hlo.exp"(%arg1) + %0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>) + ^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>): + br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>) + ^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>): + br ^exit(%arg3, %arg4 : tensor<2xf32>, tensor<2xf32>) + ^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>): + // CHECK-NEXT: return + return %arg0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: Testing : ifElseNested +func @ifElseNested(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ + // CHECK: Alloc: %0 = "xla_hlo.exp"(%arg1) + %0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>) + ^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>): + br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>) + ^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>): + cond_br %cond, ^bb3(%arg3 : tensor<2xf32>), ^bb4(%arg4 : tensor<2xf32>) + ^bb3(%arg7 : tensor<2xf32>): + br ^exit(%arg7, %arg3 : tensor<2xf32>, tensor<2xf32>) + ^bb4(%arg8 : tensor<2xf32>): + br ^exit(%arg3, %arg8 : tensor<2xf32>, tensor<2xf32>) + ^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>): + // CHECK-NEXT: Dealloc: %9 = "xla_hlo.exp"(%7) + // CHECK: Alloc: %9 = "xla_hlo.exp"(%7) + // CHECK-NEXT: Dealloc: return + %1 = "xla_hlo.exp"(%arg5) : (tensor<2xf32>) -> tensor<2xf32> + return %1 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: Testing : redundantOperations +func @redundantOperations(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) { + // CHECK: Alloc: %0 = xla_hlo.maximum + // CHECK-NEXT: Dealloc: %1 = xla_hlo.add + %1 = "xla_hlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK: Alloc: %1 = xla_hlo.add + // CHECK-NEXT: Dealloc: %1 = xla_hlo.add + %2 = "xla_hlo.add"(%arg0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return +} + +// ----- + +// CHECK-LABEL: Testing : reduce +func @reduce(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { + // CHECK: Alloc: %0 = xla_hlo.constant + // CHECK-NEXT: Dealloc: %1 = "xla_hlo.reduce"(%arg0, %0) + %0 = xla_hlo.constant dense<0.000000e+00> : tensor + // CHECK: Alloc: %1 = "xla_hlo.reduce"(%arg0, %0) + // CHECK: Dealloc: return + %2 = "xla_hlo.reduce"(%arg0, %0) ( { + ^bb0(%arg1: tensor, %arg2: tensor): + %4 = xla_hlo.add %arg1, %arg2 : tensor + "xla_hlo.return"(%4) : (tensor) -> () + }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> + return %2 : tensor<4x8xf32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir index 1e375e142f7..ff4f1d940bf 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir @@ -31,11 +31,12 @@ func @reduce(%arg: memref<100x10x5xf32>, // CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: [[ELEM_BUF:%.*]] = alloc() : memref -// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: [[ACC_BUF:%.*]] = alloc() : memref +// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref +// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: store [[ACC]], [[ACC_BUF]][] : memref -// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]]) -// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref +// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) +// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref // CHECK: loop.reduce.return [[ACC_RESULT]] : f32 // CHECK: } // CHECK: loop.yield @@ -71,11 +72,12 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>, // CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: [[ELEM_BUF:%.*]] = alloc() : memref -// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: [[ACC_BUF:%.*]] = alloc() : memref +// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref +// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: store [[ACC]], [[ACC_BUF]][] : memref -// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]]) -// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref +// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) +// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref // CHECK: loop.reduce.return [[ACC_RESULT]] // CHECK: } // CHECK: loop.yield @@ -114,11 +116,12 @@ func @dynamic_reduce(%arg: memref, // CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: [[ELEM_BUF:%.*]] = alloc() : memref -// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: [[ACC_BUF:%.*]] = alloc() : memref +// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref +// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: store [[ACC]], [[ACC_BUF]][] : memref -// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]]) -// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref +// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) +// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref // CHECK: loop.reduce.return [[ACC_RESULT]] : f32 // CHECK: } // CHECK: loop.yield @@ -185,11 +188,12 @@ func @reduce_window(%arg: memref<112x112xf32>, // CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: [[ELEM_BUF:%.*]] = alloc() : memref -// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: [[ACC_BUF:%.*]] = alloc() : memref +// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref +// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: store [[ACC]], [[ACC_BUF]][] : memref -// CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]]) -// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref +// CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) +// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref // CHECK: loop.reduce.return [[ACC_RESULT]] : f32 // CHECK: } // CHECK: loop.yield diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index 54ba9704ac5..c8ef751b450 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -698,6 +698,24 @@ add { ROOT %tuple.6 = ((f32[], f32[]), f32[]) tuple(%reduce.1, %sub.5) } +// CHECK-LABEL: func @test_reduce_window +// CHECK-SAME: ([[ARG0:%.*]]: tensor<2x17x31x7xf32>, [[ARG1:%.*]]: tensor) +%test_reduce_window (Arg_0.1: f32[2,17,31,7], Arg_1.2: f32[]) -> f32[2,5,8,7] { + %Arg_0.1 = f32[2,17,31,7] parameter(0) + %Arg_1.2 = f32[] parameter(1) + + // CHECK: "xla_hlo.reduce_window"([[ARG0]], [[ARG1]]) ( { + // CHECK: xla_hlo.add {{.*}} : tensor + // CHECK: }) { + // CHECK-SAME: base_dilations = dense<1> : tensor<4xi64> + // CHECK-SAME: padding = dense<{{\[\[}}0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64> + // CHECK-SAME: window_dilations = dense<[1, 2, 2, 1]> : tensor<4xi64> + // CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64> + // CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64> + // CHECK_SAME: } + ROOT %reduce-window.1 = f32[2,5,8,7] reduce-window(f32[2,17,31,7] %Arg_0.1, f32[] %Arg_1.2), window={size=1x2x2x1 stride=1x4x4x1 pad=0_0x2_0x0_2x0_0 rhs_dilate=1x2x2x1}, to_apply=%reduce_helper.3 +} + // CHECK-LABEL: func @test_remainder // CHECK-SAME: ([[VAL_0:%.*]]: tensor<4xf32>, [[VAL_1:%.*]]: tensor<4xf32>) %test_remainder (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] { diff --git a/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.cc b/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.cc new file mode 100644 index 00000000000..3b40f4c8326 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.cc @@ -0,0 +1,501 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements logic for computing proper alloc and dealloc positions. +// The main class is the BufferAssignment class that realizes this analysis. +// In order to put allocations and deallocations at safe positions, it is +// significantly important to put them into the proper blocks. However, the +// liveness analysis does not pay attention to aliases, which can occur due to +// branches (and their associated block arguments) in general. For this purpose, +// BufferAssignment firstly finds all possible aliases for a single value (using +// the BufferAssignmentAliasAnalysis class). Consider the following example: +// +// ^bb0(%arg0): +// cond_br %cond, ^bb1, ^bb2 +// ^bb1: +// br ^exit(%arg0) +// ^bb2: +// %new_value = ... +// br ^exit(%new_value) +// ^exit(%arg1): +// return %arg1; +// +// Using liveness information on its own would cause us to place the allocs and +// deallocs in the wrong block. This is due to the fact that %new_value will not +// be liveOut of its block. Instead, we have to place the alloc for %new_value +// in bb0 and its associated dealloc in exit. Using the class +// BufferAssignmentAliasAnalysis, we will find out that %new_value has a +// potential alias %arg1. In order to find the dealloc position we have to find +// all potential aliases, iterate over their uses and find the common +// post-dominator block. In this block we can safely be sure that %new_value +// will die and can use liveness information to determine the exact operation +// after which we have to insert the dealloc. Finding the alloc position is +// highly similar and non- obvious. Again, we have to consider all potential +// aliases and find the common dominator block to place the alloc. +// +// TODO(dfki): +// The current implementation does not support loops. The only thing that +// is currently missing is a high-level loop analysis that allows us to move +// allocs and deallocs outside of the loop blocks. + +#include "tensorflow/compiler/mlir/xla/transforms/buffer_assignment.h" + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "absl/memory/memory.h" + +namespace mlir { +namespace xla { +namespace { + +//===----------------------------------------------------------------------===// +// BufferAssignmentAliasAnalysis +//===----------------------------------------------------------------------===// + +/// A straight-forward alias analysis which ensures that all aliases of all +/// values will be determined. This is a requirement for the BufferAssignment +/// class since you need to determine safe positions to place alloc and +/// deallocs. +class BufferAssignmentAliasAnalysis { + public: + using ValueSetT = SmallPtrSet; + + public: + /// Constructs a new alias analysis using the op provided. + BufferAssignmentAliasAnalysis(Operation* op) { build(op->getRegions()); } + + /// Finds all immediate and indirect aliases this value could potentially + /// have. Note that the resulting set will also contain the value provided as + /// it is an alias of itself. + ValueSetT resolve(Value value) const { + ValueSetT result; + resolveRecursive(value, result); + return result; + } + + private: + /// Recursively determines alias information for the given value. It stores + /// all newly found potential aliases in the given result set. + void resolveRecursive(Value value, ValueSetT& result) const { + if (!result.insert(value).second) { + return; + } + auto it = aliases.find(value); + if (it == aliases.end()) return; + for (auto alias : it->second) { + resolveRecursive(alias, result); + } + } + + /// This function constructs a mapping from values to its immediate aliases. + /// It iterates over all blocks, gets their predecessors, determines the + /// values that will be passed to the corresponding block arguments and + /// inserts them into map. + void build(MutableArrayRef regions) { + for (Region& region : regions) { + for (Block& block : region) { + // Iterate over all predecessor and get the mapped values to their + // corresponding block arguments values. + for (auto pred : block.getPredecessors()) { + // Determine the current successor index of the current predecessor. + unsigned successorIndex = std::distance( + pred->getSuccessors().begin(), + llvm::find_if(pred->getSuccessors(), [&](Block* successor) { + return successor == █ + })); + // Get the terminator and the values that will be passed to our block. + if (auto branchInterface = + dyn_cast(pred->getTerminator())) { + // Query the branch op interace to get the successor operands. + auto successorOps = + branchInterface.getSuccessorOperands(successorIndex); + if (successorOps.hasValue()) { + // Build the actual mapping of values to their immediate aliases. + for (auto arg : block.getArguments()) { + Value predecessorArgValue = + successorOps.getValue()[arg.getArgNumber()]; + aliases[predecessorArgValue].insert(arg); + } + } + } + } + } + } + } + + /// Maps values to all immediate aliases this value can have. + llvm::DenseMap aliases; +}; + +//===----------------------------------------------------------------------===// +// BufferAssignmentPositions +//===----------------------------------------------------------------------===// + +/// Stores proper alloc and dealloc positions to place dialect-specific alloc +/// and dealloc operations. +struct BufferAssignmentPositions { + public: + BufferAssignmentPositions() + : allocPosition(nullptr), deallocPosition(nullptr) {} + + /// Creates a new positions tuple including alloc and dealloc positions. + BufferAssignmentPositions(Operation* allocPosition, + Operation* deallocPosition) + : allocPosition(allocPosition), deallocPosition(deallocPosition) {} + + /// Returns the alloc position before which the alloc operation has to be + /// inserted. + Operation* getAllocPosition() const { return allocPosition; } + + /// Returns the dealloc position after which the dealloc operation has to be + /// inserted. + Operation* getDeallocPosition() const { return deallocPosition; } + + private: + Operation* allocPosition; + Operation* deallocPosition; +}; + +//===----------------------------------------------------------------------===// +// BufferAssignmentAnalysis +//===----------------------------------------------------------------------===// + +// The main buffer assignment analysis used to place allocs and deallocs. +class BufferAssignmentAnalysis { + public: + using DeallocSetT = SmallPtrSet; + + public: + BufferAssignmentAnalysis(Operation* op) + : operation(op), + liveness(op), + dominators(op), + postDominators(op), + aliases(op) {} + + /// Computes the actual positions to place allocs and deallocs for the given + /// value. + BufferAssignmentPositions computeAllocAndDeallocPositions(Value value) const { + if (value.use_empty()) { + return BufferAssignmentPositions(value.getDefiningOp(), + value.getDefiningOp()); + } + // Get all possible aliases + auto possibleValues = aliases.resolve(value); + return BufferAssignmentPositions(getAllocPosition(value, possibleValues), + getDeallocPosition(value, possibleValues)); + } + + /// Finds all associated dealloc nodes for the alloc nodes using alias + /// information. + DeallocSetT findAssociatedDeallocs(AllocOp alloc) const { + DeallocSetT result; + auto possibleValues = aliases.resolve(alloc); + for (auto alias : possibleValues) { + for (auto user : alias.getUsers()) { + if (isa(user)) result.insert(user); + } + } + return result; + } + + /// Dumps the buffer assignment information to the given stream. + void print(raw_ostream& os) const { + os << "// ---- Buffer Assignment -----\n"; + + for (Region& region : operation->getRegions()) + for (Block& block : region) + for (Operation& operation : block) + for (Value result : operation.getResults()) { + BufferAssignmentPositions positions = + computeAllocAndDeallocPositions(result); + os << "Positions for "; + result.print(os); + os << "\n Alloc: "; + positions.getAllocPosition()->print(os); + os << "\n Dealloc: "; + positions.getDeallocPosition()->print(os); + os << "\n"; + } + } + + private: + /// Finds a proper placement block to store alloc/dealloc node according to + /// the algorithm described at the top of the file. It supports dominator and + /// post-dominator analyses via template arguments. + template + Block* findPlacementBlock(Value value, const AliasesT& aliases, + const DominatorT& doms) const { + assert(!value.isa() && "Cannot place a block argument"); + // Start with the current block the value is defined in. + Block* dom = value.getDefiningOp()->getBlock(); + // Iterate over all aliases and their uses to find a safe placement block + // according to the given dominator information. + for (auto alias : aliases) { + for (auto user : alias.getUsers()) { + // Move upwards in the dominator tree to find an appropriate + // dominator block that takes the current use into account. + dom = doms.findNearestCommonDominator(dom, user->getBlock()); + } + } + return dom; + } + + /// Finds a proper alloc positions according to the algorithm described at the + /// top of the file. + template + Operation* getAllocPosition(Value value, const AliasesT& aliases) const { + // Determine the actual block to place the alloc and get liveness + // information. + auto placementBlock = findPlacementBlock(value, aliases, dominators); + auto livenessInfo = liveness.getLiveness(placementBlock); + + // We have to ensure that the alloc will be before the first use of all + // aliases of the given value. We first assume that there are no uses in the + // placementBlock and that we can safely place the alloc before the + // terminator at the end of the block. + Operation* startOperation = placementBlock->getTerminator(); + // Iterate over all aliases and ensure that the startOperation will point to + // the first operation of all potential aliases in the placementBlock. + for (auto alias : aliases) { + auto aliasStartOperation = livenessInfo->getStartOperation(alias); + // Check whether the aliasStartOperation lies in the desired block and + // whether it is before the current startOperation. If yes, this will be + // the new startOperation. + if (aliasStartOperation->getBlock() == placementBlock && + aliasStartOperation->isBeforeInBlock(startOperation)) { + startOperation = aliasStartOperation; + } + } + // startOperation is the first operation before which we can safely store + // the alloc taking all potential aliases into account. + return startOperation; + } + + /// Finds a proper dealloc positions according to the algorithm described at + /// the top of the file. + template + Operation* getDeallocPosition(Value value, const AliasesT& aliases) const { + // Determine the actual block to place the dealloc and get liveness + // information. + auto placementBlock = findPlacementBlock(value, aliases, postDominators); + auto livenessInfo = liveness.getLiveness(placementBlock); + + // We have to ensure that the dealloc will be after the last use of all + // aliases of the given value. We first assume that there are no uses in the + // placementBlock and that we can safely place the dealloc at the beginning. + Operation* endOperation = &placementBlock->front(); + // Iterate over all aliases and ensure that the endOperation will point to + // the last operation of all potential aliases in the placementBlock. + for (auto alias : aliases) { + auto aliasEndOperation = + livenessInfo->getEndOperation(alias, endOperation); + // Check whether the aliasEndOperation lies in the desired block and + // whether it is behind the current endOperation. If yes, this will be the + // new endOperation. + if (aliasEndOperation->getBlock() == placementBlock && + endOperation->isBeforeInBlock(aliasEndOperation)) { + endOperation = aliasEndOperation; + } + } + // endOperation is the last operation behind which we can safely store the + // dealloc taking all potential aliases into account. + return endOperation; + } + + /// The operation this transformation was constructed from. + Operation* operation; + + /// The underlying liveness analysis to compute fine grained information about + /// alloc and dealloc positions. + Liveness liveness; + + /// The dominator analysis to place allocs in the appropriate blocks. + DominanceInfo dominators; + + /// The post dominator analysis to place deallocs in the appropriate blocks. + PostDominanceInfo postDominators; + + /// The internal alias analysis to ensure that allocs and deallocs take all + /// their potential aliases into account. + BufferAssignmentAliasAnalysis aliases; +}; + +//===----------------------------------------------------------------------===// +// BufferAssignmentPass +//===----------------------------------------------------------------------===// + +/// The actual buffer assignment pass that moves alloc and dealloc nodes into +/// the right positions. It uses the algorithm described at the top of the file. +// TODO(dfki): create a templated version that allows to match dialect-specific +// alloc/dealloc nodes and to insert dialect-specific dealloc node. +struct BufferAssignmentPass : mlir::FunctionPass { + void runOnFunction() override { + // Get required analysis information first. + auto& analysis = getAnalysis(); + + // Compute an initial placement of all nodes. + llvm::SmallDenseMap placements; + getFunction().walk([&](AllocOp alloc) { + placements[alloc] = analysis.computeAllocAndDeallocPositions(alloc); + }); + + // Move alloc (and dealloc - if any) nodes into the right places + // and insert dealloc nodes if necessary. + getFunction().walk([&](AllocOp alloc) { + // Find already associated dealloc nodes. + auto deallocs = analysis.findAssociatedDeallocs(alloc); + assert(deallocs.size() < 2 && + "Not supported number of associated dealloc operations"); + + // Move alloc node to the right place. + BufferAssignmentPositions& positions = placements[alloc]; + Operation* allocOperation = alloc.getOperation(); + allocOperation->moveBefore(positions.getAllocPosition()); + + // If there is an existing dealloc, move it to the right place. + if (deallocs.size()) { + Operation* nextOp = positions.getDeallocPosition()->getNextNode(); + if (!nextOp) + nextOp = &positions.getDeallocPosition()->getBlock()->back(); + (*deallocs.begin())->moveBefore(nextOp); + } else { + // If there is no dealloc node, insert one in the right place. + OpBuilder builder(alloc); + builder.setInsertionPointAfter(positions.getDeallocPosition()); + builder.create(allocOperation->getLoc(), alloc); + } + }); + }; +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// BufferAssignmentPlacer +//===----------------------------------------------------------------------===// + +/// Creates a new assignment placer. +BufferAssignmentPlacer::BufferAssignmentPlacer(Operation* op) + : operation(op), dominators(op) {} + +/// Computes the actual position to place allocs for the given value. +OpBuilder::InsertPoint BufferAssignmentPlacer::computeAllocPosition( + Value value) { + Operation* insertOp; + if (auto arg = value.dyn_cast()) { + // This is a block argument which has to be allocated in the scope + // of its associated terminator. + auto domNode = dominators.getNode(arg.getOwner()); + assert(domNode != nullptr && "Cannot find dominator info"); + auto idomNode = domNode->getIDom(); + assert(idomNode != nullptr && "There is no parent dominator"); + insertOp = idomNode->getBlock()->getTerminator(); + } else { + insertOp = value.getDefiningOp(); + } + OpBuilder opBuilder(insertOp); + return opBuilder.saveInsertionPoint(); +} + +//===----------------------------------------------------------------------===// +// FunctionAndBlockSignatureConverter +//===----------------------------------------------------------------------===// + +// Performs the actual signature rewriting step. +LogicalResult FunctionAndBlockSignatureConverter::matchAndRewrite( + FuncOp funcOp, ArrayRef operands, + ConversionPatternRewriter& rewriter) const { + auto toMemrefConverter = [&](Type t) -> Type { + if (auto tensorType = t.dyn_cast()) { + return MemRefType::get(tensorType.getShape(), + tensorType.getElementType()); + } + return t; + }; + // Converting tensor-type function arguments to memref-type. + auto funcType = funcOp.getType(); + TypeConverter::SignatureConversion conversion(funcType.getNumInputs()); + for (auto argType : llvm::enumerate(funcType.getInputs())) { + conversion.addInputs(argType.index(), toMemrefConverter(argType.value())); + } + for (auto resType : funcType.getResults()) { + conversion.addInputs(toMemrefConverter(resType)); + } + rewriter.updateRootInPlace(funcOp, [&] { + funcOp.setType( + rewriter.getFunctionType(conversion.getConvertedTypes(), llvm::None)); + rewriter.applySignatureConversion(&funcOp.getBody(), conversion); + }); + // Converting tensor-type block arugments of all blocks inside the + // function region to memref-type except for the entry block. + for (auto& block : funcOp.getBlocks()) { + if (block.isEntryBlock()) continue; + for (int i = 0, e = block.getNumArguments(); i < e; ++i) { + auto oldArg = block.getArgument(i); + auto newArg = + block.insertArgument(i, toMemrefConverter(oldArg.getType())); + oldArg.replaceAllUsesWith(newArg); + block.eraseArgument(i + 1); + } + } + return success(); +} + +// Adding functions whose arguments are memref type to the set of legal +// operations. +void FunctionAndBlockSignatureConverter::addDynamicallyLegalFuncOp( + ConversionTarget& target) { + target.addDynamicallyLegalOp([&](FuncOp op) { + auto inputs = op.getType().getInputs(); + return std::all_of(inputs.begin(), inputs.end(), + [](Type input) { return input.isa(); }); + }); +} + +//===----------------------------------------------------------------------===// +// Buffer assignment pass registrations +//===----------------------------------------------------------------------===// + +std::unique_ptr> createBufferAssignmentPass() { + return absl::make_unique(); +} + +static PassRegistration buffer_assignment_pass( + "buffer-assignment", + "Executes buffer assignment pass to automatically move alloc and dealloc " + "operations into their proper positions"); + +/// A simple pass to print debug/test information for the buffer assignment +/// analysis. +struct BufferAssignmentTestPass : mlir::FunctionPass { + void runOnFunction() override { + llvm::outs() << "Testing : " << getFunction().getName() << "\n"; + getAnalysis().print(llvm::outs()); + }; +}; + +std::unique_ptr> createBufferAssignmentTestPass() { + return absl::make_unique(); +} + +static PassRegistration buffer_assignment_test_pass( + "test-buffer-assignment", + "Outputs debug test information for the buffer assignment analysis"); + +} // namespace xla +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.h b/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.h new file mode 100644 index 00000000000..d8b4c2554bb --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.h @@ -0,0 +1,140 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_BUFFER_ASSIGNMENT_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_BUFFER_ASSIGNMENT_H_ + +#include "mlir/Analysis/Dominance.h" +#include "mlir/Analysis/Liveness.h" +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project + +namespace mlir { +namespace xla { + +/// Prepares a buffer assignment phase. It can place (user-defined) alloc +/// nodes. This simplifies the integration of the actual buffer-assignment +/// pass. Sample usage: +/// BufferAssignmentPlacer baHelper(regionOp); +/// -> determine alloc positions +/// auto allocPosition = baHelper.computeAllocPosition(value); +/// -> place alloc +/// allocBuilder.setInsertionPoint(positions.getAllocPosition()); +/// +/// alternatively: +/// -> place alloc +/// baHelper.insertAlloc(...); +/// Note: this class is intended to be used during legalization. In order +/// to move alloc and dealloc nodes into the right places you can use the +/// createBufferAssignmentPass() function. +class BufferAssignmentPlacer { + public: + /// Creates a new assignment builder. + explicit BufferAssignmentPlacer(Operation* op); + + /// Returns the operation this analysis was constructed from. + Operation* getOperation() const { return operation; } + + /// Computes the actual position to place allocs for the given value. + OpBuilder::InsertPoint computeAllocPosition(Value value); + + private: + /// The operation this analysis was constructed from. + Operation* operation; + + /// The dominator analysis to place allocs in the appropriate blocks. + DominanceInfo dominators; +}; + +/// Helper conversion pattern that encapsulates a BufferAssignmentPlacer +/// instance. +template +class BufferAssignmentOpConversionPattern + : public OpConversionPattern { + public: + explicit BufferAssignmentOpConversionPattern( + MLIRContext* context_, + xla::BufferAssignmentPlacer* bufferAssignment_ = nullptr, + PatternBenefit benefit_ = 1) + : OpConversionPattern(context_, benefit_), + bufferAssignment(bufferAssignment_) {} + + protected: + xla::BufferAssignmentPlacer* bufferAssignment; +}; + +// Converts only the tensor-type function and block arguments to memref-type. +class FunctionAndBlockSignatureConverter + : public BufferAssignmentOpConversionPattern { + public: + using BufferAssignmentOpConversionPattern< + FuncOp>::BufferAssignmentOpConversionPattern; + + // Adding functions whose arguments are memref type to the set of legal + // operations. + static void addDynamicallyLegalFuncOp(ConversionTarget& target); + + // Performs the actual signature rewriting step. + LogicalResult matchAndRewrite( + FuncOp funcOp, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final; +}; + +// This pattern converter transforms a non-void ReturnOpSourceTy into a void +// return of type ReturnOpTargetTy. It uses a copy operation of type CopyOpTy to +// copy the results to the output buffer. +template +class NonVoidToVoidReturnOpConverter + : public BufferAssignmentOpConversionPattern { + public: + using BufferAssignmentOpConversionPattern< + ReturnOpSourceTy>::BufferAssignmentOpConversionPattern; + + // Performs the actual return-op conversion step. + LogicalResult matchAndRewrite( + ReturnOpSourceTy returnOp, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + auto numReturnValues = returnOp.getNumOperands(); + auto funcOp = returnOp.template getParentOfType(); + auto numFuncArgs = funcOp.getNumArguments(); + auto loc = returnOp.getLoc(); + + // Find the corresponding output buffer for each operand. + for (auto operand : llvm::enumerate(operands)) { + auto returnArgNumber = numFuncArgs - numReturnValues + operand.index(); + auto dstBuffer = funcOp.getArgument(returnArgNumber); + if (dstBuffer == operand.value()) { + continue; + } + + // Insert the copy operation to copy before the return. + rewriter.setInsertionPoint( + returnOp.getOperation()->getBlock()->getTerminator()); + rewriter.create(loc, operand.value(), + funcOp.getArgument(returnArgNumber)); + } + // Insert the new target return operation. + rewriter.replaceOpWithNewOp(returnOp); + return success(); + } +}; + +} // namespace xla +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_BUFFER_ASSIGNMENT_H_ diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc index 51edaaf53bd..7215ffef6d3 100644 --- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc @@ -324,8 +324,8 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern { // "xla_lhlo.terminator"() : () -> () // } -struct HloLegalizeToLhlo : public ModulePass { - void runOnModule() override { +struct HloLegalizeToLhlo : public OperationPass { + void runOnOperation() override { OwningRewritePatternList patterns; auto& context = getContext(); ConversionTarget target(context); @@ -344,7 +344,7 @@ struct HloLegalizeToLhlo : public ModulePass { [](Type input) { return input.isa(); }); }); - auto module = getModule(); + auto module = getOperation(); populateHLOToLHLOConversionPattern(module.getContext(), &patterns); // Do partial conversion so we can have unknown ops in tests. diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc index 8d57599d397..053deddcdfe 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc @@ -51,9 +51,10 @@ using mlir::PassRegistration; namespace mlir { namespace xla_hlo { namespace { -class LegalizeTFControlFlow : public ModulePass { +class LegalizeTFControlFlow + : public OperationPass { public: - void runOnModule() override; + void runOnOperation() override; }; } // namespace @@ -164,8 +165,8 @@ void LowerWhile(TF::WhileOp op, ModuleOp module) { } } // namespace -void LegalizeTFControlFlow::runOnModule() { - auto module = getModule(); +void LegalizeTFControlFlow::runOnOperation() { + auto module = getOperation(); module.walk([&](TF::WhileOp op) -> void { LowerWhile(op, module); }); module.walk([&](TF::IfOp op) -> void { LowerIf(op, module); }); diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc index 1250db08ee5..806fe5d6f61 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc @@ -29,48 +29,128 @@ namespace mlir { namespace xla_lhlo { namespace { +// Clones and adapts the code in `lhlo_block` that works on buffers and has a +// single output buffer to make it compatible with `operands` that have element +// types of the respective buffers. Returns the computed value. +// +// Example. For `operands` with (f32, i32) types and a block with LHLO ops and +// with signature: +// ^bb(%lhs: memref, %rhs: memref, %res: memref): +// +// +// inserts necessary alloc and store ops to compute and return result that has +// `i1` type. +Value ApplySingleResultLhloCode(Location loc, ValueRange operands, + Block* lhlo_block, OpBuilder* b) { + SmallVector arg_bufs; + for (auto arg_type : lhlo_block->getArgumentTypes()) { + arg_bufs.push_back(b->create(loc, arg_type.cast())); + } + for (auto operand : llvm::enumerate(operands)) { + b->create(loc, operand.value(), arg_bufs[operand.index()]); + } + // Clone the ops from `lhlo_block`. + BlockAndValueMapping mapping; + mapping.map(lhlo_block->getArguments(), arg_bufs); + for (auto& nested : lhlo_block->without_terminator()) { + auto clone = b->clone(nested, mapping); + mapping.map(nested.getResults(), clone->getResults()); + } + return b->create(loc, arg_bufs.back()); +} + // Converts a block with LHLO ops and with signature: // ^bb(%lhs: memref, %rhs: memref, %res: memref): // into a reduction operator of loop.reduce by doing buffer allocation for // scalar arguments and the result of `loop.reduce` to make it compatible with // LHLO ops. void ConvertToReductionOperator(Location loc, loop::ReduceOp reduce_op, - Block* lhlo_block, - ConversionPatternRewriter* rewriter) { + Block* lhlo_block, OpBuilder* b) { Block& loop_reduce_op_body = reduce_op.reductionOperator().front(); - rewriter->setInsertionPointToStart(&loop_reduce_op_body); - - // Allocate buffers to hold arguments of reduction operator block to stay - // compatible with the LHLO dialect ops in the reduction body. - Value elem_arg = lhlo_block->getArgument(0); - Value elem_buf = - rewriter->create(loc, elem_arg.getType().cast()); - rewriter->create(loc, loop_reduce_op_body.getArgument(0), elem_buf); - Value acc_arg = lhlo_block->getArgument(1); - Value acc_buf = - rewriter->create(loc, acc_arg.getType().cast()); - rewriter->create(loc, loop_reduce_op_body.getArgument(1), acc_buf); - - // Clone the ops from `xla_lhlo.reduce` into reduction operator block. - BlockAndValueMapping mapping; - mapping.map(lhlo_block->getArguments(), - ValueRange{elem_buf, acc_buf, acc_buf}); - for (auto& nested : lhlo_block->without_terminator()) { - auto clone = rewriter->clone(nested, mapping); - mapping.map(nested.getResults(), clone->getResults()); - } - Value acc_result = rewriter->create(loc, acc_buf); - rewriter->create(loc, acc_result); + OpBuilder::InsertionGuard guard(*b); + b->setInsertionPointToStart(&loop_reduce_op_body); + b->create( + loc, ApplySingleResultLhloCode(loc, loop_reduce_op_body.getArguments(), + lhlo_block, b)); } // Returns result of ConstantOp if `dim` is static, otherwise uses DimOp to // extract dimension at runtime. Value GetStaticOrDynamicDim(mlir::Location loc, Value shaped_value, - size_t dim_index, int64_t dim, - ConversionPatternRewriter* rewriter) { + size_t dim_index, int64_t dim, OpBuilder* b) { return dim == ShapedType::kDynamicSize - ? rewriter->create(loc, shaped_value, dim_index).getResult() - : rewriter->create(loc, dim); + ? b->create(loc, shaped_value, dim_index).getResult() + : b->create(loc, dim); +} + +struct MappedIvs { + // False if the mapped indices are in the padding area, true otherwise. + Value in_bounds; + // Mapped indices. + SmallVector ivs; +}; + +MappedIvs MapWindowIvsToInput(ReduceWindowOp op, ValueRange ivs, + ValueRange window_ivs, OpBuilder* b) { + MappedIvs mapped_ivs; + + if (!op.window_strides().hasValue()) { + op.emitOpError("No window strides specified."); + } + auto window_strides = op.window_strides().getValue(); + + if (!op.padding().hasValue()) { + op.emitOpError("No padding specified."); + } + auto padding = op.padding().getValue(); + + auto loc = op.getLoc(); + auto operand = op.operand(); + auto operand_shape = operand.getType().cast().getShape(); + + // `in_bounds` is false when the mapped indices are in the padding area. + mapped_ivs.in_bounds = b->create( + loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1)); + for (unsigned i = 0, e = ivs.size(); i < e; ++i) { + auto stride = window_strides.getValue(i); + auto pad_low = padding.getValue({i, 0}); + + Value stride_val = b->create(loc, stride.getSExtValue()); + Value pad_low_val = b->create(loc, pad_low.getSExtValue()); + + Value center = b->create(loc, ivs[i], stride_val); + Value offset = b->create(loc, window_ivs[i], pad_low_val); + Value index = b->create(loc, center, offset); + Value upper_bound = + GetStaticOrDynamicDim(loc, operand, i, operand_shape[i], b); + // We must check whether 0 <= index_i < shape_i, as otherwise we are in + // the pad and then we have to use the neutral element for reduction. + // Equivalently, it can be computed as the unsigned comparison index_i < + // shape_i, since a negative value wraps to a large positive value. + mapped_ivs.in_bounds = b->create( + loc, mapped_ivs.in_bounds, + b->create(loc, CmpIPredicate::ult, index, upper_bound)); + mapped_ivs.ivs.push_back(index); + } + return mapped_ivs; +} + +// Returns loop::Parallel over a shaped value with static or dynamic shape. +loop::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value, + OpBuilder* b) { + Value zero = b->create(loc, 0); + Value one = b->create(loc, 1); + + ArrayRef shape = + shaped_value.getType().cast().getShape(); + SmallVector lower, upper, step; + for (auto dim : llvm::enumerate(shape)) { + upper.push_back( + GetStaticOrDynamicDim(loc, shaped_value, dim.index(), dim.value(), b)); + lower.push_back(zero); + step.push_back(one); + } + return b->create(loc, lower, upper, step); } // Converts `xla_lhlo.ReduceOp` into two loop::ParallelOp and a loop::ReduceOp. @@ -186,7 +266,7 @@ class ReduceOpConverter : public OpConversionPattern { SmallVector out_indices; if (outer != nullptr) { out_indices.reserve(outer.getNumLoops()); - for (auto& iv : outer.getInductionVars()) { + for (Value iv : outer.getInductionVars()) { out_indices.push_back(iv); } } else { @@ -198,12 +278,16 @@ class ReduceOpConverter : public OpConversionPattern { // Load the element to reduce. SmallVector indices; indices.reserve(operand_shape.size()); - Block::args_iterator outer_ivs_it = - outer ? outer.getInductionVars().begin() : nullptr; - Block::args_iterator inner_ivs_it = inner.getInductionVars().begin(); - for (unsigned i = 0, e = operand_shape.size(); i < e; ++i) { - indices.push_back(reducing_dims.count(i) ? *inner_ivs_it++ - : *outer_ivs_it++); + + if (outer) { + auto inner_ivs_it = inner.getInductionVars().begin(); + auto outer_ivs_it = outer.getInductionVars().begin(); + for (unsigned i = 0, e = operand_shape.size(); i < e; ++i) { + indices.push_back(reducing_dims.count(i) ? *inner_ivs_it++ + : *outer_ivs_it++); + } + } else { + indices = ValueRange(inner.getInductionVars()); } rewriter->setInsertionPointToStart(inner.getBody()); @@ -309,20 +393,11 @@ class ReduceWindowOpConverter // Create an outer parallel loop that spans the output of ReduceWindowOp. Value xla_output = xla_reduce_window_op.out(); - auto output_shape = xla_output.getType().cast().getShape(); - SmallVector parallel_lower, parallel_upper, parallel_step; - for (auto dim : llvm::enumerate(output_shape)) { - parallel_upper.push_back(GetStaticOrDynamicDim( - loc, xla_output, dim.index(), dim.value(), rewriter)); - parallel_lower.push_back(zero); - parallel_step.push_back(one); - } - auto output_loop = rewriter->create( - loc, parallel_lower, parallel_upper, parallel_step); + auto output_loop = MakeLoopOverShape(loc, xla_output, rewriter); // Create a nested loop that traverses the window. - rewriter->setInsertionPointToStart(output_loop.getBody()); SmallVector window_lower, window_upper, window_step; + rewriter->setInsertionPointToStart(output_loop.getBody()); for (const auto& window_dim : xla_reduce_window_op.window_dimensions()) { window_step.push_back(one); window_lower.push_back(zero); @@ -334,9 +409,8 @@ class ReduceWindowOpConverter Value reduction_result = *window_loop.getResults().begin(); auto output_ivs = output_loop.getInductionVars(); - rewriter->create( - loc, reduction_result, xla_output, - llvm::makeArrayRef(output_ivs.begin(), output_ivs.end())); + rewriter->create(loc, reduction_result, xla_output, + ValueRange{output_ivs}); return std::make_pair(output_loop, window_loop); } @@ -347,12 +421,6 @@ class ReduceWindowOpConverter rewriter->setInsertionPointToStart(window_loop.getBody()); auto loc = xla_reduce_window_op.getLoc(); - if (!xla_reduce_window_op.window_strides().hasValue()) { - xla_reduce_window_op.emitOpError("No window strides specified."); - } - if (!xla_reduce_window_op.padding().hasValue()) { - xla_reduce_window_op.emitOpError("No padding specified."); - } if (xla_reduce_window_op.base_dilations().hasValue() || xla_reduce_window_op.window_dilations().hasValue()) { xla_reduce_window_op.emitRemark( @@ -362,51 +430,18 @@ class ReduceWindowOpConverter Value xla_operand = xla_reduce_window_op.operand(); auto xla_operand_type = xla_operand.getType().cast(); - auto xla_operand_shape = xla_operand_type.getShape(); - auto output_ivs = llvm::to_vector<2>(output_loop.getInductionVars()); - auto window_ivs = llvm::to_vector<2>(window_loop.getInductionVars()); - auto window_strides = xla_reduce_window_op.window_strides().getValue(); - auto padding = xla_reduce_window_op.padding().getValue(); + MappedIvs mapped_ivs = MapWindowIvsToInput( + xla_reduce_window_op, output_loop.getInductionVars(), + window_loop.getInductionVars(), rewriter); - SmallVector operand_indices; - // `in_bounds` is false when the element in the reduce window is in the - // padding area, true otherwise. - Value in_bounds = rewriter->create( - loc, rewriter->getI1Type(), - rewriter->getIntegerAttr(rewriter->getI1Type(), 1)); - for (unsigned i = 0, e = output_loop.getNumLoops(); i < e; ++i) { - auto stride = window_strides.getValue(i); - auto pad_low = padding.getValue({i, 0}); - - Value stride_val = - rewriter->create(loc, stride.getSExtValue()); - Value pad_low_val = - rewriter->create(loc, pad_low.getSExtValue()); - - Value center = rewriter->create(loc, output_ivs[i], stride_val); - Value offset = rewriter->create(loc, window_ivs[i], pad_low_val); - Value index = rewriter->create(loc, center, offset); - operand_indices.push_back(index); - Value upper_bound = GetStaticOrDynamicDim(loc, xla_operand, i, - xla_operand_shape[i], rewriter); - // We must check whether 0 <= index_i < shape_i, as otherwise we are in - // the pad and then we have to use the neutral element for reduction. - // Equivalently, it can be computed as the unsigned comparison index_i < - // shape_i, since a negative value wraps to a large positive value. - in_bounds = rewriter->create( - loc, in_bounds, - rewriter->create(loc, CmpIPredicate::ult, index, - upper_bound)); - } - - auto elem_or_init = - rewriter->create(loc, xla_operand_type.getElementType(), - in_bounds, /*withElseRegion=*/true); + auto elem_or_init = rewriter->create( + loc, xla_operand_type.getElementType(), mapped_ivs.in_bounds, + /*withElseRegion=*/true); OpBuilder then_builder = elem_or_init.getThenBodyBuilder(); Value elem = then_builder.create( - loc, xla_reduce_window_op.operand(), operand_indices); + loc, xla_reduce_window_op.operand(), mapped_ivs.ivs); then_builder.create(loc, elem); OpBuilder else_builder = elem_or_init.getElseBodyBuilder(); @@ -423,8 +458,12 @@ struct LhloLegalizeToParallelLoops auto func = getFunction(); OwningRewritePatternList patterns; - patterns.insert( - func.getContext()); + // clang-format off + patterns.insert< + ReduceOpConverter, + ReduceWindowOpConverter + >(func.getContext()); + // clang-format on ConversionTarget target(getContext()); target.addLegalDialect createLhloCopyRemovalPass(); std::unique_ptr> createLegalizeLhloToParallelLoopsPass(); } // namespace xla_lhlo + +namespace xla { + +/// Moves alloc nodes (and their associated dealloc nodes - if any) into the +/// right positions. If there is no associated dealloc node for a given alloc +/// node, this pass will automatically insert a proper dealloc node in the right +/// place. The intended use case of this pass is to store SSA values into +/// buffers using load/store operations. For this purpose, you need to know +/// proper positions to place the required allocs and deallocs. +/// 1) Note that the function signatures and all types for which buffers should +/// be allocated need to be converted in advance. +/// 2) All required alloc nodes have the be inserted in advance. +/// 3) Note that the current implementation does not support loops. +/// Refer to the class mlir::xla::BufferAssignmentLegalizer for more +/// information. +std::unique_ptr> createBufferAssignmentPass(); + +} // namespace xla } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_PASSES_H_ diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 2868ecc61fd..706cd6e515a 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1,8 +1,11 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test") -load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") # buildifier: disable=same-origin-load load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") -load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test") -load("//tensorflow/compiler/tests:build_defs.bzl", "generate_backend_suites") +load( + "//tensorflow/compiler/tests:build_defs.bzl", + "generate_backend_suites", + "tf_xla_py_test", +) load( "//tensorflow/core/platform:build_config_root.bzl", "tf_cuda_tests_tags", diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index a0aea950cde..6e1b87a0cf7 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -134,6 +134,38 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([[-1, 1]], dtype=dtype), expected=np.array([[-1, 1]], dtype=dtype)) + def testLog(self): + for dtype in self.float_types - {dtypes.bfloat16.as_numpy_dtype}: + tol = 1e-4 if dtype == np.float32 else 1e-9 + x = np.linspace(-np.e, np.e, num=1000, dtype=dtype) + self._assertOpOutputMatchesExpected( + math_ops.log, x, expected=np.log(x), atol=tol, rtol=tol) + + x = np.linspace(0., np.e * 1e-30, num=1000, dtype=dtype) + self._assertOpOutputMatchesExpected( + math_ops.log, x, expected=np.log(x), atol=tol, rtol=tol) + + x = np.linspace(0., np.pi * 1e30, num=1000, dtype=dtype) + self._assertOpOutputMatchesExpected( + math_ops.log, x, expected=np.log(x), atol=tol, rtol=tol) + + def testSin(self): + for dtype in self.float_types - {dtypes.bfloat16.as_numpy_dtype}: + tol = 1e-3 if dtype == np.float32 else 1e-10 + + x = np.linspace(-4 * np.pi, 4 * np.pi, num=1000, dtype=dtype) + self._assertOpOutputMatchesExpected( + math_ops.sin, x, expected=np.sin(x), rtol=tol, atol=tol) + + x = np.linspace(0., 2.71828e-30, num=1000, dtype=dtype) + self._assertOpOutputMatchesExpected( + math_ops.sin, x, expected=np.sin(x), rtol=tol, atol=tol) + + if dtype == np.float64: + x = np.linspace(0., 3.141592e8, num=1000, dtype=dtype) + self._assertOpOutputMatchesExpected( + math_ops.sin, x, expected=np.sin(x), rtol=1e-5, atol=1e-5) + def testFloatOps(self): for dtype in self.float_types: x = np.arange(-0.90, 0.90, 0.25) diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index 371a5804008..6291ea6cbda 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -68,7 +68,7 @@ tf_cuda_cc_test( "nomac", ], deps = [ - "//tensorflow/core:gpu_init", + "//tensorflow/core/common_runtime/gpu:gpu_init", "//tensorflow/core:lib", "//tensorflow/core:stream_executor", "//tensorflow/core:test", @@ -97,7 +97,6 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@local_config_cuda//cuda:cuda_headers", - "//tensorflow/core:core_cpu_lib_no_ops", "//tensorflow/core:framework", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib", @@ -105,6 +104,7 @@ cc_library( "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:stream_executor", "//tensorflow/core:stream_executor_headers_lib", + "//tensorflow/core/common_runtime:core_cpu_lib_no_ops", "//tensorflow/core/grappler/costs:graph_properties", "//tensorflow/stream_executor/lib", ] + if_tensorrt([":tensorrt_lib"]) + tf_custom_op_library_additional_deps(), diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index c9d46251069..3e9a7954b03 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -617,6 +617,11 @@ std::pair GetDeviceAndAllocator(const ConversionParams& params, return std::make_pair(cuda_device_id, dev_allocator); } +int64 GetNextGraphSequenceNumber() { + static std::atomic graph_sequence_num; + return graph_sequence_num++; +} + // Entry function from optimization pass. Status ConvertAfterShapes(const ConversionParams& params) { // Sanity checks. @@ -666,10 +671,12 @@ Status ConvertAfterShapes(const ConversionParams& params) { std::vector engine_bytes_size; segment::SegmentNodesVector converted_segments; converted_segments.reserve(initial_segments.size()); + string engine_name_prefix = + StrCat("TRTEngineOp_", GetNextGraphSequenceNumber(), "_"); for (size_t t = 0; t < initial_segments.size(); t++) { auto& curr_segment = initial_segments.at(t); EngineInfo curr_engine; - curr_engine.engine_name = StrCat("TRTEngineOp_", t); + curr_engine.engine_name = StrCat(engine_name_prefix, t); Status status = GetEngineInfo(&graph, *params.graph_properties, curr_segment, node_map, reverse_topo_order, &curr_engine); diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc index 1646749ad9c..2cfefd27a67 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h" +#include // NOLINT + #include #include #include "tensorflow/cc/framework/ops.h" @@ -203,15 +205,22 @@ TEST_F(ConvertAfterShapesTest, DirectlyConnectedEngines) { GraphDef output_graph_def; TF_EXPECT_OK(RunConvertAfterShape(s, &output_graph_def)); + auto remove_graph_sequence_number = [](std::string node_name) { + const std::regex pattern("TRTEngineOp_[0-9]+_"); + return std::regex_replace(node_name, pattern, "TRTEngineOp_"); + }; int num_trt_ops = 0; for (const NodeDef& node : output_graph_def.node()) { - if (node.name() == "TRTEngineOp_1") { + std::string node_name = node.name(); + if (node.op() != "TRTEngineOp") continue; + node_name = remove_graph_sequence_number(node_name); + if (node_name == "TRTEngineOp_1") { EXPECT_EQ(1, node.input_size()); EXPECT_EQ("input", node.input(0)); ++num_trt_ops; - } else if (node.name() == "TRTEngineOp_0") { + } else if (node_name == "TRTEngineOp_0") { EXPECT_EQ(2, node.input_size()); - EXPECT_EQ("TRTEngineOp_1", node.input(0)); + EXPECT_EQ("TRTEngineOp_1", remove_graph_sequence_number(node.input(0))); EXPECT_EQ("reshape2", node.input(1)); ++num_trt_ops; } diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 5f1c2f28ba4..b49e23a9a2c 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -155,9 +155,11 @@ tf_kernel_library( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", @@ -198,20 +200,14 @@ tf_kernel_library( "//tensorflow/core:stateful_random_ops_op_lib", "//tensorflow/core:stateless_random_ops_op_lib", "//tensorflow/core:training_ops_op_lib", - "//tensorflow/core/kernels:constant_op", - "//tensorflow/core/kernels:control_flow_ops", - "//tensorflow/core/kernels:cwise_op", - "//tensorflow/core/kernels:list_kernels", "//tensorflow/core/kernels:pooling_ops", - "//tensorflow/core/kernels:random_op", - "//tensorflow/core/kernels:resource_variable_ops", - "//tensorflow/core/kernels:sendrecv_ops", - "//tensorflow/core/kernels:sparse_to_dense_op", - "//tensorflow/core/kernels:stack_ops", "//tensorflow/core/kernels:stateful_random_ops", - "//tensorflow/core/kernels:training_ops", + "//tensorflow/stream_executor:stream_header", + "//tensorflow/stream_executor/lib", + "//third_party/eigen3", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", @@ -230,9 +226,9 @@ cc_library( deps = [ "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", @@ -251,15 +247,19 @@ cc_library( deps = [ "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/core:framework", - "//tensorflow/core:framework_bounds_check", + "//tensorflow/core:framework_lite", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:conv_grad_shape_utils", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:status", + "//tensorflow/core/platform:stringpiece", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/types:span", ], ) @@ -269,17 +269,18 @@ cc_library( srcs = ["tensor_list_utils.cc"], hdrs = ["tensor_list_utils.h"], deps = [ - "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/stream_executor/lib", ], ) @@ -288,11 +289,14 @@ cc_library( srcs = ["if_while_utils.cc"], hdrs = ["if_while_utils.h"], deps = [ - "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", - "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/container:inlined_vector", ], ) @@ -303,17 +307,21 @@ tf_kernel_library( deps = [ ":if_while_utils", ":tensor_list_utils", - "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/strings", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:status", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/container:inlined_vector", ], ) @@ -323,15 +331,17 @@ tf_kernel_library( hdrs = ["if_op.h"], deps = [ ":if_while_utils", - "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", - "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/container:inlined_vector", ], ) @@ -341,15 +351,19 @@ tf_kernel_library( hdrs = ["case_op.h"], deps = [ ":if_while_utils", - "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", - "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc index f34b2ff11df..2eab811c29e 100644 --- a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc @@ -13,11 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index 3d9aceae8ec..0d3912cf637 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -13,14 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" -#include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include + +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/assert_op.cc b/tensorflow/compiler/tf2xla/kernels/assert_op.cc index c40caa8fa10..da9dfd5fd6b 100644 --- a/tensorflow/compiler/tf2xla/kernels/assert_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/assert_op.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index f60509b3746..958cb9ac787 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index fcc93eb0e8d..0b240718f96 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -13,6 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include +#include +#include + // XLA implementation of BatchNorm operations. #include "tensorflow/compiler/tf2xla/kernels/relu_op.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -20,10 +26,20 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 6b675fa8a94..6f49b1a5986 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -13,10 +13,25 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index c022284fec6..3119e48e618 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -16,11 +16,18 @@ limitations under the License. // XLA-specific Ops for broadcasting used in gradient // code. +#include + +#include "absl/container/inlined_vector.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/bcast.h" diff --git a/tensorflow/compiler/tf2xla/kernels/beta_op.cc b/tensorflow/compiler/tf2xla/kernels/beta_op.cc index aa4a8cae118..248046076a0 100644 --- a/tensorflow/compiler/tf2xla/kernels/beta_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/beta_op.cc @@ -13,18 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/lib/broadcast.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc index 33bdf9aec31..29ccabc3aa7 100644 --- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -13,14 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include +#include #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 0ea851e9325..c87f6e6fdf7 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -15,20 +15,26 @@ limitations under the License. // Native XLA implementations of simple binary Ops +#include +#include +#include + +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" -#include "tensorflow/compiler/tf2xla/lib/broadcast.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/bcast.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc index d7a8e67dd33..9943d5bdea9 100644 --- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc @@ -13,11 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc index 5078f8662bd..9bfc3a9d409 100644 --- a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc @@ -14,12 +14,18 @@ limitations under the License. ==============================================================================*/ #include +#include #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/case_op.cc b/tensorflow/compiler/tf2xla/kernels/case_op.cc index 1b15c09f7e3..9b775d2b8df 100644 --- a/tensorflow/compiler/tf2xla/kernels/case_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/case_op.cc @@ -15,14 +15,36 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/case_op.h" +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/kernels/if_while_utils.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" -#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/case_op.h b/tensorflow/compiler/tf2xla/kernels/case_op.h index 4a61707864e..dee56c7639c 100644 --- a/tensorflow/compiler/tf2xla/kernels/case_op.h +++ b/tensorflow/compiler/tf2xla/kernels/case_op.h @@ -21,7 +21,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/macros.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index d0b60d9b820..2e55330c522 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -23,10 +24,15 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index dad310911a0..4dacb21414c 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -15,21 +15,26 @@ limitations under the License. // XLA implementations of Categorical op. +#include +#include + #include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/prng.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/check_numerics_op.cc b/tensorflow/compiler/tf2xla/kernels/check_numerics_op.cc index 6061e822d8d..8f892cb2e2a 100644 --- a/tensorflow/compiler/tf2xla/kernels/check_numerics_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/check_numerics_op.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc index e6b30a38e03..e110f21df4d 100644 --- a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc @@ -13,10 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc index 547fe48046e..9b6131905ab 100644 --- a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc @@ -13,10 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index 09c97de13eb..1bad21e9d46 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -15,23 +15,18 @@ limitations under the License. // XLA-specific Concat Ops. -#include #include -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index ff6c54e47c6..5f214431892 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -13,14 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index b60a13972a7..6ae07f7e898 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -17,28 +17,32 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h" +#include +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_shape_util.h" -#include "tensorflow/core/framework/node_def_util.h" -#include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/conv_grad_shape_utils.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h index 09829fb2767..0a242851689 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h @@ -18,12 +18,17 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/stream_executor/lib/statusor.h" // This header exposes utilities for translating TensorFlow convolution ops into // XLA ops. diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index e9cd5d2744e..e793cafce85 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -15,27 +15,19 @@ limitations under the License. // XLA-specific Ops for 2D convolution. +#include + #include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h" #include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/core/framework/bounds_check.h" -#include "tensorflow/core/framework/node_def_util.h" -#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/ops_util.h" -#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/tensor_slice.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/util/padding.h" -#include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/cross_op.cc b/tensorflow/compiler/tf2xla/kernels/cross_op.cc index db579a5b35d..6acca398c2d 100644 --- a/tensorflow/compiler/tf2xla/kernels/cross_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cross_op.cc @@ -13,10 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc index a709a20c28b..79ba4d0a7c8 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc @@ -17,16 +17,19 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include +#include +#include + +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/lib/broadcast.h" -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/bcast.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h index 516ead4bfe8..63b27f265e6 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h @@ -18,10 +18,17 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_ #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_ +#include +#include + +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/bcast.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc index fb89742b139..0cffa92aa8f 100644 --- a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc @@ -13,15 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include -#include -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index df8bee7f6d5..3c1997edffe 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -13,12 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/lib/data_format.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc index 7ac38369eb4..b780319f3c7 100644 --- a/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc @@ -14,16 +14,18 @@ limitations under the License. ==============================================================================*/ #include +#include +#include -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index d22516555d4..4e6803363b8 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -13,18 +13,25 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" -#include "tensorflow/compiler/xla/client/lib/pooling.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc index bb2c0d9ddb8..b540db7ac6e 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc @@ -13,17 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include -#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" - -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index b119997cf39..636e3e4c006 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -15,17 +15,25 @@ limitations under the License. // XLA-specific dynamic stitch Op. +#include +#include + +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/einsum_op.cc b/tensorflow/compiler/tf2xla/kernels/einsum_op.cc index 028f5fa5f53..250533a5192 100644 --- a/tensorflow/compiler/tf2xla/kernels/einsum_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/einsum_op.cc @@ -14,13 +14,18 @@ limitations under the License. ==============================================================================*/ #include +#include +#include -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc index f66b81620fa..af4970fef07 100644 --- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc @@ -17,12 +17,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/elu_op.h" -#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/framework/types.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/op_kernel.h" namespace xla { XlaOp Elu(XlaOp x) { diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.h b/tensorflow/compiler/tf2xla/kernels/elu_op.h index 80f8b6bd45f..0d684fa6709 100644 --- a/tensorflow/compiler/tf2xla/kernels/elu_op.h +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.h @@ -15,7 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_ELU_OP_H_ #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_ELU_OP_H_ -#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" namespace xla { diff --git a/tensorflow/compiler/tf2xla/kernels/empty_op.cc b/tensorflow/compiler/tf2xla/kernels/empty_op.cc index 00d2ce7c12f..4c5377d8f36 100644 --- a/tensorflow/compiler/tf2xla/kernels/empty_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/empty_op.cc @@ -15,15 +15,19 @@ limitations under the License. // XLA-specific Empty Op. +#include + #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index 63e3f185421..d8a9c385af8 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -13,19 +13,28 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_shape_util.h" -#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc b/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc index ec3463bd58f..1f1f41b72a6 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc @@ -14,12 +14,15 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/macros.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc index 96f066d117c..29ec43c7eb3 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -13,12 +13,26 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + +#include +#include + +#include "absl/container/inlined_vector.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/core/platform/macros.h" +#include "tensorflow/compiler/xla/layout.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index e5e4e797cc5..b507e4445e5 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -15,21 +15,22 @@ limitations under the License. // XLA-specific Ops for FFT. +#include +#include + +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/framework/bounds_check.h" -#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/ops_util.h" -#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/tensor_slice.h" -#include "tensorflow/core/util/padding.h" -#include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index 5e489b16919..8b4114cab36 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -15,14 +15,16 @@ limitations under the License. // XLA-specific Fill Op. -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/function_ops.cc b/tensorflow/compiler/tf2xla/kernels/function_ops.cc index 516e3aeaa88..0d386321e29 100644 --- a/tensorflow/compiler/tf2xla/kernels/function_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/function_ops.cc @@ -13,12 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index e8a3dab4bed..d1729a9827a 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -13,22 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include +#include "absl/container/inlined_vector.h" #include "absl/types/optional.h" -#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h index 7bd25230d46..49be7d96f5d 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h @@ -19,10 +19,11 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_GATHER_OP_HELPERS_H_ #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/util/bcast.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc b/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc index 19aa85f9d42..883051e4fa2 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc @@ -13,11 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/identity_op.cc b/tensorflow/compiler/tf2xla/kernels/identity_op.cc index 38d8056d3e5..7eb0faa23bc 100644 --- a/tensorflow/compiler/tf2xla/kernels/identity_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/identity_op.cc @@ -16,6 +16,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/macros.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 2a059f78526..2f6125a1f92 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -15,14 +15,34 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/if_op.h" -#include "tensorflow/compiler/tf2xla/const_analysis.h" +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/kernels/if_while_utils.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" -#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.h b/tensorflow/compiler/tf2xla/kernels/if_op.h index 3ac1b344ef8..1347ebb8678 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.h +++ b/tensorflow/compiler/tf2xla/kernels/if_op.h @@ -16,8 +16,15 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_ #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_ +#include +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/macros.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc b/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc index 82d8eb892df..15dc22f1990 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc @@ -15,7 +15,25 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/if_while_utils.h" +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/if_while_utils.h b/tensorflow/compiler/tf2xla/kernels/if_while_utils.h index 631fedd25f7..e1530bdf890 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_while_utils.h +++ b/tensorflow/compiler/tf2xla/kernels/if_while_utils.h @@ -16,8 +16,15 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_WHILE_UTILS_H_ #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_WHILE_UTILS_H_ +#include +#include + +#include "absl/container/inlined_vector.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index e7bf343cd70..97fd9225bd4 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -13,9 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" -#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -25,10 +30,14 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/sorting.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 8e53ca162f5..54a57e78d1f 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -14,21 +14,34 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/kernels/image_resize_ops.h" +#include +#include + +#include +#include +#include + #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.h b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.h index b50b905a6e1..fc2ea2ae533 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.h @@ -16,7 +16,8 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_IMAGE_RESIZE_OPS_H_ #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc b/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc index 246d3f6da94..0f4b7ccca95 100644 --- a/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc @@ -13,19 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/sorting.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index 219dc738eaa..0426163a8fd 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -17,17 +17,20 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/index_ops.h" +#include + #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/framework/bounds_check.h" -#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { XlaArgMinMaxOp::XlaArgMinMaxOp(OpKernelConstruction* ctx, bool is_min) diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc index 7f25d34c3ef..9c902085b50 100644 --- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc @@ -13,12 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc index e46f4e72dc9..90f051ee456 100644 --- a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc @@ -16,16 +16,20 @@ limitations under the License. // XLA-specific ListDiff Op. This only supports constant DT_INT32 and DT_INT64 // input. +#include #include +#include -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc b/tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc index 0eacf8812f1..a5c418d60bb 100644 --- a/tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc @@ -13,15 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/type_util.h" +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc index 987901d82b3..11935284ec1 100644 --- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc @@ -16,8 +16,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index a3fcb4d4b8f..f9ab72f890f 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -15,11 +15,16 @@ limitations under the License. // XLA-specific MatMul Op. -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc index 2dd0a710e47..305cc515231 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc @@ -13,12 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc index 57e961917cc..f70715c4aea 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc @@ -13,15 +13,30 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "third_party/eigen3/Eigen/Core" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_inverse_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_inverse_op.cc index 8c625b476f3..7c9a9624891 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_inverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_inverse_op.cc @@ -13,10 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/lib/qr.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_solve_op.cc index 8a4e71068b8..ca60e68c721 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_solve_op.cc @@ -13,11 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/lib/qr.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc index 5a719484e05..fcac690d32f 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -13,13 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/bcast.h" #include "tensorflow/core/util/matmul_bcast.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index 656f9b898f3..bb280d946a2 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -13,11 +13,24 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/error_codes.pb.h" #include "tensorflow/core/util/mirror_pad_mode.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/next_after_op.cc b/tensorflow/compiler/tf2xla/kernels/next_after_op.cc index 0801c52500f..a3529015be0 100644 --- a/tensorflow/compiler/tf2xla/kernels/next_after_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/next_after_op.cc @@ -14,14 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/lib/broadcast.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/util/bcast.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/no_op.cc b/tensorflow/compiler/tf2xla/kernels/no_op.cc index da50b75251b..e658cf1ef89 100644 --- a/tensorflow/compiler/tf2xla/kernels/no_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/no_op.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc index aba54578d97..38a92f61341 100644 --- a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc @@ -15,10 +15,15 @@ limitations under the License. // XLA implementation of OneHot operator. -#include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc index 6ca100a2f2b..2715ec7d631 100644 --- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc @@ -15,22 +15,15 @@ limitations under the License. // XLA Pack operator. -#include #include -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index 36ea70ac392..117df8d4bdb 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -13,14 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 5f5cae8f176..3e4e66537ea 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -15,6 +15,12 @@ limitations under the License. // XLA specific pooling ops. +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -23,17 +29,22 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/pooling.h" +#include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/kernels/pooling_ops_common.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/stream_executor/lib/statusor.h" +#include "tensorflow/stream_executor/stream_executor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/qr_op.cc b/tensorflow/compiler/tf2xla/kernels/qr_op.cc index 66ec40a946b..b639235b118 100644 --- a/tensorflow/compiler/tf2xla/kernels/qr_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/qr_op.cc @@ -13,9 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/qr.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc index e235f291ddb..9b252f87ef0 100644 --- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc @@ -13,19 +13,29 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include +#include +#include +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 1ccf0b4b125..003f3375f75 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -17,21 +17,33 @@ limitations under the License. // TODO(misard,phawkins): handle random number generator seeds/states correctly. // TODO(misard,phawkins): add tests. +#include + +#include +#include + +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/lib/random.h" -#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/comparators.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops_util.h b/tensorflow/compiler/tf2xla/kernels/random_ops_util.h index 9a6dc37e2c9..92d253a5e78 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops_util.h +++ b/tensorflow/compiler/tf2xla/kernels/random_ops_util.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_RANDOM_OPS_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_RANDOM_OPS_UTIL_H_ -#include - +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { // Returns a tensor containing 'shape' random values uniformly distributed in diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc index 8bd8edc9497..79ca63428b4 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc @@ -13,16 +13,27 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/kernels/while_op.h" +#include +#include +#include +#include #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/core/framework/function.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 4f63c0d1b66..a50cf47a9b0 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -16,13 +16,18 @@ limitations under the License. // XLA-specific reduction Ops. #include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h" -#include "tensorflow/compiler/tf2xla/type_util.h" + +#include + #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h index af716eab798..bebe68d2a54 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h @@ -18,9 +18,14 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_REDUCTION_OPS_H_ #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_REDUCTION_OPS_H_ +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 2ca2a85244b..d63d9507990 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -15,6 +15,12 @@ limitations under the License. // XLA-specific reduction Ops. +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -23,7 +29,15 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index 3b53bacb524..e7332adee5d 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -17,10 +17,15 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/relu_op.h" +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/types.h" namespace xla { XlaOp Relu(XlaOp x) { return Max(ScalarLike(x, 0), x); } diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.h b/tensorflow/compiler/tf2xla/kernels/relu_op.h index 7e4a3833bc5..7f814a52573 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.h +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.h @@ -15,7 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_RELU_OP_H_ #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_RELU_OP_H_ -#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" namespace xla { diff --git a/tensorflow/compiler/tf2xla/kernels/replica_id_op.cc b/tensorflow/compiler/tf2xla/kernels/replica_id_op.cc index 46585a26769..59f7f79f4e0 100644 --- a/tensorflow/compiler/tf2xla/kernels/replica_id_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/replica_id_op.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/macros.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc index f9985d52603..fa72710b4cc 100644 --- a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc @@ -13,29 +13,26 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" -#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index bf9a9150ea6..099e777196a 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -15,16 +15,16 @@ limitations under the License. // XLA-specific reshape Op. -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include + +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index 058938a46db..45b5cb42246 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -13,14 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index 2ceadaf79c5..7a56f1e3780 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -15,15 +15,17 @@ limitations under the License. // XLA-specific reverse Op. -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include + +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 4d73469fb18..8095f1e683e 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -13,14 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/roll_op.cc b/tensorflow/compiler/tf2xla/kernels/roll_op.cc index 5908dbebc86..1d9e11ab201 100644 --- a/tensorflow/compiler/tf2xla/kernels/roll_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/roll_op.cc @@ -13,10 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/slicing.h" -#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index 8431724f438..a7cf86de869 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -13,24 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/partial_tensor_shape.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index ce4a46b45c8..0b3b1e68257 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -13,18 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/lib/scatter.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index 97359f81eee..ba67574ccc0 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -13,13 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/lib/scatter.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index 70e4f96c0da..7cb6bdbb57f 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -13,17 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/lib/broadcast.h" -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/core/framework/bounds_check.h" -#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/bcast.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc index 84470b230d4..802cbba487a 100644 --- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc @@ -13,15 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/macros.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index e8149d3714f..ba4123d834e 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -15,18 +15,27 @@ limitations under the License. // XLA-specific sequence and range Ops. -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include +#include + +#include +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 5d2b08f424c..63e37897914 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -15,19 +15,33 @@ limitations under the License. // XLA-specific Shape Ops. +#include + +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/tf2xla/kernels/shape_util.h" #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/bounds_check.h" -#include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.cc b/tensorflow/compiler/tf2xla/kernels/shape_util.cc index b18e3f965c4..b0bef7a9d80 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_util.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_util.cc @@ -17,7 +17,14 @@ limitations under the License. #include +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.h b/tensorflow/compiler/tf2xla/kernels/shape_util.h index ca57be3d47b..5d9c4dae740 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_util.h +++ b/tensorflow/compiler/tf2xla/kernels/shape_util.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_SHAPE_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_SHAPE_UTIL_H_ -#include - #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/sharding_op.cc b/tensorflow/compiler/tf2xla/kernels/sharding_op.cc index f1ede35236a..60a32fd3dfa 100644 --- a/tensorflow/compiler/tf2xla/kernels/sharding_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/sharding_op.cc @@ -14,11 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/macros.h" namespace tensorflow { namespace { @@ -31,11 +31,12 @@ class ShardingOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaOp input = ctx->Input(0); - auto shape = - TensorShapeToXLAShape(ctx->input_xla_type(0), ctx->InputShape(0)); + auto shape_or = ctx->InputXlaShape(0); + OP_REQUIRES_OK(ctx, shape_or.status()); + ctx->SetOutput( 0, xla::CustomCall(ctx->builder(), /*call_target_name=*/"Sharding", - {input}, shape)); + {input}, shape_or.ValueOrDie())); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 17d0b87edda..18d2536c1f6 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -15,20 +15,16 @@ limitations under the License. // XLA-specific Slice Op. -#include "absl/types/span.h" -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/ops_util.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index 4f65c625d4f..c1569834a53 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -15,6 +15,14 @@ limitations under the License. // XLA-specific Ops for softmax. +#include + +#include +#include +#include +#include +#include + #include "absl/strings/match.h" #include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -24,10 +32,12 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/util/bcast.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc index 8cfd9850519..fde2500772d 100644 --- a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc @@ -17,6 +17,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/comparators.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index 52bed2670b4..86824526adb 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -13,10 +13,27 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index b72f33e2b7c..3a095909495 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -13,12 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/lib/data_format.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc index ff7f0ac6255..0b3cc7a5e7a 100644 --- a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc @@ -13,9 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/lib/scatter.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 7a0e240400b..eb69cbb4ae9 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -15,16 +15,16 @@ limitations under the License. // XLA-specific Ops for split. -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include + +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index a93d137e965..a14443a96c0 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -15,24 +15,31 @@ limitations under the License. // XLA Stack operators. -#include +#include #include +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/partial_tensor_shape.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/error_codes.pb.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc index 46d4b70606e..2d2bd85f120 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc @@ -15,24 +15,32 @@ limitations under the License. #include "tensorflow/core/kernels/stateful_random_ops.h" -#include +#include +#include +#include #include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h" #include "tensorflow/compiler/tf2xla/lib/random.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/prng.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index 13c3dbe489e..5ba189552e0 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -13,24 +13,31 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include +#include +#include +#include -#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2xla/lib/random.h" #include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/prng.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 9093175af75..1c4d6855e85 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -15,21 +15,28 @@ limitations under the License. #include "tensorflow/core/util/strided_slice_op.h" -#include "absl/types/span.h" +#include +#include + +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/literal_util.h" -#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/ops_util.h" -#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index b98b98ce50a..0e649c22a35 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -15,27 +15,35 @@ limitations under the License. // XLA TensorArray operators. -#include +#include #include +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_resource.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/partial_tensor_shape.h" -#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/error_codes.pb.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index 4af3d4233dd..bf710e359da 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -15,30 +15,30 @@ limitations under the License. // XLA TensorList operators. -#include #include +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/partial_tensor_shape.h" -#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc index 3db7bff0bc6..1a92dbb1956 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc @@ -15,16 +15,24 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" +#include + +#include "tensorflow/compiler/tf2xla/xla_expression.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/error_codes.pb.h" +#include "tensorflow/stream_executor/lib/statusor.h" // TensorList is represented by a tuple. // - The first part of the tuple is a buffer containing all the tensors, diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h index 7fac2d9dbab..6508aacb48d 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h @@ -18,7 +18,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index e8804cae037..da414827e87 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -16,21 +16,17 @@ limitations under the License. // XLA-specific Tile Op. #include + #include "absl/algorithm/container.h" -#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/lib/broadcast.h" -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/type_index.h" -#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc index 22cfd160088..574b96b7e14 100644 --- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -17,9 +17,11 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/sorting.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index c288d613e29..7e86bf762d6 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -13,15 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include + #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index 65569576d41..b35e7937f11 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -18,15 +18,26 @@ limitations under the License. // handles all transposes, while Eigen needs a restricted DoTranspose // helper. +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/lib/scatter.h" -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/bounds_check.h" -#include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/tridiagonal_ops.cc b/tensorflow/compiler/tf2xla/kernels/tridiagonal_ops.cc index 7ce2dd060f1..fe9bd3c8f14 100644 --- a/tensorflow/compiler/tf2xla/kernels/tridiagonal_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tridiagonal_ops.cc @@ -13,13 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/lib/tridiagonal.h" #include "tensorflow/core/framework/node_def_util.h" -#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 83a894e91fe..d09a7d71388 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -15,16 +15,14 @@ limitations under the License. // Native XLA implementations of simple unary Ops -#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc index 3c992ee8407..089cefc4a5b 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc @@ -13,20 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include +#include + #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" -#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/kernels/elu_op.h" #include "tensorflow/compiler/tf2xla/kernels/relu_op.h" -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc index 2d95f2f30a8..475fd056f27 100644 --- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc @@ -15,22 +15,15 @@ limitations under the License. // XLA Unpack operator. -#include #include -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index 60424f85840..f39f2c1508c 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -13,17 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include + #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/kernels/shape_util.h" #include "tensorflow/compiler/tf2xla/lib/scatter.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/slicing.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 21568a196ba..4b4cc21c34f 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -15,24 +15,37 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/while_op.h" -#include "absl/strings/str_split.h" -#include "tensorflow/compiler/tf2xla/const_analysis.h" +#include +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/kernels/if_while_utils.h" #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" -#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" +#include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.h b/tensorflow/compiler/tf2xla/kernels/while_op.h index bae187ca3ff..538101656f5 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.h +++ b/tensorflow/compiler/tf2xla/kernels/while_op.h @@ -16,8 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_WHILE_OP_H_ #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_WHILE_OP_H_ +#include +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { @@ -62,8 +66,6 @@ class XlaWhileOp : public XlaOpKernel { // This is not supported by default now since it may cause HBM memory // overheads. bool propagate_compile_time_consts_ = false; - - TF_DISALLOW_COPY_AND_ASSIGN(XlaWhileOp); }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc index ad8e707e111..3a1f0a0332a 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc @@ -13,16 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "absl/algorithm/container.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc index 7a8aec295a6..0e48f468a4d 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc @@ -13,13 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include +#include +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dequantize_op.cc index a30b4861f6b..c1f6c7f2d41 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_dequantize_op.cc @@ -13,14 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/quantize.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc index 40b15b5579a..cea385ee7fc 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc @@ -13,13 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc index a3c2eef993c..13526880a41 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc @@ -13,14 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "absl/algorithm/container.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc index 8b481d55a80..59aa8910a85 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc @@ -13,15 +13,26 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include +#include + #include "absl/algorithm/container.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/core/framework/function.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc index 7eaab3477c5..4ec8cfa03db 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc @@ -13,16 +13,28 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/kernels/while_op.h" +#include +#include +#include +#include #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/core/framework/function.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc index 233ac8e7b45..669e71f4f52 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc @@ -13,10 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/bits.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc index 8e9ed35783f..8195c9bdf06 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc @@ -13,11 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/lib/svd.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/bits.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index a1c45a4bf30..a394de1a9e8 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -94,6 +94,15 @@ TensorShape XlaOpKernelContext::InputShape(absl::string_view name) { return GetInputTensorByName(name).shape(); } +xla::StatusOr XlaOpKernelContext::InputXlaShape(int index) { + return builder()->GetShape(Input(index)); +} + +xla::StatusOr XlaOpKernelContext::InputXlaShape( + absl::string_view name) { + return builder()->GetShape(Input(name)); +} + DataType XlaOpKernelContext::input_type(int index) const { DataType type = context_->input_dtype(index); if (type == DT_UINT8) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index d72dd3972d3..8a384399e19 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -99,6 +99,10 @@ class XlaOpKernelContext { // Returns input `name` as a XlaOp. xla::XlaOp Input(absl::string_view name); + // Returns the xla input shape for a given index. + xla::StatusOr InputXlaShape(int index); + xla::StatusOr InputXlaShape(absl::string_view name); + // Returns true if all inputs are the same shape, otherwise sets the // status to a non-OK value and returns false. // Usage: if (!context->ValidateInputsAreSameShape(this)) return; diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 9808dd3d092..8604531889e 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -535,6 +535,16 @@ static void AllocateFlags() { bool_setter_for(&DebugOptions::set_xla_gpu_deterministic_reductions), flag_values->xla_gpu_deterministic_reductions(), "Always run deterministic reductions on GPU"), + tensorflow::Flag( + "xla_tpu_detect_nan", + bool_setter_for(&DebugOptions::set_xla_tpu_detect_nan), + flag_values->xla_tpu_detect_nan(), + "Trigger error on execution on TPU if a NAN value is detected"), + tensorflow::Flag( + "xla_tpu_detect_inf", + bool_setter_for(&DebugOptions::set_xla_tpu_detect_inf), + flag_values->xla_tpu_detect_inf(), + "Trigger error on execution on TPU if a INF value is detected"), }); ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects); } diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index 650659e4d46..495701eaac2 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -38,15 +38,18 @@ Performs a custom computation across replicas. `AllReduce(operand, computation, replica_group_ids, channel_id)` -| Arguments | Type | Semantics | -| ---------------- | -------------------- | -------------------------------- | -| `operand` | `XlaOp` | Array to reduce across replicas. | -| `computation` | `XlaComputation` | Reduction computation | -| `replica_groups` | vector of vectors of | Groups between which the | -: : `int64` : reductions are performed : -| `channel_id` | optional `int64` | Optional channel ID for | -: : : cross-module communication : +| Arguments | Type | Semantics | +| ---------------- | -------------------- | --------------------------------- | +| `operand` | `XlaOp` | Array or a non-empty tuple of | +: : : arrays to reduce across replicas. : +| `computation` | `XlaComputation` | Reduction computation | +| `replica_groups` | vector of vectors of | Groups between which the | +: : `int64` : reductions are performed : +| `channel_id` | optional `int64` | Optional channel ID for | +: : : cross-module communication : +- When `operand` is a tuple of arrays, the all-reduce is performed on each + element of the tuple. - `replica_groups` is a list of replica groups between which the reduction is performed (replica id for the current replica can be retrieved using [`ReplicaId`](#replicaid)). `replica_groups` must either be empty (in which @@ -60,7 +63,8 @@ Performs a custom computation across replicas. The output shape is the same as the input shape. For example, if there are two replicas and the operand has the value `[1.0, 2.5]` and `[3.0, 5.25]` respectively on the two replicas, then the output value from this op and -summation computation will be `[4.0, 7.75]` on both replicas. +summation computation will be `[4.0, 7.75]` on both replicas. If the input is a +tuple, the output is a tuple as well. Computing the result of `AllReduce` requires having one input from each replica, so if one replica executes a `AllReduce` node more times than another, then the diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 3c93ec96113..7a960eae165 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -87,6 +87,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", "//tensorflow/core:stream_executor", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", @@ -123,12 +124,17 @@ cc_library( hdrs = ["shared_device_buffer.h"], deps = [ ":event_pool", + ":local_device_state", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:transfer_manager", + "//tensorflow/core:lib", "//tensorflow/stream_executor:device_memory", "//tensorflow/stream_executor:device_memory_allocator", + "//tensorflow/stream_executor:event", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/synchronization", ], ) @@ -144,6 +150,8 @@ tf_cc_test( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/core:test_main", + "//tensorflow/stream_executor:device_memory", + "//tensorflow/stream_executor:device_memory_allocator", ], ) @@ -160,6 +168,7 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/core:lib", "//tensorflow/core:stream_executor", + "//tensorflow/stream_executor:event", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", ], @@ -171,6 +180,7 @@ cc_library( hdrs = ["local_client.h"], visibility = ["//tensorflow/compiler/xla:friends"], deps = [ + ":event_pool", ":local_device_state", ":shared_device_buffer", "//tensorflow/compiler/xla:executable_run_options", @@ -184,13 +194,18 @@ cc_library( "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/python/distributed:protocol_proto_cc", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", "//tensorflow/core:allocator", "//tensorflow/core:lib", "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/stream_executor:event", + "//tensorflow/stream_executor:stream", + "//tensorflow/stream_executor/lib", "@com_google_absl//absl/base", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -307,8 +322,8 @@ cc_library( "//tensorflow/compiler/xla/python/distributed:client", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla:util", - "//tensorflow/core:bfc_allocator", - "//tensorflow/core:gpu_mem_allocator", + "//tensorflow/core/common_runtime:bfc_allocator", + "//tensorflow/core/common_runtime/gpu:gpu_mem_allocator", "//tensorflow/stream_executor:tf_allocator_adapter", ] + if_cuda(["@local_config_nccl//:nccl"]), ) diff --git a/tensorflow/compiler/xla/python/cpu_device.cc b/tensorflow/compiler/xla/python/cpu_device.cc index 6b55eac0c08..404d9ca133d 100644 --- a/tensorflow/compiler/xla/python/cpu_device.cc +++ b/tensorflow/compiler/xla/python/cpu_device.cc @@ -42,7 +42,7 @@ StatusOr> GetCpuClient(bool asynchronous) { se::StreamExecutor* executor = client->backend().stream_executor(i).ValueOrDie(); auto device_state = absl::make_unique( - executor, client, /*synchronous_deallocation=*/true, asynchronous, + executor, client, LocalDeviceState::kSynchronous, asynchronous, /*allow_event_reuse=*/false); auto device = absl::make_unique(i, std::move(device_state)); devices.push_back(std::move(device)); diff --git a/tensorflow/compiler/xla/python/distributed/BUILD b/tensorflow/compiler/xla/python/distributed/BUILD index b38084c3395..5cada95390c 100644 --- a/tensorflow/compiler/xla/python/distributed/BUILD +++ b/tensorflow/compiler/xla/python/distributed/BUILD @@ -1,4 +1,5 @@ load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library_cc") +load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency") load("//tensorflow:tensorflow.bzl", "tf_cc_test") licenses(["notice"]) @@ -24,11 +25,11 @@ cc_library( srcs = ["key_value_store.cc"], hdrs = ["key_value_store.h"], deps = [ - "//tensorflow:grpc++", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", + tf_grpc_cc_dependency(), ], ) @@ -73,11 +74,11 @@ cc_library( ":protocol", ":protocol_proto_cc", ":util", - "//tensorflow:grpc++", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", + tf_grpc_cc_dependency(), ], ) @@ -85,8 +86,8 @@ cc_library( name = "util", hdrs = ["util.h"], deps = [ - "//tensorflow:grpc++", "//tensorflow/compiler/xla:status", + tf_grpc_cc_dependency(), ], ) @@ -97,8 +98,8 @@ cc_library( deps = [ ":client", ":service", - "//tensorflow:grpc++", "//tensorflow/compiler/xla:statusor", + tf_grpc_cc_dependency(), ], ) @@ -109,7 +110,6 @@ tf_cc_test( ":client", ":protocol_proto_cc", ":service", - "//tensorflow:grpc++", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/service:cpu_plugin", @@ -118,5 +118,6 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "@com_google_absl//absl/time", + tf_grpc_cc_dependency(), ], ) diff --git a/tensorflow/compiler/xla/python/dlpack.cc b/tensorflow/compiler/xla/python/dlpack.cc index 4ac992011f1..103d2ba5a59 100644 --- a/tensorflow/compiler/xla/python/dlpack.cc +++ b/tensorflow/compiler/xla/python/dlpack.cc @@ -241,7 +241,16 @@ StatusOr DeviceForDLContext(const PyLocalClient& client, StatusOr BufferToDLPackManagedTensor(PyLocalBuffer* buffer) { auto pack = absl::make_unique(); - pack->buffer = buffer->DeviceBuffer(); + // Block on outstanding operations, so that it is safe to read or mutate the + // returned buffer. + StatusOr> buffer_or = + buffer->Release(/*wait_for_operations_to_complete=*/true); + if (!buffer_or.ok()) { + return InvalidArgument( + "Buffer synchronization failed converting to DLPack tensor: %s", + buffer_or.status().ToString()); + } + pack->buffer = buffer_or.ConsumeValueOrDie(); if (!pack->buffer) { return InvalidArgument( "Cannot convert deleted/invalid buffer to DLPack tensor."); @@ -281,8 +290,6 @@ StatusOr BufferToDLPackManagedTensor(PyLocalBuffer* buffer) { PyErr_Clear(); } }); - - TF_RETURN_IF_ERROR(buffer->BlockHostUntilReady()); return capsule; } @@ -330,9 +337,8 @@ StatusOr> DLPackManagedTensorToBuffer( absl::Span> definition_events; auto device_buffer = std::make_shared( /*allocator=*/nullptr, dlmt->dl_tensor.ctx.device_id, - std::initializer_list{buffer}, - /*children=*/std::vector>{}, - definition_events, std::move(on_delete_callback)); + std::initializer_list{buffer}, definition_events, + std::move(on_delete_callback)); // We have taken ownership of the array inside the capsule; make sure the // capsule it cannot be used again. diff --git a/tensorflow/compiler/xla/python/event_pool.cc b/tensorflow/compiler/xla/python/event_pool.cc index 4edb41fd41f..c7b52f523d9 100644 --- a/tensorflow/compiler/xla/python/event_pool.cc +++ b/tensorflow/compiler/xla/python/event_pool.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/python/event_pool.h" #include "absl/memory/memory.h" +#include "absl/synchronization/mutex.h" #include "tensorflow/compiler/xla/status_macros.h" namespace xla { @@ -27,7 +28,8 @@ EventPool::Handle::~Handle() { } } -EventPool::EventPool(bool allow_reuse) : allow_reuse_(allow_reuse) {} +EventPool::EventPool(bool allow_reuse) + : allow_reuse_(allow_reuse), next_sequence_number_(0) {} StatusOr EventPool::ThenAllocateAndRecordEvent( se::Stream* stream) { @@ -45,7 +47,11 @@ StatusOr EventPool::ThenAllocateAndRecordEvent( event.event_ = absl::make_unique(stream->parent()); TF_RET_CHECK(event.event_->Init()) << "Event initialization failed"; } - stream->ThenRecordEvent(event.event_.get()); + { + absl::MutexLock lock(&mu_); + stream->ThenRecordEvent(event.event_.get()); + event.sequence_number_ = next_sequence_number_++; + } return event; } diff --git a/tensorflow/compiler/xla/python/event_pool.h b/tensorflow/compiler/xla/python/event_pool.h index f858b5edef8..bda3fb6baff 100644 --- a/tensorflow/compiler/xla/python/event_pool.h +++ b/tensorflow/compiler/xla/python/event_pool.h @@ -38,13 +38,26 @@ class EventPool { Handle& operator=(const Handle&) = delete; Handle& operator=(Handle&&) = default; + // There is a total order on events handed out by the event pool. The most + // useful aspect of this total order is that two events returned by + // ThenAllocateAndRecordEvent on the same stream can be compared to see + // which was recorded earlier on that stream. + inline bool operator<(const Handle& rhs) const { + return sequence_number_ < rhs.sequence_number_; + } + inline bool operator>(const Handle& rhs) const { return rhs < *this; } + inline bool operator<=(const Handle& rhs) const { return !(*this > rhs); } + inline bool operator>=(const Handle& rhs) const { return !(*this < rhs); } + se::Event* event() const { return event_.get(); } + uint64 sequence_number() const { return sequence_number_; } private: friend class EventPool; EventPool* pool_ = nullptr; std::unique_ptr event_; + uint64 sequence_number_; }; // Initializes a new EventPool. If `allow_reuse` is true, then events will be @@ -69,6 +82,7 @@ class EventPool { absl::Mutex mu_; std::stack> free_events_ TF_GUARDED_BY(mu_); + uint64 next_sequence_number_ TF_GUARDED_BY(mu_); }; } // namespace xla diff --git a/tensorflow/compiler/xla/python/local_client.cc b/tensorflow/compiler/xla/python/local_client.cc index 128661ae8bd..1f346d4f4cc 100644 --- a/tensorflow/compiler/xla/python/local_client.cc +++ b/tensorflow/compiler/xla/python/local_client.cc @@ -20,7 +20,7 @@ limitations under the License. // // Computations and host-to-device transfers do not need to block the host // waiting for the operation to complete but instead return control to the host -// immediately. This allows Python logic to overlap with device-side +// immediately. This allows client logic to overlap with device-side // computation. // // For a good user experience, we must be careful only to enqueue operations @@ -54,50 +54,50 @@ limitations under the License. // // Synchronization between streams occurs via BufferDefinitionEvents that // describe when the contents of a logical buffer are known to be valid on -// a particular stream. +// a particular stream, and when a buffer's uses have all completed. // // Synchronous vs asynchronous deallocation: // ----------------------------------------- // -// In asynchronous deallocation mode (currently only enabled on TPU), the client -// need only keep buffers alive from its perspective until all operations that -// touch those buffers have been enqueued. -// The allocator and lower-level runtime is responsible for keeping buffers -// alive (if that is needed) from the perspective of the device until any -// device-side work actually completes. The client's use of the device allocator -// thereby corresponds to a view of the tail of the compute stream instead of -// its head. -// -// In synchronous deallocation mode the client is responsible for keeping -// buffers alive until all device-side activity that consumes those buffers has -// ceased. This is the case for CPU since HostExecutor performs allocation -// and deallocation eagerly. In this mode, the client's use of the device -// allocator is logically synchronized to the head of the compute stream, not -// the tail. +// See the comment on LocalDeviceState::AllocationModel for a discussion of the +// different allocation semantics on CPU, GPU, and TPU. #include "tensorflow/compiler/xla/python/local_client.h" +#include #include #include #include #include "absl/base/casts.h" +#include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/python/distributed/protocol.pb.h" +#include "tensorflow/compiler/xla/python/event_pool.h" +#include "tensorflow/compiler/xla/python/local_device_state.h" #include "tensorflow/compiler/xla/python/shared_device_buffer.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" +#include "tensorflow/stream_executor/event.h" +#include "tensorflow/stream_executor/lib/statusor.h" +#include "tensorflow/stream_executor/stream.h" namespace xla { @@ -193,22 +193,231 @@ StatusOr PyLocalClient::GetDefaultDeviceAssignment( num_partitions); } +namespace { + +// Ensures that it is safe to deallocate any buffers that have been enqueued in +// an operation on stream. Called only in rare error cases that are triggered +// during enqueue. These cases generally correspond to resource exhaustion. +void StallStreamOnError(LocalDeviceState* local_device, se::Stream* stream) { + switch (local_device->allocation_model()) { + case LocalDeviceState::kAsynchronous: + // We can safely deallocate any dangling buffers immediately. NOTE: this + // assumes that any buffers enqueued on stream are local to stream's + // executor, and manual action may be needed if that condition is not met. + break; + + case LocalDeviceState::kComputeSynchronized: + // This will stall computation but that's ok in this very rare error + // case. + if (stream != local_device->compute_stream()) { + local_device->compute_stream()->ThenWaitFor(stream); + } + break; + + case LocalDeviceState::kSynchronous: + // This will stall the calling thread but that's ok in this very rare + // error case. If the stall fails just crash, since we have no other + // way to synchronize. + TF_CHECK_OK(stream->BlockHostUntilDone()); + break; + } +} + +// Does all necessary bookkeeping, after a buffer is successfully enqueued onto +// a stream, to ensure that the buffer will be kept alive until its use on that +// stream is complete. +// +// device_buffer: the buffer that was enqueued. +// buffer_local_device: the device the buffer was allocated on. +// stream_local_device: the device that manages usage_stream. +// event: an event that was recorded on usage_stream +// after the usage of device_buffer was enqueued. +// usage_stream: the stream the operation using device_buffer +// was enqueued on. +// prefer_to_retain_reference: relevant only for the compute synchronous +// allocation model. If true, retain a reference +// to device_buffer until after the operation +// completes. If false then the compute stream +// will have to be synchronized past event before +// device_buffer can be freed. +// +// prefer_to_retain_reference encodes a heuristic set by the caller for the +// compute synchronous model: +// +// Generally when a buffer is the destination of a copy to a device, it will +// subsequently be used on the device's compute stream before being freed. In +// that case, there is no need to retain a reference to the buffer. If the +// buffer is freed before being used on the compute stream, the free will be +// delayed until the host knows that event has completed, but this is expected +// to be uncommon. +// +// When a buffer is the source of a copy from a device, we need to either retain +// a reference to the buffer until the copy completes or serialize the compute +// stream behind the copy. It is often better to retain a reference since while +// that keeps memory alive longer, it avoids stalling the compute stream. +void RecordUsage(SharedDeviceBuffer::ScopedUsage device_buffer, + LocalDeviceState* buffer_local_device, + LocalDeviceState* stream_local_device, + std::shared_ptr event, + se::Stream* usage_stream, bool prefer_to_retain_reference) { + bool retain_buffer_until_completion = + // If the buffer wasn't allocated on the same device as the stream, always + // retain a reference. + (stream_local_device != buffer_local_device) || + // In the synchronous allocation model, always retain a reference. + (stream_local_device->allocation_model() == + LocalDeviceState::kSynchronous) || + // In the compute synchronous model, use the caller's heuristic. + (stream_local_device->allocation_model() == + LocalDeviceState::kComputeSynchronized && + prefer_to_retain_reference); + if (retain_buffer_until_completion) { + buffer_local_device->ThenRelease(usage_stream, + device_buffer.buffer_reference()); + } + device_buffer.Convert(usage_stream, event, retain_buffer_until_completion); +} + +// Allocates the device buffers for a buffer that will be used as the +// destination of a copy, either from the host or another device. copy_stream +// may be nullptr, e.g., when allocating a buffer for a cross-host copy. If the +// buffer is a tuple then the tuple tables are allocated, and all necessary +// synchronization for them is dealt with, before the buffer is returned. +// +// It is safe to delete the returned PyLocalBuffer without further +// synchronization if an error occurs before the buffer is used. +StatusOr> AllocateDestinationBuffer( + const Shape& on_host_shape, Device* device, LocalDeviceState* local_device, + se::Stream* copy_stream, PyLocalClient* client) { + if (on_host_shape.IsTuple() && on_host_shape.tuple_shapes_size() == 0) { + return InvalidArgument("Can't make a buffer from an empty tuple"); + } + + TransferManager* transfer_manager = + client->client()->backend().transfer_manager(); + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer dst_buffer, + transfer_manager->AllocateScopedShapedBuffer( + on_host_shape, client->allocator(), local_device->device_ordinal())); + if (local_device->allocation_model() == + LocalDeviceState::kComputeSynchronized) { + CHECK(copy_stream != nullptr); + copy_stream->ThenWaitFor(local_device->compute_stream()); + } else { + DCHECK(transfer_manager->CanShapedBufferBeAccessedNow( + local_device->compute_stream()->parent(), dst_buffer)); + } + Shape on_device_shape = dst_buffer.on_device_shape(); + + absl::InlinedVector, 2> + definition_events; + // We always have at least one definition event, for the copy completing to + // the device buffers. + definition_events.emplace_back(std::make_shared()); + se::Stream* tuple_table_stream = local_device->host_to_device_stream(); + if (on_device_shape.IsTuple()) { + // We also need to copy the tuple tables, so we'll have a second defintion + // event for that copy to complete. + if (tuple_table_stream != copy_stream) { + if (local_device->allocation_model() == + LocalDeviceState::kComputeSynchronized) { + tuple_table_stream->ThenWaitFor(local_device->compute_stream()); + } else { + DCHECK(transfer_manager->CanShapedBufferBeAccessedNow( + local_device->compute_stream()->parent(), dst_buffer)); + } + } + + TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync( + tuple_table_stream, dst_buffer)); + // CAUTION: From this point onwards we need to be careful about returning + // from error cases because we have started a transfer and must not allow + // dst_buffer to be freed too soon in the non-async allocation models. + + definition_events.emplace_back(std::make_shared()); + StatusOr event_or = + local_device->event_pool().ThenAllocateAndRecordEvent( + tuple_table_stream); + if (!event_or.ok()) { + StallStreamOnError(local_device, tuple_table_stream); + return event_or.status(); + } + definition_events[1]->SetDefinitionEvent(event_or.ConsumeValueOrDie(), + tuple_table_stream); + } + std::shared_ptr dst_device_buffer = + SharedDeviceBuffer::FromScopedShapedBuffer(&dst_buffer, + definition_events); + if (on_device_shape.IsTuple()) { + // Add a usage hold for the tuple table write and immediately convert it to + // the appropriate form of synchronization. prefer_to_retain_reference=false + // means don't retain a memory reference until the transfer is complete when + // using the ComputeSynchronized allocation model. This is a heuristic + // because in the common case destination buffers will be used on the + // compute stream and therefore don't require any synchronization before + // being freed. If the buffer is allocated and never used, the free will + // take longer and this is assumed to be ok. + RecordUsage( + std::move(SharedDeviceBuffer::ScopedUsage().Acquire(dst_device_buffer)), + local_device, local_device, definition_events[1], tuple_table_stream, + /*prefer_to_retain_reference=*/false); + } + + return absl::make_unique(on_host_shape, on_device_shape, + std::move(dst_device_buffer), client, + device); +} + +// Adds necessary synchronization after a copy has been enqueued to a buffer. +// definition_event was added when the buffer was allocated, but has not yet +// had an event recorded. +Status AddDestinationBufferSynchronization( + LocalDeviceState* local_device, + SharedDeviceBuffer::ScopedUsage device_buffer, + std::shared_ptr definition_event, + se::Stream* copy_stream) { + StatusOr event_or = + local_device->event_pool().ThenAllocateAndRecordEvent(copy_stream); + if (!event_or.ok()) { + StallStreamOnError(local_device, copy_stream); + return event_or.status(); + } + definition_event->SetDefinitionEvent(event_or.ConsumeValueOrDie(), + copy_stream); + // prefer_to_retain_reference=false means don't retain a memory reference + // until the transfer is complete when using the ComputeSynchronized + // allocation model. This is a heuristic because in the common case + // destination buffers will be used on the compute stream and therefore don't + // require any synchronization before being freed. If the buffer is allocated + // and never used, the free will take longer and this is assumed to be ok. + RecordUsage(std::move(device_buffer), local_device, local_device, + definition_event, copy_stream, + /*prefer_to_retain_reference=*/false); + return Status::OK(); +} + +} // namespace + /* static */ StatusOr> PyLocalBuffer::FromHostBuffer( const void* data, const Shape& shape, bool force_copy, std::shared_ptr buffer_reference, PyLocalClient* client, Device* device) { - tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromLiterals"); - VLOG(2) << "PyLocalBuffer::FromLiterals: shape: " << shape.ToString() + tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromHostBuffer"); + VLOG(2) << "PyLocalBuffer::FromHostBuffer: shape: " << shape.ToString() << " device: " << device->DebugString(); + if (shape.IsTuple()) { + return InvalidArgument("Use FromHostLiteral to transfer a tuple"); + } TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, device->GetLocalDeviceState()); // If we are on the host platform and the input buffer is sufficiently - // aligned, we can simply point to the NumPy array's data without any further + // aligned, we can simply point to the input array's data without any further // copies. We require a 64-byte alignment because XLA may generate AVX512 - // code which requires it. Unfortunately NumPy's allocator doesn't align - // quite as aggressively, so there's a high chance this test will fail. + // code which requires it. If the client allocator doesn't align quite as + // aggressively, (e.g., NumPy doesn't) there's a high chance this test will + // fail. static constexpr int kMinimumAlignment = 64; if (!force_copy && ((absl::bit_cast(data) & (kMinimumAlignment - 1)) == 0) && @@ -222,57 +431,50 @@ StatusOr> PyLocalBuffer::FromHostBuffer( absl::Span> definition_events; auto device_buffer = std::make_shared( /*allocator=*/nullptr, local_device->device_ordinal(), - std::initializer_list{buffer}, - /*children=*/std::vector>{}, - definition_events, std::move(on_delete_callback)); + std::initializer_list{buffer}, definition_events, + std::move(on_delete_callback)); return absl::make_unique( shape, shape, std::move(device_buffer), client, device); } TransferManager* transfer_manager = client->client()->backend().transfer_manager(); - se::DeviceMemoryAllocator* allocator = client->allocator(); TF_ASSIGN_OR_RETURN(Shape compact_shape, transfer_manager->ChooseCompactLayoutForShape(shape)); TF_ASSIGN_OR_RETURN( - ScopedShapedBuffer scoped_buffer, - transfer_manager->AllocateScopedShapedBuffer( - compact_shape, allocator, local_device->device_ordinal())); + std::unique_ptr py_buffer, + AllocateDestinationBuffer(compact_shape, device, local_device, + local_device->host_to_device_stream(), client)); - // Make the host to device stream wait for the newly allocated buffer to be - // available on the compute stream. We schedule this wait synchronously; while - // not strictly necessary, we must not create stream dependency cycles, and - // adding the wait synchronously avoids any chance of any dependent - // computations that depend on this transfer being enqueued on the compute - // stream. - if (!transfer_manager->CanShapedBufferBeAccessedNow( - local_device->host_to_device_stream()->parent(), scoped_buffer)) { - local_device->host_to_device_stream()->ThenWaitFor( - local_device->compute_stream()); - } + SharedDeviceBuffer::ScopedUsage device_buffer( + py_buffer->GetBufferWithUsageHold()); + CHECK(device_buffer.IsValid()); - std::shared_ptr definition_event = - std::make_shared(); - std::shared_ptr device_buffer = - SharedDeviceBuffer::FromScopedShapedBuffer(&scoped_buffer, - {definition_event}); - Shape on_device_shape = scoped_buffer.on_device_shape(); - - auto transfer_h2d = [client, transfer_manager, local_device, device_buffer, - shape, compact_shape, on_device_shape, data, + // The host to device transfer is performed on a thread pool, mostly because + // it includes linearization that may be slow. + // TODO(misard) assess if it would be preferable to introduce a heuristic to + // put the transfer into the calling thread for small literals. + auto transfer_h2d = [client, transfer_manager, local_device, + device_buffer_ref{device_buffer.Release()}, data, shape, + compact_shape, + on_device_shape{py_buffer->on_device_shape()}, buffer_reference{std::move(buffer_reference)}]() { - // This function uses TF_CHECK_OK and ValueOrDie() since we have no way to - // report failures from a callback. However, the operations here are + SharedDeviceBuffer::ScopedUsage device_buffer; + device_buffer.Transfer(device_buffer_ref); + // This function uses TF_CHECK_OK and ValueOrDie() since we have no way + // to report failures from a callback. However, the operations here are // unlikely to fail and not recoverable even if we were to fail: DMAs to - // memory that has already been allocated, and a possible Event allocation. + // memory that has already been allocated, and a possible Event + // allocation. + ShapedBuffer buffer = device_buffer->AsShapedBuffer( compact_shape, on_device_shape, client->client()->platform()); - TF_CHECK_OK(transfer_manager->WriteTupleIndexTablesAsync( - local_device->host_to_device_stream(), buffer)); + std::shared_ptr staging_buffer; // If applicable on the backend, stage the transfer via host memory - // allocated via the host_memory_allocator. On GPU, this is pinned memory. + // allocated via the host_memory_allocator. On GPU, this is pinned + // memory. if (client->host_memory_allocator()) { int64 size = ShapeUtil::ByteSizeOf(shape); void* ptr = client->host_memory_allocator()->AllocateRaw( @@ -292,150 +494,72 @@ StatusOr> PyLocalBuffer::FromHostBuffer( local_device->host_to_device_stream(), literal, buffer)); } - EventPool::Handle event = - local_device->event_pool() - .ThenAllocateAndRecordEvent(local_device->host_to_device_stream()) - .ValueOrDie(); - - // Sets the buffer definition event. Note: this has the side effect of - // unblocking any host threads that may have been waiting to consume the - // buffer. - device_buffer->definition_events()[0]->SetDefinitionEvent( - std::move(event), local_device->host_to_device_stream()); - - if (local_device->synchronous_deallocation()) { - local_device->ThenRelease(local_device->host_to_device_stream(), - device_buffer); - } + std::shared_ptr event = + device_buffer->definition_events()[0]; + TF_CHECK_OK(AddDestinationBufferSynchronization( + local_device, std::move(device_buffer), event, + local_device->host_to_device_stream())); local_device->ThenRelease( local_device->host_to_device_stream(), std::make_pair(buffer_reference, std::move(staging_buffer))); }; client->h2d_transfer_pool()->Schedule(transfer_h2d); - return absl::make_unique( - compact_shape, std::move(on_device_shape), std::move(device_buffer), - client, device); + return py_buffer; } -/* static */ StatusOr> PyLocalBuffer::MakeTuple( - absl::Span buffers, PyLocalClient* client, - Device* device) { +/* static */ +StatusOr> PyLocalBuffer::FromHostLiteral( + const LiteralSlice& literal, PyLocalClient* client, Device* device) { + tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromHostLiteral"); + VLOG(2) << "PyLocalBuffer::FromHostLiteral: shape: " + << literal.shape().ToString() << " device: " << device->DebugString(); TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, device->GetLocalDeviceState()); - std::vector host_shapes; - std::vector device_shapes; - std::vector> device_buffers; - host_shapes.reserve(buffers.size()); - device_shapes.reserve(buffers.size()); - device_buffers.reserve(buffers.size()); - for (const PyLocalBuffer* buffer : buffers) { - if (buffer->device() != device) { - return InvalidArgument( - "Tuple elements must be on the same device; %s vs %s", - buffer->device()->DebugString(), device->DebugString()); - } - std::shared_ptr device_buffer = buffer->DeviceBuffer(); - if (!device_buffer) { - return InvalidArgument( - "Invalid buffer passed to MakeTuple() as argument %d.", - device_buffers.size()); - } - host_shapes.push_back(buffer->on_host_shape()); - device_shapes.push_back(buffer->on_device_shape()); - device_buffers.push_back(std::move(device_buffer)); - } - se::DeviceMemoryAllocator* allocator = client->allocator(); + TransferManager* transfer_manager = client->client()->backend().transfer_manager(); - - Shape on_host_shape = ShapeUtil::MakeTupleShape(host_shapes); - auto definition_event = std::make_shared(); TF_ASSIGN_OR_RETURN( - std::shared_ptr tuple_buffer, - SharedDeviceBuffer::MakeTuple( - device_buffers, on_host_shape, transfer_manager, allocator, - local_device->device_ordinal(), {definition_event})); - auto buffer = absl::make_unique( - std::move(on_host_shape), ShapeUtil::MakeTupleShape(device_shapes), - tuple_buffer, std::move(client), std::move(device)); + Shape compact_shape, + transfer_manager->ChooseCompactLayoutForShape(literal.shape())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr py_buffer, + AllocateDestinationBuffer(compact_shape, device, local_device, + local_device->host_to_device_stream(), client)); - // TODO(phawkins): extend TransferManager so we do not need to form a full - // ShapedBuffer just to write the root tuple index table. - TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, buffer->AsShapedBuffer()); - if (!transfer_manager->CanShapedBufferBeAccessedNow( - local_device->host_to_device_stream()->parent(), shaped_buffer)) { - // Wait for the compute stream so that memory allocations are synchronized. - local_device->host_to_device_stream()->ThenWaitFor( - local_device->compute_stream()); - } - TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable( - local_device->host_to_device_stream(), shaped_buffer)); + SharedDeviceBuffer::ScopedUsage device_buffer( + py_buffer->GetBufferWithUsageHold()); + CHECK(device_buffer.IsValid()); - TF_ASSIGN_OR_RETURN(EventPool::Handle event, - local_device->event_pool().ThenAllocateAndRecordEvent( - local_device->host_to_device_stream())); - definition_event->SetDefinitionEvent(std::move(event), - local_device->host_to_device_stream()); + // The host to device transfer is performed on a thread pool, mostly because + // it includes linearization that may be slow. + // TODO(misard) assess if it would be preferable to introduce a heuristic to + // put the transfer into the calling thread for small literals. + auto transfer_h2d = [client, transfer_manager, local_device, + device_buffer_ref{device_buffer.Release()}, literal, + compact_shape, + on_device_shape{py_buffer->on_device_shape()}]() { + SharedDeviceBuffer::ScopedUsage device_buffer; + device_buffer.Transfer(device_buffer_ref); + // This function uses TF_CHECK_OK and ValueOrDie() since we have no way + // to report failures from a callback. However, the operations here are + // unlikely to fail and not recoverable even if we were to fail: DMAs to + // memory that has already been allocated, and a possible Event + // allocation. - if (local_device->synchronous_deallocation()) { - local_device->ThenRelease(local_device->host_to_device_stream(), - std::move(tuple_buffer)); - } - return buffer; -} + ShapedBuffer buffer = device_buffer->AsShapedBuffer( + compact_shape, on_device_shape, client->client()->platform()); + TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync( + local_device->host_to_device_stream(), literal, buffer)); -StatusOr>> -MakeCrossHostReceiveBuffersHelper(absl::Span shapes, - PyLocalClient* client, Device* device) { - TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, - device->GetLocalDeviceState()); - TransferManager* transfer_manager = - client->client()->backend().transfer_manager(); - std::vector> buffers; - buffers.reserve(shapes.size()); - se::Stream* host_to_device_stream = local_device->host_to_device_stream(); - for (const auto& shape : shapes) { - TF_ASSIGN_OR_RETURN( - ScopedShapedBuffer scoped_buffer, - transfer_manager->AllocateScopedShapedBuffer( - shape, client->allocator(), local_device->device_ordinal())); - - if (!transfer_manager->CanShapedBufferBeAccessedNow( - local_device->compute_stream()->parent(), scoped_buffer)) { - return Unimplemented( - "Cross host receive not enabled unless deallocations are deferred"); - } - - absl::InlinedVector, 2> - definition_events; - - if (scoped_buffer.on_device_shape().IsTuple()) { - TF_CHECK_OK(transfer_manager->WriteTupleIndexTablesAsync( - host_to_device_stream, scoped_buffer)); - definition_events = {std::make_shared(), - std::make_shared()}; - TF_ASSIGN_OR_RETURN(EventPool::Handle event, - local_device->event_pool().ThenAllocateAndRecordEvent( - host_to_device_stream)); - definition_events[1]->SetDefinitionEvent(std::move(event), - host_to_device_stream); - } else { - definition_events = {std::make_shared()}; - } - - std::shared_ptr device_buffer = - SharedDeviceBuffer::FromScopedShapedBuffer(&scoped_buffer, - definition_events); - Shape on_device_shape = scoped_buffer.on_device_shape(); - - auto buffer = absl::make_unique( - shape, std::move(on_device_shape), std::move(device_buffer), client, - device); - - buffers.push_back(std::move(buffer)); - } - return buffers; + std::shared_ptr event = + device_buffer->definition_events()[0]; + TF_CHECK_OK(AddDestinationBufferSynchronization( + local_device, std::move(device_buffer), event, + local_device->host_to_device_stream())); + }; + client->h2d_transfer_pool()->Schedule(transfer_h2d); + return py_buffer; } /*static*/ void PyLocalBuffer::MakeCrossHostReceiveBuffers( @@ -446,14 +570,28 @@ MakeCrossHostReceiveBuffersHelper(absl::Span shapes, "shapes parameter empty in MakeCrossHostReceiveBuffers")); return; } - auto buffer_or = MakeCrossHostReceiveBuffersHelper(shapes, client, device); - if (!buffer_or.ok()) { - notifier(buffer_or.status()); + + auto local_device_or = device->GetLocalDeviceState(); + if (!local_device_or.ok()) { + notifier(local_device_or.status()); return; } + LocalDeviceState* local_device = local_device_or.ConsumeValueOrDie(); - client->EnqueueCrossHostReceive(buffer_or.ConsumeValueOrDie(), - std::move(notifier)); + std::vector> buffers; + buffers.reserve(shapes.size()); + for (const auto& shape : shapes) { + StatusOr> buffer_or = + AllocateDestinationBuffer(shape, device, local_device, + /*copy_stream=*/nullptr, client); + if (!buffer_or.ok()) { + notifier(buffer_or.status()); + return; + } + buffers.push_back(buffer_or.ConsumeValueOrDie()); + } + + client->EnqueueCrossHostReceive(std::move(buffers), std::move(notifier)); } PyLocalBuffer::PyLocalBuffer(Shape on_host_shape, Shape on_device_shape, @@ -465,67 +603,156 @@ PyLocalBuffer::PyLocalBuffer(Shape on_host_shape, Shape on_device_shape, device_(device), device_buffer_(std::move(device_buffer)) {} +PyLocalBuffer::~PyLocalBuffer() { Delete(); } + +StatusOr> PyLocalBuffer::Release( + bool wait_for_operations_to_complete) { + std::shared_ptr device_buffer; + { + absl::MutexLock lock(&mu_); + if (device_buffer_ == nullptr) { + return std::shared_ptr(); + } + host_value_ = nullptr; + std::swap(device_buffer_, device_buffer); + } + SharedDeviceBuffer::StreamAndEventContainer events = + device_buffer->LockUseAndTransferUsageEvents(); + LocalDeviceState* local_device_state = device_->local_device_state(); + if (wait_for_operations_to_complete) { + std::unique_ptr stream; + for (const auto& stream_and_event : events) { + if (!stream_and_event.event->IsComplete()) { + if (stream == nullptr) { + stream = local_device_state->BorrowStreamFromPool(); + } + stream_and_event.event->WaitForEventOnStream(stream.get()); + } + } + if (stream != nullptr) { + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + local_device_state->ReturnStreamToPool(std::move(stream)); + } + } else { + if (local_device_state->allocation_model() == + LocalDeviceState::kComputeSynchronized) { + std::unique_ptr block_stream; + for (const auto& stream_and_event : events) { + // We only need to do something for events that didn't already acquire a + // reference to the buffer, and also which the compute stream didn't + // already wait for. Based on our heuristics this rare case should only + // occur when a buffer was copied to a device and then never used there. + // In that case we get a new stream and use it to hold onto a reference + // to the buffer until the events are complete. + if (!stream_and_event.reference_held && + !stream_and_event.event->DefinedOn( + local_device_state->compute_stream()) && + !stream_and_event.event->IsComplete()) { + if (block_stream == nullptr) { + block_stream = local_device_state->BorrowStreamFromPool(); + } + stream_and_event.event->WaitForEventOnStream(block_stream.get()); + } + } + if (block_stream != nullptr) { + local_device_state->ThenExecuteOnCallbackThread( + block_stream.get(), + [device_buffer, block_stream_ptr{block_stream.release()}, + local_device_state]() { + local_device_state->ReturnStreamToPool( + std::unique_ptr(block_stream_ptr)); + }); + } + } + } + return device_buffer; +} + void PyLocalBuffer::Delete() { + // When wait_for_reads_to_complete is false, Release should never fail. + TF_CHECK_OK(Release(/*wait_for_operations_to_complete=*/false).status()); +} + +bool PyLocalBuffer::IsDeleted() { absl::MutexLock lock(&mu_); - device_buffer_ = nullptr; - host_value_ = nullptr; + return device_buffer_ == nullptr; } Status PyLocalBuffer::CopyToHostAsync() { - std::shared_ptr device_buffer; + if (IsEmptyTuple()) { + return InvalidArgument("CopyToHostAsync called on empty tuple"); + } + SharedDeviceBuffer::ScopedUsage device_buffer; std::shared_ptr host_value; + LocalDeviceState* local_device = device_->local_device_state(); + se::Stream* stream = local_device->GetDeviceToHostStream(); { absl::MutexLock lock(&mu_); - if (!device_buffer_) { + if (device_buffer_ == nullptr) { return InvalidArgument("CopyToHostAsync() called on invalid buffer."); } - device_buffer = device_buffer_; - if (host_value_) { // The host value has already been requested or is available. return Status::OK(); } host_value = host_value_ = std::make_shared(); + device_buffer.Acquire(device_buffer_); } - se::Stream* stream = device_->local_device_state()->GetDeviceToHostStream(); WaitForBufferDefinitionEventsOnStream(*device_buffer, stream); host_value->value = std::make_shared(on_host_shape_); - TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, AsShapedBuffer()); + ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer( + on_host_shape_, on_device_shape_, client_->client()->platform()); client_->client()->backend().transfer_manager()->TransferLiteralFromDevice( stream, shaped_buffer, host_value->value.get(), [host_value](Status done_status) { host_value->status = done_status; host_value->ready.Notify(); }); + + auto usage_event = std::make_shared(); + StatusOr event_or = + local_device->event_pool().ThenAllocateAndRecordEvent(stream); + if (!event_or.ok()) { + // Allocating the event failed, so synchronize + // the host on the copy and then drop the device buffer hold. + StallStreamOnError(local_device, stream); + return event_or.status(); + } + usage_event->SetDefinitionEvent(event_or.ConsumeValueOrDie(), stream); + // When using the ComputeSynchronized allocation model, retain a reference to + // the device_buffer until the copy completes, to ensure that the buffer isn't + // deleted or donated while it is still in use. The choice of retaining a + // reference at the host is a heuristic; the alternative is to ensure, before + // freeing the buffer, that the compute stream is synchronized past the + // transfer, but it seems better to hold onto the buffer too long than to + // stall the compute stream, particularly since the overwhelmingly common + // use case of CopyToHostAsync will hold onto the reference long enough to + // read the buffer in a subsequent call to ToLiteral. + RecordUsage(std::move(device_buffer), local_device, local_device, usage_event, + stream, + /*prefer_to_retain_reference=*/true); return Status::OK(); } StatusOr> PyLocalBuffer::ToLiteral() { tensorflow::profiler::TraceMe traceme("PyLocalBuffer::ToLiteral"); - std::shared_ptr device_buffer = DeviceBuffer(); - if (!device_buffer) { - return InvalidArgument("ToLiteral() called on invalid buffer."); - } - TF_RETURN_IF_ERROR(CopyToHostAsync()); std::shared_ptr host_value; { absl::MutexLock lock(&mu_); host_value = host_value_; } + if (host_value == nullptr) { + return InvalidArgument("ToLiteral called on invalid buffer"); + } host_value->ready.WaitForNotification(); TF_RETURN_IF_ERROR(host_value->status); return host_value->value; } -std::shared_ptr PyLocalBuffer::DeviceBuffer() const { - absl::MutexLock lock(&mu_); - return device_buffer_; -} - StatusOr PyLocalBuffer::AsShapedBuffer() const { absl::MutexLock lock(&mu_); - if (!device_buffer_) { + if (device_buffer_ == nullptr) { return InvalidArgument( "Attempted to fetch value of invalid/deleted buffer."); } @@ -533,106 +760,133 @@ StatusOr PyLocalBuffer::AsShapedBuffer() const { client_->client()->platform()); } -StatusOr>> -PyLocalBuffer::DestructureTuple() const { - tensorflow::profiler::TraceMe traceme("PyLocalBuffer::DestructureTuple"); +SharedDeviceBuffer::ScopedUsage PyLocalBuffer::GetBufferWithUsageHold() { absl::MutexLock lock(&mu_); - if (!on_host_shape_.IsTuple()) { - return InvalidArgument( - "Attempted to destructure a PyLocalBuffer that did not have a tuple " - "shape; shape: %s", - ShapeUtil::HumanString(on_host_shape_)); + SharedDeviceBuffer::ScopedUsage usage; + return std::move(usage.Acquire(device_buffer_)); +} + +std::shared_ptr +PyLocalBuffer::GetBufferWithExternalReference() { + absl::MutexLock lock(&mu_); + if (device_buffer_ == nullptr) { + return nullptr; } - if (!device_buffer_) { - return InvalidArgument("Attempted to destructure a deleted buffer."); + device_buffer_->AddExternalReference(); + return device_buffer_; +} + +StatusOr, + std::shared_ptr>> +PyLocalBuffer::CopyToDeviceHelper( + Device* dst_device, LocalDeviceState* dst_local_device, + LocalDeviceState* transfer_local_device, se::Stream* transfer_stream, + std::shared_ptr src_device_buffer) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr py_buffer, + AllocateDestinationBuffer(on_host_shape_, dst_device, dst_local_device, + transfer_stream, client_)); + + TF_ASSIGN_OR_RETURN(ShapedBuffer src_buffer, AsShapedBuffer()); + + WaitForBufferDefinitionEventsOnStream(*src_device_buffer, transfer_stream); + + SharedDeviceBuffer::ScopedUsage dst_device_buffer = + py_buffer->GetBufferWithUsageHold(); + CHECK(dst_device_buffer.IsValid()); + ShapedBuffer dst_buffer = dst_device_buffer->AsShapedBuffer( + on_host_shape_, on_device_shape_, client_->client()->platform()); + + // Copy the leaf buffers. + StatusOr> copy_event_or = + [&]() -> StatusOr> { + for (const auto& leaf : src_buffer.buffers().leaves()) { + const ShapeIndex& index = leaf.first; + const se::DeviceMemoryBase& input_buffer = leaf.second; + const se::DeviceMemoryBase& output_buffer = dst_buffer.buffer(index); + TF_RET_CHECK(input_buffer.size() == output_buffer.size()) + << "input: " << input_buffer.size() + << " output: " << output_buffer.size(); + if (input_buffer.size() != 0) { + TF_RETURN_IF_ERROR(transfer_local_device->ThenMemcpyDeviceToDevice( + transfer_stream, dst_local_device->compute_stream(), input_buffer, + output_buffer)); + } + } + std::shared_ptr event = + dst_device_buffer->definition_events()[0]; + TF_RETURN_IF_ERROR(AddDestinationBufferSynchronization( + transfer_local_device, std::move(dst_device_buffer), event, + transfer_stream)); + return event; + }(); + if (!copy_event_or.ok()) { + StallStreamOnError(transfer_local_device, transfer_stream); + if (transfer_local_device == dst_local_device) { + // Some copies may have been enqueued before the error was returned, and + // StallStreamOnError only makes sure the destination device is ok, so + // make sure that the src buffer remains valid until after any transfers + // have completed. + device_->local_device_state()->ThenRelease(transfer_stream, + src_device_buffer); + } + return copy_event_or.status(); } - int num_children = ShapeUtil::TupleElementCount(on_host_shape_); - std::vector> results; - results.reserve(num_children); - for (int64 i = 0; i < num_children; ++i) { - results.push_back(absl::make_unique( - on_host_shape_.tuple_shapes(i), on_device_shape_.tuple_shapes(i), - device_buffer_->children().at(i), client_, device_)); - } - return results; + + return std::pair, + std::shared_ptr>( + std::move(py_buffer), copy_event_or.ConsumeValueOrDie()); } StatusOr> PyLocalBuffer::CopyToDevice( Device* dst_device) { tensorflow::profiler::TraceMe traceme("PyLocalBuffer::CopyToDevice"); - std::shared_ptr src_device_buffer = DeviceBuffer(); + if (dst_device == device_) { + return InvalidArgument( + "CopyToDevice cannot accept the same source and destination devices"); + } + TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device, dst_device->GetLocalDeviceState()); - - if (dst_device == device_) { - return absl::make_unique( - on_host_shape_, on_device_shape_, src_device_buffer, client_, device_); - } LocalDeviceState* transfer_local_device = client_->EnqueueD2DTransfersOnSrcStream() ? device_->local_device_state() : dst_local_device; + CHECK_EQ(dst_local_device->allocation_model(), + transfer_local_device->allocation_model()); se::Stream* transfer_stream = transfer_local_device->GetDeviceToDeviceStream(); - TransferManager* transfer_manager = - client_->client()->backend().transfer_manager(); - TF_ASSIGN_OR_RETURN(ScopedShapedBuffer dst_buffer, - transfer_manager->AllocateScopedShapedBuffer( - on_host_shape_, client_->allocator(), - dst_local_device->device_ordinal())); - if (!transfer_manager->CanShapedBufferBeAccessedNow( - dst_local_device->compute_stream()->parent(), dst_buffer)) { - transfer_stream->ThenWaitFor(dst_local_device->compute_stream()); - } - TF_ASSIGN_OR_RETURN(ShapedBuffer src_buffer, AsShapedBuffer()); - - WaitForBufferDefinitionEventsOnStream(*src_device_buffer, transfer_stream); - - // Copy the leaf buffers. - for (const auto& leaf : src_buffer.buffers().leaves()) { - const ShapeIndex& index = leaf.first; - const se::DeviceMemoryBase& input_buffer = leaf.second; - const se::DeviceMemoryBase& output_buffer = dst_buffer.buffer(index); - TF_RET_CHECK(input_buffer.size() == output_buffer.size()) - << "input: " << input_buffer.size() - << " output: " << output_buffer.size(); - if (input_buffer.size() != 0) { - TF_RETURN_IF_ERROR(transfer_local_device->ThenMemcpyDeviceToDevice( - transfer_stream, dst_local_device->compute_stream(), input_buffer, - output_buffer)); + SharedDeviceBuffer::ScopedUsage src_device_buffer; + { + absl::MutexLock lock(&mu_); + if (device_buffer_ == nullptr) { + return InvalidArgument("CopyToDevice called on invalid buffer"); } + src_device_buffer.Acquire(device_buffer_); } - // We hold on to the `src_device_buffer` until the transfer is finished. - transfer_local_device->ThenRelease(transfer_stream, - std::move(src_device_buffer)); - - // Write new tuple buffers. The destination buffers have different addresses, - // so we must construct tuple buffers from scratch instead of copying them. - if (dst_buffer.on_device_shape().IsTuple()) { - TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync( - dst_local_device->host_to_device_stream(), dst_buffer)); - - // We need a single definition event, so make the device to device stream - // wait for the stream that wrote the tuple index tables on the destination - // device. - transfer_stream->ThenWaitFor(dst_local_device->host_to_device_stream()); + StatusOr, + std::shared_ptr>> + buffer_and_event_or = CopyToDeviceHelper( + dst_device, dst_local_device, transfer_local_device, transfer_stream, + src_device_buffer.buffer_reference()); + if (!buffer_and_event_or.ok()) { + return buffer_and_event_or.status(); } - auto definition_event = std::make_shared(); - TF_ASSIGN_OR_RETURN( - EventPool::Handle event, - transfer_local_device->event_pool().ThenAllocateAndRecordEvent( - transfer_stream)); - definition_event->SetDefinitionEvent(std::move(event), transfer_stream); + auto& [buffer, event] = buffer_and_event_or.ValueOrDie(); + // prefer_to_retain_reference=*/true means that, when using the + // ComputeSynchronized allocation model, retain a reference to the + // src_device_buffer until the copy completes. This is a heuristic; the + // alternative is to ensure, before freeing the buffer, that the compute + // stream is synchronized past the transfer, but it seems better to hold onto + // the buffer too long than to stall the compute stream. + RecordUsage(std::move(src_device_buffer), device_->local_device_state(), + transfer_local_device, event, transfer_stream, + /*prefer_to_retain_reference=*/true); - std::shared_ptr dst_device_buffer = - SharedDeviceBuffer::FromScopedShapedBuffer(&dst_buffer, - {definition_event}); - return absl::make_unique( - dst_buffer.on_host_shape(), dst_buffer.on_device_shape(), - std::move(dst_device_buffer), client_, dst_device); + return std::move(buffer); } Status PyLocalBuffer::CopyToRemoteDevice( @@ -642,20 +896,123 @@ Status PyLocalBuffer::CopyToRemoteDevice( Status PyLocalBuffer::BlockHostUntilReady() { tensorflow::profiler::TraceMe traceme("PyLocalBuffer::BlockHostUntilReady"); - std::shared_ptr device_buffer = DeviceBuffer(); - if (!device_buffer) { - return InvalidArgument("BlockHostUntilReady() called on invalid buffer."); + std::shared_ptr device_buffer; + { + absl::MutexLock lock(&mu_); + if (device_buffer_ == nullptr) { + return InvalidArgument("BlockHostUntilReady() called on invalid buffer."); + } + device_buffer = device_buffer_; + } + LocalDeviceState* local_device_state = device_->local_device_state(); + std::unique_ptr stream; + for (auto& event : device_buffer->definition_events()) { + if (!event->IsComplete()) { + if (stream == nullptr) { + stream = local_device_state->BorrowStreamFromPool(); + } + event->WaitForEventOnStream(stream.get()); + } + } + if (stream != nullptr) { + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + local_device_state->ReturnStreamToPool(std::move(stream)); + } + return Status::OK(); +} + +namespace { + +// Helper struct for the tuple that is transiently constructed to hold the +// arguments of an execution. +struct TupleHandle { + // The device buffer holding the root of the tuple table. + se::OwningDeviceMemory root_table; + // The ShapedBuffer describing the tuple. Does not own any of its buffers. + std::unique_ptr shaped_buffer; + // A definition event that has been recorded on the host_to_device stream + // after the tuple table transfer. + std::shared_ptr event; +}; + +// Makes a tuple from the arguments to an execution. +StatusOr MakeTupleHelper( + PyLocalClient* client, LocalDeviceState* local_device, + absl::Span shaped_buffers, int device_ordinal) { + std::vector host_shapes; + std::vector device_shapes; + host_shapes.reserve(shaped_buffers.size()); + device_shapes.reserve(shaped_buffers.size()); + for (const ShapedBuffer& buffer : shaped_buffers) { + host_shapes.push_back(buffer.on_host_shape()); + device_shapes.push_back(buffer.on_device_shape()); + } + Shape on_host_shape = ShapeUtil::MakeTupleShape(host_shapes); + Shape on_device_shape = ShapeUtil::MakeTupleShape(device_shapes); + + se::DeviceMemoryAllocator* allocator = client->allocator(); + TransferManager* transfer_manager = + client->client()->backend().transfer_manager(); + se::Stream* stream = local_device->host_to_device_stream(); + TF_ASSIGN_OR_RETURN( + se::OwningDeviceMemory root_table_memory, + allocator->Allocate( + device_ordinal, + transfer_manager->GetByteSizeRequirement(on_host_shape))); + + // tuple_buffer holds the device buffers for all the arguments and the root + // table, and does not own any of them. + auto tuple_buffer = absl::make_unique( + on_host_shape, on_device_shape, client->client()->platform(), + device_ordinal); + tuple_buffer->set_buffer(root_table_memory.cref(), {}); + for (int i = 0; i < shaped_buffers.size(); ++i) { + for (const auto& sub_buffer : shaped_buffers[i].buffers()) { + ShapeIndex index = sub_buffer.first; + index.push_front(i); + tuple_buffer->set_buffer(sub_buffer.second, index); + } } - // This code waits at least until the buffer is ready, but it may wait longer - // if there are other device to host transfers scheduled. If this proves to - // be an issue, we could either use a separate stream for this purpose, or - // poll for the buffer definition events. - se::Stream* stream = - client_->device_state(device_->local_device_state()->device_ordinal()) - .GetDeviceToHostStream(); - WaitForBufferDefinitionEventsOnStream(*device_buffer, stream); - return stream->BlockHostUntilDone(); + if (local_device->allocation_model() == + LocalDeviceState::kComputeSynchronized) { + stream->ThenWaitFor(local_device->compute_stream()); + } else { + DCHECK(transfer_manager->CanShapedBufferBeAccessedNow( + local_device->compute_stream()->parent(), *tuple_buffer)); + } + + TF_RETURN_IF_ERROR( + transfer_manager->WriteRootTupleIndexTable(stream, *tuple_buffer)); + StatusOr event_or = + local_device->event_pool().ThenAllocateAndRecordEvent(stream); + if (!event_or.ok()) { + StallStreamOnError(local_device, stream); + return event_or.status(); + } + + auto transfer_event = std::make_shared(); + transfer_event->SetDefinitionEvent(event_or.ConsumeValueOrDie(), stream); + return TupleHandle({std::move(root_table_memory), std::move(tuple_buffer), + std::move(transfer_event)}); +} + +// Converts a ScopedShapedBuffer returned from an execution into a +// PyLocalBuffer. +std::unique_ptr OutputBufferHelper( + ScopedShapedBuffer* result_buffer, + std::shared_ptr definition_event, + PyLocalClient* client, Device* device, LocalDeviceState* local_device) { + std::shared_ptr out_buffer = + SharedDeviceBuffer::FromScopedShapedBuffer(result_buffer, + {definition_event}); + RecordUsage(std::move(SharedDeviceBuffer::ScopedUsage().Acquire(out_buffer)), + local_device, local_device, definition_event, + local_device->compute_stream(), + /*prefer_to_retain_reference=*/false); + return absl::make_unique( + result_buffer->on_host_shape(), result_buffer->on_device_shape(), + std::move(out_buffer), client, device); } static Device* LookupDevice(const PyLocalClient& client, int device_id) { @@ -665,6 +1022,8 @@ static Device* LookupDevice(const PyLocalClient& client, int device_id) { return it->second; } +} // namespace + PyLocalExecutable::PyLocalExecutable( std::vector> executables, bool tuple_arguments, DeviceAssignment device_assignment, @@ -719,22 +1078,14 @@ const std::string& PyLocalExecutable::name() const { } } -StatusOr>> -PyLocalExecutable::ExecuteHelper( +// Enqueues a computation onto the compute stream. Each buffer returned in +// device_buffers has a usage hold added that must be dropped on error or +// converted on success. +StatusOr PyLocalExecutable::EnqueueExecution( absl::Span argument_handles, int replica, - int partition, const RunId& run_id, const ExecuteOptions& options) const { - const int device_id = (*device_assignment_)(replica, partition); - Device* device = LookupDevice(*client_, device_id); - - std::unique_ptr tuple_buffer; - std::vector tupled_arguments; - if (tuple_arguments_) { - TF_ASSIGN_OR_RETURN(tuple_buffer, PyLocalBuffer::MakeTuple( - argument_handles, client_, device)); - tupled_arguments = {tuple_buffer.get()}; - argument_handles = tupled_arguments; - } - CHECK_EQ(device->host_id(), client_->host_id()); + int partition, int executable_idx, const RunId& run_id, + const ExecuteOptions& options, Device* device, + std::vector* device_buffers) const { int device_ordinal = device->local_device_state()->device_ordinal(); tensorflow::profiler::TraceMe traceme([&] { return absl::StrCat("LocalExecutable::Execute#run_id=", run_id.ToInt(), @@ -744,16 +1095,16 @@ PyLocalExecutable::ExecuteHelper( << " mapped to device ordinal for execution: " << device_ordinal; absl::flat_hash_set events; - std::vector> device_buffers; std::vector argument_buffers; std::vector argument_buffer_ptrs; - device_buffers.reserve(argument_handles.size() + 1); + device_buffers->reserve(argument_handles.size()); argument_buffers.reserve(argument_handles.size()); argument_buffer_ptrs.reserve(argument_handles.size()); for (int i = 0; i < argument_handles.size(); ++i) { PyLocalBuffer* handle = argument_handles[i]; - std::shared_ptr device_buffer = handle->DeviceBuffer(); - if (!device_buffer) { + SharedDeviceBuffer::ScopedUsage device_buffer = + handle->GetBufferWithUsageHold(); + if (!device_buffer.IsValid()) { return InvalidArgument( "Deleted buffer passed to Execute() as argument %d to replica %d", i, replica); @@ -764,16 +1115,29 @@ PyLocalExecutable::ExecuteHelper( "device %s, but replica is assigned to device %s.", i, replica, handle->device()->DebugString(), device->DebugString()); } - TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, handle->AsShapedBuffer()); + ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer( + handle->on_host_shape(), handle->on_device_shape(), + handle->client()->client()->platform()); argument_buffers.push_back(std::move(shaped_buffer)); argument_buffer_ptrs.push_back(&argument_buffers.back()); GetDeviceBufferDefinitionEvents(*device_buffer, &events); - device_buffers.push_back(std::move(device_buffer)); + device_buffers->push_back(std::move(device_buffer)); VLOG(4) << "Argument " << i << " buffer: " << argument_buffers.back().ToString(); } LocalDeviceState* device_state = &client_->device_state(device_ordinal); + TupleHandle tuple_handle; + if (tuple_arguments_) { + TF_ASSIGN_OR_RETURN(tuple_handle, + MakeTupleHelper(client_, device_state, argument_buffers, + device_ordinal)); + argument_buffer_ptrs = {tuple_handle.shaped_buffer.get()}; + events.insert(tuple_handle.event.get()); + // CAUTION: a copy has been enqueued into tuple_handle.root_table so it is + // important not to free the root_table on error without ensuring that + // necessary synchronization has been done. + } for (BufferDefinitionEvent* event : events) { event->WaitForEventOnStream(device_state->compute_stream()); @@ -790,56 +1154,118 @@ PyLocalExecutable::ExecuteHelper( run_options.set_rng_seed(device_state->GetNewPrngSeed()); run_options.set_gpu_executable_run_options(client_->gpu_run_options()); - // The choice of where we wait is arbitrary; the reason for the wait is pacing - // to avoid problems such as memory fragmentation and running ahead too far, - // not for correctness. Placing it before the executable launch allows the - // inputs for the next executable to be fetched even if the launch is delayed. + // The choice of where we wait is arbitrary; the reason for the wait is + // pacing to avoid problems such as memory fragmentation and running ahead + // too far, not for correctness. Placing it before the executable launch + // allows the inputs for the next executable to be fetched even if the + // launch is delayed. auto compute_reservation = std::make_shared( device_state->compute_semaphore().ScopedAcquire(1)); - // SPMD sharding produces a single executable for multiple partitions. - int executable_idx = executables_.size() > 1 ? partition : 0; - StatusOr result_buffer_or_status = executables_[executable_idx]->RunAsync(argument_buffer_ptrs, run_options); VLOG(1) << "Replica " << replica << " partition " << partition << " completed; ok=" << result_buffer_or_status.ok(); + + if (device_state->allocation_model() == LocalDeviceState::kSynchronous) { + // Free the root tuple table after execution has completed. + device_state->ThenExecuteOnCallbackThread( + device_state->compute_stream(), + [references{std::make_tuple(executables_[executable_idx], + compute_reservation, device_assignment_)}, + root_buffer{tuple_handle.root_table.Release()}, + allocator{client_->allocator()}, device_ordinal]() { + TF_CHECK_OK(allocator->Deallocate(device_ordinal, root_buffer)); + }); + + } else { + // The root tuple table can be freed as soon as the computation is + // enqueued. + device_state->ThenRelease( + device_state->compute_stream(), + std::make_tuple(executables_[executable_idx], compute_reservation, + device_assignment_)); + } + + return result_buffer_or_status; +} + +StatusOr>> +PyLocalExecutable::ExecuteHelper( + absl::Span argument_handles, int replica, + int partition, const RunId& run_id, const ExecuteOptions& options) const { + const int device_id = (*device_assignment_)(replica, partition); + Device* device = LookupDevice(*client_, device_id); + + CHECK_EQ(device->host_id(), client_->host_id()); + int device_ordinal = device->local_device_state()->device_ordinal(); + tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute"); + VLOG(3) << "Replica " << replica << ", partition " << partition + << " mapped to device ordinal for execution: " << device_ordinal; + + // SPMD sharding produces a single executable for multiple partitions. + int executable_idx = executables_.size() > 1 ? partition : 0; + + std::vector device_buffers; + device_buffers.reserve(argument_handles.size()); + StatusOr result_buffer_or_status = + EnqueueExecution(argument_handles, replica, partition, executable_idx, + run_id, options, device, &device_buffers); + if (!result_buffer_or_status.ok()) { LOG(ERROR) << "Execution of replica " << replica << " failed: " << result_buffer_or_status.status(); return result_buffer_or_status.status(); } - ScopedShapedBuffer& result_buffer = result_buffer_or_status.ValueOrDie(); + ScopedShapedBuffer result_buffer = + result_buffer_or_status.ConsumeValueOrDie(); + LocalDeviceState* device_state = &client_->device_state(device_ordinal); + se::Stream* stream = device_state->compute_stream(); + StatusOr event_or = + device_state->event_pool().ThenAllocateAndRecordEvent(stream); + if (!event_or.ok()) { + StallStreamOnError(device_state, stream); + return event_or.status(); + } auto definition_event = std::make_shared(); - TF_ASSIGN_OR_RETURN(EventPool::Handle event, - device_state->event_pool().ThenAllocateAndRecordEvent( - device_state->compute_stream())); - definition_event->SetDefinitionEvent(std::move(event), - device_state->compute_stream()); - - std::shared_ptr out_buffer = - SharedDeviceBuffer::FromScopedShapedBuffer(&result_buffer, - {definition_event}); - - if (device_state->synchronous_deallocation()) { - device_buffers.push_back(out_buffer); - device_state->ThenRelease(device_state->compute_stream(), - std::move(device_buffers)); - } - - device_state->ThenRelease( - device_state->compute_stream(), - std::make_tuple(executables_[executable_idx], compute_reservation, - device_assignment_)); + definition_event->SetDefinitionEvent(event_or.ConsumeValueOrDie(), stream); std::vector> outputs; - outputs.push_back(absl::make_unique( - result_buffer.on_host_shape(), result_buffer.on_device_shape(), - std::move(out_buffer), client_, device)); if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) { - TF_ASSIGN_OR_RETURN(outputs, outputs.front()->DestructureTuple()); + int tuple_count = result_buffer.on_host_shape().tuple_shapes_size(); + outputs.reserve(tuple_count); + // Take ownership of each of the output values, leaving only the root table + // in result_buffer. + for (int i = 0; i < tuple_count; ++i) { + ScopedShapedBuffer tuple_buffer = result_buffer.TakeSubTree({i}); + outputs.push_back(OutputBufferHelper(&tuple_buffer, definition_event, + client_, device, device_state)); + } + if (device_state->allocation_model() == LocalDeviceState::kSynchronous) { + // Don't release the root buffer until after execution completes. + ShapedBuffer root_buffer_holder = result_buffer.release(); + se::DeviceMemoryBase root_buffer = root_buffer_holder.root_buffer(); + device_state->ThenExecuteOnCallbackThread( + device_state->compute_stream(), + [root_buffer, allocator{client_->allocator()}, device_ordinal]() { + TF_CHECK_OK(allocator->Deallocate(device_ordinal, root_buffer)); + }); + } + } else { + outputs.push_back(OutputBufferHelper(&result_buffer, definition_event, + client_, device, device_state)); } + + for (SharedDeviceBuffer::ScopedUsage& b : device_buffers) { + // prefer_to_retain_reference=false because when using the + // ComputeSynchronized allocation model we don't need to retain a reference + // to the device_buffer during execution because by definition the compute + // stream is synchronized past the execution. + RecordUsage(std::move(b), device_state, device_state, definition_event, + stream, /*prefer_to_retain_reference=*/false); + } + return outputs; } @@ -931,17 +1357,17 @@ PyLocalExecutable::ExecuteOnLocalDevices( mu.AssertHeld(); return running == 0; }; - // If execution does not terminate within a reasonable amount of time, we - // may be stuck at a cross-replica barrier on-device. Terminate the + // If execution does not terminate within a reasonable amount of time, + // we may be stuck at a cross-replica barrier on-device. Terminate the // process since that's the only way we can escape this situation at the // moment (b/130629719). if (!mu.AwaitWithTimeout(absl::Condition(&done_running), absl::Seconds(10))) { LOG(FATAL) << "Replicated computation launch failed, but not all replicas " - "terminated. Aborting process to work around deadlock. Failure " - "message (there may have been multiple failures, see the " - "error log for all failures): \n\n" + "terminated. Aborting process to work around deadlock. " + "Failure message (there may have been multiple failures, see " + "the error log for all failures): \n\n" << first_failure_status.error_message(); } } diff --git a/tensorflow/compiler/xla/python/local_client.h b/tensorflow/compiler/xla/python/local_client.h index c9b50fbbbef..63d26782b20 100644 --- a/tensorflow/compiler/xla/python/local_client.h +++ b/tensorflow/compiler/xla/python/local_client.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" @@ -38,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/thread_annotations.h" // API notes: // Despite having the name "PyLocalClient", it is intended that this API may @@ -87,7 +89,9 @@ class Device { const std::string platform_name_; }; +// Forward declaration. class PyLocalBuffer; + // Helper struct for cross host transfers, returned by the callback from a call // to PyLocalBuffer::MakeCrossHostReceiveBuffers. struct PyLocalCrossHostRecvBuffer { @@ -193,13 +197,12 @@ class PyLocalClient : public std::enable_shared_from_this { StatusOr DevicesToDeviceAssignment( absl::Span> devices); -// Holds a reference from Python to one or more device buffers. -// A PyLocalBuffer can be either valid or invalid. An invalid buffer is one that -// has never been initialized, or a buffer that has been deleted (e.g., by -// calling Delete). We allow PyLocalBuffer objects to outlive the underlying -// device buffers so we can decouple buffer lifetimes from the corresponding -// Python references if needed. -// Thread-safe. +// Holds a reference from Python to a tuple of device buffers. A PyLocalBuffer +// can be either valid or invalid. An invalid buffer is one that has never been +// initialized, or a buffer that has been deleted (e.g., by calling Delete). We +// allow PyLocalBuffer objects to outlive the underlying device buffers so we +// can decouple buffer lifetimes from the corresponding Python references if +// needed. Thread-safe. class PyLocalBuffer { public: // If `force_copy` is true, forces a copy of the input buffer on CPU. @@ -212,9 +215,8 @@ class PyLocalBuffer { std::shared_ptr buffer_reference, PyLocalClient* client, Device* device); - static StatusOr> MakeTuple( - absl::Span buffers, PyLocalClient* client, - Device* device); + static StatusOr> FromHostLiteral( + const LiteralSlice& literal, PyLocalClient* client, Device* device); // Asynchronously makes a vector of PyLocalBuffers that can be used to receive // cross host transfers using `client` on `device'. `shapes` must be the exact @@ -232,6 +234,7 @@ class PyLocalBuffer { PyLocalBuffer(Shape on_host_shape, Shape on_device_shape, std::shared_ptr device_buffer, PyLocalClient* client, Device* device); + ~PyLocalBuffer(); PyLocalBuffer(const PyLocalBuffer&) = delete; PyLocalBuffer(PyLocalBuffer&&) = delete; @@ -243,11 +246,13 @@ class PyLocalBuffer { Device* device() const { return device_; } const std::string& platform_name() const { return client_->platform_name(); } PyLocalClient* client() const { return client_; } + bool IsEmptyTuple() const { + return on_host_shape_.IsTuple() && on_host_shape_.tuple_shapes_size() == 0; + } - // Returns the buffer's value as a tuple DAG of Python arrays. If the value - // has previously been prefetched to the host, then returns the prefetched - // version, otherwise copies the buffer to the host. Blocks until the - // value is ready. + // Returns the buffer's value as an XLA Literal. If the value has previously + // been prefetched to the host, then returns the prefetched version, otherwise + // copies the buffer to the host. Blocks until the value is ready. StatusOr> ToLiteral(); // Initiates a copy of the buffer to the host. Does not block waiting for @@ -255,23 +260,55 @@ class PyLocalBuffer { // ToLiteral(). Status CopyToHostAsync(); - // Returns the associated device buffer. Returns a nullptr if the buffer is - // invalid. - std::shared_ptr DeviceBuffer() const; - - // Deletes the device memory associated with this buffer, leaving it in an - // invalid state. + // Drops the buffer's reference to its associated device memory, leaving the + // buffer in an invalid state. The memory will be freed lazily when all async + // operations using the buffer have completed, according to the allocation + // semantics of the underlying platform. Delete may briefly block if another + // thread is in the process of enqueuing an operation on this buffer, but it + // will never block for a stream operation to complete. If an external + // framework holds a reference to the SharedDeviceBuffer via + // GetBufferWithExternalReference, the memory will not be freed until the + // external framework drops the reference. void Delete(); - // Returns a view of the PyLocalBuffer DAG as a ShapedBuffer. The + // Similar to Delete, drops the buffer's reference to its associated device + // memory, leaving the buffer in an invalid state, but returns the + // SharedDeviceBuffer rather than freeing the device memory, so that another + // framework can take ownership of it. The buffer returned from Release may + // be safely dropped at any time even if it still has pending async + // operations. The client should call BlockHostUntilReady before calling + // Release with wait_for_operations_to_complete=false, to ensure that the host + // has synchronized past any outstanding write operations to the buffer. If + // wait_for_operations_to_complete=true the host will block until any + // potentially outstanding asynchronous operations have completed before + // returning, in which case it is safe to read or mutate the returned buffer. + StatusOr> Release( + bool wait_for_operations_to_complete); + + // True if and only if Delete or Release has previously been called. + bool IsDeleted(); + + // Returns a view of the PyLocalBuffer device memory as a ShapedBuffer. The // PyLocalBuffer retains ownership of the device buffers. StatusOr AsShapedBuffer() const; - // Destructures a tuple-valued PyLocalBuffer into its constituent elements. - StatusOr>> DestructureTuple() - const; + // Returns a 'usage hold' on the SharedDeviceBuffer holding the device + // buffers. The hold ensures that the device buffers can't be deleted until + // the hold is dropped or converted. GetBufferWithUsageHold is called + // before enqueueing the buffer on any stream operation, and the usage hold is + // dropped or converted after the enqueue is complete. The buffer in the hold + // is nullptr if Delete or Release has been called. + SharedDeviceBuffer::ScopedUsage GetBufferWithUsageHold(); - // Copies the buffer to device `dst_device`. + // Returns the SharedDeviceBuffer holding the device buffers, after adding an + // external reference ensuring that the device buffers can't be deleted until + // the reference is dropped. GetBufferWithExernalReference is called when an + // external framework wants to share the device buffers temporarily. Returns + // nullptr if Delete or Release has been called. + std::shared_ptr GetBufferWithExternalReference(); + + // Copies the buffer to device `dst_device`. Returns an error if the buffer is + // already on dst_device. StatusOr> CopyToDevice(Device* dst_device); // Copies the buffer to the remote device encoded in serialized_descriptor. @@ -290,13 +327,7 @@ class PyLocalBuffer { Status BlockHostUntilReady(); private: - PyLocalClient* const client_; - const Shape on_host_shape_; - const Shape on_device_shape_; - Device* const device_; - mutable absl::Mutex mu_; - std::shared_ptr device_buffer_ TF_GUARDED_BY(mu_); - + friend class PyLocalClient; // The cached value of the buffer on the host, produced either from a call to // CopyToHost or from a call to ToLiteral. Once a value has been fetched to // the host, it persists Delete() is called or the PyLocalBuffer is destroyed. @@ -307,6 +338,21 @@ class PyLocalBuffer { Status status; std::shared_ptr value; }; + + StatusOr, + std::shared_ptr>> + CopyToDeviceHelper(Device* dst_device, LocalDeviceState* dst_local_device, + LocalDeviceState* transfer_local_device, + se::Stream* transfer_stream, + std::shared_ptr src_device_buffer); + + PyLocalClient* const client_; + const Shape on_host_shape_; + const Shape on_device_shape_; + Device* const device_; + + mutable absl::Mutex mu_; + std::shared_ptr device_buffer_ TF_GUARDED_BY(mu_); std::shared_ptr host_value_ TF_GUARDED_BY(mu_); }; @@ -391,6 +437,11 @@ class PyLocalExecutable { const string& name() const; private: + StatusOr EnqueueExecution( + absl::Span argument_handles, int replica, + int partition, int executable_idx, const RunId& run_id, + const ExecuteOptions& options, Device* device, + std::vector* device_buffers) const; StatusOr>> ExecuteHelper( absl::Span argument_handles, int replica, int partition, const RunId& run_id, const ExecuteOptions& options) const; diff --git a/tensorflow/compiler/xla/python/local_device_state.cc b/tensorflow/compiler/xla/python/local_device_state.cc index 778cf316b34..6a96908cb12 100644 --- a/tensorflow/compiler/xla/python/local_device_state.cc +++ b/tensorflow/compiler/xla/python/local_device_state.cc @@ -19,16 +19,18 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/synchronization/mutex.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/stream.h" namespace xla { LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor, LocalClient* client, - bool synchronous_deallocation, + AllocationModel allocation_model, bool asynchronous, bool allow_event_reuse) - : synchronous_deallocation_(synchronous_deallocation), + : allocation_model_(allocation_model), event_pool_(allow_event_reuse), compute_semaphore_(/*capacity=*/asynchronous ? 32 : 1), executor_(executor), @@ -116,6 +118,24 @@ se::Stream* LocalDeviceState::GetDeviceToDeviceStream() { return device_to_device_streams_.at(i).get(); } +std::unique_ptr LocalDeviceState::BorrowStreamFromPool() { + absl::MutexLock lock(&mu_); + if (usage_stream_pool_.empty()) { + auto stream = absl::make_unique(compute_stream_->parent()); + stream->Init(); + return stream; + } else { + std::unique_ptr stream = std::move(usage_stream_pool_.top()); + usage_stream_pool_.pop(); + return stream; + } +} + +void LocalDeviceState::ReturnStreamToPool(std::unique_ptr stream) { + absl::MutexLock lock(&mu_); + usage_stream_pool_.push(std::move(stream)); +} + int LocalDeviceState::GetNewPrngSeed() { absl::MutexLock lock(&mu_); int x = 0; diff --git a/tensorflow/compiler/xla/python/local_device_state.h b/tensorflow/compiler/xla/python/local_device_state.h index fa73c832c57..5cd2c0014a0 100644 --- a/tensorflow/compiler/xla/python/local_device_state.h +++ b/tensorflow/compiler/xla/python/local_device_state.h @@ -35,15 +35,61 @@ namespace xla { // for devices local to this host. class LocalDeviceState { public: - // If synchronous_deallocation is true, the host must not free buffers until - // compute/transfers that use those buffers have completed. For example, this - // typically is the case for the "platform" where compute/transfers are - // operations that take place on another thread. - // + // There are three different semantics used by memory allocators on different + // devices. + enum AllocationModel { + // kSynchronous is used by CPU devices. + // + // A buffer returned from the allocator can be used immediately. + // + // A buffer cannot be freed until after the last stream operation + // referencing the buffer has completed, so the client is responsible for + // keeping buffers alive until all device-side activity that consumes those + // buffers has completed. + // + // The client's use of the device allocator corresponds to a view of the + // tail of the last stream using a buffer. + kSynchronous, + + // kComputeSynchronous is used by GPU devices. + // + // A buffer returned from the allocator at time t can be used after the + // compute stream has finished executing the last computation enqueued + // before time t. + // + // A buffer b can be freed after: + // 1) The last use of b on the compute stream has been enqueued, and + // 2) For any non-compute stream s on which an operation o using b is + // enqueued, either: + // a) The host has been notified that o has completed, or + // b) The next operation to be enqueued on the compute stream is + // guaranteed to be started after o has completed. + // + // The client's use of the device allocator corresponds to a view of the + // tail of the compute stream. + kComputeSynchronized, + + // kAsynchronous is used by TPU devices. + // + // A buffer returned from the allocator can be used immediately. + // + // A buffer b can be freed as soon as the last stream operation using b has + // been enqueued. + // + // The allocator and lower-level runtime are responsible for keeping buffers + // alive (if that is needed) from the perspective of the device until any + // device-side work actually completes. + // + // The only exception is when a buffer is transferred between devices since + // only one of the device executors knows about the transfer, so the buffer + // must be manually kept alive from the perspective of the other executor. + kAsynchronous + }; + // If asynchronous is false, the host will synchronize to the device after // each execution or transfer. This is intended for debugging only. LocalDeviceState(se::StreamExecutor* executor, LocalClient* client, - bool synchronous_deallocation, bool asynchronous, + AllocationModel allocation_model, bool asynchronous, bool allow_event_reuse); virtual ~LocalDeviceState(); @@ -53,7 +99,7 @@ class LocalDeviceState { LocalClient* client() const { return client_; } - bool synchronous_deallocation() const { return synchronous_deallocation_; } + AllocationModel allocation_model() const { return allocation_model_; } EventPool& event_pool() { return event_pool_; } @@ -70,6 +116,13 @@ class LocalDeviceState { // fashion amongst the available streams. se::Stream* GetDeviceToDeviceStream(); + // Returns a stream from a pool. The stream is guaranteed not to have any + // currently outstanding work at its tail. + std::unique_ptr BorrowStreamFromPool(); + // Returns a stream to the pool. The caller must ensure the stream does not + // have any outstanding work at its tail. + void ReturnStreamToPool(std::unique_ptr stream); + // Enqueues a copy of `src_buffer` to `dst_buffer` onto `transfer_stream`. virtual Status ThenMemcpyDeviceToDevice(se::Stream* transfer_stream, se::Stream* dst_stream, @@ -109,7 +162,7 @@ class LocalDeviceState { private: Status SynchronizeAllActivity(); - bool synchronous_deallocation_; + AllocationModel allocation_model_; EventPool event_pool_; @@ -131,6 +184,7 @@ class LocalDeviceState { absl::Mutex mu_; int next_device_to_host_stream_ TF_GUARDED_BY(mu_) = 0; int next_device_to_device_stream_ TF_GUARDED_BY(mu_) = 0; + std::stack> usage_stream_pool_ TF_GUARDED_BY(mu_); std::random_device prng_seed_device_ TF_GUARDED_BY(mu_); std::mt19937 prng_seed_generator_ TF_GUARDED_BY(mu_); diff --git a/tensorflow/compiler/xla/python/nvidia_gpu_device.cc b/tensorflow/compiler/xla/python/nvidia_gpu_device.cc index 26ea727dee7..572b18a0abd 100644 --- a/tensorflow/compiler/xla/python/nvidia_gpu_device.cc +++ b/tensorflow/compiler/xla/python/nvidia_gpu_device.cc @@ -76,7 +76,8 @@ StatusOr>> BuildLocalDeviceStates( se::StreamExecutor* executor = xla_client->backend().stream_executor(i).ValueOrDie(); local_devices.push_back(absl::make_unique( - executor, xla_client, /*synchronous_deallocation=*/false, asynchronous, + executor, xla_client, LocalDeviceState::kComputeSynchronized, + asynchronous, /*allow_event_reuse=*/true)); } return std::move(local_devices); diff --git a/tensorflow/compiler/xla/python/shared_device_buffer.cc b/tensorflow/compiler/xla/python/shared_device_buffer.cc index 91f2b434a61..f8ba9fd73ce 100644 --- a/tensorflow/compiler/xla/python/shared_device_buffer.cc +++ b/tensorflow/compiler/xla/python/shared_device_buffer.cc @@ -18,8 +18,14 @@ limitations under the License. #include #include +#include "absl/synchronization/mutex.h" +#include "tensorflow/compiler/xla/python/local_device_state.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/device_memory_allocator.h" +#include "tensorflow/stream_executor/event.h" +#include "tensorflow/stream_executor/stream.h" namespace xla { @@ -32,10 +38,16 @@ void BufferDefinitionEvent::SetDefinitionEvent(EventPool::Handle event, streams_defined_on_.push_back(stream); } -bool BufferDefinitionEvent::EventHasBeenRecorded() { +bool BufferDefinitionEvent::EventHasBeenRecorded() const { return event_.event() != nullptr; } +uint64 BufferDefinitionEvent::sequence_number() const { + absl::MutexLock lock(&mu_); + CHECK(EventHasBeenRecorded()); + return event_.sequence_number(); +} + void BufferDefinitionEvent::WaitForEventOnStream(se::Stream* stream) { absl::MutexLock lock(&mu_); @@ -56,41 +68,29 @@ void BufferDefinitionEvent::WaitForEventOnStream(se::Stream* stream) { streams_defined_on_.push_back(stream); } -static std::shared_ptr BufferFromScopedShapedBufferIterator( - const Shape& on_host_shape, const Shape& on_device_shape, - int device_ordinal, se::DeviceMemoryAllocator* allocator, - ShapeTree::iterator* iterator, - const ShapeTree::iterator& end, - absl::Span> - definition_events) { - std::vector buffers; - buffers.reserve(1); - std::vector> children; +bool BufferDefinitionEvent::DefinedOn(se::Stream* stream) { + absl::MutexLock lock(&mu_); - auto consume_buffer = [&]() { - CHECK(*iterator != end); - buffers.emplace_back((*iterator)->second, device_ordinal, allocator); - (*iterator)->second = se::DeviceMemoryBase(); - ++*iterator; - }; - if (on_host_shape.IsTuple()) { - consume_buffer(); - int num_children = ShapeUtil::TupleElementCount(on_device_shape); - children.reserve(num_children); - for (int i = 0; i < num_children; ++i) { - children.push_back(BufferFromScopedShapedBufferIterator( - on_host_shape.tuple_shapes(i), on_device_shape.tuple_shapes(i), - device_ordinal, allocator, iterator, end, definition_events)); - } - } else { - // An on-host array may be an on-device tuple. For example, a complex tensor - // may be represented as a (real, imag) pair. - ShapeUtil::ForEachSubshape( - on_device_shape, - [&](const Shape&, const ShapeIndex&) { consume_buffer(); }); - } - return std::make_shared( - absl::Span(buffers), children, definition_events); + // We cannot wait for an event until ThenRecordEvent has been called; on GPU + // newly created events are deemed to have already happened past. + mu_.Await( + absl::Condition(this, &BufferDefinitionEvent::EventHasBeenRecorded)); + + // The set of defined streams is expected to be very small indeed (usually + // 1-2), so a simple linear scan should be fast enough. + return std::find(streams_defined_on_.begin(), streams_defined_on_.end(), + stream) != streams_defined_on_.end(); +} + +bool BufferDefinitionEvent::IsComplete() { + absl::MutexLock lock(&mu_); + + // We cannot wait for an event until ThenRecordEvent has been called; on + // GPU newly created events are deemed to have already happened past. + mu_.Await( + absl::Condition(this, &BufferDefinitionEvent::EventHasBeenRecorded)); + + return event_.event()->PollForStatus() == se::Event::Status::kComplete; } /* static */ std::shared_ptr @@ -100,74 +100,23 @@ SharedDeviceBuffer::FromScopedShapedBuffer( definition_events) { ShapeTree::iterator iterator = shaped_buffer->buffers().begin(); - std::shared_ptr output = - BufferFromScopedShapedBufferIterator( - shaped_buffer->on_host_shape(), shaped_buffer->on_device_shape(), - shaped_buffer->device_ordinal(), shaped_buffer->memory_allocator(), - &iterator, shaped_buffer->buffers().end(), definition_events); + std::vector buffers; + buffers.reserve(1); + + ShapeUtil::ForEachSubshape( + shaped_buffer->on_device_shape(), [&](const Shape&, const ShapeIndex&) { + CHECK(iterator != shaped_buffer->buffers().end()); + buffers.push_back(iterator->second); + iterator->second = se::DeviceMemoryBase(); + ++iterator; + }); CHECK(iterator == shaped_buffer->buffers().end()); - return output; -} - -/* static */ StatusOr> -SharedDeviceBuffer::MakeTuple( - std::vector> children, - const Shape& on_host_shape, TransferManager* transfer_manager, - se::DeviceMemoryAllocator* allocator, int device_ordinal, - absl::Span> - definition_events) { - CHECK(on_host_shape.IsTuple() && - on_host_shape.tuple_shapes_size() == children.size()); - TF_ASSIGN_OR_RETURN( - se::OwningDeviceMemory device_memory, - allocator->Allocate( - device_ordinal, - transfer_manager->GetByteSizeRequirement(on_host_shape))); return std::make_shared( - allocator, device_ordinal, - std::initializer_list{device_memory.Release()}, - std::move(children), definition_events, + shaped_buffer->memory_allocator(), shaped_buffer->device_ordinal(), + absl::Span(buffers), definition_events, /*on_delete_callback=*/nullptr); } -/* static */ StatusOr> -SharedDeviceBuffer::MakeArray( - Shape on_device_shape, TransferManager* transfer_manager, - se::DeviceMemoryAllocator* allocator, int device_ordinal, - absl::Span> - definition_events) { - std::vector device_buffers; - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - on_device_shape, [&](const Shape& subshape, const ShapeIndex&) -> Status { - TF_ASSIGN_OR_RETURN( - se::OwningDeviceMemory device_memory, - allocator->Allocate( - device_ordinal, - transfer_manager->GetByteSizeRequirement(subshape))); - device_buffers.push_back(std::move(device_memory)); - return Status::OK(); - })); - return std::make_shared( - absl::Span(device_buffers), - /*children=*/std::vector>{}, - definition_events); -} - -// Populates a buffer tree from a ShapeTree iterator. -static void PopulateShapedBufferFromBuffer( - const SharedDeviceBuffer& buffer, - ShapeTree::iterator* iterator, - const ShapeTree::iterator& end) { - for (const se::DeviceMemoryBase& buf : buffer.device_memory()) { - CHECK(*iterator != end); - (*iterator)->second = buf; - ++*iterator; - } - for (const auto& child : buffer.children()) { - PopulateShapedBufferFromBuffer(*child, iterator, end); - } -} - ShapedBuffer SharedDeviceBuffer::AsShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, se::Platform* platform) const { @@ -175,8 +124,11 @@ ShapedBuffer SharedDeviceBuffer::AsShapedBuffer(const Shape& on_host_shape, device_ordinal_); ShapeTree::iterator iterator = shaped_buffer.buffers().begin(); - PopulateShapedBufferFromBuffer(*this, &iterator, - shaped_buffer.buffers().end()); + for (const se::DeviceMemoryBase& buf : device_memory_) { + CHECK(iterator != shaped_buffer.buffers().end()); + iterator->second = buf; + ++iterator; + } CHECK(iterator == shaped_buffer.buffers().end()); return shaped_buffer; } @@ -188,40 +140,58 @@ using MoveIterator = } // namespace +SharedDeviceBuffer::ScopedUsage::~ScopedUsage() { + if (parent_ != nullptr) { + parent_->DropUsageHold(); + } +} + +SharedDeviceBuffer::ScopedUsage& SharedDeviceBuffer::ScopedUsage::Acquire( + std::shared_ptr parent) { + CHECK(parent_ == nullptr); + if (parent != nullptr) { + parent_ = std::move(parent); + parent_->AddUsageHold(); + } + return *this; +} + +std::shared_ptr SharedDeviceBuffer::ScopedUsage::Release() { + return std::move(parent_); +} + +void SharedDeviceBuffer::ScopedUsage::Transfer( + std::shared_ptr parent) { + CHECK(parent_ == nullptr); + parent_ = parent; +} + +void SharedDeviceBuffer::ScopedUsage::Convert( + se::Stream* usage_stream, std::shared_ptr event, + bool reference_held) { + CHECK(parent_ != nullptr); + parent_->ConvertUsageHold(usage_stream, std::move(event), reference_held); + parent_ = nullptr; +} + SharedDeviceBuffer::SharedDeviceBuffer( se::DeviceMemoryAllocator* allocator, int device_ordinal, absl::Span device_memory, - std::vector> children, absl::Span> definition_events, std::function on_delete_callback) : allocator_(allocator), device_ordinal_(device_ordinal), device_memory_(device_memory.begin(), device_memory.end()), - children_(std::move(children)), definition_events_( std::move_iterator(definition_events.begin()), std::move_iterator(definition_events.end())), + in_use_(true), + usage_holds_(0), + external_references_(0), on_delete_callback_(std::move(on_delete_callback)) {} -SharedDeviceBuffer::SharedDeviceBuffer( - absl::Span device_memory, - std::vector> children, - absl::Span> definition_events) - : children_(std::move(children)), - definition_events_( - std::move_iterator(definition_events.begin()), - std::move_iterator(definition_events.end())) { - CHECK(!device_memory.empty()); - allocator_ = device_memory.front().allocator(); - device_ordinal_ = device_memory.front().device_ordinal(); - for (se::OwningDeviceMemory& buffer : device_memory) { - CHECK(buffer.allocator() == allocator_) << "Mismatched allocators"; - CHECK_EQ(buffer.device_ordinal(), device_ordinal_); - device_memory_.push_back(buffer.Release()); - } -} - SharedDeviceBuffer::~SharedDeviceBuffer() { + CHECK_EQ(external_references_, 0); if (allocator_) { for (const se::DeviceMemoryBase& buffer : device_memory_) { Status status = allocator_->Deallocate(device_ordinal_, buffer); @@ -235,15 +205,71 @@ SharedDeviceBuffer::~SharedDeviceBuffer() { } } +void SharedDeviceBuffer::AddUsageHold() { + absl::MutexLock lock(&mu_); + CHECK(in_use_); + ++usage_holds_; +} + +void SharedDeviceBuffer::DropUsageHold() { + absl::MutexLock lock(&mu_); + CHECK(in_use_); + CHECK_GT(usage_holds_, 0); + --usage_holds_; +} + +void SharedDeviceBuffer::AddExternalReference() { + absl::MutexLock lock(&mu_); + CHECK(in_use_); + ++external_references_; +} + +void SharedDeviceBuffer::DropExternalReference() { + absl::MutexLock lock(&mu_); + CHECK_GT(external_references_, 0); + --external_references_; +} + +void SharedDeviceBuffer::ConvertUsageHold( + se::Stream* usage_stream, std::shared_ptr event, + bool reference_held) { + absl::MutexLock lock(&mu_); + CHECK(in_use_); + CHECK_GT(usage_holds_, 0); + --usage_holds_; + + for (auto& existing : usage_events_) { + if (existing.stream == usage_stream) { + if (*existing.event < *event) { + existing.event = event; + existing.reference_held = reference_held; + } + return; + } + } + usage_events_.push_back({usage_stream, event, reference_held}); +} + +SharedDeviceBuffer::StreamAndEventContainer +SharedDeviceBuffer::LockUseAndTransferUsageEvents() { + auto holds_converted = [&]() { + mu_.AssertHeld(); + return usage_holds_ == 0; + }; + absl::MutexLock lock(&mu_); + CHECK(in_use_); + mu_.Await(absl::Condition(&holds_converted)); + CHECK(in_use_); + in_use_ = false; + return std::move(usage_events_); +} + void GetDeviceBufferDefinitionEvents( const SharedDeviceBuffer& buffer, absl::flat_hash_set* events) { for (const auto& e : buffer.definition_events()) { events->insert(e.get()); } - for (const auto& child : buffer.children()) { - GetDeviceBufferDefinitionEvents(*child, events); - } } void WaitForBufferDefinitionEventsOnStream(const SharedDeviceBuffer& buffer, diff --git a/tensorflow/compiler/xla/python/shared_device_buffer.h b/tensorflow/compiler/xla/python/shared_device_buffer.h index 3aa122c535d..8a647a691f7 100644 --- a/tensorflow/compiler/xla/python/shared_device_buffer.h +++ b/tensorflow/compiler/xla/python/shared_device_buffer.h @@ -16,13 +16,18 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_SHARED_DEVICE_BUFFER_H_ #define TENSORFLOW_COMPILER_XLA_PYTHON_SHARED_DEVICE_BUFFER_H_ +#include + #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/python/event_pool.h" +#include "tensorflow/compiler/xla/python/local_device_state.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/device_memory_allocator.h" +#include "tensorflow/stream_executor/stream.h" namespace xla { @@ -30,10 +35,8 @@ namespace xla { // viewpoint of each of stream that may access it. // // Each logical buffer in an XLA computation may be defined (i.e., written to) -// at most once, although the same physical piece of memory may be reused for -// multiple logical buffers. We call the operation that writes the buffer's -// value on some stream (e.g., a transfer or compute kernel) the buffer's -// definition event. +// at most once. We call the operation that writes the buffer's value on some +// stream (e.g., a transfer or compute kernel) the buffer's definition event. // // After the operation that populates the value of a buffer has been enqueued on // 'stream', RecordOnStream(stream) should also be called to trigger the @@ -50,6 +53,9 @@ namespace xla { // The dependency logic caches the set of streams at the tail of which the // definition event is known to have occurred; waiting for the same event on the // same stream causes no additional waiting. +// +// TODO(misard) Rename this BufferSequencingEvent now that it is used for Usage +// events as well. class BufferDefinitionEvent { public: BufferDefinitionEvent() = default; @@ -65,51 +71,108 @@ class BufferDefinitionEvent { // called, blocks the calling thread until the event has been recorded. void WaitForEventOnStream(se::Stream* stream); + // Returns true if the event is known to have occurred by the tail of + // 'stream'. If RecordOnStream has not yet been called, blocks the calling + // thread until the event has been recorded. + bool DefinedOn(se::Stream* stream); + + // Returns true if the event is known by the host to have already occurred. If + // RecordOnStream has not yet been called, blocks the calling thread until the + // event has been recorded. + bool IsComplete(); + + // Compares the sequence numbers of two recorded events. It is illegal to call + // the comparison operators unless both events have been recorded. + inline bool operator<(const BufferDefinitionEvent& rhs) const { + return sequence_number() < rhs.sequence_number(); + } + inline bool operator>(const BufferDefinitionEvent& rhs) const { + return rhs < *this; + } + inline bool operator<=(const BufferDefinitionEvent& rhs) const { + return !(*this > rhs); + } + inline bool operator>=(const BufferDefinitionEvent& rhs) const { + return !(*this < rhs); + } + private: - bool EventHasBeenRecorded() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + bool EventHasBeenRecorded() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + uint64 sequence_number() const; // An event that is triggered when the content of one or more buffers is // ready. If this event is nullptr, it is assumed that the buffer's content is // always defined. EventPool::Handle event_; - absl::Mutex mu_; - + mutable absl::Mutex mu_; // A list of all streams for which the buffer's content is known to be defined // at the tail of the queue, i.e., for any newly enqueued command. absl::InlinedVector streams_defined_on_ TF_GUARDED_BY(mu_); }; -// Class that represents a node in a reference-counted DAG of device buffers. -// Unlike a ShapedBuffer, which owns none of its buffers, and -// ScopedShapedBuffer, which owns an entire buffer tree, the reference counting -// in a SharedDeviceBuffer DAG is done at the level of individual device -// buffers. Reference counting buffer individually is more convenient when -// manipulating on-device tuples where a tuple and its elements may have -// different lifetimes. +// Class that represents a tuple of device buffers. Like a ScopedShapedBuffer it +// owns all of the device memory in the tuple. It also tracks the definition and +// usage of the memory on streams, to allow for synchronized usage and deletion +// of memory under all of the allocation model semantics. class SharedDeviceBuffer { public: - // Converts a ScopedShapedBuffer into a Buffer tree. Takes ownership of the - // buffers of the shaped_buffer. + // Converts a ScopedShapedBuffer into a SharedDeviceBuffer. Takes ownership of + // the buffers of the shaped_buffer. static std::shared_ptr FromScopedShapedBuffer( ScopedShapedBuffer* shaped_buffer, absl::Span> definition_events); - // Makes a tuple buffer. Does not initialize the tuple table. - static StatusOr> MakeTuple( - std::vector> children, - const Shape& on_host_shape, TransferManager* transfer_manager, - se::DeviceMemoryAllocator* allocator, int device_ordinal, - absl::Span> - definition_events); + // Helper class to retain a "hold" on a SharedDeviceBuffer while it is being + // enqueued on a stream. If the enqueue completes successfully the hold + // should be released using a call to Convert. If the ScopedUsage is deleted + // without Convert being called, e.g., on error, the hold is dropped. + // Deletion of a buffer will block until all ScopedUsage objects referencing + // it are either deleted or have their Convert methods called. + class ScopedUsage { + public: + ScopedUsage() = default; + ~ScopedUsage(); + ScopedUsage(ScopedUsage&&) = default; + ScopedUsage(const ScopedUsage&) = delete; + ScopedUsage& operator=(const ScopedUsage&) = delete; - // Makes an uninitialized array buffer. - static StatusOr> MakeArray( - Shape on_device_shape, TransferManager* transfer_manager, - se::DeviceMemoryAllocator* allocator, int device_ordinal, - absl::Span> - definition_events); + ScopedUsage& Acquire(std::shared_ptr parent); + std::shared_ptr Release(); + void Transfer(std::shared_ptr parent); + + bool IsValid() { return parent_ != nullptr; } + SharedDeviceBuffer* operator->() const { return parent_.get(); } + const SharedDeviceBuffer& operator*() const { return *parent_; } + std::shared_ptr buffer_reference() const { + return parent_; + } + + // Converts the usage hold into a usage event. + // + // usage_stream: a stream that the buffer was used on. + // event: an event that has been recorded on usage_stream after + // the buffer was used. + // reference_held: true if and only if the caller has caused a memory + // reference to *this to stay live until after the host + // is sure that the usage (transfer or execution) has + // completed. + void Convert(se::Stream* usage_stream, + std::shared_ptr event, + bool reference_held); + + private: + std::shared_ptr parent_; + }; + + // Increments the count of external frameworks, e.g., Numpy, that the buffer + // is shared with. Operations that require exclusive access, such as update in + // place, will fail if any external references are held. + void AddExternalReference(); + + // Decrements the count of external frameworks that the buffer is shared with. + void DropExternalReference(); // Builds a ShapedBuffer view onto the buffers of 'tree'. We require but do // not verify that TransferManager::HostShapeToDeviceShape(on_host_shape) == @@ -118,9 +181,6 @@ class SharedDeviceBuffer { const Shape& on_device_shape, se::Platform* platform) const; - const std::vector>& children() const { - return children_; - } se::DeviceMemoryAllocator* allocator() const { return allocator_; } int device_ordinal() const { return device_ordinal_; } absl::InlinedVector& device_memory() { @@ -134,20 +194,57 @@ class SharedDeviceBuffer { return definition_events_; } - SharedDeviceBuffer() = default; + // Helper object to keep track of usage of the buffer on streams. + struct StreamAndEvent { + // A stream the buffer has been used on. + se::Stream* stream; + // An event that is later than the most recent usage of the buffer on + // stream. + std::shared_ptr event; + // True if and only if a reference to the buffer is kept live until after + // the host knows that event is complete. + bool reference_held; + }; + using StreamAndEventContainer = absl::InlinedVector; + // Returns the set of streams that the buffer was used on, and for each stream + // an event later than the last use of the buffer. After + // LockUseAndTransferUsageEvents is called it is illegal to use the buffer on + // any stream and, e.g. AddUsageHold will CHECK fail. + StreamAndEventContainer LockUseAndTransferUsageEvents(); + + SharedDeviceBuffer() + : in_use_(true), usage_holds_(0), external_references_(0) {} SharedDeviceBuffer(se::DeviceMemoryAllocator* allocator, int device_ordinal, absl::Span device_memory, - std::vector> children, absl::Span> definition_events, std::function on_delete_callback); - SharedDeviceBuffer(absl::Span device_memory, - std::vector> children, - absl::Span> - definition_events); ~SharedDeviceBuffer(); private: + friend class ScopedUsage; + + // Indicates that the buffer is going to be used on a stream. Deletion of + // the buffer will block until there are no remaining ScopedUsage objects. + void AddUsageHold(); + + // Indicates that a previous usage hold can be discarded, e.g., because of an + // error while an action was being enqueued on a stream. + void DropUsageHold(); + + // Indicates that a previous usage hold can be converted into a usage event. + // + // usage_stream: a stream that the buffer was used on. + // event: an event that has been recorded on usage_stream after the + // buffer was used. + // reference_held: true if and only if the caller has caused a memory + // reference to *this to stay live until after the host + // is sure that the usage (transfer or execution) has + // completed. + void ConvertUsageHold(se::Stream* usage_stream, + std::shared_ptr event, + bool reference_held); + // Are the buffers in device_memory_ owned? If so, which allocator and device // ordinal? May be nullptr, indicating the buffers are not owned. se::DeviceMemoryAllocator* allocator_; @@ -155,26 +252,38 @@ class SharedDeviceBuffer { // Each host-side buffer may have several buffers on-device. absl::InlinedVector device_memory_; - std::vector> children_; - // An event that is triggered when the content of one or more buffers is - // ready during multistream execution. May be nullptr, which is used in the + // Events that are triggered when the content of one or more buffers is ready + // during multistream execution. May be nullptr, which is used in the // single-stream execution case where events are not necessary for buffer - // event sequencing. + // event sequencing. All events must be triggered before the buffers can be + // used. absl::InlinedVector, 2> definition_events_; + absl::Mutex mu_; + // in_use_ starts out true, and is set to false when the buffer is released + // from its owning PyLocalBuffer. Once in_use_ is false, the buffer may no + // longer be used on any stream. + bool in_use_ TF_GUARDED_BY(mu_); + // Count of operations that are currently enqueuing the buffer onto a stream. + int usage_holds_ TF_GUARDED_BY(mu_); + // Set of streams that the buffer has ever been used on, see comment on + // StreamAndEvent. + StreamAndEventContainer usage_events_ TF_GUARDED_BY(mu_); + // Count of external frameworks that hold a reference to this buffer. + int external_references_ TF_GUARDED_BY(mu_); + // A callback to call when the SharedDeviceBuffer is about to be destroyed. std::function on_delete_callback_; }; -// Populates 'events' with the set of buffer definition events for all buffers -// in the buffer DAG rooted at 'buffer'. +// Populates 'events' with the set of buffer definition events for buffer. void GetDeviceBufferDefinitionEvents( const SharedDeviceBuffer& buffer, absl::flat_hash_set* events); -// Waits for all of the buffer definition events in a buffer DAG on 'stream'. +// Waits for all of the definition events in a buffer on 'stream'. void WaitForBufferDefinitionEventsOnStream(const SharedDeviceBuffer& buffer, se::Stream* stream); diff --git a/tensorflow/compiler/xla/python/shared_device_buffer_test.cc b/tensorflow/compiler/xla/python/shared_device_buffer_test.cc index 05842c52a0c..ddf02dcb2de 100644 --- a/tensorflow/compiler/xla/python/shared_device_buffer_test.cc +++ b/tensorflow/compiler/xla/python/shared_device_buffer_test.cc @@ -15,56 +15,37 @@ limitations under the License. #include "tensorflow/compiler/xla/python/shared_device_buffer.h" +#include + #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" namespace xla { namespace { -TEST(SharedDeviceBufferTest, MakeArray) { - LocalClient* client = ClientLibrary::LocalClientOrDie(); - - Shape shape = ShapeUtil::MakeShape(F32, {3, 101, 4}); - TF_ASSERT_OK_AND_ASSIGN(auto buffer, - SharedDeviceBuffer::MakeArray( - shape, client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, {})); - EXPECT_EQ(buffer->children().size(), 0); - EXPECT_EQ(buffer->device_ordinal(), 0); - EXPECT_EQ(buffer->allocator(), client->backend().memory_allocator()); - ASSERT_EQ(buffer->device_memory().size(), 1); - EXPECT_FALSE(buffer->device_memory()[0].is_null()); -} - -TEST(SharedDeviceBufferTest, MakeTuple) { - LocalClient* client = ClientLibrary::LocalClientOrDie(); - - Shape a_shape = ShapeUtil::MakeShape(F32, {3, 101, 4}); - Shape b_shape = ShapeUtil::MakeShape(S8, {77}); - Shape tuple_shape = ShapeUtil::MakeTupleShape({a_shape, b_shape}); - TF_ASSERT_OK_AND_ASSIGN(auto a_buffer, - SharedDeviceBuffer::MakeArray( - a_shape, client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, {})); - TF_ASSERT_OK_AND_ASSIGN(auto b_buffer, - SharedDeviceBuffer::MakeArray( - b_shape, client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, {})); - TF_ASSERT_OK_AND_ASSIGN(auto tuple_buffer, - SharedDeviceBuffer::MakeTuple( - {a_buffer, b_buffer}, tuple_shape, - client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, {})); - ASSERT_EQ(tuple_buffer->children().size(), 2); - EXPECT_EQ(tuple_buffer->children()[0], a_buffer); - EXPECT_EQ(tuple_buffer->children()[1], b_buffer); - ASSERT_EQ(tuple_buffer->device_memory().size(), 1); - EXPECT_EQ(tuple_buffer->device_ordinal(), 0); - EXPECT_EQ(tuple_buffer->allocator(), client->backend().memory_allocator()); - EXPECT_FALSE(tuple_buffer->device_memory()[0].is_null()); +StatusOr> MakeArray(const Shape& shape, + LocalClient* client) { + std::vector device_buffers; + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + client->backend().transfer_manager()->HostShapeToDeviceShape(shape), + [&](const Shape& subshape, const ShapeIndex&) -> Status { + TF_ASSIGN_OR_RETURN( + se::OwningDeviceMemory device_memory, + client->backend().memory_allocator()->Allocate( + /*device_ordinal=*/0, + client->backend().transfer_manager()->GetByteSizeRequirement( + subshape))); + device_buffers.push_back(device_memory.Release()); + return Status::OK(); + })); + return std::make_shared( + client->backend().memory_allocator(), /*device_ordinal=*/0, + device_buffers, + absl::Span>(), nullptr); } TEST(SharedDeviceBufferTest, AsShapedBuffer) { @@ -72,56 +53,46 @@ TEST(SharedDeviceBufferTest, AsShapedBuffer) { Shape a_shape = ShapeUtil::MakeShape(F32, {3, 101, 4}); Shape b_shape = ShapeUtil::MakeShape(S8, {77}); - Shape ab_tuple_shape = ShapeUtil::MakeTupleShape({a_shape, b_shape}); Shape c_shape = ShapeUtil::MakeShape(S64, {}); - Shape abc_tuple_shape = ShapeUtil::MakeTupleShape({c_shape, ab_tuple_shape}); - TF_ASSERT_OK_AND_ASSIGN(auto a_buffer, - SharedDeviceBuffer::MakeArray( - a_shape, client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, {})); - TF_ASSERT_OK_AND_ASSIGN(auto b_buffer, - SharedDeviceBuffer::MakeArray( - b_shape, client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, {})); - TF_ASSERT_OK_AND_ASSIGN(auto ab_tuple_buffer, - SharedDeviceBuffer::MakeTuple( - {a_buffer, b_buffer}, ab_tuple_shape, - client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, {})); - TF_ASSERT_OK_AND_ASSIGN(auto c_buffer, - SharedDeviceBuffer::MakeArray( - c_shape, client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, {})); - TF_ASSERT_OK_AND_ASSIGN(auto abc_tuple_buffer, - SharedDeviceBuffer::MakeTuple( - {c_buffer, ab_tuple_buffer}, abc_tuple_shape, - client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, {})); - Shape abc_tuple_device_shape = - client->backend().transfer_manager()->HostShapeToDeviceShape( - abc_tuple_shape); - - ShapedBuffer shaped_buffer = abc_tuple_buffer->AsShapedBuffer( - abc_tuple_shape, abc_tuple_device_shape, client->platform()); - EXPECT_EQ(shaped_buffer.on_host_shape(), abc_tuple_shape); - EXPECT_EQ(shaped_buffer.on_device_shape(), abc_tuple_device_shape); + TF_ASSERT_OK_AND_ASSIGN(auto a_buffer, MakeArray(a_shape, client)); + TF_ASSERT_OK_AND_ASSIGN(auto b_buffer, MakeArray(b_shape, client)); + TF_ASSERT_OK_AND_ASSIGN(auto c_buffer, MakeArray(c_shape, client)); ASSERT_EQ(a_buffer->device_memory().size(), 1); ASSERT_EQ(b_buffer->device_memory().size(), 1); ASSERT_EQ(c_buffer->device_memory().size(), 1); - ASSERT_EQ(ab_tuple_buffer->device_memory().size(), 1); - ASSERT_EQ(abc_tuple_buffer->device_memory().size(), 1); std::vector expected_buffer_sequence = { - abc_tuple_buffer->device_memory()[0], c_buffer->device_memory()[0], - ab_tuple_buffer->device_memory()[0], a_buffer->device_memory()[0], - b_buffer->device_memory()[0], - }; - auto it = shaped_buffer.buffers().begin(); + a_buffer->device_memory()[0], b_buffer->device_memory()[0], + c_buffer->device_memory()[0]}; + ShapedBuffer shaped_a = a_buffer->AsShapedBuffer( + a_shape, + client->backend().transfer_manager()->HostShapeToDeviceShape(a_shape), + client->platform()); + ShapedBuffer shaped_b = b_buffer->AsShapedBuffer( + b_shape, + client->backend().transfer_manager()->HostShapeToDeviceShape(b_shape), + client->platform()); + ShapedBuffer shaped_c = c_buffer->AsShapedBuffer( + c_shape, + client->backend().transfer_manager()->HostShapeToDeviceShape(c_shape), + client->platform()); auto expected_it = expected_buffer_sequence.begin(); - while (it != shaped_buffer.buffers().end()) { + for (auto it = shaped_a.buffers().begin(); it != shaped_a.buffers().end(); + ++it) { + ASSERT_TRUE(expected_it != expected_buffer_sequence.end()); + EXPECT_TRUE(expected_it->IsSameAs(it->second)); + ++expected_it; + } + for (auto it = shaped_b.buffers().begin(); it != shaped_b.buffers().end(); + ++it) { + ASSERT_TRUE(expected_it != expected_buffer_sequence.end()); + EXPECT_TRUE(expected_it->IsSameAs(it->second)); + ++expected_it; + } + for (auto it = shaped_c.buffers().begin(); it != shaped_c.buffers().end(); + ++it) { ASSERT_TRUE(expected_it != expected_buffer_sequence.end()); EXPECT_TRUE(expected_it->IsSameAs(it->second)); - ++it; ++expected_it; } EXPECT_TRUE(expected_it == expected_buffer_sequence.end()); @@ -140,17 +111,10 @@ TEST(SharedDeviceBufferTest, FromScopedShapedBuffer) { std::shared_ptr device_buffer = SharedDeviceBuffer::FromScopedShapedBuffer(&shaped_buffer, {}); - ASSERT_EQ(device_buffer->device_memory().size(), 1); - ASSERT_EQ(device_buffer->children().size(), 2); - - EXPECT_EQ(device_buffer->children()[0]->device_memory().size(), + EXPECT_EQ(device_buffer->device_memory().size(), ShapeUtil::SubshapeCount( client->backend().transfer_manager()->HostShapeToDeviceShape( - ShapeUtil::MakeShape(F32, {10, 3, 7})))); - EXPECT_EQ(device_buffer->children()[1]->device_memory().size(), - ShapeUtil::SubshapeCount( - client->backend().transfer_manager()->HostShapeToDeviceShape( - ShapeUtil::MakeShape(S64, {})))); + literal.shape()))); } } // namespace diff --git a/tensorflow/compiler/xla/python/tpu_driver/BUILD b/tensorflow/compiler/xla/python/tpu_driver/BUILD index 08da1c29832..4725becdedf 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/BUILD +++ b/tensorflow/compiler/xla/python/tpu_driver/BUILD @@ -1,4 +1,5 @@ load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library_cc") +load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency") load( "//tensorflow/compiler/xla/python/tpu_driver:platform/external/tools.bzl", "external_deps", @@ -61,7 +62,6 @@ cc_library( hdrs = ["grpc_tpu_driver.h"], deps = [ ":tpu_driver", - "//tensorflow:grpc++", "//tensorflow/core/platform:logging", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:util", @@ -69,6 +69,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_proto_cc", "//tensorflow/compiler/xla/python/tpu_driver:tpu_service_proto_cc", "//tensorflow/compiler/xla/python/tpu_driver:tpu_driver_proto_cc", + tf_grpc_cc_dependency(), ] + external_deps(), alwayslink = 1, ) diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc index f062afc48a4..1089b3cc8e5 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc @@ -566,7 +566,7 @@ PyTpuExecutable::ExecuteResult PyTpuExecutable::ExecuteHelper( for (const auto& core_args : all_core_arguments) { for (const auto* handle : core_args) { - for (auto pending_event : handle->DeviceBuffer()->wait_for_use) { + for (const auto& pending_event : handle->DeviceBuffer()->wait_for_use) { ready_to_execute.push_back(pending_event.get()); } } diff --git a/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc index e01aab14108..7632f21d5b2 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc @@ -474,7 +474,7 @@ GrpcTpuStream::~GrpcTpuStream() { { // Mark all remaining events invalid. absl::MutexLock lock(&events_mutex_); - for (auto e : events_) { + for (const auto& e : events_) { if (!e.second.done) { LOG(ERROR) << "Resetting: " << e.first; UpdateEventStatus(e.first, xla::Status(tensorflow::error::Code::ABORTED, @@ -669,7 +669,7 @@ void GrpcTpuStream::StreamReaderFn() { StreamResponse resp; while (stream_->Read(&resp)) { VLOG(2) << "Received response: " << resp.DebugString(); - for (const StreamResponse::Entry entry : resp.entry()) { + for (const StreamResponse::Entry& entry : resp.entry()) { EventId event_id = EventId::FromInt(entry.operation_id()); VLOG(1) << "Received response for: " << event_id; diff --git a/tensorflow/compiler/xla/python/tpu_driver/recording_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/recording_tpu_driver.cc index 655dbf67fea..da51380c104 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/recording_tpu_driver.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/recording_tpu_driver.cc @@ -522,7 +522,7 @@ xla::StatusOr> RegisterRecordingTpuDriver( std::string file; std::string worker; - for (auto config : configs) { + for (const auto& config : configs) { std::vector kv = absl::StrSplit(config, absl::MaxSplits('=', 1)); if (kv[0] == "file") { diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 2affd4b30fa..1c94c4909e3 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -173,6 +173,7 @@ struct ExtraBufferInfo { int PyLocalBufferGetBuffer(PyObject* exporter, Py_buffer* view, int flags) { auto& buffer = py::reinterpret_borrow(exporter).cast(); + std::shared_ptr device_buffer; Status status = [&]() { // Py_buffer objects are POD C structures, so we don't need to hold the GIL. // Additionally we call BlockHostUntilReady() below, which may block. @@ -197,7 +198,7 @@ int PyLocalBufferGetBuffer(PyObject* exporter, Py_buffer* view, int flags) { if ((flags & PyBUF_WRITEABLE) == PyBUF_WRITEABLE) { return InvalidArgument("XLA buffers are read-only."); } - std::shared_ptr device_buffer = buffer.DeviceBuffer(); + device_buffer = buffer.GetBufferWithExternalReference(); if (!device_buffer) { return InvalidArgument("Deleted buffer used in buffer protocol."); } @@ -219,7 +220,7 @@ int PyLocalBufferGetBuffer(PyObject* exporter, Py_buffer* view, int flags) { view->buf = const_cast(device_buffer->device_memory().front().opaque()); auto extra = absl::make_unique(); - extra->device_buffer = std::move(device_buffer); + extra->device_buffer = device_buffer; view->itemsize = ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()); view->len = ShapeUtil::ByteSizeOf(shape); view->readonly = 1; @@ -246,6 +247,9 @@ int PyLocalBufferGetBuffer(PyObject* exporter, Py_buffer* view, int flags) { return Status::OK(); }(); if (!status.ok()) { + if (device_buffer != nullptr) { + device_buffer->DropExternalReference(); + } PyErr_SetString(PyExc_BufferError, status.ToString().c_str()); return -1; } @@ -255,7 +259,9 @@ int PyLocalBufferGetBuffer(PyObject* exporter, Py_buffer* view, int flags) { } void PyLocalBufferReleaseBuffer(PyObject*, Py_buffer* buffer) { - delete static_cast(buffer->internal); + auto extra = static_cast(buffer->internal); + extra->device_buffer->DropExternalReference(); + delete extra; } PyBufferProcs PyLocalBufferProcs = []() { @@ -1009,9 +1015,7 @@ PYBIND11_MODULE(xla_extension, m) { }) .def("platform", &PyLocalBuffer::platform_name) .def("is_deleted", - [](const PyLocalBuffer& buffer) { - return buffer.DeviceBuffer() == nullptr; - }) + [](PyLocalBuffer* buffer) { return buffer->IsDeleted(); }) .def("unsafe_buffer_pointer", [](const PyLocalBuffer& buffer) -> StatusOr { TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, @@ -1206,6 +1210,12 @@ PYBIND11_MODULE(xla_extension, m) { py::return_value_policy::reference, py::keep_alive<1, 0>()); py::class_(m, "XlaComputation") + .def(py::init([](const py::bytes& serialized_hlo_module_proto) + -> std::unique_ptr { + HloModuleProto proto; + proto.ParseFromString(serialized_hlo_module_proto); + return absl::make_unique(proto); + })) .def("GetProgramShape", &XlaComputation::GetProgramShape) .def("GetSerializedProto", &GetComputationSerializedProto) .def("GetHloText", &GetComputationHloText) diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 36d5da2841b..95b760965d8 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -381,6 +381,21 @@ class ComputationsWithConstantsTest(ComputationTest): self._ExecuteAndCompareClose(c, expected=[0.75]) +class ComputationFromProtoTest(absltest.TestCase): + """Test computation execution from HLO proto.""" + + def testExecuteFromProto(self): + # Build the HLO proto + b = xla_client.ComputationBuilder("computation") + b.Add(b.Constant(np.int8(1)), b.Constant(np.int8(2))) + serialized_proto = b.Build().GetSerializedProto() + + # Load and execute the proto + c = xla_client.Computation(xla_client._xla.XlaComputation(serialized_proto)) + ans, = xla_client.execute_with_python_values(c.Compile()) + np.testing.assert_equal(ans, np.int8(3)) + + class ParametersTest(ComputationTest): """Tests focusing on Parameter ops and argument-passing.""" diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index d288e0c181f..39fa6a1c267 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -1,5 +1,9 @@ -load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency") +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_binary", + "tf_cc_test", +) load( "//tensorflow/core/platform:build_config.bzl", "tf_proto_library_cc", @@ -45,12 +49,12 @@ cc_library( srcs = ["grpc_service_main.cc"], deps = [ ":grpc_service", - "//tensorflow:grpc++", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "@com_google_absl//absl/strings:str_format", + tf_grpc_cc_dependency(), ], ) @@ -70,7 +74,6 @@ tf_cc_test( ], deps = [ ":grpc_stub", - "//tensorflow:grpc++", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -79,6 +82,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "@com_google_absl//absl/strings:str_format", + tf_grpc_cc_dependency(), ], ) @@ -88,9 +92,9 @@ cc_library( hdrs = ["grpc_service.h"], deps = [ ":xla_service_proto_cc", - "//tensorflow:grpc++", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core/distributed_runtime/rpc:grpc_util", + tf_grpc_cc_dependency(), ], ) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 5faf58f0c22..255444fb53c 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -165,6 +165,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", ], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 1f36d906e73..c8da3d3ccbe 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -4128,7 +4128,9 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { return ReplaceInstruction(transpose, operand); } - if (options_.is_layout_sensitive() && TransposeIsBitcast(transpose)) { + if (options_.is_layout_sensitive() && + options_.replace_transpose_with_bitcast() && + TransposeIsBitcast(transpose)) { ReplaceWithBitcast(transpose); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index 4251e7eb846..d3c276e9bc3 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -113,6 +113,14 @@ class AlgebraicSimplifierOptions { bool enable_reduce_of_reshape() const { return enable_reduce_of_reshape_; } + void set_replace_transpose_with_bitcast(bool replace_transpose_with_bitcast) { + replace_transpose_with_bitcast_ = replace_transpose_with_bitcast; + } + + bool replace_transpose_with_bitcast() const { + return replace_transpose_with_bitcast_; + } + private: // Metadata struct can be used to store any metadata information encapsulated // with the AlgebraicSimplierOptions that can be later used in an @@ -133,6 +141,7 @@ class AlgebraicSimplifierOptions { bool enable_conv_simplification_{true}; bool enable_window_reduce_to_reduce_replacement_{true}; bool enable_reduce_of_reshape_{true}; + bool replace_transpose_with_bitcast_{true}; int64 very_small_gather_size_{4}; Metadata metadata_; }; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 31fa125b3e1..255edf78345 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -2437,7 +2437,7 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { AlgebraicSimplifier simplifier(options); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - // Verify that the reshape is replaced. + // Verify that the transpose is replaced. EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Bitcast(m::Parameter(0)))); } @@ -2464,10 +2464,17 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); + // Don't replace transposes with bitcasts. + options.set_replace_transpose_with_bitcast(false); + AlgebraicSimplifier simplifier_no_replace(options); + ASSERT_FALSE(simplifier_no_replace.Run(m.get()).ValueOrDie()); + + // Replace transposes with bitcasts if possible. + options.set_replace_transpose_with_bitcast(true); AlgebraicSimplifier simplifier(options); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - // Verify that the reshape is replaced. + // Verify that the transpose is replaced. EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Bitcast(m::Parameter(0)))); } diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 6331f02aa81..a0fe0eaa1d9 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_propagation.h" +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" @@ -22,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -203,6 +205,33 @@ void BFloat16Propagation::DetermineWhileComputationsPrecision( computations_visited_in_backward_pass_.insert(condition); } +void BFloat16Propagation::DetermineConditionalComputationsPrecision( + HloInstruction* cond) { + CHECK_EQ(cond->opcode(), HloOpcode::kConditional); + for (int64 i = 0; i < cond->branch_count(); ++i) { + auto branch = cond->branch_computation(i); + auto root = branch->root_instruction(); + ShapeUtil::ForEachSubshape( + root->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.element_type() != F32) { + return; + } + if (OutputTypeAfterChange(cond, index) == BF16) { + AddToOrRemoveFromBF16ChangeSet(root, index, BF16); + VLOG(2) << "Conditional branch " << i << " root " + << root->ToString() << " at shape index " << index + << " changed to BF16 precision for conditional " + << cond->ToString(); + } + }); + auto insts = branch->MakeInstructionPostOrder(); + for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { + DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false); + } + computations_visited_in_backward_pass_.insert(branch); + } +} + bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, const ShapeIndex& index) const { // If the subshape isn't floating point then none of the users will be BF16. @@ -265,6 +294,14 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, return false; } continue; + } else if (use.instruction->opcode() == HloOpcode::kConditional) { + auto* cond_parameter = + use.instruction->branch_computation(use.operand_number - 1) + ->parameter_instruction(0); + if (OutputTypeAfterChange(cond_parameter, use.operand_index) != BF16) { + return false; + } + continue; } if (bfloat16_support_->EffectiveOperandPrecisionIsBF16( *use.instruction, use.operand_number)) { @@ -323,7 +360,6 @@ bool ShouldKeepPrecisionUnchanged(const HloInstruction* inst) { // assumptions for them. return inst->opcode() == HloOpcode::kCustomCall || // inst->opcode() == HloOpcode::kCall || // - inst->opcode() == HloOpcode::kConditional || // inst->opcode() == HloOpcode::kBitcastConvert || // inst->HasSideEffectNoRecurse(); } @@ -332,9 +368,10 @@ bool ShouldKeepPrecisionUnchanged(const HloInstruction* inst) { void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo, bool skip_parameters) { - // We handle any fusion computation or while body/condition after the - // instruction is handled, because we need to know the output shape of a - // fusion or while before propagating inside its computations. + // We handle any fusion computation, while body/condition or conditional + // branches after the instruction is handled, because we need to know the + // output shape of a fusion or while before propagating inside its + // computations. bool postpone_processing_called_computations = false; auto cleaner = tensorflow::gtl::MakeCleanup( [this, hlo, &postpone_processing_called_computations] { @@ -343,6 +380,8 @@ void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo, DetermineFusionComputationPrecision(hlo); } else if (hlo->opcode() == HloOpcode::kWhile) { DetermineWhileComputationsPrecision(hlo); + } else if (hlo->opcode() == HloOpcode::kConditional) { + DetermineConditionalComputationsPrecision(hlo); } } instructions_visited_in_backward_pass_.insert(hlo); @@ -355,6 +394,14 @@ void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo, return; } + if (hlo->opcode() == HloOpcode::kConditional && + absl::c_any_of(hlo->branch_computations(), [&](const HloComputation* c) { + return caller_counts_[c] > 1; + })) { + postpone_processing_called_computations = true; + return; + } + // Prevent root instructions from having their output modified by recording // all F32 output values as needing to stay as F32. CHECK(hlo->parent() != nullptr); @@ -459,6 +506,12 @@ void BFloat16Propagation::AdjustCalledComputationParameters( adjust_computation(hlo->while_condition(), hlo->operands()); adjust_computation(hlo->while_body(), hlo->operands()); break; + case HloOpcode::kConditional: + for (int64 i = 0; i < hlo->branch_count(); ++i) { + adjust_computation(hlo->branch_computation(i), + {hlo->mutable_operand(i + 1)}); + } + break; default: break; } @@ -509,6 +562,11 @@ void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) { case HloOpcode::kWhile: adjust_computation(hlo->while_body(), hlo); break; + case HloOpcode::kConditional: + for (auto* branch : hlo->branch_computations()) { + adjust_computation(branch, hlo); + } + break; default: break; } @@ -590,6 +648,11 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( } else if (hlo->opcode() == HloOpcode::kFusion) { ResolveInconsistencyOfAliasingBuffersHelper( hlo->fused_instructions_computation(), visited_computations); + } else if (hlo->opcode() == HloOpcode::kConditional) { + for (auto* branch : hlo->branch_computations()) { + ResolveInconsistencyOfAliasingBuffersHelper(branch, + visited_computations); + } } } // Now adjust parameters of called computations. diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h index 5fcaa15c835..200599efab2 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.h +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h @@ -110,6 +110,11 @@ class BFloat16Propagation : public HloModulePass { // Precondition: hlo->opcode() == kWhile void DetermineWhileComputationsPrecision(HloInstruction* while_hlo); + // Special handling in the opportunity-finding pass for conditional branches. + // + // Precondition: hlo->opcode() == kConditional + void DetermineConditionalComputationsPrecision(HloInstruction* cond); + // The set of HloInstructions that have been visited in the // opportunity-finding pass. absl::flat_hash_set diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 048c0edc4a5..02d79025f1b 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -1046,4 +1046,114 @@ TEST_F(BFloat16PropagationTest, TupleDomainNoPropagation) { EXPECT_FALSE(OutputsBF16(param)); } +TEST_F(BFloat16PropagationTest, ConditionalSeparateBranchOperands) { + const string module_str = R"( +HloModule module + +true_branch { + true_param = f32[4096,4096] parameter(0) + ROOT max = f32[4096,4096] maximum(true_param, true_param) +} + +false_branch { + false_param = f32[4096,4096] parameter(0) + ROOT add = f32[4096,4096] add(false_param, false_param) +} + +ENTRY entry { + param0 = f32[4096,4096] parameter(0) + param1 = f32[4096,4096] parameter(1) + copy0 = f32[4096,4096] copy(param0) + copy1 = f32[4096,4096] copy(param1) + param2 = pred[] parameter(2) + conditional = f32[4096,4096] conditional(param2, copy0, copy1), + true_computation=true_branch, false_computation=false_branch + ROOT dot = f32[4096,4096] dot(conditional, conditional), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + EXPECT_TRUE(PropagatePrecision(module.get())); + + auto cond = FindInstruction(module.get(), "conditional"); + auto copy0 = FindInstruction(module.get(), "copy0"); + auto copy1 = FindInstruction(module.get(), "copy1"); + EXPECT_TRUE(OutputsBF16(cond)); + EXPECT_TRUE(OutputsBF16(copy0)); + EXPECT_FALSE(OutputsBF16(copy1)); +} + +TEST_F(BFloat16PropagationTest, ConditionalSharedBranchOperands) { + const string module_str = R"( +HloModule module + +true_branch { + true_param = f32[4096,4096] parameter(0) + ROOT max = f32[4096,4096] maximum(true_param, true_param) +} + +false_branch { + false_param = f32[4096,4096] parameter(0) + ROOT add = f32[4096,4096] add(false_param, false_param) +} + +ENTRY entry { + param0 = f32[4096,4096] parameter(0) + copy0 = f32[4096,4096] copy(param0) + param1 = pred[] parameter(1) + conditional = f32[4096,4096] conditional(param1, copy0, copy0), + true_computation=true_branch, false_computation=false_branch + ROOT dot = f32[4096,4096] dot(conditional, conditional), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + EXPECT_TRUE(PropagatePrecision(module.get())); + + auto cond = FindInstruction(module.get(), "conditional"); + auto copy0 = FindInstruction(module.get(), "copy0"); + EXPECT_TRUE(OutputsBF16(cond)); + EXPECT_FALSE(OutputsBF16(copy0)); +} + +TEST_F(BFloat16PropagationTest, ConditionalAliasingOutputs) { + const string module_str = R"( +HloModule module + +true_branch { + true_param = f32[4096,4096] parameter(0) + max = f32[4096,4096] maximum(true_param, true_param) + ROOT true_tuple = (f32[4096,4096], f32[4096,4096]) tuple(max, max) +} + +false_branch { + false_param = f32[4096,4096] parameter(0) + min = f32[4096,4096] minimum(false_param, false_param) + max2 = f32[4096,4096] maximum(false_param, false_param) + ROOT false_tuple = (f32[4096,4096], f32[4096,4096]) tuple(min, max2) +} + +ENTRY entry { + param0 = f32[4096,4096] parameter(0) + copy0 = f32[4096,4096] copy(param0) + param1 = pred[] parameter(1) + conditional = (f32[4096,4096], f32[4096,4096]) conditional(param1, copy0, copy0), + true_computation=true_branch, false_computation=false_branch + gte0 = f32[4096,4096] get-tuple-element(conditional), index=0 + gte1 = f32[4096,4096] get-tuple-element(conditional), index=1 + dot = f32[4096,4096] dot(gte0, gte1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT tuple = (f32[4096,4096], f32[4096,4096]) tuple(dot, gte1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + EXPECT_FALSE(PropagatePrecision(module.get())); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 8fbe29f417c..53d0d14f598 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -277,8 +277,8 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( { auto& pass = pipeline.AddPass>("simplification"); - pass.AddInvariantChecker(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false); + pass.AddInvariantCheckerDebug(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); pass.AddPass(); pass.AddPass(); @@ -339,18 +339,18 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( LLVMTargetMachineFeatures* target_machine_features) { HloPassPipeline pipeline("HLO passes after layout assignment"); // After layout assignment, use a layout-sensitive verifier. - auto& after_layout_assn = - pipeline.AddPass("after layout assignment"); - after_layout_assn.AddInvariantChecker( - /*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + + pipeline.AddPass("after layout assignment") + .AddInvariantCheckerDebug( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. { auto& pass = pipeline.AddPass>( "simplification after layout assignment"); - pass.AddInvariantChecker( + pass.AddInvariantCheckerDebug( /*layout_sensitive=*/true, /*allow_mixed_precision=*/false, LayoutAssignment::InstructionCanChangeLayout); diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index 94815e2fdbc..5f5d02f92b6 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -107,6 +107,8 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault { Status HandleScatter(HloInstruction* hlo) override; + Status HandleDomain(HloInstruction* hlo) override; + private: using DimensionConstraint = DynamicDimensionInference::DimensionConstraint; using OperandDynamicDimensionFn = std::functioncustom_call_target() != "SliceToDynamic" || + if ((hlo->custom_call_target() != "SliceToDynamic" && + hlo->custom_call_target() != "Sharding") || absl::StartsWith(hlo->custom_call_target(), "Resize")) { return Unimplemented( "CustomCall is not supported to have a dynamic dimension"); @@ -577,6 +580,10 @@ Status DynamicDimensionInferenceVisitor::PassThroughDynamicDimension( }); } +Status DynamicDimensionInferenceVisitor::HandleDomain(HloInstruction* hlo) { + return PassThroughDynamicDimension(hlo); +} + Status DynamicDimensionInferenceVisitor::HandleElementwiseUnary( HloInstruction* hlo) { return PassThroughDynamicDimension(hlo); diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc index 88e8cfd38ff..0b176031e8d 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -112,6 +112,7 @@ StatusOr ChooseIdentityValue(HloInstruction* inst, case HloOpcode::kTranspose: case HloOpcode::kSort: case HloOpcode::kSlice: + case HloOpcode::kDomain: return nullptr; // Assume that custom calls created by the client are valid with padded // dynamic dimensions. diff --git a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc index aba3032df45..75d39298aa3 100644 --- a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc +++ b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc @@ -122,8 +122,11 @@ void FusionNodeIndexingEvaluation::UpdateIndexingUsersOfOperands( operand = fusion_->operand(operand->parameter_number()); } // For simplicity we assume that all shape and layout changing - // operations invalidate index reuse. - if (Shape::Equal().IgnoreElementType()(operand->shape(), + // operations except Transposes invalidate index reuse. Transposes are + // special: although they are shape changing, we can reuse the + // multi-dimensional index for the operand by permuting it. + if (instruction->opcode() == HloOpcode::kTranspose || + Shape::Equal().IgnoreElementType()(operand->shape(), instruction->shape())) { // If the index is reused, it means the operand gets index values // from the same set of (indirect) users as 'instruction' itself. diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 0877ac2cfc7..61bc41283e1 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1206,7 +1206,6 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:llvm_compiler", - "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:rng_bit_generator_expander", "//tensorflow/compiler/xla/service:rng_expander", "//tensorflow/compiler/xla/service:slice_sinker", diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc index b78748edb7e..974db02b1b3 100644 --- a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc @@ -76,8 +76,9 @@ Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( // Convert convolutions into CustomCalls to MIOpen, then canonicalize them // (PadInsertion). HloPassPipeline pipeline("conv_canonicalization"); - pipeline.AddInvariantChecker(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false); + pipeline.AddInvariantCheckerDebug( + /*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); pipeline.AddPass(); pipeline.AddPass(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 87054d8322a..3225d47e531 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -86,7 +86,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h" #include "tensorflow/compiler/xla/service/rng_expander.h" #include "tensorflow/compiler/xla/service/slice_sinker.h" @@ -178,8 +177,9 @@ Status GpuCompiler::OptimizeHloModule( { auto& pass = pipeline.AddPass>("simplification"); - pass.AddInvariantChecker(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false); + pass.AddInvariantCheckerDebug( + /*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); // If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls // where possible. Not every batchnorm op can be implemented as a call to @@ -205,6 +205,15 @@ Status GpuCompiler::OptimizeHloModule( pipeline.AddPass(); AlgebraicSimplifierOptions options; + // When transposes appear in a fusion node, we can easily adjust the + // multi-dimensional index to create the one needed for the operand. This + // is not as easy with bitcasts, because we don't have the information + // readily available which dimensions are permuted. In addition to that, + // if we have a transpose and a reshape next to each other, they will both + // be replaced by a bitcast, and we replace bitcast(bitcast) with one + // bitcast. This leads to having to linearize and then delinearize the + // index. + options.set_replace_transpose_with_bitcast(false); pass.AddPass(options); // AlgebraicSimplifier may add contracting dimensions to a dot. pass.AddPass(); @@ -217,7 +226,6 @@ Status GpuCompiler::OptimizeHloModule( // pass.AddPass(); pass.AddPass(); - pass.AddPass(); pass.AddPass(); pass.AddPass(); } @@ -279,7 +287,7 @@ Status GpuCompiler::OptimizeHloModule( fusion.AddPass(); /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after * fixing the ticket. */ - fusion.AddInvariantChecker( + fusion.AddInvariantCheckerDebug( /*layout_sensitive=*/true, /*allow_mixed_precision=*/false, LayoutAssignment::InstructionCanChangeLayout); @@ -306,6 +314,13 @@ Status GpuCompiler::OptimizeHloModule( /*combine_threshold_count=*/256); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } + { + // Now we allow to replace any transposes outside of fusions with bitcasts. + HloPassPipeline pipeline("final_algebraic_simplifier"); + AlgebraicSimplifierOptions options; + options.set_is_layout_sensitive(true); + pipeline.AddPass(options); + } return Status::OK(); } @@ -320,7 +335,7 @@ Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) { HloPassPipeline pipeline("GPU-ir-emit-prepare"); /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after * fixing the ticket. */ - pipeline.AddInvariantChecker( + pipeline.AddInvariantCheckerDebug( /*layout_sensitive=*/true, /*allow_mixed_precision=*/false, LayoutAssignment::InstructionCanChangeLayout); @@ -359,7 +374,7 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( HloPassPipeline pipeline("post-layout_assignment"); /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after * fixing the ticket. */ - pipeline.AddInvariantChecker( + pipeline.AddInvariantCheckerDebug( /*layout_sensitive=*/true, /*allow_mixed_precision=*/false, LayoutAssignment::InstructionCanChangeLayout); @@ -372,6 +387,15 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( // duplicate or NOPs, so remove them with algebraic simplification and CSE. AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); + // When transposes appear in a fusion node, we can easily adjust the + // multi-dimensional index to create the one needed for the operand. This + // is not as easy with bitcasts, because we don't have the information + // readily available which dimensions are permuted. In addition to that, + // if we have a transpose and a reshape next to each other, they will both + // be replaced by a bitcast, and we replace bitcast(bitcast) with one + // bitcast. This leads to having to linearize and then delinearize the + // index. + options.set_replace_transpose_with_bitcast(false); pipeline.AddPass>(options); if (RequireDeterminism() || diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 528a847b3ed..6c5fe891360 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1363,7 +1363,7 @@ GetHloBufferSlices(const HloInstruction* hlo, // appear before any GTE instructions, because it's illegal to bitcast to a // tuple type. const HloInstruction* parent = instr; - while (parent->opcode() == HloOpcode::kBitcast) { + while (parent->IsEffectiveBitcast()) { parent = parent->operand(0); auto slice = buffer_assn.GetUniqueSlice(parent, {}); diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 4f46e292210..4eb24375eb7 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -109,8 +109,9 @@ Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( // Convert convolutions into CustomCalls to cudnn, then canonicalize them // (GpuConvPaddingLegalization). Also expand cuSolver calls. HloPassPipeline pipeline("conv_canonicalization"); - pipeline.AddInvariantChecker(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false); + pipeline.AddInvariantCheckerDebug( + /*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); @@ -127,10 +128,19 @@ Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( { auto& pass = pipeline.AddPass>( "algebraic_simplification_post_conv_rewriter"); - pass.AddInvariantChecker(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false); + pass.AddInvariantCheckerDebug(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); AlgebraicSimplifierOptions options; + // When transposes appear in a fusion node, we can easily adjust the + // multi-dimensional index to create the one needed for the operand. This + // is not as easy with bitcasts, because we don't have the information + // readily available which dimensions are permuted. In addition to that, + // if we have a transpose and a reshape next to each other, they will both + // be replaced by a bitcast, and we replace bitcast(bitcast) with one + // bitcast. This leads to having to linearize and then delinearize the + // index. + options.set_replace_transpose_with_bitcast(false); options.set_cudnn_batchnorm_forward_training_metadata( kCudnnBatchNormForwardTrainingCallTarget); pass.AddPass(options); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 2e089f34bac..94a4df43cf4 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -498,7 +498,10 @@ Status HloCostAnalysis::HandleBatchNormGrad(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleTranspose(const HloInstruction*) { +Status HloCostAnalysis::HandleTranspose(const HloInstruction* transpose) { + if (transpose->IsEffectiveBitcast()) { + return HandleBitcast(transpose); + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 484ed3eaa6c..22b74663087 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -2160,6 +2160,13 @@ Status HloInstruction::ReplaceAllUsesWithDifferentShape( return Status::OK(); } +bool HloInstruction::IsEffectiveBitcast() const { + return opcode_ == HloOpcode::kBitcast || + (opcode_ == HloOpcode::kTranspose && + ShapeUtil::TransposeIsBitcast(operand(0)->shape(), shape(), + dimensions())); +} + HloComputation* HloInstruction::to_apply() const { switch (opcode_) { case HloOpcode::kCall: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index fdeea10c496..98f2a20d505 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -1233,6 +1233,11 @@ class HloInstruction { const_cast(this)->LatestNonGteAncestor()); } + // Returns true whether this instruction is effectively a bitcast. Currently, + // this means it either is a bitcast, or it is a transpose that is effectively + // a bitcast. + bool IsEffectiveBitcast() const; + // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc. // The setter should only be called by HloModule or HloComputation methods. // diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h index 16fad113b0d..72549aaa681 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h @@ -70,6 +70,14 @@ class HloPassPipeline : public HloPassInterface { return *pass; } + // Add an invariant-checking pass to the pipeline on debug builds only. + template + void AddInvariantCheckerDebug(Args&&... args) { +#ifndef NDEBUG + AddInvariantChecker(std::forward(args)...); +#endif // NDEBUG + } + StatusOr Run(HloModule* module) override; StatusOr RunOnModuleGroup(HloModuleGroup* module_group) override; diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index 5045d7b0c13..7fbd01e1b21 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -257,8 +257,11 @@ bool FusedIrEmitter::IsFusedIrEmitterInefficient( } for (const auto* operand : instruction->operands()) { // For simplicity we assume that all shape and layout changing - // operations invalidate index reuse. - if (Shape::Equal().IgnoreElementType()(operand->shape(), + // operations except Transposes invalidate index reuse. Transposes are + // special: although they are shape changing, we can reuse the + // multi-dimensional index for the operand by permuting it. + if (instruction->opcode() == HloOpcode::kTranspose || + Shape::Equal().IgnoreElementType()(operand->shape(), instruction->shape())) { // If the index is reused, it means the operand gets index values // from the same set of (indirect) users as 'instruction' itself. diff --git a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc index c12418a0c49..c593938127a 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc @@ -186,7 +186,7 @@ Status HloDialectEmitter::HandleReduce(HloInstruction* instr) { reduceOp.body().push_back(block); HloDialectEmitter emitter(emission_context_, &reduceOp.body(), arguments); TF_ASSIGN_OR_RETURN(auto result, emitter.EmitComputation(*computation)); - OpBuilder body_builder(block); + OpBuilder body_builder = OpBuilder::atBlockEnd(block); body_builder.setInsertionPointToEnd(block); body_builder.create(getLocation(instr), ArrayRef{result}); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 3255aa84685..4f9e3a4d083 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -2538,6 +2538,7 @@ xla_test( tags = [ "enable_for_xla_interpreter", "noasan", # sometimes times out, http://b/78650012 + "notsan", # sometimes times out, http://b/78650012 ], deps = [ ":test_macros_header", diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 9a19427a96a..0fd5f191db0 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -64,6 +64,30 @@ ENTRY main { RunTest(hlo_text, &operand, &start_indices); } +XLA_TEST_F(GatherOperationTest, BatchDimInMiddle) { + // Reverse the middle dimension (dim 1). + const string hlo_text = R"( +HloModule BatchDimInMiddle + +ENTRY main { + operand = s32[3, 2, 3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[3, 1, 2, 3] gather(operand, indices), + offset_dims={0, 1, 3}, + collapsed_slice_dims={}, + start_index_map={1}, + index_vector_dim=1, + slice_sizes={3, 1, 3} +} +)"; + Literal operand = + LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}, + {{7, 8, 9}, {10, 11, 12}}, + {{13, 14, 15}, {16, 17, 18}}}); + Literal start_indices = LiteralUtil::CreateR1({1, 0}); + RunTest(hlo_text, &operand, &start_indices); +} + XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) { const string hlo_text = R"( HloModule TensorFlowGatherV2 diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 3468c12d8c9..f8bd7a0750e 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -265,7 +265,11 @@ message DebugOptions { // Guarantee run-to-run determinism from reductions on XLA:GPU. bool xla_gpu_deterministic_reductions = 130; - // Next id: 135 + // Debug options that trigger execution errors when NaN or Inf are detected. + bool xla_tpu_detect_nan = 135; + bool xla_tpu_detect_inf = 136; + + // Next id: 137 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 44d68782458..7d62274e87f 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -68,6 +68,7 @@ load( "cc_header_only_library", "if_android", "if_chromiumos", + "if_cuda_or_rocm", "if_ios", "if_mobile", "if_not_windows", @@ -81,7 +82,6 @@ load( "tf_features_nomodules_if_mobile", "tf_gen_op_libs", "tf_genrule_cmd_append_to_srcs", - "tf_openmp_copts", "tf_opts_nortti_if_lite_protos", "tf_opts_nortti_if_mobile", "tf_portable_full_lite_protos", @@ -100,9 +100,6 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu") # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_cc_tests_gpu") -# buildifier: disable=same-origin-load -load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") - # buildifier: disable=same-origin-load # Placeholder: load("//tensorflow:tensorflow.bzl", "tf_portable_proto_lib") @@ -113,7 +110,6 @@ load("//tensorflow:tensorflow.bzl", "tf_monitoring_deps") load( "//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", - "tf_additional_core_deps", "tf_additional_lib_deps", "tf_additional_test_deps", "tf_jspb_proto_library", @@ -122,9 +118,7 @@ load( "tf_portable_deps_no_runtime", "tf_proto_library", "tf_proto_library_cc", - "tf_protos_all", "tf_protos_all_impl", - "tf_protos_grappler", "tf_protos_grappler_impl", "tf_protos_profiler_impl", "tf_pyclif_proto_library", @@ -173,7 +167,10 @@ exports_files([ "ops/ops.pbtxt", ]) -package_group(name = "experimental_access") +package_group( + name = "experimental_access", + packages = ["//tensorflow/core/common_runtime/..."], +) # Authorized users go here. package_group(name = "friends") @@ -977,33 +974,17 @@ cc_library( alwayslink = 1, ) -tf_cuda_library( +alias( name = "core_cpu", - hdrs = [ - "common_runtime/device.h", - "common_runtime/device_factory.h", - "common_runtime/function.h", - "common_runtime/function_optimization_registry.h", - "common_runtime/optimization_registry.h", - "common_runtime/shape_refiner.h", - "//tensorflow/core/graph:core_cpu_headers", - "//tensorflow/core/public:session.h", - "//tensorflow/core/public:session_options.h", - ], + actual = "//tensorflow/core/common_runtime:core_cpu", visibility = ["//visibility:public"], - deps = [ - ":core_cpu_internal", - ], ) -cc_library( +alias( name = "core", + actual = + "//tensorflow/core/common_runtime:core", visibility = ["//visibility:public"], - deps = [ - ":core_cpu", - ":gpu_runtime", - ":sycl_runtime", - ], ) # This includes implementations of all kernels built into TensorFlow. @@ -1096,8 +1077,9 @@ cc_library( "//tensorflow/core/kernels:mkl_matmul_op", "//tensorflow/core/kernels:mkl_tfconv_op", "//tensorflow/core/kernels:mkl_tmp_bf16_ops", - ]) + if_cuda([ + ]) + if_cuda_or_rocm([ "//tensorflow/core/kernels:cudnn_rnn_kernels", + ]) + if_cuda([ "//tensorflow/core/grappler/optimizers:gpu_swapping_kernels", "//tensorflow/core/grappler/optimizers:gpu_swapping_ops", ]) + if_nccl([ @@ -1176,14 +1158,11 @@ cc_library( name = "testlib", testonly = 1, srcs = [ - "common_runtime/function_testlib.cc", - "common_runtime/kernel_benchmark_testlib.cc", + "//tensorflow/core/common_runtime:testlib_srcs", "//tensorflow/core/graph:testlib_srcs", ], hdrs = [ - "common_runtime/function_testlib.h", - "common_runtime/kernel_benchmark_testlib.h", - "common_runtime/test_collective_executor_mgr.h", + "//tensorflow/core/common_runtime:testlib_headers", "//tensorflow/core/graph:testlib_headers", # TODO(josh11b): Drop this once users are depending on # kernels:ops_testutil instead. @@ -1221,16 +1200,11 @@ cc_library( ], ) -cc_library( +alias( name = "testlib_ops", testonly = 1, - srcs = ["common_runtime/testlib_ops.cc"], - linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel - deps = [ - ":framework", - ":lib", - ], - alwayslink = 1, + actual = + "//tensorflow/core/common_runtime:testlib_ops", ) # This is a link-only library to provide a DirectSession @@ -1241,7 +1215,7 @@ tf_cuda_library( linkstatic = 1, visibility = ["//visibility:public"], deps = [ - ":direct_session_internal", + "//tensorflow/core/common_runtime:direct_session_internal", ], alwayslink = 1, ) @@ -1306,6 +1280,7 @@ filegroup( # Sources for which we do not yet have granular targets. "//tensorflow/c/eager:srcs", "//tensorflow/c:srcs", + "//tensorflow/core/common_runtime:mobile_srcs_only_runtime", "//tensorflow/core/common_runtime/eager:srcs", "//tensorflow/core/framework:mobile_srcs_only_runtime", "//tensorflow/core/graph:mobile_srcs_only_runtime", @@ -1329,8 +1304,6 @@ filegroup( "//tensorflow/core/platform:mobile_srcs_only_runtime", ] + glob( [ - "common_runtime/**/*.cc", - "common_runtime/**/*.h", "lib/wav/*.cc", "lib/wav/*.h", ], @@ -1339,8 +1312,6 @@ filegroup( "**/*testutil*", "**/*testlib*", "**/*main.cc", - "common_runtime/gpu/**/*", - "common_runtime/gpu_device_factory.*", ], ), visibility = ["//visibility:public"], @@ -1602,46 +1573,23 @@ cc_library( # Libraries with GPU facilities that are useful for writing kernels. cc_library( name = "gpu_lib", - srcs = [ - "common_runtime/gpu/gpu_event_mgr.cc", - ], - hdrs = [ - "common_runtime/gpu/gpu_event_mgr.h", - ], - copts = tf_copts(), visibility = ["//visibility:public"], deps = [ - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":protos_all_cc", - ":stream_executor", + "//tensorflow/core/common_runtime/gpu:gpu_lib", ], ) -cc_library( +alias( name = "gpu_headers_lib", - hdrs = [ - "common_runtime/gpu/gpu_event_mgr.h", - ], + actual = + "//tensorflow/core/common_runtime/gpu:gpu_headers_lib", visibility = ["//visibility:public"], ) -cc_library( +alias( name = "cuda", + actual = "//tensorflow/core/common_runtime/gpu:cuda", visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core/platform/default/build_config:cuda", - ], -) - -cc_library( - name = "rocm", - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core/platform/default/build_config:rocm", - ], ) # ----------------------------------------------------------------------------- @@ -2298,14 +2246,6 @@ cc_header_only_library( ], ) -cc_header_only_library( - name = "core_cpu_headers_lib", - visibility = ["//visibility:public"], - deps = [ - ":core_cpu_lib", - ], -) - tf_cuda_library( name = "framework_internal_impl", srcs = [ @@ -2456,20 +2396,7 @@ tf_cuda_library( filegroup( name = "core_cpu_base_headers", srcs = [ - "common_runtime/device.h", - "common_runtime/device_factory.h", - "common_runtime/device_mgr.h", - "common_runtime/device_set.h", - "common_runtime/eval_const_tensor.h", - "common_runtime/function.h", - "common_runtime/graph_runner.h", - "common_runtime/metrics.h", - "common_runtime/process_function_library_runtime.h", - "common_runtime/scoped_allocator.h", - "common_runtime/scoped_allocator_mgr.h", - "common_runtime/shape_refiner.h", - "//tensorflow/core/framework:versions.h", - "//tensorflow/core/graph:graph_headers", + "//tensorflow/core/common_runtime:core_cpu_base_headers", ], ) @@ -2480,7 +2407,7 @@ tf_cuda_library( "//tensorflow/core/public:session.h", ], copts = tf_copts(), - deps = [":core_cpu_base_no_ops"] + if_static([ + deps = ["//tensorflow/core/common_runtime:core_cpu_base_no_ops"] + if_static([ ":function_ops_op_lib", ":functional_grad", ":functional_ops_op_lib", @@ -2490,288 +2417,22 @@ tf_cuda_library( alwayslink = 1, ) -tf_cuda_library( - name = "core_cpu_base_no_ops", - srcs = [ - "common_runtime/eval_const_tensor.cc", - "common_runtime/graph_optimizer.h", - "common_runtime/scoped_allocator.cc", - "common_runtime/scoped_allocator_mgr.cc", - "common_runtime/shape_refiner.cc", - "//tensorflow/core/graph:core_cpu_base_no_ops_srcs", - "//tensorflow/core/public:session_options.h", - "//tensorflow/core/public:version.h", - ], - hdrs = [ - ":core_cpu_base_headers", - "//tensorflow/core/public:session.h", - ], - copts = tf_copts(), - deps = [ - ":graph", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":protos_all_cc", - "@com_google_absl//absl/container:flat_hash_set", - "//third_party/eigen3", - ] + if_static([ - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/strings", - ]), -) - -filegroup( - name = "core_cpu_lib_headers", - srcs = [ - ":core_cpu_base_headers", - "common_runtime/allocator_retry.h", - "common_runtime/shared_counter.h", - "common_runtime/base_collective_executor.h", - "common_runtime/bfc_allocator.h", - "common_runtime/hierarchical_tree_broadcaster.h", - "common_runtime/buf_rendezvous.h", - "common_runtime/build_graph_options.h", - "common_runtime/collective_executor_mgr.h", - "common_runtime/collective_param_resolver_local.h", - "common_runtime/collective_rma_local.h", - "common_runtime/collective_util.h", - "common_runtime/colocation_graph.h", - "common_runtime/constant_folding.h", - "common_runtime/copy_tensor.h", - "common_runtime/costmodel_manager.h", - "common_runtime/placer_inspection_required_ops_utils.h", - "common_runtime/debugger_state_interface.h", - "common_runtime/device_resolver_local.h", - "common_runtime/dma_helper.h", - "common_runtime/entry.h", - "common_runtime/executor.h", - "common_runtime/executor_factory.h", - "common_runtime/function_optimization_registry.h", - "common_runtime/graph_optimizer.h", - "common_runtime/graph_view.h", - "common_runtime/immutable_executor_state.h", - "common_runtime/input_colocation_exemption_registry.h", - "common_runtime/inspecting_placer.h", - "common_runtime/isolate_placer_inspection_required_ops_pass.h", - "common_runtime/local_device.h", - "common_runtime/lower_function_call_op.h", - "common_runtime/lower_if_op.h", - "common_runtime/lower_case_op.h", - "common_runtime/lower_functional_ops.h", - "common_runtime/lower_while_op.h", - "common_runtime/memory_types.h", - "common_runtime/mkl_cpu_allocator.h", - "common_runtime/optimization_registry.h", - "common_runtime/pending_counts.h", - "common_runtime/partitioning_utils.h", - "common_runtime/placer.h", - "common_runtime/process_util.h", - "common_runtime/propagator_state.h", - "common_runtime/profile_handler.h", - "common_runtime/renamed_device.h", - "common_runtime/rendezvous_mgr.h", - "common_runtime/rendezvous_util.h", - "common_runtime/replicate_per_replica_nodes.h", - "common_runtime/ring_reducer.h", - "common_runtime/ring_alg.h", - "common_runtime/ring_gatherer.h", - "common_runtime/session_factory.h", - "common_runtime/single_threaded_cpu_device.h", - "common_runtime/stats_publisher_interface.h", - "common_runtime/step_stats_collector.h", - "common_runtime/threadpool_device.h", - "common_runtime/process_state.h", - "common_runtime/pool_allocator.h", - "//tensorflow/core/graph:core_cpu_lib_headers", - ] + if_mkl(["//tensorflow/core/graph:mkl_graph_util_header"]), -) - -tf_cuda_library( +alias( name = "core_cpu_impl", - srcs = [ - "common_runtime/accumulate_n_optimizer.cc", - "common_runtime/base_collective_executor.cc", - "common_runtime/buf_rendezvous.cc", - "common_runtime/build_graph_options.cc", - "common_runtime/collective_executor_mgr.cc", - "common_runtime/collective_param_resolver_local.cc", - "common_runtime/collective_rma_local.cc", - "common_runtime/collective_util.cc", - "common_runtime/colocation_graph.cc", - "common_runtime/constant_folding.cc", - "common_runtime/copy_tensor.cc", - "common_runtime/costmodel_manager.cc", - "common_runtime/debugger_state_interface.cc", - "common_runtime/device.cc", - "common_runtime/device_factory.cc", - "common_runtime/device_mgr.cc", - "common_runtime/device_resolver_local.cc", - "common_runtime/device_set.cc", - "common_runtime/dynamic_device_mgr.cc", - "common_runtime/executor.cc", - "common_runtime/executor_factory.cc", - "common_runtime/function.cc", - "common_runtime/function_optimization_registry.cc", - "common_runtime/graph_optimizer.cc", - "common_runtime/graph_runner.cc", - "common_runtime/graph_view.cc", - "common_runtime/hierarchical_tree_broadcaster.cc", - "common_runtime/immutable_executor_state.cc", - "common_runtime/input_colocation_exemption_registry.cc", - "common_runtime/inspecting_placer.cc", - "common_runtime/isolate_placer_inspection_required_ops_pass.cc", - "common_runtime/local_device.cc", - "common_runtime/lower_case_op.cc", - "common_runtime/lower_function_call_op.cc", - "common_runtime/lower_functional_ops.cc", - "common_runtime/lower_if_op.cc", - "common_runtime/lower_while_op.cc", - "common_runtime/memory_types.cc", - "common_runtime/metrics.cc", - "common_runtime/mkl_cpu_allocator.cc", - "common_runtime/optimization_registry.cc", - "common_runtime/parallel_concat_optimizer.cc", - "common_runtime/partitioning_utils.cc", - "common_runtime/placer.cc", - "common_runtime/placer_inspection_required_ops_utils.cc", - "common_runtime/placer_inspection_required_ops_utils.h", - "common_runtime/pool_allocator.cc", - "common_runtime/process_function_library_runtime.cc", - "common_runtime/process_state.cc", - "common_runtime/process_util.cc", - "common_runtime/propagator_state.cc", - "common_runtime/renamed_device.cc", - "common_runtime/rendezvous_mgr.cc", - "common_runtime/rendezvous_util.cc", - "common_runtime/replicate_per_replica_nodes.cc", - "common_runtime/ring_alg.cc", - "common_runtime/ring_gatherer.cc", - "common_runtime/ring_reducer.cc", - "common_runtime/session.cc", - "common_runtime/session_factory.cc", - "common_runtime/session_options.cc", - "common_runtime/session_state.cc", - "common_runtime/single_threaded_cpu_device.cc", - "common_runtime/stats_publisher_interface.cc", - "common_runtime/step_stats_collector.cc", - "common_runtime/threadpool_device.cc", - "common_runtime/threadpool_device_factory.cc", - "//tensorflow/core/graph:core_cpu_impl_srcs", - "//tensorflow/core/public:session.h", - "//tensorflow/core/public:session_options.h", - ], - hdrs = [":core_cpu_lib_headers"], - copts = tf_copts() + tf_openmp_copts(), - deps = [ - ":bfc_allocator", - ":graph", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":protos_all_cc", - "@com_google_absl//absl/base", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:variant", - "//third_party/eigen3", - "//tensorflow/core/public:version", - "//tensorflow/core/grappler/utils:functions", - "//tensorflow/core/profiler/lib:annotated_traceme", - "//tensorflow/core/profiler/lib:scoped_annotation", - "//tensorflow/core/profiler/lib:traceme", - ] + mkl_deps(), - alwayslink = 1, + actual = + "//tensorflow/core/common_runtime:core_cpu_impl", ) -tf_cuda_library( +alias( name = "core_cpu_lib", - hdrs = [":core_cpu_lib_headers"], - deps = [ - ":core_cpu_base", - "//tensorflow/core/grappler:grappler_item", - ] + if_static([":core_cpu_impl"]) + tf_protos_all() + tf_protos_grappler(), + actual = + "//tensorflow/core/common_runtime:core_cpu_lib", ) -tf_cuda_library( - name = "core_cpu_lib_no_ops", - hdrs = [":core_cpu_lib_headers"], - deps = [ - ":core_cpu_base_no_ops", - "//tensorflow/core/grappler:grappler_item", - ] + tf_protos_all() + tf_protos_grappler(), -) - -tf_cuda_library( +alias( name = "core_cpu_internal", - srcs = [ - "common_runtime/graph_execution_state.cc", - ], - hdrs = [ - "common_runtime/graph_execution_state.h", - ":core_cpu_lib_headers", - ], - copts = tf_copts(), - deps = [ - ":framework", - ":graph", - ":lib", - ":protos_all_cc", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "//tensorflow/core/grappler:grappler_item", - "//tensorflow/core/grappler/clusters:utils", - "//tensorflow/core/grappler/clusters:virtual_cluster", - "//tensorflow/core/grappler/optimizers:meta_optimizer", - "//third_party/eigen3", - ] + mkl_deps() + tf_additional_core_deps() + if_static([ - ":core_cpu_impl", - ":function_ops_op_lib", - ":functional_grad", - ":functional_ops_op_lib", - "//tensorflow/core/kernels:required", - ]), - alwayslink = 1, -) - -# This is redundant with the "core_cpu_*" targets above. It's useful for -# applications that want to depend on a minimal subset of TensorFlow (e.g. XLA). -cc_library( - name = "bfc_allocator", - srcs = [ - "common_runtime/allocator_retry.cc", - "common_runtime/allocator_retry.h", - "common_runtime/bfc_allocator.cc", - ], - hdrs = ["common_runtime/bfc_allocator.h"], - features = ["parse_headers"], - visibility = ["//visibility:public"], - deps = [ - ":lib", - ":lib_internal", - ":protos_all_cc", - ":shared_counter", - "//tensorflow/core/framework:allocator", - "//tensorflow/core/profiler/lib:traceme", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "shared_counter", - hdrs = ["common_runtime/shared_counter.h"], - features = ["parse_headers"], - visibility = ["//visibility:public"], - deps = [ - ":lib", - ], + actual = + "//tensorflow/core/common_runtime:core_cpu_internal", ) alias( @@ -2786,31 +2447,10 @@ alias( ], ) -tf_cuda_library( +alias( name = "direct_session_internal", - srcs = ["common_runtime/direct_session.cc"], - hdrs = [ - "common_runtime/direct_session.h", - "//tensorflow/core/util:lib_internal_public_hdrs", - ], - copts = tf_copts(), - deps = [ - ":core_cpu_internal", - ":framework", - ":framework_internal", - ":graph", - ":lib", - ":lib_experimental", - ":lib_internal", - ":protos_all_cc", - "//tensorflow/core/debug:debug_graph_utils", - "//tensorflow/core/kernels:function_ops", - "//tensorflow/core/profiler/lib:profiler_backends", - "//tensorflow/core/profiler/lib:profiler_session", - "//tensorflow/core/profiler/lib:traceme", - "@com_google_absl//absl/container:flat_hash_set", - ], - alwayslink = 1, + actual = + "//tensorflow/core/common_runtime:direct_session_internal", ) alias( @@ -2831,209 +2471,10 @@ tf_proto_library_cc( ], ) -cc_library( - name = "gpu_id", - hdrs = [ - "common_runtime/gpu/gpu_id.h", - "common_runtime/gpu/gpu_id_manager.h", - ], - deps = [ - ":lib", - ] + if_static([ - ":gpu_id_impl", - ]), -) - -cc_library( - name = "gpu_id_impl", - srcs = ["common_runtime/gpu/gpu_id_manager.cc"], - hdrs = [ - "common_runtime/gpu/gpu_id.h", - "common_runtime/gpu/gpu_id_manager.h", - ], - deps = [ - ":lib", - ], -) - -filegroup( - name = "gpu_runtime_headers", - srcs = [ - "common_runtime/gpu/gpu_bfc_allocator.h", - "common_runtime/gpu/gpu_cudamalloc_allocator.h", - "common_runtime/gpu/gpu_debug_allocator.h", - "common_runtime/gpu/gpu_device.h", - "common_runtime/gpu/gpu_host_allocator.h", - "common_runtime/gpu/gpu_id.h", - "common_runtime/gpu/gpu_id_manager.h", - "common_runtime/gpu/gpu_id_utils.h", - "common_runtime/gpu/gpu_init.h", - "common_runtime/gpu/gpu_managed_allocator.h", - "common_runtime/gpu/gpu_mem_allocator.h", - "common_runtime/gpu/gpu_process_state.h", - "common_runtime/gpu/gpu_stream_util.h", - "common_runtime/gpu/gpu_util.h", - "common_runtime/gpu_device_context.h", - ], - visibility = ["//visibility:private"], -) - -tf_cuda_library( - name = "gpu_runtime_impl", - srcs = [ - "common_runtime/gpu/gpu_cudamalloc_allocator.cc", - "common_runtime/gpu/gpu_debug_allocator.cc", - "common_runtime/gpu/gpu_device.cc", - "common_runtime/gpu/gpu_device_factory.cc", - "common_runtime/gpu/gpu_managed_allocator.cc", - "common_runtime/gpu/gpu_process_state.cc", - "common_runtime/gpu/gpu_stream_util.cc", - "common_runtime/gpu/gpu_util.cc", - "common_runtime/gpu/gpu_util_platform_specific.cc", - ], - hdrs = [":gpu_runtime_headers"], - copts = tf_copts(), - cuda_deps = [ - "@local_config_cuda//cuda:cudnn_header", - ], - deps = [ - ":core_cpu_impl", - ":core_cpu_lib", - ":framework", - ":framework_internal", - ":gpu_bfc_allocator", - ":gpu_id_impl", - ":gpu_init_impl", - ":gpu_lib", - ":graph", - ":lib", - ":lib_internal", - ":protos_all_cc", - ":stream_executor", - "//tensorflow/core/profiler/lib:annotated_traceme", - "//tensorflow/core/profiler/lib:scoped_annotation", - "//third_party/eigen3", - ], - alwayslink = 1, -) - -tf_cuda_library( +alias( name = "gpu_runtime", - hdrs = [":gpu_runtime_headers"], - linkstatic = 1, - deps = [ - ":core_cpu_lib", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":protos_all_cc", - ":stream_executor", - "//third_party/eigen3", - ] + if_static([":gpu_runtime_impl"]), -) - -# This is redundant with the "gpu_runtime_*" targets above. It's useful for -# applications that want to depend on a minimal subset of TensorFlow (e.g. XLA). -tf_cuda_library( - name = "gpu_bfc_allocator", - srcs = [ - "common_runtime/gpu/gpu_bfc_allocator.cc", - ], - hdrs = ["common_runtime/gpu/gpu_bfc_allocator.h"], - features = ["parse_headers"], - visibility = ["//visibility:public"], - deps = [ - ":bfc_allocator", - ":gpu_mem_allocator", - ":lib", - ":lib_internal", - ":protos_all_cc", - ], -) - -tf_cuda_library( - name = "gpu_mem_allocator", - srcs = [ - "common_runtime/gpu/gpu_id.h", - ], - hdrs = [ - "common_runtime/gpu/gpu_host_allocator.h", - "common_runtime/gpu/gpu_mem_allocator.h", - ], - features = ["parse_headers"], - visibility = ["//visibility:public"], - deps = [ - ":lib", - ":lib_internal", - ":stream_executor", - "//tensorflow/core/framework:allocator", - ], -) - -tf_cuda_library( - name = "gpu_init", - hdrs = [ - "common_runtime/gpu/gpu_init.h", - ], - deps = [ - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":stream_executor", - ] + if_static( - [":gpu_init_impl"], - ), -) - -tf_cuda_library( - name = "gpu_init_impl", - srcs = [ - "common_runtime/gpu/gpu_init.cc", - ], - hdrs = [ - "common_runtime/gpu/gpu_init.h", - ], - copts = tf_copts(), - linkstatic = 1, - deps = [ - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":stream_executor", - ], - alwayslink = 1, -) - -cc_library( - name = "sycl_runtime", - srcs = if_not_windows([ - "common_runtime/sycl/sycl_allocator.cc", - "common_runtime/sycl/sycl_device.cc", - "common_runtime/sycl/sycl_device_context.cc", - "common_runtime/sycl/sycl_device_factory.cc", - ]), - hdrs = if_not_windows([ - "common_runtime/sycl/sycl_allocator.h", - "common_runtime/sycl/sycl_device.h", - "common_runtime/sycl/sycl_util.h", - "common_runtime/sycl/sycl_device_context.h", - ]), - copts = tf_copts(), - linkstatic = 0, - deps = [ - ":core_cpu", - ":core_cpu_internal", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - "//third_party/eigen3", - "@local_config_sycl//sycl", - ], - alwayslink = 0, + actual = + "//tensorflow/core/common_runtime/gpu:gpu_runtime", ) # ----------------------------------------------------------------------------- @@ -3226,70 +2667,10 @@ test_suite( ], ) -tf_cc_test( - name = "common_runtime_placer_test", - size = "small", - srcs = [ - "common_runtime/placer_test.cc", - ], - linkopts = select({ - "//tensorflow:macos": ["-headerpad_max_install_names"], - "//conditions:default": [], - }), - linkstatic = tf_kernel_tests_linkstatic(), - tags = ["no_windows"], - deps = [ - ":core", - ":core_cpu", - ":core_cpu_internal", - ":direct_session_internal", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:cc_ops_internal", - "//tensorflow/cc:function_ops", - "//tensorflow/cc:ops", - "//tensorflow/cc:scope", - "//tensorflow/cc:sendrecv_ops", - "//tensorflow/cc:while_loop", - "//tensorflow/core/kernels:ops_util", - "//tensorflow/core/platform:regexp", - "//tensorflow/core/util:protos_test_cc", - "//third_party/eigen3", - "@com_google_absl//absl/base", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - ], -) - tf_cc_tests( name = "core_higher_level_tests", size = "small", srcs = [ - "common_runtime/buf_rendezvous_test.cc", - "common_runtime/collective_executor_mgr_test.cc", - "common_runtime/collective_rma_local_test.cc", - "common_runtime/device_mgr_test.cc", - "common_runtime/device_resolver_local_test.cc", - "common_runtime/device_set_test.cc", - "common_runtime/dynamic_device_mgr_test.cc", - "common_runtime/function_optimization_registration_test.cc", - "common_runtime/function_optimization_registry_no_pass_test.cc", - "common_runtime/function_optimization_registry_pass_failure_test.cc", - "common_runtime/function_optimization_registry_test.cc", - "common_runtime/isolate_placer_inspection_required_ops_pass_test.cc", - "common_runtime/optimization_registry_test.cc", - "common_runtime/pending_counts_test.cc", - "common_runtime/placer_inspection_required_ops_utils_test.cc", - "common_runtime/session_test.cc", - "common_runtime/threadpool_device_test.cc", "//tensorflow/core/example:feature_util_test.cc", "//tensorflow/core/graph:algorithm_test.cc", "//tensorflow/core/graph:control_flow_test.cc", @@ -3345,7 +2726,6 @@ tf_cc_tests( name = "higher_level_tests_needing_kernels", size = "small", srcs = [ - "common_runtime/collective_param_resolver_local_test.cc", "//tensorflow/core/graph:higher_level_tests_needing_kernels", ], linkopts = select({ @@ -3394,114 +2774,6 @@ tf_cc_test( ], ) -tf_cc_tests_gpu( - name = "ring_reducer_test", - size = "medium", - srcs = [ - "common_runtime/ring_reducer_test.cc", - ], - linkstatic = tf_kernel_tests_linkstatic(), - tags = ["no_cuda_on_cpu_tap"], - deps = [ - ":all_kernels", - ":core", - ":core_cpu", - ":core_cpu_internal", - ":direct_session_internal", - ":framework", - ":framework_internal", - ":gpu_runtime", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/core/util:protos_test_cc", - "@com_google_absl//absl/memory", - ], -) - -tf_cc_tests_gpu( - name = "ring_gatherer_test", - size = "medium", - srcs = [ - "common_runtime/ring_gatherer_test.cc", - ], - linkstatic = tf_kernel_tests_linkstatic(), - tags = ["no_cuda_on_cpu_tap"], - deps = [ - ":all_kernels", - ":core", - ":core_cpu", - ":core_cpu_internal", - ":direct_session_internal", - ":framework", - ":framework_internal", - ":gpu_runtime", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/core/util:protos_test_cc", - "@com_google_absl//absl/memory", - ], -) - -tf_cc_tests_gpu( - name = "hierarchical_tree_broadcaster_test", - size = "medium", - srcs = [ - "common_runtime/hierarchical_tree_broadcaster_test.cc", - ], - linkstatic = tf_kernel_tests_linkstatic(), - tags = ["no_cuda_on_cpu_tap"], - deps = [ - ":all_kernels", - ":core", - ":core_cpu", - ":core_cpu_internal", - ":direct_session_internal", - ":framework", - ":framework_internal", - ":gpu_runtime", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/core/util:protos_test_cc", - "@com_google_absl//absl/memory", - ], -) - -tf_cc_test_mkl( - name = "mkl_runtime_tests", - size = "small", - srcs = [ - "common_runtime/mkl_cpu_allocator_test.cc", - "common_runtime/mkl_threadpool_device_test.cc", - ], - linkstatic = 1, - deps = [ - ":core", - ":core_cpu", - ":core_cpu_internal", - ":framework", - ":framework_internal", - ":lib", - ":test", - ":test_main", - ":testlib", - ], -) - tf_cc_test_mkl( name = "mkl_related_tests", size = "small", @@ -3554,28 +2826,10 @@ tf_cc_test_mkl( ]), ) -tf_cc_tests_gpu( - name = "gpu_device_on_non_gpu_machine_test", - size = "small", - srcs = ["common_runtime/gpu/gpu_device_on_non_gpu_machine_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":gpu_headers_lib", - ":gpu_id", - ":gpu_runtime", - ":test", - ], -) - tf_cc_tests_gpu( name = "gpu_related_tests", size = "small", - srcs = glob(["user_ops/**/*_test.cc"]) + [ - "common_runtime/gpu/gpu_bfc_allocator_test.cc", - "common_runtime/gpu/gpu_device_test.cc", - "common_runtime/gpu/gpu_id_manager_test.cc", - "common_runtime/gpu/pool_allocator_test.cc", - ], + srcs = glob(["user_ops/**/*_test.cc"]), linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags(), deps = [ @@ -3597,81 +2851,6 @@ tf_cc_tests_gpu( ], ) -tf_cc_test_gpu( - name = "gpu_event_mgr_test", - srcs = ["common_runtime/gpu/gpu_event_mgr_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - tags = tf_cuda_tests_tags(), - deps = [ - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/core/kernels:cwise_op", - ], -) - -tf_cuda_cc_test( - name = "gpu_device_unified_memory_test", - size = "small", - srcs = [ - "common_runtime/gpu/gpu_device_test.cc", - ], - linkstatic = tf_kernel_tests_linkstatic(), - # Runs test on a Guitar cluster that uses P100s to test unified memory - # allocations. - tags = tf_cuda_tests_tags() + [ - "guitar", - "multi_gpu", - ], - deps = [ - ":core_cpu", - ":core_cpu_internal", - ":direct_session", - ":framework", - ":framework_internal", - ":gpu_id", - ":lib", - ":lib_internal", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/core/kernels:ops_util", - ], -) - -tf_cc_test_gpu( - name = "memory_types_test", - size = "small", - srcs = ["common_runtime/memory_types_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - tags = tf_cuda_tests_tags(), - deps = [ - ":core", - ":core_cpu", - ":core_cpu_internal", - ":framework", - ":framework_internal", - ":gpu_runtime", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/core/kernels:cast_op", - "//third_party/eigen3", - ], -) - tf_cc_test_gpu( name = "variant_op_copy_test", size = "small", @@ -3701,140 +2880,6 @@ tf_cc_test_gpu( ], ) -tf_cc_test( - name = "common_runtime_constant_folding_test", - size = "small", - srcs = ["common_runtime/constant_folding_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - tags = tf_cuda_tests_tags(), - deps = [ - ":core", - ":core_cpu", - ":core_cpu_internal", - ":direct_session_internal", - ":framework", - ":framework_internal", - ":gpu_runtime", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:cc_ops_internal", - "//tensorflow/cc:sendrecv_ops", - "//tensorflow/core/kernels:bcast_ops", - "//tensorflow/core/kernels:cast_op", - "//tensorflow/core/kernels:concat_op", - "//tensorflow/core/kernels:cwise_op", - "//tensorflow/core/kernels:identity_op", - "//tensorflow/core/kernels:immutable_constant_op", - "//tensorflow/core/kernels:matmul_op", - "//tensorflow/core/kernels:topk_op", - "//third_party/eigen3", - ], -) - -tf_cc_test( - name = "common_runtime_shape_refiner_test", - size = "small", - srcs = [ - "common_runtime/shape_refiner_test.cc", - ], - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":core", - ":core_cpu", - ":core_cpu_internal", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:resource_variable_ops", - "//tensorflow/cc:scope", - "//tensorflow/core/kernels:array", - "//tensorflow/core/kernels:math", - "//tensorflow/core/kernels:resource_variable_ops", - "//third_party/eigen3", - ], -) - -tf_cuda_cc_test( - name = "common_runtime_process_function_library_runtime_test", - size = "small", - srcs = ["common_runtime/process_function_library_runtime_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - tags = ["no_rocm"], - deps = [ - ":core_cpu", - ":core_cpu_internal", - ":framework", - ":framework_internal", - ":lib", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:function_ops", - "//tensorflow/core/kernels:cast_op", - "//tensorflow/core/kernels:cwise_op", - "//tensorflow/core/kernels:function_ops", - "//tensorflow/core/kernels:resource_variable_ops", - ], -) - -tf_cc_test( - name = "common_runtime_process_util_test", - size = "small", - srcs = ["common_runtime/process_util_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":core_cpu_internal", - ":test", - ":test_main", - ], -) - -tf_cc_test( - name = "common_runtime_rendezvous_util_test", - size = "small", - srcs = ["common_runtime/rendezvous_util_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":core_cpu_internal", - ":lib", - ":test", - ":test_main", - ], -) - -tf_cc_test( - name = "common_runtime_replicate_per_replica_nodes_test", - size = "small", - srcs = ["common_runtime/replicate_per_replica_nodes_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":core_cpu_internal", - ":framework", - ":test", - ":test_main", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:cc_ops_internal", - "//tensorflow/cc:function_ops", - "//tensorflow/cc:ops", - "//tensorflow/cc:resource_variable_ops", - "@com_google_absl//absl/strings", - ], -) - tf_cc_test( name = "framework_run_handler_util_test", size = "small", @@ -3872,362 +2917,6 @@ tf_cc_test( ], ) -tf_cc_test( - name = "common_runtime_partitioning_utils_test", - size = "small", - srcs = ["common_runtime/partitioning_utils_test.cc"], - deps = [ - ":core_cpu", - ":core_cpu_internal", - ":framework", - ":lib", - ":ops", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:cc_ops_internal", - "//tensorflow/cc:function_ops", - "//tensorflow/core/kernels:function_ops", - "//tensorflow/core/kernels:identity_op", - ], -) - -tf_cuda_cc_test( - name = "common_runtime_direct_session_test", - size = "small", - srcs = ["common_runtime/direct_session_test.cc"], - args = [] + if_cuda(["--heap_check=local"]), # The GPU tracer leaks memory - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":core_cpu", - ":core_cpu_internal", - ":direct_session_internal", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "//third_party/eigen3", - "//tensorflow/cc:cc_ops", - "//tensorflow/core/kernels:collective_ops", - "//tensorflow/core/kernels:control_flow_ops", - "//tensorflow/core/kernels:cwise_op", - "//tensorflow/core/kernels:dense_update_ops", - "//tensorflow/core/kernels:fifo_queue_op", - "//tensorflow/core/kernels:function_ops", - "//tensorflow/core/kernels:identity_n_op", - "//tensorflow/core/kernels:identity_op", - "//tensorflow/core/kernels:matmul_op", - "//tensorflow/core/kernels:ops_util", - "//tensorflow/core/kernels:queue_ops", - "//tensorflow/core/kernels:session_ops", - "//tensorflow/core/kernels:variable_ops", - "//tensorflow/core/kernels/data:single_threaded_executor", - ] + if_cuda([":cuda"]), -) - -# This is identical to :common_runtime_direct_session_test with the addition of -# a dependency on alwayslink target //third_party/tensorflow/core/debug, which -# enables support for TensorFlow Debugger (tfdbg). -tf_cc_test( - name = "common_runtime_direct_session_with_debug_test", - size = "small", - srcs = ["common_runtime/direct_session_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":core", - ":core_cpu", - ":core_cpu_internal", - ":direct_session_internal", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "@com_google_absl//absl/strings", - "//third_party/eigen3", - "@com_google_absl//absl/memory", - "//tensorflow/cc:cc_ops", - # Link with support for TensorFlow Debugger (tfdbg). - "//tensorflow/core/debug", - "//tensorflow/core/kernels:collective_ops", - "//tensorflow/core/kernels:control_flow_ops", - "//tensorflow/core/kernels:cwise_op", - "//tensorflow/core/kernels:dense_update_ops", - "//tensorflow/core/kernels:fifo_queue_op", - "//tensorflow/core/kernels:function_ops", - "//tensorflow/core/kernels:identity_op", - "//tensorflow/core/kernels:identity_n_op", - "//tensorflow/core/kernels:matmul_op", - "//tensorflow/core/kernels:ops_util", - "//tensorflow/core/kernels:queue_ops", - "//tensorflow/core/kernels:session_ops", - "//tensorflow/core/kernels:variable_ops", - ], -) - -tf_cc_test( - name = "common_runtime_direct_session_with_tracking_alloc_test", - size = "small", - srcs = ["common_runtime/direct_session_with_tracking_alloc_test.cc"], - args = ["--heap_check=local"], # The GPU tracer leaks memory - linkstatic = tf_kernel_tests_linkstatic(), - tags = ["no_gpu"], - deps = [ - ":core", - ":core_cpu", - ":core_cpu_internal", - ":direct_session_internal", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/core/kernels:cwise_op", - "//tensorflow/core/kernels:dense_update_ops", - "//tensorflow/core/kernels:fifo_queue_op", - "//tensorflow/core/kernels:identity_op", - "//tensorflow/core/kernels:matmul_op", - "//tensorflow/core/kernels:ops_util", - "//tensorflow/core/kernels:queue_ops", - "//tensorflow/core/kernels:variable_ops", - "//third_party/eigen3", - ], -) - -tf_cc_test( - name = "common_runtime_graph_runner_test", - size = "small", - srcs = ["common_runtime/graph_runner_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":array_ops_op_lib", - ":core", - ":core_cpu", - ":core_cpu_internal", - ":direct_session_internal", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//third_party/eigen3", - "//tensorflow/c/kernels:bitcast_op_lib", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:scope", - "//tensorflow/core/kernels:cwise_op", - ] + if_mkl([":mkl_array_ops_op_lib"]), -) - -tf_cc_test( - name = "common_runtime_executor_test", - size = "small", - srcs = ["common_runtime/executor_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":core", - ":core_cpu", - ":core_cpu_internal", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/core/kernels:array", - "//tensorflow/core/kernels:control_flow_ops", - "//tensorflow/core/kernels:math", - "//tensorflow/core/kernels:random_ops", - "//tensorflow/core/kernels:state", - ], -) - -tf_cc_test( - name = "common_runtime_function_test", - size = "small", - srcs = ["common_runtime/function_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - tags = [ - "manual", - "no_oss", - ], - deps = [ - ":core", - ":core_cpu", - ":core_cpu_internal", - ":direct_session_internal", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:cc_ops_internal", - "//tensorflow/cc:function_ops", - "//tensorflow/cc:functional_ops", - "//tensorflow/cc:sendrecv_ops", - "//tensorflow/core/kernels:cast_op", - "//tensorflow/core/kernels:cwise_op", - "//tensorflow/core/kernels:function_ops", - "//tensorflow/core/kernels:matmul_op", - "//tensorflow/core/kernels:partitioned_function_ops", - "//tensorflow/core/kernels:random_ops", - "//tensorflow/core/kernels:shape_ops", - "//third_party/eigen3", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - ], -) - -tf_cc_test( - name = "common_runtime_function_threadpool_test", - size = "small", - srcs = ["common_runtime/function_threadpool_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":core", - ":core_cpu", - ":core_cpu_internal", - ":direct_session_internal", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:cc_ops_internal", - "//tensorflow/cc:function_ops", - "//tensorflow/cc:functional_ops", - "//tensorflow/core/kernels:cast_op", - "//tensorflow/core/kernels:cwise_op", - "//tensorflow/core/kernels:function_ops", - "//tensorflow/core/kernels:matmul_op", - "//tensorflow/core/kernels:random_ops", - "//tensorflow/core/kernels:shape_ops", - "//third_party/eigen3", - ], -) - -tf_cc_test( - name = "common_runtime_scoped_allocator_mgr_test", - size = "small", - srcs = ["common_runtime/scoped_allocator_mgr_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":core_cpu", - ":core_cpu_internal", - ":framework", - ":lib", - ":test", - ":test_main", - ], -) - -tf_cc_test_gpu( - name = "gpu_allocator_retry_test", - size = "medium", - srcs = ["common_runtime/gpu/gpu_allocator_retry_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - tags = tf_cuda_tests_tags(), - deps = [ - ":core_cpu", - ":core_cpu_internal", - ":direct_session", - ":framework", - ":framework_internal", - ":gpu_runtime", - ":lib", - ":lib_internal", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - ], -) - -tf_cc_test_gpu( - name = "gpu_debug_allocator_test", - size = "medium", - srcs = ["common_runtime/gpu/gpu_debug_allocator_test.cc"], - args = ["--gtest_death_test_style=threadsafe"], - linkstatic = tf_kernel_tests_linkstatic(), - tags = tf_cuda_tests_tags(), - deps = [ - ":core_cpu", - ":core_cpu_internal", - ":direct_session", - ":framework", - ":framework_internal", - ":gpu_id", - ":gpu_runtime", - ":lib", - ":lib_internal", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/core/kernels:ops_util", - ], -) - -tf_cc_test_gpu( - name = "gpu_stream_util_test", - size = "small", - srcs = ["common_runtime/gpu/gpu_stream_util_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - tags = tf_cuda_tests_tags() + ["nomac"], - deps = [ - ":core_cpu", - ":core_cpu_internal", - ":direct_session", - ":framework", - ":framework_internal", - ":gpu_runtime", - ":lib", - ":lib_internal", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:sendrecv_ops", - "//tensorflow/core/kernels:matmul_op", - "//tensorflow/core/kernels:ops_util", - ], -) - tf_cc_test( name = "framework_op_segment_test", size = "small", @@ -4379,138 +3068,6 @@ tf_cc_test( ], ) -tf_cc_test( - name = "common_runtime_input_colocation_exemption_registry_test", - size = "small", - srcs = ["common_runtime/input_colocation_exemption_registry_test.cc"], - deps = [ - ":core_cpu", - ":core_cpu_internal", - ":test", - ":test_main", - ":testlib", - ], -) - -tf_cc_test( - name = "common_runtime_lower_function_call_test", - size = "small", - srcs = ["common_runtime/lower_function_call_op_test.cc"], - deps = [ - ":all_kernels", - ":core_cpu", - ":core_cpu_internal", - ":direct_session", - ":framework", - ":framework_internal", - ":lib", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:cc_ops_internal", - "//tensorflow/cc:client_session", - "//tensorflow/cc:function_ops", - "//tensorflow/cc:ops", - "//tensorflow/cc:resource_variable_ops", - ], -) - -tf_cc_test( - name = "common_runtime_lower_if_op_test", - size = "small", - srcs = ["common_runtime/lower_if_op_test.cc"], - deps = [ - ":all_kernels", - ":core_cpu", - ":core_cpu_internal", - ":direct_session", - ":framework", - ":framework_internal", - ":lib", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:cc_ops_internal", - "//tensorflow/cc:client_session", - "//tensorflow/cc:function_ops", - "//tensorflow/cc:ops", - "//tensorflow/cc:resource_variable_ops", - ], -) - -tf_cc_test( - name = "common_runtime_lower_case_op_test", - size = "small", - srcs = ["common_runtime/lower_case_op_test.cc"], - deps = [ - ":all_kernels", - ":core_cpu", - ":core_cpu_internal", - ":direct_session", - ":framework", - ":framework_internal", - ":lib", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:cc_ops_internal", - "//tensorflow/cc:client_session", - "//tensorflow/cc:function_ops", - "//tensorflow/cc:ops", - "//tensorflow/cc:resource_variable_ops", - ], -) - -tf_cc_test( - name = "common_runtime_lower_while_op_test", - size = "small", - srcs = ["common_runtime/lower_while_op_test.cc"], - deps = [ - ":all_kernels", - ":core_cpu", - ":core_cpu_internal", - ":direct_session", - ":framework", - ":framework_internal", - ":lib", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:cc_ops_internal", - "//tensorflow/cc:client_session", - "//tensorflow/cc:function_ops", - "//tensorflow/cc:ops", - "@com_google_absl//absl/algorithm:container", - ], -) - -tf_cc_test( - name = "common_runtime_lower_functional_ops_test", - size = "small", - srcs = ["common_runtime/lower_functional_ops_test.cc"], - deps = [ - ":all_kernels", - ":core_cpu", - ":core_cpu_internal", - ":direct_session", - ":framework", - ":framework_internal", - ":lib", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:cc_ops_internal", - "//tensorflow/cc:client_session", - "//tensorflow/cc:function_ops", - "//tensorflow/cc:ops", - ], -) - # Test data filegroup( name = "image_testdata", diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatrixOrderingAMD.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatrixOrderingAMD.pbtxt index 9e7842c0f68..e8effc11814 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseMatrixOrderingAMD.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseMatrixOrderingAMD.pbtxt @@ -45,7 +45,7 @@ Usage example: with tf.Session() as sess: # Define (COO format) SparseTensor over Numpy array. - a_st = tf.SparseTensor(a_indices, a_values, a_dense_shape) + a_st = tf.sparse.SparseTensor(a_indices, a_values, a_dense_shape) # Convert SparseTensors to CSR SparseMatrix. a_sm = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix( diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatrixSparseCholesky.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatrixSparseCholesky.pbtxt index f7cdd3574ac..ddebddeb57a 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseMatrixSparseCholesky.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseMatrixSparseCholesky.pbtxt @@ -58,7 +58,7 @@ Usage example: with tf.Session() as sess: # Define (COO format) SparseTensor over Numpy array. - a_st = tf.SparseTensor(a_indices, a_values, a_dense_shape) + a_st = tf.sparse.SparseTensor(a_indices, a_values, a_dense_shape) # Convert SparseTensors to CSR SparseMatrix. a_sm = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix( diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatrixSparseMatMul.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatrixSparseMatMul.pbtxt index f84b3948be4..78eb9aeb512 100644 --- a/tensorflow/core/api_def/base_api/api_def_SparseMatrixSparseMatMul.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SparseMatrixSparseMatMul.pbtxt @@ -71,8 +71,8 @@ Usage example: with tf.Session() as sess: # Define (COO format) Sparse Tensors over Numpy arrays - a_st = tf.SparseTensor(a_indices, a_values, a_dense_shape) - b_st = tf.SparseTensor(b_indices, b_values, b_dense_shape) + a_st = tf.sparse.SparseTensor(a_indices, a_values, a_dense_shape) + b_st = tf.sparse.SparseTensor(b_indices, b_values, b_dense_shape) # Convert SparseTensors to CSR SparseMatrix a_sm = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix( diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD new file mode 100644 index 00000000000..bbfed7f8f5b --- /dev/null +++ b/tensorflow/core/common_runtime/BUILD @@ -0,0 +1,1306 @@ +load( + "//tensorflow:tensorflow.bzl", + "cc_header_only_library", + "tf_cc_test", + "tf_cc_test_mkl", + "tf_cc_tests", + "tf_copts", + "tf_cuda_library", + "tf_openmp_copts", +) + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "tf_cc_tests_gpu") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") + +# For platform specific build config +load( + "//tensorflow/core/platform:build_config.bzl", + "tf_additional_core_deps", + "tf_kernel_tests_linkstatic", + "tf_protos_all", + "tf_protos_grappler", +) +load( + "//tensorflow/core/platform:rules_cc.bzl", + "cc_library", +) +load( + "//tensorflow/core/platform:build_config_root.bzl", + "if_static", + "tf_cuda_tests_tags", +) +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") +load( + "//third_party/mkl:build_defs.bzl", + "if_mkl", + "mkl_deps", +) + +package( + default_visibility = [ + "//tensorflow:internal", + "//tensorflow_models:__subpackages__", + ], + licenses = ["notice"], # Apache 2.0 +) + +tf_cuda_library( + name = "core_cpu", + hdrs = [ + "device.h", + "device_factory.h", + "function.h", + "function_optimization_registry.h", + "optimization_registry.h", + "shape_refiner.h", + "//tensorflow/core/graph:core_cpu_headers", + "//tensorflow/core/public:session.h", + "//tensorflow/core/public:session_options.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":core_cpu_internal", + ], +) + +cc_header_only_library( + name = "core_cpu_headers_lib", + visibility = ["//visibility:public"], + deps = [ + ":core_cpu_lib", + ], +) + +cc_library( + name = "core", + visibility = ["//visibility:public"], + deps = [ + ":core_cpu", + "//tensorflow/core/common_runtime/gpu:gpu_runtime", + "//tensorflow/core/common_runtime/sycl:sycl_runtime", + ], +) + +filegroup( + name = "testlib_srcs", + srcs = [ + "function_testlib.cc", + "kernel_benchmark_testlib.cc", + ], +) + +filegroup( + name = "testlib_headers", + srcs = [ + "function_testlib.h", + "kernel_benchmark_testlib.h", + "test_collective_executor_mgr.h", + ], +) + +cc_library( + name = "testlib_ops", + testonly = 1, + srcs = ["testlib_ops.cc"], + linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + +# ----------------------------------------------------------------------------- +# Public Android targets + +# Sources required to build the TensorFlow framework with runtime on +# mobile platforms without granular targets. It is assumed that the source +# files in tensorflow/core:mobile_srcs_no_runtime have been compiled +# separately and are linked in as a dependency. +filegroup( + name = "mobile_srcs_only_runtime", + srcs = [ + ] + glob( + [ + "**/*.cc", + "**/*.h", + ], + exclude = [ + "**/*test.*", + "**/*testutil*", + "**/*testlib*", + "**/*main.cc", + "gpu/**/*", + "gpu_device_factory.*", + ], + ), +) + +filegroup( + name = "core_cpu_base_headers", + srcs = [ + "device.h", + "device_factory.h", + "device_mgr.h", + "device_set.h", + "eval_const_tensor.h", + "function.h", + "graph_runner.h", + "metrics.h", + "process_function_library_runtime.h", + "scoped_allocator.h", + "scoped_allocator_mgr.h", + "shape_refiner.h", + "//tensorflow/core/framework:versions.h", + "//tensorflow/core/graph:graph_headers", + ], +) + +tf_cuda_library( + name = "core_cpu_base_no_ops", + srcs = [ + "eval_const_tensor.cc", + "graph_optimizer.h", + "scoped_allocator.cc", + "scoped_allocator_mgr.cc", + "shape_refiner.cc", + "//tensorflow/core/graph:core_cpu_base_no_ops_srcs", + "//tensorflow/core/public:session_options.h", + "//tensorflow/core/public:version.h", + ], + hdrs = [ + ":core_cpu_base_headers", + "//tensorflow/core/public:session.h", + ], + copts = tf_copts(), + deps = [ + "//tensorflow/core:graph", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_set", + "//third_party/eigen3", + ] + if_static([ + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + ]), +) + +filegroup( + name = "core_cpu_lib_headers", + srcs = [ + ":core_cpu_base_headers", + "allocator_retry.h", + "shared_counter.h", + "base_collective_executor.h", + "bfc_allocator.h", + "hierarchical_tree_broadcaster.h", + "buf_rendezvous.h", + "build_graph_options.h", + "collective_executor_mgr.h", + "collective_param_resolver_local.h", + "collective_rma_local.h", + "collective_util.h", + "colocation_graph.h", + "constant_folding.h", + "copy_tensor.h", + "costmodel_manager.h", + "placer_inspection_required_ops_utils.h", + "debugger_state_interface.h", + "device_resolver_local.h", + "dma_helper.h", + "entry.h", + "executor.h", + "executor_factory.h", + "function_optimization_registry.h", + "graph_optimizer.h", + "graph_view.h", + "immutable_executor_state.h", + "input_colocation_exemption_registry.h", + "isolate_placer_inspection_required_ops_pass.h", + "local_device.h", + "lower_function_call_op.h", + "lower_if_op.h", + "lower_case_op.h", + "lower_functional_ops.h", + "lower_while_op.h", + "memory_types.h", + "mkl_cpu_allocator.h", + "optimization_registry.h", + "pending_counts.h", + "partitioning_utils.h", + "placer.h", + "process_util.h", + "inspecting_placer.h", + "profile_handler.h", + "propagator_debug_utils.h", + "propagator_state.h", + "renamed_device.h", + "rendezvous_mgr.h", + "rendezvous_util.h", + "replicate_per_replica_nodes.h", + "ring_reducer.h", + "ring_alg.h", + "ring_gatherer.h", + "session_factory.h", + "simple_propagator_state.h", + "single_threaded_cpu_device.h", + "stats_publisher_interface.h", + "step_stats_collector.h", + "threadpool_device.h", + "process_state.h", + "pool_allocator.h", + "//tensorflow/core/graph:core_cpu_lib_headers", + ] + if_mkl(["//tensorflow/core/graph:mkl_graph_util_header"]), +) + +tf_cuda_library( + name = "core_cpu_impl", + srcs = [ + "accumulate_n_optimizer.cc", + "base_collective_executor.cc", + "buf_rendezvous.cc", + "build_graph_options.cc", + "collective_executor_mgr.cc", + "collective_param_resolver_local.cc", + "collective_rma_local.cc", + "collective_util.cc", + "colocation_graph.cc", + "constant_folding.cc", + "copy_tensor.cc", + "costmodel_manager.cc", + "debugger_state_interface.cc", + "device.cc", + "device_factory.cc", + "device_mgr.cc", + "device_resolver_local.cc", + "device_set.cc", + "dynamic_device_mgr.cc", + "executor.cc", + "executor_factory.cc", + "function.cc", + "function_optimization_registry.cc", + "graph_optimizer.cc", + "graph_runner.cc", + "graph_view.cc", + "hierarchical_tree_broadcaster.cc", + "immutable_executor_state.cc", + "input_colocation_exemption_registry.cc", + "inspecting_placer.cc", + "isolate_placer_inspection_required_ops_pass.cc", + "local_device.cc", + "lower_case_op.cc", + "lower_function_call_op.cc", + "lower_functional_ops.cc", + "lower_if_op.cc", + "lower_while_op.cc", + "memory_types.cc", + "metrics.cc", + "mkl_cpu_allocator.cc", + "optimization_registry.cc", + "parallel_concat_optimizer.cc", + "partitioning_utils.cc", + "placer.cc", + "placer_inspection_required_ops_utils.cc", + "placer_inspection_required_ops_utils.h", + "pool_allocator.cc", + "process_function_library_runtime.cc", + "process_state.cc", + "process_util.cc", + "propagator_debug_utils.cc", + "propagator_state.cc", + "renamed_device.cc", + "rendezvous_mgr.cc", + "rendezvous_util.cc", + "replicate_per_replica_nodes.cc", + "ring_alg.cc", + "ring_gatherer.cc", + "ring_reducer.cc", + "session.cc", + "session_factory.cc", + "session_options.cc", + "session_state.cc", + "simple_propagator_state.cc", + "single_threaded_cpu_device.cc", + "stats_publisher_interface.cc", + "step_stats_collector.cc", + "threadpool_device.cc", + "threadpool_device_factory.cc", + "//tensorflow/core/graph:core_cpu_impl_srcs", + "//tensorflow/core/public:session.h", + "//tensorflow/core/public:session_options.h", + ], + hdrs = [":core_cpu_lib_headers"], + copts = tf_copts() + tf_openmp_copts(), + deps = [ + ":bfc_allocator", + "//tensorflow/core:graph", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/base", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:variant", + "//third_party/eigen3", + "//tensorflow/core/public:version", + "//tensorflow/core/grappler/utils:functions", + "//tensorflow/core/profiler/lib:annotated_traceme", + "//tensorflow/core/profiler/lib:scoped_annotation", + "//tensorflow/core/profiler/lib:traceme", + ] + mkl_deps(), + alwayslink = 1, +) + +tf_cuda_library( + name = "core_cpu_lib", + hdrs = [":core_cpu_lib_headers"], + deps = [ + "//tensorflow/core:core_cpu_base", + "//tensorflow/core/grappler:grappler_item", + ] + if_static([":core_cpu_impl"]) + tf_protos_all() + tf_protos_grappler(), +) + +tf_cuda_library( + name = "core_cpu_lib_no_ops", + hdrs = [":core_cpu_lib_headers"], + deps = [ + ":core_cpu_base_no_ops", + "//tensorflow/core/grappler:grappler_item", + ] + tf_protos_all() + tf_protos_grappler(), +) + +tf_cuda_library( + name = "core_cpu_internal", + srcs = [ + "graph_execution_state.cc", + ], + hdrs = [ + "graph_execution_state.h", + ":core_cpu_lib_headers", + ], + copts = tf_copts(), + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/clusters:utils", + "//tensorflow/core/grappler/clusters:virtual_cluster", + "//tensorflow/core/grappler/optimizers:meta_optimizer", + "//third_party/eigen3", + ] + mkl_deps() + tf_additional_core_deps() + if_static([ + ":core_cpu_impl", + "//tensorflow/core:function_ops_op_lib", + "//tensorflow/core:functional_grad", + "//tensorflow/core:functional_ops_op_lib", + "//tensorflow/core/kernels:required", + ]), + alwayslink = 1, +) + +# This is redundant with the "core_cpu_*" targets above. It's useful for +# applications that want to depend on a minimal subset of TensorFlow (e.g. XLA). +cc_library( + name = "bfc_allocator", + srcs = [ + "allocator_retry.cc", + "allocator_retry.h", + "bfc_allocator.cc", + ], + hdrs = ["bfc_allocator.h"], + features = ["parse_headers"], + visibility = ["//visibility:public"], + deps = [ + ":shared_counter", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/framework:allocator", + "//tensorflow/core/profiler/lib:traceme", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "shared_counter", + hdrs = ["shared_counter.h"], + features = ["parse_headers"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:lib", + ], +) + +tf_cuda_library( + name = "direct_session_internal", + srcs = ["direct_session.cc"], + hdrs = [ + "direct_session.h", + ], + copts = tf_copts(), + deps = [ + ":core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:lib_experimental", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/debug:debug_graph_utils", + "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/profiler/lib:profiler_backends", + "//tensorflow/core/profiler/lib:profiler_session", + "//tensorflow/core/profiler/lib:traceme", + "@com_google_absl//absl/container:flat_hash_set", + ], + alwayslink = 1, +) + +filegroup( + name = "gpu_runtime_headers", + srcs = [ + "gpu_device_context.h", + ], +) + +# ----------------------------------------------------------------------------- +# Tests + +tf_cc_test( + name = "placer_test", + size = "small", + srcs = [ + "placer_test.cc", + ], + linkopts = select({ + "//tensorflow:macos": ["-headerpad_max_install_names"], + "//conditions:default": [], + }), + linkstatic = tf_kernel_tests_linkstatic(), + tags = ["no_windows"], + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/cc:sendrecv_ops", + "//tensorflow/cc:while_loop", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_util", + "//tensorflow/core/platform:regexp", + "//tensorflow/core/util:protos_test_cc", + "//third_party/eigen3", + "@com_google_absl//absl/base", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_tests( + name = "core_higher_level_tests", + size = "small", + srcs = [ + "buf_rendezvous_test.cc", + "collective_executor_mgr_test.cc", + "collective_rma_local_test.cc", + "device_mgr_test.cc", + "device_resolver_local_test.cc", + "device_set_test.cc", + "dynamic_device_mgr_test.cc", + "function_optimization_registration_test.cc", + "function_optimization_registry_no_pass_test.cc", + "function_optimization_registry_pass_failure_test.cc", + "function_optimization_registry_test.cc", + "isolate_placer_inspection_required_ops_pass_test.cc", + "optimization_registry_test.cc", + "pending_counts_test.cc", + "placer_inspection_required_ops_utils_test.cc", + "session_test.cc", + "threadpool_device_test.cc", + ], + create_named_test_suite = True, + linkopts = select({ + "//tensorflow:macos": ["-headerpad_max_install_names"], + "//conditions:default": [], + }), + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/cc:sendrecv_ops", + "//tensorflow/cc:while_loop", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_util", + "//tensorflow/core/platform:regexp", + "//tensorflow/core/util:protos_test_cc", + "//third_party/eigen3", + "@com_google_absl//absl/base", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_tests( + name = "higher_level_tests_needing_kernels", + size = "small", + srcs = [ + "collective_param_resolver_local_test.cc", + ], + linkopts = select({ + "//tensorflow:macos": ["-headerpad_max_install_names"], + "//conditions:default": [], + }), + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:scope", + "//tensorflow/cc:sendrecv_ops", + "//tensorflow/core:all_kernels", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_util", + "//tensorflow/core/util:protos_test_cc", + "//third_party/eigen3", + ], +) + +tf_cc_tests_gpu( + name = "ring_reducer_test", + size = "medium", + srcs = [ + "ring_reducer_test.cc", + ], + linkstatic = tf_kernel_tests_linkstatic(), + tags = ["no_cuda_on_cpu_tap"], + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + "//tensorflow/core:all_kernels", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:gpu_runtime", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/util:protos_test_cc", + "@com_google_absl//absl/memory", + ], +) + +tf_cc_tests_gpu( + name = "ring_gatherer_test", + size = "medium", + srcs = [ + "ring_gatherer_test.cc", + ], + linkstatic = tf_kernel_tests_linkstatic(), + tags = ["no_cuda_on_cpu_tap"], + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + "//tensorflow/core:all_kernels", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/common_runtime/gpu:gpu_runtime", + "//tensorflow/core/util:protos_test_cc", + "@com_google_absl//absl/memory", + ], +) + +tf_cc_tests_gpu( + name = "hierarchical_tree_broadcaster_test", + size = "medium", + srcs = [ + "hierarchical_tree_broadcaster_test.cc", + ], + linkstatic = tf_kernel_tests_linkstatic(), + tags = ["no_cuda_on_cpu_tap"], + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + "//tensorflow/core:all_kernels", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/common_runtime/gpu:gpu_runtime", + "//tensorflow/core/util:protos_test_cc", + "@com_google_absl//absl/memory", + ], +) + +tf_cc_test_mkl( + name = "mkl_runtime_tests", + size = "small", + srcs = [ + "mkl_cpu_allocator_test.cc", + "mkl_threadpool_device_test.cc", + ], + linkstatic = 1, + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +tf_cc_test_gpu( + name = "memory_types_test", + size = "small", + srcs = ["memory_types_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags(), + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/common_runtime/gpu:gpu_runtime", + "//tensorflow/core/kernels:cast_op", + "//third_party/eigen3", + ], +) + +tf_cc_test( + name = "constant_folding_test", + size = "small", + srcs = ["constant_folding_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags(), + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:sendrecv_ops", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/common_runtime/gpu:gpu_runtime", + "//tensorflow/core/kernels:bcast_ops", + "//tensorflow/core/kernels:cast_op", + "//tensorflow/core/kernels:concat_op", + "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:identity_op", + "//tensorflow/core/kernels:immutable_constant_op", + "//tensorflow/core/kernels:matmul_op", + "//tensorflow/core/kernels:topk_op", + "//third_party/eigen3", + ], +) + +tf_cc_test( + name = "shape_refiner_test", + size = "small", + srcs = [ + "shape_refiner_test.cc", + ], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:resource_variable_ops", + "//tensorflow/cc:scope", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:array", + "//tensorflow/core/kernels:math", + "//tensorflow/core/kernels:resource_variable_ops", + "//third_party/eigen3", + ], +) + +tf_cuda_cc_test( + name = "process_function_library_runtime_test", + size = "small", + srcs = ["process_function_library_runtime_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + tags = ["no_rocm"], + deps = [ + ":core_cpu", + ":core_cpu_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:cast_op", + "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:resource_variable_ops", + ], +) + +tf_cc_test( + name = "process_util_test", + size = "small", + srcs = ["process_util_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":core_cpu_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "rendezvous_util_test", + size = "small", + srcs = ["rendezvous_util_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":core_cpu_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "replicate_per_replica_nodes_test", + size = "small", + srcs = ["replicate_per_replica_nodes_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":core_cpu_internal", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "partitioning_utils_test", + size = "small", + srcs = ["partitioning_utils_test.cc"], + deps = [ + ":core_cpu", + ":core_cpu_internal", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:ops", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:identity_op", + ], +) + +tf_cuda_cc_test( + name = "direct_session_test", + size = "small", + srcs = ["direct_session_test.cc"], + args = [] + if_cuda(["--heap_check=local"]), # The GPU tracer leaks memory + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "//third_party/eigen3", + "//tensorflow/cc:cc_ops", + "//tensorflow/core/kernels:collective_ops", + "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:dense_update_ops", + "//tensorflow/core/kernels:fifo_queue_op", + "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:identity_n_op", + "//tensorflow/core/kernels:identity_op", + "//tensorflow/core/kernels:matmul_op", + "//tensorflow/core/kernels:ops_util", + "//tensorflow/core/kernels:queue_ops", + "//tensorflow/core/kernels:session_ops", + "//tensorflow/core/kernels:variable_ops", + "//tensorflow/core/kernels/data:single_threaded_executor", + ] + if_cuda(["//tensorflow/core/common_runtime/gpu:cuda"]), +) + +# This is identical to :common_runtime_direct_session_test with the addition of +# a dependency on alwayslink target //third_party/tensorflow/core/debug, which +# enables support for TensorFlow Debugger (tfdbg). +tf_cc_test( + name = "direct_session_with_debug_test", + size = "small", + srcs = ["direct_session_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "@com_google_absl//absl/strings", + "//third_party/eigen3", + "@com_google_absl//absl/memory", + "//tensorflow/cc:cc_ops", + # Link with support for TensorFlow Debugger (tfdbg). + "//tensorflow/core/debug", + "//tensorflow/core/kernels:collective_ops", + "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:dense_update_ops", + "//tensorflow/core/kernels:fifo_queue_op", + "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:identity_op", + "//tensorflow/core/kernels:identity_n_op", + "//tensorflow/core/kernels:matmul_op", + "//tensorflow/core/kernels:ops_util", + "//tensorflow/core/kernels:queue_ops", + "//tensorflow/core/kernels:session_ops", + "//tensorflow/core/kernels:variable_ops", + ], +) + +tf_cc_test( + name = "direct_session_with_tracking_alloc_test", + size = "small", + srcs = ["direct_session_with_tracking_alloc_test.cc"], + args = ["--heap_check=local"], # The GPU tracer leaks memory + linkstatic = tf_kernel_tests_linkstatic(), + tags = ["no_gpu"], + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:dense_update_ops", + "//tensorflow/core/kernels:fifo_queue_op", + "//tensorflow/core/kernels:identity_op", + "//tensorflow/core/kernels:matmul_op", + "//tensorflow/core/kernels:ops_util", + "//tensorflow/core/kernels:queue_ops", + "//tensorflow/core/kernels:variable_ops", + "//third_party/eigen3", + ], +) + +tf_cc_test( + name = "graph_runner_test", + size = "small", + srcs = ["graph_runner_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + "//tensorflow/core:array_ops_op_lib", + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//third_party/eigen3", + "//tensorflow/c/kernels:bitcast_op_lib", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + "//tensorflow/core/kernels:cwise_op", + ] + if_mkl([":mkl_array_ops_op_lib"]), +) + +tf_cc_test( + name = "executor_test", + size = "small", + srcs = ["executor_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:array", + "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:math", + "//tensorflow/core/kernels:random_ops", + "//tensorflow/core/kernels:state", + ], +) + +tf_cc_test( + name = "function_test", + size = "small", + srcs = ["function_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + tags = [ + "manual", + "no_oss", + ], + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", + "//tensorflow/cc:sendrecv_ops", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:cast_op", + "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:matmul_op", + "//tensorflow/core/kernels:partitioned_function_ops", + "//tensorflow/core/kernels:random_ops", + "//tensorflow/core/kernels:shape_ops", + "//third_party/eigen3", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "function_threadpool_test", + size = "small", + srcs = ["function_threadpool_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:cast_op", + "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:matmul_op", + "//tensorflow/core/kernels:random_ops", + "//tensorflow/core/kernels:shape_ops", + "//third_party/eigen3", + ], +) + +tf_cc_test( + name = "scoped_allocator_mgr_test", + size = "small", + srcs = ["scoped_allocator_mgr_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":core_cpu", + ":core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "input_colocation_exemption_registry_test", + size = "small", + srcs = ["input_colocation_exemption_registry_test.cc"], + deps = [ + ":core_cpu", + ":core_cpu_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +tf_cc_test( + name = "lower_function_call_test", + size = "small", + srcs = ["lower_function_call_op_test.cc"], + deps = [ + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:client_session", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", + "//tensorflow/core:all_kernels", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +tf_cc_test( + name = "lower_if_op_test", + size = "small", + srcs = ["lower_if_op_test.cc"], + deps = [ + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:client_session", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", + "//tensorflow/core:all_kernels", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +tf_cc_test( + name = "lower_case_op_test", + size = "small", + srcs = ["lower_case_op_test.cc"], + deps = [ + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:client_session", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", + "//tensorflow/core:all_kernels", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +tf_cc_test( + name = "lower_while_op_test", + size = "small", + srcs = ["lower_while_op_test.cc"], + deps = [ + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:client_session", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/core:all_kernels", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "@com_google_absl//absl/algorithm:container", + ], +) + +tf_cc_test( + name = "lower_functional_ops_test", + size = "small", + srcs = ["lower_functional_ops_test.cc"], + deps = [ + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:client_session", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/core:all_kernels", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index 680beac3e60..c639e23062d 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -30,6 +30,7 @@ tf_cuda_library( ":eager_operation", ":execute", ":tensor_handle", + "//tensorflow/c:c_api_internal", "//tensorflow/c:tf_tensor_internal", ], alwayslink = 1, diff --git a/tensorflow/core/common_runtime/eager/core.cc b/tensorflow/core/common_runtime/eager/core.cc index d47c9e1c7e0..20500ac210e 100644 --- a/tensorflow/core/common_runtime/eager/core.cc +++ b/tensorflow/core/common_runtime/eager/core.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/tf_tensor_internal.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" @@ -61,8 +62,12 @@ AbstractTensorInterface* TensorHandle::Resolve(Status* status) { h_cpu->Unref(); return nullptr; } - auto* retval = new TensorInterface(*t); + // TODO(b/153052876): Change TF_TensorFromTensor to just return an + // AbstractTensorInterface + TF_Tensor* tf_tensor = TF_TensorFromTensor(*t, status); + AbstractTensorInterface* retval = tf_tensor->tensor; h_cpu->Unref(); + delete tf_tensor; return retval; } else { tensorflow::Tensor tensor; @@ -78,15 +83,19 @@ AbstractTensorInterface* TensorHandle::Resolve(Status* status) { } else { *status = CopyToDevice(*ctx_, ctx_->HostCPU(), &tensor); if (!status->ok()) return nullptr; - if (ImplicitMirroring()) { - *status = AddEmptyLocalMirror(nullptr); - if (!status->ok()) return nullptr; - tensorflow::Tensor mirror = tensor; - *status = SetTensor(std::move(mirror), nullptr); - if (!status->ok()) return nullptr; - } + + *status = AddEmptyLocalMirror(nullptr); + if (!status->ok()) return nullptr; + tensorflow::Tensor mirror = tensor; + *status = SetTensor(std::move(mirror), nullptr); + if (!status->ok()) return nullptr; } - return new TensorInterface(std::move(tensor)); + // TODO(b/153052876): Change TF_TensorFromTensor to just return an + // AbstractTensorInterface + TF_Tensor* tf_tensor = TF_TensorFromTensor(tensor, status); + AbstractTensorInterface* retval = tf_tensor->tensor; + delete tf_tensor; + return retval; } } diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index b304fa77883..f1c90119bda 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -157,9 +157,9 @@ Status CopyInputToExpectedDevice(EagerContext* ctx, EagerOperation* op, " to ", expected_input_device->name()); }, profiler::TraceMeLevel::kInfo); - Status status = EagerCopyToDevice( - handle, ctx, &op->Executor(), expected_input_device, - handle->ImplicitMirroring() || ctx->MirrorTensors(), &result_handle); + Status status = + EagerCopyToDevice(handle, ctx, &op->Executor(), expected_input_device, + /* mirror= */ true, &result_handle); activity.Stop(); if (!status.ok()) { return errors::Internal("Failed copying input tensor from ", @@ -416,7 +416,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, TensorHandle* handle = nullptr; TF_RETURN_IF_ERROR(EagerCopyToDevice( input, &ctx, &executor, device == nullptr ? ctx.HostCPU() : device, - input->ImplicitMirroring() || ctx.MirrorTensors(), &handle)); + /* mirror= */ true, &handle)); op->UpdateInput(i, handle); // Unref handle since it has a ref as an input now handle->Unref(); diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index 385828a0426..2cbb978b5ee 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -137,7 +137,6 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, op_device_(op_device), resource_device_(resource_device), ctx_(ctx), - implicit_mirroring_(true), data_(absl::in_place_type, std::move(t)) { DVLOG(3) << "Creating Local TensorHandle: " << this << " device: " << VariantDeviceDebugString(device_) @@ -152,7 +151,6 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, resource_device_( GetResourceDevice(t.flat()(0), ctx)), ctx_(ctx), - implicit_mirroring_(true), resource_handle_info_( {t.flat()(0).dtypes_and_shapes(), t.flat()(0).allowed_devices()}), @@ -169,7 +167,6 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, CustomDevice* d, op_device_(nullptr), resource_device_(nullptr), ctx_(ctx), - implicit_mirroring_(true), data_(absl::in_place_type, std::move(t)) { // TODO(allenl): Figure out a better op_device story for custom devices, // since always setting it to CPU=nullptr doesn't make much sense. @@ -193,7 +190,6 @@ TensorHandle::TensorHandle(Device* d, Device* op_device, op_device_(op_device), resource_device_(resource_device), ctx_(ctx), - implicit_mirroring_(true), data_(absl::in_place_type) { DVLOG(3) << "Creating empty Local TensorHandle: " << this << " device: " << VariantDeviceDebugString(device_); @@ -215,7 +211,6 @@ TensorHandle::TensorHandle(int64 op_id, int32 output_num, op_device_(d), resource_device_(dtype == DT_RESOURCE ? d : nullptr), ctx_(ctx), - implicit_mirroring_(true), data_(absl::in_place_type, op_id, output_num, remote_task, ctx) { DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this @@ -238,7 +233,6 @@ TensorHandle::TensorHandle(int64 op_id, int32 output_num, op_device_(d), resource_device_(dtype == DT_RESOURCE ? d : nullptr), ctx_(ctx), - implicit_mirroring_(true), data_(absl::in_place_type, op_id, output_num, ctx->GetContextViewId()) { DVLOG(3) << "Creating Lazy Remote TensorHandle: " << this @@ -305,7 +299,13 @@ Status TensorHandle::TensorFromDevice(const Device* d, Status TensorHandle::TensorValue(const Device* d, tensorflow::TensorValue* t) { DVLOG(3) << "TensorValue on TensorHandle: " << this << " device: " << d; - if (d == absl::get(device_)) { + if (VariantDeviceIsCustom(device_)) { + return errors::Internal( + "TensorHandle::TensorValue not supported for custom devices yet. " + "Handle device: ", + VariantDeviceDebugString(device_), + ", requested device: ", d != nullptr ? d->name() : "(nil)"); + } else if (d == absl::get(device_)) { if (IsRemote()) { return errors::Internal("Invalid TensorValue call on remote handle: ", this); diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index 2eb72d65022..9309b4fcccd 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -115,8 +115,6 @@ class TensorHandle : public AbstractTensorHandleInterface, AbstractTensorHandleInterface* Copy() override; - void EnableImplicitMirroring() override { implicit_mirroring_ = true; } - // Return the Tensor from the default device. Status Tensor(const tensorflow::Tensor** t) const; // Return the Tensor from the specified device which could be either the @@ -207,7 +205,6 @@ class TensorHandle : public AbstractTensorHandleInterface, const tensorflow::DataType dtype; bool IsRemote() const; - bool ImplicitMirroring() const { return implicit_mirroring_; } string DebugString() const; @@ -276,7 +273,6 @@ class TensorHandle : public AbstractTensorHandleInterface, // Does not need synchronization because it can be accessed only after // WaitReady() has returned. At that point, is_poisoned_ is immutable. Status is_poisoned_; - bool implicit_mirroring_; // If this TensorHandle 1) is a local tensor, and 2) is a resource handle or // refers to a remote resource handle, we store data types, shapes and allowed diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 39f396d2286..f1177e8cba4 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/pending_counts.h" #include "tensorflow/core/common_runtime/propagator_state.h" #include "tensorflow/core/common_runtime/renamed_device.h" +#include "tensorflow/core/common_runtime/simple_propagator_state.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/cancellation.h" @@ -315,6 +316,8 @@ class ExecutorState { // nodes in 'ready' into 'inline_ready'. // // This method will clear `*ready` before returning. + // + // REQUIRES: `!ready->empty()`. void ScheduleReady(TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready); // Clean up when this executor is done. @@ -372,6 +375,10 @@ class ExecutorState { mutex mu_; Status status_ TF_GUARDED_BY(mu_); + + // A flag that is set on error after the propagator state has been + // dumped for diagnostic purposes. + bool dumped_on_error_ TF_GUARDED_BY(mu_) = false; }; template @@ -625,7 +632,14 @@ template void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { profiler::TraceMe activity( - [&] { return absl::StrCat("ExecutorState::Process#id=", step_id_, "#"); }, + [&] { + // NOTE: This tracing uses the iteration number from the first tagged + // node that executes during this call to `Process()`. In principle, + // subsequent nodes could have different values of `iter_num` that + // will not be traced. + return absl::StrCat("ExecutorState::Process#id=", step_id_, + ",iter_num=", tagged_node.get_iter_num(), "#"); + }, 2); WithContext wc(context_); TaggedNodeSeq ready; @@ -918,7 +932,11 @@ Status ExecutorState::ProcessOutputs( // add better optional debugging support. if (vlog_ && VLOG_IS_ON(1)) { LOG(WARNING) << this << " Compute status: " << s; - propagator_.DumpState(); + mutex_lock l(mu_); + if (!dumped_on_error_) { + propagator_.DumpState(); + dumped_on_error_ = true; + } } if (s.code() == error::RESOURCE_EXHAUSTED) { if (stats_collector_) { @@ -1006,73 +1024,80 @@ template bool ExecutorState::NodeDone( const Status& s, TaggedNodeSeq* ready, NodeExecStatsInterface* stats, TaggedNodeReadyQueue* inline_ready) { - nodestats::SetAllEnd(stats); if (stats) { - if (stats_collector_) { - stats->Done(immutable_state_.params().device->name()); - } else { - delete stats; - } + nodestats::SetAllEnd(stats); + DCHECK_NE(stats_collector_, nullptr); + stats->Done(immutable_state_.params().device->name()); } - bool abort_run = false; - if (!s.ok()) { - // Some error happened. This thread of computation is done. - mutex_lock l(mu_); - if (status_.ok()) { - abort_run = true; + if (TF_PREDICT_TRUE(s.ok())) { + const size_t ready_size = ready->size(); + if (ready_size == 0) { + return num_outstanding_ops_.fetch_sub(1) == 1; + } else { + // NOTE: Avoid touching the atomic counter if only one node becomes ready. + if (ready_size > 1) { + num_outstanding_ops_.fetch_add(ready_size - 1, + std::memory_order_relaxed); + } - // If execution has been cancelled, mark any new errors as being derived. - // This ensures any errors triggered by cancellation are marked as - // derived. - if (cancellation_manager_ && cancellation_manager_->IsCancelled()) { - status_ = StatusGroup::MakeDerived(s); - } else { - status_ = s; + // Schedule the ready nodes in 'ready'. + ScheduleReady(ready, inline_ready); + + return false; + } + } else { + bool abort_run = false; + + // Some error happened. This thread of computation is done. + { + mutex_lock l(mu_); + if (status_.ok()) { + // If this is the first node to fail in this run, we are responsible for + // aborting all other execution in the step. + abort_run = true; + + // If execution has been cancelled, mark any new errors as being + // derived. This ensures any errors triggered by cancellation are marked + // as derived. + if (cancellation_manager_ && cancellation_manager_->IsCancelled()) { + status_ = StatusGroup::MakeDerived(s); + } else { + status_ = s; + } } } - } - if (abort_run) { - TRACEPRINTF("StartAbort: %s", s.ToString().c_str()); - if (cancellation_manager_) { - // only log when the abort happens during the actual run time. - auto device_name = immutable_state_.params().device->name(); - // Use VLOG instead of LOG(warning) because error status is expected when - // the executor is run under the grappler optimization phase or when - // iterating through a tf.data input pipeline. - VLOG(1) << "[" << device_name << "] Executor start aborting: " << s; + + if (abort_run) { + TRACEPRINTF("StartAbort: %s", s.ToString().c_str()); + if (cancellation_manager_) { + // Only log when the abort happens during the actual run time. + // Use VLOG instead of LOG(warning) because error status is expected + // when the executor is run under the grappler optimization phase or + // when iterating through a tf.data input pipeline. + VLOG(1) << "[" << immutable_state_.params().device->name() + << "] Executor start aborting: " << s; + } + + if (rendezvous_) { + rendezvous_->StartAbort(s); + } + if (collective_executor_) { + collective_executor_->StartAbort(s); + } + if (cancellation_manager_) { + cancellation_manager_->StartCancel(); + } } - if (rendezvous_) { - rendezvous_->StartAbort(s); - } - if (collective_executor_) { - collective_executor_->StartAbort(s); - } - if (cancellation_manager_) { - cancellation_manager_->StartCancel(); - } + return num_outstanding_ops_.fetch_sub(1) == 1; } - - bool completed = false; - const size_t ready_size = ready->size(); - if (ready_size == 0 || !s.ok()) { - completed = (num_outstanding_ops_.fetch_sub(1) == 1); - } else if (ready_size > 1) { - num_outstanding_ops_.fetch_add(ready_size - 1, std::memory_order_relaxed); - } - - // Schedule the ready nodes in 'ready'. - if (s.ok()) { - ScheduleReady(ready, inline_ready); - } - return completed; } template void ExecutorState::ScheduleReady( TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready) { - if (ready->empty()) return; + DCHECK(!ready->empty()); int64 scheduled_nsec = 0; if (stats_collector_) { @@ -1249,8 +1274,14 @@ void ExecutorState::Finish() { } void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) { - (new ExecutorState(args, immutable_state_, &kernel_stats_)) - ->RunAsync(std::move(done)); + if (immutable_state_.requires_control_flow_support()) { + (new ExecutorState(args, immutable_state_, &kernel_stats_)) + ->RunAsync(std::move(done)); + } else { + (new ExecutorState(args, immutable_state_, + &kernel_stats_)) + ->RunAsync(std::move(done)); + } } } // namespace diff --git a/tensorflow/core/common_runtime/executor_test.cc b/tensorflow/core/common_runtime/executor_test.cc index 74febf43287..fe62a8459f1 100644 --- a/tensorflow/core/common_runtime/executor_test.cc +++ b/tensorflow/core/common_runtime/executor_test.cc @@ -413,6 +413,14 @@ TEST_F(ExecutorTest, RecvInvalidRefDtype) { rendez->Unref(); } +TEST_F(ExecutorTest, NoInputTensors) { + // Create a graph where none of the nodes have input tensors. + auto g = absl::make_unique(OpRegistry::Global()); + test::graph::Constant(g.get(), V(1.0)); + Create(std::move(g)); + TF_ASSERT_OK(Run(rendez_)); +} + // Create a graph that is 'depth' deep. At each level, fan-in and fan-out a // maximum of 'width' nodes. All nodes are no-ops and all dependencies are // control dependencies. diff --git a/tensorflow/core/common_runtime/gpu/BUILD b/tensorflow/core/common_runtime/gpu/BUILD new file mode 100644 index 00000000000..07919117051 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/BUILD @@ -0,0 +1,428 @@ +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", + "tf_copts", + "tf_cuda_library", +) + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "tf_cc_tests_gpu") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") + +# For platform specific build config +load( + "//tensorflow/core/platform:build_config.bzl", + "tf_kernel_tests_linkstatic", +) +load( + "//tensorflow/core/platform:rules_cc.bzl", + "cc_library", +) +load( + "//tensorflow/core/platform:build_config_root.bzl", + "if_static", + "tf_cuda_tests_tags", +) + +package( + default_visibility = [ + "//tensorflow:internal", + "//tensorflow_models:__subpackages__", + ], + licenses = ["notice"], # Apache 2.0 +) + +# ----------------------------------------------------------------------------- +# Libraries with GPU facilities that are useful for writing kernels. +cc_library( + name = "gpu_lib", + srcs = [ + "gpu_event_mgr.cc", + ], + hdrs = [ + "gpu_event_mgr.h", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:stream_executor", + ], +) + +cc_library( + name = "gpu_headers_lib", + hdrs = [ + "gpu_event_mgr.h", + ], +) + +cc_library( + name = "cuda", + deps = [ + "//tensorflow/core/platform/default/build_config:cuda", + ], +) + +cc_library( + name = "rocm", + deps = [ + "//tensorflow/core/platform/default/build_config:rocm", + ], +) + +cc_library( + name = "gpu_id", + hdrs = [ + "gpu_id.h", + "gpu_id_manager.h", + ], + deps = [ + "//tensorflow/core:lib", + ] + if_static([ + ":gpu_id_impl", + ]), +) + +cc_library( + name = "gpu_id_impl", + srcs = ["gpu_id_manager.cc"], + hdrs = [ + "gpu_id.h", + "gpu_id_manager.h", + ], + deps = [ + "//tensorflow/core:lib", + ], +) + +filegroup( + name = "gpu_runtime_headers", + srcs = [ + "gpu_bfc_allocator.h", + "gpu_cudamalloc_allocator.h", + "gpu_debug_allocator.h", + "gpu_device.h", + "gpu_event_mgr.h", + "gpu_host_allocator.h", + "gpu_id.h", + "gpu_id_manager.h", + "gpu_id_utils.h", + "gpu_init.h", + "gpu_managed_allocator.h", + "gpu_mem_allocator.h", + "gpu_process_state.h", + "gpu_stream_util.h", + "gpu_util.h", + "//tensorflow/core/common_runtime:gpu_runtime_headers", + ], + visibility = ["//visibility:private"], +) + +tf_cuda_library( + name = "gpu_runtime_impl", + srcs = [ + "gpu_cudamalloc_allocator.cc", + "gpu_debug_allocator.cc", + "gpu_device.cc", + "gpu_device_factory.cc", + "gpu_managed_allocator.cc", + "gpu_process_state.cc", + "gpu_stream_util.cc", + "gpu_util.cc", + "gpu_util_platform_specific.cc", + ], + hdrs = [":gpu_runtime_headers"], + copts = tf_copts(), + cuda_deps = [ + "@local_config_cuda//cuda:cudnn_header", + ], + deps = [ + ":gpu_bfc_allocator", + ":gpu_id_impl", + ":gpu_init_impl", + ":gpu_lib", + "//tensorflow/core:core_cpu_impl", + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:stream_executor", + "//tensorflow/core/profiler/lib:annotated_traceme", + "//tensorflow/core/profiler/lib:scoped_annotation", + "//third_party/eigen3", + ], + alwayslink = 1, +) + +tf_cuda_library( + name = "gpu_runtime", + hdrs = [":gpu_runtime_headers"], + linkstatic = 1, + deps = [ + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:stream_executor", + "//third_party/eigen3", + ] + if_static([":gpu_runtime_impl"]), +) + +# This is redundant with the "gpu_runtime_*" targets above. It's useful for +# applications that want to depend on a minimal subset of TensorFlow (e.g. XLA). +tf_cuda_library( + name = "gpu_bfc_allocator", + srcs = [ + "gpu_bfc_allocator.cc", + ], + hdrs = ["gpu_bfc_allocator.h"], + features = ["parse_headers"], + visibility = ["//visibility:public"], + deps = [ + ":gpu_mem_allocator", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime:bfc_allocator", + ], +) + +tf_cuda_library( + name = "gpu_mem_allocator", + srcs = [ + "gpu_id.h", + ], + hdrs = [ + "gpu_host_allocator.h", + "gpu_mem_allocator.h", + ], + features = ["parse_headers"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:stream_executor", + "//tensorflow/core/framework:allocator", + ], +) + +tf_cuda_library( + name = "gpu_init", + hdrs = [ + "gpu_init.h", + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:stream_executor", + ] + if_static( + [":gpu_init_impl"], + ), +) + +tf_cuda_library( + name = "gpu_init_impl", + srcs = [ + "gpu_init.cc", + ], + hdrs = [ + "gpu_init.h", + ], + copts = tf_copts(), + linkstatic = 1, + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:stream_executor", + ], + alwayslink = 1, +) + +# ----------------------------------------------------------------------------- +# Tests + +tf_cc_test( + name = "gpu_device_on_non_gpu_machine_test", + size = "small", + srcs = ["gpu_device_on_non_gpu_machine_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":gpu_headers_lib", + ":gpu_id", + ":gpu_runtime", + "//tensorflow/core:test", + ], +) + +tf_cc_tests_gpu( + name = "gpu_related_tests", + size = "small", + srcs = [ + "gpu_bfc_allocator_test.cc", + "gpu_device_test.cc", + "gpu_id_manager_test.cc", + "pool_allocator_test.cc", + ], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags(), + deps = [ + ":gpu_id", + ":gpu_runtime", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/common_runtime:core_cpu", + "//tensorflow/core/common_runtime:core_cpu_internal", + "//tensorflow/core/common_runtime:direct_session_internal", + "//tensorflow/core/kernels:ops_util", + ], +) + +tf_cc_test_gpu( + name = "gpu_event_mgr_test", + srcs = ["gpu_event_mgr_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags(), + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:cwise_op", + ], +) + +tf_cuda_cc_test( + name = "gpu_device_unified_memory_test", + size = "small", + srcs = [ + "gpu_device_test.cc", + ], + linkstatic = tf_kernel_tests_linkstatic(), + # Runs test on a Guitar cluster that uses P100s to test unified memory + # allocations. + tags = tf_cuda_tests_tags() + [ + "guitar", + "multi_gpu", + ], + deps = [ + ":gpu_id", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/common_runtime:core_cpu", + "//tensorflow/core/common_runtime:core_cpu_internal", + "//tensorflow/core/common_runtime:direct_session_internal", + "//tensorflow/core/kernels:ops_util", + ], +) + +tf_cc_test_gpu( + name = "gpu_allocator_retry_test", + size = "medium", + srcs = ["gpu_allocator_retry_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags(), + deps = [ + ":gpu_runtime", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/common_runtime:core_cpu", + "//tensorflow/core/common_runtime:core_cpu_internal", + "//tensorflow/core/common_runtime:direct_session_internal", + ], +) + +tf_cc_test_gpu( + name = "gpu_debug_allocator_test", + size = "medium", + srcs = ["gpu_debug_allocator_test.cc"], + args = ["--gtest_death_test_style=threadsafe"], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags(), + deps = [ + ":gpu_id", + ":gpu_runtime", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/common_runtime:core_cpu", + "//tensorflow/core/common_runtime:core_cpu_internal", + "//tensorflow/core/common_runtime:direct_session_internal", + "//tensorflow/core/kernels:ops_util", + ], +) + +tf_cc_test_gpu( + name = "gpu_stream_util_test", + size = "small", + srcs = ["gpu_stream_util_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags() + ["nomac"], + deps = [ + ":gpu_runtime", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:sendrecv_ops", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/common_runtime:core_cpu", + "//tensorflow/core/common_runtime:core_cpu_internal", + "//tensorflow/core/common_runtime:direct_session_internal", + "//tensorflow/core/kernels:matmul_op", + "//tensorflow/core/kernels:ops_util", + ], +) diff --git a/tensorflow/core/common_runtime/graph_view.h b/tensorflow/core/common_runtime/graph_view.h index b0bc0f4b6de..6d31555ed9a 100644 --- a/tensorflow/core/common_runtime/graph_view.h +++ b/tensorflow/core/common_runtime/graph_view.h @@ -211,6 +211,8 @@ class GraphView { Status SetAllocAttrs(const Graph* g, const Device* device); void SetScopedAllocatorAttrs(const std::vector& sa_nodes); + // Returns a mutable pointer to the `NodeItem` with the given `id` if it + // exists in the graph, or `nullptr` if it does not. NodeItem* node(int32 id) const { DCHECK_GE(id, 0); DCHECK_LT(id, num_nodes_); @@ -220,6 +222,17 @@ class GraphView { : reinterpret_cast(space_ + node_offsets_[id])); } + // Returns the `NodeItem` with the given `id`. + // + // REQUIRES: `id` must be the ID of a valid node in the graph. + const NodeItem& node_ref(int32 id) const { + DCHECK_GE(id, 0); + DCHECK_LT(id, num_nodes_); + uint32 offset = node_offsets_[id]; + DCHECK_NE(offset, kuint32max); + return *reinterpret_cast(space_ + node_offsets_[id]); + } + int32 num_nodes() const { return num_nodes_; } private: diff --git a/tensorflow/core/common_runtime/immutable_executor_state.cc b/tensorflow/core/common_runtime/immutable_executor_state.cc index 97c17aa287d..2f6d985b9cc 100644 --- a/tensorflow/core/common_runtime/immutable_executor_state.cc +++ b/tensorflow/core/common_runtime/immutable_executor_state.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/core/common_runtime/metrics.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/edgeset.h" #include "tensorflow/core/graph/graph.h" @@ -88,13 +89,33 @@ Status ImmutableExecutorState::Initialize(const Graph& graph) { EnsureFrameInfo(it)->nodes = absl::make_unique>(); } + root_frame_info_ = frame_info_[""]; pending_ids_.resize(gview_.num_nodes()); // Preprocess every node in the graph to create an instance of op // kernel for each node. + requires_control_flow_ = false; for (const Node* n : graph.nodes()) { if (IsSink(n)) continue; + if (IsSwitch(n) || IsMerge(n) || IsEnter(n) || IsExit(n)) { + requires_control_flow_ = true; + } else if (IsRecv(n)) { + // A Recv node from a different device may produce dead tensors from + // non-local control-flow nodes. + // + // TODO(mrry): Track whether control flow was present in the + // pre-partitioned graph, and enable the caller (e.g. + // `DirectSession`) to relax this constraint. + string send_device; + string recv_device; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "send_device", &send_device)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "recv_device", &recv_device)); + if (send_device != recv_device) { + requires_control_flow_ = true; + } + } + const int id = n->id(); const string& frame_name = cf_info.frame_names[id]; FrameInfo* frame_info = EnsureFrameInfo(frame_name); @@ -302,10 +323,17 @@ void ImmutableExecutorState::InitializePending(const Graph* graph, const ControlFlowInfo& cf_info) { for (auto& it : cf_info.unique_frame_names) { FrameInfo* finfo = EnsureFrameInfo(it); - DCHECK_EQ(finfo->pending_counts, nullptr); + DCHECK_EQ(finfo->pending_counts.get(), nullptr); finfo->pending_counts = absl::make_unique(finfo->pending_counts_layout); } + + if (!requires_control_flow_) { + atomic_pending_counts_.reset(new std::atomic[gview_.num_nodes()]); + std::fill(atomic_pending_counts_.get(), + atomic_pending_counts_.get() + gview_.num_nodes(), 0); + } + for (const Node* n : graph->nodes()) { if (IsSink(n)) continue; const int id = n->id(); @@ -314,6 +342,9 @@ void ImmutableExecutorState::InitializePending(const Graph* graph, GetMaxPendingCounts(n, &max_pending, &max_dead); auto& counts = EnsureFrameInfo(name)->pending_counts; counts->set_initial_count(pending_ids_[id], max_pending); + if (!requires_control_flow_) { + atomic_pending_counts_[id] = max_pending; + } } } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/immutable_executor_state.h b/tensorflow/core/common_runtime/immutable_executor_state.h index c9c23e55a21..9a2987cfaae 100644 --- a/tensorflow/core/common_runtime/immutable_executor_state.h +++ b/tensorflow/core/common_runtime/immutable_executor_state.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_IMMUTABLE_EXECUTOR_STATE_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_IMMUTABLE_EXECUTOR_STATE_H_ +#include #include #include #include @@ -91,6 +92,25 @@ class ImmutableExecutorState { } } + const FrameInfo& get_root_frame_info() const { return *root_frame_info_; } + + bool requires_control_flow_support() const { return requires_control_flow_; } + + // Copies the pending counts for nodes in this graph to the given array. + // + // This method provides a more efficient way of initializing + // `SimplePropagatorState` than individually accessing the pending counts from + // `get_root_frame_info().counts`. + // + // REQUIRES: `!requires_control_flow_support && len(dest) == + // graph_view().num_nodes()`. + void copy_pending_counts(std::atomic* dest) const { + DCHECK(!requires_control_flow_); + memcpy(dest, atomic_pending_counts_.get(), + graph_view().num_nodes() * sizeof(std::atomic)); + std::atomic_thread_fence(std::memory_order_release); + } + private: struct ControlFlowInfo { gtl::FlatSet unique_frame_names; @@ -106,6 +126,7 @@ class ImmutableExecutorState { // Owned. LocalExecutorParams params_; GraphView gview_; + bool requires_control_flow_; std::vector pending_ids_; // Root nodes (with no in edges) that should form the initial ready queue @@ -115,6 +136,11 @@ class ImmutableExecutorState { // TODO(yuanbyu): We could cache it along with the graph so to avoid // the overhead of constructing it for each executor instance. gtl::FlatMap frame_info_; + const FrameInfo* root_frame_info_; // Not owned. + + // If `requires_control_flow_` is false, this points to an array of initial + // pending counts for the nodes in the graph, indexed by node ID. + std::unique_ptr[]> atomic_pending_counts_; // Shallow copies of the constant tensors used in the graph. std::vector const_tensors_; diff --git a/tensorflow/core/common_runtime/propagator_debug_utils.cc b/tensorflow/core/common_runtime/propagator_debug_utils.cc new file mode 100644 index 00000000000..27f9da7ea52 --- /dev/null +++ b/tensorflow/core/common_runtime/propagator_debug_utils.cc @@ -0,0 +1,95 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/common_runtime/propagator_debug_utils.h" + +#include + +#include "tensorflow/core/common_runtime/entry.h" +#include "tensorflow/core/common_runtime/immutable_executor_state.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +// 1-D, 0 element tensor. +static const Tensor* const kEmptyTensor = new Tensor; + + +const Tensor* GetTensorValueForDump(const Entry& input) { + switch (input.state) { + case Entry::State::NO_VALUE: + return kEmptyTensor; + case Entry::State::HAS_VALUE: + return input.val.get(); + case Entry::State::HAS_CONST_TENSOR: + return input.const_tensor; + case Entry::State::HAS_REF_TENSOR: + return input.ref_tensor.tensor; + } +} + +void DumpPendingNodeState(const ImmutableExecutorState& immutable_state, + const int node_id, const Entry* input_vector, + const bool show_nodes_with_no_ready_inputs) { + const NodeItem& node_item = immutable_state.graph_view().node_ref(node_id); + const int input_base = node_item.input_start; + if (!show_nodes_with_no_ready_inputs) { + bool has_ready_input = false; + for (int i = 0; i < node_item.num_inputs; ++i) { + const Entry& input = input_vector[input_base + i]; + const Tensor* tensor = GetTensorValueForDump(input); + if (tensor && tensor->IsInitialized()) { + has_ready_input = true; + break; + } + } + if (!has_ready_input) { + return; + } + } + LOG(WARNING) << " Pending Node: " << node_item.DebugString(); + for (int i = 0; i < node_item.num_inputs; ++i) { + const Entry& input = input_vector[input_base + i]; + const Tensor* tensor = GetTensorValueForDump(input); + if (tensor->IsInitialized()) { + LOG(WARNING) << " Input " << i << ": " + << strings::StrCat( + "Tensordtype()), + " shape: ", tensor->shape().DebugString(), ">"); + } else { + LOG(WARNING) << " Input " << i << ": not present"; + } + } +} + +void DumpActiveNodeState(const ImmutableExecutorState& immutable_state, + const int node_id, const Entry* input_vector) { + const NodeItem& node_item = immutable_state.graph_view().node_ref(node_id); + LOG(WARNING) << " Active Node: " << node_item.DebugString(); + const int input_base = node_item.input_start; + for (int i = 0; i < node_item.num_inputs; ++i) { + const Entry& input = input_vector[input_base + i]; + const Tensor* tensor = GetTensorValueForDump(input); + if (tensor->IsInitialized()) { + LOG(WARNING) << " Input " << i << ": " + << strings::StrCat( + "Tensordtype()), + " shape: ", tensor->shape().DebugString(), ">"); + } else { + LOG(WARNING) << " Input " << i << ": not present"; + } + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/propagator_debug_utils.h b/tensorflow/core/common_runtime/propagator_debug_utils.h new file mode 100644 index 00000000000..8f1204998ff --- /dev/null +++ b/tensorflow/core/common_runtime/propagator_debug_utils.h @@ -0,0 +1,40 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_DEBUG_UTILS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_DEBUG_UTILS_H_ + +namespace tensorflow { + +struct Entry; +class ImmutableExecutorState; +class Tensor; + +// Returns a pointer to the tensor in `input` if one exists, or `nullptr`. +const Tensor* GetTensorValueForDump(const Entry& input); + +// Writes a LOG(WARNING) message describing the state of the pending node +// `node_id` in the graph described by `immutable_state`. +void DumpPendingNodeState(const ImmutableExecutorState& immutable_state, + const int node_id, const Entry* input_vector, + const bool show_nodes_with_no_ready_inputs); + +// Writes a LOG(WARNING) message describing the state of the active node +// `node_id` in the graph described by `immutable_state`. +void DumpActiveNodeState(const ImmutableExecutorState& immutable_state, + const int node_id, const Entry* input_vector); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_DEBUG_UTILS_H_ diff --git a/tensorflow/core/common_runtime/propagator_state.cc b/tensorflow/core/common_runtime/propagator_state.cc index e2827a8eb1f..a4e311cbc6b 100644 --- a/tensorflow/core/common_runtime/propagator_state.cc +++ b/tensorflow/core/common_runtime/propagator_state.cc @@ -16,17 +16,12 @@ limitations under the License. #include "tensorflow/core/common_runtime/propagator_state.h" #include "tensorflow/core/common_runtime/graph_view.h" +#include "tensorflow/core/common_runtime/propagator_debug_utils.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/profiler/lib/traceme.h" namespace tensorflow { -// 1-D, 0 element tensor. -static const Tensor* const kEmptyTensor = new Tensor; - -typedef gtl::InlinedVector TensorValueVec; -typedef gtl::InlinedVector AllocatorAttributeVec; - PropagatorState::PropagatorState(const ImmutableExecutorState& immutable_state, int64 step_id) : immutable_state_(immutable_state), @@ -57,7 +52,7 @@ void PropagatorState::ActivateRoots(gtl::ArraySlice roots, TaggedNodeSeq* ready) { for (const NodeItem* item : roots) { DCHECK_EQ(item->num_inputs, 0); - ready->push_back(TaggedNode{item, root_frame_, 0, false}); + ready->emplace_back(item, root_frame_, 0, false); } mutex_lock l(root_frame_->mu); root_frame_->GetIteration(0)->outstanding_ops = ready->size(); @@ -173,72 +168,6 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, } } -const Tensor* PropagatorState::GetTensorValueForDump(const Entry& input) { - switch (input.state) { - case Entry::State::NO_VALUE: - return kEmptyTensor; - case Entry::State::HAS_VALUE: - return input.val.get(); - case Entry::State::HAS_CONST_TENSOR: - return input.const_tensor; - case Entry::State::HAS_REF_TENSOR: - return input.ref_tensor.tensor; - } -} - -void PropagatorState::DumpPendingNodeState( - const int node_id, const Entry* input_vector, - const bool show_nodes_with_no_ready_inputs) { - const NodeItem& node_item = *immutable_state_.graph_view().node(node_id); - const int input_base = node_item.input_start; - if (!show_nodes_with_no_ready_inputs) { - bool has_ready_input = false; - for (int i = 0; i < node_item.num_inputs; ++i) { - const Entry& input = input_vector[input_base + i]; - const Tensor* tensor = GetTensorValueForDump(input); - if (tensor->IsInitialized()) { - has_ready_input = true; - break; - } - } - if (!has_ready_input) { - return; - } - } - LOG(WARNING) << " Pending Node: " << node_item.DebugString(); - for (int i = 0; i < node_item.num_inputs; ++i) { - const Entry& input = input_vector[input_base + i]; - const Tensor* tensor = GetTensorValueForDump(input); - if (tensor->IsInitialized()) { - LOG(WARNING) << " Input " << i << ": " - << strings::StrCat( - "Tensordtype()), - " shape: ", tensor->shape().DebugString(), ">"); - } else { - LOG(WARNING) << " Input " << i << ": not present"; - } - } -} - -void PropagatorState::DumpActiveNodeState(const int node_id, - const Entry* input_vector) { - const NodeItem& node_item = *immutable_state_.graph_view().node(node_id); - LOG(WARNING) << " Active Node: " << node_item.DebugString(); - const int input_base = node_item.input_start; - for (int i = 0; i < node_item.num_inputs; ++i) { - const Entry& input = input_vector[input_base + i]; - const Tensor* tensor = GetTensorValueForDump(input); - if (tensor->IsInitialized()) { - LOG(WARNING) << " Input " << i << ": " - << strings::StrCat( - "Tensordtype()), - " shape: ", tensor->shape().DebugString(), ">"); - } else { - LOG(WARNING) << " Input " << i << ": not present"; - } - } -} - void PropagatorState::DumpIterationState(const FrameState* frame, IterationState* iteration) { const std::vector* nodes = frame->nodes; @@ -248,7 +177,8 @@ void PropagatorState::DumpIterationState(const FrameState* frame, immutable_state_.pending_ids()[node->node_id]; if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY || iteration->node_state(pending_id) == PendingCounts::PENDING_READY) { - DumpPendingNodeState(node->node_id, iteration->input_tensors, false); + DumpPendingNodeState(immutable_state_, node->node_id, + iteration->input_tensors, false); } } // Then the active nodes. @@ -256,7 +186,8 @@ void PropagatorState::DumpIterationState(const FrameState* frame, PendingCounts::Handle pending_id = immutable_state_.pending_ids()[node->node_id]; if (iteration->node_state(pending_id) == PendingCounts::STARTED) { - DumpActiveNodeState(node->node_id, iteration->input_tensors); + DumpActiveNodeState(immutable_state_, node->node_id, + iteration->input_tensors); } } // Show all input tensors in use. @@ -279,14 +210,11 @@ void PropagatorState::DumpIterationState(const FrameState* frame, void PropagatorState::DumpState() { mutex_lock l(mu_); - if (!dumped_on_error_) { - LOG(WARNING) << "Dumping state"; - for (auto& frame : outstanding_frames_) { - LOG(WARNING) << frame.first; - FrameState* frame_state = frame.second; - frame_state->DumpIterationState(this); - } - dumped_on_error_ = true; + LOG(WARNING) << "Dumping state"; + for (auto& frame : outstanding_frames_) { + LOG(WARNING) << frame.first; + FrameState* frame_state = frame.second; + frame_state->DumpIterationState(this); } } @@ -378,7 +306,7 @@ void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) { for (const EdgeInfo& e : item->output_edges()) { const NodeItem& dst_item = - *immutable_state_.graph_view().node(e.dst_id); + immutable_state_.graph_view().node_ref(e.dst_id); const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id]; bool dst_dead = true; @@ -398,7 +326,7 @@ void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) { for (const ControlEdgeInfo& e : item->output_control_edges()) { const NodeItem& dst_item = - *immutable_state_.graph_view().node(e.dst_id); + immutable_state_.graph_view().node_ref(e.dst_id); const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id]; bool dst_dead; @@ -464,17 +392,17 @@ void PropagatorState::FrameState::ActivateNodesFastPath(const NodeItem* item, // // NOTE(mrry): Use a macro here instead of a lambda, because this method is // performance-critical and we need to ensure that the code is inlined. -#define MAYBE_ADD_TO_READY(dst_id, adjust_result) \ - do { \ - if (!adjust_result.any_pending) { \ - const NodeItem* dst_item = gview.node(dst_id); \ - TaggedNode& t = ready->emplace_back(); \ - t.node_item = dst_item; \ - t.input_frame = this; \ - t.input_iter = iter; \ - t.is_dead = adjust_result.any_dead; \ - iter_state->outstanding_ops++; \ - } \ +#define MAYBE_ADD_TO_READY(dst_id, adjust_result) \ + do { \ + if (!adjust_result.any_pending) { \ + const NodeItem* dst_item = &gview.node_ref(dst_id); \ + TaggedNode& t = ready->emplace_back(); \ + t.node_item = dst_item; \ + t.input_frame = this; \ + t.input_iter = iter; \ + t.is_dead = adjust_result.any_dead; \ + iter_state->outstanding_ops++; \ + } \ } while (0); Entry* input_tensors = iter_state->input_tensors; @@ -534,7 +462,7 @@ void PropagatorState::FrameState::ActivateNodesSlowPath(const NodeItem* item, for (const EdgeInfo& e : item->output_edges()) { const int dst_id = e.dst_id; - const NodeItem* dst_item = gview.node(dst_id); + const NodeItem* dst_item = &gview.node_ref(dst_id); const PendingCounts::Handle dst_pending_id = immutable_state.pending_ids()[dst_id]; const int src_slot = e.output_slot; @@ -596,7 +524,7 @@ void PropagatorState::FrameState::ActivateNodesSlowPath(const NodeItem* item, for (const ControlEdgeInfo& e : item->output_control_edges()) { const int dst_id = e.dst_id; - const NodeItem* dst_item = gview.node(dst_id); + const NodeItem* dst_item = &gview.node_ref(dst_id); const PendingCounts::Handle dst_pending_id = immutable_state.pending_ids()[dst_id]; diff --git a/tensorflow/core/common_runtime/propagator_state.h b/tensorflow/core/common_runtime/propagator_state.h index d82d3bf7261..6d5abd02afa 100644 --- a/tensorflow/core/common_runtime/propagator_state.h +++ b/tensorflow/core/common_runtime/propagator_state.h @@ -74,6 +74,7 @@ class PropagatorState { const NodeItem& get_node_item() const { return *node_item; } bool get_is_dead() const { return is_dead; } + int64 get_iter_num() const { return input_iter; } }; // A drop-in replacement for std::deque. We typically don't @@ -428,26 +429,15 @@ class PropagatorState { void CleanupFramesIterations(FrameState* frame, int64 iter, TaggedNodeSeq* ready); - // Provide debugging output about an outstanding node in the executor. - void DumpPendingNodeState(const int node_id, const Entry* input_vector, - bool show_nodes_with_no_ready_inputs); - void DumpActiveNodeState(const int node_id, const Entry* input_vector); - // Provide debugging output about an outstanding iteration in the executor. void DumpIterationState(const FrameState* frame, IterationState* iteration); - const Tensor* GetTensorValueForDump(const Entry& input); - const ImmutableExecutorState& immutable_state_; const int64 step_id_; const bool vlog_; mutex mu_; - // A flag that is set on error after the frame state has been - // dumped for diagnostic purposes. - bool dumped_on_error_ TF_GUARDED_BY(mu_) = false; - // The root frame in which the execution of this step is started. FrameState* root_frame_; diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index 8d396298e01..d6ab1e30a55 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -85,8 +85,8 @@ Status InferShapesForFunctionSubNode(const Node* node, ShapeRefiner* refiner, // TODO(b/134547156): TEMPORARY WORKAROUND. If input shape handle is not set // in outer context, set _Arg node output shape to unknown. if (outer_context->input(index).SameHandle(ShapeHandle())) { - LOG(WARNING) << "Function instantiation has undefined input shape at " - << "index: " << index << " in the outer inference context."; + VLOG(1) << "Function instantiation has undefined input shape at " + << "index: " << index << " in the outer inference context."; node_context->set_output(0, node_context->UnknownShape()); } else { node_context->set_output(0, outer_context->input(index)); diff --git a/tensorflow/core/common_runtime/simple_propagator_state.cc b/tensorflow/core/common_runtime/simple_propagator_state.cc new file mode 100644 index 00000000000..bf6172bf3cf --- /dev/null +++ b/tensorflow/core/common_runtime/simple_propagator_state.cc @@ -0,0 +1,138 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/common_runtime/simple_propagator_state.h" + +#include + +#include "tensorflow/core/common_runtime/propagator_debug_utils.h" +#include "tensorflow/core/profiler/lib/traceme.h" + +namespace tensorflow { + +SimplePropagatorState::SimplePropagatorState( + const ImmutableExecutorState& immutable_state, int64 step_id) + : SimplePropagatorState(immutable_state, step_id, + immutable_state.get_root_frame_info()) {} + +SimplePropagatorState::SimplePropagatorState( + const ImmutableExecutorState& immutable_state, int64 step_id, + const ImmutableExecutorState::FrameInfo& finfo) + : immutable_state_(immutable_state), + step_id_(step_id), + vlog_(VLOG_IS_ON(1)), + input_tensors_(finfo.total_inputs), + pending_( + new std::atomic[immutable_state.graph_view().num_nodes()]), + active_(vlog_ ? new std::vector( + immutable_state.graph_view().num_nodes()) + : nullptr), + nodes_(finfo.nodes.get()) { + immutable_state_.copy_pending_counts(pending_.get()); +} + +SimplePropagatorState::~SimplePropagatorState() {} + +void SimplePropagatorState::ActivateRoots( + gtl::ArraySlice roots, TaggedNodeSeq* ready) { + for (const NodeItem* item : roots) { + DCHECK_EQ(item->num_inputs, 0); + ready->push_back(TaggedNode{item}); + } +} + +void SimplePropagatorState::PropagateOutputs(const TaggedNode& tagged_node, + EntryVector* outputs, + TaggedNodeSeq* ready) { + profiler::TraceMe activity( + [&]() { + return strings::StrCat( + "ExecutorPropagateOutputs#", "id=", step_id_, + ",kernel_name=", tagged_node.node_item->kernel->name_view(), + ",num_output_edges=", tagged_node.node_item->num_output_edges, + ",num_output_control_edges=", + tagged_node.node_item->num_output_control_edges, "#"); + }, + profiler::GetTFTraceMeLevel(/*is_expensive=*/false)); + + // Propagates outputs along out edges, and puts newly ready nodes + // into the ready queue. + DCHECK(ready->empty()); + + const GraphView& gview = immutable_state_.graph_view(); + const NodeItem* item = tagged_node.node_item; + + for (const EdgeInfo& e : item->output_edges()) { + const int dst_id = e.dst_id; + const int src_slot = e.output_slot; + const int dst_loc = e.input_slot; + + // NOTE(mrry): The write to `input_tensors_[dst_loc]` must happen before + // the pending count update, or else one thread might conclude that the + // count has dropped to zero before another thread finishes updating the + // input. + if (e.is_last) { + input_tensors_[dst_loc] = std::move((*outputs)[src_slot]); + } else { + input_tensors_[dst_loc] = (*outputs)[src_slot]; + } + + int32 previous_num_pending = + pending_[dst_id].fetch_sub(1, std::memory_order_release); + if (previous_num_pending == 1) ready->emplace_back(&gview.node_ref(dst_id)); + } + + for (const ControlEdgeInfo& e : item->output_control_edges()) { + const int dst_id = e.dst_id; + + int32 previous_num_pending = + pending_[dst_id].fetch_sub(1, std::memory_order_release); + if (previous_num_pending == 1) ready->emplace_back(&gview.node_ref(dst_id)); + } +} + +void SimplePropagatorState::DumpState() { + mutex_lock l(mu_); + // Dump any waiting nodes that are holding on to tensors. + for (const NodeItem* node : *nodes_) { + if (pending_[node->node_id]) { + DumpPendingNodeState(immutable_state_, node->node_id, + input_tensors_.data(), false); + } + } + // Then the active nodes. + for (const NodeItem* node : *nodes_) { + if ((*active_)[node->node_id]) { + DumpActiveNodeState(immutable_state_, node->node_id, + input_tensors_.data()); + } + } + // Show all input tensors in use. + size_t total_bytes = 0; + for (size_t i = 0; i < input_tensors_.size(); ++i) { + const Entry& input = input_tensors_[i]; + const Tensor* tensor = GetTensorValueForDump(input); + if (tensor && tensor->IsInitialized()) { + LOG(WARNING) << " Input " << i << ": " + << strings::StrCat( + "Tensordtype()), + " shape: ", tensor->shape().DebugString(), + ", bytes: ", tensor->TotalBytes(), ">"); + total_bytes += tensor->TotalBytes(); + } + } + LOG(WARNING) << " Total bytes " << total_bytes; +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/simple_propagator_state.h b/tensorflow/core/common_runtime/simple_propagator_state.h new file mode 100644 index 00000000000..1aee4c7ff2f --- /dev/null +++ b/tensorflow/core/common_runtime/simple_propagator_state.h @@ -0,0 +1,188 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_ + +#include + +#include "tensorflow/core/common_runtime/entry.h" +#include "tensorflow/core/common_runtime/immutable_executor_state.h" +#include "tensorflow/core/common_runtime/pending_counts.h" +#include "tensorflow/core/framework/control_flow.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Represents the ephemeral "edge state" associated with one invocation of +// `Executor::Run()`. +// +// NOTE: `SimplePropagatorState` does not support "v1-style" control flow, +// including "dead tensors", "Switch" and "Merge" nodes, and cycles in the +// graph. Use `PropagatorState` for graphs with those features. +// `SimplePropagatorState` *does* support "v2-style" or "functional" control +// flow. +// +// `SimplePropagatorState` is responsible for propagating values along dataflow +// edges in a TensorFlow graph and determining which nodes are runnable. The +// executor primarily updates `SimplePropagatorState` by calling +// `PropagateOutputs()` after processing a node, and `SimplePropagatorState` +// dispatches `TaggedNode`s by adding them to a `TaggedNodeSeq`. +class SimplePropagatorState { + public: + SimplePropagatorState(const ImmutableExecutorState& immutable_state, + int64 step_id); + ~SimplePropagatorState(); + + // A `TaggedNode` corresponds to a single invocation of a node's kernel, + // and it is created when the kernel becomes runnable. + struct TaggedNode { + const NodeItem* node_item; + + explicit TaggedNode(const NodeItem* node_item) : node_item(node_item) {} + + const NodeItem& get_node_item() const { return *node_item; } + + bool get_is_dead() const { return false; } + int64 get_iter_num() const { return 0; } + }; + + // A drop-in replacement for std::deque. We typically don't + // have that many nodes in the ready queue, so we just use a vector and + // don't free up memory from the queue as we consume nodes. + // TODO(mrry): Extract this and share it with the version in + // `PropagatorState`. The correct constants might be different, since + // sizeof(TaggedNode) is smaller in this version. + class TaggedNodeReadyQueue { + public: + TaggedNodeReadyQueue() : front_index_(0) {} + + void push_back(const TaggedNode& node) { ready_.push_back(node); } + TaggedNode front() const { + DCHECK_LT(front_index_, ready_.size()); + return ready_[front_index_]; + } + void pop_front() { + DCHECK_LT(front_index_, ready_.size()); + front_index_++; + if ((front_index_ == ready_.size()) || (front_index_ > kSpillThreshold)) { + if (front_index_ == ready_.size()) { + ready_.clear(); + } else { + // Lots of unused entries at beginning of vector: move everything + // down to start of vector. + ready_.erase(ready_.begin(), ready_.begin() + front_index_); + } + front_index_ = 0; + } + } + bool empty() const { return ready_.empty(); } + + private: + // TODO(b/152925936): Re-evaluate these constants with current usage + // patterns. + static constexpr int kSpillThreshold = 16384; + gtl::InlinedVector ready_; + int front_index_; + }; + + // TODO(b/152925936): Re-evaluate this constant with current usage patterns. + typedef gtl::InlinedVector TaggedNodeSeq; + + // Creates and adds a `TaggedNode` for each node in `roots` to `*ready`. + void ActivateRoots(gtl::ArraySlice roots, + TaggedNodeSeq* ready); + + // After processing the outputs, propagates the outputs to their dsts. + // Contents of *outputs are left in an indeterminate state after + // returning from this method. + void PropagateOutputs(const TaggedNode& tagged_node, EntryVector* outputs, + TaggedNodeSeq* ready); + + // Returns an array of `Entry` objects corresponding to the inputs of + // `tagged_node`. + Entry* GetInputTensors(const TaggedNode& tagged_node) { +#if defined(THREAD_SANITIZER) || defined(DEBUG) + // NOTE: This read of `pending_[...]` works around a limitation in TSAN. + // To avoid false positive data race reports, we need to perform an atomic + // object access that will establish the happens-before relation between + // the write to input_tensors_ in `PropagateOutputs()` and the read in + // `PrepareInputs()`. + CHECK_EQ(pending_[tagged_node.node_item->node_id], 0); +#endif // defined(THREAD_SANITIZER) || defined(DEBUG) + return input_tensors_.data() + tagged_node.node_item->input_start; + } + + FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const { + return {0, 0}; + } + + // Provide debugging output of the state of the executor. + void DumpState(); + + // For debugging/logging only. + void MaybeMarkStarted(const TaggedNode& tagged_node) { + // TODO(misard) Replace with a finer-grain enabling flag once we add better + // optional debugging support. + if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { + mutex_lock l(mu_); + (*active_)[tagged_node.node_item->node_id] = true; + } + } + void MaybeMarkCompleted(const TaggedNode& tagged_node) { + // TODO(misard) Replace with a finer-grain enabling flag once we add better + // optional debugging support. + if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { + mutex_lock l(mu_); + (*active_)[tagged_node.node_item->node_id] = false; + } + } + + private: + SimplePropagatorState(const ImmutableExecutorState& immutable_state_, + int64 step_id, + const ImmutableExecutorState::FrameInfo& finfo); + + const ImmutableExecutorState& immutable_state_; + const int64 step_id_; + const bool vlog_; + + // The i-th node's j-th input is stored at + // `input_tensors[impl_->nodes[i].input_start + j]`. + // + // NOTE: No need to protect input_tensors[i] by any locks because it + // is resized once. Each element of input_tensors is written once by the + // source node of an edge and is cleared by the destination of the same + // edge. The destination node always runs after the source node, so there + // is never concurrent access to the same entry. + std::vector input_tensors_; + + std::unique_ptr[]> pending_; + + // If `vlog_` is true, this stores a bit vector of active nodes, indexed by + // node ID. + mutex mu_; + std::unique_ptr> active_ TF_GUARDED_BY(mu_); + + const std::vector* const nodes_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_ diff --git a/tensorflow/core/common_runtime/sycl/BUILD b/tensorflow/core/common_runtime/sycl/BUILD new file mode 100644 index 00000000000..426903197df --- /dev/null +++ b/tensorflow/core/common_runtime/sycl/BUILD @@ -0,0 +1,46 @@ +load( + "//tensorflow:tensorflow.bzl", + "if_not_windows", + "tf_copts", +) +load( + "//tensorflow/core/platform:rules_cc.bzl", + "cc_library", +) + +package( + default_visibility = [ + "//tensorflow:internal", + ], + features = ["-parse_headers"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "sycl_runtime", + srcs = if_not_windows([ + "sycl_allocator.cc", + "sycl_device.cc", + "sycl_device_context.cc", + "sycl_device_factory.cc", + ]), + hdrs = if_not_windows([ + "sycl_allocator.h", + "sycl_device.h", + "sycl_util.h", + "sycl_device_context.h", + ]), + copts = tf_copts(), + linkstatic = 0, + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/common_runtime:core_cpu", + "//tensorflow/core/common_runtime:core_cpu_internal", + "//third_party/eigen3", + "@local_config_sycl//sycl", + ], + alwayslink = 0, +) diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index c0bd88a5e3f..799dd8ea9f6 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -195,6 +195,18 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "test_cluster", + testonly = True, + srcs = ["test_cluster.cc"], + hdrs = ["test_cluster.h"], + deps = [ + ":server_lib", + "//tensorflow/core/platform:errors", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "test_util", testonly = True, @@ -286,6 +298,7 @@ tf_cc_test( ":master_cc_grpc_proto", ":master_proto_cc", ":server_lib", + ":test_cluster", ":test_util", ":worker_cc_grpc_proto", ":worker_proto_cc", diff --git a/tensorflow/core/data/service/data_service_test.cc b/tensorflow/core/data/service/data_service_test.cc index 0eb3ca55c05..de52595021d 100644 --- a/tensorflow/core/data/service/data_service_test.cc +++ b/tensorflow/core/data/service/data_service_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/data/service/master.grpc.pb.h" #include "tensorflow/core/data/service/master.pb.h" #include "tensorflow/core/data/service/server_lib.h" +#include "tensorflow/core/data/service/test_cluster.h" #include "tensorflow/core/data/service/test_util.h" #include "tensorflow/core/data/service/worker.grpc.pb.h" #include "tensorflow/core/data/service/worker.pb.h" @@ -33,58 +34,6 @@ namespace tensorflow { namespace data { namespace { -const char kProtocol[] = "grpc+local"; - -// Parse the address from a string in the form "://
". -Status AddressFromTarget(const std::string& target, std::string* address) { - std::vector parts = absl::StrSplit(target, "://"); - if (parts.size() != 2) { - return errors::InvalidArgument("target ", target, " split into ", - parts.size(), " parts, not 2"); - } - *address = parts[1]; - return Status::OK(); -} - -class TestCluster { - public: - explicit TestCluster(int num_workers) : num_workers_(num_workers) {} - - Status Initialize() { - TF_RETURN_IF_ERROR(NewMasterServer(/*port=*/0, kProtocol, &master_)); - TF_RETURN_IF_ERROR(master_->Start()); - TF_RETURN_IF_ERROR(AddressFromTarget(master_->Target(), &master_address_)); - workers_.reserve(num_workers_); - worker_addresses_.reserve(num_workers_); - for (int i = 0; i < num_workers_; ++i) { - TF_RETURN_IF_ERROR(AddWorker()); - } - return Status::OK(); - } - - Status AddWorker() { - workers_.emplace_back(); - TF_RETURN_IF_ERROR(NewWorkerServer(/*port=*/0, kProtocol, master_address_, - &workers_.back())); - TF_RETURN_IF_ERROR(workers_.back()->Start()); - worker_addresses_.emplace_back(); - TF_RETURN_IF_ERROR(AddressFromTarget(workers_.back()->Target(), - &worker_addresses_.back())); - return Status::OK(); - } - - std::string MasterAddress() { return master_address_; } - - std::string WorkerAddress(int index) { return worker_addresses_[index]; } - - private: - int num_workers_; - std::unique_ptr master_; - std::string master_address_; - std::vector> workers_; - std::vector worker_addresses_; -}; - Status RegisterDataset(MasterService::Stub* master_stub, const GraphDef& dataset_graph, int64* dataset_id) { grpc_impl::ClientContext ctx; diff --git a/tensorflow/core/data/service/test_cluster.cc b/tensorflow/core/data/service/test_cluster.cc new file mode 100644 index 00000000000..bfa337b3dce --- /dev/null +++ b/tensorflow/core/data/service/test_cluster.cc @@ -0,0 +1,80 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/data/service/test_cluster.h" + +#include "absl/strings/str_split.h" +#include "tensorflow/core/data/service/server_lib.h" +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { +namespace data { + +namespace { +const char kProtocol[] = "grpc+local"; + +// Parse the address from a string in the form "://
". +Status AddressFromTarget(absl::string_view target, std::string* address) { + std::vector parts = absl::StrSplit(target, "://"); + if (parts.size() != 2) { + return errors::InvalidArgument("target ", target, " split into ", + parts.size(), " parts, not 2"); + } + *address = parts[1]; + return Status::OK(); +} +} // namespace + +TestCluster::TestCluster(int num_workers) : num_workers_(num_workers) {} + +Status TestCluster::Initialize() { + if (initialized_) { + return errors::FailedPrecondition( + "Test cluster has already been initialized."); + } + initialized_ = true; + TF_RETURN_IF_ERROR(NewMasterServer(/*port=*/0, kProtocol, &master_)); + TF_RETURN_IF_ERROR(master_->Start()); + TF_RETURN_IF_ERROR(AddressFromTarget(master_->Target(), &master_address_)); + workers_.reserve(num_workers_); + worker_addresses_.reserve(num_workers_); + for (int i = 0; i < num_workers_; ++i) { + TF_RETURN_IF_ERROR(AddWorker()); + } + return Status::OK(); +} + +Status TestCluster::AddWorker() { + std::unique_ptr worker; + TF_RETURN_IF_ERROR( + NewWorkerServer(/*port=*/0, kProtocol, master_address_, &worker)); + TF_RETURN_IF_ERROR(worker->Start()); + std::string address; + TF_RETURN_IF_ERROR(AddressFromTarget(worker->Target(), &address)); + workers_.push_back(std::move(worker)); + worker_addresses_.push_back(address); + return Status::OK(); +} + +std::string TestCluster::MasterAddress() { return master_address_; } + +std::string TestCluster::WorkerAddress(int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_workers_); + return worker_addresses_[index]; +} + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/data/service/test_cluster.h b/tensorflow/core/data/service/test_cluster.h new file mode 100644 index 00000000000..6aa75f4b86a --- /dev/null +++ b/tensorflow/core/data/service/test_cluster.h @@ -0,0 +1,54 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DATA_SERVICE_TEST_CLUSTER_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_TEST_CLUSTER_H_ + +#include "tensorflow/core/data/service/server_lib.h" + +namespace tensorflow { +namespace data { + +// Helper class for unit testing a tf.data service cluster. +class TestCluster { + public: + // Creates a new test cluster with a master and `num_workers` workers. + explicit TestCluster(int num_workers); + + // Initializes the test cluster. This must be called before interacting with + // the cluster. Initialize should be called only once. + Status Initialize(); + // Adds a new worker to the cluster. + Status AddWorker(); + // Returns the master address in the form "hostname:port". + std::string MasterAddress(); + // Returns the address of the worker at the specified index, in the form + // "hostname:port". The index must be non-negative and less than the number of + // workers in the cluster. + std::string WorkerAddress(int index); + + private: + bool initialized_ = false; + int num_workers_; + std::unique_ptr master_; + std::string master_address_; + std::vector> workers_; + std::vector worker_addresses_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_TEST_CLUSTER_H_ diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD index d9dfbc16677..ca3118c51e0 100644 --- a/tensorflow/core/debug/BUILD +++ b/tensorflow/core/debug/BUILD @@ -12,6 +12,7 @@ # a watch state. # ":debug_node_key" - Defines a struct used for tracking tensors. +load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency") load( "//tensorflow:tensorflow.bzl", "check_deps", @@ -123,7 +124,6 @@ tf_cuda_library( ":debug_node_key", ":debug_service_proto_cc", ":debugger_event_metadata_proto_cc", - "//tensorflow:grpc++", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", @@ -131,6 +131,7 @@ tf_cuda_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/strings", + tf_grpc_cc_dependency(), ], alwayslink = 1, ) @@ -146,11 +147,11 @@ tf_cuda_library( ":debug_io_utils", ":debug_service_proto_cc", ":debugger_event_metadata_proto_cc", - "//tensorflow:grpc++", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + tf_grpc_cc_dependency(), ], alwayslink = 1, ) diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index 626fb8fe19e..5dd2db26512 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -4,7 +4,8 @@ # processes. load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_copts", "tf_cuda_library") -load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") # buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency") # buildifier: disable=same-origin-load # For platform specific build config load( @@ -652,7 +653,6 @@ tf_cuda_cc_test( ":master", ":remote_device", ":worker_interface", - "//tensorflow:grpc++", "//tensorflow/core:control_flow_ops_op_lib", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", @@ -676,6 +676,7 @@ tf_cuda_cc_test( "//tensorflow/core/kernels:dense_update_ops", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:variable_ops", + tf_grpc_cc_dependency(), ], ) @@ -693,7 +694,6 @@ tf_cuda_cc_test( ":master", ":remote_device", ":worker_interface", - "//tensorflow:grpc++", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -709,6 +709,7 @@ tf_cuda_cc_test( "//tensorflow/core/distributed_runtime/rpc:grpc_testlib", "//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", + tf_grpc_cc_dependency(), ], ) diff --git a/tensorflow/core/distributed_runtime/eager/BUILD b/tensorflow/core/distributed_runtime/eager/BUILD index 7a299a3620e..2d434934bf2 100644 --- a/tensorflow/core/distributed_runtime/eager/BUILD +++ b/tensorflow/core/distributed_runtime/eager/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency") load( "//tensorflow:tensorflow.bzl", "tf_cc_test", @@ -99,7 +100,6 @@ cc_library( ":cluster_function_library_runtime", ":remote_mgr", ":remote_tensor_handle", - "//tensorflow:grpc++", "//tensorflow/c:c_api_internal", "//tensorflow/c:tf_status_helper", "//tensorflow/core:core_cpu_internal", @@ -125,6 +125,7 @@ cc_library( "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", + tf_grpc_cc_dependency(), ], ) diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index cf28e2680d8..d87012de104 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -367,11 +367,6 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation, { profiler::TraceMe activity("EagerService:RemoteTensorHandleInternal", profiler::TraceMeLevel::kVerbose); - if (!operation.op_inputs().empty() && !operation.inputs().empty()) { - return errors::InvalidArgument( - "Both operation.inputs and operation.op_inputs are specified in the " - "same request."); - } for (const auto& input : operation.op_inputs()) { tensorflow::TensorHandle* handle; if (input.has_remote_handle()) { @@ -393,17 +388,6 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation, // Unref handle since it has a ref as an input now. handle->Unref(); } - // TODO(b/150963957): Remove this once the migration from operation.inputs - // to operation.op_inputs completes. - for (const auto& remote_handle : operation.inputs()) { - tensorflow::TensorHandle* handle; - TF_RETURN_IF_ERROR( - eager_context->RemoteMgr()->DeserializeRemoteTensorHandle( - remote_handle, &handle)); - op->AddInput(handle); - // Unref handle since it has a ref as an input now. - handle->Unref(); - } } for (const auto& attr : operation.attrs()) { diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index 6aff1e85465..96e1a63e5a6 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -8,8 +8,10 @@ load( "tf_cc_test", "tf_cuda_library", ) -load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") -load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_tests") +load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") # buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_tests") # buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "tf_grpc_dependency") # buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency") # buildifier: disable=same-origin-load # For platform specific build config load( @@ -42,12 +44,12 @@ cc_library( hdrs = ["grpc_util.h"], linkopts = if_windows(["-DEFAULTLIB:ws2_32.lib"]), deps = [ - "//tensorflow:grpc", - "//tensorflow:grpc++", "//tensorflow/core:lib", # Required to be able to overload TensorResponse parsing. "//tensorflow/core/distributed_runtime:tensor_coding", "//tensorflow/core:lib_internal", + tf_grpc_dependency(), + tf_grpc_cc_dependency(), ], ) @@ -57,8 +59,8 @@ cc_library( hdrs = ["grpc_client_cq_tag.h"], deps = [ ":grpc_util", - "//tensorflow:grpc++", "//tensorflow/core:lib", + tf_grpc_cc_dependency(), ], ) @@ -69,12 +71,12 @@ cc_library( deps = [ ":grpc_client_cq_tag", ":grpc_util", - "//tensorflow:grpc++", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/distributed_runtime:call_options", "//tensorflow/core/distributed_runtime:tensor_coding", "@com_google_absl//absl/strings:str_format", + tf_grpc_cc_dependency(), ], ) @@ -87,7 +89,6 @@ cc_library( ":grpc_state", ":grpc_util", ":grpc_worker_service_impl", - "//tensorflow:grpc++", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -96,6 +97,7 @@ cc_library( "//tensorflow/core/distributed_runtime:tensor_coding", "//tensorflow/core/distributed_runtime:worker_cache_logger", "//tensorflow/core/distributed_runtime:worker_interface", + tf_grpc_cc_dependency(), ], ) @@ -105,12 +107,12 @@ cc_library( hdrs = ["grpc_channel.h"], deps = [ ":grpc_util", - "//tensorflow:grpc++", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/strings", + tf_grpc_cc_dependency(), ], ) @@ -119,7 +121,6 @@ cc_library( srcs = ["grpc_tensor_coding.cc"], hdrs = ["grpc_tensor_coding.h"], deps = [ - "//tensorflow:grpc++", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -127,6 +128,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core:worker_proto_cc", "@com_google_absl//absl/flags:flag", + tf_grpc_cc_dependency(), ], ) @@ -135,9 +137,9 @@ cc_library( srcs = [], hdrs = ["grpc_call.h"], deps = [ - "//tensorflow:grpc++", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + tf_grpc_cc_dependency(), ], ) @@ -188,7 +190,6 @@ tf_cuda_library( ":grpc_tensor_coding", ":grpc_util", ":grpc_worker_service_impl", - "//tensorflow:grpc++", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -202,6 +203,7 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core/distributed_runtime:worker_session", "@com_google_absl//absl/container:flat_hash_map", + tf_grpc_cc_dependency(), ], ) @@ -211,9 +213,9 @@ cc_library( hdrs = ["grpc_worker_service_impl.h"], deps = [ ":grpc_util", - "//tensorflow:grpc++", "//tensorflow/core:worker_proto_cc", "//tensorflow/core/distributed_runtime:tensor_coding", + tf_grpc_cc_dependency(), ], ) @@ -244,12 +246,12 @@ cc_library( ":grpc_call", ":grpc_master_service_impl", ":grpc_util", - "//tensorflow:grpc++", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:master_proto_cc", "//tensorflow/core/distributed_runtime:master", "//tensorflow/core/profiler/lib:traceme", + tf_grpc_cc_dependency(), ], alwayslink = 1, ) @@ -259,8 +261,8 @@ cc_library( srcs = ["grpc_master_service_impl.cc"], hdrs = ["grpc_master_service_impl.h"], deps = [ - "//tensorflow:grpc++", "//tensorflow/core:master_proto_cc", + tf_grpc_cc_dependency(), ], ) @@ -293,8 +295,6 @@ cc_library( ":grpc_worker_cache", ":grpc_worker_service", ":rpc_rendezvous_mgr", - "//tensorflow:grpc", - "//tensorflow:grpc++", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -315,6 +315,8 @@ cc_library( "//tensorflow/core/distributed_runtime:worker_cache_wrapper", "//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_service_impl", + tf_grpc_dependency(), + tf_grpc_cc_dependency(), ], alwayslink = 1, ) @@ -335,7 +337,6 @@ tf_cc_binary( ], deps = [ ":grpc_server_lib", - "//tensorflow:grpc++", "//tensorflow/core:core_cpu", "//tensorflow/core:data_flow_ops_op_lib", "//tensorflow/core:framework_internal", @@ -348,6 +349,7 @@ tf_cc_binary( "//tensorflow/core:sendrecv_ops_op_lib", "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/kernels:data_flow", + tf_grpc_cc_dependency(), ], ) @@ -360,7 +362,6 @@ cc_library( ], deps = [ ":grpc_server_lib", - "//tensorflow:grpc++", "//tensorflow/core:array_ops_op_lib", "//tensorflow/core:bitwise_ops_op_lib", "//tensorflow/core:core_cpu", @@ -378,6 +379,7 @@ cc_library( "//tensorflow/core/kernels:matmul_op", "//tensorflow/core/kernels:reduction_ops", "//tensorflow/core/kernels:variable_ops", + tf_grpc_cc_dependency(), ], alwayslink = 1, ) @@ -473,7 +475,6 @@ tf_cc_test( deps = [ ":grpc_tensor_coding", ":grpc_testlib", - "//tensorflow:grpc++", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -483,6 +484,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core:worker_proto_cc", + tf_grpc_cc_dependency(), ], ) @@ -492,11 +494,11 @@ tf_cc_test( srcs = ["grpc_util_test.cc"], deps = [ ":grpc_util", - "//tensorflow:grpc", - "//tensorflow:grpc++", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:worker_proto_cc", + tf_grpc_dependency(), + tf_grpc_cc_dependency(), ], ) diff --git a/tensorflow/core/distributed_runtime/rpc/eager/BUILD b/tensorflow/core/distributed_runtime/rpc/eager/BUILD index 1ac8e683f07..d7251029d10 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/eager/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency") + package( default_visibility = [ "//tensorflow:internal", @@ -12,9 +14,9 @@ cc_library( srcs = ["grpc_eager_service.h"], hdrs = ["grpc_eager_service.h"], deps = [ - "//tensorflow:grpc++", "//tensorflow/core:eager_service_proto_cc", "//tensorflow/stream_executor/platform", + tf_grpc_cc_dependency(), ], ) @@ -24,7 +26,6 @@ cc_library( hdrs = ["grpc_eager_client.h"], deps = [ ":grpc_eager_service", - "//tensorflow:grpc++", "//tensorflow/core:eager_service_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -33,6 +34,7 @@ cc_library( "//tensorflow/core/distributed_runtime/rpc:grpc_client_cq_tag", "//tensorflow/core/distributed_runtime/rpc:grpc_state", "//tensorflow/core/distributed_runtime/rpc:grpc_util", + tf_grpc_cc_dependency(), ], ) @@ -42,7 +44,6 @@ cc_library( hdrs = ["grpc_eager_service_impl.h"], deps = [ ":grpc_eager_service", - "//tensorflow:grpc++", "//tensorflow/core:eager_service_proto_cc", "//tensorflow/core:framework", "//tensorflow/core:ptr_util", @@ -52,5 +53,6 @@ cc_library( "//tensorflow/core/distributed_runtime/rpc:grpc_channel", "//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", + tf_grpc_cc_dependency(), ], ) diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index 0471d9bbc60..0d3d2a27d73 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -76,7 +76,10 @@ exports_files( "tracking_allocator.h", "versions.h", ], - visibility = ["//tensorflow/core:__pkg__"], + visibility = [ + "//tensorflow/core:__pkg__", + "//tensorflow/core/common_runtime:__pkg__", + ], ) # List of exported test source files that do not yet have local build rules. @@ -626,7 +629,7 @@ cc_library( ], visibility = [ "//tensorflow/core:__pkg__", - "//tensorflow/core/tfrt_delegate:__pkg__", + "//tensorflow/core/runtime_fallback:__pkg__", ], deps = [ ":bounds_check", diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD index 618cc9a990f..48c733c0987 100644 --- a/tensorflow/core/grappler/BUILD +++ b/tensorflow/core/grappler/BUILD @@ -40,6 +40,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -69,7 +71,7 @@ tf_cuda_library( srcs = ["devices.cc"], hdrs = ["devices.h"], cuda_deps = [ - "//tensorflow/core:gpu_init", + "//tensorflow/core/common_runtime/gpu:gpu_init", "//tensorflow/core:stream_executor", ], visibility = ["//visibility:public"], diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD index ab1c36010e9..08c0179fc52 100644 --- a/tensorflow/core/grappler/clusters/BUILD +++ b/tensorflow/core/grappler/clusters/BUILD @@ -28,7 +28,7 @@ tf_cuda_library( deps = [ "//third_party/eigen3", "//tensorflow/core:framework", - "//tensorflow/core:gpu_id", + "//tensorflow/core/common_runtime/gpu:gpu_id", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ] + select({ @@ -44,11 +44,11 @@ tf_cc_test( tags = tf_cuda_tests_tags(), deps = [ ":utils", - "//tensorflow/core:gpu_id", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/common_runtime/gpu:gpu_id", ], ) @@ -128,12 +128,12 @@ cc_library( ":utils", "//tensorflow/cc:coordinator", "//tensorflow/cc:queue_runner", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_lib", - "//tensorflow/core:direct_session", "//tensorflow/core:framework", - "//tensorflow/core:gpu_id", "//tensorflow/core:lib", + "//tensorflow/core/common_runtime:core_cpu", + "//tensorflow/core/common_runtime:core_cpu_lib", + "//tensorflow/core/common_runtime:direct_session_internal", + "//tensorflow/core/common_runtime/gpu:gpu_id", "//tensorflow/core/grappler:utils", "//tensorflow/core/kernels:ops_util", ], diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index 9104cea896d..443b1918a0f 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -66,6 +66,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":utils", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:optional", "//tensorflow/core/grappler/utils:functions", "//tensorflow/core/grappler/utils:topological_sort", @@ -161,11 +162,12 @@ tf_cuda_library( deps = [ ":cost_estimator", "//third_party/eigen3", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "//tensorflow/core:framework", "//tensorflow/core:graph", - "//tensorflow/core:gpu_id", + "//tensorflow/core/common_runtime/gpu:gpu_id", "//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 1f9bae4f852..be987f2d151 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -201,7 +201,7 @@ class DisjointSet { private: Processor processor_; - std::unordered_map, CompareHandle> + absl::flat_hash_map, CompareHandle> nodes_; }; @@ -297,9 +297,9 @@ bool HasAnyUnknownDimensions(const TensorShapeProto& proto) { // This really should be done in an external debugging tool void VerboseLogUnknownDimensionSources( const GraphDef& graph, - const std::unordered_map>& + const absl::flat_hash_map>& input_properties_map, - const std::unordered_map>& + const absl::flat_hash_map>& output_properties_map) { if (!VLOG_IS_ON(2)) { return; @@ -497,9 +497,9 @@ class TopoQueue { } }; - const std::unordered_map TopoOrder( + const absl::flat_hash_map TopoOrder( const std::vector& topo_order) const { - std::unordered_map map; + absl::flat_hash_map map; map.reserve(topo_order.size()); for (int i = 0; i < topo_order.size(); ++i) { map.emplace(topo_order[i], i); @@ -507,7 +507,7 @@ class TopoQueue { return map; } - const std::unordered_map topo_order_; + const absl::flat_hash_map topo_order_; std::set queue_; }; @@ -599,7 +599,7 @@ class SymbolicShapeRefiner { public: explicit SymbolicShapeRefiner( const GraphView& graph, - const std::unordered_map>& fed_ports, + const absl::flat_hash_map>& fed_ports, const bool aggressive_shape_inference) : graph_(graph), function_library_(OpRegistry::Global(), graph.graph()->library()), @@ -1917,20 +1917,20 @@ class SymbolicShapeRefiner { const GraphView& graph_; int graph_def_version_; - std::unordered_map node_to_context_; - std::unordered_map unknown_shapes_; - std::unordered_map unknown_dims_; + absl::flat_hash_map node_to_context_; + absl::flat_hash_map unknown_shapes_; + absl::flat_hash_map unknown_dims_; // Store function instantiations only for valid function. If function // instantiation failed it will have an `absl::nullopt`. - std::unordered_map> + absl::flat_hash_map> fun_to_grappler_function_item_; FunctionLibraryDefinition function_library_; - const std::unordered_map>& fed_ports_; - // Store TensorProtos for tensor value propagation. Note that we use list, not - // vector, as we use pointers to the TensorProtos in this container. Vector - // may resize and copy the objects into a new buffer, then the existing + const absl::flat_hash_map>& fed_ports_; + // Store TensorProtos for tensor value propagation. Note that we use deque, + // not vector, as we use pointers to the TensorProtos in this container. + // Vector may resize and copy the objects into a new buffer, then the existing // pointers become dangling pointers. - std::list const_tensors_to_propagate_; + std::deque const_tensors_to_propagate_; // For more aggressive shape and value inference. bool aggressive_shape_inference_; @@ -2093,7 +2093,7 @@ Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner, Status GraphProperties::UpdateShapes( SymbolicShapeRefiner* shape_refiner, - const std::unordered_map& resource_handles, + const absl::flat_hash_map& resource_handles, const NodeDef* n, bool* new_shapes) const { if (IsEnter(*n)) { // The Enter shape function always forwards an UnknownShape, so do the right @@ -2122,7 +2122,7 @@ Status GraphProperties::UpdateShapes( // Propagates the shapes in the transitive fan-out of . Status GraphProperties::PropagateShapes( SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes, - const std::unordered_map& resource_handles, + const absl::flat_hash_map& resource_handles, int num_loops) const { // Limit the number of iterations to prevent infinite loops in the presence of // incorrect shape functions. The algorithm should converge in at most @@ -2221,7 +2221,7 @@ Status GraphProperties::UpdateQueue(const NodeDef* queue_node, Status GraphProperties::UpdateEnqueue( const NodeDef* enqueue_node, - const std::unordered_map& resource_handles, + const absl::flat_hash_map& resource_handles, SymbolicShapeRefiner* shape_refiner, bool* new_shapes) { auto ctx = shape_refiner->GetNodeContext(enqueue_node); if (!ctx) { @@ -2272,7 +2272,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds, bool include_output_tensor_values) { FunctionLibraryDefinition function_library(OpRegistry::Global(), item_.graph.library()); - std::unordered_map> fed_ports; + absl::flat_hash_map> fed_ports; if (!assume_valid_feeds) { for (const auto& feed : item_.feed) { SafeTensorId tensor_id = ParseTensorName(feed.first); @@ -2284,13 +2284,13 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds, // List the resources and the nodes using them. Also collect the Merge nodes, // fed nodes, and primary inputs. - std::unordered_map, - std::unordered_set>> + absl::flat_hash_map, + absl::flat_hash_set>> resources; - std::unordered_set merge_nodes; - std::unordered_set fed_nodes; - std::unordered_set primary_inputs; + absl::flat_hash_set merge_nodes; + absl::flat_hash_set fed_nodes; + absl::flat_hash_set primary_inputs; int num_loops = 0; for (const NodeDef& node : item_.graph.node()) { if (IsQueue(node)) { @@ -2327,7 +2327,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds, } } - std::unordered_map resource_handles; + absl::flat_hash_map resource_handles; std::vector extra_deps; for (const auto& resource : resources) { for (const NodeDef* src : resource.second.first) { diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h index 37c41a3dba5..3f7487a064b 100644 --- a/tensorflow/core/grappler/costs/graph_properties.h +++ b/tensorflow/core/grappler/costs/graph_properties.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/costs/op_performance_data.pb.h" @@ -168,7 +169,7 @@ class GraphProperties { // queue, and schedule the reprocessing of the queue if needed. static Status UpdateEnqueue( const NodeDef* enqueue_node, - const std::unordered_map& + const absl::flat_hash_map& resource_handles, SymbolicShapeRefiner* shape_refiner, bool* new_shapes); @@ -187,22 +188,22 @@ class GraphProperties { // Update the shapes for node 'n'. If output shapes for n have changed, // enqueue its fanout in 'new_shapes'. Status UpdateShapes(SymbolicShapeRefiner* shape_refiner, - const std::unordered_map& + const absl::flat_hash_map& resource_handles, const NodeDef* n, bool* new_shapes) const; // Propagate the shapes for the nodes enqueued in new_shapes and their // transitive fanout until a fixed point is reached. Status PropagateShapes( SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes, - const std::unordered_map& + const absl::flat_hash_map& resource_handles, int num_loops) const; // Data members const GrapplerItem& item_; - std::unordered_map> + absl::flat_hash_map> input_properties_; - std::unordered_map> + absl::flat_hash_map> output_properties_; const std::vector missing_properties_; diff --git a/tensorflow/core/grappler/graph_analyzer/BUILD b/tensorflow/core/grappler/graph_analyzer/BUILD index 1377424a6fb..da27c10bb18 100644 --- a/tensorflow/core/grappler/graph_analyzer/BUILD +++ b/tensorflow/core/grappler/graph_analyzer/BUILD @@ -62,7 +62,6 @@ cc_library( deps = [ ":graph_analyzer_lib", "//tensorflow/core:framework", - "//tensorflow/core:tensorflow", "//tensorflow/core/grappler:op_types", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 56b7754355c..ae854ad85f1 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -164,14 +164,11 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/grappler:graph_topology_view", "//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/utils:functions", - "//tensorflow/core/grappler/utils:topological_sort", - "//tensorflow/core/grappler/utils:traversal", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -896,7 +893,6 @@ tf_cc_test_mkl( srcs = ["mkl_remapper_test.cc"], tags = [ "no_mac", - "no_oss", ], deps = [ ":remapper", @@ -1107,9 +1103,7 @@ cc_library( "//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:op_types", - "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/costs:graph_properties", - "//tensorflow/core/grappler/utils:frame", "//tensorflow/core/grappler/utils:symbolic_shapes", "//tensorflow/core/grappler/utils:topological_sort", "//tensorflow/core/grappler/utils:tpu", diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index fc4bc0357bd..cec6d7cce7f 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -1764,7 +1764,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { // Update consumers of node to take new_input as input instead. void UpdateConsumers(NodeDef* node, const string& new_input) { const string& node_name = node->name(); - const std::set consumers = ctx().node_map->GetOutputs(node_name); + const auto consumers = ctx().node_map->GetOutputs(node_name); for (NodeDef* consumer : consumers) { for (int i = 0; i < consumer->input_size(); ++i) { if (consumer->input(i) == node_name && @@ -2910,7 +2910,7 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage { void UpdateConsumers(NodeDef* node, const string& new_input) { const string& node_name = node->name(); - const std::set consumers = ctx().node_map->GetOutputs(node_name); + const auto consumers = ctx().node_map->GetOutputs(node_name); for (NodeDef* consumer : consumers) { for (int i = 0; i < consumer->input_size(); ++i) { if (consumer->input(i) == node_name && @@ -3561,12 +3561,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { // consumers of `node` are already redirected to `simplified_tensor`. // Re-push the consumers into `nodes_to_simplify` for further // optimizations. - const std::set outputs = node_map_->GetOutputs(node->name()); - std::vector consumers(outputs.begin(), outputs.end()); - std::sort(consumers.begin(), consumers.end(), - [](const NodeDef* n1, const NodeDef* n2) { - return n1->name() < n2->name(); - }); + const std::vector consumers = + node_map_->GetOutputsOrderedByNodeName(node->name()); for (NodeDef* consumer : consumers) { // Update `consumer`'s use of `node` to `input`'s operand. for (int i = 0; i < consumer->input_size(); ++i) { diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc index fe213d8aafb..55f83eb7a76 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc @@ -863,11 +863,11 @@ Status ValidateLists(const gtl::FlatSet& white_list, std::vector> lists{white_list, black_list, gray_list, clear_list}; std::multiset counts; - for (auto list : lists) { + for (const auto& list : lists) { counts.insert(list.begin(), list.end()); } bool duplicates = false; - for (auto s : counts) { + for (const auto& s : counts) { if (counts.count(s) > 1) { duplicates = true; LOG(ERROR) << "Op present in multiple lists: " << s; @@ -1054,20 +1054,20 @@ Status AutoMixedPrecisionImpl::PrintDebugLogs(bool preop, size_t timestamp) { strings::StrCat("paintbuckets", suffix, ".txt")); f.open(fname.c_str(), std::fstream::out); f << "WhiteList:\n"; - for (auto x : + for (const auto& x : AutoMixedPrecisionLists::WhiteList(cuda_version_, cudnn_version_)) { f << x << "\n"; } f << "\nBlackList:\n"; - for (auto x : AutoMixedPrecisionLists::BlackList()) { + for (const auto& x : AutoMixedPrecisionLists::BlackList()) { f << x << "\n"; } f << "\nGrayList:\n"; - for (auto x : AutoMixedPrecisionLists::GrayList()) { + for (const auto& x : AutoMixedPrecisionLists::GrayList()) { f << x << "\n"; } f << "\nClearList:\n"; - for (auto x : AutoMixedPrecisionLists::ClearList()) { + for (const auto& x : AutoMixedPrecisionLists::ClearList()) { f << x << "\n"; } f.close(); diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h b/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h index 7e7f487fa37..d3d13e2edc0 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h @@ -27,10 +27,10 @@ class AutoMixedPrecisionLists { private: static void UpdateList(gtl::FlatSet* list, const string& to_add, const string& to_remove) { - for (auto x : str_util::Split(to_add, ",")) { + for (const auto& x : str_util::Split(to_add, ",")) { list->insert(x); } - for (auto x : str_util::Split(to_remove, ",")) { + for (const auto& x : str_util::Split(to_remove, ",")) { list->erase(x); } } diff --git a/tensorflow/core/grappler/optimizers/common_subgraph_elimination.cc b/tensorflow/core/grappler/optimizers/common_subgraph_elimination.cc index 2b36296a273..af323e913a7 100644 --- a/tensorflow/core/grappler/optimizers/common_subgraph_elimination.cc +++ b/tensorflow/core/grappler/optimizers/common_subgraph_elimination.cc @@ -217,8 +217,8 @@ Status CommonSubgraphElimination::DedupComputations(GraphDef* optimized_graph) { if (rep == node) { continue; } - const std::set& tmp = node_map.GetOutputs(node->name()); - std::vector fanouts(tmp.begin(), tmp.end()); + // Make a copy since we mutate the set below. + const auto fanouts = node_map.GetOutputs(node->name()); for (NodeDef* fanout : fanouts) { // Update consumers of node. bool updated_fanout = false; diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 24758409386..66fca58e907 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -253,7 +253,7 @@ bool ConstantFolding::ForwardInputs(NodeDef* node, } } - const std::set& tmp = node_map_->GetOutputs(node->name()); + const auto& tmp = node_map_->GetOutputs(node->name()); const std::vector consumers(tmp.begin(), tmp.end()); bool updated_graph = false; for (int input_idx : inputs_to_forward) { @@ -691,7 +691,7 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs( } // We make a copy here since we might mutate the set. - const std::set outputs = node_map_->GetOutputs(node.name()); + const auto outputs = node_map_->GetOutputs(node.name()); for (NodeDef* output : outputs) { for (int k = 0; k < output->input_size(); ++k) { int port; @@ -1053,7 +1053,7 @@ bool ConstantFolding::MaybeFoldable(const NodeDef& node, op.find("Reader") != string::npos) { return false; } - if (op.find("Quantized") != string::npos || op.find("Sparse") == 0) { + if (op.find("Quantized") != string::npos || absl::StartsWith(op, "Sparse")) { return false; } @@ -1594,13 +1594,8 @@ Status ConstantFolding::FoldGraph( } // We need to record a copy of output nodes before FoldNode() modifies it. // We also need to ensure that the fanout is sorted deterministically. - const std::set& outputs = node_map_->GetOutputs(node->name()); - std::vector fanout(outputs.begin(), outputs.end()); - std::sort(fanout.begin(), fanout.end(), - [](const NodeDef* n1, const NodeDef* n2) { - return n1->name() < n2->name(); - }); - + std::vector fanout = + node_map_->GetOutputsOrderedByNodeName(node->name()); bool result_too_large = false; Status s = FoldNode(node, output, &result_too_large); processed_nodes.insert(node->name()); @@ -2449,12 +2444,8 @@ bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) { SetTensorValue(DT_BOOL, false, &false_t).ok()) { // Copy the set of consumers of the switch as they will be manipulated // below. - const auto& consumer_set = node_map_->GetOutputs(node->name()); - std::vector consumers(consumer_set.begin(), consumer_set.end()); - std::sort(consumers.begin(), consumers.end(), - [](const NodeDef* n1, const NodeDef* n2) { - return n1->name() < n2->name(); - }); + std::vector consumers = + node_map_->GetOutputsOrderedByNodeName(node->name()); // Create constant false & true nodes. NodeDef tmp_false_node; tmp_false_node.set_name(OptimizedNodeName(*node, "_const_false")); diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc index fed60036137..58ef14e3d3d 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc @@ -136,7 +136,7 @@ int DependencyOptimizer::NumEdgesIfBypassed( // multi-input identity_n with input/output control dependencies will likely // increase number of edges after optimization. int num_edges_if_bypassed(0); - for (string input_node_name : node.input()) { + for (const string& input_node_name : node.input()) { if (IsControlInput(input_node_name)) { num_edges_if_bypassed += num_outputs; } else { @@ -233,7 +233,7 @@ void DependencyOptimizer::OptimizeNode(int node_idx, // Constant nodes with no input control dependency are always executed early, // so we can prune all their output control dependencies. if (IsConstant(*node) && node->input_size() == 0) { - const std::set output_nodes = node_map_->GetOutputs(node_name); + const auto output_nodes = node_map_->GetOutputs(node_name); for (NodeDef* fanout : output_nodes) { bool optimize_fanout = false; bool data_connection = false; diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc index d60fb2042a7..867433dcff5 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc @@ -466,7 +466,8 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level, // meaning it either begins with or contains the name scope. // Defaults to "gradients/" which will match any node names that begins // with "gradients/" or contains "/gradients/". - return node.name().find(recomputation_targets_name_scope) == 0 || + return absl::StartsWith(node.name(), + recomputation_targets_name_scope) || node.name().find("/" + recomputation_targets_name_scope) != -1; }; diff --git a/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc b/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc index 87841316fc1..83eca92e51c 100644 --- a/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc +++ b/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc @@ -25,152 +25,122 @@ limitations under the License. namespace tensorflow { namespace grappler { -class MklRemapperTest : public GrapplerTest {}; +class MklRemapperTest : public GrapplerTest { + protected: + void FuseConv2DWithBiasAndAddN(const string& data_format, bool has_relu) { + using ::tensorflow::ops::Placeholder; -TEST_F(MklRemapperTest, FuseConv2DWithBiasAndAddN) { - using ::tensorflow::ops::Placeholder; + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto input_shape = (data_format == "NHWC") + ? ops::Placeholder::Shape({8, 32, 32, 3}) + : ops::Placeholder::Shape({8, 3, 32, 32}); + auto input_shape_addn = (data_format == "NHWC") + ? ops::Placeholder::Shape({8, 32, 32, 128}) + : ops::Placeholder::Shape({8, 128, 32, 32}); + auto filter_shape = ops::Placeholder::Shape({1, 1, 3, 128}); + auto bias_shape = ops::Placeholder::Shape({128}); - auto input_shape = ops::Placeholder::Shape({8, 32, 32, 3}); - auto input_shape_addn = ops::Placeholder::Shape({8, 32, 32, 128}); - auto filter_shape = ops::Placeholder::Shape({1, 1, 3, 128}); - auto bias_shape = ops::Placeholder::Shape({128}); + auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape); + auto input_addn = + Placeholder(s.WithOpName("input_addn"), DT_FLOAT, input_shape_addn); + auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape); + auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape); - auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape); - auto input_addn = - Placeholder(s.WithOpName("input_addn"), DT_FLOAT, input_shape_addn); - auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape); - auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape); - - std::vector strides = {1, 1, 1, 1}; - auto conv = ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME"); - auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias); - auto addn = ops::AddN(s.WithOpName("addn"), - std::initializer_list{input_addn, bias_add}); - auto fetch = ops::Identity(s.WithOpName("fetch"), addn); - - auto input_t = GenerateRandomTensor({8, 32, 32, 3}); - auto input_addn_t = GenerateRandomTensor({8, 32, 32, 128}); - auto filter_t = GenerateRandomTensor({1, 1, 3, 128}); - auto bias_t = GenerateRandomTensor({128}); - - GrapplerItem item; - item.fetch = {"fetch"}; - item.feed = {{"input", input_t}, - {"filter", filter_t}, - {"bias", bias_t}, - {"input_addn", input_addn_t}}; - TF_CHECK_OK(s.ToGraphDef(&item.graph)); - - // Place all nodes on CPU. - for (int i = 0; i < item.graph.node_size(); ++i) { - item.graph.mutable_node(i)->set_device("/device:CPU:0"); - } - - Remapper optimizer(RewriterConfig::ON); - GraphDef output; - TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output)); - - int found = 0; - for (const NodeDef& node : output.node()) { - if (node.name() == "addn") { - EXPECT_EQ("_FusedConv2D", node.op()); - EXPECT_EQ("input", node.input(0)); - EXPECT_EQ("filter", node.input(1)); - - EXPECT_EQ(2, node.attr().at("num_args").i()); - EXPECT_EQ("bias", node.input(2)); - EXPECT_EQ("input_addn", node.input(3)); - - const auto fused_ops = node.attr().at("fused_ops").list().s(); - EXPECT_EQ(2, fused_ops.size()); - EXPECT_EQ("BiasAdd", fused_ops[0]); - EXPECT_EQ("Add", fused_ops[1]); - found++; + std::vector strides = {1, 1, 1, 1}; + auto conv = + ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME", + ops::Conv2D::Attrs().DataFormat(data_format)); + auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias, + ops::BiasAdd::Attrs().DataFormat(data_format)); + auto addn = ops::AddN(s.WithOpName("addn"), + std::initializer_list{input_addn, bias_add}); + if (has_relu) { + auto relu = ops::Relu(s.WithOpName("relu"), addn); + ops::Identity(s.WithOpName("fetch"), relu); + } else { + ops::Identity(s.WithOpName("fetch"), addn); } - } - EXPECT_EQ(1, found); + auto input_tensor = GenerateRandomTensor( + TensorShape(input_shape.shape_.dim_sizes())); + auto input_addn_tensor = GenerateRandomTensor( + TensorShape(input_shape_addn.shape_.dim_sizes())); + auto filter_tensor = GenerateRandomTensor( + TensorShape(filter_shape.shape_.dim_sizes())); + auto bias_tensor = GenerateRandomTensor( + TensorShape(bias_shape.shape_.dim_sizes())); - auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); - auto tensors = EvaluateNodes(output, item.fetch, item.feed); - EXPECT_EQ(1, tensors_expected.size()); - EXPECT_EQ(1, tensors.size()); - test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); + GrapplerItem item; + item.fetch = {"fetch"}; + item.feed = {{"input", input_tensor}, + {"filter", filter_tensor}, + {"bias", bias_tensor}, + {"input_addn", input_addn_tensor}}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + // Place all nodes on CPU. + for (int i = 0; i < item.graph.node_size(); ++i) { + item.graph.mutable_node(i)->set_device("/device:CPU:0"); + } + + Remapper optimizer(RewriterConfig::ON); + GraphDef output; + TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output)); + + int found = 0; + for (const NodeDef& node : output.node()) { + auto fetch_node_name = has_relu ? "relu" : "addn"; + if (node.name() == fetch_node_name) { + EXPECT_EQ("_FusedConv2D", node.op()); + EXPECT_EQ("input", node.input(0)); + EXPECT_EQ("filter", node.input(1)); + + EXPECT_EQ(2, node.attr().at("num_args").i()); + EXPECT_EQ("bias", node.input(2)); + EXPECT_EQ("input_addn", node.input(3)); + + const auto fused_ops = node.attr().at("fused_ops").list().s(); + if (has_relu) { + EXPECT_EQ(3, fused_ops.size()); + EXPECT_EQ("BiasAdd", fused_ops[0]); + EXPECT_EQ("Add", fused_ops[1]); + EXPECT_EQ("Relu", fused_ops[2]); + } else { + EXPECT_EQ(2, fused_ops.size()); + EXPECT_EQ("BiasAdd", fused_ops[0]); + EXPECT_EQ("Add", fused_ops[1]); + } + found++; + } + } + EXPECT_EQ(1, found); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + EXPECT_EQ(1, tensors_expected.size()); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); + } +}; + +TEST_F(MklRemapperTest, FuseConv2DWithBiasAndAddN_NHWC_WithoutRelu) { + const bool kShouldFuseRelu = false; + FuseConv2DWithBiasAndAddN("NHWC", kShouldFuseRelu); } -TEST_F(MklRemapperTest, FuseConv2DWithBiasAndAddNRelu) { - using ::tensorflow::ops::Placeholder; +TEST_F(MklRemapperTest, FuseConv2DWithBiasAndAddN_NHWC_WithRelu) { + const bool kShouldFuseRelu = true; + FuseConv2DWithBiasAndAddN("NHWC", kShouldFuseRelu); +} - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); +TEST_F(MklRemapperTest, FuseConv2DWithBiasAndAddN_NCHW_WithoutRelu) { + const bool kShouldFuseRelu = false; + FuseConv2DWithBiasAndAddN("NCHW", kShouldFuseRelu); +} - auto input_shape = ops::Placeholder::Shape({8, 32, 32, 3}); - auto input_shape_addn = ops::Placeholder::Shape({8, 32, 32, 128}); - auto filter_shape = ops::Placeholder::Shape({1, 1, 3, 128}); - auto bias_shape = ops::Placeholder::Shape({128}); - - auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape); - auto input_addn = - Placeholder(s.WithOpName("input_addn"), DT_FLOAT, input_shape_addn); - auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape); - auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape); - - std::vector strides = {1, 1, 1, 1}; - auto conv = ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME"); - auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias); - auto addn = ops::AddN(s.WithOpName("addn"), - std::initializer_list{input_addn, bias_add}); - auto relu = ops::Relu(s.WithOpName("relu"), addn); - auto fetch = ops::Identity(s.WithOpName("fetch"), relu); - - auto input_t = GenerateRandomTensor({8, 32, 32, 3}); - auto input_addn_t = GenerateRandomTensor({8, 32, 32, 128}); - auto filter_t = GenerateRandomTensor({1, 1, 3, 128}); - auto bias_t = GenerateRandomTensor({128}); - - GrapplerItem item; - item.fetch = {"fetch"}; - item.feed = {{"input", input_t}, - {"filter", filter_t}, - {"bias", bias_t}, - {"input_addn", input_addn_t}}; - TF_CHECK_OK(s.ToGraphDef(&item.graph)); - - // Place all nodes on CPU. - for (int i = 0; i < item.graph.node_size(); ++i) { - item.graph.mutable_node(i)->set_device("/device:CPU:0"); - } - - Remapper optimizer(RewriterConfig::ON); - GraphDef output; - TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output)); - - int found = 0; - for (const NodeDef& node : output.node()) { - if (node.name() == "relu") { - EXPECT_EQ("_FusedConv2D", node.op()); - EXPECT_EQ("input", node.input(0)); - EXPECT_EQ("filter", node.input(1)); - - EXPECT_EQ(2, node.attr().at("num_args").i()); - EXPECT_EQ("bias", node.input(2)); - EXPECT_EQ("input_addn", node.input(3)); - - const auto fused_ops = node.attr().at("fused_ops").list().s(); - EXPECT_EQ(3, fused_ops.size()); - EXPECT_EQ("BiasAdd", fused_ops[0]); - EXPECT_EQ("Add", fused_ops[1]); - EXPECT_EQ("Relu", fused_ops[2]); - found++; - } - } - EXPECT_EQ(1, found); - - auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); - auto tensors = EvaluateNodes(output, item.fetch, item.feed); - EXPECT_EQ(1, tensors_expected.size()); - EXPECT_EQ(1, tensors.size()); - test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); +TEST_F(MklRemapperTest, FuseConv2DWithBiasAndAddN_NCHW_WithRelu) { + const bool kShouldFuseRelu = true; + FuseConv2DWithBiasAndAddN("NCHW", kShouldFuseRelu); } } // namespace grappler diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index 5b41ad38089..7d14e087d98 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -240,7 +240,11 @@ bool IsGpuCompatibleDataType(const NodeDef* contraction, bool IsCpuCompatibleDataFormat(const NodeDef* conv2d) { DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op"; const string& data_format = conv2d->attr().at(kDataFormat).s(); +#ifndef INTEL_MKL return data_format == "NHWC"; +#else + return data_format == "NHWC" || data_format == "NCHW"; +#endif // !INTEL_MKL } bool IsGpuCompatibleDataFormat(const NodeDef* conv2d) { @@ -1662,7 +1666,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, } TF_RETURN_IF_ERROR(mutation->Apply()); - *optimized_graph = mutable_item.graph; + *optimized_graph = std::move(mutable_item.graph); return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc index 5d65067b036..358cc79826b 100644 --- a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc @@ -708,7 +708,7 @@ class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter { NodeDef* old_op = ops[op_idx]; // Copy the output node set since we'll be modifying the version // maintained by NodeMap in the loop. - std::set output_nodes = node_map->GetOutputs(old_op->name()); + auto output_nodes = node_map->GetOutputs(old_op->name()); VLOG(3) << "old_op " << old_op->name() << " had " << output_nodes.size() << " outputs. Moving them to the ScopedAllocatorSplit node."; if (VLOG_IS_ON(2)) { @@ -971,7 +971,7 @@ class Tree { public: Tree(const string& edge, int depth) : edge_(edge), depth_(depth) {} ~Tree() { - for (auto it : subtrees_) delete it.second; + for (const auto& it : subtrees_) delete it.second; } Tree* GetSubTree(const string& edge) { @@ -996,7 +996,7 @@ class Tree { // on any non-OK Status. Status ApplyToAll(Tree* tree, const std::function& func) { Status s; - for (auto it : tree->subtrees_) { + for (const auto& it : tree->subtrees_) { s = ApplyToAll(it.second, func); if (!s.ok()) return s; } diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.cc b/tensorflow/core/grappler/optimizers/shape_optimizer.cc index 4bcb4dfc791..69de1cde4ca 100644 --- a/tensorflow/core/grappler/optimizers/shape_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/shape_optimizer.cc @@ -66,121 +66,125 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } *optimized_graph = item.graph; - MutableGraphView graph(optimized_graph); GraphProperties properties(item); bool inferred_properties = false; - - // The product of all the dimensions in a tensor shape can be expressed more - // simply as the size of the tensor. - for (auto& node : *optimized_graph->mutable_node()) { - if (!IsShape(node)) { - continue; - } - for (MutableGraphView::InputPort fanout : - graph.GetFanout(MutableGraphView::OutputPort(&node, 0))) { - if (fanout.node->op() != "Prod") { + { + MutableGraphView graph(optimized_graph); + // The product of all the dimensions in a tensor shape can be expressed more + // simply as the size of the tensor. + for (auto& node : *optimized_graph->mutable_node()) { + if (!IsShape(node)) { continue; } - if (fanout.node->attr().count("keep_dims") != 0 && - fanout.node->attr().at("keep_dims").b()) { - // Keeping the reduced dimensions won't result in a scalar, so we can't - // rewrite the whole expression directly as a Size operation. - continue; - } - const MutableGraphView::OutputPort reduce_indices = - graph.GetRegularFanin(MutableGraphView::InputPort(fanout.node, 1)); - if (!inferred_properties) { - // Infer properties lazily in case they are not needed. - TF_RETURN_IF_ERROR( - properties.InferStatically(/*assume_valid_feeds=*/false, - /*aggressive_shape_inference=*/false, - /*include_tensor_values=*/false)); - inferred_properties = true; - } - const auto& prop = - properties.GetOutputProperties(reduce_indices.node->name()); - if (prop.size() <= reduce_indices.port_id) { - continue; - } - const TensorShapeProto& reduction_indices_shape = - prop[reduce_indices.port_id].shape(); - if (NumCoefficients(reduction_indices_shape) == 1) { - const auto& input_props = properties.GetInputProperties(node.name()); - if (input_props.size() != 1) { + for (MutableGraphView::InputPort fanout : + graph.GetFanout(MutableGraphView::OutputPort(&node, 0))) { + if (fanout.node->op() != "Prod") { continue; } - // Rewrite the reduction of the shape dimensions as a Size operation. - NodeDef size_node(*fanout.node); - const DataType type = input_props[0].dtype(); - size_node.set_op("Size"); - size_node.set_input(0, node.input(0)); - size_node.set_input(1, AsControlDependency(node)); - size_node.mutable_attr()->erase("Tidx"); - size_node.mutable_attr()->erase("keep_dims"); - (*size_node.mutable_attr())["out_type"] = fanout.node->attr().at("T"); - (*size_node.mutable_attr())["T"].set_type(type); - - // The corresponding Size kernel might not exist on the device where - // Prod was placed, so assign the Size kernel to the same device as the - // input. - size_node.set_device(node.device()); - - // In the unlikely even that "Size" is not registered on the input - // device, skip the optimization. - Status s = IsKernelRegisteredForNode(size_node); - if (!s.ok()) { + if (fanout.node->attr().count("keep_dims") != 0 && + fanout.node->attr().at("keep_dims").b()) { + // Keeping the reduced dimensions won't result in a scalar, so we + // can't rewrite the whole expression directly as a Size operation. continue; } + const MutableGraphView::OutputPort reduce_indices = + graph.GetRegularFanin(MutableGraphView::InputPort(fanout.node, 1)); + if (!inferred_properties) { + // Infer properties lazily in case they are not needed. + TF_RETURN_IF_ERROR( + properties.InferStatically(/*assume_valid_feeds=*/false, + /*aggressive_shape_inference=*/false, + /*include_tensor_values=*/false)); + inferred_properties = true; + } + const auto& prop = + properties.GetOutputProperties(reduce_indices.node->name()); + if (prop.size() <= reduce_indices.port_id) { + continue; + } + const TensorShapeProto& reduction_indices_shape = + prop[reduce_indices.port_id].shape(); + if (NumCoefficients(reduction_indices_shape) == 1) { + const auto& input_props = properties.GetInputProperties(node.name()); + if (input_props.size() != 1) { + continue; + } + // Rewrite the reduction of the shape dimensions as a Size operation. + NodeDef size_node(*fanout.node); + const DataType type = input_props[0].dtype(); + size_node.set_op("Size"); + size_node.set_input(0, node.input(0)); + size_node.set_input(1, AsControlDependency(node)); + size_node.mutable_attr()->erase("Tidx"); + size_node.mutable_attr()->erase("keep_dims"); + (*size_node.mutable_attr())["out_type"] = fanout.node->attr().at("T"); + (*size_node.mutable_attr())["T"].set_type(type); - fanout.node->Swap(&size_node); + // The corresponding Size kernel might not exist on the device where + // Prod was placed, so assign the Size kernel to the same device as + // the input. + size_node.set_device(node.device()); + + // In the unlikely even that "Size" is not registered on the input + // device, skip the optimization. + Status s = IsKernelRegisteredForNode(size_node); + if (!s.ok()) { + continue; + } + + fanout.node->Swap(&size_node); + } } } } - for (auto& node : *optimized_graph->mutable_node()) { - // Try to convert the ratio of 2 symbolic tensor sizes into a constant. This - // is possible whenever the symbolic dimensions in the numerator and - // denominator cancel each other. - if (node.op() == "Div") { - const MutableGraphView::OutputPort input1 = - graph.GetRegularFanin(MutableGraphView::InputPort(&node, 0)); - const MutableGraphView::OutputPort input2 = - graph.GetRegularFanin(MutableGraphView::InputPort(&node, 1)); - if (input1.node == nullptr || input2.node == nullptr) continue; - if (!IsSize(*input1.node) || !IsSize(*input2.node)) { - continue; - } - if (!inferred_properties) { - // Infer properties lazily in case they are not needed. - TF_RETURN_IF_ERROR( - properties.InferStatically(/*assume_valid_feeds=*/false, - /*aggressive_shape_inference=*/false, - /*include_tensor_values=*/false)); - inferred_properties = true; - } - const auto& prop1 = properties.GetInputProperties(input1.node->name()); - const auto& prop2 = properties.GetInputProperties(input2.node->name()); - if (prop1.size() != 1 || prop2.size() != 1) { - continue; - } - const TensorShapeProto& shape1 = prop1[0].shape(); - const TensorShapeProto& shape2 = prop2[0].shape(); - int64 result = ComputeSizeRatio(shape1, shape2); - if (result >= 0) { - // Replace div with constant. - node.set_op("Const"); - DataType dtype = node.attr().at("T").type(); - node.mutable_attr()->erase("T"); - (*node.mutable_attr())["dtype"].set_type(dtype); - TensorProto* t = (*node.mutable_attr())["value"].mutable_tensor(); - t->set_dtype(dtype); - *t->mutable_tensor_shape() = TensorShapeProto(); - if (dtype == DT_INT32) { - t->add_int_val(result); - } else { - t->add_int64_val(result); + { + MutableGraphView graph(optimized_graph); + for (auto& node : *optimized_graph->mutable_node()) { + // Try to convert the ratio of 2 symbolic tensor sizes into a constant. + // This is possible whenever the symbolic dimensions in the numerator and + // denominator cancel each other. + if (node.op() == "Div") { + const MutableGraphView::OutputPort input1 = + graph.GetRegularFanin(MutableGraphView::InputPort(&node, 0)); + const MutableGraphView::OutputPort input2 = + graph.GetRegularFanin(MutableGraphView::InputPort(&node, 1)); + if (input1.node == nullptr || input2.node == nullptr) continue; + if (!IsSize(*input1.node) || !IsSize(*input2.node)) { + continue; + } + if (!inferred_properties) { + // Infer properties lazily in case they are not needed. + TF_RETURN_IF_ERROR( + properties.InferStatically(/*assume_valid_feeds=*/false, + /*aggressive_shape_inference=*/false, + /*include_tensor_values=*/false)); + inferred_properties = true; + } + const auto& prop1 = properties.GetInputProperties(input1.node->name()); + const auto& prop2 = properties.GetInputProperties(input2.node->name()); + if (prop1.size() != 1 || prop2.size() != 1) { + continue; + } + const TensorShapeProto& shape1 = prop1[0].shape(); + const TensorShapeProto& shape2 = prop2[0].shape(); + int64 result = ComputeSizeRatio(shape1, shape2); + if (result >= 0) { + // Replace div with constant. + node.set_op("Const"); + DataType dtype = node.attr().at("T").type(); + node.mutable_attr()->erase("T"); + (*node.mutable_attr())["dtype"].set_type(dtype); + TensorProto* t = (*node.mutable_attr())["value"].mutable_tensor(); + t->set_dtype(dtype); + *t->mutable_tensor_shape() = TensorShapeProto(); + if (dtype == DT_INT32) { + t->add_int_val(result); + } else { + t->add_int64_val(result); + } + node.set_input(0, AsControlDependency(node.input(0))); + node.set_input(1, AsControlDependency(node.input(1))); } - node.set_input(0, AsControlDependency(node.input(0))); - node.set_input(1, AsControlDependency(node.input(1))); } } } diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index c5c3fcd9665..cd6b4855583 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "tensorflow/core/framework/attr_value.pb.h" @@ -92,76 +93,6 @@ NodeMap::NodeMap(GraphDef* graph) { } } -void NodeMap::RemoveNode(const string& name) { - nodes_.erase(NodeName(name)); - outputs_.erase(NodeName(name)); -} - -NodeDef* NodeMap::GetNode(const string& name) const { - const string node_name = NodeName(name); - auto it = nodes_.find(node_name); - if (it == nodes_.end()) { - VLOG(1) << "Node could not be found: " << name; - return nullptr; - } - return it->second; -} - -bool NodeMap::NodeExists(const string& name) const { - const string node_name = NodeName(name); - return nodes_.find(node_name) != nodes_.end(); -} - -const std::set& NodeMap::GetOutputs(const string& node_name) const { - auto it = outputs_.find(node_name); - if (it == outputs_.end()) { - return empty_set_; - } - return it->second; -} - -void NodeMap::AddNode(const string& node_name, NodeDef* node) { - auto ret = nodes_.emplace(node_name, CHECK_NOTNULL(node)); - CHECK(ret.second) << "Pair (" << node_name << "," << node - << ") is not inserted because the same key already exists."; -} - -void NodeMap::AddOutput(const string& node_name, const string& output_name) { - auto output_node = nodes_[NodeName(output_name)]; - CHECK(output_node) << "Output node " << output_name - << " is missing in NodeMap."; - outputs_[node_name].insert(output_node); -} - -void NodeMap::RemoveOutput(const string& node_name, const string& output_name) { - outputs_[node_name].erase(nodes_[NodeName(output_name)]); -} - -void NodeMap::UpdateInput(const string& node_name, const string& old_input_name, - const string& new_input_name) { - RemoveOutput(NodeName(old_input_name), node_name); - AddOutput(NodeName(new_input_name), node_name); -} - -void NodeMap::RemoveInputs(const string& node_name) { - auto node = nodes_[node_name]; - for (const auto& input : node->input()) { - RemoveOutput(NodeName(input), node->name()); - } -} - -void NodeMap::RemoveOutputs(const string& node_name) { - outputs_.erase(node_name); -} - -void NodeMap::UpdateOutput(const string& node_name, - const string& old_output_name, - const string& new_output_name) { - std::set& outputs = outputs_[node_name]; - outputs.erase(nodes_[NodeName(old_output_name)]); - outputs.insert(nodes_[NodeName(new_output_name)]); -} - string TensorIdToString(const TensorId& tensor_id) { return tensor_id.index() == 0 ? string(tensor_id.node()) : tensor_id.ToString(); @@ -436,7 +367,7 @@ void PermuteNodesInPlace(GraphDef* graph, std::vector* permutation, } void DedupControlInputs(NodeDef* node) { - std::unordered_set inputs; + absl::flat_hash_set inputs; int pos = 0; while (pos < node->input_size()) { const string& input = node->input(pos); diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index a50c6f71fee..61cf533d0d7 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -18,11 +18,10 @@ limitations under the License. #include #include -#include -#include #include #include +#include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -42,84 +41,7 @@ limitations under the License. namespace tensorflow { namespace grappler { -// A utility class to lookup a node and its outputs by node name. -class NodeMap { - public: - // Note: The NodeMap will store pointers to nodes in graph, which may become - // invalid if graph is changed. - explicit NodeMap(GraphDef* graph); - NodeDef* GetNode(const string& name) const; - bool NodeExists(const string& name) const; - const std::set& GetOutputs(const string& node_name) const; - // This method doesn't record the outputs of the added node; the outputs need - // to be explicitly added by the AddOutput method. - void AddNode(const string& name, NodeDef* node); - void RemoveNode(const string& name); - void UpdateInput(const string& node_name, const string& old_input_name, - const string& new_input_name); - void AddOutput(const string& node_name, const string& output_name); - void RemoveInputs(const string& node_name); - void RemoveOutput(const string& node_name, const string& output_name); - void RemoveOutputs(const string& node_name); - void UpdateOutput(const string& node_name, const string& old_output_name, - const string& new_output_name); - - private: - const std::set empty_set_; - absl::node_hash_map nodes_; - absl::node_hash_map> outputs_; -}; - -// A vector with a set. The set stores the same elements as the vector, and -// quickly answers whether a value is in the vector. Duplicated elements are not -// allowed for now. -template > -class SetVector { - public: - // Returns false if value already existed in the set, true otherwise. - bool PushBack(const T& value) { - if (!set_.insert(value).second) { - return false; - } - vector_.push_back(value); - return true; - } - - T PopBack() { - T back = vector_.back(); - set_.erase(back); - vector_.pop_back(); - return back; - } - - bool Exists(const T& value) const { return set_.find(value) != set_.end(); } - - bool Empty() const { return vector_.empty(); } - - void Reserve(int64 size) { vector_.reserve(size); } - - private: - gtl::FlatSet set_; - std::vector vector_; -}; - -// Returns formatted string from TensorId specific to grappler. Specifically, -// for the 0 port (first output), only the node name is returned. -string TensorIdToString(const TensorId& tensor_id); - -// Returns formatted string from SafeTensorId specific to grappler. -// Specifically, for the 0 port (first output), only the node name is returned. -string SafeTensorIdToString(const SafeTensorId& tensor_id); - -// True iff 'name' refers to a control inputs, i.e. a node name prefixed with -// the ^ character. -bool IsControlInput(const string& name); - -// True iff tensor index refers to a control input. -bool IsControlInput(const TensorId& tensor_id); - -// True iff 'name1' and 'name2' refer to the same input. -bool IsSameInput(const string& name1, const string& name2); +// Utilities for manipulating node name and input strings. // Returns the trailing position number (or zero if no number is present) if // NodeName(input_name) is equal to node_name. Returns -1 for control inputs. @@ -176,6 +98,162 @@ inline int NodePosition(const string& name) { return position; } +// A utility class to lookup a node and its outputs by node name. +class NodeMap { + public: + // Note: The NodeMap will store pointers to nodes in graph, which may become + // invalid if graph is changed. + explicit NodeMap(GraphDef* graph); + + // Get unordered list of fanouts from node. Notice, that the order is + // non-deterministic. + const absl::flat_hash_set& GetOutputs( + const string& node_name) const { + auto it = outputs_.find(node_name); + if (it == outputs_.end()) { + return empty_set_; + } + return it->second; + } + + // Get fanouts ordered by name. + std::vector GetOutputsOrderedByNodeName( + const string& node_name) const { + std::vector result; + auto it = outputs_.find(node_name); + if (it != outputs_.end()) { + const absl::flat_hash_set& outputs = it->second; + result.reserve(outputs.size()); + result.assign(outputs.begin(), outputs.end()); + std::sort(result.begin(), result.end(), + [](const NodeDef* n1, const NodeDef* n2) { + return n1->name() < n2->name(); + }); + } + return result; + } + + // This method doesn't record the outputs of the added node; the outputs need + // to be explicitly added by the AddOutput method. + void AddNode(const string& node_name, NodeDef* node) { + DCHECK(node != nullptr); + auto ret = nodes_.emplace(node_name, node); + DCHECK(ret.second) + << "Pair (" << node_name << "," << node + << ") is not inserted because the same key already exists."; + } + + void RemoveNode(const string& name) { + nodes_.erase(NodeName(name)); + outputs_.erase(NodeName(name)); + } + + NodeDef* GetNode(const string& name) const { + const string node_name = NodeName(name); + auto it = nodes_.find(node_name); + if (it == nodes_.end()) { + VLOG(1) << "Node could not be found: " << name; + return nullptr; + } + return it->second; + } + + bool NodeExists(const string& name) const { + const string node_name = NodeName(name); + return nodes_.find(node_name) != nodes_.end(); + } + + void AddOutput(const string& node_name, const string& output_name) { + auto output_node = nodes_[NodeName(output_name)]; + DCHECK(output_node) << "Output node " << output_name + << " is missing in NodeMap."; + outputs_[node_name].insert(output_node); + } + + void RemoveOutput(const string& node_name, const string& output_name) { + outputs_[node_name].erase(nodes_[NodeName(output_name)]); + } + + void UpdateInput(const string& node_name, const string& old_input_name, + const string& new_input_name) { + RemoveOutput(NodeName(old_input_name), node_name); + AddOutput(NodeName(new_input_name), node_name); + } + + void RemoveInputs(const string& node_name) { + auto node = nodes_[node_name]; + for (const auto& input : node->input()) { + RemoveOutput(NodeName(input), node->name()); + } + } + + void RemoveOutputs(const string& node_name) { outputs_.erase(node_name); } + + void UpdateOutput(const string& node_name, const string& old_output_name, + const string& new_output_name) { + absl::flat_hash_set& outputs = outputs_[node_name]; + outputs.erase(nodes_[NodeName(old_output_name)]); + outputs.insert(nodes_[NodeName(new_output_name)]); + } + + private: + const absl::flat_hash_set empty_set_; + absl::node_hash_map nodes_; + absl::node_hash_map> outputs_; +}; + +// A vector with a set. The set stores the same elements as the vector, and +// quickly answers whether a value is in the vector. Duplicated elements are not +// allowed for now. +template > +class SetVector { + public: + // Returns false if value already existed in the set, true otherwise. + bool PushBack(const T& value) { + if (!set_.insert(value).second) { + return false; + } + vector_.push_back(value); + return true; + } + + T PopBack() { + T back = vector_.back(); + set_.erase(back); + vector_.pop_back(); + return back; + } + + bool Exists(const T& value) const { return set_.find(value) != set_.end(); } + + bool Empty() const { return vector_.empty(); } + + void Reserve(int64 size) { vector_.reserve(size); } + + private: + gtl::FlatSet set_; + std::vector vector_; +}; + +// Returns formatted string from TensorId specific to grappler. Specifically, +// for the 0 port (first output), only the node name is returned. +string TensorIdToString(const TensorId& tensor_id); + +// Returns formatted string from SafeTensorId specific to grappler. +// Specifically, for the 0 port (first output), only the node name is returned. +string SafeTensorIdToString(const SafeTensorId& tensor_id); + +// True iff 'name' refers to a control inputs, i.e. a node name prefixed with +// the ^ character. +bool IsControlInput(const string& name); + +// True iff tensor index refers to a control input. +bool IsControlInput(const TensorId& tensor_id); + +// True iff 'name1' and 'name2' refer to the same input. +bool IsSameInput(const string& name1, const string& name2); + + // Add a prefix to a node name with a custom delimiter. string AddPrefixToNodeName(const string& name, const string& prefix, const string& delimiter); diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index 127bf465b3f..8e04b573770 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -188,12 +188,12 @@ cc_library( hdrs = ["functions.h"], visibility = ["//visibility:public"], deps = [ - "//tensorflow/core:core_cpu_base_no_ops", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime:core_cpu_base_no_ops", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", diff --git a/tensorflow/core/grappler/utils/canonicalizer.cc b/tensorflow/core/grappler/utils/canonicalizer.cc index a30d97b0f3d..9ec22d39849 100644 --- a/tensorflow/core/grappler/utils/canonicalizer.cc +++ b/tensorflow/core/grappler/utils/canonicalizer.cc @@ -58,7 +58,9 @@ void CompressConstants(GraphDef* graph) { if ((IsConstant(*node) || IsHostConstant(*node)) && HasNodeAttr(*node, "value")) { AttrValue& attr_val = (*node->mutable_attr())["value"]; - tensor::CompressTensorProtoInPlace(attr_val.mutable_tensor()); + if (attr_val.has_tensor()) { + tensor::CompressTensorProtoInPlace(attr_val.mutable_tensor()); + } } } } diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 0477d260e10..ea4d7685705 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -982,7 +982,7 @@ ARRAY_DEPS = [ "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//third_party/eigen3", -] + if_sycl(["//tensorflow/core:sycl_runtime"]) +] + if_sycl(["//tensorflow/core/common_runtime/sycl:sycl_runtime"]) tf_kernel_library( name = "immutable_constant_op", @@ -1314,7 +1314,23 @@ tf_kernel_library( tf_kernel_library( name = "tile_ops", - srcs = ["tile_functor_cpu.cc"], + srcs = [ + "tile_functor_cpu.h", + "tile_functor_cpu_bfloat16.cc", + "tile_functor_cpu_bool.cc", + "tile_functor_cpu_complex128.cc", + "tile_functor_cpu_complex64.cc", + "tile_functor_cpu_double.cc", + "tile_functor_cpu_float.cc", + "tile_functor_cpu_half.cc", + "tile_functor_cpu_int16.cc", + "tile_functor_cpu_int32.cc", + "tile_functor_cpu_int64.cc", + "tile_functor_cpu_int8.cc", + "tile_functor_cpu_tstring.cc", + "tile_functor_cpu_uint8.cc", + "tile_functor_sycl.cc", + ], hdrs = ["tile_functor.h"], gpu_srcs = [ "tile_functor.h", @@ -5608,7 +5624,7 @@ STATE_DEPS = [ "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", -] + if_sycl(["//tensorflow/core:sycl_runtime"]) +] + if_sycl(["//tensorflow/core/common_runtime/sycl:sycl_runtime"]) tf_kernel_library( name = "count_up_to_op", @@ -6553,6 +6569,7 @@ filegroup( "data_format_ops.h", "depthtospace_op.h", "depthwise_conv_op.h", + "extract_image_patches_op.h", "fake_quant_ops_functor.h", "fused_batch_norm_op.h", "gemm_functors.h", @@ -6736,6 +6753,7 @@ filegroup( "decode_bmp_op.cc", "depthtospace_op.cc", "dynamic_stitch_op.cc", + "extract_image_patches_op.cc", "fft_ops.cc", "in_topk_op.cc", "in_topk_op.h", @@ -6821,7 +6839,20 @@ filegroup( "summary_op.cc", "tensor_array.cc", "tensor_array_ops.cc", - "tile_functor_cpu.cc", + "tile_functor_cpu.h", + "tile_functor_cpu_bfloat16.cc", + "tile_functor_cpu_bool.cc", + "tile_functor_cpu_complex128.cc", + "tile_functor_cpu_complex64.cc", + "tile_functor_cpu_double.cc", + "tile_functor_cpu_float.cc", + "tile_functor_cpu_half.cc", + "tile_functor_cpu_int16.cc", + "tile_functor_cpu_int32.cc", + "tile_functor_cpu_int64.cc", + "tile_functor_cpu_int8.cc", + "tile_functor_cpu_tstring.cc", + "tile_functor_cpu_uint8.cc", "tile_ops.cc", "tile_ops_cpu_impl_1.cc", "tile_ops_cpu_impl_2.cc", diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index 4cd52ad6188..bbad9278ac1 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -121,6 +121,34 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "data_service_dataset_op", + srcs = ["data_service_dataset_op.cc"], + hdrs = ["data_service_dataset_op.h"], + deps = [ + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/data/service:compression_utils", + "//tensorflow/core/data/service:credentials_factory", + "//tensorflow/core/data/service:grpc_util", + "//tensorflow/core/data/service:master_cc_grpc_proto", + "//tensorflow/core/data/service:master_proto_cc", + "//tensorflow/core/data/service:worker_cc_grpc_proto", + "//tensorflow/core/distributed_runtime/rpc:grpc_util", + "//tensorflow/core/kernels/data:dataset_utils", + "//tensorflow/core/kernels/data:name_utils", + "//tensorflow/core/kernels/data:serialization_utils", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + tf_grpc_cc_dependency(), + ], +) + tf_kernel_library( name = "data_service_ops", srcs = ["data_service_ops.cc"], diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc new file mode 100644 index 00000000000..3e9edbf3349 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc @@ -0,0 +1,498 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/data/experimental/data_service_dataset_op.h" + +#include +#include +#include + +#include "grpcpp/create_channel.h" +#include "grpcpp/impl/codegen/server_context.h" +#include "grpcpp/security/credentials.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/core/data/service/compression_utils.h" +#include "tensorflow/core/data/service/credentials_factory.h" +#include "tensorflow/core/data/service/grpc_util.h" +#include "tensorflow/core/data/service/master.grpc.pb.h" +#include "tensorflow/core/data/service/master.pb.h" +#include "tensorflow/core/data/service/worker.grpc.pb.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/model.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" +#include "tensorflow/core/kernels/data/name_utils.h" +#include "tensorflow/core/kernels/data/serialization_utils.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/snappy.h" +#include "tensorflow/core/protobuf/error_codes.pb.h" + +namespace tensorflow { +namespace data { + +/* static */ constexpr const char* const DataServiceDatasetOp::kDatasetType; +/* static */ constexpr const char* const DataServiceDatasetOp::kAddress; +/* static */ constexpr const char* const DataServiceDatasetOp::kEpochId; +/* static */ constexpr const char* const DataServiceDatasetOp::kProtocol; +/* static */ constexpr const char* const + DataServiceDatasetOp::kMaxOutstandingRequests; +/* static */ constexpr const char* const DataServiceDatasetOp::kOutputTypes; +/* static */ constexpr const char* const DataServiceDatasetOp::kOutputShapes; + +// Once we've spent `kRetryTimeoutMicros` in `GetNextInternal`, we will wait for +// the current attempt to complete and perform no more retries. +const int64 kRetryTimeoutMicros = 1000LL * 1000 * 60 * 60; // 60 minutes. + +// How often to refresh the task list. +const int64 kRefreshTasksIntervalMicros = 1000LL * 1000 * 60; // 60 seconds. + +// Dataset for reading data from the tf.data service non-deterministically. +// +// This dataset interleaves dataset elements produced by multiple tf.data +// workers. We periodically query the tf.data master to determine which workers +// to read from (in case workers are added or removed). +class DataServiceDatasetOp::Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const std::string& address, + const std::string& protocol, const int64 max_outstanding_requests, + const DataTypeVector& output_types, + const std::vector& output_shapes) + : DatasetBase(DatasetContext(ctx)), + address_(address), + protocol_(protocol), + max_outstanding_requests_(max_outstanding_requests), + output_types_(output_types), + output_shapes_(output_shapes) {} + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return absl::make_unique(Iterator::Params{ + this, name_utils::IteratorPrefix(kDatasetType, prefix)}); + } + + const DataTypeVector& output_dtypes() const override { return output_types_; } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return name_utils::DatasetDebugString(kDatasetType); + } + + Status CheckExternalState() const override { + return Status( + error::FAILED_PRECONDITION, + strings::StrCat(DebugString(), " does not yet support serialization.")); + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* address; + TF_RETURN_IF_ERROR(b->AddScalar(address_, &address)); + + Node* protocol; + TF_RETURN_IF_ERROR(b->AddScalar(protocol_, &protocol)); + + Node* max_outstanding_requests; + TF_RETURN_IF_ERROR( + b->AddScalar(max_outstanding_requests_, &max_outstanding_requests)); + + TF_RETURN_IF_ERROR(b->AddDataset( + this, {address, protocol, max_outstanding_requests}, {}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + ~Iterator() override { + mutex_lock l(mu_); + cancelled_ = true; + cv_.notify_all(); + // Thread destructors will block until the threads finish, no need to wait + // here. + } + + Status Initialize(IteratorContext* ctx) override { + VLOG(3) << "Connecting to " << dataset()->address_ + << " in data service dataset op"; + if (ctx->epoch_id() == IteratorContext::kNoEpochId) { + // TODO(aaudibert): add instructions for passing an epoch id after we + // add a Python API. + return errors::FailedPrecondition( + "Expected an epoch id, but none found."); + } + epoch_id_ = ctx->epoch_id(); + TF_RETURN_IF_ERROR(CredentialsFactory::CreateClientCredentials( + dataset()->protocol_, &credentials_)); + return Status::OK(); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + VLOG(3) << "Calling GetNext in data service dataset op"; + mutex_lock l(mu_); + if (!task_thread_manager_ && !cancelled_) { + task_thread_manager_ = ctx->StartThread( + "task-thread-manager", [this, ctx]() { TaskThreadManager(ctx); }); + } + + // tasks_.empty() indicates that we haven't yet received tasks from the + // master, so we should wait. + while (results_.empty() && + (tasks_.empty() || num_unfinished_tasks_ > 0) && !cancelled_) { + cv_.wait(l); + } + if (cancelled_) { + return errors::Cancelled("Data service iterator was cancelled"); + } + if (results_.empty()) { + *end_of_sequence = true; + return Status::OK(); + } + DCHECK(!results_.empty()); + out_tensors->swap(results_.front()); + results_.pop(); + cv_.notify_all(); + + return Status::OK(); + } + + protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { + return errors::Unimplemented("SaveInternal is not yet supported"); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return errors::Unimplemented("RestoreInternal is not yet supported"); + } + + private: + typedef struct TaskThread { + int64 task_id; + // Cached address of the worker for task `task_id`. + std::string address; + std::unique_ptr worker_stub; + std::unique_ptr thread; + bool end_of_sequence = false; + } TaskThread; + + // Periodically refresh the task list. + // Maintain one thread fetching elements for each task. + // TODO(aaudibert): Instead of polling, have master send updates when + // the list of tasks changes. + void TaskThreadManager(IteratorContext* ctx) { + VLOG(3) << "Starting task handler manager"; + auto channel = ::grpc::CreateChannel(dataset()->address_, credentials_); + std::unique_ptr master_stub = + MasterService::NewStub(channel); + + uint64 next_check = Env::Default()->NowMicros(); + while (true) { + { + mutex_lock l(mu_); + // All units are microseconds. + while (!cancelled_ && Env::Default()->NowMicros() < next_check) { + int64 remaining_time = next_check - Env::Default()->NowMicros(); + VLOG(3) << "Task manager waiting for " << remaining_time << "us"; + cv_.wait_for(l, std::chrono::microseconds(remaining_time)); + } + if (cancelled_) { + return; + } + } + UpdateTaskThreads(master_stub.get(), ctx); + next_check = Env::Default()->NowMicros() + kRefreshTasksIntervalMicros; + } + } + + void UpdateTaskThreads(MasterService::Stub* master_stub, + IteratorContext* ctx) LOCKS_EXCLUDED(mu_) { + VLOG(3) << "Updating task handler threads"; + GetTasksResponse resp; + GetTasksRequest req; + req.set_epoch_id(epoch_id_); + grpc::ClientContext client_ctx; + grpc::Status s = master_stub->GetTasks(&client_ctx, req, &resp); + if (!s.ok()) { + LOG(INFO) << "Failed to get task info for epoch id " << epoch_id_ + << ": " << s.error_message() << "(" << s.error_code() << ")"; + return; + } + absl::flat_hash_set task_ids; + mutex_lock l(mu_); + for (auto& task : resp.task_info()) { + task_ids.insert(task.id()); + if (task_threads_.contains(task.id())) { + continue; + } + tasks_[task.id()] = task; + task_threads_[task.id()] = absl::make_unique(); + TaskThread* task_handler = task_threads_[task.id()].get(); + task_handler->task_id = task.id(); + num_unfinished_tasks_++; + task_handler->thread = ctx->StartThread( + "tf-data-service-task_handler", + [this, task_handler]() { RunTaskThread(task_handler); }); + } + // Mark deleted tasks and clean up finished task threads. + for (auto it = task_threads_.begin(); it != task_threads_.end();) { + TaskThread* task_thread = it->second.get(); + if (task_thread->end_of_sequence) { + task_threads_.erase(it++); + continue; + } + if (!task_ids.contains(task_thread->task_id)) { + task_thread->end_of_sequence = true; + } + ++it; + } + if (dataset()->max_outstanding_requests_ == model::kAutotune) { + // Adjust max_outstanding_requests to account for newly added tasks. + max_outstanding_requests_ = task_threads_.size(); + } + } + + void RunTaskThread(TaskThread* task_handler) { + auto cleanup = gtl::MakeCleanup([this]() { + mutex_lock l(mu_); + outstanding_requests_--; + num_unfinished_tasks_--; + cv_.notify_all(); + }); + { + mutex_lock l(mu_); + outstanding_requests_++; + task_handler->address = tasks_[task_handler->task_id].worker_address(); + } + VLOG(3) << "Starting task handler thread for task " + << task_handler->task_id << " with worker address " + << task_handler->address; + while (true) { + if (!task_handler->worker_stub) { + Status s = CreateWorkerStub(task_handler->address, + &task_handler->worker_stub); + if (!s.ok()) { + LOG(WARNING) << "Failed to create a worker stub for " + << task_handler->address << ": " << s; + } + } + { + mutex_lock l(mu_); + if (task_handler->end_of_sequence) { + return; + } + outstanding_requests_--; + while (!cancelled_ && results_.size() + outstanding_requests_ >= + max_outstanding_requests_) { + VLOG(3) << "Task thread for task " << task_handler->task_id + << " waiting. results_.size()=" << results_.size() + << " outstanding_requests_=" << outstanding_requests_; + cv_.wait(l); + } + outstanding_requests_++; + if (cancelled_) { + return; + } + } + // TODO(aaudibert): add backoff and max retries. + int64 deadline_micros = + Env::Default()->NowMicros() + kRetryTimeoutMicros; + Status s = FetchElement(task_handler, deadline_micros); + if (!s.ok()) { + LOG(WARNING) << "Failed to fetch element from worker at " + << task_handler->address << ": " << s; + } + } + } + + Status FetchElement(TaskThread* task_handler, int64 deadline_micros) { + VLOG(3) << "Fetchng an element for task id " << task_handler->task_id; + GetElementResponse resp; + TF_RETURN_IF_ERROR( + GetElementWithDeadline(task_handler, &resp, deadline_micros)); + std::vector element; + if (!resp.end_of_sequence()) { + TF_RETURN_IF_ERROR( + service_util::Uncompress(resp.compressed_element(), &element)); + } + mutex_lock l(mu_); + if (resp.end_of_sequence()) { + task_handler->end_of_sequence = true; + return Status::OK(); + } + results_.push(std::move(element)); + cv_.notify_all(); + VLOG(3) << "Fetched an element for task id " << task_handler->task_id; + return Status::OK(); + } + + Status CreateWorkerStub(const std::string& worker_address, + std::unique_ptr* stub) { + ::grpc::ChannelArguments args; + args.SetMaxReceiveMessageSize(-1); + std::shared_ptr<::grpc::ChannelCredentials> credentials; + TF_RETURN_IF_ERROR(CredentialsFactory::CreateClientCredentials( + dataset()->protocol_, &credentials)); + auto channel = + ::grpc::CreateCustomChannel(worker_address, credentials, args); + *stub = WorkerService::NewStub(channel); + return Status::OK(); + } + + Status GetElementWithDeadline(TaskThread* task_handler, + GetElementResponse* resp, + int64 deadline_micros) { + return RetryWithDeadline( + [task_handler, resp] { + GetElementRequest req; + req.set_task_id(task_handler->task_id); + grpc::ClientContext client_ctx; + grpc::Status s = + task_handler->worker_stub->GetElement(&client_ctx, req, resp); + if (s.ok()) { + return Status::OK(); + } + return grpc_util::WrapError("Failed to fetch an element", s); + }, + deadline_micros); + } + + static bool ShouldRetryError(error::Code error_code) { + // Retry all errors that could indicate preemption. + return error_code == error::Code::UNAVAILABLE || + error_code == error::Code::CANCELLED || + error_code == error::Code::ABORTED; + } + + static Status RetryWithDeadline(const std::function& call, + int64 deadline_micros) { + Status s; + for (int num_retries = 0;; ++num_retries) { + s = call(); + if (s.ok() || !ShouldRetryError(s.code())) { + return s; + } + const int64 now_micros = EnvTime::NowMicros(); + if (now_micros > deadline_micros) { + return s; + } + const int64 deadline_with_backoff_micros = + now_micros + ::tensorflow::ComputeBackoffMicroseconds(num_retries); + // Wait for a short period of time before retrying the RPC. If our + // backoff would put us past the RPC deadline, we truncate it to ensure + // our RPC starts before the deadline. + const auto backoff_until = + (deadline_micros > deadline_with_backoff_micros) + ? deadline_with_backoff_micros + : deadline_micros; + Env::Default()->SleepForMicroseconds(backoff_until - now_micros); + } + } + + mutex mu_; + // TODO(aaudibert): split this into a couple cvs for different conditions + // so that we can use notify_one and avoid unnecessary wakeups. + condition_variable cv_ TF_GUARDED_BY(mu_); + bool cancelled_ TF_GUARDED_BY(mu_) = false; + + int64 outstanding_requests_ TF_GUARDED_BY(mu_) = 0; + // max_outstanding_requests controls how many elements may be held in memory + // at the same time. This count includes both in-progress requests for + // elements as well as completed requests which haven't yet been produced. + int64 max_outstanding_requests_ TF_GUARDED_BY(mu_); + std::queue> results_ TF_GUARDED_BY(mu_); + + // Set once in Initialize(). + int64 epoch_id_; + std::shared_ptr<::grpc::ChannelCredentials> credentials_; + int64 num_unfinished_tasks_ TF_GUARDED_BY(mu_) = 0; + // Map from task id to task info. + absl::flat_hash_map tasks_ TF_GUARDED_BY(mu_); + + // Must come second to last so that task threads are joined before + // destroying other fields. + absl::flat_hash_map> task_threads_ + TF_GUARDED_BY(mu_); + // Must be ordered last so that the thread is joined before destroying other + // fields. + std::unique_ptr task_thread_manager_ GUARDED_BY(mu_); + }; + + const tstring address_; + const tstring protocol_; + const int64 max_outstanding_requests_; + const DataTypeVector output_types_; + const std::vector output_shapes_; +}; + +DataServiceDatasetOp::DataServiceDatasetOp(OpKernelConstruction* ctx) + : DatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_)); +} + +void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx, + DatasetBase** output) { + tstring address; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kAddress, &address)); + OP_REQUIRES(ctx, !address.empty(), + errors::InvalidArgument(kAddress, " must be non-empty.")); + + tstring protocol; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kProtocol, &protocol)); + OP_REQUIRES(ctx, !protocol.empty(), + errors::InvalidArgument(kProtocol, " must be non-empty.")); + + int64 max_outstanding_requests; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kMaxOutstandingRequests, + &max_outstanding_requests)); + OP_REQUIRES( + ctx, + max_outstanding_requests == model::kAutotune || + max_outstanding_requests > 0, + errors::InvalidArgument(kMaxOutstandingRequests, " must be positive or ", + model::kAutotune)); + + *output = new Dataset(ctx, address, protocol, max_outstanding_requests, + output_types_, output_shapes_); +} + +REGISTER_KERNEL_BUILDER(Name("DataServiceDataset").Device(DEVICE_CPU), + DataServiceDatasetOp); + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.h b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.h new file mode 100644 index 00000000000..b0356fd53bf --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.h @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +// Creates a dataset for reading from the tf.data service. +class DataServiceDatasetOp : public DatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "DataService"; + static constexpr const char* const kAddress = "address"; + static constexpr const char* const kEpochId = "epoch_id"; + static constexpr const char* const kProtocol = "protocol"; + static constexpr const char* const kMaxOutstandingRequests = + "max_outstanding_requests"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit DataServiceDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; + + private: + class Dataset; + + DataTypeVector output_types_; + std::vector output_shapes_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_DATASET_OP_H_ diff --git a/tensorflow/core/kernels/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl_avgpooling_op.cc index 0751fff4c26..47b999e0e11 100644 --- a/tensorflow/core/kernels/mkl_avgpooling_op.cc +++ b/tensorflow/core/kernels/mkl_avgpooling_op.cc @@ -233,6 +233,13 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase { : memory::desc(orig_input_dims_mkl_order, MklDnnType(), this->data_format_mkldnn_); + // Get diff_dst memory::desc. + memory::desc diff_dst_md = + grad_mkl_shape.IsMklTensor() + ? grad_mkl_shape.GetMklLayout() + : memory::desc(diff_dst_dims, MklDnnType(), + this->data_format_mkldnn_); + // Pass prop_kind::forward_training to create a forward primitive // that is used in the backward pass. #ifdef ENABLE_MKLDNN_V1 @@ -241,7 +248,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase { orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims, strides, padding_left, padding_right, ALGORITHM::pooling_avg_exclude_padding, prop_kind::forward_training, - static_cast(this->data_format_mkldnn_), src_md); + static_cast(this->data_format_mkldnn_), src_md, + diff_dst_md); #else MklPoolingParams bwdParams( orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims, @@ -256,12 +264,6 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase { this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()), orig_input_dims_mkl_order, this->tensor_format_mkldnn_, &output_tensor); - // Get diff_dst memory::desc. - memory::desc diff_dst_md = - grad_mkl_shape.IsMklTensor() - ? grad_mkl_shape.GetMklLayout() - : memory::desc(diff_dst_dims, MklDnnType(), - this->data_format_mkldnn_); // TODO(nammbash): Refactor (lines 249-262) common code for // max & avg pooling into superclass or common utils function. diff --git a/tensorflow/core/kernels/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl_batch_matmul_op.cc index e4be4f8a341..37888656020 100644 --- a/tensorflow/core/kernels/mkl_batch_matmul_op.cc +++ b/tensorflow/core/kernels/mkl_batch_matmul_op.cc @@ -148,35 +148,56 @@ class BatchMatMulMkl : public OpKernel { std::vector ldb_array(batch_size, adj_y_ ? K : N); std::vector ldc_array(batch_size, N); std::vector group_size(1, batch_size); - std::vector a_array; - std::vector b_array; - std::vector c_array; - a_array.reserve(batch_size); - b_array.reserve(batch_size); - c_array.reserve(batch_size); - if (!bcast.IsBroadcastingRequired()) { - for (int64 i = 0; i < batch_size; i++) { - a_array.push_back(&lhs_reshaped(i, 0, 0)); - b_array.push_back(&rhs_reshaped(i, 0, 0)); - c_array.push_back(&out_reshaped(i, 0, 0)); - } + if (std::is_same::value) { + // DNNL bfloat16 API requires a, b, and c as pointers to tensors + // represented as flat-byte array. + const Scalar* a = nullptr; + const Scalar* b = nullptr; + OP_REQUIRES(ctx, !bcast.IsBroadcastingRequired(), + errors::Unimplemented("Broadcasting is not supported for " + "BFloat16 _MklBatchMatMul yet.")); + a = &lhs_reshaped(0, 0, 0); + b = &rhs_reshaped(0, 0, 0); + Scalar* c = &out_reshaped(0, 0, 0); + // TODO(nhasabni): Use appropriate cast instead of passing addresses of + // a,b and c. + MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, m_array, n_array, + k_array, &a, lda_array, &b, ldb_array, &c, ldc_array, 1, + group_size); } else { - // Broadcasting is needed, so get the mapping from flattened output batch - // indices to x's and y's flattened batch indices. - const std::vector& a_batch_indices = bcast.x_batch_indices(); - const std::vector& b_batch_indices = bcast.y_batch_indices(); + std::vector a_array; + std::vector b_array; + std::vector c_array; + a_array.reserve(batch_size); + b_array.reserve(batch_size); + c_array.reserve(batch_size); - for (int64 i = 0; i < batch_size; i++) { - a_array.push_back(&lhs_reshaped(a_batch_indices[i], 0, 0)); - b_array.push_back(&rhs_reshaped(b_batch_indices[i], 0, 0)); - c_array.push_back(&out_reshaped(i, 0, 0)); + if (!bcast.IsBroadcastingRequired()) { + for (int64 i = 0; i < batch_size; i++) { + a_array.push_back(&lhs_reshaped(i, 0, 0)); + b_array.push_back(&rhs_reshaped(i, 0, 0)); + c_array.push_back(&out_reshaped(i, 0, 0)); + } + } else { + // Broadcasting is needed, so get the mapping from flattened output + // batch indices to x's and y's flattened batch indices. + const std::vector& a_batch_indices = bcast.x_batch_indices(); + const std::vector& b_batch_indices = bcast.y_batch_indices(); + + for (int64 i = 0; i < batch_size; i++) { + a_array.push_back(&lhs_reshaped(a_batch_indices[i], 0, 0)); + b_array.push_back(&rhs_reshaped(b_batch_indices[i], 0, 0)); + c_array.push_back(&out_reshaped(i, 0, 0)); + } } - } - MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, m_array, n_array, k_array, - &a_array[0], lda_array, &b_array[0], ldb_array, - &c_array[0], ldc_array, 1, group_size); + // MKL CBLAS API requires a, b, and c as array of pointers, where each + // pointer is to 2D matrix. + MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, m_array, n_array, + k_array, &a_array[0], lda_array, &b_array[0], ldb_array, + &c_array[0], ldc_array, 1, group_size); + } } private: @@ -269,10 +290,11 @@ class BatchMatMulMkl : public OpKernel { std::vector TransB_Array(group_size[0], TransB); std::vector alpha_Array(group_size[0], 1.0); std::vector beta_Array(group_size[0], 0.0); + // TODO(nhasabni): Remove *A when we pass a, b, and c correctly. + // MKLDNN API does not require lda, ldb, and ldc. dnnl_gemm_batch(TransA_Array, TransB_Array, M_Array, N_Array, - K_Array, alpha_Array, A_Array, lda_Array, B_Array, - ldb_Array, beta_Array, C_Array, ldc_Array, - group_count, group_size); + K_Array, alpha_Array, *A_Array, *B_Array, + beta_Array, *C_Array, group_count, group_size); } #endif // ENABLE_MKLDNN_V1 && ENABLE_INTEL_MKL_BFLOAT16 }; @@ -302,10 +324,10 @@ TF_CALL_double(REGISTER_BATCH_MATMUL_MKL_V2); TF_CALL_complex64(REGISTER_BATCH_MATMUL_MKL_V2); TF_CALL_complex128(REGISTER_BATCH_MATMUL_MKL_V2); -#if defined(ENABLE_INTEL_MKLDNN_V1) && defined(ENABLE_INTEL_MKL_BFLOAT16) +#if defined(ENABLE_MKLDNN_V1) && defined(ENABLE_INTEL_MKL_BFLOAT16) TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL); TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL_V2); -#endif // ENABLE_INTEL_MKLDNN_V1 && ENABLE_INTEL_MKL_BFLOAT16 +#endif // ENABLE_MKLDNN_V1 && ENABLE_INTEL_MKL_BFLOAT16 #endif // ENABLE_MKL } // end namespace tensorflow diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc index 40e4825c0fa..115b3597964 100644 --- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc @@ -1008,7 +1008,7 @@ class MklFusedBatchNormOp : public OpKernel { tf_shape_scale, mkl_shape_saved_mean); DCHECK(*saved_mean_tensor); - // Set NAN mean value in case of empty input tensor + // Set 0 mean value in case of empty input tensor auto saved_mean_data = (*saved_mean_tensor)->flat().data(); std::fill_n(saved_mean_data, num_elements, static_cast(0)); @@ -1019,7 +1019,7 @@ class MklFusedBatchNormOp : public OpKernel { mkl_shape_saved_variance); DCHECK(*saved_variance_tensor); - // Set NAN variance value in case of empty input tensor + // Set 0 variance value in case of empty input tensor auto saved_variance_data = (*saved_variance_tensor)->flat().data(); std::fill_n(saved_variance_data, num_elements, static_cast(0)); @@ -1346,16 +1346,12 @@ class MklFusedBatchNormGradOp : public OpKernel { mkl_shape_p.SetMklTensor(false); AllocateOutputSetMklShape(context, kP1Index, &p1_tensor, TensorShape({}), mkl_shape_p); -#ifndef ENABLE_MKLDNN_V1 std::fill_n(p1_tensor->flat().data(), p1_tensor->shape().num_elements(), static_cast(0)); -#endif // !ENABLE_MKLDNN_V1 AllocateOutputSetMklShape(context, kP2Index, &p2_tensor, TensorShape({}), mkl_shape_p); -#ifndef ENABLE_MKLDNN_V1 std::fill_n(p2_tensor->flat().data(), p2_tensor->shape().num_elements(), static_cast(0)); -#endif // !ENABLE_MKLDNN_V1 } memory::dims GetMeanVarianceDims() { return memory::dims({1, depth_}); } diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc index 98ee577f807..3a7c864d10e 100644 --- a/tensorflow/core/kernels/mkl_matmul_op.cc +++ b/tensorflow/core/kernels/mkl_matmul_op.cc @@ -181,17 +181,18 @@ class MklMatMulOp : public OpKernel { const int ldc) { const float alpha = 1.0f; const float beta = 0.0f; - const char* const ftrans[] = {"N", "T", "C"}; const int index_transa = transa ? 1 : 0; const int index_transb = transb ? 1 : 0; -#ifdef ENABLE_MKLDNN_V1 - dnnl_gemm(transa ? CblasTrans : CblasNoTrans, - transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, - lda, b, ldb, beta, c, ldc); -#else Tensor c_float; OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {m, n}, &c_float)); +#ifdef ENABLE_MKLDNN_V1 + const char ftrans[] = {'N', 'T', 'C'}; + dnnl_gemm(ftrans[index_transa], ftrans[index_transb], m, n, k, + alpha, a, lda, b, ldb, beta, + c_float.flat().data(), ldc); +#else + const char* const ftrans[] = {"N", "T", "C"}; // MKL-DNN only supports the Fortran API and requires column major while // Tensorflow uses row major so we reverse the order of A and B. @@ -200,9 +201,8 @@ class MklMatMulOp : public OpKernel { reinterpret_cast(b), &ldb, reinterpret_cast(a), &lda, &beta, c_float.flat().data(), &ldc); - - FloatToBFloat16(c_float.flat().data(), c, c_float.NumElements()); #endif // ENABLE_MKLDNN_V1 + FloatToBFloat16(c_float.flat().data(), c, c_float.NumElements()); } #endif // ENABLE_INTEL_MKL_BFLOAT16 diff --git a/tensorflow/core/kernels/mkl_matmul_op_fused.cc b/tensorflow/core/kernels/mkl_matmul_op_fused.cc index 18f6667fd1e..99a2cfc214b 100644 --- a/tensorflow/core/kernels/mkl_matmul_op_fused.cc +++ b/tensorflow/core/kernels/mkl_matmul_op_fused.cc @@ -243,6 +243,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ MklFusedMatMulOp); TF_CALL_float(REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES); +TF_CALL_bfloat16(REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES); } // namespace tensorflow diff --git a/tensorflow/core/kernels/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl_matmul_ops_common.h index ffe50dd4022..ab816ce73fa 100644 --- a/tensorflow/core/kernels/mkl_matmul_ops_common.h +++ b/tensorflow/core/kernels/mkl_matmul_ops_common.h @@ -548,16 +548,14 @@ template void dnnl_gemm_batch(const std::vector& transa, const std::vector& transb, const std::vector& m, const std::vector& n, const std::vector& k, - const std::vector& alpha, const T** a, - const std::vector& lda, const T** b, - const std::vector& ldb, - const std::vector& beta, T** c, - const std::vector& ldc, const int group_count, + const std::vector& alpha, const T* a, const T* b, + const std::vector& beta, T* c, + const int group_count, const std::vector& group_size) { // Current BatchMatMul support in Tensorflow is narrower than the one offered // by MKL and MKL-DNN. Current BatchMatMul support in Tensorflow uses only 1 // group of size equal to batch_size, and all MatMul parameters (m, n, k, - // lda, ldb, ldc, alpha, beta) within that group are same. + // alpha, beta) within that group are same. DCHECK(group_size.size() == 1); DCHECK(transa.size() == group_size[0]); DCHECK(transb.size() == group_size[0]); @@ -566,9 +564,6 @@ void dnnl_gemm_batch(const std::vector& transa, DCHECK(m.size() == group_size[0]); DCHECK(n.size() == group_size[0]); DCHECK(k.size() == group_size[0]); - DCHECK(lda.size() == group_size[0]); - DCHECK(ldb.size() == group_size[0]); - DCHECK(ldc.size() == group_size[0]); for (int64_t idx = 0; idx < group_size[0]; idx++) DCHECK(transa[0] == transa[idx]); for (int64_t idx = 0; idx < group_size[0]; idx++) @@ -580,21 +575,24 @@ void dnnl_gemm_batch(const std::vector& transa, for (int64_t idx = 0; idx < group_size[0]; idx++) DCHECK(m[0] == m[idx]); for (int64_t idx = 0; idx < group_size[0]; idx++) DCHECK(n[0] == n[idx]); for (int64_t idx = 0; idx < group_size[0]; idx++) DCHECK(k[0] == k[idx]); - for (int64_t idx = 0; idx < group_size[0]; idx++) DCHECK(lda[0] == lda[idx]); - for (int64_t idx = 0; idx < group_size[0]; idx++) DCHECK(ldb[0] == ldb[idx]); - for (int64_t idx = 0; idx < group_size[0]; idx++) DCHECK(ldc[0] == ldc[idx]); using dims = mkldnn::memory::dims; // Prepare strides based on the transa and transb flags: transposed // matrices have strides swapped BatchMatMul in MKL-DNN supports 3D metrices // so far. That is why strides are 3D also. - dims a_strides = transa[0] ? dims{lda[0], 1, 1} : dims{1, 1, lda[0]}; - dims b_strides = transb[0] ? dims{ldb[0], 1, 1} : dims{1, 1, ldb[0]}; - dims c_strides = dims{ldc[0], 1, 1}; + dims a_sizes = dims{group_size[0], m[0], k[0]}; + dims b_sizes = dims{group_size[0], k[0], n[0]}; + dims c_sizes = dims{group_size[0], m[0], n[0]}; + dims a_strides = + !transa[0] ? dims{m[0] * k[0], k[0], 1} : dims{k[0] * m[0], 1, m[0]}; + dims b_strides = + !transb[0] ? dims{k[0] * n[0], n[0], 1} : dims{n[0] * k[0], 1, k[0]}; + dims c_strides = dims{m[0] * n[0], n[0], 1}; + // Prepare memory descriptors - memory::desc a_md({group_size[0], m[0], k[0]}, MklDnnType(), a_strides); - memory::desc b_md({group_size[0], k[0], n[0]}, MklDnnType(), b_strides); - memory::desc c_md({group_size[0], m[0], n[0]}, MklDnnType(), c_strides); + memory::desc a_md(a_sizes, MklDnnType(), a_strides); + memory::desc b_md(b_sizes, MklDnnType(), b_strides); + memory::desc c_md(c_sizes, MklDnnType(), c_strides); // Create attributes (to handle alpha and beta if necessary) mkldnn::primitive_attr attr; if (alpha[0] != 1.f) attr.set_output_scales(/* mask */ 0, {alpha[0]}); @@ -610,7 +608,7 @@ void dnnl_gemm_batch(const std::vector& transa, template void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k, float alpha, const T* a, int64_t lda, const T* b, int64_t ldb, - float beta, T* c, int64_t ldc) { + float beta, float* c, int64_t ldc) { using dims = mkldnn::memory::dims; // Prepare strides based on the transa and transb flags: transposed // matrices have strides swapped @@ -619,7 +617,7 @@ void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k, // Prepare memory descriptors memory::desc a_md({m, k}, MklDnnType(), a_strides); memory::desc b_md({k, n}, MklDnnType(), b_strides); - memory::desc c_md({m, n}, MklDnnType(), {ldc, 1}); + memory::desc c_md({m, n}, MklDnnType(), {ldc, 1}); // Create attributes (to handle alpha and beta if necessary) mkldnn::primitive_attr attr; if (alpha != 1.f) attr.set_output_scales(/* mask */ 0, {alpha}); diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc index 098ea049246..dbccb35b88b 100644 --- a/tensorflow/core/kernels/mkl_maxpooling_op.cc +++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc @@ -297,13 +297,21 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase { : memory::desc(orig_input_dims_mkl_order, MklDnnType(), this->data_format_mkldnn_); + // Get diff_dst memory descriptor. + memory::desc diff_dst_md = + grad_mkl_shape.IsMklTensor() + ? grad_mkl_shape.GetMklLayout() + : memory::desc(diff_dst_dims, MklDnnType(), + this->data_format_mkldnn_); + #ifdef ENABLE_MKLDNN_V1 // TODO(DNNL): Find out what should be used for src_md.data.format. MklPoolingParams bwdParams( orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims, strides, padding_left, padding_right, ALGORITHM::pooling_max, prop_kind::forward_training, - static_cast(this->data_format_mkldnn_), src_md); + static_cast(this->data_format_mkldnn_), src_md, + diff_dst_md); #else MklPoolingParams bwdParams( orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims, @@ -320,13 +328,6 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase { orig_input_dims_mkl_order, this->tensor_format_mkldnn_, &output_tensor); - // Get diff_dst memory descriptor. - memory::desc diff_dst_md = - grad_mkl_shape.IsMklTensor() - ? grad_mkl_shape.GetMklLayout() - : memory::desc(diff_dst_dims, MklDnnType(), - this->data_format_mkldnn_); - // Check if diff_dst needs to be reordered. std::shared_ptr pooling_bwd_pd = pooling_bwd->GetPoolingBwdPd(); diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl_pooling_ops_common.cc index 438721f85fd..5bd9c17f95e 100644 --- a/tensorflow/core/kernels/mkl_pooling_ops_common.cc +++ b/tensorflow/core/kernels/mkl_pooling_ops_common.cc @@ -172,8 +172,12 @@ void MklPoolingBwdPrimitive::Setup(const MklPoolingParams& bwdParams) { // Create memory descriptor. context_.diff_src_md.reset(new memory::desc( {bwdParams.src_dims}, MklDnnType(), MEMORY_FORMAT::any)); +#ifndef ENABLE_MKLDNN_V1 context_.diff_dst_md.reset(new memory::desc( {bwdParams.dst_dims}, MklDnnType(), bwdParams.src_format)); +#else + context_.diff_dst_md.reset(new memory::desc(bwdParams.diff_dst_md.data)); +#endif // !ENABLE_MKLDNN_V1 #ifndef ENABLE_MKLDNN_V1 context_.bwd_desc.reset(new pooling_backward::desc( diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl_pooling_ops_common.h index ff51282ecc6..54f4dc8503e 100644 --- a/tensorflow/core/kernels/mkl_pooling_ops_common.h +++ b/tensorflow/core/kernels/mkl_pooling_ops_common.h @@ -49,12 +49,20 @@ struct MklPoolingParams { mkldnn::prop_kind prop_kind; MEMORY_FORMAT src_format; memory::desc src_md; +#ifdef ENABLE_MKLDNN_V1 + memory::desc diff_dst_md; +#endif // ENABLE_MKLDNN_V1 MklPoolingParams(memory::dims src_dims, memory::dims dst_dims, memory::dims filter_dims, memory::dims strides, memory::dims padding_left, memory::dims padding_right, mkldnn::algorithm alg_kind, mkldnn::prop_kind prop_kind, +#ifdef ENABLE_MKLDNN_V1 + MEMORY_FORMAT src_format, memory::desc src_md, + memory::desc diff_dst_md = memory::desc()) +#else MEMORY_FORMAT src_format, memory::desc src_md) +#endif // ENABLE_MKLDNN_V1 : src_dims(src_dims), dst_dims(dst_dims), filter_dims(filter_dims), @@ -64,7 +72,14 @@ struct MklPoolingParams { alg_kind(alg_kind), prop_kind(prop_kind), src_format(src_format), - src_md(src_md) {} +#ifdef ENABLE_MKLDNN_V1 + src_md(src_md), + diff_dst_md(diff_dst_md) { + } +#else + src_md(src_md) { + } +#endif // ENABLE_MKLDNN_V1 }; template diff --git a/tensorflow/core/kernels/mkl_tmp_bf16_ops.cc b/tensorflow/core/kernels/mkl_tmp_bf16_ops.cc index 29aaf3f363f..7f45979a57e 100644 --- a/tensorflow/core/kernels/mkl_tmp_bf16_ops.cc +++ b/tensorflow/core/kernels/mkl_tmp_bf16_ops.cc @@ -52,7 +52,11 @@ namespace tensorflow { .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .TypeConstraint("U"), \ - NoOp); + NoOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_FusedMatMul").Device(DEVICE_CPU).TypeConstraint("T"), NoOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint("T"), NoOp); TF_CALL_bfloat16(REGISTER_CPU); #undef REGISTER_CPU diff --git a/tensorflow/core/kernels/tile_functor_cpu.cc b/tensorflow/core/kernels/tile_functor_cpu.h similarity index 60% rename from tensorflow/core/kernels/tile_functor_cpu.cc rename to tensorflow/core/kernels/tile_functor_cpu.h index 2a5fb3f62d6..5b005e4a8b4 100644 --- a/tensorflow/core/kernels/tile_functor_cpu.cc +++ b/tensorflow/core/kernels/tile_functor_cpu.h @@ -12,17 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_TILE_FUNCTOR_CPU_H_ +#define TENSORFLOW_CORE_KERNELS_TILE_FUNCTOR_CPU_H_ #define EIGEN_USE_THREADS -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/kernels/tile_functor.h" namespace tensorflow { namespace internal { -namespace { template void TileSimpleImpl(const Device& d, Tensor* out, const Tensor& in) { @@ -44,8 +43,6 @@ void TileSimpleImpl(const Device& d, Tensor* out, const Tensor& in) { } } -} // namespace - template void TileSimple(const Eigen::ThreadPoolDevice& d, Tensor* out, const Tensor& in) { @@ -59,50 +56,6 @@ void TileSimple(const Eigen::SyclDevice& d, Tensor* out, const Tensor& in) { #endif } // namespace internal - -namespace functor { - -typedef Eigen::ThreadPoolDevice CPUDevice; - -// Register functors used for Tile functor. -#define DEFINE_TYPE(T) \ - template struct Tile; \ - template struct Tile; - -TF_CALL_bool(DEFINE_TYPE); -TF_CALL_float(DEFINE_TYPE); -TF_CALL_bfloat16(DEFINE_TYPE); -TF_CALL_double(DEFINE_TYPE); -TF_CALL_uint8(DEFINE_TYPE); -TF_CALL_int8(DEFINE_TYPE); -TF_CALL_int32(DEFINE_TYPE); -TF_CALL_int16(DEFINE_TYPE); -TF_CALL_int64(DEFINE_TYPE); -TF_CALL_half(DEFINE_TYPE); -TF_CALL_complex64(DEFINE_TYPE); -TF_CALL_complex128(DEFINE_TYPE); -TF_CALL_tstring(DEFINE_TYPE); - -#undef DEFINE_TYPE - -#ifdef TENSORFLOW_USE_SYCL -typedef Eigen::SyclDevice SYCLDevice; - -#define DEFINE_TYPE(T) \ - template struct Tile; \ - template struct Tile; - -TF_CALL_bool(DEFINE_TYPE); -TF_CALL_float(DEFINE_TYPE); -TF_CALL_bfloat16(DEFINE_TYPE); -TF_CALL_double(DEFINE_TYPE); -TF_CALL_uint8(DEFINE_TYPE); -TF_CALL_int32(DEFINE_TYPE); -TF_CALL_int16(DEFINE_TYPE); -TF_CALL_int64(DEFINE_TYPE); - -#undef DEFINE_TYPE -#endif // TENSORFLOW_USE_SYCL - -} // end namespace functor } // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_TILE_FUNCTOR_CPU_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/profiler/test_instrumented_library.h b/tensorflow/core/kernels/tile_functor_cpu_bfloat16.cc similarity index 59% rename from tensorflow/lite/experimental/ruy/ruy/profiler/test_instrumented_library.h rename to tensorflow/core/kernels/tile_functor_cpu_bfloat16.cc index d6a80a09042..aaac5ada99f 100644 --- a/tensorflow/lite/experimental/ruy/ruy/profiler/test_instrumented_library.h +++ b/tensorflow/core/kernels/tile_functor_cpu_bfloat16.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,11 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_ +#define EIGEN_USE_THREADS -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "tensorflow/core/kernels/tile_functor_cpu.h" -void MergeSort(int size, int* data); +namespace tensorflow { +namespace functor { -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_ +typedef Eigen::ThreadPoolDevice CPUDevice; + +template struct Tile; +template struct Tile; + +} // end namespace functor +} // end namespace tensorflow diff --git a/tensorflow/lite/experimental/ruy/ruy/detect_arm.h b/tensorflow/core/kernels/tile_functor_cpu_bool.cc similarity index 56% rename from tensorflow/lite/experimental/ruy/ruy/detect_arm.h rename to tensorflow/core/kernels/tile_functor_cpu_bool.cc index 9a1542d3cce..3ef6e6a9f72 100644 --- a/tensorflow/lite/experimental/ruy/ruy/detect_arm.h +++ b/tensorflow/core/kernels/tile_functor_cpu_bool.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,17 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Temporary dotprod-detection code until we can rely on getauxval. +#define EIGEN_USE_THREADS -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_ARM_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_ARM_H_ +#include "tensorflow/core/kernels/tile_functor_cpu.h" -namespace ruy { +namespace tensorflow { +namespace functor { -// On A64, returns true if the dotprod extension is present. -// On other architectures, returns false unconditionally. -bool DetectDotprod(); +typedef Eigen::ThreadPoolDevice CPUDevice; -} // namespace ruy +template struct Tile; +template struct Tile; -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_ARM_H_ +} // end namespace functor +} // end namespace tensorflow diff --git a/tensorflow/lite/experimental/ruy/ruy/have_built_path_for.h b/tensorflow/core/kernels/tile_functor_cpu_complex128.cc similarity index 53% rename from tensorflow/lite/experimental/ruy/ruy/have_built_path_for.h rename to tensorflow/core/kernels/tile_functor_cpu_complex128.cc index 08651facb7e..542b3e21062 100644 --- a/tensorflow/lite/experimental/ruy/ruy/have_built_path_for.h +++ b/tensorflow/core/kernels/tile_functor_cpu_complex128.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,20 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_HAVE_BUILT_PATH_FOR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_HAVE_BUILT_PATH_FOR_H_ +#define EIGEN_USE_THREADS -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" +#include "tensorflow/core/kernels/tile_functor_cpu.h" -namespace ruy { +namespace tensorflow { +namespace functor { -#if RUY_PLATFORM(X86) -bool HaveBuiltPathForSse42(); -bool HaveBuiltPathForAvx2(); -bool HaveBuiltPathForAvx512(); -bool HaveBuiltPathForAvxVnni(); -#endif // RUY_PLATFORM(X86) +typedef Eigen::ThreadPoolDevice CPUDevice; -} // namespace ruy +template struct Tile; +template struct Tile; -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_HAVE_BUILT_PATH_FOR_H_ +} // end namespace functor +} // end namespace tensorflow diff --git a/tensorflow/lite/experimental/ruy/ruy/kernel.h b/tensorflow/core/kernels/tile_functor_cpu_complex64.cc similarity index 51% rename from tensorflow/lite/experimental/ruy/ruy/kernel.h rename to tensorflow/core/kernels/tile_functor_cpu_complex64.cc index dd9a60b8d09..d97e98d2fdd 100644 --- a/tensorflow/lite/experimental/ruy/ruy/kernel.h +++ b/tensorflow/core/kernels/tile_functor_cpu_complex64.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,19 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_H_ +#define EIGEN_USE_THREADS -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" +#include "tensorflow/core/kernels/tile_functor_cpu.h" -// IWYU pragma: begin_exports -#if RUY_PLATFORM(NEON) -#include "tensorflow/lite/experimental/ruy/ruy/kernel_arm.h" -#elif RUY_PLATFORM(X86) -#include "tensorflow/lite/experimental/ruy/ruy/kernel_x86.h" -#else -#include "tensorflow/lite/experimental/ruy/ruy/kernel_common.h" -#endif -// IWYU pragma: end_exports +namespace tensorflow { +namespace functor { -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_H_ +typedef Eigen::ThreadPoolDevice CPUDevice; + +template struct Tile; +template struct Tile; + +} // end namespace functor +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/tile_functor_cpu_double.cc b/tensorflow/core/kernels/tile_functor_cpu_double.cc new file mode 100644 index 00000000000..1fb9618257a --- /dev/null +++ b/tensorflow/core/kernels/tile_functor_cpu_double.cc @@ -0,0 +1,29 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/tile_functor_cpu.h" + +namespace tensorflow { +namespace functor { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template struct Tile; +template struct Tile; + +} // end namespace functor +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/tile_functor_cpu_float.cc b/tensorflow/core/kernels/tile_functor_cpu_float.cc new file mode 100644 index 00000000000..047004eb4e1 --- /dev/null +++ b/tensorflow/core/kernels/tile_functor_cpu_float.cc @@ -0,0 +1,29 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/tile_functor_cpu.h" + +namespace tensorflow { +namespace functor { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template struct Tile; +template struct Tile; + +} // end namespace functor +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/tile_functor_cpu_half.cc b/tensorflow/core/kernels/tile_functor_cpu_half.cc new file mode 100644 index 00000000000..0b63039c942 --- /dev/null +++ b/tensorflow/core/kernels/tile_functor_cpu_half.cc @@ -0,0 +1,29 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/tile_functor_cpu.h" + +namespace tensorflow { +namespace functor { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template struct Tile; +template struct Tile; + +} // end namespace functor +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/tile_functor_cpu_int16.cc b/tensorflow/core/kernels/tile_functor_cpu_int16.cc new file mode 100644 index 00000000000..6787601845e --- /dev/null +++ b/tensorflow/core/kernels/tile_functor_cpu_int16.cc @@ -0,0 +1,29 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/tile_functor_cpu.h" + +namespace tensorflow { +namespace functor { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template struct Tile; +template struct Tile; + +} // end namespace functor +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/tile_functor_cpu_int32.cc b/tensorflow/core/kernels/tile_functor_cpu_int32.cc new file mode 100644 index 00000000000..82d66d8a8c3 --- /dev/null +++ b/tensorflow/core/kernels/tile_functor_cpu_int32.cc @@ -0,0 +1,29 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/tile_functor_cpu.h" + +namespace tensorflow { +namespace functor { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template struct Tile; +template struct Tile; + +} // end namespace functor +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/tile_functor_cpu_int64.cc b/tensorflow/core/kernels/tile_functor_cpu_int64.cc new file mode 100644 index 00000000000..1427f240d87 --- /dev/null +++ b/tensorflow/core/kernels/tile_functor_cpu_int64.cc @@ -0,0 +1,29 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/tile_functor_cpu.h" + +namespace tensorflow { +namespace functor { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template struct Tile; +template struct Tile; + +} // end namespace functor +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/tile_functor_cpu_int8.cc b/tensorflow/core/kernels/tile_functor_cpu_int8.cc new file mode 100644 index 00000000000..e6cf0047abf --- /dev/null +++ b/tensorflow/core/kernels/tile_functor_cpu_int8.cc @@ -0,0 +1,29 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/tile_functor_cpu.h" + +namespace tensorflow { +namespace functor { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template struct Tile; +template struct Tile; + +} // end namespace functor +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/tile_functor_cpu_tstring.cc b/tensorflow/core/kernels/tile_functor_cpu_tstring.cc new file mode 100644 index 00000000000..3ac5ad40e3e --- /dev/null +++ b/tensorflow/core/kernels/tile_functor_cpu_tstring.cc @@ -0,0 +1,29 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/tile_functor_cpu.h" + +namespace tensorflow { +namespace functor { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template struct Tile; +template struct Tile; + +} // end namespace functor +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/tile_functor_cpu_uint8.cc b/tensorflow/core/kernels/tile_functor_cpu_uint8.cc new file mode 100644 index 00000000000..f6099a7965f --- /dev/null +++ b/tensorflow/core/kernels/tile_functor_cpu_uint8.cc @@ -0,0 +1,29 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/tile_functor_cpu.h" + +namespace tensorflow { +namespace functor { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template struct Tile; +template struct Tile; + +} // end namespace functor +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/tile_functor_sycl.cc b/tensorflow/core/kernels/tile_functor_sycl.cc new file mode 100644 index 00000000000..21574250773 --- /dev/null +++ b/tensorflow/core/kernels/tile_functor_sycl.cc @@ -0,0 +1,42 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/tile_functor_cpu.h" + +namespace tensorflow { +namespace functor { + +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; + +#define DEFINE_TYPE(T) \ + template struct Tile; \ + template struct Tile; + +TF_CALL_bool(DEFINE_TYPE); +TF_CALL_float(DEFINE_TYPE); +TF_CALL_bfloat16(DEFINE_TYPE); +TF_CALL_double(DEFINE_TYPE); +TF_CALL_uint8(DEFINE_TYPE); +TF_CALL_int32(DEFINE_TYPE); +TF_CALL_int16(DEFINE_TYPE); +TF_CALL_int64(DEFINE_TYPE); + +#undef DEFINE_TYPE +#endif // TENSORFLOW_USE_SYCL + +} // end namespace functor +} // end namespace tensorflow diff --git a/tensorflow/core/nccl/BUILD b/tensorflow/core/nccl/BUILD index b21936167d2..3acf7579f62 100644 --- a/tensorflow/core/nccl/BUILD +++ b/tensorflow/core/nccl/BUILD @@ -72,6 +72,6 @@ tf_cuda_cc_test( "//tensorflow/core:cuda", ]) + if_rocm([ "@local_config_rocm//rocm:rccl", - "//tensorflow/core:rocm", + "//tensorflow/core/common_runtime/gpu:rocm", ]), ) diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc index b5181b1edd3..9fd12a20cad 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc @@ -153,7 +153,7 @@ absl::flat_hash_map CollectTfOpsFromHostThreadsXPlane( // user-inserted TraceMe's have "unknown" type. We don't count them in // Tf-stats. TfOp tf_op = ParseTfOpFullname(metadata.name()); - if (!IsUnknownOp(tf_op.type)) { + if (tf_op.category != Category::kUnknown) { tf_ops.try_emplace(metadata.id(), tf_op); } } @@ -214,7 +214,7 @@ OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb( if (tf_op_fullname.empty()) return; TfOp tf_op = ParseTfOpFullname(tf_op_fullname); TfOpRoofLineCostEstimator::OpRoofLineStats costs; - if (tf_op.type != kUnknownOp) { + if (tf_op.category != Category::kUnknown) { costs = op_level_cost_estimator.Predict(event); } device_op_metrics_db_builder.EnterOp( diff --git a/tensorflow/core/profiler/internal/cpu/BUILD b/tensorflow/core/profiler/internal/cpu/BUILD index d81e5d82be5..1001fec7535 100644 --- a/tensorflow/core/profiler/internal/cpu/BUILD +++ b/tensorflow/core/profiler/internal/cpu/BUILD @@ -78,3 +78,20 @@ cc_library( ], alwayslink = True, ) + +cc_library( + name = "metadata_collector", + srcs = ["metadata_collector.cc"], + deps = [ + "//tensorflow/compiler/xla/service/gpu:gpu_debug_info_manager", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/profiler/internal:profiler_factory", + "//tensorflow/core/profiler/internal:profiler_interface", + "//tensorflow/core/profiler/protobuf:xplane_proto_cc", + "//tensorflow/core/profiler/utils:xplane_builder", + "//tensorflow/core/profiler/utils:xplane_schema", + "//tensorflow/core/profiler/utils:xplane_utils", + ], + alwayslink = True, +) diff --git a/tensorflow/core/profiler/internal/cpu/metadata_collector.cc b/tensorflow/core/profiler/internal/cpu/metadata_collector.cc new file mode 100644 index 00000000000..fbcfaa26e73 --- /dev/null +++ b/tensorflow/core/profiler/internal/cpu/metadata_collector.cc @@ -0,0 +1,100 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/profiler/internal/profiler_factory.h" +#include "tensorflow/core/profiler/internal/profiler_interface.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" +#include "tensorflow/core/profiler/utils/xplane_builder.h" +#include "tensorflow/core/profiler/utils/xplane_schema.h" +#include "tensorflow/core/profiler/utils/xplane_utils.h" + +namespace tensorflow { +namespace profiler { +namespace { + +// MetadataCollector collect miscellaneous metadata for xprof, e.g. HLO protos +// from XLA runtime etc. +// +// Thread-safety: This class is go/thread-compatible. +class MetadataCollector : public ProfilerInterface { + public: + MetadataCollector() = default; + + Status Start() override { + if (!trace_active_) { + xla::gpu::GpuDebugInfoManager::Get()->StartTracing(); + trace_active_ = true; + } + return Status::OK(); + } + + Status Stop() override { + if (trace_active_) { + xla::gpu::GpuDebugInfoManager::Get()->StopTracing(&debug_info_); + trace_active_ = false; + } + return Status::OK(); + } + + Status CollectData(RunMetadata* run_metadata) override { + return Status::OK(); // legacy session is not supported. + } + + Status CollectData(XSpace* space) override { + if (!debug_info_.empty()) { + XPlane* plane = GetOrCreatePlane(space, kMetadataPlane); + plane->set_id(kMetadataPlaneId); + XPlaneBuilder xplane(plane); + for (auto& p : debug_info_) { + std::string hlo_proto; + p.hlo_proto->SerializeToString(&hlo_proto); + p.hlo_proto.reset(); + xplane.AddStatValue(*xplane.GetOrCreateStatMetadata(kHloProto), + std::move(hlo_proto), /*is_bytes=*/true); + } + debug_info_.clear(); + } + return Status::OK(); + } + + DeviceType GetDeviceType() override { return DeviceType::kCpu; } + + private: + std::vector debug_info_; + bool trace_active_ = false; + + TF_DISALLOW_COPY_AND_ASSIGN(MetadataCollector); +}; + +std::unique_ptr CreatMetadataCollector( + const profiler::ProfilerOptions& options) { + return options.enable_hlo_proto ? absl::make_unique() + : nullptr; +} + +} // namespace + +auto register_metadata_collector_factory = [] { + RegisterProfilerFactory(&CreatMetadataCollector); + return 0; +}(); + +} // namespace profiler +} // namespace tensorflow diff --git a/tensorflow/core/profiler/internal/gpu/device_tracer_test.cc b/tensorflow/core/profiler/internal/gpu/device_tracer_test.cc index 24f8d8771fb..2e422160a59 100644 --- a/tensorflow/core/profiler/internal/gpu/device_tracer_test.cc +++ b/tensorflow/core/profiler/internal/gpu/device_tracer_test.cc @@ -72,28 +72,25 @@ class DeviceTracerTest : public ::testing::Test { Tensor a_tensor(DT_FLOAT, TensorShape({2, 2})); test::FillValues(&a_tensor, a_values); - Node* a = test::graph::Constant(&graph, a_tensor); - a->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0"); + Node* a = test::graph::HostConstant(&graph, a_tensor); Tensor x_tensor(DT_FLOAT, TensorShape({2, 1})); test::FillValues(&x_tensor, {1, 1}); - Node* x = test::graph::Constant(&graph, x_tensor); - x->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0"); + Node* x = test::graph::HostConstant(&graph, x_tensor); x_ = x->name(); // y = A * x Node* y = test::graph::Matmul(&graph, a, x, false, false); - y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0"); + y->set_assigned_device_name("/device:GPU:0"); y_ = y->name(); // Use an Identity op to force a memcpy to CPU and back to GPU. Node* i = test::graph::Identity(&graph, y); - i->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0"); + i->set_assigned_device_name("/cpu:0"); Node* y_neg = test::graph::Unary(&graph, "Neg", i); y_neg_ = y_neg->name(); - y_neg->set_assigned_device_name( - "/job:localhost/replica:0/task:0/device:GPU:0"); + y_neg->set_assigned_device_name("/device:GPU:0"); test::graph::ToGraphDef(&graph, &def_); } @@ -278,6 +275,9 @@ TEST_F(DeviceTracerTest, TraceToXSpace) { FindPlaneWithName(space, strings::StrCat(kGpuPlanePrefix, 0)); ASSERT_NE(device_plane, nullptr); // Check if device plane is serialized. EXPECT_EQ(device_plane->id(), kGpuPlaneBaseId); + // one for MemcpyH2D, one for MemcpyD2H, two for Matmul (one from Eigen, one + // from cudnn). + EXPECT_EQ(device_plane->event_metadata_size(), 4); // Check if device capacity is serialized. XPlaneVisitor plane = CreateTfXPlaneVisitor(device_plane); EXPECT_NE(plane.GetStats(kDevCapClockRateKHz), nullptr); @@ -288,12 +288,15 @@ TEST_F(DeviceTracerTest, TraceToXSpace) { EXPECT_NE(plane.GetStats(kDevCapComputeCapMinor), nullptr); // Check if the device events timestamps are set. + int total_events = 0; plane.ForEachLine([&](const tensorflow::profiler::XLineVisitor& line) { line.ForEachEvent([&](const tensorflow::profiler::XEventVisitor& event) { EXPECT_GT(event.TimestampNs(), 0); EXPECT_GT(event.DurationNs(), 0); + ++total_events; }); }); + EXPECT_EQ(total_events, 5); } } // namespace diff --git a/tensorflow/core/profiler/internal/profiler_interface.h b/tensorflow/core/profiler/internal/profiler_interface.h index 74aa56514b7..c42c278f847 100644 --- a/tensorflow/core/profiler/internal/profiler_interface.h +++ b/tensorflow/core/profiler/internal/profiler_interface.h @@ -55,6 +55,9 @@ struct ProfilerOptions { // Whether to enable python function calls tracer. bool enable_python_tracer = false; + + // Whether to capture HLO protos from XLA runtime. + bool enable_hlo_proto = true; }; // Interface for tensorflow profiler plugins. diff --git a/tensorflow/core/profiler/lib/BUILD b/tensorflow/core/profiler/lib/BUILD index 7ccbb81a281..4bb1d92c0cb 100644 --- a/tensorflow/core/profiler/lib/BUILD +++ b/tensorflow/core/profiler/lib/BUILD @@ -69,6 +69,7 @@ tf_cuda_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/core/profiler/internal/cpu:host_tracer", + "//tensorflow/core/profiler/internal/cpu:metadata_collector", ], alwayslink = True, ) diff --git a/tensorflow/core/profiler/rpc/BUILD b/tensorflow/core/profiler/rpc/BUILD index 3ecec434ade..d8af53fe8f9 100644 --- a/tensorflow/core/profiler/rpc/BUILD +++ b/tensorflow/core/profiler/rpc/BUILD @@ -1,4 +1,5 @@ -load("//tensorflow:tensorflow.bzl", "tf_external_workspace_visible") +load("//tensorflow:tensorflow.bzl", "tf_external_workspace_visible") # buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency") # buildifier: disable=same-origin-load package( licenses = ["notice"], # Apache 2.0 @@ -13,7 +14,6 @@ cc_library( ["//tensorflow_serving/model_servers:__pkg__"], ), deps = [ - "//tensorflow:grpc++", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/profiler:profiler_service_proto_cc", @@ -22,6 +22,7 @@ cc_library( "//tensorflow/core/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", + tf_grpc_cc_dependency(), ], ) @@ -35,10 +36,10 @@ cc_library( ], deps = [ ":profiler_service_impl", - "//tensorflow:grpc++", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/profiler:profiler_service_proto_cc", "@com_google_absl//absl/strings", + tf_grpc_cc_dependency(), ], ) diff --git a/tensorflow/core/profiler/rpc/client/BUILD b/tensorflow/core/profiler/rpc/client/BUILD index 370aa00a602..43ebb35230c 100644 --- a/tensorflow/core/profiler/rpc/client/BUILD +++ b/tensorflow/core/profiler/rpc/client/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency") + package( licenses = ["notice"], # Apache 2.0 ) @@ -9,12 +11,12 @@ cc_library( visibility = ["//tensorflow/python/profiler/internal:__pkg__"], deps = [ ":save_profile", - "//tensorflow:grpc++", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/profiler:profiler_analysis_proto_cc", "//tensorflow/core/profiler:profiler_service_proto_cc", "@com_google_absl//absl/strings", + tf_grpc_cc_dependency(), ], ) diff --git a/tensorflow/core/profiler/utils/derived_timeline.cc b/tensorflow/core/profiler/utils/derived_timeline.cc index f61926a1850..70b12560bb8 100644 --- a/tensorflow/core/profiler/utils/derived_timeline.cc +++ b/tensorflow/core/profiler/utils/derived_timeline.cc @@ -44,10 +44,8 @@ class DerivedXLineBuilder { public: DerivedXLineBuilder(XPlaneBuilder* plane, int64 line_id, absl::string_view name, int64 timestamp_ns, - std::vector dependent_lines, - bool try_expand) - : line_(plane->GetOrCreateLine(line_id)), - try_expand_(try_expand) { + std::vector dependent_lines) + : line_(plane->GetOrCreateLine(line_id)) { line_.SetName(name); line_.SetTimestampNs(timestamp_ns); dependent_lines_ = std::move(dependent_lines); @@ -71,12 +69,12 @@ class DerivedXLineBuilder { } private: - // If the last event of the given level has the same metadata and try_expand_ - // is true, expands it to include the time until the given event's (offset_ps - // + duration_ps). Otherwise, adds a new event and clears last_event_by_level_ - // for the levels below the given level and all levels of the dependent lines. - // Clearing last_event_by_level_ prevents a nested event from growing larger - // than the parent event(s). + // If the last event of the given level has the same metadata, expands it to + // include the time until the given event's (offset_ps + duration_ps). + // Otherwise, adds a new event and clears last_event_by_level_ for the levels + // below the given level and all levels of the dependent lines. Clearing + // last_event_by_level_ prevents a nested event from growing larger than the + // parent event(s). void ExpandOrAddLevelEvent(const XEvent& event, int level) { int64 offset_ps = event.offset_ps(); int64 duration_ps = event.duration_ps(); @@ -84,8 +82,7 @@ class DerivedXLineBuilder { // If last_event is not nullptr, its offset must be less than or equal to // the given event's offset. DCHECK(!last_event || last_event->OffsetPs() <= offset_ps); - if (try_expand_ && last_event && - last_event->MetadataId() == event.metadata_id()) { + if (last_event && last_event->MetadataId() == event.metadata_id()) { // If last_event is not nullptr and metadata is same, merge the given // event into last_event. last_event->SetDurationPs((offset_ps + duration_ps) - @@ -108,7 +105,6 @@ class DerivedXLineBuilder { XLineBuilder line_; absl::flat_hash_map> last_event_by_level_; std::vector dependent_lines_; - bool try_expand_; }; const absl::string_view kDerivedLineSteps = "Steps"; @@ -147,7 +143,8 @@ void ProcessTfOpEvent(const XEventVisitor& event, plane_builder->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId)) ->id(); TfOp tf_op = ParseTfOpFullname(tf_op_full_name); - if (tf_op.is_tf_op) { + Category category = tf_op.category; + if (category == Category::kTensorFlow || category == Category::kJax) { std::vector name_scope_event_per_level; for (const auto& tf_name_scope : ParseTfNameScopes(tf_op)) { name_scope_event_per_level.push_back(CreateXEvent( @@ -184,19 +181,18 @@ void DeriveEventsFromAnnotations(const SymbolResolver& symbol_resolver, XPlaneBuilder plane(device_trace); DerivedXLineBuilder tf_ops(&plane, kThreadIdTfOp, kDerivedLineTensorFlowOps, - start_timestamp_ns, {}, /*try_expand=*/true); - DerivedXLineBuilder tf_name_scope( - &plane, kThreadIdTfNameScope, kDerivedLineTensorFlowNameScope, - start_timestamp_ns, {&tf_ops}, /*try_expand=*/true); + start_timestamp_ns, {}); + DerivedXLineBuilder tf_name_scope(&plane, kThreadIdTfNameScope, + kDerivedLineTensorFlowNameScope, + start_timestamp_ns, {&tf_ops}); DerivedXLineBuilder hlo_ops(&plane, kThreadIdHloOp, kDerivedLineXlaOps, - start_timestamp_ns, {}, /*try_expand=*/true); - DerivedXLineBuilder hlo_modules( - &plane, kThreadIdHloModule, kDerivedLineXlaModules, start_timestamp_ns, - {&tf_ops, &tf_name_scope, &hlo_ops}, /*try_expand=*/false); + start_timestamp_ns, {}); + DerivedXLineBuilder hlo_modules(&plane, kThreadIdHloModule, + kDerivedLineXlaModules, start_timestamp_ns, + {&tf_ops, &tf_name_scope, &hlo_ops}); DerivedXLineBuilder steps(&plane, kThreadIdStepInfo, kDerivedLineSteps, start_timestamp_ns, - {&tf_ops, &tf_name_scope, &hlo_ops}, - /*try_expand=*/true); + {&tf_ops, &tf_name_scope, &hlo_ops, &hlo_modules}); int64 group_id_stat_metadata_id = plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId))->id(); int64 step_name_stat_metadata_id = diff --git a/tensorflow/core/profiler/utils/derived_timeline_test.cc b/tensorflow/core/profiler/utils/derived_timeline_test.cc index a4962c641a0..b92f4c5f801 100644 --- a/tensorflow/core/profiler/utils/derived_timeline_test.cc +++ b/tensorflow/core/profiler/utils/derived_timeline_test.cc @@ -36,7 +36,7 @@ TEST(DerivedTimelineTest, EmptySpaceTest) { EXPECT_EQ(space.planes_size(), 0); } -// Checks that HLO module events are not expanded. +// Checks that HLO module events are expanded. TEST(DerivedTimelineTest, HloModuleNameTest) { const absl::string_view kHloModuleName = "hlo_module"; const absl::string_view kKernelDetails = "kernel_details"; @@ -69,7 +69,7 @@ TEST(DerivedTimelineTest, HloModuleNameTest) { plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { if (line_visitor.Id() == 0) return; EXPECT_EQ(line_visitor.Id(), kThreadIdHloModule); - EXPECT_EQ(line_visitor.NumEvents(), 2); + EXPECT_EQ(line_visitor.NumEvents(), 1); line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { EXPECT_EQ(event_visitor.Name(), kHloModuleName); }); diff --git a/tensorflow/core/profiler/utils/tf_op_utils.cc b/tensorflow/core/profiler/utils/tf_op_utils.cc index 8a9556fb4cd..5a4204440a3 100644 --- a/tensorflow/core/profiler/utils/tf_op_utils.cc +++ b/tensorflow/core/profiler/utils/tf_op_utils.cc @@ -47,14 +47,16 @@ TfOp ParseTfOpFullname(absl::string_view tf_op_fullname) { // JAX op types have only lowercase letters and underscores. static const LazyRE2 kJaxOpTypeRegEx = {"[a-z_]*"}; - TfOp tf_op = {tf_op_fullname, kUnknownOp, /*is_tf_op=*/false}; + TfOp tf_op = {Category::kUnknown, tf_op_fullname, kUnknownOp}; std::vector parts = absl::StrSplit(tf_op_fullname, absl::MaxSplits(':', 1)); if (parts.size() != 2) { // GPU-related Ops that need to be tracked. if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYHToD")) { + tf_op.category = Category::kMemcpyHToD; tf_op.type = kMemcpyHToDOp; } else if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYDToH")) { + tf_op.category = Category::kMemcpyDToH; tf_op.type = kMemcpyDToHOp; } // TODO(ckluk): Include the corresponding Ops on TPU. @@ -62,12 +64,13 @@ TfOp ParseTfOpFullname(absl::string_view tf_op_fullname) { // Dataset Op names (e.g., Iterator::Batch::Map::TFRecord) do not follow the // format of TF Op names. But we still want to capture them for // input-pipeline analysis. + tf_op.category = Category::kTfData; tf_op.type = kDatasetOp; } else if (RE2::FullMatch(parts[1], *kTfOpTypeRegEx) && RE2::FullMatch(parts[0], *kTfOpNameRegEx)) { // TensorFlow - tf_op = {parts[0], parts[1], /*is_tf_op=*/true}; + tf_op = {Category::kTensorFlow, parts[0], parts[1]}; } else if (RE2::FullMatch(parts[1], *kJaxOpTypeRegEx)) { // JAX - tf_op = {parts[0], parts[1], /*is_tf_op=*/false}; + tf_op = {Category::kJax, parts[0], parts[1]}; } return tf_op; } @@ -81,10 +84,10 @@ std::vector ParseTfNameScopes(const TfOp& tf_op) { std::string TfOpEventName(const TfOp& tf_op) { std::string event_name; - if (tf_op.type == kUnknownOp) { + if (tf_op.category == Category::kUnknown) { // Some TraceMe names contain trailing whitespace, remove it. event_name = std::string(absl::StripTrailingAsciiWhitespace(tf_op.name)); - } else if (tf_op.type == kDatasetOp) { + } else if (tf_op.category == Category::kTfData) { std::vector op_parts = absl::StrSplit(tf_op.name, kSeparator); event_name = absl::StrCat(kIterator, kSeparator, op_parts.back()); diff --git a/tensorflow/core/profiler/utils/tf_op_utils.h b/tensorflow/core/profiler/utils/tf_op_utils.h index 4647dbbcc40..d1ac69e2976 100644 --- a/tensorflow/core/profiler/utils/tf_op_utils.h +++ b/tensorflow/core/profiler/utils/tf_op_utils.h @@ -31,11 +31,20 @@ ABSL_CONST_INIT extern const absl::string_view kDatasetOp; ABSL_CONST_INIT extern const absl::string_view kMemcpyHToDOp; ABSL_CONST_INIT extern const absl::string_view kMemcpyDToHOp; +enum class Category { + kTensorFlow, + kJax, + kTfData, + kMemcpyHToD, + kMemcpyDToH, + kUnknown, +}; + // Breaks a TensorFlow op fullname into name and type. struct TfOp { + Category category; absl::string_view name; absl::string_view type; - bool is_tf_op; }; TfOp ParseTfOpFullname(absl::string_view tf_op_fullname); @@ -48,11 +57,6 @@ std::vector ParseTfNameScopes(const TfOp& tf_op); std::string TfOpEventName(const TfOp& tf_op); std::string TfOpEventName(absl::string_view tf_op_fullname); -// Returns true if the given name is not a TensorFlow op. -inline bool IsUnknownOp(absl::string_view tf_op_type) { - return tf_op_type == kUnknownOp; -} - // Returns true if the given name is a TensorFlow Dataset Op. inline bool IsDatasetOp(absl::string_view tf_op_type) { return tf_op_type == kDatasetOp; diff --git a/tensorflow/core/profiler/utils/tf_op_utils_test.cc b/tensorflow/core/profiler/utils/tf_op_utils_test.cc index ff62c822e65..fa5169557d1 100644 --- a/tensorflow/core/profiler/utils/tf_op_utils_test.cc +++ b/tensorflow/core/profiler/utils/tf_op_utils_test.cc @@ -24,6 +24,7 @@ namespace { TEST(TfOpUtilsTest, TfOpTest) { const absl::string_view kName = "OpName:OpType"; TfOp tf_op = ParseTfOpFullname(kName); + EXPECT_EQ(tf_op.category, Category::kTensorFlow); EXPECT_EQ(tf_op.name, "OpName"); EXPECT_EQ(tf_op.type, "OpType"); EXPECT_EQ(TfOpEventName(kName), "OpType"); // type only @@ -32,6 +33,7 @@ TEST(TfOpUtilsTest, TfOpTest) { TEST(TfOpUtilsTest, InternalTfOpTest) { const absl::string_view kName = "OpName:_InternalOpType"; TfOp tf_op = ParseTfOpFullname(kName); + EXPECT_EQ(tf_op.category, Category::kTensorFlow); EXPECT_EQ(tf_op.name, "OpName"); EXPECT_EQ(tf_op.type, "_InternalOpType"); EXPECT_EQ(TfOpEventName(kName), "_InternalOpType"); // type only @@ -40,6 +42,7 @@ TEST(TfOpUtilsTest, InternalTfOpTest) { TEST(TfOpUtilsTest, TfOpWithPathTest) { const absl::string_view kName = "path/to/name:OpType"; TfOp tf_op = ParseTfOpFullname(kName); + EXPECT_EQ(tf_op.category, Category::kTensorFlow); EXPECT_EQ(tf_op.name, "path/to/name"); EXPECT_EQ(tf_op.type, "OpType"); EXPECT_EQ(TfOpEventName(kName), "OpType"); // type only @@ -48,24 +51,27 @@ TEST(TfOpUtilsTest, TfOpWithPathTest) { TEST(TfOpUtilsTest, ShortDatasetOpTest) { const absl::string_view kName = "Iterator::Batch"; TfOp tf_op = ParseTfOpFullname(kName); + EXPECT_EQ(tf_op.category, Category::kTfData); EXPECT_EQ(tf_op.name, kName); - EXPECT_TRUE(IsDatasetOp(tf_op.type)); + EXPECT_EQ(tf_op.type, kDatasetOp); EXPECT_EQ(TfOpEventName(kName), kName); } TEST(TfOpUtilsTest, LongDatasetOpTest) { const absl::string_view kName = "Iterator::Batch::Map::TfRecord"; TfOp tf_op = ParseTfOpFullname(kName); + EXPECT_EQ(tf_op.category, Category::kTfData); EXPECT_EQ(tf_op.name, kName); - EXPECT_TRUE(IsDatasetOp(tf_op.type)); + EXPECT_EQ(tf_op.type, kDatasetOp); EXPECT_EQ(TfOpEventName(kName), "Iterator::TfRecord"); // shorter name } TEST(TfOpUtilsTest, TraceMeTest) { const absl::string_view kName = "MyTraceMe"; TfOp tf_op = ParseTfOpFullname(kName); + EXPECT_EQ(tf_op.category, Category::kUnknown); EXPECT_EQ(tf_op.name, kName); - EXPECT_TRUE(IsUnknownOp(tf_op.type)); + EXPECT_EQ(tf_op.type, kUnknownOp); EXPECT_EQ(TfOpEventName(kName), kName); } @@ -73,16 +79,18 @@ TEST(TfOpUtilsTest, TraceMeWithColonTest) { // "12345" is not a valid op type. const absl::string_view kName = "RunStep/Server:54635"; TfOp tf_op = ParseTfOpFullname(kName); + EXPECT_EQ(tf_op.category, Category::kUnknown); EXPECT_EQ(tf_op.name, kName); - EXPECT_TRUE(IsUnknownOp(tf_op.type)); + EXPECT_EQ(tf_op.type, kUnknownOp); EXPECT_EQ(TfOpEventName(kName), kName); } TEST(TfOpUtilsTest, TraceMeWithDoubleColonTest) { const absl::string_view kName = "XLA::StartProgram"; TfOp tf_op = ParseTfOpFullname(kName); + EXPECT_EQ(tf_op.category, Category::kUnknown); EXPECT_EQ(tf_op.name, kName); - EXPECT_TRUE(IsUnknownOp(tf_op.type)); + EXPECT_EQ(tf_op.type, kUnknownOp); EXPECT_EQ(TfOpEventName(kName), kName); } @@ -90,11 +98,39 @@ TEST(TfOpUtilsTest, TraceMeWithTrailingWhitespaceTest) { const absl::string_view kName = "SessionRun "; const absl::string_view kNameTrimmed = "SessionRun"; TfOp tf_op = ParseTfOpFullname(kName); + EXPECT_EQ(tf_op.category, Category::kUnknown); EXPECT_EQ(tf_op.name, kName); - EXPECT_TRUE(IsUnknownOp(tf_op.type)); + EXPECT_EQ(tf_op.type, kUnknownOp); EXPECT_EQ(TfOpEventName(kName), kNameTrimmed); } +TEST(TfOpUtilsTest, MemcpyHToDTest) { + const absl::string_view kName = "MemcpyHToD"; + TfOp tf_op = ParseTfOpFullname(kName); + EXPECT_EQ(tf_op.category, Category::kMemcpyHToD); + EXPECT_EQ(tf_op.name, kName); + EXPECT_EQ(tf_op.type, kMemcpyHToDOp); + EXPECT_EQ(TfOpEventName(kName), kName); +} + +TEST(TfOpUtilsTest, MemcpyDToHTest) { + const absl::string_view kName = "MemcpyDToH"; + TfOp tf_op = ParseTfOpFullname(kName); + EXPECT_EQ(tf_op.category, Category::kMemcpyDToH); + EXPECT_EQ(tf_op.name, kName); + EXPECT_EQ(tf_op.type, kMemcpyDToHOp); + EXPECT_EQ(TfOpEventName(kName), kName); +} + +TEST(TfOpUtilsTest, JaxOpTest) { + const absl::string_view kName = "op_name:op_type"; + TfOp tf_op = ParseTfOpFullname(kName); + EXPECT_EQ(tf_op.category, Category::kJax); + EXPECT_EQ(tf_op.name, "op_name"); + EXPECT_EQ(tf_op.type, "op_type"); + EXPECT_EQ(TfOpEventName(kName), "op_type"); +} + } // namespace } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/time_utils.h b/tensorflow/core/profiler/utils/time_utils.h index 802af4e8260..0a2518b90ff 100644 --- a/tensorflow/core/profiler/utils/time_utils.h +++ b/tensorflow/core/profiler/utils/time_utils.h @@ -16,24 +16,23 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_TIME_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_TIME_UTILS_H_ -#include - #include "tensorflow/core/platform/types.h" namespace tensorflow { namespace profiler { // Converts among different time units. -inline double PicosToMillis(uint64 ps) { return ps / 1E9; } -inline double PicosToSecond(uint64 ps) { return ps / 1E12; } -inline double PicosToMicros(uint64 ps) { return ps / 1E6; } inline double PicosToNanos(uint64 ps) { return ps / 1E3; } -inline uint64 NanosToPicos(double ns) { return std::llround(ns * 1E3); } +inline double PicosToMicros(uint64 ps) { return ps / 1E6; } +inline double PicosToMillis(uint64 ps) { return ps / 1E9; } +inline double PicosToSeconds(uint64 ps) { return ps / 1E12; } +inline uint64 NanosToPicos(uint64 ns) { return ns * 1000; } +inline double NanosToMicros(uint64 ns) { return ns / 1E3; } inline double MicrosToMillis(double us) { return us / 1E3; } -inline uint64 MillisToPicos(double ms) { return std::llround(ms * 1E9); } -inline double MilliToSecond(double ms) { return ms / 1E3; } -inline uint64 MilliToNanos(uint64 ms) { return ms * 1E6; } -inline uint64 SecondsToNanos(uint64 s) { return s * 1E9; } +inline uint64 MillisToPicos(uint64 ms) { return ms * 1000000000; } +inline uint64 MillisToNanos(uint64 ms) { return ms * 1000000; } +inline double MillisToSeconds(uint64 ms) { return ms / 1E3; } +inline uint64 SecondsToNanos(double s) { return s * 1E9; } } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/xplane_schema.cc b/tensorflow/core/profiler/utils/xplane_schema.cc index 0c0924cfb38..da0ba034dbe 100644 --- a/tensorflow/core/profiler/utils/xplane_schema.cc +++ b/tensorflow/core/profiler/utils/xplane_schema.cc @@ -25,10 +25,12 @@ namespace profiler { const absl::string_view kHostThreads = "/host:CPU"; const absl::string_view kGpuPlanePrefix = "/device:GPU:"; const absl::string_view kCuptiDriverApiPlaneName = "/host:CUPTI"; +const absl::string_view kMetadataPlane = "/host:metadata"; const int32 kHostPlaneId = 49; const int32 kGpuPlaneBaseId = 0; const int32 kCuptiDriverApiPlaneId = 50; +const int32 kMetadataPlaneId = 51; namespace { @@ -149,6 +151,7 @@ const StatTypeMap& GetStatTypeMap() { // XLA metadata map related. {"SELF_DURATION_PS", kSelfDurationPs}, {"MIN_DURATION_PS", kMinDurationPs}, + {"Hlo Proto", kHloProto}, // Device capability related. {"clock_rate", kDevCapClockRateKHz}, {"core_count", kDevCapCoreCount}, diff --git a/tensorflow/core/profiler/utils/xplane_schema.h b/tensorflow/core/profiler/utils/xplane_schema.h index 9e6eaab1036..bce6c5ecc8f 100644 --- a/tensorflow/core/profiler/utils/xplane_schema.h +++ b/tensorflow/core/profiler/utils/xplane_schema.h @@ -31,6 +31,8 @@ ABSL_CONST_INIT extern const absl::string_view kHostThreads; ABSL_CONST_INIT extern const absl::string_view kGpuPlanePrefix; // Name of XPlane that contains CUPTI driver API generated events. ABSL_CONST_INIT extern const absl::string_view kCuptiDriverApiPlaneName; +// Name of XPlane that contains profile metadata such as XLA debug info. +ABSL_CONST_INIT extern const absl::string_view kMetadataPlane; // Id of XPlane that contains TraceMe events. ABSL_CONST_INIT extern const int32 kHostPlaneId; @@ -39,6 +41,8 @@ ABSL_CONST_INIT extern const int32 kGpuPlaneBaseId; // Id of XPlane that contains CUPTI driver API generated events which happens // on CPU host threads, e.g. Kernel launch. ABSL_CONST_INIT extern const int32 kCuptiDriverApiPlaneId; +// Id of XPlane that contains profile metadata such as XLA debug info. +ABSL_CONST_INIT extern const int32 kMetadataPlaneId; // Interesting event types (i.e., TraceMe names). enum HostEventType { @@ -140,6 +144,7 @@ enum StatType { // XLA metadata map related. kSelfDurationPs, kMinDurationPs, + kHloProto, // Device capability related. kDevCapClockRateKHz, kDevCapCoreCount, diff --git a/tensorflow/core/protobuf/eager_service.proto b/tensorflow/core/protobuf/eager_service.proto index d57ca22b0d2..7be7199f10c 100644 --- a/tensorflow/core/protobuf/eager_service.proto +++ b/tensorflow/core/protobuf/eager_service.proto @@ -23,8 +23,6 @@ message Operation { // future. int64 id = 1; string name = 2; - // TODO(b/150963957): Deprecate this. - repeated RemoteTensorHandle inputs = 3; message Input { oneof item { @@ -52,6 +50,8 @@ message Operation { int64 func_step_id = 8; // Indicates whether the op is a function. bool is_function = 9; + + reserved 3; } message QueueItem { diff --git a/tensorflow/examples/saved_model/integration_tests/export_simple_text_embedding.py b/tensorflow/examples/saved_model/integration_tests/export_simple_text_embedding.py index a90d90d4373..891a8f1c7e2 100644 --- a/tensorflow/examples/saved_model/integration_tests/export_simple_text_embedding.py +++ b/tensorflow/examples/saved_model/integration_tests/export_simple_text_embedding.py @@ -87,7 +87,8 @@ class TextEmbeddingModel(tf.train.Checkpoint): return tf.nn.safe_embedding_lookup_sparse( embedding_weights=self.embeddings, - sparse_ids=tf.SparseTensor(token_ids, token_values, token_dense_shape), + sparse_ids=tf.sparse.SparseTensor(token_ids, token_values, + token_dense_shape), sparse_weights=None, combiner="sqrtn") diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index c6ee0e29c1d..d73c267cf37 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -20741,7 +20741,7 @@ func Print(scope *Scope, input tf.Output, data []tf.Output, optional ...PrintAtt // // with tf.Session() as sess: // # Define (COO format) SparseTensor over Numpy array. -// a_st = tf.SparseTensor(a_indices, a_values, a_dense_shape) +// a_st = tf.sparse.SparseTensor(a_indices, a_values, a_dense_shape) // // # Convert SparseTensors to CSR SparseMatrix. // a_sm = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix( @@ -32098,8 +32098,8 @@ func SparseMatrixSparseMatMulAdjointB(value bool) SparseMatrixSparseMatMulAttr { // // with tf.Session() as sess: // # Define (COO format) Sparse Tensors over Numpy arrays -// a_st = tf.SparseTensor(a_indices, a_values, a_dense_shape) -// b_st = tf.SparseTensor(b_indices, b_values, b_dense_shape) +// a_st = tf.sparse.SparseTensor(a_indices, a_values, a_dense_shape) +// b_st = tf.sparse.SparseTensor(b_indices, b_values, b_dense_shape) // // # Convert SparseTensors to CSR SparseMatrix // a_sm = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix( @@ -37615,7 +37615,7 @@ func RecvTPUEmbeddingActivations(scope *Scope, num_outputs int64, config string) // // with tf.Session() as sess: // # Define (COO format) SparseTensor over Numpy array. -// a_st = tf.SparseTensor(a_indices, a_values, a_dense_shape) +// a_st = tf.sparse.SparseTensor(a_indices, a_values, a_dense_shape) // // # Convert SparseTensors to CSR SparseMatrix. // a_sm = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix( diff --git a/tensorflow/lite/c/builtin_op_data.h b/tensorflow/lite/c/builtin_op_data.h index d7848f9bcc0..1bd7959c4db 100644 --- a/tensorflow/lite/c/builtin_op_data.h +++ b/tensorflow/lite/c/builtin_op_data.h @@ -124,21 +124,33 @@ typedef struct { typedef struct { int rank; TfLiteFusedActivation activation; + + // Parameter for SVDF version 4. + bool asymmetric_quantize_inputs; } TfLiteSVDFParams; typedef struct { TfLiteFusedActivation activation; + + // Parameter for RNN version 3. + bool asymmetric_quantize_inputs; } TfLiteRNNParams; typedef struct { bool time_major; TfLiteFusedActivation activation; + + // Parameter for Sequence RNN version 3. + bool asymmetric_quantize_inputs; } TfLiteSequenceRNNParams; typedef struct { bool time_major; TfLiteFusedActivation activation; bool merge_outputs; + + // Parameter for Bidirectional RNN verison 3. + bool asymmetric_quantize_inputs; } TfLiteBidirectionalSequenceRNNParams; typedef enum { @@ -158,6 +170,11 @@ typedef struct { // tensors are the same. Furthermore, all but the last dimension of the input // and output shapes will be equal. bool keep_num_dims; + + // Parameters for FullyConnected version 7 or above. + // If set to true and the weights are quantized, then non constant inputs + // are quantized at evaluation time with asymmetric quantization. + bool asymmetric_quantize_inputs; } TfLiteFullyConnectedParams; typedef enum { @@ -228,6 +245,9 @@ typedef struct { // Parameters for LSTM version 2. // kTfLiteLSTMBasicKernel is only supported in version 2 or above. TfLiteLSTMKernelType kernel_type; + + // Parameters for LSTM version 4. + bool asymmetric_quantize_inputs; } TfLiteLSTMParams; typedef struct { @@ -238,6 +258,9 @@ typedef struct { // If set to true then the first dimension is time, otherwise batch. bool time_major; + + // Parameter for unidirectional sequence RNN version 3. + bool asymmetric_quantize_inputs; } TfLiteUnidirectionalSequenceLSTMParams; typedef struct { @@ -253,6 +276,10 @@ typedef struct { // Parameters supported by version 2: // If set to true then the first dimension is time, otherwise batch. bool time_major; + + // Parameters supported by version 4: + // If set to true, then hybrid ops use asymmetric quantization for inputs. + bool asymmetric_quantize_inputs; } TfLiteBidirectionalSequenceLSTMParams; typedef struct { diff --git a/tensorflow/lite/c/common.c b/tensorflow/lite/c/common.c index a9d92b223ca..f70a60002dd 100644 --- a/tensorflow/lite/c/common.c +++ b/tensorflow/lite/c/common.c @@ -209,6 +209,8 @@ const char* TfLiteTypeGetName(TfLiteType type) { return "STRING"; case kTfLiteFloat16: return "FLOAT16"; + case kTfLiteFloat64: + return "FLOAT64"; } return "Unknown type"; } diff --git a/tensorflow/lite/c/common.h b/tensorflow/lite/c/common.h index 10280df05b3..39ec547198e 100644 --- a/tensorflow/lite/c/common.h +++ b/tensorflow/lite/c/common.h @@ -236,6 +236,7 @@ typedef enum { kTfLiteComplex64 = 8, kTfLiteInt8 = 9, kTfLiteFloat16 = 10, + kTfLiteFloat64 = 11, } TfLiteType; // Return the name of a given type, for error reporting purposes. diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 83b4159cce0..878dbef29bb 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -91,11 +91,14 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, ErrorReporter* error_reporter) { *type = kTfLiteNoType; switch (tensor_type) { + case TensorType_FLOAT16: + *type = kTfLiteFloat16; + break; case TensorType_FLOAT32: *type = kTfLiteFloat32; break; - case TensorType_FLOAT16: - *type = kTfLiteFloat16; + case TensorType_FLOAT64: + *type = kTfLiteFloat64; break; case TensorType_INT16: *type = kTfLiteInt16; @@ -269,6 +272,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, params->rank = svdf_params->rank(); params->activation = parse_activation(svdf_params->fused_activation_function()); + params->asymmetric_quantize_inputs = + svdf_params->asymmetric_quantize_inputs(); } *builtin_data = reinterpret_cast(params.release()); break; @@ -280,6 +285,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, params->activation = parse_activation(sequence_rnn_params->fused_activation_function()); params->time_major = sequence_rnn_params->time_major(); + params->asymmetric_quantize_inputs = + sequence_rnn_params->asymmetric_quantize_inputs(); } *builtin_data = reinterpret_cast(params.release()); break; @@ -293,6 +300,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, bidi_sequence_rnn_params->fused_activation_function()); params->time_major = bidi_sequence_rnn_params->time_major(); params->merge_outputs = bidi_sequence_rnn_params->merge_outputs(); + params->asymmetric_quantize_inputs = + bidi_sequence_rnn_params->asymmetric_quantize_inputs(); } *builtin_data = reinterpret_cast(params.release()); break; @@ -302,6 +311,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, if (const auto* rnn_params = op->builtin_options_as_RNNOptions()) { params->activation = parse_activation(rnn_params->fused_activation_function()); + params->asymmetric_quantize_inputs = + rnn_params->asymmetric_quantize_inputs(); } *builtin_data = reinterpret_cast(params.release()); break; @@ -323,6 +334,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, params->activation = parse_activation( fully_connected_params->fused_activation_function()); params->keep_num_dims = fully_connected_params->keep_num_dims(); + params->asymmetric_quantize_inputs = + fully_connected_params->asymmetric_quantize_inputs(); switch (fully_connected_params->weights_format()) { case FullyConnectedOptionsWeightsFormat_DEFAULT: params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault; @@ -440,6 +453,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, lstm_params->kernel_type()); return kTfLiteError; } + params->asymmetric_quantize_inputs = + lstm_params->asymmetric_quantize_inputs(); } else { TF_LITE_REPORT_ERROR(error_reporter, "No valid LSTM builtin options exist"); @@ -458,6 +473,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, params->cell_clip = seq_lstm_params->cell_clip(); params->proj_clip = seq_lstm_params->proj_clip(); params->time_major = seq_lstm_params->time_major(); + params->asymmetric_quantize_inputs = + seq_lstm_params->asymmetric_quantize_inputs(); } *builtin_data = reinterpret_cast(params.release()); break; @@ -473,6 +490,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, params->proj_clip = bidi_lstm_params->proj_clip(); params->merge_outputs = bidi_lstm_params->merge_outputs(); params->time_major = bidi_lstm_params->time_major(); + params->asymmetric_quantize_inputs = + bidi_lstm_params->asymmetric_quantize_inputs(); } *builtin_data = reinterpret_cast(params.release()); break; diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index d057b2adc6e..4dfea10cf2d 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -941,6 +941,22 @@ TfLiteStatus Subgraph::Invoke() { TfLiteStatus Subgraph::ResizeTensor(TfLiteContext* context, TfLiteTensor* tensor, TfLiteIntArray* new_size) { + // If the dimensions don't change, avoiding + // unnecessary (re)allocations. + // + // Note that it's required to check `tensor->data.raw != nullptr`. Otherwise + // the subgraph won't allocate memory for a dynamic tensor when its size + // is equal to the original tensor size. + if (tensor->data.raw != nullptr && + EqualArrayAndTfLiteIntArray(tensor->dims, new_size->size, + new_size->data)) { + // A number of clients assume |new_size| remains valid upon success, so + // swap it in as the new (but logically identical) tensor dims. + TfLiteIntArrayFree(tensor->dims); + tensor->dims = new_size; + return kTfLiteOk; + } + // Note here that context->impl_ is recovering the this pointer for an // instance of Interpreter to call into the member function ResizeTensorImpl // (this function is static). diff --git a/tensorflow/lite/delegates/flex/util.cc b/tensorflow/lite/delegates/flex/util.cc index 4279f4ae397..750de7397fa 100644 --- a/tensorflow/lite/delegates/flex/util.cc +++ b/tensorflow/lite/delegates/flex/util.cc @@ -62,6 +62,8 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) { return TF_FLOAT; case kTfLiteFloat16: return TF_HALF; + case kTfLiteFloat64: + return TF_DOUBLE; case kTfLiteInt16: return TF_INT16; case kTfLiteInt32: diff --git a/tensorflow/lite/delegates/flex/whitelisted_flex_ops.cc b/tensorflow/lite/delegates/flex/whitelisted_flex_ops.cc index d40bd332965..639adf72fcf 100644 --- a/tensorflow/lite/delegates/flex/whitelisted_flex_ops.cc +++ b/tensorflow/lite/delegates/flex/whitelisted_flex_ops.cc @@ -117,6 +117,7 @@ bool IsWhitelistedFlexOp(const std::string& tensorflow_op_name) { "Exit", "Exp", "ExpandDims", + "ExtractImagePatches", "FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxArgsGradient", "FakeQuantWithMinMaxVars", diff --git a/tensorflow/lite/delegates/gpu/cl/gpu_api_delegate.cc b/tensorflow/lite/delegates/gpu/cl/gpu_api_delegate.cc index 0e2d046eba2..082d6c12985 100644 --- a/tensorflow/lite/delegates/gpu/cl/gpu_api_delegate.cc +++ b/tensorflow/lite/delegates/gpu/cl/gpu_api_delegate.cc @@ -90,7 +90,7 @@ class Delegate { absl::Status Prepare(TfLiteContext* context, const TfLiteDelegateParams* delegate_params) { // Extract TFLite delegate execution plan from the context and convert it - // into FlowGraph32. + // into GraphFloat32. GraphFloat32 graph; RETURN_IF_ERROR(BuildModel(context, delegate_params, &graph)); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc index 2573f2d7422..4c5e20abde3 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc @@ -73,7 +73,7 @@ std::string GetSrcValue(const TensorCodeGenerator& src_tensor, return c; } -std::string GenerateDepthWiseConvolutionCode( +std::string GenerateDepthwiseConvolutionCode( const OperationDef& op_def, bool stride_correction, const LinearStorage& biases, int channel_multiplier, bool weights_are_buffer, @@ -179,7 +179,7 @@ std::string GenerateDepthWiseConvolutionCode( } } // namespace -DepthWiseConvolution::DepthWiseConvolution( +DepthwiseConvolution::DepthwiseConvolution( const OperationDef& definition, const DepthwiseConvolution2DAttributes& attr, bool weights_are_buffer) : GPUOperation(definition), @@ -191,7 +191,7 @@ DepthWiseConvolution::DepthWiseConvolution( channel_multiplier_(attr.weights.shape.o), work_group_size_(8, 8, 1) {} -DepthWiseConvolution::DepthWiseConvolution(DepthWiseConvolution&& operation) +DepthwiseConvolution::DepthwiseConvolution(DepthwiseConvolution&& operation) : GPUOperation(std::move(operation)), weights_are_buffer_(operation.weights_are_buffer_), weights_tex2d_(std::move(operation.weights_tex2d_)), @@ -206,8 +206,8 @@ DepthWiseConvolution::DepthWiseConvolution(DepthWiseConvolution&& operation) kernel_(std::move(operation.kernel_)), work_group_size_(operation.work_group_size_) {} -DepthWiseConvolution& DepthWiseConvolution::operator=( - DepthWiseConvolution&& operation) { +DepthwiseConvolution& DepthwiseConvolution::operator=( + DepthwiseConvolution&& operation) { if (this != &operation) { std::swap(weights_are_buffer_, operation.weights_are_buffer_); weights_tex2d_ = std::move(operation.weights_tex2d_); @@ -226,11 +226,11 @@ DepthWiseConvolution& DepthWiseConvolution::operator=( return *this; } -absl::Status DepthWiseConvolution::Compile( +absl::Status DepthwiseConvolution::Compile( const CreationContext& creation_context) { const bool stride_correction = definition_.IsBatchSupported() && stride_.x != 1; - const auto code = GenerateDepthWiseConvolutionCode( + const auto code = GenerateDepthwiseConvolutionCode( definition_, stride_correction, biases_, channel_multiplier_, weights_are_buffer_, linked_operations_, *creation_context.device); return creation_context.cache->GetOrCreateCLKernel( @@ -238,7 +238,7 @@ absl::Status DepthWiseConvolution::Compile( *creation_context.device, &kernel_); } -absl::Status DepthWiseConvolution::BindArguments() { +absl::Status DepthwiseConvolution::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_)); @@ -259,29 +259,29 @@ absl::Status DepthWiseConvolution::BindArguments() { return absl::OkStatus(); } -int3 DepthWiseConvolution::GetGridSize() const { +int3 DepthwiseConvolution::GetGridSize() const { const int grid_x = dst_[0]->Width() * dst_[0]->Batch(); const int grid_y = dst_[0]->Height(); const int grid_z = dst_[0]->Slices(); return int3(grid_x, grid_y, grid_z); } -absl::Status DepthWiseConvolution::Tune(const TuningParameters& params) { +absl::Status DepthwiseConvolution::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status DepthWiseConvolution::AddToQueue(CLCommandQueue* queue) { +absl::Status DepthwiseConvolution::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } -absl::Status CreateDepthWiseConvolution( +absl::Status CreateDepthwiseConvolution( const CreationContext& creation_context, const OperationDef& definition, const DepthwiseConvolution2DAttributes& attr, - DepthWiseConvolution* result) { + DepthwiseConvolution* result) { bool weights_are_buffer = creation_context.device->IsMali(); - *result = DepthWiseConvolution(definition, attr, weights_are_buffer); + *result = DepthwiseConvolution(definition, attr, weights_are_buffer); RETURN_IF_ERROR( result->UploadWeights(attr.weights, creation_context.context)); LinearStorageCreateInfo create_info; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h index 1c1c55c1989..9d3e33630f8 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h @@ -35,26 +35,26 @@ namespace tflite { namespace gpu { namespace cl { -class DepthWiseConvolution : public GPUOperation { +class DepthwiseConvolution : public GPUOperation { public: - DepthWiseConvolution() = default; + DepthwiseConvolution() = default; absl::Status AddToQueue(CLCommandQueue* queue) override; absl::Status Tune(const TuningParameters& params) override; absl::Status Compile(const CreationContext& creation_context) override; // Move only - DepthWiseConvolution(DepthWiseConvolution&& operation); - DepthWiseConvolution& operator=(DepthWiseConvolution&& operation); - DepthWiseConvolution(const DepthWiseConvolution&) = delete; - DepthWiseConvolution& operator=(const DepthWiseConvolution&) = delete; + DepthwiseConvolution(DepthwiseConvolution&& operation); + DepthwiseConvolution& operator=(DepthwiseConvolution&& operation); + DepthwiseConvolution(const DepthwiseConvolution&) = delete; + DepthwiseConvolution& operator=(const DepthwiseConvolution&) = delete; private: - friend absl::Status CreateDepthWiseConvolution( + friend absl::Status CreateDepthwiseConvolution( const CreationContext& creation_context, const OperationDef& definition, const DepthwiseConvolution2DAttributes& attr, - DepthWiseConvolution* result); - DepthWiseConvolution(const OperationDef& definition, + DepthwiseConvolution* result); + DepthwiseConvolution(const OperationDef& definition, const DepthwiseConvolution2DAttributes& attr, bool weights_are_buffer); template @@ -86,7 +86,7 @@ class DepthWiseConvolution : public GPUOperation { }; template -absl::Status DepthWiseConvolution::UploadWeights( +absl::Status DepthwiseConvolution::UploadWeights( const tflite::gpu::Tensor& weights, CLContext* context) { const int dst_channels = weights.shape.i * weights.shape.o; const int dst_depth = IntegralDivideRoundUp(dst_channels, 4); @@ -134,7 +134,7 @@ absl::Status DepthWiseConvolution::UploadWeights( } template -void DepthWiseConvolution::RearrangeWeightsData( +void DepthwiseConvolution::RearrangeWeightsData( const tflite::gpu::Tensor& weights, absl::Span dst) { const int dst_channels = weights.shape.i * weights.shape.o; const int dst_depth = IntegralDivideRoundUp(dst_channels, 4); @@ -162,9 +162,9 @@ void DepthWiseConvolution::RearrangeWeightsData( } } -absl::Status CreateDepthWiseConvolution( +absl::Status CreateDepthwiseConvolution( const CreationContext& creation_context, const OperationDef& definition, - const DepthwiseConvolution2DAttributes& attr, DepthWiseConvolution* result); + const DepthwiseConvolution2DAttributes& attr, DepthwiseConvolution* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3d.cc b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3d.cc index 5f1d529fba2..f9926a9f466 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3d.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3d.cc @@ -79,7 +79,7 @@ std::string GetSrcValue(const TensorCodeGenerator& src_tensor, return c; } -std::string GenerateDepthWiseConvolution3DCode( +std::string GenerateDepthwiseConvolution3DCode( const OperationDef& op_def, bool stride_correction, const LinearStorage& biases, int channel_multiplier, bool weights_are_buffer, @@ -208,7 +208,7 @@ std::string GenerateDepthWiseConvolution3DCode( } } // namespace -DepthWiseConvolution3D::DepthWiseConvolution3D( +DepthwiseConvolution3D::DepthwiseConvolution3D( const OperationDef& definition, const DepthwiseConvolution3DAttributes& attr, const CLDevice& device) : GPUOperation(definition), @@ -222,8 +222,8 @@ DepthWiseConvolution3D::DepthWiseConvolution3D( channel_multiplier_(attr.weights.shape.o), work_group_size_(8, 8, 1) {} -DepthWiseConvolution3D::DepthWiseConvolution3D( - DepthWiseConvolution3D&& operation) +DepthwiseConvolution3D::DepthwiseConvolution3D( + DepthwiseConvolution3D&& operation) : GPUOperation(std::move(operation)), weights_tex2d_(std::move(operation.weights_tex2d_)), weights_buf_(std::move(operation.weights_buf_)), @@ -237,8 +237,8 @@ DepthWiseConvolution3D::DepthWiseConvolution3D( kernel_(std::move(operation.kernel_)), work_group_size_(operation.work_group_size_) {} -DepthWiseConvolution3D& DepthWiseConvolution3D::operator=( - DepthWiseConvolution3D&& operation) { +DepthwiseConvolution3D& DepthwiseConvolution3D::operator=( + DepthwiseConvolution3D&& operation) { if (this != &operation) { weights_tex2d_ = std::move(operation.weights_tex2d_); weights_buf_ = std::move(operation.weights_buf_); @@ -256,11 +256,11 @@ DepthWiseConvolution3D& DepthWiseConvolution3D::operator=( return *this; } -absl::Status DepthWiseConvolution3D::Compile( +absl::Status DepthwiseConvolution3D::Compile( const CreationContext& creation_context) { const bool stride_correction = definition_.IsBatchSupported() && stride_.x != 1; - const auto code = GenerateDepthWiseConvolution3DCode( + const auto code = GenerateDepthwiseConvolution3DCode( definition_, stride_correction, biases_, channel_multiplier_, weights_are_buffer_, linked_operations_, *creation_context.device); return creation_context.cache->GetOrCreateCLKernel( @@ -268,7 +268,7 @@ absl::Status DepthWiseConvolution3D::Compile( *creation_context.device, &kernel_); } -absl::Status DepthWiseConvolution3D::BindArguments() { +absl::Status DepthwiseConvolution3D::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); if (weights_are_buffer_) { @@ -298,28 +298,28 @@ absl::Status DepthWiseConvolution3D::BindArguments() { return absl::OkStatus(); } -int3 DepthWiseConvolution3D::GetGridSize() const { +int3 DepthwiseConvolution3D::GetGridSize() const { const int grid_x = dst_[0]->Width() * dst_[0]->Batch(); const int grid_y = dst_[0]->Height(); const int grid_z = dst_[0]->Slices() * dst_[0]->Depth(); return int3(grid_x, grid_y, grid_z); } -absl::Status DepthWiseConvolution3D::Tune(const TuningParameters& params) { +absl::Status DepthwiseConvolution3D::Tune(const TuningParameters& params) { RETURN_IF_ERROR(BindArguments()); return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status DepthWiseConvolution3D::AddToQueue(CLCommandQueue* queue) { +absl::Status DepthwiseConvolution3D::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } -absl::Status CreateDepthWiseConvolution3D( +absl::Status CreateDepthwiseConvolution3D( const CreationContext& creation_context, const OperationDef& definition, const DepthwiseConvolution3DAttributes& attr, - DepthWiseConvolution3D* result) { - *result = DepthWiseConvolution3D(definition, attr, *creation_context.device); + DepthwiseConvolution3D* result) { + *result = DepthwiseConvolution3D(definition, attr, *creation_context.device); RETURN_IF_ERROR( result->UploadWeights(attr.weights, creation_context.context)); LinearStorageCreateInfo create_info; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3d.h b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3d.h index 1d80d5ddca0..53e38a3e154 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3d.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3d.h @@ -35,26 +35,26 @@ namespace tflite { namespace gpu { namespace cl { -class DepthWiseConvolution3D : public GPUOperation { +class DepthwiseConvolution3D : public GPUOperation { public: - DepthWiseConvolution3D() = default; + DepthwiseConvolution3D() = default; absl::Status AddToQueue(CLCommandQueue* queue) override; absl::Status Tune(const TuningParameters& params) override; absl::Status Compile(const CreationContext& creation_context) override; // Move only - DepthWiseConvolution3D(DepthWiseConvolution3D&& operation); - DepthWiseConvolution3D& operator=(DepthWiseConvolution3D&& operation); - DepthWiseConvolution3D(const DepthWiseConvolution3D&) = delete; - DepthWiseConvolution3D& operator=(const DepthWiseConvolution3D&) = delete; + DepthwiseConvolution3D(DepthwiseConvolution3D&& operation); + DepthwiseConvolution3D& operator=(DepthwiseConvolution3D&& operation); + DepthwiseConvolution3D(const DepthwiseConvolution3D&) = delete; + DepthwiseConvolution3D& operator=(const DepthwiseConvolution3D&) = delete; private: - friend absl::Status CreateDepthWiseConvolution3D( + friend absl::Status CreateDepthwiseConvolution3D( const CreationContext& creation_context, const OperationDef& definition, const DepthwiseConvolution3DAttributes& attr, - DepthWiseConvolution3D* result); - DepthWiseConvolution3D(const OperationDef& definition, + DepthwiseConvolution3D* result); + DepthwiseConvolution3D(const OperationDef& definition, const DepthwiseConvolution3DAttributes& attr, const CLDevice& device); template @@ -85,7 +85,7 @@ class DepthWiseConvolution3D : public GPUOperation { }; template -absl::Status DepthWiseConvolution3D::UploadWeights( +absl::Status DepthwiseConvolution3D::UploadWeights( const tflite::gpu::Tensor& weights, CLContext* context) { const int dst_channels = weights.shape.i * weights.shape.o; const int dst_slices = IntegralDivideRoundUp(dst_channels, 4); @@ -127,7 +127,7 @@ absl::Status DepthWiseConvolution3D::UploadWeights( } template -void DepthWiseConvolution3D::RearrangeWeightsData( +void DepthwiseConvolution3D::RearrangeWeightsData( const tflite::gpu::Tensor& weights, absl::Span dst) { const int dst_channels = weights.shape.i * weights.shape.o; const int dst_slices = IntegralDivideRoundUp(dst_channels, 4); @@ -158,10 +158,10 @@ void DepthWiseConvolution3D::RearrangeWeightsData( } } -absl::Status CreateDepthWiseConvolution3D( +absl::Status CreateDepthwiseConvolution3D( const CreationContext& creation_context, const OperationDef& definition, const DepthwiseConvolution3DAttributes& attr, - DepthWiseConvolution3D* result); + DepthwiseConvolution3D* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.cc b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.cc index e4868be7ffc..348229e69f7 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.cc @@ -28,7 +28,7 @@ namespace gpu { namespace cl { namespace { -std::string GenerateDepthWiseConvCode( +std::string GenerateDepthwiseConvCode( const OperationDef& op_def, const std::vector& linked_operations, const CLDevice& device, bool weights_are_buffer, bool local_mem_uploads) { @@ -266,14 +266,14 @@ std::string GenerateDepthWiseConvCode( } // namespace -DepthWiseConv3x3::DepthWiseConv3x3(const OperationDef& definition, +DepthwiseConv3x3::DepthwiseConv3x3(const OperationDef& definition, bool weights_are_buffer, bool local_mem_uploads) : GPUOperation(definition), weights_are_buffer_(weights_are_buffer), local_mem_uploads_(local_mem_uploads) {} -DepthWiseConv3x3::DepthWiseConv3x3(DepthWiseConv3x3&& operation) +DepthwiseConv3x3::DepthwiseConv3x3(DepthwiseConv3x3&& operation) : GPUOperation(std::move(operation)), weights_are_buffer_(operation.weights_are_buffer_), local_mem_uploads_(operation.local_mem_uploads_), @@ -283,7 +283,7 @@ DepthWiseConv3x3::DepthWiseConv3x3(DepthWiseConv3x3&& operation) kernel_(std::move(operation.kernel_)), work_group_size_(operation.work_group_size_) {} -DepthWiseConv3x3& DepthWiseConv3x3::operator=(DepthWiseConv3x3&& operation) { +DepthwiseConv3x3& DepthwiseConv3x3::operator=(DepthwiseConv3x3&& operation) { if (this != &operation) { std::swap(weights_are_buffer_, operation.weights_are_buffer_); std::swap(local_mem_uploads_, operation.local_mem_uploads_); @@ -297,9 +297,9 @@ DepthWiseConv3x3& DepthWiseConv3x3::operator=(DepthWiseConv3x3&& operation) { return *this; } -absl::Status DepthWiseConv3x3::Compile( +absl::Status DepthwiseConv3x3::Compile( const CreationContext& creation_context) { - std::string code = GenerateDepthWiseConvCode( + std::string code = GenerateDepthwiseConvCode( definition_, linked_operations_, *creation_context.device, weights_are_buffer_, local_mem_uploads_); std::vector options; @@ -312,7 +312,7 @@ absl::Status DepthWiseConv3x3::Compile( *creation_context.device, &kernel_); } -absl::Status DepthWiseConv3x3::BindArguments() { +absl::Status DepthwiseConv3x3::BindArguments() { kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_)); @@ -322,14 +322,14 @@ absl::Status DepthWiseConv3x3::BindArguments() { return absl::OkStatus(); } -int3 DepthWiseConv3x3::GetGridSize() const { +int3 DepthwiseConv3x3::GetGridSize() const { const int grid_x = IntegralDivideRoundUp(dst_[0]->Width(), 2); const int grid_y = IntegralDivideRoundUp(dst_[0]->Height(), 2); const int grid_z = dst_[0]->Slices(); return int3(grid_x, grid_y, grid_z); } -absl::Status DepthWiseConv3x3::Tune(const TuningParameters& params) { +absl::Status DepthwiseConv3x3::Tune(const TuningParameters& params) { if (local_mem_uploads_) { return absl::OkStatus(); } @@ -337,12 +337,12 @@ absl::Status DepthWiseConv3x3::Tune(const TuningParameters& params) { return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_); } -absl::Status DepthWiseConv3x3::AddToQueue(CLCommandQueue* queue) { +absl::Status DepthwiseConv3x3::AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(BindArguments()); return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_); } -bool IsDepthWiseConv3x3Supported(const DepthwiseConvolution2DAttributes& attr) { +bool IsDepthwiseConv3x3Supported(const DepthwiseConvolution2DAttributes& attr) { return attr.weights.shape.o == 1 && attr.dilations.w == 1 && attr.dilations.h == 1 && attr.weights.shape.w == 3 && attr.weights.shape.h == 3 && attr.strides.w == 1 && @@ -351,18 +351,18 @@ bool IsDepthWiseConv3x3Supported(const DepthwiseConvolution2DAttributes& attr) { attr.padding.appended.h == 1; } -absl::Status CreateDepthWiseConv3x3( +absl::Status CreateDepthwiseConv3x3( const CreationContext& creation_context, const OperationDef& definition, - const DepthwiseConvolution2DAttributes& attr, DepthWiseConv3x3* result) { - if (!IsDepthWiseConv3x3Supported(attr)) { + const DepthwiseConvolution2DAttributes& attr, DepthwiseConv3x3* result) { + if (!IsDepthwiseConv3x3Supported(attr)) { return absl::InvalidArgumentError( - "DepthWiseConv3x3 doesn't support this attributes"); + "DepthwiseConv3x3 doesn't support this attributes"); } bool weights_are_buffer = creation_context.device->IsPowerVR() || creation_context.device->IsMali(); bool local_mem_uploads = weights_are_buffer && creation_context.device->IsPowerVR(); - *result = DepthWiseConv3x3(definition, weights_are_buffer, local_mem_uploads); + *result = DepthwiseConv3x3(definition, weights_are_buffer, local_mem_uploads); return result->UploadWeightsAndBiases(attr.weights, attr.bias, creation_context.context); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.h b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.h index 769903adcb2..ac7c316df8b 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3.h @@ -35,31 +35,31 @@ namespace tflite { namespace gpu { namespace cl { -class DepthWiseConv3x3 : public GPUOperation { +class DepthwiseConv3x3 : public GPUOperation { public: - DepthWiseConv3x3() = default; + DepthwiseConv3x3() = default; absl::Status AddToQueue(CLCommandQueue* queue) override; absl::Status Tune(const TuningParameters& params) override; absl::Status Compile(const CreationContext& creation_context) override; // Move only - DepthWiseConv3x3(DepthWiseConv3x3&& operation); - DepthWiseConv3x3& operator=(DepthWiseConv3x3&& operation); - DepthWiseConv3x3(const DepthWiseConv3x3&) = delete; - DepthWiseConv3x3& operator=(const DepthWiseConv3x3&) = delete; + DepthwiseConv3x3(DepthwiseConv3x3&& operation); + DepthwiseConv3x3& operator=(DepthwiseConv3x3&& operation); + DepthwiseConv3x3(const DepthwiseConv3x3&) = delete; + DepthwiseConv3x3& operator=(const DepthwiseConv3x3&) = delete; private: - explicit DepthWiseConv3x3(const OperationDef& definition, + explicit DepthwiseConv3x3(const OperationDef& definition, bool weights_are_buffer, bool local_mem_uploads); template absl::Status UploadWeightsAndBiases( const tflite::gpu::Tensor& weights, const tflite::gpu::Tensor& biases, CLContext* context); - friend absl::Status CreateDepthWiseConv3x3( + friend absl::Status CreateDepthwiseConv3x3( const CreationContext& creation_context, const OperationDef& definition, - const DepthwiseConvolution2DAttributes& attr, DepthWiseConv3x3* result); + const DepthwiseConvolution2DAttributes& attr, DepthwiseConv3x3* result); template void RearrangeWeightsAndBiasesData( @@ -80,7 +80,7 @@ class DepthWiseConv3x3 : public GPUOperation { }; template -absl::Status DepthWiseConv3x3::UploadWeightsAndBiases( +absl::Status DepthwiseConv3x3::UploadWeightsAndBiases( const tflite::gpu::Tensor& weights, const tflite::gpu::Tensor& biases, CLContext* context) { const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4); @@ -126,7 +126,7 @@ absl::Status DepthWiseConv3x3::UploadWeightsAndBiases( } template -void DepthWiseConv3x3::RearrangeWeightsAndBiasesData( +void DepthwiseConv3x3::RearrangeWeightsAndBiasesData( const tflite::gpu::Tensor& weights, const tflite::gpu::Tensor& biases, absl::Span dst) { const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4); @@ -158,11 +158,11 @@ void DepthWiseConv3x3::RearrangeWeightsAndBiasesData( } } -bool IsDepthWiseConv3x3Supported(const DepthwiseConvolution2DAttributes& attr); +bool IsDepthwiseConv3x3Supported(const DepthwiseConvolution2DAttributes& attr); -absl::Status CreateDepthWiseConv3x3( +absl::Status CreateDepthwiseConv3x3( const CreationContext& creation_context, const OperationDef& definition, - const DepthwiseConvolution2DAttributes& attr, DepthWiseConv3x3* result); + const DepthwiseConvolution2DAttributes& attr, DepthwiseConv3x3* result); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3_test.cc index 6b33cdf90f2..a88b05bb8b3 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_3x3_test.cc @@ -31,7 +31,7 @@ namespace gpu { namespace cl { namespace { -TEST_F(OpenCLOperationTest, DepthWiseConv3x3SimpleWeights) { +TEST_F(OpenCLOperationTest, DepthwiseConv3x3SimpleWeights) { TensorFloat32 src_tensor; src_tensor.shape = BHWC(1, 2, 2, 2); src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; @@ -56,9 +56,9 @@ TEST_F(OpenCLOperationTest, DepthWiseConv3x3SimpleWeights) { op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); TensorFloat32 dst_tensor; - DepthWiseConv3x3 operation; + DepthwiseConv3x3 operation; ASSERT_OK( - CreateDepthWiseConv3x3(creation_context_, op_def, attr, &operation)); + CreateDepthwiseConv3x3(creation_context_, op_def, attr, &operation)); ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, BHWC(1, 2, 2, 2), &dst_tensor)); EXPECT_THAT(dst_tensor.data, @@ -68,7 +68,7 @@ TEST_F(OpenCLOperationTest, DepthWiseConv3x3SimpleWeights) { } } -TEST_F(OpenCLOperationTest, DepthWiseConv3x3) { +TEST_F(OpenCLOperationTest, DepthwiseConv3x3) { TensorFloat32 src_tensor; src_tensor.shape = BHWC(1, 2, 2, 2); src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; @@ -93,9 +93,9 @@ TEST_F(OpenCLOperationTest, DepthWiseConv3x3) { op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); TensorFloat32 dst_tensor; - DepthWiseConv3x3 operation; + DepthwiseConv3x3 operation; ASSERT_OK( - CreateDepthWiseConv3x3(creation_context_, op_def, attr, &operation)); + CreateDepthwiseConv3x3(creation_context_, op_def, attr, &operation)); ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, BHWC(1, 2, 2, 2), &dst_tensor)); EXPECT_THAT(dst_tensor.data, diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_test.cc index e69b3d99309..ac010e7d572 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_test.cc @@ -31,7 +31,7 @@ namespace gpu { namespace cl { namespace { -TEST_F(OpenCLOperationTest, DepthWiseConvSimpleWeights) { +TEST_F(OpenCLOperationTest, DepthwiseConvSimpleWeights) { TensorFloat32 src_tensor; src_tensor.shape = BHWC(1, 2, 2, 2); src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; @@ -55,8 +55,8 @@ TEST_F(OpenCLOperationTest, DepthWiseConvSimpleWeights) { op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); TensorFloat32 dst_tensor; - DepthWiseConvolution operation; - ASSERT_OK(CreateDepthWiseConvolution(creation_context_, op_def, attr, + DepthwiseConvolution operation; + ASSERT_OK(CreateDepthwiseConvolution(creation_context_, op_def, attr, &operation)); ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, BHWC(1, 2, 2, 2), &dst_tensor)); @@ -67,7 +67,7 @@ TEST_F(OpenCLOperationTest, DepthWiseConvSimpleWeights) { } } -TEST_F(OpenCLOperationTest, DepthWiseConvNoMultiplier) { +TEST_F(OpenCLOperationTest, DepthwiseConvNoMultiplier) { TensorFloat32 src_tensor; src_tensor.shape = BHWC(1, 2, 2, 2); src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; @@ -91,8 +91,8 @@ TEST_F(OpenCLOperationTest, DepthWiseConvNoMultiplier) { op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); TensorFloat32 dst_tensor; - DepthWiseConvolution operation; - ASSERT_OK(CreateDepthWiseConvolution(creation_context_, op_def, attr, + DepthwiseConvolution operation; + ASSERT_OK(CreateDepthwiseConvolution(creation_context_, op_def, attr, &operation)); ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, BHWC(1, 2, 2, 2), &dst_tensor)); @@ -103,7 +103,7 @@ TEST_F(OpenCLOperationTest, DepthWiseConvNoMultiplier) { } } -TEST_F(OpenCLOperationTest, DepthWiseConvMultiplier2) { +TEST_F(OpenCLOperationTest, DepthwiseConvMultiplier2) { TensorFloat32 src_tensor; src_tensor.shape = BHWC(1, 2, 2, 2); src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; @@ -128,8 +128,8 @@ TEST_F(OpenCLOperationTest, DepthWiseConvMultiplier2) { op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); TensorFloat32 dst_tensor; - DepthWiseConvolution operation; - ASSERT_OK(CreateDepthWiseConvolution(creation_context_, op_def, attr, + DepthwiseConvolution operation; + ASSERT_OK(CreateDepthwiseConvolution(creation_context_, op_def, attr, &operation)); ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, BHWC(1, 2, 2, 4), &dst_tensor)); diff --git a/tensorflow/lite/delegates/gpu/cl/precision.h b/tensorflow/lite/delegates/gpu/cl/precision.h index f25db33673d..10afcd661c1 100644 --- a/tensorflow/lite/delegates/gpu/cl/precision.h +++ b/tensorflow/lite/delegates/gpu/cl/precision.h @@ -28,7 +28,7 @@ enum class CalculationsPrecision { F32, F32_F16, F16 }; // F32 - all data and all math ops in F32 // F16 - all data and all math ops in F16 // F32_F16 - as F16, but some operations (Convolution, -// DepthWiseConvolution, FullyConnected, ConvolutionTransposed) +// DepthwiseConvolution, FullyConnected, ConvolutionTransposed) // have accumulator in F32 and usually it calculates 4 mads in F16, sum them, // than converts this partial sum to F32 and add to accumulator. diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc index 72f31154b4b..9ae87c6ba07 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc @@ -30,16 +30,16 @@ absl::Status SelectDWConvolutionAdreno( const DepthwiseConvolution2DAttributes& attr, const CreationContext& creation_context, const OperationDef& op_def, std::unique_ptr* ptr) { - if (!op_def.IsBatchSupported() && IsDepthWiseConv3x3Supported(attr)) { - DepthWiseConv3x3 dw_conv; + if (!op_def.IsBatchSupported() && IsDepthwiseConv3x3Supported(attr)) { + DepthwiseConv3x3 dw_conv; RETURN_IF_ERROR( - CreateDepthWiseConv3x3(creation_context, op_def, attr, &dw_conv)); - *ptr = absl::make_unique(std::move(dw_conv)); + CreateDepthwiseConv3x3(creation_context, op_def, attr, &dw_conv)); + *ptr = absl::make_unique(std::move(dw_conv)); } else { - DepthWiseConvolution dw_conv; + DepthwiseConvolution dw_conv; RETURN_IF_ERROR( - CreateDepthWiseConvolution(creation_context, op_def, attr, &dw_conv)); - *ptr = absl::make_unique(std::move(dw_conv)); + CreateDepthwiseConvolution(creation_context, op_def, attr, &dw_conv)); + *ptr = absl::make_unique(std::move(dw_conv)); } return absl::OkStatus(); } @@ -48,16 +48,16 @@ absl::Status SelectDWConvolutionPowerVR( const DepthwiseConvolution2DAttributes& attr, const CreationContext& creation_context, const OperationDef& op_def, std::unique_ptr* ptr) { - if (!op_def.IsBatchSupported() && IsDepthWiseConv3x3Supported(attr)) { - DepthWiseConv3x3 dw_conv; + if (!op_def.IsBatchSupported() && IsDepthwiseConv3x3Supported(attr)) { + DepthwiseConv3x3 dw_conv; RETURN_IF_ERROR( - CreateDepthWiseConv3x3(creation_context, op_def, attr, &dw_conv)); - *ptr = absl::make_unique(std::move(dw_conv)); + CreateDepthwiseConv3x3(creation_context, op_def, attr, &dw_conv)); + *ptr = absl::make_unique(std::move(dw_conv)); } else { - DepthWiseConvolution dw_conv; + DepthwiseConvolution dw_conv; RETURN_IF_ERROR( - CreateDepthWiseConvolution(creation_context, op_def, attr, &dw_conv)); - *ptr = absl::make_unique(std::move(dw_conv)); + CreateDepthwiseConvolution(creation_context, op_def, attr, &dw_conv)); + *ptr = absl::make_unique(std::move(dw_conv)); } return absl::OkStatus(); } @@ -70,18 +70,18 @@ absl::Status SelectDWConvolutionMali( bool buffer_type = storage_type == TensorStorageType::BUFFER || storage_type == TensorStorageType::IMAGE_BUFFER; MaliInfo mali_info = creation_context.device->GetInfo().mali_info; - if (IsDepthWiseConv3x3Supported(attr) && !mali_info.IsMidgard() && + if (IsDepthwiseConv3x3Supported(attr) && !mali_info.IsMidgard() && !buffer_type && !op_def.IsBatchSupported() && op_def.precision != CalculationsPrecision::F32) { - DepthWiseConv3x3 dw_conv; + DepthwiseConv3x3 dw_conv; RETURN_IF_ERROR( - CreateDepthWiseConv3x3(creation_context, op_def, attr, &dw_conv)); - *ptr = absl::make_unique(std::move(dw_conv)); + CreateDepthwiseConv3x3(creation_context, op_def, attr, &dw_conv)); + *ptr = absl::make_unique(std::move(dw_conv)); } else { - DepthWiseConvolution dw_conv; + DepthwiseConvolution dw_conv; RETURN_IF_ERROR( - CreateDepthWiseConvolution(creation_context, op_def, attr, &dw_conv)); - *ptr = absl::make_unique(std::move(dw_conv)); + CreateDepthwiseConvolution(creation_context, op_def, attr, &dw_conv)); + *ptr = absl::make_unique(std::move(dw_conv)); } return absl::OkStatus(); } diff --git a/tensorflow/lite/delegates/gpu/delegate.cc b/tensorflow/lite/delegates/gpu/delegate.cc index 81974a1db68..7bfc977f7af 100644 --- a/tensorflow/lite/delegates/gpu/delegate.cc +++ b/tensorflow/lite/delegates/gpu/delegate.cc @@ -98,7 +98,7 @@ class DelegateKernel { thread_id_prepare_ = std::this_thread::get_id(); // Extract TFLite delegate execution plan from the context and convert it - // into FlowGraph32. + // into GraphFloat32. GraphFloat32 graph; RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, &graph)); diff --git a/tensorflow/lite/delegates/gpu/gl_delegate.cc b/tensorflow/lite/delegates/gpu/gl_delegate.cc index 5ebefb4a6eb..45e791b9d45 100644 --- a/tensorflow/lite/delegates/gpu/gl_delegate.cc +++ b/tensorflow/lite/delegates/gpu/gl_delegate.cc @@ -130,7 +130,7 @@ class Delegate { absl::Status Prepare(TfLiteContext* context, const TfLiteDelegateParams* delegate_params) { // Extract TFLite delegate execution plan from the context and convert it - // into FlowGraph32. + // into GraphFloat32. GraphFloat32 graph; RETURN_IF_ERROR(BuildModel(context, delegate_params, &graph)); diff --git a/tensorflow/lite/delegates/gpu/metal/api.cc b/tensorflow/lite/delegates/gpu/metal/api.cc index d9c8a369592..4246b1678c8 100644 --- a/tensorflow/lite/delegates/gpu/metal/api.cc +++ b/tensorflow/lite/delegates/gpu/metal/api.cc @@ -129,12 +129,12 @@ std::vector SelectReshape( } } -std::vector SelectSoftmax(const GraphFloat32& graph, - int id, ValueId input_id, - ValueId output_id) { +std::vector SelectSoftmax( + const GraphFloat32& graph, int id, ValueId input_id, ValueId output_id, + const DeviceInfo& device_info) { const auto src_shape = graph.FindInputs(id)[0]->tensor.shape; if (src_shape.w == 1 && src_shape.h == 1) { - return Softmax1x1(id, input_id, output_id, src_shape.c); + return Softmax1x1(id, input_id, output_id, device_info, src_shape.c); } else { return Softmax(id, input_id, output_id, src_shape.c); } @@ -146,6 +146,28 @@ std::vector SelectSpaceToDepth( return SpaceToDepth(id, input_id, output_id, attr); } +std::vector SelectWinograd4x4To36( + int id, ValueId input_id, ValueId output_id, + const Winograd4x4To36Attributes& attr, const DeviceInfo& device_info, + const metal::RuntimeOptions& options) { + if (device_info.IsAppleGPU()) { + return Winograd4x4To36(id, input_id, output_id, attr); + } else { + return Winograd4x4To36TileX6(id, input_id, output_id, attr, options); + } +} + +std::vector SelectWinograd36To4x4( + int id, ValueId input_id, ValueId output_id, + const Winograd36To4x4Attributes& attr, const DeviceInfo& device_info, + const metal::RuntimeOptions& options) { + if (device_info.IsAppleGPU()) { + return Winograd36To4x4(id, input_id, output_id, options, attr); + } else { + return Winograd36To4x4Tile4x1(id, input_id, output_id, options, attr); + } +} + bool IsSuitableForWinograd4x4To6x6(const Convolution2DAttributes& attr, const BHWC& dst_shape) { const int tiles_x = IntegralDivideRoundUp(dst_shape.w, 4); @@ -217,8 +239,8 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, wino_up_attr.padding = attr.padding; (*last_node_id) += 1; int value_id = *last_value_id + 1; - *tasks = - Winograd4x4To36(*last_node_id, inputs[0], value_id, wino_up_attr); + *tasks = SelectWinograd4x4To36(*last_node_id, inputs[0], value_id, + wino_up_attr, device_info, options); BHWC conv_shape{dst_shape.b, 36, tiles_x * tiles_y, dst_shape.c}; (*last_node_id) += 1; @@ -231,8 +253,8 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, wino_down_attr.output_shape = dst_shape; wino_down_attr.biases = attr.bias; (*last_node_id) += 1; - auto t2 = Winograd36To4x4(*last_node_id, value_id + 1, outputs[0], - options, wino_down_attr); + auto t2 = SelectWinograd36To4x4(*last_node_id, value_id + 1, outputs[0], + wino_down_attr, device_info, options); tasks->insert(tasks->end(), t2.begin(), t2.end()); (*last_value_id) += 2; } else { @@ -334,7 +356,8 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, return absl::UnimplementedError( "Softmax supports only CHANNELS dimension"); } - *tasks = SelectSoftmax(graph, node_id, inputs[0], outputs[0]); + *tasks = + SelectSoftmax(graph, node_id, inputs[0], outputs[0], device_info); break; } case OperationType::SPACE_TO_DEPTH: diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task.mm b/tensorflow/lite/delegates/gpu/metal/compute_task.mm index d3e3466ca6f..88be8676651 100644 --- a/tensorflow/lite/delegates/gpu/metal/compute_task.mm +++ b/tensorflow/lite/delegates/gpu/metal/compute_task.mm @@ -111,7 +111,7 @@ using ::tflite::gpu::ValueId; @"TO_ACCUM2_TYPE" : toAccumulatorType2, @"TO_ACCUM3_TYPE" : toAccumulatorType3, @"TO_ACCUM4_TYPE" : toAccumulatorType4, - @"BARRIER" : barrier, + @"SIMDGROUP_BARRIER" : barrier, }; NSString* code = [NSString stringWithCString:desc->shader_source.c_str() diff --git a/tensorflow/lite/delegates/gpu/metal/environment.h b/tensorflow/lite/delegates/gpu/metal/environment.h index 732dbe1d18b..14c8860dee2 100644 --- a/tensorflow/lite/delegates/gpu/metal/environment.h +++ b/tensorflow/lite/delegates/gpu/metal/environment.h @@ -57,6 +57,9 @@ struct AppleGPUInfo { // floating point rounding mode bool IsRoundToNearestSupported() const; + // returns true if device have fixed wave size equal to 32 + bool IsWaveSizeEqualTo32() const; + int GetComputeUnitsCount() const; }; @@ -75,6 +78,9 @@ struct DeviceInfo { // floating point rounding mode bool IsRoundToNearestSupported() const; + // returns true if device have fixed wave size equal to 32 + bool IsWaveSizeEqualTo32() const; + int GetComputeUnitsCount() const; }; diff --git a/tensorflow/lite/delegates/gpu/metal/environment.mm b/tensorflow/lite/delegates/gpu/metal/environment.mm index 78376b70c8c..f08a9beef47 100644 --- a/tensorflow/lite/delegates/gpu/metal/environment.mm +++ b/tensorflow/lite/delegates/gpu/metal/environment.mm @@ -78,6 +78,10 @@ bool AppleGPUInfo::IsRoundToNearestSupported() const { return IsBionic(); } +bool AppleGPUInfo::IsWaveSizeEqualTo32() const { + return true; +} + int AppleGPUInfo::GetComputeUnitsCount() const { switch (gpu_type) { case AppleGPU::kA7: @@ -135,6 +139,14 @@ bool DeviceInfo::IsRoundToNearestSupported() const { } } +bool DeviceInfo::IsWaveSizeEqualTo32() const { + if (vendor == Vendor::kApple) { + return apple_info.IsWaveSizeEqualTo32(); + } else { + return false; + } +} + int DeviceInfo::GetComputeUnitsCount() const { if (vendor == Vendor::kApple) { return apple_info.GetComputeUnitsCount(); diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD index f5ac216a9b1..a1052b8adf4 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD @@ -707,6 +707,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/common:util", "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", + "//tensorflow/lite/delegates/gpu/metal:environment", "//tensorflow/lite/delegates/gpu/metal:runtime_options", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc b/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc index 8f63fab7cf5..f9ff87e75e2 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc @@ -397,11 +397,11 @@ kernel void ComputeFunction( const int total_work_items = params.work_group_size.x * params.work_group_size.y * params.work_group_size.z; - c += " BARRIER(mem_flags::mem_none);\n"; + c += " SIMDGROUP_BARRIER(mem_flags::mem_none);\n"; c += GenerateUploadByThreads("weights_cache", "tmp", /*global_offset_name*/ "", "tid", total_work_items, local_mem_size); - c += " BARRIER(mem_flags::mem_threadgroup);\n"; + c += " SIMDGROUP_BARRIER(mem_flags::mem_threadgroup);\n"; } else if (use_simd_broadcast) { int parts = local_mem_size / simd_size; int reminder = local_mem_size % simd_size; @@ -920,7 +920,7 @@ ConvParams GetConvParamsForIntel(const Convolution2DAttributes& attr, const int dst_slices = IntegralDivideRoundUp(dst_shape.c, 4); const int src_slices = IntegralDivideRoundUp(attr.weights.shape.i, 4); ConvParams params; - params.weights_upload_type = WeightsUploadType::PRIVATE_MEM_SIMD16_BROADCAST; + params.weights_upload_type = WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST; params.x_kernel_is_1 = IsKernelXIs1(attr); params.y_kernel_is_1 = IsKernelYIs1(attr); params.src_depth_loop_size = 1; @@ -1132,8 +1132,7 @@ std::vector ConvolutionWino4x4To6x6( } } else if (device_info.IsIntelGPU()) { params.weight_layout = WeightsInnerBlockLayout::I4O4; - params.weights_upload_type = - WeightsUploadType::PRIVATE_MEM_SIMD16_BROADCAST; + params.weights_upload_type = WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST; params.work_group_size = int3(16, 1, 1); params.block_size = int3(1, 1, 4); } else if (device_info.IsAMDGPU()) { diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm index bb8121288a9..0291cd7e856 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/conv_test.mm @@ -272,8 +272,6 @@ using ::tflite::gpu::metal::SingleOpModel; attr.padding.appended.w - 2; int new_height = src_shape.h + attr.padding.prepended.h + attr.padding.appended.h - 2; - std::cout << dst_shape.w << " vs " << new_width << std::endl; - std::cout << dst_shape.h << " vs " << new_height << std::endl; BHWC conv_shape; conv_shape.b = dst_shape.b; conv_shape.h = 36; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.cc b/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.cc index 283b03ce707..331f3cc051e 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.cc @@ -37,8 +37,14 @@ namespace gpu { namespace metal { namespace { -std::string GetFullyConnectedCode(bool shared_memory, int src_channels, - int dst_channels) { +std::string GetFullyConnectedCode(const DeviceInfo& device_info, + int src_channels, int dst_channels) { + bool shared_memory = + device_info.IsAppleGPU() && + device_info.apple_info.IsLocalMemoryPreferredOverGlobal(); + const std::string barrier = device_info.IsWaveSizeEqualTo32() + ? "SIMDGROUP_BARRIER" + : "threadgroup_barrier"; const int src_depth = IntegralDivideRoundUp(src_channels, 4); std::stringstream code; code << R"( @@ -67,11 +73,11 @@ std::string GetFullyConnectedCode(bool shared_memory, int src_channels, for (int j = 0; j < $0; ++j) { local_vector[tid_index] = j * 32 + tid_index >= params.src_depth ? FLT4(0.0f) : vector[j * 32 + tid_index]; - BARRIER(mem_flags::mem_threadgroup); + $1(mem_flags::mem_threadgroup); for (uint i = 0, counter = j * 32 + tid.y * 8; i < 8; ++i, ++counter) { summa += dot(local_vector[tid.y * 8 + i], matrix[counter * params.dst_channels + ugid.x]); } - BARRIER(mem_flags::mem_none); + $1(mem_flags::mem_none); } )"; } else { @@ -92,19 +98,19 @@ std::string GetFullyConnectedCode(bool shared_memory, int src_channels, threadgroup float temp[8][4]; temp[tid.x][tid.y] = summa; - BARRIER(mem_flags::mem_threadgroup); + $1(mem_flags::mem_threadgroup); if (tid.y == 0) { summa += temp[tid.x][1]; summa += temp[tid.x][2]; summa += temp[tid.x][3]; temp[tid.x][0] = summa; } - BARRIER(mem_flags::mem_threadgroup); + $1(mem_flags::mem_threadgroup); if (tid.y == 0 && tid.x % 4 == 0 && ugid.x < params.out_channels) { const int linear_index = ugid.x / 4; FLT4 value = FLT4(temp[tid.x][0], temp[tid.x + 1][0], temp[tid.x + 2][0], temp[tid.x + 3][0]) + biases[linear_index]; - uint3 gid = uint3(1u, 1u, uint(linear_index)); + uint3 gid = uint3(0u, 0u, uint(linear_index)); $$2 result[linear_index] = value; } @@ -113,7 +119,7 @@ std::string GetFullyConnectedCode(bool shared_memory, int src_channels, const int src_depth_sub_groups = shared_memory ? IntegralDivideRoundUp(src_depth, 32) : IntegralDivideRoundUp(src_depth, 4); - return absl::Substitute(code.str(), src_depth_sub_groups); + return absl::Substitute(code.str(), src_depth_sub_groups, barrier); } } // namespace @@ -124,9 +130,8 @@ std::vector FullyConnected( auto desc = std::make_shared(); desc->id = id; desc->is_linkable = false; - bool shared = device_info.apple_info.IsLocalMemoryPreferredOverGlobal(); - desc->shader_source = - GetFullyConnectedCode(shared, attr.weights.shape.i, attr.weights.shape.o); + desc->shader_source = GetFullyConnectedCode(device_info, attr.weights.shape.i, + attr.weights.shape.o); desc->input_buffers = { {input_id, "device FLT4* const vector"}, @@ -138,8 +143,11 @@ std::vector FullyConnected( return CalculateOutputShape(buffers.find(input_id)->second, attr); }}; + bool shared_memory = + device_info.IsAppleGPU() && + device_info.apple_info.IsLocalMemoryPreferredOverGlobal(); const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); - const int src_depth_aligned = AlignByN(src_depth, shared ? 32 : 4); + const int src_depth_aligned = AlignByN(src_depth, shared_memory ? 32 : 4); const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 8); int counter = 0; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/softmax.cc b/tensorflow/lite/delegates/gpu/metal/kernels/softmax.cc index 3b4fbea4aef..0ed2e0650e1 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/softmax.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/softmax.cc @@ -25,13 +25,17 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/environment.h" #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" namespace tflite { namespace gpu { namespace metal { namespace { -std::string GetSoftmax1x1Code() { +std::string GetSoftmax1x1Code(const DeviceInfo& device_info) { + const std::string barrier = device_info.IsWaveSizeEqualTo32() + ? "SIMDGROUP_BARRIER" + : "threadgroup_barrier"; std::string code = R"( #include using namespace metal; @@ -63,7 +67,9 @@ kernel void ComputeFunction($1 threadgroup float4 tmp[8]; threadgroup float* tmpx1 = (threadgroup float*)tmp; tmpx1[tid] = sum; - BARRIER(mem_flags::mem_threadgroup); +)"; + code += " " + barrier + "(mem_flags::mem_threadgroup);\n"; + code += R"( if (tid == 0) { sum = dot(float4(1.0f), tmp[0]); sum += dot(float4(1.0f), tmp[1]); @@ -75,7 +81,9 @@ kernel void ComputeFunction($1 sum += dot(float4(1.0f), tmp[7]); tmpx1[0] = 1.0 / sum; } - BARRIER(mem_flags::mem_threadgroup); +)"; + code += " " + barrier + "(mem_flags::mem_threadgroup);\n"; + code += R"( sum = tmpx1[0]; offset = 0; @@ -171,11 +179,12 @@ std::vector Softmax(int id, ValueId input_id, std::vector Softmax1x1(int id, ValueId input_id, ValueId output_id, + const DeviceInfo& device_info, int channels_count) { auto desc = std::make_shared(); desc->id = id; desc->is_linkable = false; - desc->shader_source = GetSoftmax1x1Code(); + desc->shader_source = GetSoftmax1x1Code(device_info); desc->input_buffers = { {input_id, "device FLT4* const src_buffer"}, diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/softmax.h b/tensorflow/lite/delegates/gpu/metal/kernels/softmax.h index 24fa38e8f57..2745d1f0c3e 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/softmax.h +++ b/tensorflow/lite/delegates/gpu/metal/kernels/softmax.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/environment.h" #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" namespace tflite { @@ -35,6 +36,7 @@ std::vector Softmax(int id, ValueId input_id, // We have this case in MobilenetV1/V2. std::vector Softmax1x1(int id, ValueId input_id, ValueId output_id, + const DeviceInfo& device_info, int channels_count); } // namespace metal diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.cc b/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.cc index 1b6e6963fb5..56630c5d2af 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.cc @@ -275,7 +275,18 @@ std::string GetDeconvolutionShared(const ConvolutionTransposedAttributes& attr, src_local_size_x, src_local_size_y, workgroup_x, workgroup_y); } -std::string GetDeconvolution4x4(const int2& block_size, bool use_local_mem) { +std::string GetDeconvolution4x4(const int2& block_size, + const DeviceInfo& device_info) { + bool use_local_mem = false; + if (device_info.IsAppleGPU() && device_info.apple_info.IsBionic()) { + use_local_mem = true; + } + if (device_info.IsIntelGPU()) { + use_local_mem = true; + } + const std::string barrier = device_info.IsWaveSizeEqualTo32() + ? "SIMDGROUP_BARRIER" + : "threadgroup_barrier"; std::string c = R"( #include using namespace metal; @@ -349,7 +360,7 @@ std::string GetDeconvolution4x4(const int2& block_size, bool use_local_mem) { } c += " for (int s = 0; s < params.src_size.z; ++s) {\n"; if (use_local_mem) { - c += " BARRIER(mem_flags::mem_none);\n"; + c += " " + barrier + "(mem_flags::mem_none);\n"; c += " weights_cache[local_id] = filters[f_offset + local_id];\n"; c += " weights_cache[local_id + 32] = filters[f_offset + local_id + " "32];\n"; @@ -365,7 +376,7 @@ std::string GetDeconvolution4x4(const int2& block_size, bool use_local_mem) { } c += " f_offset += 64;\n"; if (use_local_mem) { - c += " BARRIER(mem_flags::mem_threadgroup);\n"; + c += " " + barrier + "(mem_flags::mem_threadgroup);\n"; } for (int i = 0; i < 16; ++i) { const int result_sub_pixel_id = i % 4; @@ -595,12 +606,20 @@ std::vector ConvolutionTransposed4x4( desc->id = id; desc->is_linkable = false; - const bool recommended_2x = - device_info.apple_info.IsBionic() && - options.storage_precision == RuntimeOptions::Precision::FP16; - const bool use_local_mem = !device_info.apple_info.IsBionic(); + bool recommended_2x = false; + if (device_info.IsAppleGPU()) { + if (device_info.apple_info.IsBionic() && + options.storage_precision == RuntimeOptions::Precision::FP16) { + recommended_2x = true; + } + } else { + if (options.storage_precision == RuntimeOptions::Precision::FP16) { + recommended_2x = true; + } + } + const int2 block_size(recommended_2x ? 2 : 1, 1); - desc->shader_source = GetDeconvolution4x4(block_size, use_local_mem); + desc->shader_source = GetDeconvolution4x4(block_size, device_info); desc->input_buffers = { {input_id, "device FLT4* const src_buffer"}, diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/winograd.cc b/tensorflow/lite/delegates/gpu/metal/kernels/winograd.cc index f1c9d75e62a..6d68e9e6704 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/winograd.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/winograd.cc @@ -129,6 +129,134 @@ kernel void ComputeFunction($1 return c; } +std::string GetKernelWinograd4x4To36TileX6() { + std::string c = R"( +#include +using namespace metal; + +struct uniforms { + int4 src_size; + int4 dst_size; + int2 padding; + int2 tiles; +}; +)"; + auto bt_mat = BtMatrixForWinograd4x4To6x6(); + c += "constant FLT Bt[36] = {\n"; + for (int y = 0; y < 6; ++y) { + c += "\t"; + for (int x = 0; x < 6; ++x) { + c += absl::StrFormat("%.10f", bt_mat[y * 6 + x]) + "f, "; + } + c += "\n"; + } + c += "};\n"; + c += R"( + +$0 + +kernel void ComputeFunction($1 + uint3 global_ids[[thread_position_in_grid]]) +{ + int DST_X = global_ids.x; + int DST_Y = global_ids.y; + int DST_Z = global_ids.z; + if (DST_X >= U.tiles.y || DST_Y >= 6 || DST_Z >= U.dst_size.z) { + return; + } + int tile_x = (DST_X % U.tiles.x) * 4; + int tile_y = (DST_X / U.tiles.x) * 4; + FLT4 I0, I1, I2, I3, I4, I5; + FLT bt_ar[6]; + FLT4 t0 = bt_arr[DST_Y * 2 + 0]; + FLT4 t1 = bt_arr[DST_Y * 2 + 1]; + DST_Y *= 6; + bt_ar[0] = t0.x; + bt_ar[1] = t0.y; + bt_ar[2] = t0.z; + bt_ar[3] = t0.w; + bt_ar[4] = t1.x; + bt_ar[5] = t1.y; +)"; + auto read_src = [&](const std::string& src, const std::string& xs) { + c += " FLT4 " + src + " = src_buffer[src_a_" + xs + " + offset] * m" + + xs + "_x;\n"; + }; + for (int x = 0; x < 6; ++x) { + const std::string xs = std::to_string(x); + c += " int xc" + xs + " = tile_x + U.padding.x + " + xs + ";\n"; + c += " FLT m" + xs + "_x = xc" + xs + " >= 0 && xc" + xs + + " < U.src_size.x;\n"; + c += " bool inx" + xs + " = (xc" + xs + " >= 0 && xc" + xs + + " < U.src_size.x);\n"; + c += " xc" + xs + " = clamp(xc" + xs + ", 0, U.src_size.x - 1);\n"; + c += " int src_a_" + xs + " = DST_Z * U.src_size.x * U.src_size.y + xc" + + xs + ";\n"; + } + c += " {\n"; + c += " int yc = tile_y + U.padding.y;\n"; + c += " bool iny = (yc >= 0 && yc < U.src_size.y);\n"; + c += " yc = clamp(yc, 0, U.src_size.y - 1);\n"; + c += " int offset = yc * U.src_size.x;\n"; + c += " FLT bt = bt_ar[0] * FLT(iny);\n"; + for (int x = 0; x < 6; ++x) { + const std::string xs = std::to_string(x); + const std::string src = "src" + xs; + read_src(src, xs); + c += " I" + xs + " = bt * " + src + ";\n"; + } + c += " }\n"; + for (int y = 1; y < 6; ++y) { + const std::string ys = std::to_string(y); + c += " {\n"; + c += " int yc = tile_y + U.padding.y + (" + ys + ");\n"; + c += " bool iny = (yc >= 0 && yc < U.src_size.y);\n"; + c += " yc = clamp(yc, 0, U.src_size.y - 1);\n"; + c += " int offset = yc * U.src_size.x;\n"; + c += " FLT bt = bt_ar[" + ys + "] * FLT(iny);\n"; + for (int x = 0; x < 6; ++x) { + const std::string xs = std::to_string(x); + const std::string src = "src" + xs; + read_src(src, xs); + c += " I" + xs + " += bt * " + src + ";\n"; + } + c += " }\n"; + } + c += R"( + { + FLT4 r0 = I0 + Bt[2] * I2 + Bt[4] * I4; + dst_buffer[(DST_Z * U.dst_size.y + DST_Y) * U.dst_size.x + DST_X] = r0; + DST_Y++; + } + { + FLT4 r0 = Bt[7] * I1 + Bt[8] * I2 + Bt[9] * I3 + Bt[10] * I4; + dst_buffer[(DST_Z * U.dst_size.y + DST_Y) * U.dst_size.x + DST_X] = r0; + DST_Y++; + } + { + FLT4 r0 = Bt[13] * I1 + Bt[14] * I2 + Bt[15] * I3 + Bt[16] * I4; + dst_buffer[(DST_Z * U.dst_size.y + DST_Y) * U.dst_size.x + DST_X] = r0; + DST_Y++; + } + { + FLT4 r0 = Bt[19] * I1 + Bt[20] * I2 + Bt[21] * I3 + Bt[22] * I4; + dst_buffer[(DST_Z * U.dst_size.y + DST_Y) * U.dst_size.x + DST_X] = r0; + DST_Y++; + } + { + FLT4 r0 = Bt[25] * I1 + Bt[26] * I2 + Bt[27] * I3 + Bt[28] * I4; + dst_buffer[(DST_Z * U.dst_size.y + DST_Y) * U.dst_size.x + DST_X] = r0; + DST_Y++; + } + { + FLT4 r0 = Bt[31] * I1 + Bt[33] * I3 + I5; + dst_buffer[(DST_Z * U.dst_size.y + DST_Y) * U.dst_size.x + DST_X] = r0; + } +} +)"; + return c; +} + std::string GetKernelWinograd36To4x4() { std::string c; c += R"( @@ -222,6 +350,117 @@ kernel void ComputeFunction($1 )"; return c; } + +std::string GetKernelWinograd36To4x4Tile4x1() { + std::string c; + c += R"( +#include +using namespace metal; + +struct uniforms { + int4 src_size; + int4 dst_size; + int4 tiles; +}; +)"; + auto at_mat = AtMatrixForWinograd4x4To6x6(); + c += "constant FLT At[24] = {\n"; + for (int y = 0; y < 4; ++y) { + c += "\t"; + for (int x = 0; x < 6; ++x) { + c += absl::StrFormat("%.10f", at_mat[y * 6 + x]) + "f, "; + } + c += "\n"; + } + c += "};\n"; + c += R"( + +$0 + +kernel void ComputeFunction($1 + uint3 global_ids[[thread_position_in_grid]]) +{ + int tile_id = global_ids.x; + int DST_Y = global_ids.y; + int DST_Z = global_ids.z; + int tile_x = (tile_id % U.tiles.x) * 4; + int tile_y = (tile_id / U.tiles.x) * 4 + DST_Y; + if (tile_x >= U.dst_size.x || tile_y >= U.dst_size.y || DST_Z >= U.dst_size.z) { + return; + } + FLT4 I0, I1, I2, I3, I4, I5; + FLT at_ar[6]; + FLT4 t00 = at_arr[DST_Y * 2 + 0]; + FLT4 t01 = at_arr[DST_Y * 2 + 1]; + at_ar[0] = t00.x; + at_ar[1] = t00.y; + at_ar[2] = t00.z; + at_ar[3] = t00.w; + at_ar[4] = t01.x; + at_ar[5] = t01.y; + int src_adress = DST_Z * U.src_size.y * U.src_size.x + tile_id; + { + FLT at = at_ar[0]; +)"; + for (int x = 0; x < 6; ++x) { + const std::string yc = std::to_string(x); + const std::string src = "src" + std::to_string(x); + c += " FLT4 " + src + " = src_buffer[src_adress + U.src_size.x * " + yc + + "];\n"; + c += " I" + std::to_string(x) + " = at * " + src + ";\n"; + } + c += " }\n"; + for (int y = 1; y < 6; ++y) { + c += " {\n"; + c += " FLT at = at_ar[" + std::to_string(y) + "];\n"; + for (int x = 0; x < 6; ++x) { + const std::string yc = std::to_string(y * 6 + x); + const std::string src = "src" + std::to_string(x); + c += " FLT4 " + src + " = src_buffer[src_adress + U.src_size.x * " + + yc + "];\n"; + c += " I" + std::to_string(x) + " += at * " + src + ";\n"; + } + c += " }\n"; + } + c += R"( + FLT4 t0 = I1 + I2; + FLT4 t1 = I3 + I4; + FLT4 bias_val = biases[DST_Z]; + int dst_adress = (DST_Z * U.dst_size.y + tile_y) * U.dst_size.x + tile_x; + if (tile_x < U.dst_size.x) { + FLT4 value = I0 + t0 + t1 + bias_val; + uint3 gid = uint3(tile_x, tile_y, global_ids.z); + int linear_index = dst_adress; + $2; + dst_buffer[linear_index] = value; + } + FLT4 t2 = I1 - I2; + FLT4 t3 = I3 - I4; + if (tile_x + 1 < U.dst_size.x) { + FLT4 value = t2 * At[7] + t3 * At[9] + bias_val; + uint3 gid = uint3(tile_x + 1, tile_y, global_ids.z); + int linear_index = dst_adress + 1; + $2; + dst_buffer[linear_index] = value; + } + if (tile_x + 2 < U.dst_size.x) { + FLT4 value = t0 * At[13] + t1 * At[15] + bias_val; + uint3 gid = uint3(tile_x + 2, tile_y, global_ids.z); + int linear_index = dst_adress + 2; + $2; + dst_buffer[linear_index] = value; + } + if (tile_x + 3 < U.dst_size.x) { + FLT4 value = t2 * At[19] + t3 * At[21] + I5 + bias_val; + uint3 gid = uint3(tile_x + 3, tile_y, global_ids.z); + int linear_index = dst_adress + 3; + $2; + dst_buffer[linear_index] = value; + } +} +)"; + return c; +} } // namespace std::vector Winograd4x4To36( @@ -301,6 +540,94 @@ std::vector Winograd4x4To36( return {desc}; } +std::vector Winograd4x4To36TileX6( + int id, ValueId input_id, ValueId output_id, + const Winograd4x4To36Attributes& attr, const RuntimeOptions& options) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + desc->shader_source = GetKernelWinograd4x4To36TileX6(); + + desc->input_buffers = { + {input_id, "device FLT4* const src_buffer"}, + }; + + desc->output_buffer = { + output_id, "device FLT4* dst_buffer", + [input_id, attr](const std::map& buffers) { + const auto src_shape = buffers.find(input_id)->second; + int new_width = src_shape.w + attr.padding.prepended.w + + attr.padding.appended.w - 2; + int new_height = src_shape.h + attr.padding.prepended.h + + attr.padding.appended.h - 2; + BHWC dst_shape; + dst_shape.b = src_shape.b; + dst_shape.h = 36; + dst_shape.w = IntegralDivideRoundUp(new_width, 4) * + IntegralDivideRoundUp(new_height, 4); + dst_shape.c = src_shape.c; + return dst_shape; + }}; + + std::vector bt_aligned(6 * 8); + auto bt_mat = BtMatrixForWinograd4x4To6x6(); + for (int y = 0; y < 6; ++y) { + for (int x = 0; x < 6; ++x) { + bt_aligned[y * 8 + x] = bt_mat[y * 6 + x]; + } + bt_aligned[y * 8 + 6] = 0.0f; + bt_aligned[y * 8 + 7] = 0.0f; + } + + desc->immutable_buffers = { + {"device FLT4* const bt_arr", + GetByteBufferConverted(bt_aligned, options.storage_precision)}, + }; + + desc->uniform_buffers = { + {"constant uniforms& U", + [input_id, output_id, attr](const std::map& buffers) { + const auto& src_shape = buffers.find(input_id)->second; + const auto& dst_shape = buffers.find(output_id)->second; + int new_width = src_shape.w + attr.padding.prepended.w + + attr.padding.appended.w - 2; + int new_height = src_shape.h + attr.padding.prepended.h + + attr.padding.appended.h - 2; + int tiles_x = IntegralDivideRoundUp(new_width, 4); + int tiles_y = IntegralDivideRoundUp(new_height, 4); + std::vector sizes = { + src_shape.w, + src_shape.h, + IntegralDivideRoundUp(src_shape.c, 4), + 0, + dst_shape.w, + dst_shape.h, + IntegralDivideRoundUp(dst_shape.c, 4), + 0, + -attr.padding.prepended.w, + -attr.padding.prepended.h, + tiles_x, + tiles_x * tiles_y, + }; + return GetByteBuffer(sizes); + }}, + }; + + desc->resize_function = [output_id, + attr](const std::map& buffers) { + const uint3 groups_size{4, 6, 1}; + const auto& dst_shape = buffers.find(output_id)->second; + int grid_x = dst_shape.w; + int grid_y = 6; + int grid_z = IntegralDivideRoundUp(dst_shape.c, 4); + int groups_x = IntegralDivideRoundUp(grid_x, groups_size.x); + int groups_y = IntegralDivideRoundUp(grid_y, groups_size.y); + int groups_z = IntegralDivideRoundUp(grid_z, groups_size.z); + return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); + }; + return {desc}; +} + std::vector Winograd36To4x4( int id, ValueId input_id, ValueId output_id, const RuntimeOptions& options, const Winograd36To4x4Attributes& attr) { @@ -359,6 +686,90 @@ std::vector Winograd36To4x4( return {desc}; } +std::vector Winograd36To4x4Tile4x1( + int id, ValueId input_id, ValueId output_id, const RuntimeOptions& options, + const Winograd36To4x4Attributes& attr) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + desc->shader_source = GetKernelWinograd36To4x4Tile4x1(); + + desc->input_buffers = { + {input_id, "device FLT4* const src_buffer"}, + }; + + desc->output_buffer = { + output_id, "device FLT4* dst_buffer", + [input_id, attr](const std::map& buffers) { + const auto src_shape = buffers.find(input_id)->second; + BHWC dst_shape; + dst_shape.b = src_shape.b; + dst_shape.h = attr.output_shape.h; + dst_shape.w = attr.output_shape.w; + dst_shape.c = src_shape.c; + return dst_shape; + }}; + + std::vector at_aligned(4 * 8); + auto at_mat = AtMatrixForWinograd4x4To6x6(); + for (int y = 0; y < 4; ++y) { + for (int x = 0; x < 6; ++x) { + at_aligned[y * 8 + x] = at_mat[y * 6 + x]; + } + at_aligned[y * 8 + 6] = 0.0f; + at_aligned[y * 8 + 7] = 0.0f; + } + + desc->immutable_buffers = { + {"device FLT4* const biases", + GetByteBufferConvertedResized(attr.biases.data, + options.storage_precision, + AlignByN(attr.output_shape.c, 4))}, + {"device FLT4* const at_arr", + GetByteBufferConverted(at_aligned, options.storage_precision)}, + }; + + desc->uniform_buffers = { + {"constant uniforms& U", + [input_id, output_id](const std::map& buffers) { + const auto& src_shape = buffers.find(input_id)->second; + const auto& dst_shape = buffers.find(output_id)->second; + const int tiles_x = IntegralDivideRoundUp(dst_shape.w, 4); + const int tiles_y = IntegralDivideRoundUp(dst_shape.h, 4); + std::vector sizes = { + src_shape.w, + src_shape.h, + IntegralDivideRoundUp(src_shape.c, 4), + 0, + dst_shape.w, + dst_shape.h, + IntegralDivideRoundUp(dst_shape.c, 4), + 0, + tiles_x, + tiles_y, + 0, + 0, + }; + return GetByteBuffer(sizes); + }}, + }; + + desc->resize_function = [output_id](const std::map& buffers) { + const uint3 groups_size{8, 4, 1}; + const auto& dst_shape = buffers.find(output_id)->second; + const int tiles_x = IntegralDivideRoundUp(dst_shape.w, 4); + const int tiles_y = IntegralDivideRoundUp(dst_shape.h, 4); + int grid_x = tiles_x * tiles_y; + int grid_y = 4; + int grid_z = IntegralDivideRoundUp(dst_shape.c, 4); + int groups_x = IntegralDivideRoundUp(grid_x, groups_size.x); + int groups_y = IntegralDivideRoundUp(grid_y, groups_size.y); + int groups_z = IntegralDivideRoundUp(grid_z, groups_size.z); + return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); + }; + return {desc}; +} + } // namespace metal } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/winograd.h b/tensorflow/lite/delegates/gpu/metal/kernels/winograd.h index 26c18538fd9..e231e1eb1cc 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/winograd.h +++ b/tensorflow/lite/delegates/gpu/metal/kernels/winograd.h @@ -33,6 +33,10 @@ std::vector Winograd4x4To36( int id, ValueId input_id, ValueId output_id, const Winograd4x4To36Attributes& attr); +std::vector Winograd4x4To36TileX6( + int id, ValueId input_id, ValueId output_id, + const Winograd4x4To36Attributes& attr, const RuntimeOptions& options); + struct Winograd36To4x4Attributes { BHWC output_shape; tflite::gpu::Tensor biases; @@ -42,6 +46,10 @@ std::vector Winograd36To4x4( int id, ValueId input_id, ValueId output_id, const RuntimeOptions& options, const Winograd36To4x4Attributes& attr); +std::vector Winograd36To4x4Tile4x1( + int id, ValueId input_id, ValueId output_id, const RuntimeOptions& options, + const Winograd36To4x4Attributes& attr); + } // namespace metal } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/winograd_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/winograd_test.mm index 67290730062..95f17ccc1c4 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/winograd_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/winograd_test.mm @@ -100,6 +100,68 @@ using ::tflite::gpu::metal::CompareVectors; XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } +- (void)testWinograd4x4To36TileX6 { + TensorFloat32 src_tensor; + src_tensor.shape = BHWC(1, 4, 4, 1); + src_tensor.data.resize(16); + for (int i = 0; i < 16; ++i) { + src_tensor.data[i] = sin(i); + } + + TensorFloat32 dst_tensor; + dst_tensor.shape = BHWC(1, 36, 1, 1); + dst_tensor.data.resize(36, 0.0f); + auto b_t = tflite::gpu::BtMatrixForWinograd4x4To6x6(); + + // Bt * Src * B + // 1: temp = Src * B + std::vector temp(36, 0.0f); + for (int y = 0; y < 6; ++y) { + for (int x = 0; x < 6; ++x) { + float sum = 0.0f; + for (int i = 0; i < 6; ++i) { + if (y < 1 || y > 4 || i < 1 || i > 4) continue; + const int index = src_tensor.shape.LinearIndex({0, y - 1, i - 1, 0}); + sum += src_tensor.data[index] * b_t[x * 6 + i]; + } + temp[y * 6 + x] = sum; + } + } + // 2: dst_tensor = Bt * temp + for (int y = 0; y < 6; ++y) { + for (int x = 0; x < 6; ++x) { + float sum = 0.0f; + for (int i = 0; i < 6; ++i) { + sum += b_t[y * 6 + i] * temp[i * 6 + x]; + } + const int index = dst_tensor.shape.LinearIndex({0, y * 6 + x, 0, 0}); + dst_tensor.data[index] = sum; + } + } + + tflite::gpu::metal::RuntimeOptions options; + options.storage_precision = tflite::gpu::metal::RuntimeOptions::Precision::FP32; + options.accumulator_precision = tflite::gpu::metal::RuntimeOptions::Precision::FP32; + + tflite::gpu::metal::Winograd4x4To36Attributes attr; + attr.padding.prepended = tflite::gpu::HW(1, 1); + attr.padding.appended = tflite::gpu::HW(1, 1); + auto tasks = tflite::gpu::metal::Winograd4x4To36TileX6(0, 0, 1, attr, options); + + std::map inputs; + inputs[0] = src_tensor; + std::map outputs; + outputs[1].shape = BHWC(1, 36, 1, 1); + outputs[1].data.resize(36, 0.0f); + + id device = MTLCreateSystemDefaultDevice(); + auto status = RunGraph(tasks, device, inputs, &outputs); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + + status = CompareVectors(dst_tensor.data, outputs[1].data, 1e-6f); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); +} + - (void)testWinograd36To4x4 { TensorFloat32 src_tensor; src_tensor.shape = BHWC(1, 36, 1, 1); @@ -163,4 +225,67 @@ using ::tflite::gpu::metal::CompareVectors; XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); } +- (void)testWinograd36To4x4Tile4x1 { + TensorFloat32 src_tensor; + src_tensor.shape = BHWC(1, 36, 1, 1); + src_tensor.data.resize(36); + for (int i = 0; i < 36; ++i) { + src_tensor.data[i] = sin(i); + } + + TensorFloat32 dst_tensor; + dst_tensor.shape = BHWC(1, 4, 4, 1); + dst_tensor.data.resize(16, 0.0f); + auto a_t = tflite::gpu::AtMatrixForWinograd4x4To6x6(); + + // At * Src * A + // 1: temp = Src * A + std::vector temp(24, 0.0f); + for (int y = 0; y < 6; ++y) { + for (int x = 0; x < 4; ++x) { + float sum = 0.0f; + for (int i = 0; i < 6; ++i) { + const int index = src_tensor.shape.LinearIndex({0, y * 6 + i, 0, 0}); + sum += src_tensor.data[index] * a_t[x * 6 + i]; + } + temp[y * 4 + x] = sum; + } + } + // 2: dst_tensor = At * temp + for (int y = 0; y < 4; ++y) { + for (int x = 0; x < 4; ++x) { + float sum = 0.0f; + for (int i = 0; i < 6; ++i) { + sum += a_t[y * 6 + i] * temp[i * 4 + x]; + } + const int index = dst_tensor.shape.LinearIndex({0, y, x, 0}); + dst_tensor.data[index] = sum; + } + } + + tflite::gpu::metal::Winograd36To4x4Attributes attr; + attr.output_shape = BHWC(1, 4, 4, 1); + attr.biases.shape = tflite::gpu::Linear(1); + attr.biases.data.resize(1, 0.0f); + + tflite::gpu::metal::RuntimeOptions options; + options.storage_precision = tflite::gpu::metal::RuntimeOptions::Precision::FP32; + options.accumulator_precision = tflite::gpu::metal::RuntimeOptions::Precision::FP32; + + auto tasks = tflite::gpu::metal::Winograd36To4x4Tile4x1(0, 0, 1, options, attr); + + std::map inputs; + inputs[0] = src_tensor; + std::map outputs; + outputs[1].shape = BHWC(1, 4, 4, 1); + outputs[1].data.resize(16, 0.0f); + + id device = MTLCreateSystemDefaultDevice(); + auto status = RunGraph(tasks, device, inputs, &outputs); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); + + status = CompareVectors(dst_tensor.data, outputs[1].data, 1e-6f); + XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); +} + @end diff --git a/tensorflow/lite/delegates/gpu/metal_delegate.mm b/tensorflow/lite/delegates/gpu/metal_delegate.mm index 797a2c4e4c9..16fce886a17 100644 --- a/tensorflow/lite/delegates/gpu/metal_delegate.mm +++ b/tensorflow/lite/delegates/gpu/metal_delegate.mm @@ -226,7 +226,7 @@ class Delegate { } absl::Status Prepare(TfLiteContext* context, const TfLiteDelegateParams* delegate_params) { - // Extract TFLite delegate execution plan from the context and convert it into FlowGraph32. + // Extract TFLite delegate execution plan from the context and convert it into GraphFloat32. GraphFloat32 graph; RETURN_IF_ERROR(BuildModel(context, delegate_params, &graph)); diff --git a/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc b/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc index 1e4eae3d2a7..d00150bec40 100644 --- a/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc +++ b/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc @@ -129,6 +129,7 @@ ConcatenationOpTest/FourInputsQuantizedMixedRange,29 ConcatenationOpTest/FourInputsQuantizedMixedRangeClampingLogic,29 # conv_test +-ConvolutionOpTest/ConvolutionOpTest.SimplePerTensorTest/.+ ConvolutionOpTest/ConvolutionOpTest.SimpleTestFloatWithDilation/.+,29 ConvolutionOpTest/ConvolutionOpTest.SimpleTestLargeIrregularQuantized/.+,29 ConvolutionOpTest/ConvolutionOpTest.SimpleTestQuantizedOutputMultiplierGreaterThan1/.+,29 @@ -150,6 +151,7 @@ DepthToSpaceOpModel/UInt8 DepthToSpaceOpModel/int8 # div_test +-FloatDivOpTest/WithBroadcast5D FloatDivOpTest/.+ # elementwise_test @@ -364,6 +366,8 @@ TopKV2OpTest/TopKV2OpTest/.+/0,29 TransposeTest/.+ # transpose_conv_test +-TransposeConvOpTest/TransposeConvOpTest.SimpleTestQuantizedPerChannelSingleChannel/0 +-TransposeConvOpTest/TransposeConvOpTest.TestQuantizedPerChannelMultiChannel/0 # Const tensor only TransposeConvOpTest/TransposeConvOpTest/.+/0,29 diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc index 25a09943394..2e993fe820c 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc @@ -1042,6 +1042,7 @@ class NNAPIOpBuilder { int32_t nn_type = 0; float scale = 0.0f; int32_t zeroPoint = 0; + ANeuralNetworksSymmPerChannelQuantParams ann_perchannel_params; TfLiteTensor* tensor = &context_->tensors[tensor_index]; TfLiteType tensor_type = tensor->type; if (hybrid_op && (tensor_type == kTfLiteUInt8)) { @@ -1067,14 +1068,37 @@ class NNAPIOpBuilder { : ANEURALNETWORKS_TENSOR_QUANT8_SYMM; scale = tensor->params.scale; zeroPoint = tensor->params.zero_point; - if (need_int8_conversion) { - zeroPoint += 128; - operand_mapping_->add_type_conversion(tensor_index, kTfLiteUInt8); + if (tensor->quantization.type == kTfLiteAffineQuantization) { + TfLiteAffineQuantization* quantization_params = + static_cast( + tensor->quantization.params); + if (quantization_params->scale->size > 1) { + // Set up per-channel quantization. + ann_perchannel_params = { + .channelDim = static_cast( + quantization_params->quantized_dimension), + .scaleCount = + static_cast(quantization_params->scale->size), + .scales = quantization_params->scale->data, + }; + nn_type = ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL; + scale = 0.0f; + zeroPoint = 0; + } else if (quantization_params->scale->size == 1) { + scale = quantization_params->scale->data[0]; + zeroPoint = quantization_params->zero_point->data[0]; + } } - if (scale == 0) { - // TENSOR_QUANT8_ASYMM and ANEURALNETWORKS_TENSOR_QUANT8_ASYMM - // with zero scale are not valid in NNAPI. - scale = 1; + if (nn_type != ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL) { + if (need_int8_conversion) { + zeroPoint += 128; + operand_mapping_->add_type_conversion(tensor_index, kTfLiteUInt8); + } + if (scale == 0) { + // TENSOR_QUANT8_ASYMM and ANEURALNETWORKS_TENSOR_QUANT8_ASYMM + // with zero scale are not valid in NNAPI. + scale = 1; + } } break; case kTfLiteInt32: @@ -1107,26 +1131,6 @@ class NNAPIOpBuilder { // if the tensor_rank is 0, the dimension ptr must be nullptr. tensor_dims = nullptr; } - ANeuralNetworksSymmPerChannelQuantParams ann_perchannel_params; - if (tensor_type == kTfLiteInt8 || tensor_type == kTfLiteUInt8) { - if (tensor->quantization.type == kTfLiteAffineQuantization) { - TfLiteAffineQuantization* quantization_params = - static_cast(tensor->quantization.params); - if (quantization_params->scale->size > 1) { - // Set up per-channel quantization. - ann_perchannel_params = { - .channelDim = static_cast( - quantization_params->quantized_dimension), - .scaleCount = - static_cast(quantization_params->scale->size), - .scales = quantization_params->scale->data, - }; - nn_type = ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL; - scale = 0.0f; - zeroPoint = 0; - } - } - } ANeuralNetworksOperandType operand_type{nn_type, tensor_rank, tensor_dims, scale, zeroPoint}; @@ -1903,6 +1907,13 @@ bool NNAPIDelegateKernel::Validate( ExpectOpVersion(version, 1, &val_ctx); ExpectMinAndroidSdkVersion(android_sdk_version, kMinSdkVersionForNNAPI12, &val_ctx); + Expect((node->inputs->size > 1) && + (context->tensors[node->inputs->data[0]].allocation_type == + kTfLiteMmapRo) && + (context->tensors[node->inputs->data[1]].allocation_type == + kTfLiteMmapRo), + NNAPIValidationFailureType::kInputTensorShouldHaveConstantShape, + "Dynamically-sized tensors not supported.", &val_ctx); } break; case kTfLiteBuiltinSqrt: { ExpectOpVersion(version, 1, &val_ctx); @@ -2577,10 +2588,24 @@ TfLiteStatus NNAPIDelegateKernel::Map( case kTfLiteBuiltinTransposeConv: { const bool hybrid_op = IsHybridOperator( mapping_args.context, kTfLiteBuiltinTransposeConv, mapping_args.node); - mapping_args.builder->AddTensorInput( - mapping_args.node->inputs->data[/*kDataInputTensor*/ 2], hybrid_op); - mapping_args.builder->AddTensorInput( - mapping_args.node->inputs->data[/*kWeightsTensor*/ 1], hybrid_op); + int input_tensor_flags = 0; + const int input_tensor_id = + mapping_args.node->inputs->data[/*kDataInputTensor*/ 2]; + const int weight_tensor_id = + mapping_args.node->inputs->data[/*kWeightsTensor*/ 1]; + if (context->tensors[input_tensor_id].type == kTfLiteInt8) { + const auto& weights_tensor = context->tensors[weight_tensor_id]; + if ((weights_tensor.type == kTfLiteInt8 || + weights_tensor.type == kTfLiteUInt8) && + weights_tensor.quantization.type == kTfLiteAffineQuantization) { + input_tensor_flags |= NN_TENSOR_FLAG_SCALAR_AS_TENSOR; + } + } + + mapping_args.builder->AddTensorInput(input_tensor_id, hybrid_op, + input_tensor_flags); + mapping_args.builder->AddTensorInput(weight_tensor_id, hybrid_op, + input_tensor_flags); // NNAPI requires a bias tensor, so we allocate a new tensor to fill // it with zeroes. It is deleted with other tensors in the context @@ -3508,6 +3533,8 @@ TfLiteStatus NNAPIDelegateKernel::AddOpsAndTensors(TfLiteContext* context, if (need_int8_conversion && (input_pos == 0 || reg->builtin_code == kTfLiteBuiltinFullyConnected || + reg->builtin_code == kTfLiteBuiltinConv2d || + reg->builtin_code == kTfLiteBuiltinDepthwiseConv2d || reg->builtin_code == kTfLiteBuiltinAdd || reg->builtin_code == kTfLiteBuiltinMul || reg->builtin_code == kTfLiteBuiltinSub || @@ -4260,8 +4287,10 @@ TfLiteStatus StatefulNnApiDelegate::DoPrepare(TfLiteContext* context, int num_partitions; TfLiteDelegateParams* params_array; - if (is_accelerator_specified) { - // Filtering out nodes not supported by target accelerators + if (is_accelerator_specified && + nnapi->android_sdk_version >= kMinSdkVersionForNNAPI12) { + // Filtering out nodes not supported by target accelerators. + // Cannot query supported operation before NNAPI 1.2 TF_LITE_ENSURE_STATUS(GetNodesSupportedByAccelerator( context, delegate, nnapi, supported_nodes, &nodes_to_delegate, &num_partitions, ¶ms_array, nnapi_errno)); diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_device_selection_test.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_device_selection_test.cc index d6183e63013..4f80b95ac3d 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_device_selection_test.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_device_selection_test.cc @@ -546,6 +546,26 @@ TEST_F(UnsupportedOperationOnDeviceTest, ShouldCacheModelCompilation) { EXPECT_EQ(should_cache_model_compilation_model_create_count, 1); } +TEST_F(UnsupportedOperationOnDeviceTest, + ShouldNotApplySupportedOperationsFilterBeforeAndroidSdk29) { + nnapi_mock_->SetAndroidSdkVersion(28, /*set_unsupported_ops_to_null=*/true); + nnapi_mock_->ModelCreateReturns<0>(); + AddSubOpsAcceleratedModel m( + {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {}}, + ActivationFunctionType_NONE, nnapi_mock_->GetNnApi(), + /*accelerator_name=*/"test-device"); + std::vector input1{-2.0, 0.2, 0.7, 0.9}; + std::vector input2{0.1, 0.2, 0.3, 0.5}; + m.PopulateTensor(m.input1(), input1); + m.PopulateTensor(m.input2(), input2); + m.PopulateTensor(m.input3(), input2); + m.Invoke(); + + // Delegation succeded without failures and all nodes have been delegated. + ASSERT_EQ(m.CountOpsExecutedByCpuKernel(), 0); +} + // Model with a chain of no-op (add with zero operations) // interleaved with no-op custom nodes. class LongIdentityModel : public MultiOpModel, public AcceleratedModel { diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_mock_test.h b/tensorflow/lite/delegates/nnapi/nnapi_delegate_mock_test.h index c0dc06ab62b..fa7ff9dd1f1 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_mock_test.h +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_mock_test.h @@ -52,6 +52,7 @@ class NnApiMock : public ::tflite::nnapi::NnApiHandler { nnapi_->ASharedMemory_create = [](const char* name, size_t size) -> int { return open("/dev/zero", O_RDWR); }; + nnapi_->ANeuralNetworksEvent_free = [](ANeuralNetworksEvent* event) {}; ModelCreateReturns(); AddOperandReturns(); @@ -68,6 +69,8 @@ class NnApiMock : public ::tflite::nnapi::NnApiHandler { ExecutionSetInputFromMemoryReturns(); ExecutionSetOutputFromMemoryReturns(); ExecutionComputeReturns(); + ExecutionStartComputeReturns(); + EventWaitReturns(); SetNnapiSupportedDevice("test-device", android_sdk_version); } diff --git a/tensorflow/lite/examples/label_image/BUILD b/tensorflow/lite/examples/label_image/BUILD index b3dd0764330..a1d134e5b6a 100644 --- a/tensorflow/lite/examples/label_image/BUILD +++ b/tensorflow/lite/examples/label_image/BUILD @@ -33,6 +33,7 @@ cc_binary( "//tensorflow/lite:framework", "//tensorflow/lite:string_util", "//tensorflow/lite/delegates/nnapi:nnapi_delegate", + "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/profiling:profiler", "//tensorflow/lite/tools/evaluation:utils", diff --git a/tensorflow/lite/examples/label_image/README.md b/tensorflow/lite/examples/label_image/README.md index f38b03ecfe2..09e9e77b86a 100644 --- a/tensorflow/lite/examples/label_image/README.md +++ b/tensorflow/lite/examples/label_image/README.md @@ -105,7 +105,7 @@ Run the model with NNAPI delegate (`-a 1`), `adb shell then you should see something like the followings: `Loaded model /data/local/tmp/mobilenet_v1_1.0_224.tflite resolved reporter INFO: Initialized TensorFlow Lite runtime. INFO: Created TensorFlow Lite delegate for NNAPI. -Applied NNAPI delegate.invoked average time: 10.348 ms 0.905401: 653 military +Applied NNAPI delegate. invoked average time:10.348 ms 0.905401: 653 military uniform 0.0379589: 907 Windsor tie 0.00735866: 466 bulletproof vest 0.00605307: 458 bow tie 0.00422573: 514 cornet` @@ -125,4 +125,13 @@ average time: 8.307 ms 0.729412: 653 military uniform 0.0980392: 907 Windsor tie 0.0313726: 466 bulletproof vest 0.0313726: 458 bow tie 0.0117647: 700 panpipe ``` +Run the model with the XNNPACK delegate (`-x 1`), `adb shell +"/data/local/tmp/label_image \ -m /data/local/tmp/mobilenet_v1_1.0_224.tflite \ +-i /data/local/tmp/grace_hopper.bmp \ -l /data/local/tmp/labels.txt -x 1"` then +you should see something like the followings: `Loaded model +/data/local/tmp/mobilenet_v1_1.0_224.tflite resolved reporter INFO: Initialized +TensorFlow Lite runtime. Applied XNNPACK delegate.invoked average time: 11.0237 +ms 0.90707: 653 military uniform 0.0372418: 907 Windsor tie 0.0073376: 466 +bulletproof vest 0.00592856: 458 bow tie 0.00414093: 514 cornet` + See the `label_image.cc` source code for other command line options. diff --git a/tensorflow/lite/examples/label_image/label_image.cc b/tensorflow/lite/examples/label_image/label_image.cc index b493fafa839..fe3d4cf9f09 100644 --- a/tensorflow/lite/examples/label_image/label_image.cc +++ b/tensorflow/lite/examples/label_image/label_image.cc @@ -37,6 +37,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" +#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" #include "tensorflow/lite/examples/label_image/bitmap_helpers.h" #include "tensorflow/lite/examples/label_image/get_top_n.h" #include "tensorflow/lite/kernels/register.h" @@ -101,6 +102,15 @@ TfLiteDelegatePtrMap GetDelegates(Settings* s) { } } + if (s->xnnpack_delegate) { + auto delegate = evaluation::CreateXNNPACKDelegate(s->number_of_threads); + if (!delegate) { + LOG(INFO) << "XNNPACK acceleration is unsupported on this platform."; + } else { + delegates.emplace("XNNPACK", std::move(delegate)); + } + } + return delegates; } @@ -360,6 +370,7 @@ void display_usage() { << "--threads, -t: number of threads\n" << "--verbose, -v: [0|1] print more information\n" << "--warmup_runs, -w: number of warmup runs\n" + << "--xnnpack_delegate, -x: xnnpack delegate\n" << "\n"; } @@ -386,13 +397,14 @@ int Main(int argc, char** argv) { {"warmup_runs", required_argument, nullptr, 'w'}, {"gl_backend", required_argument, nullptr, 'g'}, {"hexagon_delegate", required_argument, nullptr, 'j'}, + {"xnnpack_delegate", required_argument, nullptr, 'x'}, {nullptr, 0, nullptr, 0}}; /* getopt_long stores the option index here. */ int option_index = 0; c = getopt_long(argc, argv, - "a:b:c:d:e:f:g:i:j:l:m:p:r:s:t:v:w:", long_options, + "a:b:c:d:e:f:g:i:j:l:m:p:r:s:t:v:w:x:", long_options, &option_index); /* Detect the end of the options. */ @@ -460,6 +472,9 @@ int Main(int argc, char** argv) { s.number_of_warmup_runs = strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; + case 'x': + s.xnnpack_delegate = optarg; + break; case 'h': case '?': /* getopt_long already printed an error message. */ diff --git a/tensorflow/lite/examples/label_image/label_image.h b/tensorflow/lite/examples/label_image/label_image.h index 110340c6ddf..737231e567f 100644 --- a/tensorflow/lite/examples/label_image/label_image.h +++ b/tensorflow/lite/examples/label_image/label_image.h @@ -31,6 +31,7 @@ struct Settings { bool allow_fp16 = false; bool gl_backend = false; bool hexagon_delegate = false; + bool xnnpack_delegate = false; int loop_count = 1; float input_mean = 127.5f; float input_std = 127.5f; diff --git a/tensorflow/lite/experimental/delegates/coreml/BUILD.apple b/tensorflow/lite/experimental/delegates/coreml/BUILD.apple new file mode 100644 index 00000000000..92aa96d5c50 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/BUILD.apple @@ -0,0 +1,77 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["coreml_delegate.h"]) + +objc_library( + name = "coreml_executor", + srcs = ["coreml_executor.mm"], + hdrs = ["coreml_executor.h"], + sdk_frameworks = [ + "Foundation", + "UIKit", + "CoreML", + ], + deps = [ + ":mlmodel_proto_cc", + ], +) + +cc_library( + name = "mlmodel_proto_cc", + deps = [ + "@coremltools//:mlmodel_cc_proto", + ], +) + +objc_library( + name = "coreml_delegate", + srcs = ["coreml_delegate.mm"], + hdrs = ["coreml_delegate.h"], + deps = [ + ":coreml_delegate_kernel", + ":mlmodel_proto_cc", + "//tensorflow/lite:kernel_api", + "//tensorflow/lite:minimal_logging", + "//tensorflow/lite/c:common", + "//tensorflow/lite/experimental/delegates/coreml/builders:op_builder", + ], +) + +objc_library( + name = "coreml_delegate_kernel", + srcs = [ + "coreml_delegate_kernel.mm", + ], + hdrs = [ + "coreml_delegate_kernel.h", + ], + deps = [ + ":coreml_executor", + ":mlmodel_proto_cc", + "//tensorflow/lite:kernel_api", + "//tensorflow/lite/c:common", + "//tensorflow/lite/experimental/delegates/coreml/builders:op_builder", + "//tensorflow/lite/kernels:kernel_util", + "//tensorflow/lite/kernels/internal:optimized_base", + "//tensorflow/lite/kernels/internal:types", + ], +) diff --git a/tensorflow/lite/experimental/delegates/coreml/README.md b/tensorflow/lite/experimental/delegates/coreml/README.md new file mode 100644 index 00000000000..fa2e2a8d68a --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/README.md @@ -0,0 +1,160 @@ +# Tensorflow Lite Core ML Delegate + +TensorFlow Lite Core ML Delegate enables running TensorFlow Lite models on +[Core ML framework](https://developer.apple.com/documentation/coreml), +which results in faster model inference on iOS devices. + +[TOC] + +## Supported iOS versions and processors + +* iOS 12 and later. In the older iOS versions, Core ML delegate will + automatically fallback to CPU. +* When running on iPhone Xs and later, it will use Neural Engine for faster + inference. + +## Update code to use Core ML delegate + +### Swift + +Initialize TensorFlow Lite interpreter with Core ML delegate. + +```swift +let coreMlDelegate = CoreMLDelegate() +let interpreter = try Interpreter(modelPath: modelPath, + delegates: [coreMLDelegate]) +``` + +### Objective-C++ + +#### Interpreter initialization + +Include `coreml_delegate.h`. + +```objectivec++ +#include "tensorflow/lite/experimental/delegates/coreml/coreml_delegate.h" +``` + +Modify code following interpreter initialization to apply delegate. + +```objectivec++ +// initializer interpreter with model. +tflite::InterpreterBuilder(*model, resolver)(&interpreter); + +// Add following section to use Core ML delegate. +TfLiteCoreMlDelegateOptions options = {}; +delegate = TfLiteCoreMlDelegateCreate(&options); +interpreter->ModifyGraphWithDelegate(delegate); + +// ... +``` + +#### Disposal + +Add this code to the section where you dispose of the delegate (e.g. `dealloc` +of class). + +```objectivec++ +TfLiteCoreMlDelegateDelete(delegate); +``` + +## Supported ops + +Following ops are supported by the Core ML delegate. + +* Add + * Only certain shapes are broadcastable. In Core ML tensor layout, + following tensor shapes are broadcastable. `[B, C, H, W]`, `[B, C, 1, + 1]`, `[B, 1, H, W]`, `[B, 1, 1, 1]`. +* AveragePool2D +* Concat +* Conv2D + * Weights and bias should be constant. +* DepthwiseConv2D + * Weights and bias should be constant. +* Hardswish +* Logistic (aka Sigmoid) +* MaxPool2D +* Mul + * Only certain shapes are broadcastable. In Core ML tensor layout, + following tensor shapes are broadcastable. `[B, C, H, W]`, `[B, C, 1, + 1]`, `[B, 1, H, W]`, `[B, 1, 1, 1]`. +* Relu +* ReluN1To1 +* Relu6 +* Reshape +* ResizeBilinear +* SoftMax +* Tanh + +## FAQ + +* Does Core ML delegate support fallback to CPU if a graph contains unsupported + ops? + * Yes. +* Does Core ML delegate work on iOS Simulator? + * Yes. The library includes x86 and x86_64 targets so it can run on + a simulator, but you will not see performance boost over CPU. +* Does TensorFlow Lite and Core ML delegate support macOS? + * TensorFlow Lite is only tested on iOS but not macOS. +* Are custom TF Lite ops supported? + * No, CoreML delegate does not support custom ops and they will fallback to + CPU. + +## Appendix + +### Core ML delegate Swift API + +```swift +/// A delegate that uses the `Core ML` framework for performing +/// TensorFlow Lite graph operations. +/// +/// - Important: This is an experimental interface that is subject to change. +public final class CoreMLDelegate: Delegate { + /// The configuration options for the `CoreMLDelegate`. + public let options: Options + + // Conformance to the `Delegate` protocol. + public private(set) var cDelegate: CDelegate + + * /// Creates a new instance configured with the given `options`. + /// + /// - Parameters: + /// - options: Configurations for the delegate. The default is a new instance of + /// `CoreMLDelegate.Options` with the default configuration values. + public init(options: Options = Options()) { + self.options = options + var delegateOptions = TfLiteCoreMlDelegateOptions() + cDelegate = TfLiteCoreMlDelegateCreate(&delegateOptions) + } + + deinit { + TfLiteCoreMlDelegateDelete(cDelegate) + } +} + +extension CoreMLDelegate { + /// Options for configuring the `CoreMLDelegate`. + public struct Options: Equatable, Hashable { + /// Creates a new instance with the default values. + public init() {} + } +} +``` + +### Core ML delegate C++ API + +```c++ +typedef struct { + // We have dummy for now as we can't have empty struct in C. + char dummy; +} TfLiteCoreMlDelegateOptions; + +// Return a delegate that uses CoreML for ops execution. +// Must outlive the interpreter. +TfLiteDelegate* TfLiteCoreMlDelegateCreate( + const TfLiteCoreMlDelegateOptions* options); + +// Do any needed cleanup and delete 'delegate'. +void TfLiteCoreMlDelegateDelete(TfLiteDelegate* delegate); +``` diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/BUILD.apple b/tensorflow/lite/experimental/delegates/coreml/builders/BUILD.apple new file mode 100644 index 00000000000..210ad8996dd --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/BUILD.apple @@ -0,0 +1,79 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "op_builder", + srcs = glob(["*_builder.cc"]), + hdrs = glob(["*_builder.h"]), + deps = [ + ":op_factory", + ":op_validator", + ":util", + "//tensorflow/lite:kernel_api", + "//tensorflow/lite/c:common", + "//tensorflow/lite/kernels:kernel_util", + "//tensorflow/lite/kernels/internal:optimized_base", + "//tensorflow/lite/kernels/internal:tensor", + "//tensorflow/lite/kernels/internal:types", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_protobuf//:protobuf_headers", + "@coremltools//:mlmodel_cc_proto", + ], +) + +cc_library( + name = "op_factory", + hdrs = ["op_factory.h"], + deps = [ + "//tensorflow/lite/c:common", + ], +) + +cc_library( + name = "op_validator", + hdrs = ["op_validator.h"], + deps = [ + "//tensorflow/lite/c:common", + ], +) + +cc_library( + name = "util", + srcs = ["util.cc"], + hdrs = ["util.h"], + deps = [ + ":op_validator", + "//tensorflow/lite/c:common", + "//tensorflow/lite/kernels:kernel_util", + ], +) + +cc_test( + name = "util_test", + srcs = ["util_test.cc"], + deps = [ + ":util", + "//tensorflow/lite/c:common", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/activation_layer_builder.cc b/tensorflow/lite/experimental/delegates/coreml/builders/activation_layer_builder.cc new file mode 100644 index 00000000000..ec032d8421e --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/activation_layer_builder.cc @@ -0,0 +1,141 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/delegates/coreml/builders/activation_layer_builder.h" + +#include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_factory.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/threshold_layer_builder.h" + +namespace tflite { +namespace delegates { +namespace coreml { + +const char* ActivationLayerBuilder::DebugName() { + if (!str_debug_name_[0]) + GetDebugName("ActivationLayerBuilder", node_id_, str_debug_name_); + return str_debug_name_; +} + +CoreML::Specification::NeuralNetworkLayer* ActivationLayerBuilder::Build() { + layer_->set_name(DebugName()); + switch (activation_) { + // ActNone is used for sclalar multiplication (linear activation) + case kTfLiteActNone: + layer_->mutable_activation()->mutable_linear()->set_alpha(alpha_); + break; + case kTfLiteActRelu: + layer_->mutable_activation()->mutable_relu(); + break; + // Relu1 and Relu6 layers are fully composed in PopulateSubgraph(). + case kTfLiteActRelu1: // clip(-1, 1) + layer_->mutable_unary()->set_alpha(-1); + layer_->mutable_unary()->set_type( + CoreML::Specification::UnaryFunctionLayerParams::THRESHOLD); + break; + case kTfLiteActRelu6: // clip(0, 6) + layer_->mutable_activation()->mutable_relu(); + break; + case kTfLiteActTanh: + layer_->mutable_activation()->mutable_tanh(); + break; + case kTfLiteActSigmoid: + layer_->mutable_activation()->mutable_sigmoid(); + break; + // TODO(taeheej): signbit is not implemented. + default: + fprintf(stderr, "Activation %d is not supported.\n", activation_); + break; + } + return layer_.release(); +} + +TfLiteStatus ActivationLayerBuilder::PopulateSubgraph(TfLiteContext* context) { + if (!(activation_ == kTfLiteActRelu6 || activation_ == kTfLiteActRelu1)) { + builder_output_ = AddOutput(); + return kTfLiteOk; + } + + // Relu1: Threshold(-1) -> Threshold(-1) with scale: -1 -> Negation + // Relu6: ReLU -> Threshold(-6) with scale: -1 -> Negation + const int relu_threshold = activation_ == kTfLiteActRelu6 ? 6 : 1; + ThresholdLayerBuilder* threshold_builder = + reinterpret_cast( + graph_builder_->AddBuilder(CreateThresholdLayerBuilder, nullptr)); + + threshold_builder->SetAlpha(-relu_threshold); + threshold_builder->SetScale(-1); + + threshold_builder->AddInput(AddOutput()); + + ActivationLayerBuilder* negation_builder = + reinterpret_cast( + graph_builder_->AddBuilder(CreateActivationLayerBuilder, nullptr)); + negation_builder->SetActivation(kTfLiteActNone); + negation_builder->SetAlpha(-1); + + negation_builder->AddInput(threshold_builder->AddOutput()); + builder_output_ = negation_builder->AddOutput(); + return kTfLiteOk; +} + +TfLiteStatus ActivationLayerBuilder::RegisterInputs( + const TfLiteIntArray* inputs, TfLiteContext* context) { + if (inputs->size != 1) { + TF_LITE_KERNEL_LOG(context, "Activation: Wrong # of inputs!."); + return kTfLiteError; + } + AddInput(inputs->data[0]); + return kTfLiteOk; +} + +TfLiteStatus ActivationLayerBuilder::RegisterOutputs( + const TfLiteIntArray* outputs, TfLiteContext* context) { + if (outputs->size != 1) { + TF_LITE_KERNEL_LOG(context, "Activation: Wrong # of outputs!."); + return kTfLiteError; + } + graph_builder_->AddTensorWithID(outputs->data[0], GetOutput(context)); + return kTfLiteOk; +} + +OpBuilder* CreateActivationLayerBuilder(GraphBuilder* graph_builder) { + return new ActivationLayerBuilder(graph_builder); +} + +OpBuilder* CreateLogisticOpBuilder(GraphBuilder* graph_builder) { + return new ActivationLayerBuilder(graph_builder, kTfLiteActSigmoid); +} + +OpBuilder* CreateReluOpBuilder(GraphBuilder* graph_builder) { + return new ActivationLayerBuilder(graph_builder, kTfLiteActRelu); +} + +OpBuilder* CreateReluN1To1OpBuilder(GraphBuilder* graph_builder) { + return new ActivationLayerBuilder(graph_builder, kTfLiteActRelu1); +} + +OpBuilder* CreateRelu6OpBuilder(GraphBuilder* graph_builder) { + return new ActivationLayerBuilder(graph_builder, kTfLiteActRelu6); +} + +OpBuilder* CreateTanhOpBuilder(GraphBuilder* graph_builder) { + return new ActivationLayerBuilder(graph_builder, kTfLiteActTanh); +} + +} // namespace coreml +} // namespace delegates +} // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/activation_layer_builder.h b/tensorflow/lite/experimental/delegates/coreml/builders/activation_layer_builder.h new file mode 100644 index 00000000000..b22b454894b --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/activation_layer_builder.h @@ -0,0 +1,60 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_ACTIVATION_LAYER_BUILDER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_ACTIVATION_LAYER_BUILDER_H_ + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_builder.h" + +namespace tflite { +namespace delegates { +namespace coreml { + +class ActivationLayerBuilder : public OpBuilder { + public: + explicit ActivationLayerBuilder(GraphBuilder* graph_builder) + : OpBuilder(graph_builder) {} + + explicit ActivationLayerBuilder(GraphBuilder* graph_builder, + TfLiteFusedActivation activation) + : OpBuilder(graph_builder), activation_(activation) {} + + const char* DebugName() override; + + CoreML::Specification::NeuralNetworkLayer* Build() override; + + void SetActivation(TfLiteFusedActivation activation) { + activation_ = activation; + } + + void SetAlpha(float alpha) { alpha_ = alpha; } + + TfLiteStatus PopulateSubgraph(TfLiteContext* context) override; + + TfLiteStatus RegisterInputs(const TfLiteIntArray* inputs, + TfLiteContext* context) override; + + TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs, + TfLiteContext* context) override; + + private: + TfLiteFusedActivation activation_; + float alpha_ = 1.0f; +}; + +} // namespace coreml +} // namespace delegates +} // namespace tflite +#endif // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_ACTIVATION_LAYER_BUILDER_H_ diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/add_op_builder.cc b/tensorflow/lite/experimental/delegates/coreml/builders/add_op_builder.cc new file mode 100644 index 00000000000..d381b8a8e6c --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/add_op_builder.cc @@ -0,0 +1,109 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/delegates/coreml/builders/add_op_builder.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/activation_layer_builder.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_factory.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace delegates { +namespace coreml { +const char* AddOpBuilder::DebugName() { + if (!str_debug_name_[0]) + GetDebugName("AddOpBuilder", node_id_, str_debug_name_); + return str_debug_name_; +} + +CoreML::Specification::NeuralNetworkLayer* AddOpBuilder::Build() { + if (layer_ == nullptr) { + layer_.reset(new CoreML::Specification::NeuralNetworkLayer); + } + layer_->set_name(DebugName()); + layer_->mutable_add(); + if (alpha_ != 0.0f) { + layer_->mutable_add()->set_alpha(alpha_); + } + + return layer_.release(); +} + +TfLiteStatus AddOpBuilder::PopulateSubgraph(TfLiteContext* context) { + TfLiteAddParams* params = reinterpret_cast(builtin_data_); + + TfLiteFusedActivation activation = params->activation; + if (activation == kTfLiteActNone) { + builder_output_ = AddOutput(); + } else { + ActivationLayerBuilder* activation_builder = + reinterpret_cast( + graph_builder_->AddBuilder(CreateActivationLayerBuilder, nullptr)); + activation_builder->SetActivation(activation); + activation_builder->AddInput(AddOutput()); + activation_builder->PopulateSubgraph(context); + builder_output_ = activation_builder->GetOutput(context); + } + return kTfLiteOk; +} + +TfLiteStatus AddOpBuilder::RegisterInputs(const TfLiteIntArray* inputs, + TfLiteContext* context) { + // TODO(taeheej): support 1 input case if necessary. TFL add needs 2 inputs. + if (inputs->size != 2) { + TF_LITE_KERNEL_LOG(context, "Wrong # of inputs to add!."); + return kTfLiteError; + } + const auto* input_0 = &context->tensors[inputs->data[0]]; + const auto* input_1 = &context->tensors[inputs->data[1]]; + // store constant, scalar value into MultiplyLayerParams directly. + if (IsConstantTensor(input_0) && NumElements(input_0) == 1) { + AddInput(inputs->data[1]); + SetAlpha(GetTensorData(input_0)[0]); + } else if (IsConstantTensor(input_1) && NumElements(input_1) == 1) { + AddInput(inputs->data[0]); + SetAlpha(GetTensorData(input_1)[0]); + } else { + AddInput(inputs->data[0]); + AddInput(inputs->data[1]); + } + return kTfLiteOk; +} + +TfLiteStatus AddOpBuilder::RegisterOutputs(const TfLiteIntArray* outputs, + TfLiteContext* context) { + if (outputs->size != 1) { + TF_LITE_KERNEL_LOG(context, "Wrong # of outputs to add!."); + return kTfLiteError; + } + TensorID output_tensor = GetOutput(context); + if (output_tensor.NodeID() == -1) { + TF_LITE_KERNEL_LOG(context, "Failed to build output tensor."); + return kTfLiteError; + } + graph_builder_->AddTensorWithID(outputs->data[0], output_tensor); + return kTfLiteOk; +} + +void AddOpBuilder::SetAlpha(float alpha) { alpha_ = alpha; } + +OpBuilder* CreateAddOpBuilder(GraphBuilder* graph_builder) { + return new AddOpBuilder(graph_builder); +} +} // namespace coreml +} // namespace delegates +} // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/add_op_builder.h b/tensorflow/lite/experimental/delegates/coreml/builders/add_op_builder.h new file mode 100644 index 00000000000..17e1f9a9827 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/add_op_builder.h @@ -0,0 +1,51 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_ADD_OP_BUILDER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_ADD_OP_BUILDER_H_ + +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_builder.h" + +namespace tflite { +namespace delegates { +namespace coreml { +// Builder for Add op in CoreML. +class AddOpBuilder : public OpBuilder { + public: + explicit AddOpBuilder(GraphBuilder* graph_builder) + : OpBuilder(graph_builder) {} + const char* DebugName() override; + + CoreML::Specification::NeuralNetworkLayer* Build() override; + + TfLiteStatus PopulateSubgraph(TfLiteContext* context) override; + + TfLiteStatus RegisterInputs(const TfLiteIntArray* inputs, + TfLiteContext* context) override; + + TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs, + TfLiteContext* context) override; + + void SetAlpha(float alpha); + + private: + // Used for unary add + float alpha_ = 0.0f; +}; + +} // namespace coreml +} // namespace delegates +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_ADD_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/concatenation_op_builder.cc b/tensorflow/lite/experimental/delegates/coreml/builders/concatenation_op_builder.cc new file mode 100644 index 00000000000..1a61d9fd997 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/concatenation_op_builder.cc @@ -0,0 +1,83 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/delegates/coreml/builders/concatenation_op_builder.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_validator.h" + +namespace tflite { +namespace delegates { +namespace coreml { + +CoreML::Specification::NeuralNetworkLayer* ConcatenationOpBuilder::Build() { + if (layer_ == nullptr) { + layer_.reset(new CoreML::Specification::NeuralNetworkLayer); + } + layer_->set_name(DebugName()); + layer_->mutable_concat()->set_sequenceconcat(false); + return layer_.release(); +} + +TfLiteStatus ConcatenationOpBuilder::RegisterInputs( + const TfLiteIntArray* inputs, TfLiteContext* context) { + if (inputs->size < 2) { + TF_LITE_KERNEL_LOG( + context, "ConcatenationOpBuidler: at least 2 inputs are required."); + return kTfLiteError; + } + for (int i = 0; i < inputs->size; ++i) { + AddInput(inputs->data[i]); + } + return kTfLiteOk; +} + +TfLiteStatus ConcatenationOpBuilder::RegisterOutputs( + const TfLiteIntArray* outputs, TfLiteContext* context) { + if (outputs->size != 1) { + TF_LITE_KERNEL_LOG(context, "Wrong # of outputs to Concat!."); + return kTfLiteError; + } + graph_builder_->AddTensorWithID(outputs->data[0], GetOutput(context)); + return kTfLiteOk; +} + +OpBuilder* CreateConcatenationOpBuilder(GraphBuilder* graph_builder) { + return new ConcatenationOpBuilder(graph_builder); +} + +bool IsConcatenationOpSupported(const TfLiteRegistration* registration, + const TfLiteNode* node, + TfLiteContext* context) { + if (node->builtin_data == nullptr) return false; + auto params = + reinterpret_cast(node->builtin_data); + int input_dims = context->tensors[node->inputs->data[0]].dims->size; + + // Not supported in TfLite kernel. + if (params->activation != kTfLiteActNone) return false; + if (node->inputs->size < 2) return false; + + // Only supports concatenation by channel. Core ML concatenation supports + // concatenation by channel and by sequence (axis -5) only. + // TODO(b/145642128): support stack layer here with Core ML 3 support. + if (input_dims == 3) return params->axis == 2 || params->axis == -1; + if (input_dims == 4) return params->axis == 3 || params->axis == -1; + return false; +} + +} // namespace coreml +} // namespace delegates +} // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/concatenation_op_builder.h b/tensorflow/lite/experimental/delegates/coreml/builders/concatenation_op_builder.h new file mode 100644 index 00000000000..a61bec114fa --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/concatenation_op_builder.h @@ -0,0 +1,48 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_CONCATENATION_OP_BUILDER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_CONCATENATION_OP_BUILDER_H_ + +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_builder.h" + +namespace tflite { +namespace delegates { +namespace coreml { + +class ConcatenationOpBuilder : public OpBuilder { + public: + explicit ConcatenationOpBuilder(GraphBuilder* graph_builder) + : OpBuilder(graph_builder) {} + + const char* DebugName() override { + if (!str_debug_name_[0]) + GetDebugName("ConcatOpBuilder", node_id_, str_debug_name_); + return str_debug_name_; + } + + CoreML::Specification::NeuralNetworkLayer* Build() override; + + TfLiteStatus RegisterInputs(const TfLiteIntArray* inputs, + TfLiteContext* context) override; + + TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs, + TfLiteContext* context) override; +}; + +} // namespace coreml +} // namespace delegates +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_CONCATENATION_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/convolution_op_builder.cc b/tensorflow/lite/experimental/delegates/coreml/builders/convolution_op_builder.cc new file mode 100644 index 00000000000..b6a859d1dff --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/convolution_op_builder.cc @@ -0,0 +1,347 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/delegates/coreml/builders/convolution_op_builder.h" + +#include "google/protobuf/repeated_field.h" +#include "external/coremltools/mlmodel/format/NeuralNetwork.pb.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/activation_layer_builder.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_factory.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_validator.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace delegates { +namespace coreml { +const char* ConvolutionOpBuilder::DebugName() { + if (!str_debug_name_[0]) + GetDebugName("ConvolutionOpBuilder", node_id_, str_debug_name_); + return str_debug_name_; +} + +void ConvolutionOpBuilder::SetWeights(TfLiteTensor* weights) { + weights_ = weights; +} + +void ConvolutionOpBuilder::SetBias(TfLiteTensor* bias) { bias_ = bias; } + +void ConvolutionOpBuilder::SetOutputShape(TfLiteTensor* output_shape) { + output_shape_ = output_shape; +} + +CoreML::Specification::NeuralNetworkLayer* ConvolutionOpBuilder::Build() { + if (layer_ == nullptr) { + layer_.reset(new CoreML::Specification::NeuralNetworkLayer); + } + layer_->set_name(DebugName()); + + int stride_height = 0; + int stride_width = 0; + int dilation_height = 0; + int dilation_width = 0; + TfLitePadding padding; + + switch (conv_type_) { + case ConvolutionType::kConv: { + const auto* conv_params = + reinterpret_cast(builtin_data_); + stride_height = conv_params->stride_height; + stride_width = conv_params->stride_width; + dilation_height = conv_params->dilation_height_factor; + dilation_width = conv_params->dilation_width_factor; + padding = conv_params->padding; + + layer_->mutable_convolution()->set_ngroups(1); + break; + } + case ConvolutionType::kDepthwiseConv: { + const auto* depthwise_conv_params = + reinterpret_cast(builtin_data_); + stride_height = depthwise_conv_params->stride_height; + stride_width = depthwise_conv_params->stride_width; + dilation_height = depthwise_conv_params->dilation_height_factor; + dilation_width = depthwise_conv_params->dilation_width_factor; + padding = depthwise_conv_params->padding; + + // n_groups = kernel_channel / depth_multiplier + layer_->mutable_convolution()->set_ngroups( + weights_->dims->data[3] / depthwise_conv_params->depth_multiplier); + break; + } + case ConvolutionType::kTransposeConv: { + const auto* transpose_conv_params = + reinterpret_cast(builtin_data_); + const int height_index = 1; + const int width_index = 2; + + stride_height = transpose_conv_params->stride_height; + stride_width = transpose_conv_params->stride_width; + padding = transpose_conv_params->padding; + + layer_->mutable_convolution()->mutable_outputshape()->Add( + GetTensorData(output_shape_)[height_index]); + layer_->mutable_convolution()->mutable_outputshape()->Add( + GetTensorData(output_shape_)[width_index]); + break; + } + } + + // If not set, it will default to (1,1) + if (stride_height) { + layer_->mutable_convolution()->add_stride(stride_height); + layer_->mutable_convolution()->add_stride(stride_width); + } + + if (dilation_height) { + layer_->mutable_convolution()->add_dilationfactor(dilation_height); + layer_->mutable_convolution()->add_dilationfactor(dilation_width); + } + + switch (padding) { + case kTfLitePaddingSame: + layer_->mutable_convolution()->mutable_same(); + break; + case kTfLitePaddingValid: + layer_->mutable_convolution()->mutable_valid(); + break; + case kTfLitePaddingUnknown: + fprintf(stderr, "Padding is unknown.\n"); + break; + } + + FillCoreMLWeights(); + FillCoreMLBias(); + + return layer_.release(); +} + +void ConvolutionOpBuilder::FillCoreMLWeights() { + if (conv_type_ == ConvolutionType::kDepthwiseConv) { + layer_->mutable_convolution()->set_kernelchannels(1); + layer_->mutable_convolution()->set_outputchannels(weights_->dims->data[3]); + } else { + layer_->mutable_convolution()->set_kernelchannels(weights_->dims->data[3]); + layer_->mutable_convolution()->set_outputchannels(weights_->dims->data[0]); + } + layer_->mutable_convolution()->add_kernelsize(weights_->dims->data[1]); + layer_->mutable_convolution()->add_kernelsize(weights_->dims->data[2]); + + TransposeKernelWeights(); // Should be called after CoreML shape is set. +} + +void ConvolutionOpBuilder::TransposeKernelWeights() { + RuntimeShape tfl_shape(4, weights_->dims->data); + // CoreML kernel has shape of (C_out, C_in, H, W) + RuntimeShape coreml_shape( + {static_cast(layer_->convolution().outputchannels()), + static_cast(layer_->convolution().kernelchannels()), + static_cast(layer_->convolution().kernelsize()[0]), + static_cast(layer_->convolution().kernelsize()[1])}); + + TransposeParams params; + + if (conv_type_ == ConvolutionType::kDepthwiseConv) { + // DepthwiseConv2D: TFL kernel has shape of (1, H, W, C_out), + // and CoreML kernel has shape of (C_out, 1, H, W) + params = {/*perm_count=*/4, /*perm=*/{3, 0, 1, 2}}; + } else { + // Conv2D and TransposeConv: TFL kernel has shape of (C_out, H, W, C_in), + // and CoreML kernel has shape of (C_out, C_in, H, W) + params = {/*perm_count=*/4, /*perm=*/{0, 3, 1, 2}}; + } + + if (conv_type_ == ConvolutionType::kTransposeConv) { + layer_->mutable_convolution()->set_isdeconvolution(true); + } + + auto* coreml_weights = + layer_->mutable_convolution()->mutable_weights()->mutable_floatvalue(); + coreml_weights->Resize(NumElements(weights_), 0); + + optimized_ops::Transpose(params, tfl_shape, weights_->data.f, + coreml_shape, coreml_weights->mutable_data()); +} + +void ConvolutionOpBuilder::FillCoreMLBias() { + if (bias_ != nullptr) { + layer_->mutable_convolution()->set_hasbias(true); + std::copy(bias_->data.f, bias_->data.f + NumElements(bias_->dims), + google::protobuf::RepeatedFieldBackInserter(layer_->mutable_convolution() + ->mutable_bias() + ->mutable_floatvalue())); + } +} + +TfLiteStatus ConvolutionOpBuilder::PopulateSubgraph(TfLiteContext* context) { + TfLiteFusedActivation activation; + switch (conv_type_) { + case ConvolutionType::kConv: { + const auto* conv_params = + reinterpret_cast(builtin_data_); + activation = conv_params->activation; + break; + } + case ConvolutionType::kDepthwiseConv: { + const auto* depthwise_conv_params = + reinterpret_cast(builtin_data_); + activation = depthwise_conv_params->activation; + break; + } + case ConvolutionType::kTransposeConv: { + activation = kTfLiteActNone; + break; + } + } + + if (activation == kTfLiteActNone) { + builder_output_ = AddOutput(); + } else { + ActivationLayerBuilder* activation_builder = + reinterpret_cast( + graph_builder_->AddBuilder(CreateActivationLayerBuilder, nullptr)); + activation_builder->SetActivation(activation); + activation_builder->AddInput(AddOutput()); + activation_builder->PopulateSubgraph(context); + builder_output_ = activation_builder->GetOutput(context); + } + return kTfLiteOk; +} + +TfLiteStatus ConvolutionOpBuilder::RegisterInputs(const TfLiteIntArray* inputs, + TfLiteContext* context) { + if (conv_type_ == ConvolutionType::kTransposeConv) { + if (inputs->size != 3) { + TF_LITE_KERNEL_LOG(context, + "Transpose Conv should have 3 inputs, %d given.", + inputs->size); + return kTfLiteError; + } + AddInput(inputs->data[2]); + SetOutputShape(&context->tensors[inputs->data[0]]); + } else { + if (inputs->size != 2 && inputs->size != 3) { + TF_LITE_KERNEL_LOG(context, + "Convolution and depthwise convolution should have 2 " + "or 3 inputs, %d given.", + inputs->size); + return kTfLiteError; + } + AddInput(inputs->data[0]); + if (inputs->size > 2) { + SetBias(&context->tensors[inputs->data[2]]); + } + } + SetWeights(&context->tensors[inputs->data[1]]); + return kTfLiteOk; +} + +TfLiteStatus ConvolutionOpBuilder::RegisterOutputs( + const TfLiteIntArray* outputs, TfLiteContext* context) { + if (outputs->size != 1) { + TF_LITE_KERNEL_LOG(context, "Wrong # of outputs!."); + return kTfLiteError; + } + TensorID output_tensor = GetOutput(context); + if (output_tensor.NodeID() == -1) { + TF_LITE_KERNEL_LOG(context, "Failed to build output tensor."); + return kTfLiteError; + } + graph_builder_->AddTensorWithID(outputs->data[0], output_tensor); + return kTfLiteOk; +} + +OpBuilder* CreateConvolutionOpBuilder(GraphBuilder* graph_builder) { + return new ConvolutionOpBuilder(graph_builder, ConvolutionType::kConv); +} + +OpBuilder* CreateDepthwiseConvolutionOpBuilder(GraphBuilder* graph_builder) { + return new ConvolutionOpBuilder(graph_builder, + ConvolutionType::kDepthwiseConv); +} + +OpBuilder* CreateTransposeConvolutionOpBuilder(GraphBuilder* graph_builder) { + return new ConvolutionOpBuilder(graph_builder, + ConvolutionType::kTransposeConv); +} + +bool IsConvolutionOpSupported(const TfLiteRegistration* registration, + const TfLiteNode* node, TfLiteContext* context) { + if (node->builtin_data == nullptr) return false; + + TfLiteFusedActivation activation; + + if (registration->builtin_code == kTfLiteBuiltinConv2d) { + const auto* conv_params = + reinterpret_cast(node->builtin_data); + activation = conv_params->activation; + } else if (registration->builtin_code == kTfLiteBuiltinDepthwiseConv2d) { + const auto* depthwise_conv_params = + reinterpret_cast(node->builtin_data); + activation = depthwise_conv_params->activation; + } else if (registration->builtin_code == kTfLiteBuiltinTransposeConv) { + activation = kTfLiteActNone; + } else { + TF_LITE_KERNEL_LOG( + context, + "Invalid op: op must be Conv2D, DepthwiseConv2D or TransposeConv."); + return false; + } + + if (activation == kTfLiteActSignBit) { + return false; + } + + const int kOutputShapeTensor = 0; // Only used for TransposeConv + const int kWeightTensor = 1; + const int kBiasTensor = 2; // Only used for non-TransposeConv + const TfLiteTensor* weights = GetInput(context, node, kWeightTensor); + const int max_kernel_size = 16384; + if (!IsConstantTensor(weights)) { + return false; + } + if (weights->dims->data[1] > max_kernel_size || + weights->dims->data[2] > max_kernel_size) { + return false; + } + if (registration->builtin_code == kTfLiteBuiltinTransposeConv) { + if (!IsConstantTensor(GetInput(context, node, kOutputShapeTensor))) { + return false; + } + } else { + if (node->inputs->size >= kBiasTensor && + !IsConstantTensor(GetInput(context, node, kBiasTensor))) { + return false; + } + } + + return true; +} + +bool IsDepthwiseConvolutionOpSupported(const TfLiteRegistration* registration, + const TfLiteNode* node, + TfLiteContext* context) { + return IsConvolutionOpSupported(registration, node, context); +} + +bool IsTransposeConvolutionOpSupported(const TfLiteRegistration* registration, + const TfLiteNode* node, + TfLiteContext* context) { + return IsConvolutionOpSupported(registration, node, context); +} + +} // namespace coreml +} // namespace delegates +} // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/convolution_op_builder.h b/tensorflow/lite/experimental/delegates/coreml/builders/convolution_op_builder.h new file mode 100644 index 00000000000..0e2e8ee35aa --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/convolution_op_builder.h @@ -0,0 +1,85 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_CONVOLUTION_OP_BUILDER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_CONVOLUTION_OP_BUILDER_H_ + +#include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_builder.h" + +namespace tflite { +namespace delegates { +namespace coreml { + +enum class ConvolutionType { kConv, kDepthwiseConv, kTransposeConv }; + +// Layer that provides convolution and depthwise convolution. +class ConvolutionOpBuilder : public OpBuilder { + public: + explicit ConvolutionOpBuilder(GraphBuilder* graph_builder, + ConvolutionType conv_type) + : OpBuilder(graph_builder), conv_type_(conv_type) {} + + const char* DebugName() override; + + CoreML::Specification::NeuralNetworkLayer* Build() override; + + TfLiteStatus PopulateSubgraph(TfLiteContext* context) override; + + void SetOutputChannels(uint64_t output_channels); + + void SetNGroups(uint64_t n_groups); + + void SetWeights(TfLiteTensor* weights); + + void SetBias(TfLiteTensor* bias); + + void SetOutputShape(TfLiteTensor* output_shape); + + void SetParams(void* builtin_data); + + TfLiteStatus RegisterInputs(const TfLiteIntArray* inputs, + TfLiteContext* context) override; + + TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs, + TfLiteContext* context) override; + + private: + void FillCoreMLWeights(); + void FillCoreMLBias(); + + // Transpose TFLite kernel weights to CoreML kernel weights. + // Should be called after setting CoreML's kernel shapes. + void TransposeKernelWeights(); + + uint64_t output_channels_; + uint64_t n_groups_ = 1; + + ConvolutionType conv_type_; + + // using default dilation_factor (1, 1) + // CoreML ConvolutionLayerParams.isDeconvolution == false + TfLiteTensor* weights_ = nullptr; + TfLiteTensor* bias_ = nullptr; + // Only used for TransposeConv. + TfLiteTensor* output_shape_ = nullptr; +}; + +} // namespace coreml +} // namespace delegates +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_CONVOLUTION_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/hardswish_op_builder.cc b/tensorflow/lite/experimental/delegates/coreml/builders/hardswish_op_builder.cc new file mode 100644 index 00000000000..1c9de179f46 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/hardswish_op_builder.cc @@ -0,0 +1,89 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/delegates/coreml/builders/hardswish_op_builder.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/add_op_builder.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/mul_op_builder.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_factory.h" + +namespace tflite { +namespace delegates { +namespace coreml { +const char* HardSwishOpBuilder::DebugName() { + if (!str_debug_name_[0]) + GetDebugName("HardSwishOpBuilder", node_id_, str_debug_name_); + return str_debug_name_; +} + +CoreML::Specification::NeuralNetworkLayer* HardSwishOpBuilder::Build() { + layer_->set_name(DebugName()); + layer_->mutable_multiply()->set_alpha(1.0f / 6.0f); + + return layer_.release(); +} + +TfLiteStatus HardSwishOpBuilder::PopulateSubgraph(TfLiteContext* context) { + // hswish(x) = (x/6) * ReLU6(x+3). main layer_ contains the first part, x/6. + // ReLU6(x +3) constructed as add op with fused ReLU6 activation. + AddOpBuilder* add_builder = reinterpret_cast( + graph_builder_->AddBuilder(CreateAddOpBuilder, nullptr)); + TfLiteAddParams add_param{kTfLiteActRelu6}; + add_builder->SetBuiltinData(&add_param); + add_builder->SetAlpha(3.0f); + add_builder->AddInput(layer_->input(0)); + add_builder->PopulateSubgraph(context); + + // multiplies (x/6) from main layer_ and ReLU6(x+3) from the above code. + MulOpBuilder* mul_builder = reinterpret_cast( + graph_builder_->AddBuilder(CreateMulOpBuilder, nullptr)); + mul_builder->AddInput(AddOutput()); + mul_builder->AddInput(add_builder->GetOutput(context)); + builder_output_ = mul_builder->AddOutput(); + return kTfLiteOk; +} + +TfLiteStatus HardSwishOpBuilder::RegisterInputs(const TfLiteIntArray* inputs, + TfLiteContext* context) { + if (inputs->size != 1) { + TF_LITE_KERNEL_LOG(context, "Wrong # of inputs to hardswish!."); + return kTfLiteError; + } + AddInput(inputs->data[0]); + return kTfLiteOk; +} + +TfLiteStatus HardSwishOpBuilder::RegisterOutputs(const TfLiteIntArray* outputs, + TfLiteContext* context) { + if (outputs->size != 1) { + TF_LITE_KERNEL_LOG(context, "Wrong # of outputs to hardswish!."); + return kTfLiteError; + } + TensorID output_tensor = GetOutput(context); + if (output_tensor.NodeID() == -1) { + TF_LITE_KERNEL_LOG(context, "Failed to build output tensor."); + return kTfLiteError; + } + graph_builder_->AddTensorWithID(outputs->data[0], output_tensor); + return kTfLiteOk; +} + +OpBuilder* CreateHardSwishOpBuilder(GraphBuilder* graph_builder) { + return new HardSwishOpBuilder(graph_builder); +} +} // namespace coreml +} // namespace delegates +} // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/hardswish_op_builder.h b/tensorflow/lite/experimental/delegates/coreml/builders/hardswish_op_builder.h new file mode 100644 index 00000000000..d86c9f9b0de --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/hardswish_op_builder.h @@ -0,0 +1,45 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_HARDSWISH_OP_BUILDER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_HARDSWISH_OP_BUILDER_H_ + +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_builder.h" + +namespace tflite { +namespace delegates { +namespace coreml { +// hswish(x) = x * ReLU6(x + 3) / 6 +class HardSwishOpBuilder : public OpBuilder { + public: + explicit HardSwishOpBuilder(GraphBuilder* graph_builder) + : OpBuilder(graph_builder) {} + const char* DebugName() override; + + CoreML::Specification::NeuralNetworkLayer* Build() override; + + TfLiteStatus PopulateSubgraph(TfLiteContext* context) override; + + TfLiteStatus RegisterInputs(const TfLiteIntArray* inputs, + TfLiteContext* context) override; + + TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs, + TfLiteContext* context) override; +}; + +} // namespace coreml +} // namespace delegates +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_HARDSWISH_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/mul_op_builder.cc b/tensorflow/lite/experimental/delegates/coreml/builders/mul_op_builder.cc new file mode 100644 index 00000000000..2ff85545301 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/mul_op_builder.cc @@ -0,0 +1,116 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/delegates/coreml/builders/mul_op_builder.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/activation_layer_builder.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_factory.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_validator.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/util.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace delegates { +namespace coreml { +const char* MulOpBuilder::DebugName() { + if (!str_debug_name_[0]) + GetDebugName("MulOpBuilder", node_id_, str_debug_name_); + return str_debug_name_; +} + +CoreML::Specification::NeuralNetworkLayer* MulOpBuilder::Build() { + if (layer_ == nullptr) { + layer_.reset(new CoreML::Specification::NeuralNetworkLayer); + } + // MultiplyLayerParams only has limited broadcasting support. For example: + // [B, 1, 1, 1], [B, C, 1, 1], [B, 1, H, W], [B, C, H, W]. other shapes + // will make broadcasting fail. + layer_->set_name(DebugName()); + layer_->mutable_multiply(); + if (alpha_ != 1.0f) { + layer_->mutable_multiply()->set_alpha(alpha_); + } + + return layer_.release(); +} + +TfLiteStatus MulOpBuilder::PopulateSubgraph(TfLiteContext* context) { + TfLiteMulParams* params = reinterpret_cast(builtin_data_); + + TfLiteFusedActivation activation = params->activation; + if (activation == kTfLiteActNone) { + builder_output_ = AddOutput(); + } else { + ActivationLayerBuilder* activation_builder = + reinterpret_cast( + graph_builder_->AddBuilder(CreateActivationLayerBuilder, nullptr)); + activation_builder->SetActivation(activation); + activation_builder->AddInput(AddOutput()); + activation_builder->PopulateSubgraph(context); + builder_output_ = activation_builder->GetOutput(context); + } + return kTfLiteOk; +} + +TfLiteStatus MulOpBuilder::RegisterInputs(const TfLiteIntArray* inputs, + TfLiteContext* context) { + // TFL MUL op always has 2 inputs. + if (inputs->size != 2) { + TF_LITE_KERNEL_LOG(context, "Wrong # of inputs to mul!."); + return kTfLiteError; + } + const auto* input_0 = &context->tensors[inputs->data[0]]; + const auto* input_1 = &context->tensors[inputs->data[1]]; + // store constant, scalar value into MultiplyLayerParams directly. + if (IsConstantTensor(input_0) && NumElements(input_0) == 1) { + AddInput(inputs->data[1]); + SetAlpha(GetTensorData(input_0)[0]); + } else if (IsConstantTensor(input_1) && NumElements(input_1) == 1) { + AddInput(inputs->data[0]); + SetAlpha(GetTensorData(input_1)[0]); + } else { + AddInput(inputs->data[0]); + AddInput(inputs->data[1]); + } + return kTfLiteOk; +} + +TfLiteStatus MulOpBuilder::RegisterOutputs(const TfLiteIntArray* outputs, + TfLiteContext* context) { + if (outputs->size != 1) { + TF_LITE_KERNEL_LOG(context, "Wrong # of outputs to mul!."); + return kTfLiteError; + } + TensorID output_tensor = GetOutput(context); + if (output_tensor.NodeID() == -1) { + TF_LITE_KERNEL_LOG(context, "Failed to build output tensor."); + return kTfLiteError; + } + graph_builder_->AddTensorWithID(outputs->data[0], output_tensor); + return kTfLiteOk; +} + +void MulOpBuilder::SetAlpha(float alpha) { alpha_ = alpha; } + +OpBuilder* CreateMulOpBuilder(GraphBuilder* graph_builder) { + return new MulOpBuilder(graph_builder); +} + +} // namespace coreml +} // namespace delegates +} // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/mul_op_builder.h b/tensorflow/lite/experimental/delegates/coreml/builders/mul_op_builder.h new file mode 100644 index 00000000000..d0d54712369 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/mul_op_builder.h @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_MUL_OP_BUILDER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_MUL_OP_BUILDER_H_ + +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_builder.h" + +namespace tflite { +namespace delegates { +namespace coreml { +// Builder for Mul op in CoreML. +class MulOpBuilder : public OpBuilder { + public: + explicit MulOpBuilder(GraphBuilder* graph_builder) + : OpBuilder(graph_builder) {} + const char* DebugName() override; + + CoreML::Specification::NeuralNetworkLayer* Build() override; + + TfLiteStatus PopulateSubgraph(TfLiteContext* context) override; + + TfLiteStatus RegisterInputs(const TfLiteIntArray* inputs, + TfLiteContext* context) override; + + TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs, + TfLiteContext* context) override; + + void SetAlpha(float alpha); + + private: + // Used for unary mul + float alpha_ = 1.0f; +}; + +} // namespace coreml +} // namespace delegates +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_MUL_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/op_builder.cc b/tensorflow/lite/experimental/delegates/coreml/builders/op_builder.cc new file mode 100644 index 00000000000..489e126e55f --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/op_builder.cc @@ -0,0 +1,162 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_builder.h" + +#include "external/coremltools/mlmodel/format/NeuralNetwork.pb.h" +#include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_factory.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace delegates { +namespace coreml { +OpBuilder* GraphBuilder::AddBuilder(int builtin_code, const TfLiteNode* node) { + // Follow the ordering of TfLiteBuiltinOperator enum. + switch (builtin_code) { + case kTfLiteBuiltinAdd: + return AddBuilder(CreateAddOpBuilder, node); + case kTfLiteBuiltinAveragePool2d: + return AddBuilder(CreateAveragePool2dOpBuilder, node); + case kTfLiteBuiltinConcatenation: + return AddBuilder(CreateConcatenationOpBuilder, node); + case kTfLiteBuiltinConv2d: + return AddBuilder(CreateConvolutionOpBuilder, node); + case kTfLiteBuiltinDepthwiseConv2d: + return AddBuilder(CreateDepthwiseConvolutionOpBuilder, node); + case kTfLiteBuiltinLogistic: + return AddBuilder(CreateLogisticOpBuilder, node); + case kTfLiteBuiltinMaxPool2d: + return AddBuilder(CreateMaxPool2dOpBuilder, node); + case kTfLiteBuiltinMul: + return AddBuilder(CreateMulOpBuilder, node); + case kTfLiteBuiltinRelu: + return AddBuilder(CreateReluOpBuilder, node); + case kTfLiteBuiltinReluN1To1: + return AddBuilder(CreateReluN1To1OpBuilder, node); + case kTfLiteBuiltinRelu6: + return AddBuilder(CreateRelu6OpBuilder, node); + case kTfLiteBuiltinReshape: + return AddBuilder(CreateReshapeOpBuilder, node); + case kTfLiteBuiltinResizeBilinear: + return AddBuilder(CreateResizeBilinearOpBuilder, node); + case kTfLiteBuiltinSoftmax: + return AddBuilder(CreateSoftmaxOpBuilder, node); + case kTfLiteBuiltinTanh: + return AddBuilder(CreateTanhOpBuilder, node); + case kTfLiteBuiltinTransposeConv: + return AddBuilder(CreateTransposeConvolutionOpBuilder, node); + case kTfLiteBuiltinHardSwish: + return AddBuilder(CreateHardSwishOpBuilder, node); + default: + return nullptr; + } +} + +OpBuilder* GraphBuilder::AddBuilder( + const std::function& builder, + const TfLiteNode* node) { + if (builder == nullptr) { + fprintf(stderr, "builder should be set.\n"); + return nullptr; + } + OpBuilder* op = builder(this); + + builders_.emplace_back(op); + op->SetNodeID(builders_.size()); + if (node != nullptr) { + op->SetBuiltinData(node->builtin_data); + op->SetTfLiteNode(node); + } + return builders_.back().get(); +} + +CoreML::Specification::Model* GraphBuilder::BuildModel() { + CoreML::Specification::Model* model = new CoreML::Specification::Model(); + auto* neural_network = model->mutable_neuralnetwork(); + for (auto& builder : builders_) { + CoreML::Specification::NeuralNetworkLayer* layer = builder->Build(); + if (layer == nullptr) { + fprintf(stderr, "Null layer returned from builder: %s\n", + builder->DebugName()); + continue; + } + neural_network->mutable_layers()->AddAllocated(layer); + } + return model; +} + +void GraphBuilder::AddTensorWithID(int tf_tensor_id, + const TensorID& tensor_id) { + if (tensors_.size() <= tf_tensor_id) { + tensors_.resize(tf_tensor_id + 1); + used_tensor_.resize(tf_tensor_id + 1); + } + tensors_[tf_tensor_id] = tensor_id; +} + +std::string GraphBuilder::GetTensorName(int tensor_id) { + return GetTensorID(tensor_id).ToString(); +} + +const TensorID GraphBuilder::GetTensorID(int tensor_id) { + if (!HasTensor(tensor_id)) { + // TODO(karimnosseir): Double check if this happened, if we are + // adding in execution order it shouldn't happen. + fprintf(stderr, "index out of range...!!! Requested index %d , size %d\n", + tensor_id, static_cast(tensors_.size())); + // Return invalid ID. + return TensorID(-1, -1); + } + used_tensor_[tensor_id] = true; + return tensors_[tensor_id]; +} + +bool GraphBuilder::HasTensor(int tflite_tensor_index) { + if (tensors_.size() <= tflite_tensor_index) { + return false; + } + return tensors_[tflite_tensor_index].NodeID() != -1; +} + +bool GraphBuilder::IsTensorUsed(int tflite_tensor_index) { + if (!HasTensor(tflite_tensor_index)) return false; + return used_tensor_[tflite_tensor_index]; +} + +void OpBuilder::AddInput(const std::string& input_name) { + if (layer_ == nullptr) { + layer_.reset(new CoreML::Specification::NeuralNetworkLayer); + } + *layer_->mutable_input()->Add() = input_name; +} + +void OpBuilder::AddInput(const TensorID& input_id) { + AddInput(input_id.ToString()); +} + +void OpBuilder::AddInput(int tf_input_id) { + AddInput(graph_builder_->GetTensorName(tf_input_id)); +} + +TensorID OpBuilder::AddOutput() { + auto tensor_id = TensorID(GetID(), num_outputs_++); + *layer_->mutable_output()->Add() = tensor_id.ToString(); + return tensor_id; +} + +} // namespace coreml +} // namespace delegates +} // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/op_builder.h b/tensorflow/lite/experimental/delegates/coreml/builders/op_builder.h new file mode 100644 index 00000000000..5367ae20d2f --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/op_builder.h @@ -0,0 +1,181 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_OP_BUILDER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_OP_BUILDER_H_ + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "external/coremltools/mlmodel/format/Model.pb.h" +#include "external/coremltools/mlmodel/format/NeuralNetwork.pb.h" +#include "tensorflow/lite/c/common.h" + +namespace tflite { +namespace delegates { +namespace coreml { +class OpBuilder; + +// A class represents an ID in the coreML graph. +// A node is represented by a pair (node_id, and output_index) +// API is experimental and subject to change. +class TensorID { + public: + TensorID() {} + TensorID(int node, int output_id) : node_(node), output_id_(output_id) {} + + std::string ToString() const { return absl::StrCat(node_, "__", output_id_); } + + int NodeID() const { return node_; } + + int OutputID() const { return output_id_; } + + private: + int node_ = -1; + int output_id_ = -1; +}; + +// Builder for the whole graph. +// All op builders should be added using AddBuilder +// and then BuildModel should be called to return the CoreML generated. +// +// API is experimental and subject to change. +class GraphBuilder { + public: + // Returns pointer to the created builder. Ownership still belongs + // to the GraphBuilder. + OpBuilder* AddBuilder(int builtin_code, const TfLiteNode* node); + + // Returns pointer to the created builder with op builder function provided. + OpBuilder* AddBuilder(const std::function& builder, + const TfLiteNode* node); + + // Builds Model instance and returns it. + CoreML::Specification::Model* BuildModel(); + + // Returns string representing tensor 'tensor_id' in coreML. + // tensor_id should have been added before calling this method. + std::string GetTensorName(int tensor_id); + + // Returns Core ML Tensor ID for TFL 'tensor_id'. + // tensor_id should have been added before calling this method. + const TensorID GetTensorID(int tensor_id); + + void AddTensorWithID(int tf_tensor_id, const TensorID& tensor_id); + + // Return true if this tensor was added before to the graph. + bool HasTensor(int tflite_tensor_index); + // Return if this tensor is used in the graph (not as data). + // This information is used to mark constant tensors that are used as input. + bool IsTensorUsed(int tflite_tensor_index); + + private: + std::vector> builders_; + // Index in the vector is the tflite_tensor_index, the value + // is the ID in the coreml graph. + std::vector tensors_; + std::vector used_tensor_; +}; + +// Interface for all op layers +// API is experimental and subject to change. +class OpBuilder { + public: + explicit OpBuilder(GraphBuilder* graph_builder) + : graph_builder_(graph_builder) {} + virtual ~OpBuilder() {} + + // Returns the Layer this builder responsible for. + // Ownership is transferred to caller. + virtual CoreML::Specification::NeuralNetworkLayer* Build() { + layer_->set_name(DebugName()); + return layer_.release(); + } + + virtual TfLiteStatus RegisterInputs(const TfLiteIntArray* inputs, + TfLiteContext* context) { + return kTfLiteOk; + } + + virtual TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs, + TfLiteContext* context) { + return kTfLiteOk; + } + + // Adds additional required OpBuilders, and populate builder_output_ with + // Actual output that corresponds to output tensor of TFL Node. + // Clients need to override this in cases where the nodes can be used for + // composing other ops. For example, Relu6 in TfLite can be converted to + // Relu -> Threshold -> Neg. + // TODO(b/147211734): have this called automatically when necessary. + virtual TfLiteStatus PopulateSubgraph(TfLiteContext* context) { + builder_output_ = AddOutput(); + return kTfLiteOk; + } + + virtual const char* DebugName() = 0; + + void SetBuiltinData(void* builtin_data) { builtin_data_ = builtin_data; } + + void SetNodeID(int id) { node_id_ = id; } + + void SetTfLiteNode(const TfLiteNode* node) { tflite_node_ = node; } + + int GetID() const { return node_id_; } + + TensorID AddOutput(); + + // To be used by clients that needs the output of the node. + virtual TensorID GetOutput(TfLiteContext* context) { + if (builder_output_.NodeID() != -1) { + return builder_output_; + } + // builder_output_ is not set when PopulateSubgraph is not called. + builder_output_ = AddOutput(); + return builder_output_; + } + + // Adds input with tensor name. + void AddInput(const std::string& input_name); + + // Adds input with CoreML tensor ID. + void AddInput(const TensorID& input_id); + + // Adds input with TF Lite tensor ID. + // TODO(taeheej): cleanup AddInput use cases and used tensor tracking. + void AddInput(int tf_input_id); + + protected: + // Helper to print op instance name. + void GetDebugName(const char* name, int id, char* debug_name) { + // TODO(karimnosseir): Move away from absl, probably adding overhead + // on binary size ?. + absl::SNPrintF(debug_name, 100 * sizeof(char), "%s_%d", name, id); + } + + GraphBuilder* graph_builder_ = nullptr; + // Data needed by this node. + void* builtin_data_ = nullptr; + int node_id_ = -1; + int num_outputs_ = 0; + const TfLiteNode* tflite_node_ = nullptr; + TensorID builder_output_; + char str_debug_name_[100] = {0}; + std::unique_ptr layer_; +}; + +} // namespace coreml +} // namespace delegates +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/op_factory.h b/tensorflow/lite/experimental/delegates/coreml/builders/op_factory.h new file mode 100644 index 00000000000..898c1a96bd2 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/op_factory.h @@ -0,0 +1,51 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_OP_FACTORY_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_OP_FACTORY_H_ + +#include "tensorflow/lite/c/builtin_op_data.h" + +namespace tflite { +namespace delegates { +namespace coreml { +class GraphBuilder; +class OpBuilder; + +OpBuilder* CreateAddOpBuilder(GraphBuilder* graph_builder); +OpBuilder* CreateAveragePool2dOpBuilder(GraphBuilder* graph_builder); +OpBuilder* CreateConcatenationOpBuilder(GraphBuilder* graph_builder); +OpBuilder* CreateConvolutionOpBuilder(GraphBuilder* graph_builder); +OpBuilder* CreateDepthwiseConvolutionOpBuilder(GraphBuilder* graph_builder); +OpBuilder* CreateHardSwishOpBuilder(GraphBuilder* graph_builder); +OpBuilder* CreateLogisticOpBuilder(GraphBuilder* graph_builder); +OpBuilder* CreateMaxPool2dOpBuilder(GraphBuilder* graph_builder); +OpBuilder* CreateMulOpBuilder(GraphBuilder* graph_builder); +OpBuilder* CreateReluOpBuilder(GraphBuilder* graph_builder); +OpBuilder* CreateReluN1To1OpBuilder(GraphBuilder* graph_builder); +OpBuilder* CreateRelu6OpBuilder(GraphBuilder* graph_builder); +OpBuilder* CreateReshapeOpBuilder(GraphBuilder* graph_builder); +OpBuilder* CreateResizeBilinearOpBuilder(GraphBuilder* graph_builder); +OpBuilder* CreateSoftmaxOpBuilder(GraphBuilder* graph_builder); +OpBuilder* CreateTanhOpBuilder(GraphBuilder* graph_builder); +OpBuilder* CreateTransposeConvolutionOpBuilder(GraphBuilder* graph_builder); + +OpBuilder* CreateActivationLayerBuilder(GraphBuilder* graph_builder); +OpBuilder* CreateThresholdLayerBuilder(GraphBuilder* graph_builder); + +} // namespace coreml +} // namespace delegates +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_OP_FACTORY_H_ diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/op_validator.h b/tensorflow/lite/experimental/delegates/coreml/builders/op_validator.h new file mode 100644 index 00000000000..0d47b8f2d86 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/op_validator.h @@ -0,0 +1,43 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_OP_VALIDATOR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_OP_VALIDATOR_H_ + +#include "tensorflow/lite/c/builtin_op_data.h" + +namespace tflite { +namespace delegates { +namespace coreml { +// Follow the ordering of TfLiteBuiltinOperator enum. +bool IsConcatenationOpSupported(const TfLiteRegistration* registration, + const TfLiteNode* node, TfLiteContext* context); +bool IsConvolutionOpSupported(const TfLiteRegistration* registration, + const TfLiteNode* node, TfLiteContext* context); +bool IsDepthwiseConvolutionOpSupported(const TfLiteRegistration* registration, + const TfLiteNode* node, + TfLiteContext* context); +bool IsReshapeOpSupported(const TfLiteRegistration* registration, + const TfLiteNode* node, TfLiteContext* context); +bool IsResizeBilinearOpSupported(const TfLiteRegistration* registration, + const TfLiteNode* node, + TfLiteContext* context); +bool IsTransposeConvolutionOpSupported(const TfLiteRegistration* registration, + const TfLiteNode* node, + TfLiteContext* context); +} // namespace coreml +} // namespace delegates +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_OP_VALIDATOR_H_ diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/pooling_layer_builder.cc b/tensorflow/lite/experimental/delegates/coreml/builders/pooling_layer_builder.cc new file mode 100644 index 00000000000..8859639b1fb --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/pooling_layer_builder.cc @@ -0,0 +1,120 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/delegates/coreml/builders/pooling_layer_builder.h" + +#include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_factory.h" + +namespace tflite { +namespace delegates { +namespace coreml { + +const char* PoolingLayerBuilder::DebugName() { + if (str_debug_name_[0]) return str_debug_name_; + switch (pooling_type_) { + case kTfLiteBuiltinAveragePool2d: + GetDebugName("PoolingLayerBuilder (AVERAGE)", node_id_, str_debug_name_); + break; + + case kTfLiteBuiltinMaxPool2d: + GetDebugName("PoolingLayerBuilder (MAX)", node_id_, str_debug_name_); + break; + case kTfLiteBuiltinL2Pool2d: + GetDebugName("PoolingLayerBuilder (L2, unsupported)", + node_id_, str_debug_name_); + break; + default: + GetDebugName("PoolingLayerBuilder (ERROR)", node_id_, str_debug_name_); + } + return str_debug_name_; +} + +CoreML::Specification::NeuralNetworkLayer* PoolingLayerBuilder::Build() { + if (layer_ == nullptr) { + layer_.reset(new CoreML::Specification::NeuralNetworkLayer); + } + layer_->set_name(DebugName()); + const TfLitePoolParams* params = + reinterpret_cast(builtin_data_); + auto* pooling_params = layer_->mutable_pooling(); + pooling_params->mutable_stride()->Add(params->stride_height); + pooling_params->mutable_stride()->Add(params->stride_width); + pooling_params->mutable_kernelsize()->Add(params->filter_height); + pooling_params->mutable_kernelsize()->Add(params->filter_width); + + if (params->padding == kTfLitePaddingSame) { + pooling_params->mutable_same(); + } else { + pooling_params->mutable_valid(); + } + + switch (pooling_type_) { + case kTfLiteBuiltinAveragePool2d: + pooling_params->set_type( + CoreML::Specification::PoolingLayerParams::AVERAGE); + pooling_params->set_avgpoolexcludepadding(true); + break; + case kTfLiteBuiltinMaxPool2d: + pooling_params->set_type(CoreML::Specification::PoolingLayerParams::MAX); + break; + case kTfLiteBuiltinL2Pool2d: + // TODO(b/145873272) implement L2 pooling + // NOLINTNEXTLINE: minimize absl usage + fprintf(stderr, "L2 pooling is not supported yet.\n"); + return nullptr; + default: + // NOLINTNEXTLINE: minimize absl usage + fprintf(stderr, "Unexpected pooling type.\n"); // Should not reach here. + return nullptr; + } + + // TODO(b/145582958): Add padding values. + // TODO(b/145582958): Handle fused activation function. + return layer_.release(); +} + +TfLiteStatus PoolingLayerBuilder::RegisterInputs(const TfLiteIntArray* inputs, + TfLiteContext* context) { + if (inputs->size != 1) { + TF_LITE_KERNEL_LOG(context, "Wrong # of inputs to Pooling!."); + return kTfLiteError; + } + AddInput(inputs->data[0]); + return kTfLiteOk; +} + +TfLiteStatus PoolingLayerBuilder::RegisterOutputs(const TfLiteIntArray* outputs, + TfLiteContext* context) { + if (outputs->size != 1) { + TF_LITE_KERNEL_LOG(context, "Wrong # of outputs to Pooling!."); + return kTfLiteError; + } + graph_builder_->AddTensorWithID(outputs->data[0], GetOutput(context)); + return kTfLiteOk; +} + +OpBuilder* CreateAveragePool2dOpBuilder(GraphBuilder* graph_builder) { + return new PoolingLayerBuilder(graph_builder, kTfLiteBuiltinAveragePool2d); +} + +OpBuilder* CreateMaxPool2dOpBuilder(GraphBuilder* graph_builder) { + return new PoolingLayerBuilder(graph_builder, kTfLiteBuiltinMaxPool2d); +} + +} // namespace coreml +} // namespace delegates +} // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/pooling_layer_builder.h b/tensorflow/lite/experimental/delegates/coreml/builders/pooling_layer_builder.h new file mode 100644 index 00000000000..1c9be64c23a --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/pooling_layer_builder.h @@ -0,0 +1,50 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_POOLING_LAYER_BUILDER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_POOLING_LAYER_BUILDER_H_ + +#include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_builder.h" + +namespace tflite { +namespace delegates { +namespace coreml { + +class PoolingLayerBuilder : public OpBuilder { + public: + explicit PoolingLayerBuilder(GraphBuilder* graph_builder, + TfLiteBuiltinOperator pooling_type) + : OpBuilder(graph_builder), pooling_type_(pooling_type) {} + + const char* DebugName() override; + + CoreML::Specification::NeuralNetworkLayer* Build() override; + + TfLiteStatus RegisterInputs(const TfLiteIntArray* inputs, + TfLiteContext* context) override; + + TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs, + TfLiteContext* context) override; + + private: + // Should be one of pooling types. + TfLiteBuiltinOperator pooling_type_; +}; + +} // namespace coreml +} // namespace delegates +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_POOLING_LAYER_BUILDER_H_ diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/reshape_op_builder.cc b/tensorflow/lite/experimental/delegates/coreml/builders/reshape_op_builder.cc new file mode 100644 index 00000000000..33040e2e070 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/reshape_op_builder.cc @@ -0,0 +1,142 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/delegates/coreml/builders/reshape_op_builder.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_builder.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_factory.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_validator.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace delegates { +namespace coreml { + +const char* ReshapeOpBuilder::DebugName() { + if (!str_debug_name_[0]) { + GetDebugName("ReshapeOpBuilder", node_id_, str_debug_name_); + } + return str_debug_name_; +} + +CoreML::Specification::NeuralNetworkLayer* ReshapeOpBuilder::Build() { + if (layer_ == nullptr) { + layer_.reset(new CoreML::Specification::NeuralNetworkLayer); + } + layer_->set_name(DebugName()); + for (int dim : shape_) { + layer_->mutable_reshape()->add_targetshape(dim); + } + if (need_transpose_) + layer_->mutable_reshape()->set_mode( + CoreML::Specification::ReshapeLayerParams::CHANNEL_LAST); + return layer_.release(); +} + +void ReshapeOpBuilder::SetShapeFromTensor(const TfLiteTensor* output_shape, + const TfLiteIntArray* input_shape) { + TfLiteIntArray* shape = TfLiteIntArrayCreate(output_shape->dims->data[0]); + std::memcpy(shape->data, GetTensorData(output_shape), + shape->size * sizeof(int)); + + SetShapeFromIntArray(shape, input_shape); + TfLiteIntArrayFree(shape); +} + +void ReshapeOpBuilder::SetShapeFromIntArray(const TfLiteIntArray* output_shape, + const TfLiteIntArray* input_shape) { + // ignore first dimension (batch) + std::copy(output_shape->data + 1, output_shape->data + output_shape->size, + std::back_inserter(shape_)); + + int64_t reshape_size = 1; + int negative_index = -1; + for (int i = 0; i < shape_.size(); ++i) { + if (shape_[i] == -1) { + negative_index = i; + } else { + reshape_size *= shape_[i]; + } + } + if (negative_index >= 0) { + int64_t input_size = NumElements(input_shape); + shape_[negative_index] = input_size / reshape_size; + } + + if (shape_.size() == 2) { + shape_ = {shape_[1], 1, shape_[0]}; + } else if (shape_.size() == 3) { + shape_ = {shape_[2], shape_[0], shape_[1]}; + } + // When channel dimension is changed, reshape should be done with HWC layout. + if (shape_[0] != input_shape->data[input_shape->size - 1]) { + need_transpose_ = true; + } +} + +TfLiteStatus ReshapeOpBuilder::RegisterInputs(const TfLiteIntArray* inputs, + TfLiteContext* context) { + AddInput(inputs->data[0]); + + if (inputs->size == 2) { + SetShapeFromTensor(&context->tensors[inputs->data[1]], + context->tensors[inputs->data[0]].dims); + } else { + const auto* params = reinterpret_cast(builtin_data_); + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(params->num_dimensions); + std::memcpy(output_shape->data, params->shape, + params->num_dimensions * sizeof(int)); + + SetShapeFromIntArray(output_shape, context->tensors[inputs->data[0]].dims); + TfLiteIntArrayFree(output_shape); + } + return kTfLiteOk; +} + +TfLiteStatus ReshapeOpBuilder::RegisterOutputs(const TfLiteIntArray* outputs, + TfLiteContext* context) { + graph_builder_->AddTensorWithID(outputs->data[0], GetOutput(context)); + return kTfLiteOk; +} + +bool IsReshapeOpSupported(const TfLiteRegistration* registration, + const TfLiteNode* node, TfLiteContext* context) { + if (node->inputs->size == 1) { + const auto* params = + reinterpret_cast(node->builtin_data); + return params->num_dimensions == 3 || params->num_dimensions == 4; + } + + const int kShapeTensor = 1; + const auto* shape = GetInput(context, node, kShapeTensor); + if (shape->allocation_type != kTfLiteMmapRo) { + TF_LITE_KERNEL_LOG(context, "Reshape has non-const shape."); + return false; + } + const bool is_shape_tensor = + shape->dims->size == 1 && shape->type == kTfLiteInt32; + return is_shape_tensor && + (shape->dims->data[0] == 3 || shape->dims->data[0] == 4); +} + +OpBuilder* CreateReshapeOpBuilder(GraphBuilder* graph_builder) { + return new ReshapeOpBuilder(graph_builder); +} + +} // namespace coreml +} // namespace delegates +} // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/reshape_op_builder.h b/tensorflow/lite/experimental/delegates/coreml/builders/reshape_op_builder.h new file mode 100644 index 00000000000..0a00a112a60 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/reshape_op_builder.h @@ -0,0 +1,56 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_RESHAPE_OP_BUILDER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_RESHAPE_OP_BUILDER_H_ + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_builder.h" + +namespace tflite { +namespace delegates { +namespace coreml { +// Builder for Reshape op in CoreML. +class ReshapeOpBuilder : public OpBuilder { + public: + explicit ReshapeOpBuilder(GraphBuilder* graph_builder) + : OpBuilder(graph_builder) {} + const char* DebugName() override; + + CoreML::Specification::NeuralNetworkLayer* Build() override; + + TfLiteStatus RegisterInputs(const TfLiteIntArray* inputs, + TfLiteContext* context) override; + TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs, + TfLiteContext* context) override; + + // Sets output shape of the Core ML reshape layer, given output shape and + // the input tensor's shape. + void SetShapeFromTensor(const TfLiteTensor* output_shape, + const TfLiteIntArray* input_shape); + void SetShapeFromIntArray(const TfLiteIntArray* output_shape, + const TfLiteIntArray* input_shape); + + private: + std::vector shape_; + // When channel dimension is changed, reshape should be done with HWC layout, + // thus transpose is required. (set with ReshapeLayerParams.mode) + bool need_transpose_ = false; +}; + +} // namespace coreml +} // namespace delegates +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_RESHAPE_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/resize_bilinear_op_builder.cc b/tensorflow/lite/experimental/delegates/coreml/builders/resize_bilinear_op_builder.cc new file mode 100644 index 00000000000..9b9933e932b --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/resize_bilinear_op_builder.cc @@ -0,0 +1,106 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/delegates/coreml/builders/resize_bilinear_op_builder.h" + +#include + +#include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_factory.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_validator.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace delegates { +namespace coreml { + +const char* ResizeBilinearOpBuilder::DebugName() { + if (str_debug_name_[0]) return str_debug_name_; + GetDebugName("ResizeBilinearOpBuilder", node_id_, str_debug_name_); + return str_debug_name_; +} + +CoreML::Specification::NeuralNetworkLayer* ResizeBilinearOpBuilder::Build() { + if (layer_ == nullptr) { + layer_.reset(new CoreML::Specification::NeuralNetworkLayer); + } + layer_->set_name(DebugName()); + const TfLiteResizeBilinearParams* params = + reinterpret_cast(builtin_data_); + + layer_->mutable_resizebilinear()->mutable_targetsize()->Add(height_); + layer_->mutable_resizebilinear()->mutable_targetsize()->Add(width_); + + // align_corners makes last sampling position to be aligned with last index of + // input. This is the same behavior as STRICT_ALIGN_ENDPOINTS_MODE in Core ML + // sampling mode. When not set, the sampling positions are the same as + // UPSAMPLE_MODE. (indices are in [0, (input_size-1)/output_size]) + if (params->align_corners) { + layer_->mutable_resizebilinear()->mutable_mode()->set_samplingmethod( + CoreML::Specification::SamplingMode::STRICT_ALIGN_ENDPOINTS_MODE); + } else { + layer_->mutable_resizebilinear()->mutable_mode()->set_samplingmethod( + CoreML::Specification::SamplingMode::UPSAMPLE_MODE); + } + return layer_.release(); +} + +TfLiteStatus ResizeBilinearOpBuilder::RegisterInputs( + const TfLiteIntArray* inputs, TfLiteContext* context) { + if (inputs->size != 2) { + TF_LITE_KERNEL_LOG(context, "Wrong # of inputs to ResizeBilinear!."); + return kTfLiteError; + } + AddInput(inputs->data[0]); + TfLiteTensor* size = &context->tensors[inputs->data[1]]; + height_ = GetTensorData(size)[0]; + width_ = GetTensorData(size)[1]; + return kTfLiteOk; +} + +TfLiteStatus ResizeBilinearOpBuilder::RegisterOutputs( + const TfLiteIntArray* outputs, TfLiteContext* context) { + if (outputs->size != 1) { + TF_LITE_KERNEL_LOG(context, "Wrong # of outputs to ResizeBilinear!."); + return kTfLiteError; + } + graph_builder_->AddTensorWithID(outputs->data[0], GetOutput(context)); + return kTfLiteOk; +} + +OpBuilder* CreateResizeBilinearOpBuilder(GraphBuilder* graph_builder) { + return new ResizeBilinearOpBuilder(graph_builder); +} + +bool IsResizeBilinearOpSupported(const TfLiteRegistration* registration, + const TfLiteNode* node, + TfLiteContext* context) { + if (node->builtin_data == nullptr) { + return false; + } + const int kOutputSize = 1; + if (!IsConstantTensor(GetInput(context, node, kOutputSize))) { + TF_LITE_KERNEL_LOG(context, + "Output size of ResizeBilinear should be constant."); + return false; + } + return true; +} + +} // namespace coreml +} // namespace delegates +} // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/resize_bilinear_op_builder.h b/tensorflow/lite/experimental/delegates/coreml/builders/resize_bilinear_op_builder.h new file mode 100644 index 00000000000..f89258b397b --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/resize_bilinear_op_builder.h @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_RESIZE_BILINEAR_OP_BUILDER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_RESIZE_BILINEAR_OP_BUILDER_H_ + +#include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_builder.h" + +namespace tflite { +namespace delegates { +namespace coreml { + +class ResizeBilinearOpBuilder : public OpBuilder { + public: + explicit ResizeBilinearOpBuilder(GraphBuilder* graph_builder) + : OpBuilder(graph_builder) {} + + const char* DebugName() override; + + CoreML::Specification::NeuralNetworkLayer* Build() override; + + TfLiteStatus RegisterInputs(const TfLiteIntArray* inputs, + TfLiteContext* context) override; + + TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs, + TfLiteContext* context) override; + + private: + int height_; + int width_; +}; + +} // namespace coreml +} // namespace delegates +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_RESIZE_BILINEAR_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/softmax_op_builder.cc b/tensorflow/lite/experimental/delegates/coreml/builders/softmax_op_builder.cc new file mode 100644 index 00000000000..1bd40e94d13 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/softmax_op_builder.cc @@ -0,0 +1,68 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/delegates/coreml/builders/softmax_op_builder.h" + +#include "tensorflow/lite/c/common.h" + +namespace tflite { +namespace delegates { +namespace coreml { +const char* SoftmaxOpBuilder::DebugName() { + if (!str_debug_name_[0]) + GetDebugName("SoftmaxOpBuilder", node_id_, str_debug_name_); + return str_debug_name_; +} + +CoreML::Specification::NeuralNetworkLayer* SoftmaxOpBuilder::Build() { + if (layer_ == nullptr) { + layer_.reset(new CoreML::Specification::NeuralNetworkLayer); + } + layer_->set_name(DebugName()); + layer_->mutable_softmax(); + + return layer_.release(); +} + +TfLiteStatus SoftmaxOpBuilder::RegisterInputs(const TfLiteIntArray* inputs, + TfLiteContext* context) { + if (inputs->size != 1) { + TF_LITE_KERNEL_LOG(context, "Wrong # of inputs to softmax!."); + return kTfLiteError; + } + AddInput(inputs->data[0]); + return kTfLiteOk; +} + +TfLiteStatus SoftmaxOpBuilder::RegisterOutputs(const TfLiteIntArray* outputs, + TfLiteContext* context) { + if (outputs->size != 1) { + TF_LITE_KERNEL_LOG(context, "Wrong # of outputs to softmax!."); + return kTfLiteError; + } + TensorID output_tensor = GetOutput(context); + if (output_tensor.NodeID() == -1) { + TF_LITE_KERNEL_LOG(context, "Failed to build output tensor."); + return kTfLiteError; + } + graph_builder_->AddTensorWithID(outputs->data[0], output_tensor); + return kTfLiteOk; +} + +OpBuilder* CreateSoftmaxOpBuilder(GraphBuilder* graph_builder) { + return new SoftmaxOpBuilder(graph_builder); +} +} // namespace coreml +} // namespace delegates +} // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/softmax_op_builder.h b/tensorflow/lite/experimental/delegates/coreml/builders/softmax_op_builder.h new file mode 100644 index 00000000000..f5028e4960e --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/softmax_op_builder.h @@ -0,0 +1,43 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_SOFTMAX_OP_BUILDER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_SOFTMAX_OP_BUILDER_H_ + +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_builder.h" + +namespace tflite { +namespace delegates { +namespace coreml { +// Builder for Softmax op in CoreML. +class SoftmaxOpBuilder : public OpBuilder { + public: + explicit SoftmaxOpBuilder(GraphBuilder* graph_builder) + : OpBuilder(graph_builder) {} + const char* DebugName() override; + + CoreML::Specification::NeuralNetworkLayer* Build() override; + + TfLiteStatus RegisterInputs(const TfLiteIntArray* inputs, + TfLiteContext* context) override; + + TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs, + TfLiteContext* context) override; +}; + +} // namespace coreml +} // namespace delegates +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_SOFTMAX_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/test_util.h b/tensorflow/lite/experimental/delegates/coreml/builders/test_util.h new file mode 100644 index 00000000000..0866465822c --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/test_util.h @@ -0,0 +1,54 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_TEST_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_TEST_UTIL_H_ + +#include "tensorflow/lite/experimental/delegates/coreml/coreml_delegate.h" +#include "tensorflow/lite/kernels/test_util.h" + +#import + +namespace tflite { +namespace delegates { +namespace coreml { +class SingleOpModelWithCoreMlDelegate : public tflite::SingleOpModel { + public: + SingleOpModelWithCoreMlDelegate() : delegate_(nullptr, [](TfLiteDelegate*) {}) {} + + static const char kDelegateName[]; + + void ApplyDelegateAndInvoke(); + + tflite::Interpreter* interpreter() { return interpreter_.get(); } + + protected: + using SingleOpModel::builder_; + + private: + tflite::Interpreter::TfLiteDelegatePtr delegate_; + TfLiteCoreMlDelegateOptions params_ = {0}; +}; + +} // namespace coreml +} // namespace delegates +} // namespace tflite + +@interface BaseOpTest : XCTestCase +@property tflite::delegates::coreml::SingleOpModelWithCoreMlDelegate* model; +- (void)validateInterpreter:(tflite::Interpreter*)interpreter; +- (void)invokeAndValidate; +@end + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_TEST_UTIL_H_ diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/test_util.mm b/tensorflow/lite/experimental/delegates/coreml/builders/test_util.mm new file mode 100644 index 00000000000..e576e4e911b --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/test_util.mm @@ -0,0 +1,59 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/delegates/coreml/builders/test_util.h" + +namespace tflite { +namespace delegates { +namespace coreml { + +const char SingleOpModelWithCoreMlDelegate::kDelegateName[] = "TfLiteCoreMlDelegate"; + +void SingleOpModelWithCoreMlDelegate::ApplyDelegateAndInvoke() { + auto* delegate_ptr = TfLiteCoreMlDelegateCreate(¶ms_); + ASSERT_TRUE(delegate_ptr != nullptr); + delegate_ = tflite::Interpreter::TfLiteDelegatePtr( + delegate_ptr, [](TfLiteDelegate* delegate) { TfLiteCoreMlDelegateDelete(delegate); }); + // Add delegate. + // TODO(karimnosseir): This doesn't actually make the test fail, switch to something else. + ASSERT_TRUE(interpreter_->ModifyGraphWithDelegate(delegate_.get()) != kTfLiteError); + + Invoke(); +} + +} // namespace coreml +} // namespace delegates +} // namespace tflite + +@implementation BaseOpTest +- (void)validateInterpreter:(tflite::Interpreter*)interpreter { + // Make sure we have valid interpreter. + XCTAssertTrue(interpreter != nullptr); + // Make sure graph has one Op which is the delegate node. + XCTAssertEqual(interpreter->execution_plan().size(), 1); + const int node_index = interpreter->execution_plan()[0]; + const auto* node_and_reg = interpreter->node_and_registration(node_index); + XCTAssertTrue(node_and_reg != nullptr); + XCTAssertTrue(node_and_reg->second.custom_name != nullptr); + XCTAssertTrue( + node_and_reg->second.custom_name == + std::string(tflite::delegates::coreml::SingleOpModelWithCoreMlDelegate::kDelegateName)); +} + +- (void)invokeAndValidate { + _model->ApplyDelegateAndInvoke(); + [self validateInterpreter:_model->interpreter()]; +} + +@end diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/threshold_layer_builder.cc b/tensorflow/lite/experimental/delegates/coreml/builders/threshold_layer_builder.cc new file mode 100644 index 00000000000..d1dfded1d1b --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/threshold_layer_builder.cc @@ -0,0 +1,67 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/delegates/coreml/builders/threshold_layer_builder.h" + +#include "tensorflow/lite/c/common.h" + +namespace tflite { +namespace delegates { +namespace coreml { + +const char* ThresholdLayerBuilder::DebugName() { + if (!str_debug_name_[0]) + GetDebugName("ThresholdLayerBuilder", node_id_, str_debug_name_); + return str_debug_name_; +} + +CoreML::Specification::NeuralNetworkLayer* ThresholdLayerBuilder::Build() { + if (layer_ == nullptr) { + layer_.reset(new CoreML::Specification::NeuralNetworkLayer); + } + layer_->set_name(DebugName()); + layer_->mutable_unary()->set_alpha(alpha_); + layer_->mutable_unary()->set_scale(scale_); + layer_->mutable_unary()->set_type( + CoreML::Specification::UnaryFunctionLayerParams::THRESHOLD); + return layer_.release(); +} + +TfLiteStatus ThresholdLayerBuilder::RegisterInputs(const TfLiteIntArray* inputs, + TfLiteContext* context) { + if (inputs->size != 1) { + TF_LITE_KERNEL_LOG(context, "Threshold: Wrong # of inputs!."); + return kTfLiteError; + } + AddInput(inputs->data[0]); + return kTfLiteOk; +} + +TfLiteStatus ThresholdLayerBuilder::RegisterOutputs( + const TfLiteIntArray* outputs, TfLiteContext* context) { + if (outputs->size != 1) { + TF_LITE_KERNEL_LOG(context, "Threshold: Wrong # of outputs!."); + return kTfLiteError; + } + graph_builder_->AddTensorWithID(outputs->data[0], GetOutput(context)); + return kTfLiteOk; +} + +OpBuilder* CreateThresholdLayerBuilder(GraphBuilder* graph_builder) { + return new ThresholdLayerBuilder(graph_builder); +} + +} // namespace coreml +} // namespace delegates +} // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/threshold_layer_builder.h b/tensorflow/lite/experimental/delegates/coreml/builders/threshold_layer_builder.h new file mode 100644 index 00000000000..4b12cb7c3b8 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/threshold_layer_builder.h @@ -0,0 +1,56 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_THRESHOLD_LAYER_BUILDER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_THRESHOLD_LAYER_BUILDER_H_ + +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_builder.h" + +namespace tflite { +namespace delegates { +namespace coreml { + +// Layer that provides threshold operation. Depending on scale, this can be used +// as max (scale > 0) or min (scale < 0), in combination with another negative +// linear activation layer) operation. +// TODO(karimnosseir): Generalize to other unary operators. +class ThresholdLayerBuilder : public OpBuilder { + public: + explicit ThresholdLayerBuilder(GraphBuilder* graph_builder) + : OpBuilder(graph_builder) {} + + const char* DebugName() override; + + CoreML::Specification::NeuralNetworkLayer* Build() override; + + void SetAlpha(float alpha) { alpha_ = alpha; } + + void SetScale(float scale) { scale_ = scale; } + + TfLiteStatus RegisterInputs(const TfLiteIntArray* inputs, + TfLiteContext* context) override; + + TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs, + TfLiteContext* context) override; + + private: + float alpha_ = 0.0f; + float scale_ = 1.0f; +}; + +} // namespace coreml +} // namespace delegates +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_THRESHOLD_LAYER_BUILDER_H_ diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/util.cc b/tensorflow/lite/experimental/delegates/coreml/builders/util.cc new file mode 100644 index 00000000000..acaf4ab4bd4 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/util.cc @@ -0,0 +1,87 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/delegates/coreml/builders/util.h" + +#include + +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_validator.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace delegates { +namespace coreml { +namespace { +void Get4DShape(const TfLiteTensor* tensor, std::vector* shape) { + const int rank = tensor->dims->size; + shape->resize(4); + for (int i = 0; i < 4 - rank; i++) { + (*shape)[i] = 1; + } + for (int i = 4 - rank; i < 4; ++i) { + (*shape)[i] = tensor->dims->data[i - (4 - rank)]; + } +} +} // namespace + +// Determines if two tensor shapes are broadcastable. See comment of +// IsBinaryOpSupported for more info. +bool IsBroadcastable(const TfLiteTensor* input_0, const TfLiteTensor* input_1) { + std::vector shape_0; + std::vector shape_1; + Get4DShape(input_0, &shape_0); + Get4DShape(input_1, &shape_1); + const int B_0 = shape_0[0]; + const int B_1 = shape_1[0]; + const int H_0 = shape_0[1]; + const int H_1 = shape_1[1]; + const int W_0 = shape_0[2]; + const int W_1 = shape_1[2]; + const int C_0 = shape_0[3]; + const int C_1 = shape_1[3]; + + // TFL tensor has [B, H, W, C] format. + // comparing B: shape[0], (H, W): (shape[1], shape[2]), C: shape[3]. + + // When B is different, it's not supported unless + // one of the tensor is size 1 constant tensor. + if (B_0 != B_1) { + if (!((IsConstantTensor(input_0) && NumElements(input_0) == 1) || + (IsConstantTensor(input_1) && NumElements(input_1) == 1))) + return false; + } + + // When (H, W) are different, one of the (H, W) should be (1, 1). + if (H_0 != H_1 || W_0 != W_1) { + if (!((H_0 == 1 && W_0 == 1) || (H_1 == 1 && W_1 == 1))) { + return false; + } + } + + // When C is different, one of the C should be 1. + if (C_0 != C_1) { + if (C_0 != 1 && C_1 != 1) return false; + } + return true; +} + +bool IsBinaryOpSupported(const TfLiteRegistration* registration, + const TfLiteNode* node, TfLiteContext* context) { + return IsBroadcastable(GetInput(context, node, 0), + GetInput(context, node, 1)); +} + +} // namespace coreml +} // namespace delegates +} // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/util.h b/tensorflow/lite/experimental/delegates/coreml/builders/util.h new file mode 100644 index 00000000000..0c54cbe6eea --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/util.h @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_UTIL_H_ + +#include "tensorflow/lite/c/common.h" + +namespace tflite { +namespace delegates { +namespace coreml { + +// Checks if Binary ops have supported broadcastable shapes. +// Core ml arithmetic ops - Add and Mul support broadcasts among +// [B, 1, 1, 1], [B, C, 1, 1], [B, 1, H, W], [B, C, H, W]. +// other shapes should be rejected. Unless it is a constant tensor of size 1, +// which will be added as data. + +bool IsBinaryOpSupported(const TfLiteRegistration* registration, + const TfLiteNode* node, TfLiteContext* context); + +} // namespace coreml +} // namespace delegates +} // namespace tflite +#endif // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_UTIL_H_ diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/util_test.cc b/tensorflow/lite/experimental/delegates/coreml/builders/util_test.cc new file mode 100644 index 00000000000..929bc4a2282 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/builders/util_test.cc @@ -0,0 +1,110 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/delegates/coreml/builders/util.h" + +#include + +#include +#include +#include "tensorflow/lite/c/common.h" + +using tflite::delegates::coreml::IsBinaryOpSupported; + +namespace { + +class IsBinaryOpSupportedTest : public testing::Test { + protected: + void SetUp() override { + const int input_size = 2; + tensors_.resize(input_size); + context_.tensors = tensors_.data(); + node_.inputs = TfLiteIntArrayCreate(input_size); + for (int i = 0; i < input_size; ++i) { + node_.inputs->data[i] = i; + } + + for (auto& tensor : tensors_) { + tensor.allocation_type = kTfLiteArenaRw; + tensor.dims = nullptr; + } + } + + void TearDown() override { + FreeInputShapes(); + TfLiteIntArrayFree(node_.inputs); + } + + void SetInputShapes(const std::vector>& shapes) { + for (int i = 0; i < tensors_.size(); ++i) { + tensors_[i].dims = TfLiteIntArrayCreate(shapes[i].size()); + std::copy(shapes[i].begin(), shapes[i].end(), tensors_[i].dims->data); + } + } + + void FreeInputShapes() { + for (auto& tensor : tensors_) { + if (tensor.dims != nullptr) { + TfLiteIntArrayFree(tensor.dims); + tensor.dims = nullptr; + } + } + } + + TfLiteContext context_; + TfLiteNode node_; + std::vector tensors_; +}; + +TEST_F(IsBinaryOpSupportedTest, BroadcastTest) { + std::vector base_shape = {2, 2, 3, 4}; + std::vector> shapes = { + {2, 2, 3, 4}, {2, 1, 1, 4}, {2, 2, 3, 1}, {2, 1, 1, 1}}; + std::vector inputs(2); + for (const auto& shape : shapes) { + SetInputShapes({base_shape, shape}); + ASSERT_TRUE(IsBinaryOpSupported(nullptr, &node_, &context_)); + FreeInputShapes(); + } +} + +TEST_F(IsBinaryOpSupportedTest, LessThan4DTest) { + std::vector base_shape = {1, 2, 3, 4}; + std::vector> shapes = {{4}, {2, 3, 1}, {1, 1, 1, 1}}; + for (const auto& shape : shapes) { + SetInputShapes({base_shape, shape}); + ASSERT_TRUE(IsBinaryOpSupported(nullptr, &node_, &context_)); + FreeInputShapes(); + } +} + +TEST_F(IsBinaryOpSupportedTest, ConstScalarTest) { + std::vector base_shape = {2, 2, 3, 4}; + tensors_[1].allocation_type = kTfLiteMmapRo; + SetInputShapes({base_shape, {1}}); + ASSERT_TRUE(IsBinaryOpSupported(nullptr, &node_, &context_)); + FreeInputShapes(); +} + +TEST_F(IsBinaryOpSupportedTest, NotSupportedBroadcastTest) { + std::vector base_shape = {2, 2, 3, 4}; + std::vector> shapes = { + {2, 2, 1, 4}, {2, 1, 2, 4}, {1, 2, 3, 1}, {1, 1, 1, 1}}; + for (const auto& shape : shapes) { + SetInputShapes({base_shape, shape}); + ASSERT_FALSE(IsBinaryOpSupported(nullptr, &node_, &context_)); + FreeInputShapes(); + } +} +} // namespace diff --git a/tensorflow/lite/experimental/delegates/coreml/coreml_delegate.h b/tensorflow/lite/experimental/delegates/coreml/coreml_delegate.h new file mode 100644 index 00000000000..4e47b5739de --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/coreml_delegate.h @@ -0,0 +1,40 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_COREML_DELEGATE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_COREML_DELEGATE_H_ + +#include "tensorflow/lite/c/common.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus +typedef struct { + // TODO(karimnosseir): Remove when other fields are added. + // We have dummy for now as we can't have empty struct in C. + char dummy; +} TfLiteCoreMlDelegateOptions; + +// Return a delegate that uses CoreML for ops execution. +// Must outlive the interpreter. +TfLiteDelegate* TfLiteCoreMlDelegateCreate( + const TfLiteCoreMlDelegateOptions* options); + +// Do any needed cleanup and delete 'delegate'. +void TfLiteCoreMlDelegateDelete(TfLiteDelegate* delegate); +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_COREML_DELEGATE_H_ diff --git a/tensorflow/lite/experimental/delegates/coreml/coreml_delegate.mm b/tensorflow/lite/experimental/delegates/coreml/coreml_delegate.mm new file mode 100644 index 00000000000..4fd861e6163 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/coreml_delegate.mm @@ -0,0 +1,245 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/delegates/coreml/coreml_delegate.h" + +#include + +#include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/context_util.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_validator.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/util.h" +#include "tensorflow/lite/experimental/delegates/coreml/coreml_delegate_kernel.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/minimal_logging.h" + +namespace tflite { +namespace { +using delegates::coreml::CoreMlDelegateKernel; + +bool IsNodeSupportedByDelegate(const TfLiteRegistration* registration, const TfLiteNode* node, + TfLiteContext* context) { + if (@available(iOS 11.0, *)) { + } else { + return false; + } + + // For most ops, only version 1 is supported. + if (registration->version > 1) { + switch (registration->builtin_code) { + case kTfLiteBuiltinDepthwiseConv2d: + if (registration->version > 2) return false; + break; + default: + return false; + } + } + + // The model should not be full-integer quantized. For ops supported by Core ML delegate, + // Testing if the first input is float is sufficient to filter full-integer quantized ops. + int input_tensor_index = 0; + // TransposeConv input: (output_shape, filters, input) + if (registration->builtin_code == kTfLiteBuiltinTransposeConv) { + input_tensor_index = 2; + } + if (GetInput(context, node, input_tensor_index)->type != kTfLiteFloat32) { + return false; + } + + // TODO(b/149179044): Add extra validation if this is not sufficient. + + // TODO(karimnossier): Refactor this function. + // TODO(karimnosseir): Add + // 1) Checks for versioning. + // 2) Checks for input constraints. + // Follow the ordering of TfLiteBuiltinOperator enum. + switch (registration->builtin_code) { + case kTfLiteBuiltinAdd: { + return node->builtin_data != nullptr && + delegates::coreml::IsBinaryOpSupported(registration, node, context); + } + case kTfLiteBuiltinAveragePool2d: { + const auto* params = reinterpret_cast(node->builtin_data); + return params != nullptr && params->activation == kTfLiteActNone; + } + case kTfLiteBuiltinConcatenation: { + return delegates::coreml::IsConcatenationOpSupported(registration, node, context); + } + case kTfLiteBuiltinConv2d: { + return delegates::coreml::IsConvolutionOpSupported(registration, node, context); + } + case kTfLiteBuiltinDepthwiseConv2d: { + return delegates::coreml::IsDepthwiseConvolutionOpSupported(registration, node, context); + } + case kTfLiteBuiltinHardSwish: { + return true; + } + case kTfLiteBuiltinLogistic: { + return true; + } + case kTfLiteBuiltinMaxPool2d: { + const auto* params = reinterpret_cast(node->builtin_data); + return params != nullptr && params->activation == kTfLiteActNone; + } + case kTfLiteBuiltinMul: { + return node->builtin_data != nullptr && + delegates::coreml::IsBinaryOpSupported(registration, node, context); + } + case kTfLiteBuiltinRelu: { + return true; + } + case kTfLiteBuiltinReluN1To1: { + return true; + } + case kTfLiteBuiltinRelu6: { + return true; + } + case kTfLiteBuiltinReshape: { + return delegates::coreml::IsReshapeOpSupported(registration, node, context); + } + case kTfLiteBuiltinResizeBilinear: { + return delegates::coreml::IsResizeBilinearOpSupported(registration, node, context); + } + case kTfLiteBuiltinSoftmax: { + // Only supports when beta is 1.0 for now. + const auto* softmax_params = reinterpret_cast(node->builtin_data); + return softmax_params != nullptr && softmax_params->beta == 1.0; + } + case kTfLiteBuiltinTanh: { + return true; + } + case kTfLiteBuiltinTransposeConv: { + return delegates::coreml::IsTransposeConvolutionOpSupported(registration, node, context); + } + default: + return false; + } + return false; +} + +TfLiteRegistration GetCoreMlKernelRegistration() { + // This is the registration for the Delegate Node that gets added to + // the TFLite graph instead of the subGraph it replaces it. + // It is treated as an OP node. But in our case + // Init will initialize the delegate + // Invoke will run the delegate graph. + // Prepare for prearing the delegate. + // Free for any cleaning needed by the delegate. + TfLiteRegistration kernel_registration; + kernel_registration.builtin_code = kTfLiteBuiltinDelegate; + kernel_registration.custom_name = "TfLiteCoreMlDelegate"; + kernel_registration.free = [](TfLiteContext* context, void* buffer) -> void { + delete reinterpret_cast(buffer); + }; + kernel_registration.init = [](TfLiteContext* context, const char* buffer, + size_t length) -> void* { + const TfLiteDelegateParams* params = reinterpret_cast(buffer); + CoreMlDelegateKernel* coreml_kernel = new CoreMlDelegateKernel(); + if (coreml_kernel->Init(context, params) != kTfLiteOk) { + delete coreml_kernel; + return nullptr; + } + return coreml_kernel; + }; + kernel_registration.invoke = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus { + CoreMlDelegateKernel* kernel = reinterpret_cast(node->user_data); + if (!kernel) { + TF_LITE_KERNEL_LOG(context, "CoreMl Kernel was not initialized"); + return kTfLiteError; + } + return kernel->Invoke(context, node); + }; + kernel_registration.prepare = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus { + CoreMlDelegateKernel* kernel = reinterpret_cast(node->user_data); + if (kernel == nullptr) { + TF_LITE_KERNEL_LOG(context, "CoreMl Kernel was not initialized"); + return kTfLiteError; + } + return kernel->Prepare(context, node); + }; + + return kernel_registration; +} + +TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { + // Reserve 1 element, since we need first element to be size, will be updated + // later. + std::vector supported_nodes(1); + TfLiteIntArray* plan; + TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); + TfLiteNode* node; + TfLiteRegistration* registration; + + for (int node_index : TfLiteIntArrayView(plan)) { + TF_LITE_ENSURE_STATUS( + context->GetNodeAndRegistration(context, node_index, &node, ®istration)); + if (IsNodeSupportedByDelegate(registration, node, context)) { + supported_nodes.push_back(node_index); + } + } + // Set first element to the number of nodes to replace. + supported_nodes[0] = supported_nodes.size() - 1; + TfLiteRegistration coreml_kernel_registration = GetCoreMlKernelRegistration(); + TFLITE_LOG_PROD(tflite::TFLITE_LOG_INFO, "CoreML delegate: %d nodes delegated out of %d nodes.\n", + supported_nodes[0], plan->size); + + return context->ReplaceNodeSubsetsWithDelegateKernels( + context, coreml_kernel_registration, + reinterpret_cast(supported_nodes.data()), delegate); +} + +class CoreMlDelegate : public TfLiteDelegate { + public: + explicit CoreMlDelegate(const TfLiteCoreMlDelegateOptions* params) + : params_(params != nullptr ? *params : TfLiteCoreMlDelegateOptions()) {} + + TfLiteCoreMlDelegateOptions* params() { return ¶ms_; } + + bool VerifyDelegate() { return true; } + + private: + TfLiteCoreMlDelegateOptions params_; +}; + +TfLiteDelegate* CreateCoreMlDelegate(const TfLiteCoreMlDelegateOptions* options) { + TfLiteDelegate* delegate = new CoreMlDelegate(options); + if (!static_cast(delegate)->VerifyDelegate()) { + delete delegate; + return nullptr; + } + + delegate->data_ = static_cast(delegate)->params(); + delegate->flags = kTfLiteDelegateFlagsNone; + delegate->Prepare = &DelegatePrepare; + delegate->CopyFromBufferHandle = nullptr; + delegate->CopyToBufferHandle = nullptr; + delegate->FreeBufferHandle = nullptr; + + return delegate; +} +} // namespace +} // namespace tflite + +TfLiteDelegate* TfLiteCoreMlDelegateCreate(const TfLiteCoreMlDelegateOptions* options) { + if (@available(iOS 11.0, *)) { + return tflite::CreateCoreMlDelegate(options); + } else { + NSLog(@"Core ML delegate is not supported in this iOS version. " + "Minimum required iOS version is 11.0."); + return nullptr; + } +} + +void TfLiteCoreMlDelegateDelete(TfLiteDelegate* delegate) { delete delegate; } diff --git a/tensorflow/lite/experimental/delegates/coreml/coreml_delegate_kernel.h b/tensorflow/lite/experimental/delegates/coreml/coreml_delegate_kernel.h new file mode 100644 index 00000000000..04053ea81c1 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/coreml_delegate_kernel.h @@ -0,0 +1,69 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_COREML_DELEGATE_KERNEL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_COREML_DELEGATE_KERNEL_H_ + +#include "external/coremltools/mlmodel/format/Model.pb.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/experimental/delegates/coreml/builders/op_builder.h" +#import "tensorflow/lite/experimental/delegates/coreml/coreml_executor.h" + +namespace tflite { +namespace delegates { +namespace coreml { + +// Represents a subgraph in TFLite that will be delegated to CoreML. +// It is abstracted as a single kernel node in the main TFLite graph and +// implements Init/Prepare/Invoke as TFLite kernel nodes. +class CoreMlDelegateKernel { + public: + // Initialize the delegated graph and add required nodes. + TfLiteStatus Init(TfLiteContext* context, const TfLiteDelegateParams* params); + + // Any preparation work needed for the delegated graph. + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node); + + // Allocates delegated tensordefs for graph I/O & execute it. + TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node); + + ~CoreMlDelegateKernel(); + + private: + // Builds the ML Model protocol buffer + TfLiteStatus BuildModel(TfLiteContext* context, + const TfLiteDelegateParams* params); + + // Adds the output tensors to the model generated. + void AddOutputTensors(const TfLiteIntArray* output_tensors, + TfLiteContext* context); + + // Adds the input tensors to the model generated. + void AddInputTensors(const TfLiteIntArray* output_tensors, + TfLiteContext* context); + + std::unique_ptr builder_; + std::unique_ptr model_; + ::CoreMlExecutor* executor_; + + std::vector input_tensor_ids_; + std::vector inputs_; + std::vector outputs_; +}; + +} // namespace coreml +} // namespace delegates +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_COREML_DELEGATE_KERNEL_H_ diff --git a/tensorflow/lite/experimental/delegates/coreml/coreml_delegate_kernel.mm b/tensorflow/lite/experimental/delegates/coreml/coreml_delegate_kernel.mm new file mode 100644 index 00000000000..a36837bcc44 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/coreml_delegate_kernel.mm @@ -0,0 +1,231 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/delegates/coreml/coreml_delegate_kernel.h" + +#include "tensorflow/lite/context_util.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +#import "tensorflow/lite/experimental/delegates/coreml/coreml_executor.h" + +namespace tflite { +namespace delegates { +namespace coreml { +namespace { +// TODO(karimnosseir): Move to util library +TfLiteStatus GetDims(int* batch_size, int* height_size, int* width_size, int* depth_size, + const TfLiteIntArray* dims) { + if (dims == nullptr || dims->size > 4) { + return kTfLiteError; + } + int* dim[] = {batch_size, height_size, width_size, depth_size}; + for (int i = 0; i < 4; ++i) *(dim[i]) = 1; + for (int i = 4 - dims->size; i < 4; ++i) { + *dim[i] = dims->data[i - (4 - dims->size)]; + } + return kTfLiteOk; +} + +void TransposeToCHW(const float* hwc, float* chw, const TfLiteIntArray* hwc_dims) { + int batch_size, height_size, width_size, depth_size; + GetDims(&batch_size, &height_size, &width_size, &depth_size, hwc_dims); + RuntimeShape hwc_shape({batch_size, height_size, width_size, depth_size}); + RuntimeShape chw_shape({batch_size, depth_size, height_size, width_size}); + TransposeParams params = {/*perm_count=*/4, /*perm=*/{0, 3, 1, 2}}; + optimized_ops::Transpose(params, hwc_shape, hwc, chw_shape, chw); +} + +void TransposeToHWC(const float* chw, float* hwc, const TfLiteIntArray* hwc_dims) { + int batch_size, height_size, width_size, depth_size; + GetDims(&batch_size, &height_size, &width_size, &depth_size, hwc_dims); + RuntimeShape hwc_shape({batch_size, height_size, width_size, depth_size}); + RuntimeShape chw_shape({batch_size, depth_size, height_size, width_size}); + TransposeParams params = {/*perm_count=*/4, /*perm=*/{0, 2, 3, 1}}; + optimized_ops::Transpose(params, chw_shape, chw, hwc_shape, hwc); +} +} // namespace + +TfLiteStatus CoreMlDelegateKernel::Init(TfLiteContext* context, + const TfLiteDelegateParams* delegate_params) { + if (@available(iOS 11.0, *)) { + executor_ = [[::CoreMlExecutor alloc] init]; + TF_LITE_ENSURE_STATUS(BuildModel(context, delegate_params)); + // Serialize the model protocol buffer and compile it. + if (model_ == nullptr) { + TF_LITE_KERNEL_LOG(context, "Failed to createModel"); + return kTfLiteError; + } + NSURL* model_url = [executor_ saveModel:model_.get()]; + model_.reset(); + if (![executor_ build:model_url]) { + TF_LITE_KERNEL_LOG(context, "Failed to Compile and save Model."); + return kTfLiteError; + } + return kTfLiteOk; + } else { + TF_LITE_KERNEL_LOG(context, "Minimum required iOS version is 11.0."); + return kTfLiteError; + } +} + +void CoreMlDelegateKernel::AddInputTensors(const TfLiteIntArray* input_tensors, + TfLiteContext* context) { + int num_inputs = 0; + for (int i = 0; i < input_tensors->size; ++i) { + const int tensor_id = input_tensors->data[i]; + const auto& tensor = context->tensors[tensor_id]; + builder_->AddTensorWithID(tensor_id, delegates::coreml::TensorID(0, num_inputs++)); + } +} + +void CoreMlDelegateKernel::AddOutputTensors(const TfLiteIntArray* output_tensors, + TfLiteContext* context) { + auto* model_description = model_->mutable_description(); + for (int i = 0; i < output_tensors->size; ++i) { + const int tensor_id = output_tensors->data[i]; + const auto& tensor = context->tensors[tensor_id]; + + auto* output = model_description->mutable_output()->Add(); + output->set_name(builder_->GetTensorName(tensor_id)); + auto* multi_array = output->mutable_type()->mutable_multiarraytype(); + int batch_size, height_size, width_size, depth_size; + GetDims(&batch_size, &height_size, &width_size, &depth_size, tensor.dims); + multi_array->set_datatype(CoreML::Specification::ArrayFeatureType::FLOAT32); + multi_array->mutable_shape()->Add(depth_size); + multi_array->mutable_shape()->Add(height_size); + multi_array->mutable_shape()->Add(width_size); + } +} + +TfLiteStatus CoreMlDelegateKernel::BuildModel(TfLiteContext* context, + const TfLiteDelegateParams* delegate_params) { + TfLiteNode* node; + TfLiteRegistration* reg; + builder_.reset(new delegates::coreml::GraphBuilder()); + // Add Inputs + AddInputTensors(delegate_params->input_tensors, context); + // Build all ops. + for (int node_index : TfLiteIntArrayView(delegate_params->nodes_to_replace)) { + TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration(context, node_index, &node, ®)); + auto* op_builder = builder_->AddBuilder(reg->builtin_code, node); + if (op_builder == nullptr) { + TF_LITE_KERNEL_LOG(context, "Failed to build node %d.", node_index); + return kTfLiteError; + } + if (op_builder->RegisterInputs(node->inputs, context) != kTfLiteOk) { + TF_LITE_KERNEL_LOG(context, "Failed to add inputs for node %d.", node_index); + return kTfLiteError; + } + if (op_builder->PopulateSubgraph(context) != kTfLiteOk) { + TF_LITE_KERNEL_LOG(context, "Failed to add sub-builders for node %d.", node_index); + return kTfLiteError; + } + if (op_builder->RegisterOutputs(node->outputs, context) != kTfLiteOk) { + TF_LITE_KERNEL_LOG(context, "Failed to add outputs for node %d.", node_index); + return kTfLiteError; + } + } + model_.reset(builder_->BuildModel()); + if (model_ == nullptr) { + TF_LITE_KERNEL_LOG(context, "Failed to build Model"); + return kTfLiteError; + } + AddOutputTensors(delegate_params->output_tensors, context); + // TODO(karimnosseir): Set correct version ? + model_->set_specificationversion(1); + auto* model_description = model_->mutable_description(); + for (int i = 0; i < delegate_params->input_tensors->size; ++i) { + const int tensor_id = delegate_params->input_tensors->data[i]; + if (builder_->IsTensorUsed(tensor_id)) { + const auto& tensor = context->tensors[tensor_id]; + auto* input = model_description->mutable_input()->Add(); + input->set_name(builder_->GetTensorName(tensor_id)); + // TODO(karimnosseir): Handle different types ? + auto* multi_array = input->mutable_type()->mutable_multiarraytype(); + int batch_size, height_size, width_size, depth_size; + GetDims(&batch_size, &height_size, &width_size, &depth_size, tensor.dims); + multi_array->set_datatype(CoreML::Specification::ArrayFeatureType::FLOAT32); + multi_array->mutable_shape()->Add(depth_size); + multi_array->mutable_shape()->Add(height_size); + multi_array->mutable_shape()->Add(width_size); + } + } + + return kTfLiteOk; +} + +TfLiteStatus CoreMlDelegateKernel::Prepare(TfLiteContext* context, TfLiteNode* node) { + for (int tensor_index : TfLiteIntArrayView(node->inputs)) { + if (builder_->IsTensorUsed(tensor_index)) { + input_tensor_ids_.push_back(tensor_index); + } + } + + inputs_.reserve(input_tensor_ids_.size()); + for (int tensor_index : input_tensor_ids_) { + TfLiteTensor* tensor = &context->tensors[tensor_index]; + const int input_size = NumElements(tensor); + int batch_size, height_size, width_size, depth_size; + GetDims(&batch_size, &height_size, &width_size, &depth_size, tensor->dims); + + inputs_.push_back({std::vector(input_size), + builder_->GetTensorName(tensor_index), + {depth_size, height_size, width_size}}); + } + + outputs_.reserve(node->outputs->size); + for (int tensor_index : TfLiteIntArrayView(node->outputs)) { + TfLiteTensor* tensor = &context->tensors[tensor_index]; + const int output_size = NumElements(tensor); + + outputs_.push_back({std::vector(output_size), builder_->GetTensorName(tensor_index)}); + } + + return kTfLiteOk; +} + +TfLiteStatus CoreMlDelegateKernel::Invoke(TfLiteContext* context, TfLiteNode* node) { + if (@available(iOS 11.0, *)) { + TfLiteIntArrayView node_inputs(node->inputs); + for (int i = 0; i < input_tensor_ids_.size(); ++i) { + const int tensor_id = input_tensor_ids_[i]; + TfLiteTensor* tensor = &context->tensors[tensor_id]; + // Transpose input to CHW. + // TODO(b/143992544): try adding transpose op for inputs. + TransposeToCHW(tensor->data.f, inputs_[i].data.data(), tensor->dims); + } + + if (![executor_ invokeWithInputs:inputs_ outputs:outputs_]) { + return kTfLiteError; + } + for (int i = 0; i < node->outputs->size; ++i) { + TfLiteTensor* output_tensor = GetOutput(context, node, i); + TransposeToHWC(outputs_[i].data.data(), output_tensor->data.f, output_tensor->dims); + } + return kTfLiteOk; + } else { + TF_LITE_KERNEL_LOG(context, "Minimum required iOS version is 11.0."); + return kTfLiteError; + } +} + +CoreMlDelegateKernel::~CoreMlDelegateKernel() { + [executor_ cleanup]; +} + +} // namespace coreml +} // namespace delegates +} // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/coreml/coreml_executor.h b/tensorflow/lite/experimental/delegates/coreml/coreml_executor.h new file mode 100644 index 00000000000..edec3020cbc --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/coreml_executor.h @@ -0,0 +1,48 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import +#import + +#include +#include + +#include "external/coremltools/mlmodel/format/Model.pb.h" + +// Data for input/output tensors. +struct TensorData { + std::vector data; + const std::string name; + std::vector shape; // only required for input tensor. +}; + +// Responsible for: +// - Compiling and constructing MLModel from a serialized MlModel +// protocol buffer. +// - Invoking predictions on the built model. +// Usage: Construct object, call Build() and Invoke() for inference. +@interface CoreMlExecutor : NSObject + +- (bool)invokeWithInputs:(const std::vector&)inputs + outputs:(const std::vector&)outputs API_AVAILABLE(ios(11)); + +- (NSURL*)saveModel:(CoreML::Specification::Model*)model API_AVAILABLE(ios(11)); +- (bool)build:(NSURL*)modelUrl API_AVAILABLE(ios(11)); + +- (bool)cleanup; + +@property MLModel* model API_AVAILABLE(ios(11)); +@property NSString* mlModelFilePath; +@property NSString* compiledModelFilePath; +@end diff --git a/tensorflow/lite/experimental/delegates/coreml/coreml_executor.mm b/tensorflow/lite/experimental/delegates/coreml/coreml_executor.mm new file mode 100644 index 00000000000..2091c0d7ca0 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/coreml/coreml_executor.mm @@ -0,0 +1,186 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import "tensorflow/lite/experimental/delegates/coreml/coreml_executor.h" + +#import +#import + +#include +#include + +namespace { +// Returns NSURL for a temporary file. +NSURL* createTemporaryFile() { + // Get temporary directory. + NSURL* temporaryDirectoryURL = [NSURL fileURLWithPath:NSTemporaryDirectory() isDirectory:YES]; + // Generate a Unique file name to use. + NSString* temporaryFilename = [[NSProcessInfo processInfo] globallyUniqueString]; + // Create URL to that file. + NSURL* temporaryFileURL = [temporaryDirectoryURL URLByAppendingPathComponent:temporaryFilename]; + + return temporaryFileURL; +} +} // namespace + +@interface MultiArrayFeatureProvider : NSObject { + const std::vector* _inputs; + NSSet* _featureNames; +} + +- (instancetype)initWithInputs:(const std::vector*)inputs; +- (MLFeatureValue*)featureValueForName:(NSString*)featureName API_AVAILABLE(ios(11)); +- (NSSet*)featureNames; + +@end + +@implementation MultiArrayFeatureProvider + +- (instancetype)initWithInputs:(const std::vector*)inputs { + self = [super init]; + _inputs = inputs; + for (auto& input : *_inputs) { + if (input.name.empty()) { + return nil; + } + } + return self; +} + +- (NSSet*)featureNames { + if (_featureNames == nil) { + NSMutableArray* names = [[NSMutableArray alloc] init]; + for (auto& input : *_inputs) { + [names addObject:[NSString stringWithCString:input.name.c_str() + encoding:[NSString defaultCStringEncoding]]]; + } + _featureNames = [NSSet setWithArray:names]; + } + return _featureNames; +} + +- (MLFeatureValue*)featureValueForName:(NSString*)featureName { + for (auto& input : *_inputs) { + if ([featureName cStringUsingEncoding:NSUTF8StringEncoding] == input.name) { + // TODO(b/141492326): Update shape handling for higher ranks + NSArray* shape = @[ @(input.shape[0]), @(input.shape[1]), @(input.shape[2]) ]; + NSArray* strides = @[ @(input.shape[1] * input.shape[2]), @(input.shape[2]), @1 ]; + NSError* error = nil; + MLMultiArray* mlArray = [[MLMultiArray alloc] initWithDataPointer:(float*)input.data.data() + shape:shape + dataType:MLMultiArrayDataTypeFloat32 + strides:strides + deallocator:(^(void* bytes){ + })error:&error]; + if (error != nil) { + NSLog(@"Failed to create MLMultiArray for feature %@ error: %@", featureName, + [error localizedDescription]); + return nil; + } + auto* mlFeatureValue = [MLFeatureValue featureValueWithMultiArray:mlArray]; + return mlFeatureValue; + } + } + + NSLog(@"Feature %@ not found", featureName); + return nil; +} +@end + +@implementation CoreMlExecutor +- (bool)invokeWithInputs:(const std::vector&)inputs + outputs:(const std::vector&)outputs { + if (_model == nil) { + return NO; + } + NSError* error = nil; + MultiArrayFeatureProvider* inputFeature = + [[MultiArrayFeatureProvider alloc] initWithInputs:&inputs]; + if (inputFeature == nil) { + NSLog(@"inputFeature is not initialized."); + return NO; + } + MLPredictionOptions* options = [[MLPredictionOptions alloc] init]; + id outputFeature = [_model predictionFromFeatures:inputFeature + options:options + error:&error]; + if (error != nil) { + NSLog(@"Error executing model: %@", [error localizedDescription]); + return NO; + } + NSSet* outputFeatureNames = [outputFeature featureNames]; + for (auto& output : outputs) { + NSString* outputName = [NSString stringWithCString:output.name.c_str() + encoding:[NSString defaultCStringEncoding]]; + MLFeatureValue* outputValue = + [outputFeature featureValueForName:[outputFeatureNames member:outputName]]; + auto* data = [outputValue multiArrayValue]; + float* outputData = (float*)data.dataPointer; + if (outputData == nullptr) { + return NO; + } + memcpy((float*)output.data.data(), outputData, output.data.size() * sizeof(output.data[0])); + } + return YES; +} + +- (bool)cleanup { + NSError* error = nil; + [[NSFileManager defaultManager] removeItemAtPath:_mlModelFilePath error:&error]; + if (error != nil) { + NSLog(@"Failed cleaning up model: %@", [error localizedDescription]); + return NO; + } + [[NSFileManager defaultManager] removeItemAtPath:_compiledModelFilePath error:&error]; + if (error != nil) { + NSLog(@"Failed cleaning up compiled model: %@", [error localizedDescription]); + return NO; + } + return YES; +} + +- (NSURL*)saveModel:(CoreML::Specification::Model*)model { + NSURL* modelUrl = createTemporaryFile(); + NSString* modelPath = [modelUrl path]; + // Flush data to file. + // TODO(karimnosseir): Can we mmap this instead of actual writing it to phone ? + std::ofstream file_stream([modelPath UTF8String], std::ios::out | std::ios::binary); + model->SerializeToOstream(&file_stream); + return modelUrl; +} + +- (bool)build:(NSURL*)modelUrl { + NSError* error = nil; + NSURL* compileUrl = [MLModel compileModelAtURL:modelUrl error:&error]; + if (error != nil) { + NSLog(@"Error compiling model %@", [error localizedDescription]); + return NO; + } + _mlModelFilePath = [modelUrl path]; + _compiledModelFilePath = [compileUrl path]; + + if (@available(iOS 12.0, *)) { + MLModelConfiguration* config = [MLModelConfiguration alloc]; + config.computeUnits = MLComputeUnitsAll; + _model = [MLModel modelWithContentsOfURL:compileUrl configuration:config error:&error]; + } else { + _model = [MLModel modelWithContentsOfURL:compileUrl error:&error]; + } + if (error != NULL) { + NSLog(@"Error Creating MLModel %@", [error localizedDescription]); + return NO; + } + return YES; +} +@end diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/BUILD b/tensorflow/lite/experimental/delegates/hexagon/builders/BUILD index bf83176a764..4a49b457b20 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/BUILD +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/BUILD @@ -13,6 +13,7 @@ cc_library( "arithmetic_builder.cc", "concat_builder.cc", "conv_2d_builder.cc", + "conv_2d_helpers.cc", "l2_normalization_builder.cc", "matmul_builder.cc", "neg_op_builder.cc", diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/conv_2d_builder.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/conv_2d_builder.cc index 85957706d57..28f56f3045b 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/conv_2d_builder.cc +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/conv_2d_builder.cc @@ -128,13 +128,16 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, int input_batch_size, input_height_size, input_width_size, input_depth_size; GetDims(&input_batch_size, &input_height_size, &input_width_size, &input_depth_size, data_tensor.dims); - TF_LITE_ENSURE_STATUS(ComputeMinAndMaxQuantValues( - data_tensor, &data_min_, &data_max_, std::numeric_limits::min(), - std::numeric_limits::max())); + float data_min = 0; + float data_max = 0; + TF_LITE_ENSURE_STATUS( + ComputeMinAndMaxQuantValues(data_tensor, &data_min, &data_max)); auto* data_min_const = graph_builder_->AddConstNodeWithData( - quant_bound_shape.data(), (char*)&data_min_, sizeof(data_min_)); + quant_bound_shape.data(), reinterpret_cast(&data_min), + sizeof(data_min)); auto* data_max_const = graph_builder_->AddConstNodeWithData( - quant_bound_shape.data(), (char*)&data_max_, sizeof(data_max_)); + quant_bound_shape.data(), reinterpret_cast(&data_max), + sizeof(data_max)); // Gather information about the Convolution operations. TfLitePadding padding_type = kTfLitePaddingUnknown; @@ -168,65 +171,8 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, } // Weights tensor - const auto& weights_tensor = context->tensors[inputs->data[1]]; - if (weights_tensor.allocation_type != kTfLiteMmapRo) { - context->ReportError( - context, "Weights tensor doesn't have correct allocation type: %s", - weights_tensor.name); - return kTfLiteError; - } - int weights_batch_size, weights_height_size, weights_width_size, - weights_depth_size; - // Hexagon lib expects the weight tensor in HWCN, TFLite uses NHWC. - // Transpose NHWC -> HWCN - GetDims(&weights_batch_size, &weights_height_size, &weights_width_size, - &weights_depth_size, weights_tensor.dims); - OpBuilder* const_weights_node = nullptr; - if (op_node_.op_type == OP_Supernode_8x8p32to8) { - // Hexagon lib expects the weight tensor in HWCN, TFLite uses NHWC. - // Transpose NHWC -> HWCN - weight_shape_ = {weights_height_size, weights_width_size, - weights_depth_size, weights_batch_size}; - RuntimeShape nhwc_shape({weights_batch_size, weights_height_size, - weights_width_size, weights_depth_size}); - RuntimeShape hwcn_shape({weights_height_size, weights_width_size, - weights_depth_size, weights_batch_size}); - std::vector hwcn(NumElements(&weights_tensor)); - TransposeParams transpose_params; - transpose_params.perm_count = 4; - transpose_params.perm[0] = 1; - transpose_params.perm[1] = 2; - transpose_params.perm[2] = 3; - transpose_params.perm[3] = 0; - optimized_ops::Transpose(transpose_params, nhwc_shape, - weights_tensor.data.uint8, hwcn_shape, - hwcn.data()); - const_weights_node = graph_builder_->AddConstNodeWithData( - weight_shape_.data(), (char*)hwcn.data(), - hwcn.size() * sizeof(hwcn[0])); - } else if (op_node_.op_type == OP_DepthwiseSupernode_8x8p32to8) { - // Hexagon treats depthwise conv like tf.nn.depthwise_conv2d, where the - // expected filter shape is [fh,fw,din,dmul]. - // The data itself will remain the same, since TFLite's representation is - // just a 'flattening' of Hexagon's version. - const int channel_multiplier = weights_depth_size / input_depth_size; - weight_shape_ = {weights_height_size, weights_width_size, input_depth_size, - channel_multiplier}; - const_weights_node = graph_builder_->AddConstNodeWithData( - weight_shape_.data(), weights_tensor.data.raw, - NumElements(&weights_tensor) * sizeof(weights_tensor.data.uint8[0])); - } - // Quantization params for Weights tensor. TF_LITE_ENSURE_STATUS( - ComputeMinAndMaxQuantValues(weights_tensor, &weights_min_, &weights_max_, - std::numeric_limits::min(), - std::numeric_limits::max())); - auto* weights_min_const = graph_builder_->AddConstNodeWithData( - quant_bound_shape.data(), (char*)&weights_min_, sizeof(weights_min_)); - auto* weights_max_const = graph_builder_->AddConstNodeWithData( - quant_bound_shape.data(), (char*)&weights_max_, sizeof(weights_max_)); - graph_builder_->AddTensorWithID(inputs->data[1], const_weights_node->GetID(), - 0); + InitializeWeightsNodes(inputs, outputs, context, input_depth_size)); // Stride node. static int dummy = 0; @@ -242,14 +188,14 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, // Output bounds. // TODO(b/129276536): Add support for other activations here. Current // implementation assumes None/Relu. + float output_min = 0; + float output_max = 0; TF_LITE_ENSURE_STATUS(ComputeMinAndMaxQuantValues( - context->tensors[outputs->data[0]], &output_min_, &output_max_, - std::numeric_limits::min(), - std::numeric_limits::max())); + context->tensors[outputs->data[0]], &output_min, &output_max)); // These denote the bounds fed to Hexagon's Conv mechanism, which will be // different from the TFLite tensor bounds if there is a RELU activation. - float conv_output_min = output_min_; - float conv_output_max = output_max_; + float conv_output_min = output_min; + float conv_output_max = output_max; if (activation == kTfLiteActRelu6) { conv_output_min = 0; conv_output_max = 6; @@ -267,16 +213,7 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, sizeof(conv_output_max)); // Bias node. - const auto& bias_tensor = context->tensors[inputs->data[2]]; - auto* bias_data_node = - graph_builder_->AddConstNodeWithData(inputs->data[2], bias_tensor); - TF_LITE_ENSURE_STATUS(ComputeMinAndMaxQuantValues( - bias_tensor, &bias_min_, &bias_max_, std::numeric_limits::min(), - std::numeric_limits::max())); - auto* bias_min_const = graph_builder_->AddConstNodeWithData( - quant_bound_shape.data(), (char*)&bias_min_, sizeof(bias_min_)); - auto* bias_max_const = graph_builder_->AddConstNodeWithData( - quant_bound_shape.data(), (char*)&bias_max_, sizeof(bias_max_)); + TF_LITE_ENSURE_STATUS(InitializeBiasNodes(inputs, outputs, context)); // TODO(b/143759564): Simplify this method when depth_multiplier support needs // generalizing. @@ -290,7 +227,7 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, return kTfLiteError; } - TensorID output, output_min, output_max; + TensorID output_tensor, output_min_tensor, output_max_tensor; if (is_dilated_depthwise_conv) { // For dilated Depthwise Conv, we convert this node into SpaceToBatchND, and // then chain Supernode & BatchToSpaceND after it. @@ -298,9 +235,9 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, GetDims(&input_batch_size, &input_height_size, &input_width_size, &input_depth_size, data_tensor.dims); ComputeSpaceToBatchParams( - input_height_size, input_width_size, weights_height_size, - weights_width_size, dilation_factors_h_w_, padding_type, - &space_to_batch_paddings_, &batch_to_space_crops_); + input_height_size, input_width_size, weight_shape_[0], weight_shape_[1], + dilation_factors_h_w_, padding_type, &space_to_batch_paddings_, + &batch_to_space_crops_); auto* dilation_factors_const = graph_builder_->AddConstNodeWithData( dilation_factors_shape.data(), (char*)dilation_factors_h_w_.data(), dilation_factors_h_w_.size() * sizeof(stride_height)); @@ -332,17 +269,20 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, auto* conv_op = graph_builder_->AddNode(GetTFLiteNodeID()); conv_op->SetOpType(OP_DepthwiseSupernode_8x8p32to8); conv_op->AddInput(space_to_batch_op_out); - conv_op->AddInput(TensorID(const_weights_node->GetID(), 0)); + conv_op->AddInput(TensorID(weights_data_node_->GetID(), 0)); conv_op->AddInput(TensorID(data_min_const->GetID(), 0)); conv_op->AddInput(TensorID(data_max_const->GetID(), 0)); - conv_op->AddInput(TensorID(weights_min_const->GetID(), 0)); - conv_op->AddInput(TensorID(weights_max_const->GetID(), 0)); + conv_op->AddInput(TensorID(weights_min_node_->GetID(), 0)); + conv_op->AddInput(TensorID(weights_max_node_->GetID(), 0)); conv_op->AddInput(TensorID(stride_node->GetID(), 0)); - conv_op->AddInput(TensorID(bias_data_node->GetID(), 0)); - conv_op->AddInput(TensorID(bias_min_const->GetID(), 0)); - conv_op->AddInput(TensorID(bias_max_const->GetID(), 0)); + conv_op->AddInput(TensorID(bias_data_node_->GetID(), 0)); + conv_op->AddInput(TensorID(bias_min_node_->GetID(), 0)); + conv_op->AddInput(TensorID(bias_max_node_->GetID(), 0)); conv_op->AddInput(TensorID(conv_output_min_const->GetID(), 0)); conv_op->AddInput(TensorID(conv_output_max_const->GetID(), 0)); + if (channel_scales_node_ != nullptr) { + conv_op->AddInput(TensorID(channel_scales_node_->GetID(), 0)); + } // The padding is handled by the SpaceToBatch/BatchToSpace ops surrounding // this node. Hence, this op's padding remains VALID only. // tf.nn.with_space_to_batch's docs state the following pattern: @@ -374,12 +314,14 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, batch_to_space_op->AddInput(TensorID(crops_const->GetID(), 0)); batch_to_space_op->AddInput(TensorID(conv_output_min_const->GetID(), 0)); batch_to_space_op->AddInput(TensorID(conv_output_max_const->GetID(), 0)); - output = + output_tensor = batch_to_space_op->AddOutput(sizeof(uint8_t), 4, {output_batch_size, output_height_size, output_width_size, output_depth_size}); - output_min = batch_to_space_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1}); - output_max = batch_to_space_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1}); + output_min_tensor = + batch_to_space_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1}); + output_max_tensor = + batch_to_space_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1}); } else { // Standard case. // Padding type. @@ -390,38 +332,41 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, } // Inputs AddInput(graph_builder_->GetHexagonTensorId(inputs->data[0])); - AddInput(TensorID(const_weights_node->GetID(), 0)); + AddInput(TensorID(weights_data_node_->GetID(), 0)); AddInput(TensorID(data_min_const->GetID(), 0)); AddInput(TensorID(data_max_const->GetID(), 0)); - AddInput(TensorID(weights_min_const->GetID(), 0)); - AddInput(TensorID(weights_max_const->GetID(), 0)); + AddInput(TensorID(weights_min_node_->GetID(), 0)); + AddInput(TensorID(weights_max_node_->GetID(), 0)); AddInput(TensorID(stride_node->GetID(), 0)); - AddInput(TensorID(bias_data_node->GetID(), 0)); - AddInput(TensorID(bias_min_const->GetID(), 0)); - AddInput(TensorID(bias_max_const->GetID(), 0)); + AddInput(TensorID(bias_data_node_->GetID(), 0)); + AddInput(TensorID(bias_min_node_->GetID(), 0)); + AddInput(TensorID(bias_max_node_->GetID(), 0)); AddInput(TensorID(conv_output_min_const->GetID(), 0)); AddInput(TensorID(conv_output_max_const->GetID(), 0)); + if (channel_scales_node_ != nullptr) { + AddInput(TensorID(channel_scales_node_->GetID(), 0)); + } // Outputs - output = AddOutput(sizeof(uint8_t), 4, - {output_batch_size, output_height_size, - output_width_size, output_depth_size}); - output_min = AddOutput(sizeof(float), 4, {1, 1, 1, 1}); - output_max = AddOutput(sizeof(float), 4, {1, 1, 1, 1}); + output_tensor = AddOutput(sizeof(uint8_t), 4, + {output_batch_size, output_height_size, + output_width_size, output_depth_size}); + output_min_tensor = AddOutput(sizeof(float), 4, {1, 1, 1, 1}); + output_max_tensor = AddOutput(sizeof(float), 4, {1, 1, 1, 1}); } // Requantize if activation was not None. if (activation != kTfLiteActNone) { auto* requantized_min_const = graph_builder_->AddConstNodeWithData( - quant_bound_shape.data(), reinterpret_cast(&output_min_), - sizeof(output_min_)); + quant_bound_shape.data(), reinterpret_cast(&output_min), + sizeof(output_min)); auto* requantized_max_const = graph_builder_->AddConstNodeWithData( - quant_bound_shape.data(), reinterpret_cast(&output_max_), - sizeof(output_max_)); + quant_bound_shape.data(), reinterpret_cast(&output_max), + sizeof(output_max)); auto* requantize_op = graph_builder_->AddNode(GetTFLiteNodeID()); requantize_op->SetOpType(OP_Requantize_8to8); - requantize_op->AddInput(output); - requantize_op->AddInput(output_min); - requantize_op->AddInput(output_max); + requantize_op->AddInput(output_tensor); + requantize_op->AddInput(output_min_tensor); + requantize_op->AddInput(output_max_tensor); requantize_op->AddInput(TensorID(requantized_min_const->GetID(), 0)); requantize_op->AddInput(TensorID(requantized_max_const->GetID(), 0)); node_output_ = @@ -431,7 +376,7 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, requantize_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1}); requantize_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1}); } else { - node_output_ = output; + node_output_ = output_tensor; } return kTfLiteOk; diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/conv_2d_builder.h b/tensorflow/lite/experimental/delegates/hexagon/builders/conv_2d_builder.h index b66e410d3bb..f67f017299c 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/conv_2d_builder.h +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/conv_2d_builder.h @@ -34,15 +34,44 @@ class Conv2dOpBuilder : public OpBuilder { TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs, TfLiteContext* context) override; - ~Conv2dOpBuilder(); + ~Conv2dOpBuilder() override; private: + TfLiteStatus ProcessPerChannelQuantizedWeights(const TfLiteIntArray* inputs, + const TfLiteIntArray* outputs, + TfLiteContext* context, + float* weights_min, + float* weights_max); + + TfLiteStatus InitializeWeightsNodes(const TfLiteIntArray* inputs, + const TfLiteIntArray* outputs, + TfLiteContext* context, + const int input_depth); + + TfLiteStatus ProcessPerChannelQuantizedBias(const TfLiteIntArray* inputs, + const TfLiteIntArray* outputs, + TfLiteContext* context, + float* bias_min, float* bias_max); + + TfLiteStatus InitializeBiasNodes(const TfLiteIntArray* inputs, + const TfLiteIntArray* outputs, + TfLiteContext* context); + TensorID node_output_; std::vector transposed_weights_; std::vector stride_shape_; std::vector weight_shape_; - float data_min_, data_max_, weights_min_, weights_max_, bias_min_, bias_max_, - output_min_, output_max_; + OpBuilder* weights_data_node_ = nullptr; + OpBuilder* weights_min_node_ = nullptr; + OpBuilder* weights_max_node_ = nullptr; + OpBuilder* bias_data_node_ = nullptr; + OpBuilder* bias_min_node_ = nullptr; + OpBuilder* bias_max_node_ = nullptr; + + // Non-null only if node has per-channel quantized weights/biases. + OpBuilder* channel_scales_node_ = nullptr; + float* scales_data_ = nullptr; + int num_scale_values_ = 1; // Only used for dilated Depthwise Conv. std::vector dilation_factors_h_w_; diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/conv_2d_helpers.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/conv_2d_helpers.cc new file mode 100644 index 00000000000..6cb5ddaa86f --- /dev/null +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/conv_2d_helpers.cc @@ -0,0 +1,269 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include +#include +#include +#include + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/experimental/delegates/hexagon/builders/conv_2d_builder.h" +#include "tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/hexagon_nn.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace delegates { +namespace hexagon { +namespace { + +constexpr uint8_t k8BitSignFlipConstant = 0x80; +// 1/1024 ~ 0.0009766 is a restriction set by Hexagon's kernels. +// TODO(b/151103818): Figure out a way to retrieve this constant reliably. +constexpr float kHexagonMinRelativeScale = 0.0009766f; + +} // namespace + +TfLiteStatus Conv2dOpBuilder::ProcessPerChannelQuantizedWeights( + const TfLiteIntArray* inputs, const TfLiteIntArray* outputs, + TfLiteContext* context, float* weights_min, float* weights_max) { + const auto& weights_tensor = context->tensors[inputs->data[1]]; + TfLiteAffineQuantization* weights_quant_params = + reinterpret_cast( + weights_tensor.quantization.params); + + // Retrieve channel scales. + num_scale_values_ = weights_quant_params->scale->size; + // Normalize the scales as expected by Hexagon. + scales_data_ = weights_quant_params->scale->data; + std::vector normalized_scales; + normalized_scales.reserve(num_scale_values_); + float scale_max = 0.0; + for (int i = 0; i < num_scale_values_; ++i) { + normalized_scales.push_back(scales_data_[i]); + if (scales_data_[i] > scale_max) { + scale_max = scales_data_[i]; + } + } + if (scale_max == 0.0) { + TF_LITE_KERNEL_LOG(context, "Scale max is zero for: %s", + weights_tensor.name); + return kTfLiteError; + } + for (int i = 0; i < num_scale_values_; ++i) { + normalized_scales[i] = + std::max(normalized_scales[i] / scale_max, kHexagonMinRelativeScale); + } + // Add node for channel scales data. + const std::vector scales_shape = {1, 1, 1, num_scale_values_}; + channel_scales_node_ = graph_builder_->AddConstNodeWithData( + scales_shape.data(), reinterpret_cast(normalized_scales.data()), + normalized_scales.size() * sizeof(normalized_scales[0])); + *weights_min = -128 * scale_max; + *weights_max = 127 * scale_max; + return kTfLiteOk; +} + +TfLiteStatus Conv2dOpBuilder::InitializeWeightsNodes( + const TfLiteIntArray* inputs, const TfLiteIntArray* outputs, + TfLiteContext* context, const int input_depth) { + const std::vector quant_bound_shape = {1, 1, 1, 1}; + + const auto& weights_tensor = context->tensors[inputs->data[1]]; + if (weights_tensor.allocation_type != kTfLiteMmapRo) { + TF_LITE_KERNEL_LOG( + context, "Weights tensor doesn't have correct allocation type: %s", + weights_tensor.name); + return kTfLiteError; + } + int weights_batch_size, weights_height_size, weights_width_size, + weights_depth_size; + // Hexagon lib expects the weight tensor in HWCN, TFLite uses NHWC. + // Transpose NHWC -> HWCN + GetDims(&weights_batch_size, &weights_height_size, &weights_width_size, + &weights_depth_size, weights_tensor.dims); + + // Weights tensor could be int8 even for per-tensor quantization. + // Therefore, we look at the number of scale values to check if it is + // per-channel quantized. + TfLiteAffineQuantization* weights_quant_params = + reinterpret_cast( + weights_tensor.quantization.params); + const bool is_per_channel_quant = weights_quant_params->scale->size > 1; + + // WEIGHTS DATA. + if (op_node_.op_type == OP_Supernode_8x8p32to8) { + // Hexagon lib expects the weight tensor in HWCN, TFLite uses NHWC. + // Transpose NHWC -> HWCN + weight_shape_ = {weights_height_size, weights_width_size, + weights_depth_size, weights_batch_size}; + RuntimeShape nhwc_shape({weights_batch_size, weights_height_size, + weights_width_size, weights_depth_size}); + RuntimeShape hwcn_shape({weights_height_size, weights_width_size, + weights_depth_size, weights_batch_size}); + std::vector hwcn(NumElements(&weights_tensor)); + TransposeParams transpose_params; + transpose_params.perm_count = 4; + transpose_params.perm[0] = 1; + transpose_params.perm[1] = 2; + transpose_params.perm[2] = 3; + transpose_params.perm[3] = 0; + // TODO(b/151103818): Try merging Transpose & bit flip. + if (weights_tensor.type == kTfLiteInt8) { + optimized_ops::Transpose(transpose_params, nhwc_shape, + weights_tensor.data.int8, hwcn_shape, + reinterpret_cast(hwcn.data())); + // Flip bits on the weight values so that the int8 values are treated + // as uint8. + for (int i = 0; i < hwcn.size(); ++i) { + hwcn[i] = hwcn[i] ^ k8BitSignFlipConstant; + } + } else { + optimized_ops::Transpose(transpose_params, nhwc_shape, + weights_tensor.data.uint8, hwcn_shape, + hwcn.data()); + } + weights_data_node_ = graph_builder_->AddConstNodeWithData( + weight_shape_.data(), reinterpret_cast(hwcn.data()), + hwcn.size() * sizeof(hwcn[0])); + } else if (op_node_.op_type == OP_DepthwiseSupernode_8x8p32to8) { + // Hexagon treats depthwise conv like tf.nn.depthwise_conv2d, where the + // expected filter shape is [fh,fw,din,dmul]. + // The data itself will remain the same, since TFLite's representation is + // just a 'flattening' of Hexagon's version. + const int channel_multiplier = weights_depth_size / input_depth; + weight_shape_ = {weights_height_size, weights_width_size, input_depth, + channel_multiplier}; + + if (weights_tensor.type == kTfLiteInt8) { + // Flip bits on the weight values so that the int8 values are treated + // as uint8. + std::vector converted_data(NumElements(&weights_tensor)); + for (int i = 0; i < converted_data.size(); ++i) { + converted_data[i] = weights_tensor.data.int8[i] ^ k8BitSignFlipConstant; + } + weights_data_node_ = graph_builder_->AddConstNodeWithData( + weight_shape_.data(), reinterpret_cast(converted_data.data()), + converted_data.size() * sizeof(converted_data[0])); + } else { + weights_data_node_ = graph_builder_->AddConstNodeWithData( + weight_shape_.data(), weights_tensor.data.raw, + NumElements(&weights_tensor) * sizeof(weights_tensor.data.uint8[0])); + } + } + graph_builder_->AddTensorWithID(inputs->data[1], weights_data_node_->GetID(), + 0); + + // WEIGHTS QUANTIZATION. + float weights_min = 0; + float weights_max = 0; + if (is_per_channel_quant) { + ProcessPerChannelQuantizedWeights(inputs, outputs, context, &weights_min, + &weights_max); + } else { + TF_LITE_ENSURE_STATUS(ComputeMinAndMaxQuantValues( + weights_tensor, &weights_min, &weights_max)); + } + weights_min_node_ = graph_builder_->AddConstNodeWithData( + quant_bound_shape.data(), reinterpret_cast(&weights_min), + sizeof(weights_min)); + weights_max_node_ = graph_builder_->AddConstNodeWithData( + quant_bound_shape.data(), reinterpret_cast(&weights_max), + sizeof(weights_max)); + + return kTfLiteOk; +} + +TfLiteStatus Conv2dOpBuilder::ProcessPerChannelQuantizedBias( + const TfLiteIntArray* inputs, const TfLiteIntArray* outputs, + TfLiteContext* context, float* bias_min, float* bias_max) { + const auto& bias_tensor = context->tensors[inputs->data[2]]; + + const TfLiteAffineQuantization* input_quant_params = + static_cast( + context->tensors[inputs->data[0]].quantization.params); + const float input_scale = input_quant_params->scale->data[0]; + // Now dequantize bias values to float first, to adjust for the + // normalization of channel scales. + int32_t* bias_data = bias_tensor.data.i32; + const int bias_size = NumElements(&bias_tensor); + if (bias_size != num_scale_values_) { + TF_LITE_KERNEL_LOG( + context, "Bias/channel scales number mismatch for bias tensor: %s", + bias_tensor.name); + return kTfLiteError; + } + std::vector dequantized_bias; + dequantized_bias.reserve(bias_size); + for (int i = 0; i < bias_size; ++i) { + const float dequantized_value = + bias_data[i] * input_scale * scales_data_[i]; + const float abs_dequantized_value = std::abs(dequantized_value); + if (abs_dequantized_value > *bias_max) { + *bias_max = abs_dequantized_value; + } + dequantized_bias.push_back(dequantized_value); + } + *bias_max = *bias_max * 8; + *bias_min = -1 * *bias_max; + // Now requantize the bias values to the new min/max values. + std::vector preprocessed_bias_data; + preprocessed_bias_data.reserve(num_scale_values_); + for (int i = 0; i < bias_size; ++i) { + preprocessed_bias_data.push_back(static_cast( + std::round(std::pow(2, 31) * (dequantized_bias[i] / *bias_max)))); + } + // Add nodes for bias. + const std::vector bias_shape = {1, 1, 1, bias_size}; + bias_data_node_ = graph_builder_->AddConstNodeWithData( + bias_shape.data(), reinterpret_cast(preprocessed_bias_data.data()), + preprocessed_bias_data.size() * sizeof(preprocessed_bias_data[0])); + return kTfLiteOk; +} + +TfLiteStatus Conv2dOpBuilder::InitializeBiasNodes(const TfLiteIntArray* inputs, + const TfLiteIntArray* outputs, + TfLiteContext* context) { + const std::vector quant_bound_shape = {1, 1, 1, 1}; + + const auto& bias_tensor = context->tensors[inputs->data[2]]; + + float bias_min = 0; + float bias_max = 0; + if (channel_scales_node_ != nullptr) { + ProcessPerChannelQuantizedBias(inputs, outputs, context, &bias_min, + &bias_max); + } else { + bias_data_node_ = + graph_builder_->AddConstNodeWithData(inputs->data[2], bias_tensor); + TF_LITE_ENSURE_STATUS( + ComputeMinAndMaxQuantValues(bias_tensor, &bias_min, &bias_max)); + } + + bias_min_node_ = graph_builder_->AddConstNodeWithData( + quant_bound_shape.data(), reinterpret_cast(&bias_min), + sizeof(bias_min)); + bias_max_node_ = graph_builder_->AddConstNodeWithData( + quant_bound_shape.data(), reinterpret_cast(&bias_max), + sizeof(bias_max)); + + return kTfLiteOk; +} + +} // namespace hexagon +} // namespace delegates +} // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.h b/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.h index 7c39d013d59..0278964f6de 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.h +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.h @@ -134,6 +134,10 @@ class OpBuilder { return ComputeMinAndMaxQuantValues(tensor, min, max, std::numeric_limits::min(), std::numeric_limits::max()); + } else if (tensor.type == kTfLiteInt32) { + return ComputeMinAndMaxQuantValues(tensor, min, max, + std::numeric_limits::min(), + std::numeric_limits::max()); } return kTfLiteError; } @@ -151,10 +155,6 @@ class OpBuilder { } const TfLiteAffineQuantization* params = static_cast(quant.params); - if (params->quantized_dimension != 0) { - printf("Quantized dimensions not 0 for tensor: %s\n", tensor.name); - return kTfLiteError; - } float scale = params->scale->data[0]; float zero_point = static_cast(params->zero_point->data[0]); *min = scale * (static_cast(min_value) - zero_point); diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/conv_test.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/conv_test.cc index ba4b57001fb..f204713304d 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/conv_test.cc +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/conv_test.cc @@ -42,10 +42,34 @@ class QuantizedConvolutionOpModel : public SingleOpModelWithHexagon { if (type == BuiltinOperator_DEPTHWISE_CONV_2D) { bias_size = GetShape(filter_)[3]; } - // per tensor quantization. - auto bias_scale = GetScale(input_) * GetScale(filter_); - TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; - bias_ = AddInput(bias); + if (filter.per_channel_quantization) { + // per channel quantization. + std::vector bias_scale( + filter.per_channel_quantization_scales.size()); + std::vector bias_zero_points( + filter.per_channel_quantization_scales.size()); + for (size_t i = 0; i < filter.per_channel_quantization_scales.size(); + ++i) { + bias_scale[i] = input.scale * filter.per_channel_quantization_scales[i]; + bias_zero_points[i] = 0; + } + TensorData bias{TensorType_INT32, + {bias_size}, + /*min=*/0, + /*max=*/0, + /*scale=*/0, + /*zero_point=*/0, + true, + /*per_channel_quantization_scales=*/bias_scale, + /*per_channel_quantization_offsets=*/bias_zero_points, + /*channel_index==*/0}; + bias_ = AddInput(bias); + } else { + // per tensor quantization. + auto bias_scale = GetScale(input_) * GetScale(filter_); + TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } output_ = AddOutput(output); @@ -88,9 +112,22 @@ class QuantizedConvolutionOpModel : public SingleOpModelWithHexagon { QuantizeAndPopulate(bias_, data); } + template std::vector GetDequantizedOutput() { - return Dequantize(ExtractVector(output_), - GetScale(output_), GetZeroPoint(output_)); + return Dequantize(ExtractVector(output_), GetScale(output_), + GetZeroPoint(output_)); + } + + void SetInt8Input(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + void SetPerChannelQuantizedFilter(std::initializer_list data) { + PerChannelSymmetricQuantizeAndPopulate(filter_, data); + } + + void SetPerChannelQuantizedBias(std::initializer_list data) { + PerChannelQuantizeBias(bias_, data); } protected: @@ -100,6 +137,168 @@ class QuantizedConvolutionOpModel : public SingleOpModelWithHexagon { int output_; }; +// CONVOLUTION TESTS + +TEST(QuantizedConvolutionOpModel, SimpleConvTestNoActivation) { + QuantizedConvolutionOpModel m( + BuiltinOperator_CONV_2D, {TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64}, + {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64}, + {TensorType_UINT8, {}, -127, 128}, Padding_VALID, /**dilation_factor**/ 1, + /**stride**/ 2); + m.SetInput({ + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }); + m.SetFilter({ + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + m.ApplyDelegateAndInvoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 17, 4, 3, // second batch, left + 37, 4, 3, // second batch, right + }, + 1e-5))); +} + +TEST(QuantizedConvolutionOpModel, SimpleConvTestReLU6Activation) { + QuantizedConvolutionOpModel m( + BuiltinOperator_CONV_2D, {TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64}, + {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64}, + {TensorType_UINT8, {}, -127, 128}, Padding_VALID, /**dilation_factor**/ 1, + /**stride**/ 2, ActivationFunctionType_RELU6); + m.SetInput({ + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }); + m.SetFilter({ + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + m.ApplyDelegateAndInvoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 6, 2, 5, // first batch, left + 6, 2, 5, // first batch, right + 6, 4, 3, // second batch, left + 6, 4, 3, // second batch, right + }, + 1e-5))); +} + +TEST(QuantizedConvolutionOpModel, SimplePerTensor_Int8) { + QuantizedConvolutionOpModel m( + BuiltinOperator_CONV_2D, + {TensorType_INT8, {1, 2, 3, 2}, -63.5, 64, 0.5, -1}, + {TensorType_INT8, + // [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel] + {2, 2, 2, 2}, + 0, + 0, + 0, + 0, + /*per_channel_quantization=*/true, + /*per_channel_quantization_scales=*/{1}, + /*per_channel_quantization_offsets=*/{0}, + /*channel_index=*/0}, + {TensorType_INT8, {}, -63.5, 64, 0.5, -1}, Padding_VALID); + m.SetInt8Input({ + // [1 * 2 * 3 * 2] as [batch, y, x, input_channel] + 3, 2, // batch = 0, y = 0, x = 0 + 1, -1, // batch = 0, y = 0, x = 1 + -2, -3, // batch = 0, y = 0, x = 2 + 4, 3, // batch = 0, y = 1, x = 0 + 2, -2, // batch = 0, y = 1, x = 1 + -3, -4, // batch = 0, y = 1, x = 2 + }); + m.SetPerChannelQuantizedFilter( + // [2 * 2 * 2 * 2] as [output_channel, y, x,input_channel] + { + 1, 2, // out channel = 0, y = 0, x = 0 + 3, 4, // out channel = 0, y = 0, x = 1 + 3, 4, // out channel = 0, y = 1, x = 0 + 5, 6, // out channel = 0, y = 1, x = 1 + 7, 8, // out channel = 1, y = 0, x = 0 + 5, 6, // out channel = 1, y = 0, x = 1 + 3, 4, // out channel = 1, y = 1, x = 0 + 1, 2, // out channel = 1, y = 1, x = 1 + }); + m.SetPerChannelQuantizedBias({3, -2}); + + m.ApplyDelegateAndInvoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({31, 56, -57, -44}, 1e-5))); +} + +TEST(QuantizedConvolutionOpModel, SimplePerChannel_Int8) { + QuantizedConvolutionOpModel m( + BuiltinOperator_CONV_2D, + {TensorType_INT8, {1, 2, 3, 2}, -63.5, 64, 0.5, -1}, + {TensorType_INT8, + // [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel] + {2, 2, 2, 2}, + 0, + 0, + 0, + 0, + /*per_channel_quantization=*/true, + /*per_channel_quantization_scales=*/{1, 2}, + /*per_channel_quantization_offsets=*/{0, 0}, + /*channel_index=*/0}, + {TensorType_INT8, {}, -63.5, 64, 0.5, -1}, Padding_VALID); + m.SetInt8Input({ + // [1 * 2 * 3 * 2] as [batch, y, x, input_channel] + 3, 2, // batch = 0, y = 0, x = 0 + 1, -1, // batch = 0, y = 0, x = 1 + -2, -3, // batch = 0, y = 0, x = 2 + 4, 3, // batch = 0, y = 1, x = 0 + 2, -2, // batch = 0, y = 1, x = 1 + -3, -4, // batch = 0, y = 1, x = 2 + }); + m.SetPerChannelQuantizedFilter( + // [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel] + { + 1, 2, // out channel = 0, y = 0, x = 0 + 3, 4, // out channel = 0, y = 0, x = 1 + 3, 4, // out channel = 0, y = 1, x = 0 + 5, 6, // out channel = 0, y = 1, x = 1 + 7, 8, // out channel = 1, y = 0, x = 0 + 5, 6, // out channel = 1, y = 0, x = 1 + 3, 4, // out channel = 1, y = 1, x = 0 + 1, 2, // out channel = 1, y = 1, x = 1 + }); + m.SetPerChannelQuantizedBias({3, -2}); + + m.ApplyDelegateAndInvoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({31, 64, -57, -46}, 0.6f))); +} + +// DEPTHWISE CONVOLUTION TESTS + TEST(QuantizedConvolutionOpModel, SimpleDilatedDepthwiseConvTestPaddingValid) { const int depth = 1; const int image_width = 9; @@ -155,7 +354,7 @@ TEST(QuantizedConvolutionOpModel, SimpleDilatedDepthwiseConvTestPaddingValid) { // | 5 | 5 | 5 | // | 5 | 5 | 5 | // | 5 | 5 | 5 | - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5})); } @@ -182,80 +381,13 @@ TEST(QuantizedConvolutionOpModel, DepthwiseConv5x5) { // Reference output. m.Invoke(); - auto reference_output = m.GetDequantizedOutput(); + auto reference_output = m.GetDequantizedOutput(); m.ApplyDelegateAndInvoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(reference_output, 1e-5))); } -TEST(QuantizedConvolutionOpModel, SimpleConvTestNoActivation) { - QuantizedConvolutionOpModel m( - BuiltinOperator_CONV_2D, {TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64}, - {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64}, - {TensorType_UINT8, {}, -127, 128}, Padding_VALID, /**dilation_factor**/ 1, - /**stride**/ 2); - m.SetInput({ - // First batch - 1, 1, 1, 1, // row = 1 - 2, 2, 2, 2, // row = 2 - // Second batch - 1, 2, 3, 4, // row = 1 - 1, 2, 3, 4, // row = 2 - }); - m.SetFilter({ - 1, 2, 3, 4, // first 2x2 filter - -1, 1, -1, 1, // second 2x2 filter - -1, -1, 1, 1, // third 2x2 filter - }); - m.SetBias({1, 2, 3}); - - m.ApplyDelegateAndInvoke(); - - EXPECT_THAT(m.GetDequantizedOutput(), - ElementsAreArray(ArrayFloatNear( - { - 18, 2, 5, // first batch, left - 18, 2, 5, // first batch, right - 17, 4, 3, // second batch, left - 37, 4, 3, // second batch, right - }, - 1e-5))); -} - -TEST(QuantizedConvolutionOpModel, SimpleConvTestReLU6Activation) { - QuantizedConvolutionOpModel m( - BuiltinOperator_CONV_2D, {TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64}, - {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64}, - {TensorType_UINT8, {}, -127, 128}, Padding_VALID, /**dilation_factor**/ 1, - /**stride**/ 2, ActivationFunctionType_RELU6); - m.SetInput({ - // First batch - 1, 1, 1, 1, // row = 1 - 2, 2, 2, 2, // row = 2 - // Second batch - 1, 2, 3, 4, // row = 1 - 1, 2, 3, 4, // row = 2 - }); - m.SetFilter({ - 1, 2, 3, 4, // first 2x2 filter - -1, 1, -1, 1, // second 2x2 filter - -1, -1, 1, 1, // third 2x2 filter - }); - m.SetBias({1, 2, 3}); - - m.ApplyDelegateAndInvoke(); - - EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( - { - 6, 2, 5, // first batch, left - 6, 2, 5, // first batch, right - 6, 4, 3, // second batch, left - 6, 4, 3, // second batch, right - }, - 1e-5))); -} - // Depthwise Conv with multiplier > 1 but input depth==1 should resolve into a // Conv op. TEST(QuantizedConvolutionOpModel, DepthwiseConvWithMultiplier_InputDepth1) { @@ -289,10 +421,10 @@ TEST(QuantizedConvolutionOpModel, DepthwiseConvWithMultiplier_InputDepth1) { // Reference output. m.Invoke(); - auto reference_output = m.GetDequantizedOutput(); + auto reference_output = m.GetDequantizedOutput(); m.ApplyDelegateAndInvoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(reference_output, 1e-5))); } @@ -331,11 +463,176 @@ TEST(QuantizedConvolutionOpModel, // Reference output. m.Invoke(); - auto reference_output = m.GetDequantizedOutput(); + auto reference_output = m.GetDequantizedOutput(); m.ApplyDelegateAndInvoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(reference_output, 1e-5))); } +TEST(QuantizedConvolutionOpModel, DepthwiseConvSimplePerTensor_Int8) { + QuantizedConvolutionOpModel m( + BuiltinOperator_DEPTHWISE_CONV_2D, + {TensorType_INT8, {1, 2, 3, 1}, -63.5, 64, 0.5, -1}, + {TensorType_INT8, + // [1 * 2 * 2 * 4] as [input_channel, y, x, output_channel] + {1, 2, 2, 4}, + 0, + 0, + 0, + 0, + /*per_channel_quantization=*/true, + /*per_channel_quantization_scales=*/{1}, + /*per_channel_quantization_offsets=*/{0}, + /*channel_index=*/3}, + {TensorType_INT8, {}, -63.5, 64, 0.5, -1}, Padding_VALID); + m.SetInt8Input({ + // [1 * 2 * 3 * 1] as [batch, y, x, input_channel] + 3, // batch = 0, y = 0, x = 0 + 1, // batch = 0, y = 0, x = 1 + -2, // batch = 0, y = 0, x = 2 + 4, // batch = 0, y = 1, x = 0 + 2, // batch = 0, y = 1, x = 1 + -3, // batch = 0, y = 1, x = 2 + }); + m.SetPerChannelQuantizedFilter({ + // [1 * 2 * 2 * 4] as [input_channel, y, x, output_channel] + // depth multiplier = 2 + 1, 2, 3, 4, // y = 0, x = 0 + 3, 4, 5, 6, // y = 0, x = 1 + 7, 8, 5, 6, // y = 1, x = 0 + 3, 4, 1, 2, // y = 1, x = 1 + }); + m.SetPerChannelQuantizedBias({3, -2, 4, 6}); + + m.ApplyDelegateAndInvoke(); + + EXPECT_THAT( + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({43, 48, 40, 52, 3, -4, 4, 4}, 0.6f))); +} + +TEST(QuantizedConvolutionOpModel, DepthwiseConvSimplePerAxis_Int8) { + QuantizedConvolutionOpModel m( + BuiltinOperator_DEPTHWISE_CONV_2D, + {TensorType_INT8, {1, 2, 3, 1}, -63.5, 64, 0.5, -1}, + {TensorType_INT8, + // [1 * 2 * 2 * 4] as [input_channel, y, x, output_channel] + {1, 2, 2, 4}, + 0, + 0, + 0, + 0, + /*per_channel_quantization=*/true, + /*per_channel_quantization_scales=*/{0.1, 2, 3, 0.4}, + /*per_channel_quantization_offsets=*/{0, 0, 0, 0}, + /*channel_index=*/3}, + {TensorType_INT8, {}, -63.5, 64, 0.5, -1}, Padding_VALID); + m.SetInt8Input({ + // [1 * 2 * 3 * 1] as [batch, y, x, input_channel] + 3, // batch = 0, y = 0, x = 0 + 1, // batch = 0, y = 0, x = 1 + -2, // batch = 0, y = 0, x = 2 + 4, // batch = 0, y = 1, x = 0 + 2, // batch = 0, y = 1, x = 1 + -4, // batch = 0, y = 1, x = 2 + }); + m.SetPerChannelQuantizedFilter({ + // [1 * 2 * 2 * 4] as [input_channel, y, x, output_channel] + // depth multiplier = 2 + 1, 2, 3, 4, // y = 0, x = 0 + 3, 4, 5, 6, // y = 0, x = 1 + 7, 8, 5, 6, // y = 1, x = 0 + 3, 4, 1, 2, // y = 1, x = 1 + }); + m.SetPerChannelQuantizedBias({3, -2, 4, 6}); + + m.ApplyDelegateAndInvoke(); + + EXPECT_THAT( + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({43, 48, 42, 52, 0, -8, 6, 2}, 0.6f))); +} + +TEST(QuantizedConvolutionOpModel, DepthwiseConvPerChannel_3x3Filter) { + QuantizedConvolutionOpModel m( + BuiltinOperator_DEPTHWISE_CONV_2D, + {TensorType_INT8, {1, 3, 3, 8}, -63.5, 64, 0.5, -1}, + {TensorType_INT8, + // [1 * 3 * 3 * 8] as [input_channel, y, x, output_channel] + {1, 3, 3, 8}, + 0, + 0, + 0, + 0, + /*per_channel_quantization=*/true, + /*per_channel_quantization_scales=*/ + {0.1, 0.2, 0.3, 0.4, 0.4, 0.3, 0.2, 0.1}, + /*per_channel_quantization_offsets=*/{0, 0, 0, 0, 0, 0, 0, 0}, + /*channel_index=*/3}, + {TensorType_INT8, {}, -63.5, 64, 0.5, -1}, Padding_VALID); + m.SetInt8Input({// array of 9 x 8 => [1, 3, 3, 8] + 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, + 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, + 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, + 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0}); + m.SetPerChannelQuantizedFilter( + {// array of 9 x 8 => [1, 3, 3, 8] + 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, + 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, + 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8}); + m.SetPerChannelQuantizedBias({0, 0, 0, 0, 0, 0, 0, 0}); + + m.ApplyDelegateAndInvoke(); + + EXPECT_THAT( + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({9, 18, 0, 0, 47, 54, 0, 0}, 0.6f))); +} + +TEST(QuantizedConvolutionOpModel, + DepthwiseConvPerChannel_3x3FilterPaddingSame) { + QuantizedConvolutionOpModel m( + BuiltinOperator_DEPTHWISE_CONV_2D, + {TensorType_INT8, {1, 3, 3, 8}, -63.5, 64, 0.5, -1}, + {TensorType_INT8, + // [1 * 3 * 3 * 8] as [input_channel, y, x, output_channel] + {1, 3, 3, 8}, + 0, + 0, + 0, + 0, + /*per_channel_quantization=*/true, + /*per_channel_quantization_scales=*/ + {0.1, 0.2, 0.3, 0.4, 0.4, 0.3, 0.2, 0.1}, + /*per_channel_quantization_offsets=*/{0, 0, 0, 0, 0, 0, 0, 0}, + /*channel_index=*/3}, + {TensorType_INT8, {}, -63.5, 64, 0.5, -1}, Padding_SAME); + m.SetInt8Input({// array of 9 x 8 => [1, 3, 3, 8] + 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, + 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, + 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, + 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0}); + m.SetPerChannelQuantizedFilter( + {// array of 9 x 8 => [1, 3, 3, 8] + 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, + 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, + 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8}); + m.SetPerChannelQuantizedBias({0, 0, 0, 0, 0, 0, 0, 0}); + + m.ApplyDelegateAndInvoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + // array of 9 x 8 => [1, 3, 3, 8] + 4, 8, 0, 0, 21, 24, 0, 0, 6, 12, 0, 0, 31.5, 36, 0, 0, + 4, 8, 0, 0, 21, 24, 0, 0, 6, 12, 0, 0, 31.5, 36, 0, 0, + 9, 18, 0, 0, 47, 54, 0, 0, 6, 12, 0, 0, 31.5, 36, 0, 0, + 4, 8, 0, 0, 21, 24, 0, 0, 6, 12, 0, 0, 31.5, 36, 0, 0, + 4, 8, 0, 0, 21, 24, 0, 0, + }, + 0.6f))); +} + } // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/hexagon/utils.cc b/tensorflow/lite/experimental/delegates/hexagon/utils.cc index 508f6657e61..4c4862f53da 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/utils.cc +++ b/tensorflow/lite/experimental/delegates/hexagon/utils.cc @@ -64,11 +64,13 @@ TfLiteStatus Get4DShape(unsigned int* batch_size, unsigned int* height_size, bool CheckOpVersion(const TfLiteRegistration* registration) { switch (registration->builtin_code) { case kTfLiteBuiltinAveragePool2d: - case kTfLiteBuiltinDepthwiseConv2d: case kTfLiteBuiltinSoftmax: return registration->version <= 2; case kTfLiteBuiltinRelu: return registration->version >= 2; + case kTfLiteBuiltinConv2d: + case kTfLiteBuiltinDepthwiseConv2d: + return registration->version <= 3; default: return registration->version == 1; } @@ -182,7 +184,9 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration, } case kTfLiteBuiltinConv2d: { if (!InputsWithCorrectTypes(node, context, - {kTfLiteUInt8, kTfLiteUInt8, kTfLiteInt32})) + {kTfLiteUInt8, kTfLiteUInt8, kTfLiteInt32}) && + !InputsWithCorrectTypes(node, context, + {kTfLiteInt8, kTfLiteInt8, kTfLiteInt32})) return false; const TfLiteConvParams* conv_params = reinterpret_cast(node->builtin_data); @@ -194,7 +198,9 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration, } case kTfLiteBuiltinDepthwiseConv2d: { if (!InputsWithCorrectTypes(node, context, - {kTfLiteUInt8, kTfLiteUInt8, kTfLiteInt32})) + {kTfLiteUInt8, kTfLiteUInt8, kTfLiteInt32}) && + !InputsWithCorrectTypes(node, context, + {kTfLiteInt8, kTfLiteInt8, kTfLiteInt32})) return false; // Check dilation. diff --git a/tensorflow/lite/experimental/ios/BUILD.apple b/tensorflow/lite/experimental/ios/BUILD.apple index 5aa662376e4..faa3f12971c 100644 --- a/tensorflow/lite/experimental/ios/BUILD.apple +++ b/tensorflow/lite/experimental/ios/BUILD.apple @@ -11,8 +11,27 @@ package( licenses = ["notice"], # Apache 2.0 ) +genrule( + name = "strip_coreml_include_hdr", + srcs = ["//tensorflow/lite/experimental/delegates/coreml:coreml_delegate.h"], + outs = ["coreml_delegate.h"], + cmd = """ + sed 's/#include \".*common.h"/#include \"common.h\"/' \ + "$(location //tensorflow/lite/experimental/delegates/coreml:coreml_delegate.h)" \ + > "$@" + """, +) + TFL_LIBRARY_HDRS = [ "//tensorflow/lite/delegates/gpu:metal_delegate.h", + "//tensorflow/lite/experimental/delegates/coreml:coreml_delegate.h", + "//tensorflow/lite/c:c_api.h", + "//tensorflow/lite/c:common.h", +] + +TFL_FRAMEWORK_HDRS = [ + "//tensorflow/lite/delegates/gpu:metal_delegate.h", + ":coreml_delegate.h", "//tensorflow/lite/c:c_api.h", "//tensorflow/lite/c:common.h", ] @@ -20,7 +39,7 @@ TFL_LIBRARY_HDRS = [ # bazel build -c opt --config=ios_fat //tensorflow/lite/experimental/ios:TensorFlowLiteC_framework ios_static_framework( name = "TensorFlowLiteC_framework", - hdrs = TFL_LIBRARY_HDRS, + hdrs = TFL_FRAMEWORK_HDRS, bundle_name = "TensorFlowLiteC", minimum_os_version = TFL_MINIMUM_OS_VERSION, deps = [ @@ -34,6 +53,7 @@ objc_library( module_name = "TensorFlowLiteC", weak_sdk_frameworks = [ "Metal", + "CoreML", ], deps = [ ":tensorflow_lite_c", @@ -75,6 +95,7 @@ cc_library( deps = [ "//tensorflow/lite/c:c_api", "//tensorflow/lite/delegates/gpu:metal_delegate", + "//tensorflow/lite/experimental/delegates/coreml:coreml_delegate", ], ) diff --git a/tensorflow/lite/experimental/ios/TensorFlowLiteC.podspec.template b/tensorflow/lite/experimental/ios/TensorFlowLiteC.podspec.template index ec5ed33670f..d69c479282b 100644 --- a/tensorflow/lite/experimental/ios/TensorFlowLiteC.podspec.template +++ b/tensorflow/lite/experimental/ios/TensorFlowLiteC.podspec.template @@ -20,4 +20,5 @@ Pod::Spec.new do |s| s.module_name = 'TensorFlowLiteC' s.library = 'c++' s.vendored_frameworks = 'Frameworks/TensorFlowLiteC.framework' + s.weak_frameworks = 'CoreML' end diff --git a/tensorflow/lite/experimental/ios/TensorFlowLiteSelectTfOps.podspec.template b/tensorflow/lite/experimental/ios/TensorFlowLiteSelectTfOps.podspec.template index 7a91e4a08ce..11229075c03 100644 --- a/tensorflow/lite/experimental/ios/TensorFlowLiteSelectTfOps.podspec.template +++ b/tensorflow/lite/experimental/ios/TensorFlowLiteSelectTfOps.podspec.template @@ -18,4 +18,5 @@ Pod::Spec.new do |s| s.module_name = 'TensorFlowLiteSelectTfOps' s.library = 'c++' s.vendored_frameworks = 'Frameworks/TensorFlowLiteSelectTfOps.framework' + s.weak_frameworks = 'CoreML' end diff --git a/tensorflow/lite/experimental/objc/apis/TFLTensor.h b/tensorflow/lite/experimental/objc/apis/TFLTensor.h index fd781bd5723..ced6c55add8 100644 --- a/tensorflow/lite/experimental/objc/apis/TFLTensor.h +++ b/tensorflow/lite/experimental/objc/apis/TFLTensor.h @@ -49,6 +49,9 @@ typedef NS_ENUM(NSUInteger, TFLTensorDataType) { /** 8-bit signed integer. */ TFLTensorDataTypeInt8, + + /** 64-bit double precision floating point. */ + TFLTensorDataTypeFloat64, }; /** diff --git a/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm b/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm index e0cca1076f6..94031ee5428 100644 --- a/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm +++ b/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm @@ -373,6 +373,8 @@ static void TFLInterpreterErrorReporter(void *user_data, const char *format, va_ return TFLTensorDataTypeFloat32; case kTfLiteFloat16: return TFLTensorDataTypeFloat16; + case kTfLiteFloat64: + return TFLTensorDataTypeFloat64; case kTfLiteInt32: return TFLTensorDataTypeInt32; case kTfLiteUInt8: diff --git a/tensorflow/lite/experimental/ruy/CONTRIBUTING.md b/tensorflow/lite/experimental/ruy/CONTRIBUTING.md deleted file mode 100644 index 654a071648d..00000000000 --- a/tensorflow/lite/experimental/ruy/CONTRIBUTING.md +++ /dev/null @@ -1,28 +0,0 @@ -# How to Contribute - -We'd love to accept your patches and contributions to this project. There are -just a few small guidelines you need to follow. - -## Contributor License Agreement - -Contributions to this project must be accompanied by a Contributor License -Agreement. You (or your employer) retain the copyright to your contribution; -this simply gives us permission to use and redistribute your contributions as -part of the project. Head over to to see -your current agreements on file or to sign a new one. - -You generally only need to submit a CLA once, so if you've already submitted one -(even if it was for a different project), you probably don't need to do it -again. - -## Code reviews - -All submissions, including submissions by project members, require review. We -use GitHub pull requests for this purpose. Consult -[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more -information on using pull requests. - -## Community Guidelines - -This project follows [Google's Open Source Community -Guidelines](https://opensource.google/conduct/). diff --git a/tensorflow/lite/experimental/ruy/README.md b/tensorflow/lite/experimental/ruy/README.md deleted file mode 100644 index 09b85927d09..00000000000 --- a/tensorflow/lite/experimental/ruy/README.md +++ /dev/null @@ -1,24 +0,0 @@ -# The ruy matrix multiplication library - -This is not an officially supported Google product. - -ruy is a matrix multiplication library. Its focus is to cover the matrix -multiplication needs of neural network inference engines. Its initial user has -been TensorFlow Lite, where it is used by default on the ARM CPU architecture. - -ruy supports both floating-point and 8bit-integer-quantized matrices. - -## Efficiency - -ruy is designed to achieve maximal performance not just on very large sizes, as -is the focus of many established libraries, but on whatever are the actual sizes -and shapes of matrices most critical in current TensorFlow Lite applications. -This often means quite small sizes, e.g. 100x100 or even 50x50, and all sorts of -rectangular shapes. - -ruy is currently only optimized for the ARM architectures (both 64-bit and -32-bit code). Optimization for the Intel x86 architecture is in progress. - -ruy is currently optimized only for the following combination of storage orders: -LHS = row-major, RHS = column-major, destination = column-major. All other -combinations of storage orders fall back to slow reference code at the moment. diff --git a/tensorflow/lite/experimental/ruy/WORKSPACE b/tensorflow/lite/experimental/ruy/WORKSPACE deleted file mode 100644 index 8364d8047b1..00000000000 --- a/tensorflow/lite/experimental/ruy/WORKSPACE +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Workspace file for the Ruy project. - -workspace(name = "com_google_ruy") diff --git a/tensorflow/lite/experimental/ruy/ruy/BUILD b/tensorflow/lite/experimental/ruy/ruy/BUILD deleted file mode 100644 index c808c3ec063..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/BUILD +++ /dev/null @@ -1,954 +0,0 @@ -# Ruy is not BLAS - -load(":build_defs.bzl", "ruy_copts_avx2", "ruy_copts_avxvnni", "ruy_copts_base", "ruy_copts_skylake", "ruy_copts_sse42") -load(":ruy_test_ext.bzl", "ruy_test_ext_defines", "ruy_test_ext_deps") -load(":ruy_test.bzl", "ruy_benchmark", "ruy_test") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -config_setting( - name = "windows", - values = {"cpu": "x64_windows"}, -) - -config_setting( - name = "armeabi-v7a", - values = {"cpu": "armeabi-v7a"}, -) - -config_setting( - name = "x86_64", - values = {"cpu": "k8"}, -) - -config_setting( - name = "optimized", - values = { - "compilation_mode": "opt", - }, - visibility = ["//visibility:public"], -) - -cc_library( - name = "platform", - hdrs = ["platform.h"], - copts = ruy_copts_base(), -) - -cc_library( - name = "check_macros", - hdrs = ["check_macros.h"], - copts = ruy_copts_base(), -) - -cc_test( - name = "check_macros_test", - srcs = ["check_macros_test.cc"], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "opt_set", - hdrs = ["opt_set.h"], - copts = ruy_copts_base(), -) - -cc_library( - name = "time", - hdrs = ["time.h"], - copts = ruy_copts_base(), -) - -cc_library( - name = "wait", - srcs = ["wait.cc"], - hdrs = ["wait.h"], - copts = ruy_copts_base(), - deps = [":time"], -) - -cc_test( - name = "wait_test", - srcs = ["wait_test.cc"], - deps = [ - ":platform", - ":wait", - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "size_util", - hdrs = ["size_util.h"], - copts = ruy_copts_base(), - deps = [":check_macros"], -) - -cc_test( - name = "size_util_test", - srcs = ["size_util_test.cc"], - deps = [ - ":size_util", - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "tune", - srcs = [ - "tune.cc", - ], - hdrs = [ - "tune.h", - ], - copts = ruy_copts_base(), - deps = [ - ":opt_set", - ":platform", - ":time", - ], -) - -cc_library( - name = "prepacked_cache", - srcs = [ - "prepacked_cache.cc", - ], - hdrs = [ - "prepacked_cache.h", - ], - copts = ruy_copts_base(), - deps = [ - ":allocator", - ":matrix", - ":opt_set", - ":platform", - ":time", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_test( - name = "tune_test", - srcs = ["tune_test.cc"], - deps = [ - ":tune", - "@com_google_googletest//:gtest", - ], -) - -cc_test( - name = "prepacked_cache_test", - srcs = ["prepacked_cache_test.cc"], - deps = [ - ":prepacked_cache", - ":ruy", - ":time", - "@com_google_googletest//:gtest", - ], -) - -cc_binary( - name = "tune_tool", - srcs = ["tune_tool.cc"], - deps = [ - ":tune", - ], -) - -cc_library( - name = "allocator", - srcs = [ - "allocator.cc", - ], - hdrs = [ - "allocator.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":size_util", - ], -) - -cc_test( - name = "allocator_test", - srcs = ["allocator_test.cc"], - deps = [ - ":allocator", - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "side_pair", - hdrs = ["side_pair.h"], - copts = ruy_copts_base(), - deps = [":check_macros"], -) - -cc_library( - name = "block_map", - srcs = [ - "block_map.cc", - ], - hdrs = [ - "block_map.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":opt_set", - ":path", - ":side_pair", - ":size_util", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_test( - name = "block_map_test", - srcs = ["block_map_test.cc"], - deps = [ - ":block_map", - ":cpu_cache_size", - ":path", - ":side_pair", - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "blocking_counter", - srcs = [ - "blocking_counter.cc", - ], - hdrs = [ - "blocking_counter.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":wait", - ], -) - -cc_library( - name = "thread_pool", - srcs = [ - "thread_pool.cc", - ], - hdrs = [ - "thread_pool.h", - ], - copts = ruy_copts_base(), - deps = [ - ":blocking_counter", - ":check_macros", - ":wait", - ], -) - -cc_library( - name = "detect_arm", - srcs = [ - "detect_arm.cc", - ], - hdrs = [ - "detect_arm.h", - ], - copts = ruy_copts_base(), -) - -cc_library( - name = "detect_x86", - srcs = [ - "detect_x86.cc", - ], - hdrs = [ - "detect_x86.h", - ], - copts = ruy_copts_base(), - deps = [ - ":platform", - ], -) - -cc_library( - name = "path", - hdrs = ["path.h"], - copts = ruy_copts_base(), - deps = [ - ":platform", - ":size_util", - ], -) - -cc_library( - name = "cpu_cache_size", - hdrs = ["cpu_cache_size.h"], - copts = ruy_copts_base(), - deps = [ - ":path", - ":platform", - ], -) - -cc_library( - name = "trace", - srcs = [ - "trace.cc", - ], - hdrs = [ - "trace.h", - ], - copts = ruy_copts_base(), - deps = [ - ":block_map", - ":check_macros", - ":side_pair", - ":time", - ], -) - -cc_library( - name = "matrix", - hdrs = ["matrix.h"], - copts = ruy_copts_base(), - deps = [":check_macros"], -) - -cc_library( - name = "spec", - hdrs = ["spec.h"], - copts = ruy_copts_base(), - deps = [ - ":cpu_cache_size", - ":matrix", - ], -) - -cc_library( - name = "internal_matrix", - hdrs = ["internal_matrix.h"], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":common", - ":matrix", - ":size_util", - ], -) - -cc_library( - name = "common", - hdrs = [ - "common.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":matrix", - ":opt_set", - ":path", - ":platform", - ], -) - -cc_library( - name = "kernel_common", - hdrs = [ - "kernel.h", - "kernel_arm.h", - "kernel_common.h", - "kernel_x86.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":common", - ":internal_matrix", - ":matrix", - ":opt_set", - ":path", - ":platform", - ":side_pair", - ":size_util", - ":spec", - ":tune", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "pack_common", - hdrs = [ - "pack.h", - "pack_arm.h", - "pack_common.h", - "pack_x86.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":common", - ":internal_matrix", - ":matrix", - ":opt_set", - ":path", - ":platform", - ":tune", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "kernel_arm", - srcs = [ - "kernel_arm32.cc", - "kernel_arm64.cc", - ], - copts = ruy_copts_base(), - deps = [ - ":common", - ":kernel_common", - ":opt_set", - ":platform", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "pack_arm", - srcs = [ - "pack_arm.cc", - ], - copts = ruy_copts_base(), - deps = [ - ":common", - ":opt_set", - ":pack_common", - ":platform", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -# AVX-512 compilation units. -# -# These must use the same compiler options. -RUY_COPTS_BUILT_FOR_AVX512 = ruy_copts_base() + ruy_copts_skylake() - -cc_library( - name = "kernel_avx512", - srcs = [ - "kernel_avx512.cc", - ], - copts = RUY_COPTS_BUILT_FOR_AVX512, - deps = [ - ":check_macros", - ":kernel_common", - ":opt_set", - ":platform", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "pack_avx512", - srcs = [ - "pack_avx512.cc", - ], - copts = RUY_COPTS_BUILT_FOR_AVX512, - deps = [ - ":check_macros", - ":matrix", - ":opt_set", - ":pack_common", - ":path", - ":platform", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "have_built_path_for_avx512", - srcs = [ - "have_built_path_for_avx512.cc", - ], - hdrs = [ - "have_built_path_for.h", - ], - copts = RUY_COPTS_BUILT_FOR_AVX512, - deps = [ - ":opt_set", - ":platform", - ], -) -# End: AVX-512 compilation units. - -# AVX2 compilation units. -# -# These must use the same compiler options. -RUY_COPTS_BUILT_FOR_AVX2 = ruy_copts_base() + ruy_copts_avx2() - -cc_library( - name = "kernel_avx2", - srcs = [ - "kernel_avx2.cc", - ], - copts = RUY_COPTS_BUILT_FOR_AVX2, - deps = [ - ":check_macros", - ":kernel_common", - ":opt_set", - ":platform", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "pack_avx2", - srcs = [ - "pack_avx2.cc", - ], - copts = RUY_COPTS_BUILT_FOR_AVX2, - deps = [ - ":check_macros", - ":matrix", - ":opt_set", - ":pack_common", - ":path", - ":platform", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "have_built_path_for_avx2", - srcs = [ - "have_built_path_for_avx2.cc", - ], - hdrs = [ - "have_built_path_for.h", - ], - copts = RUY_COPTS_BUILT_FOR_AVX2, - deps = [ - ":opt_set", - ":platform", - ], -) -# End: AVX2 compilation units. - -# SSE42 compilation units. -# -# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -# Optimization is not finished. In particular the dimensions of the kernel -# blocks can be changed as desired. -# -# These must use the same compiler options. -RUY_COPTS_BUILT_FOR_SSE42 = ruy_copts_base() + ruy_copts_sse42() - -cc_library( - name = "kernel_sse42", - srcs = [ - "kernel_sse42.cc", - ], - copts = RUY_COPTS_BUILT_FOR_SSE42, - deps = [ - ":check_macros", - ":kernel_common", - ":opt_set", - ":platform", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "pack_sse42", - srcs = [ - "pack_sse42.cc", - ], - copts = RUY_COPTS_BUILT_FOR_SSE42, - deps = [ - ":check_macros", - ":matrix", - ":opt_set", - ":pack_common", - ":path", - ":platform", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "have_built_path_for_sse42", - srcs = [ - "have_built_path_for_sse42.cc", - ], - hdrs = [ - "have_built_path_for.h", - ], - copts = RUY_COPTS_BUILT_FOR_SSE42, - deps = [ - ":opt_set", - ":platform", - ], -) -# End: SSE42 compilation units. - -# AVX-VNNI compilation units. -# -# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -# Optimization is not finished. In particular the dimensions of the kernel -# blocks can be changed as desired. -# -# These must use the same compiler options. -RUY_COPTS_BUILT_FOR_AVX_VNNI = ruy_copts_base() + ruy_copts_avxvnni() - -cc_library( - name = "kernel_avxvnni", - srcs = [ - "kernel_avxvnni.cc", - ], - copts = RUY_COPTS_BUILT_FOR_AVX_VNNI, - deps = [ - ":check_macros", - ":kernel_common", - ":opt_set", - ":platform", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "pack_avxvnni", - srcs = [ - "pack_avxvnni.cc", - ], - copts = RUY_COPTS_BUILT_FOR_AVX_VNNI, - deps = [ - ":check_macros", - ":matrix", - ":opt_set", - ":pack_common", - ":path", - ":platform", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "have_built_path_for_avxvnni", - srcs = [ - "have_built_path_for_avxvnni.cc", - ], - hdrs = [ - "have_built_path_for.h", - ], - copts = RUY_COPTS_BUILT_FOR_AVX_VNNI, - deps = [ - ":opt_set", - ":platform", - ], -) -# End: AVX-VNNI compilation units. - -cc_library( - name = "kernel", - hdrs = [ - "kernel.h", - "kernel_common.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":common", - ":internal_matrix", - ":kernel_arm", # fixdeps: keep - ":kernel_avx2", # fixdeps: keep - ":kernel_avx512", # fixdeps: keep - ":kernel_avxvnni", # fixdeps: keep - ":kernel_common", - ":kernel_sse42", # fixdeps: keep - ":matrix", - ":opt_set", - ":path", - ":platform", - ":side_pair", - ":size_util", - ":spec", - ":tune", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "pack", - hdrs = [ - "pack.h", - "pack_common.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":common", - ":internal_matrix", - ":matrix", - ":opt_set", - ":pack_arm", # fixdeps: keep - ":pack_avx2", # fixdeps: keep - ":pack_avx512", # fixdeps: keep - ":pack_avxvnni", # fixdeps: keep - ":pack_common", - ":pack_sse42", # fixdeps: keep - ":path", - ":platform", - ":tune", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "have_built_path_for", - hdrs = [ - "have_built_path_for.h", - ], - deps = [ - ":have_built_path_for_avx2", - ":have_built_path_for_avx512", - ":have_built_path_for_avxvnni", - ":have_built_path_for_sse42", - ":platform", - ], -) - -cc_library( - name = "context", - srcs = [ - "context.cc", - ], - hdrs = [ - "context.h", - ], - copts = ruy_copts_base(), - deps = [ - ":allocator", - ":check_macros", - ":detect_arm", - ":detect_x86", - ":have_built_path_for", - ":path", - ":platform", - ":prepacked_cache", - ":thread_pool", - ":trace", - ":tune", - ], -) - -cc_test( - name = "context_test", - srcs = ["context_test.cc"], - deps = [ - ":context", - ":path", - ":platform", - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "trmul_params", - hdrs = ["trmul_params.h"], - copts = ruy_copts_base(), - deps = [ - ":internal_matrix", - ":side_pair", - ":tune", - ], -) - -cc_library( - name = "trmul", - srcs = ["trmul.cc"], - hdrs = ["trmul.h"], - copts = ruy_copts_base(), - deps = [ - ":allocator", - ":block_map", - ":check_macros", - ":common", - ":context", - ":internal_matrix", - ":matrix", - ":opt_set", - ":side_pair", - ":size_util", - ":spec", - ":thread_pool", - ":trace", - ":trmul_params", - ":tune", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -# The main library. -cc_library( - name = "ruy", - srcs = [ - "dispatch.h", - "prepack.h", - ], - hdrs = [ - "ruy.h", - "ruy_advanced.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":common", - ":context", - ":internal_matrix", - ":kernel", - ":matrix", - ":opt_set", - ":pack", - ":path", - ":prepacked_cache", - ":side_pair", - ":size_util", - ":spec", - ":trmul", - ":trmul_params", - ":tune", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -# Usage examples. -cc_binary( - name = "example", - srcs = ["example.cc"], - deps = [":ruy"], -) - -# Usage examples of the advanced API. -cc_binary( - name = "example_advanced", - srcs = ["example_advanced.cc"], - deps = [":ruy"], -) - -# Small library to query PMU counters, for benchmark only -cc_library( - name = "pmu", - testonly = True, - srcs = ["pmu.cc"], - hdrs = ["pmu.h"], - copts = ruy_copts_base(), - deps = [":check_macros"], -) - -# Testing framework. -cc_library( - name = "test_lib", - testonly = True, - hdrs = ["test.h"], - copts = ruy_copts_base(), - # need defines, not copts, because it's controlling a header, test.h - defines = ruy_test_ext_defines(), - linkopts = select({ - ":windows": [], - "//conditions:default": ["-lm"], - }), - deps = [ - ":matrix", - ":pmu", - ":ruy", - ":spec", - ":time", - "@com_google_googletest//:gtest", - ":platform", - "//tensorflow/lite/experimental/ruy/ruy/profiler:profiler", - ] + ruy_test_ext_deps(), -) - -ruy_benchmark( - name = "benchmark", - srcs = ["benchmark.cc"], - copts = ruy_copts_base(), - lhs_rhs_accum_dst = [ - ("f32", "f32", "f32", "f32"), - ("u8", "u8", "i32", "u8"), - ("i8", "i8", "i32", "u8"), - ("i8", "i8", "i32", "i8"), - ("u8", "u8", "i32", "i16"), - ("i8", "i8", "i32", "i32"), - ], - deps = [ - "//tensorflow/lite/experimental/ruy/ruy:test_lib", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -ruy_test( - name = "test_fast", - srcs = ["test_fast.cc"], - copts = ruy_copts_base(), - lhs_rhs_accum_dst = [ - ("f32", "f32", "f32", "f32"), - ("f64", "f32", "f64", "f32"), - ("f32", "f64", "f64", "f64"), - ("u8", "u8", "i32", "u8"), - ("i8", "i8", "i32", "i8"), - ("i8", "u8", "i32", "i8"), - ("u8", "u8", "i32", "i16"), - ("i8", "i8", "i32", "i32"), - ("i8", "u8", "i32", "i32"), - ], - deps = [ - "//tensorflow/lite/experimental/ruy/ruy:test_lib", - "@com_google_googletest//:gtest_main", - ], -) - -ruy_test( - name = "test_slow", - srcs = ["test_slow.cc"], - copts = ruy_copts_base(), - lhs_rhs_accum_dst = [ - ("f32", "f32", "f32", "f32"), - ("u8", "u8", "i32", "u8"), - ("i8", "i8", "i32", "i8"), - ("u8", "u8", "i32", "i16"), - ("i8", "i8", "i32", "i32"), - ], - tags = ["slow"], - deps = [ - "//tensorflow/lite/experimental/ruy/ruy:test_lib", - "@com_google_googletest//:gtest_main", - ], -) - -ruy_test( - name = "test_special_specs", - srcs = ["test_special_specs.cc"], - copts = ruy_copts_base(), - lhs_rhs_accum_dst = [ - ("f32", "f32", "f32", "f32"), - ("u8", "u8", "i32", "u8"), - ("u8", "u8", "i32", "i16"), - ], - deps = [ - "//tensorflow/lite/experimental/ruy/ruy:test_lib", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/ruy/ruy/allocator.cc b/tensorflow/lite/experimental/ruy/ruy/allocator.cc deleted file mode 100644 index 2c507561f2f..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/allocator.cc +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/allocator.h" - -#include -#include - -#ifdef _WIN32 -#include -#endif - -namespace ruy { - -namespace detail { - -void *SystemAlignedAlloc(std::ptrdiff_t num_bytes) { -#ifdef _WIN32 - return _aligned_malloc(num_bytes, kMinimumBlockAlignment); -#else - void *ptr; - if (posix_memalign(&ptr, kMinimumBlockAlignment, num_bytes)) { - return nullptr; - } - return ptr; -#endif -} - -void SystemAlignedFree(void *ptr) { -#ifdef _WIN32 - _aligned_free(ptr); -#else - free(ptr); -#endif -} - -} // namespace detail - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/allocator.h b/tensorflow/lite/experimental/ruy/ruy/allocator.h deleted file mode 100644 index 56aa0eef8f9..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/allocator.h +++ /dev/null @@ -1,185 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ALLOCATOR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ALLOCATOR_H_ - -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" - -namespace ruy { - -namespace detail { - -inline void* VoidPtrAdd(void* p, std::ptrdiff_t offset) { - RUY_DCHECK(p); - std::uintptr_t addr = reinterpret_cast(p) + offset; - return reinterpret_cast(addr); -} - -// Minimum alignment for blocks. -// -// Considerations: -// - This needs to be at least the alignment of any usual data type. -// - It's useful that this is at least the size of a cache line to limit -// possible cache side effects (if only on performance behavior). -// - It's useful that this is at least the size of SIMD registers, as -// some SIMD instruction sets have at least performance behavior -// differences (e.g. NEON) or even different requirements (e.g. SSE) -// based on that. -// - It's useful that this is at least the size of an "exclusive reservation -// granule" on ARM, meaning that if we use this Allocator to allocate -// an atomic variable, there will be no side effects from other things -// contending for exclusive/atomic memory accesses to it. While the -// ARM reference manual mentions that this granule size may be as large -// as 2048 bytes, in practice we observe it to be 64 bytes. It can -// be queried cheaply, at runtime, from userspace, if needed. -static constexpr std::ptrdiff_t kMinimumBlockAlignment = 64; - -// Primitive allocation functions obtaining aligned memory from the -// operating system. -void* SystemAlignedAlloc(std::ptrdiff_t num_bytes); -void SystemAlignedFree(void* ptr); - -// Specialized allocator designed to converge to a steady-state where all -// allocations are bump-ptr allocations from an already-allocated buffer. -// -// To support these constraints, this allocator only supports two -// operations. -// - AllocateAlignedBytes: allocates a pointer to storage of a specified -// size, which must be aligned to kMinimumBlockAlignment. -// - FreeAll: frees all previous allocations (but retains the internal -// buffer to minimize future calls into the system allocator). -// -// This class is specialized for supporting just those two operations -// under this specific steady-state usage pattern. Extending this class -// with new allocation interfaces that don't fit that pattern is probably not -// the right choice. Instead, build a new class on top of -// SystemAlignedAlloc/SystemAlignedFree. -// -// All operations happen on aligned blocks for simplicity. -class AlignedAllocator { - public: - void operator=(const AlignedAllocator&) = delete; - ~AlignedAllocator() { - FreeAll(); - SystemAlignedFree(ptr_); - } - - void* AllocateAlignedBytes(std::ptrdiff_t num_bytes) { - RUY_DCHECK_GT(num_bytes, 0); - RUY_DCHECK((num_bytes & (kMinimumBlockAlignment - 1)) == 0); - if (void* p = AllocateFast(num_bytes)) { - return p; - } - return AllocateSlow(num_bytes); - } - - void FreeAll() { - current_ = 0; - if (fallback_blocks_.empty()) { - return; - } - - // No rounding-up of the size means linear instead of logarithmic - // bound on the number of allocation in some worst-case calling patterns. - // This is considered worth it because minimizing memory usage is important - // and actual calling patterns in applications that we care about still - // reach the no-further-allocations steady state in a small finite number - // of iterations. - std::ptrdiff_t new_size = size_ + fallback_blocks_total_size_; - SystemAlignedFree(ptr_); - ptr_ = SystemAlignedAlloc(new_size); - size_ = new_size; - - for (void* p : fallback_blocks_) { - SystemAlignedFree(p); - } - fallback_blocks_.clear(); - fallback_blocks_total_size_ = 0; - } - - private: - void* AllocateFast(std::ptrdiff_t num_bytes) { - if (current_ + num_bytes > size_) { - return nullptr; - } - void* ret = VoidPtrAdd(ptr_, current_); - current_ += num_bytes; - return ret; - } - - void* AllocateSlow(std::ptrdiff_t num_bytes) { - void* p = SystemAlignedAlloc(num_bytes); - fallback_blocks_total_size_ += num_bytes; - fallback_blocks_.push_back(p); - return p; - } - - // Theory of operation: - // - // - ptr_, current_, and size_ implement a basic bump-ptr allocator. - // - // - in AllocateAlignedBytes, the fast path is just a bump-ptr - // allocation. If our bump-ptr allocator doesn't have enough space for an - // allocation, then we allocate a block from the system allocator to - // service the allocation request. We save that block in fallback_blocks_ - // and track the total size of the fallback blocks in - // fallback_blocks_total_size_. - // - // - in FreeAll, the fast path just resets the bump-ptr allocator. If - // there are any fallback blocks, we free them and reallocate the - // bump-ptr allocator's buffer so that the next sequence of allocations - // will hopefully not need any fallback blocks. - void* ptr_ = nullptr; - std::ptrdiff_t current_ = 0; - std::ptrdiff_t size_ = 0; - std::vector fallback_blocks_; - std::ptrdiff_t fallback_blocks_total_size_ = 0; -}; - -} // namespace detail - -// The main Allocator class, with a convenient interface for allocating a -// typed buffer. -class Allocator { - public: - void* AllocateBytes(std::ptrdiff_t num_bytes) { - if (num_bytes == 0) { - return nullptr; - } - return aligned.AllocateAlignedBytes( - round_up_pot(num_bytes, detail::kMinimumBlockAlignment)); - } - template - void Allocate(std::ptrdiff_t count, Pointer* out) { - using T = typename std::pointer_traits::element_type; - *out = static_cast(AllocateBytes(count * sizeof(T))); - } - - void FreeAll() { aligned.FreeAll(); } - - private: - detail::AlignedAllocator aligned; -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ALLOCATOR_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/allocator_test.cc b/tensorflow/lite/experimental/ruy/ruy/allocator_test.cc deleted file mode 100644 index 1584b86b4cc..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/allocator_test.cc +++ /dev/null @@ -1,103 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/allocator.h" - -#include - -namespace ruy { -namespace { - -TEST(AllocatorTest, ReturnsValidMemory) { - Allocator allocator; - int *p; - allocator.Allocate(1, &p); - ASSERT_NE(p, nullptr); - - // If this is bogus memory, ASan will cause this test to fail. - *p = 42; - - allocator.FreeAll(); -} - -TEST(AllocatorTest, NoLeak) { - Allocator allocator; - // Allocate and free some ridiculously large total amount of memory, so - // that a leak will hopefully cause some sort of resource exhaustion. - // - // Despite the large number of allocations, this test is actually quite - // fast, since our fast-path allocation logic is very fast. - constexpr int kNumAllocations = 100 * 1024; - constexpr int kAllocationSize = 1024 * 1024; - for (int i = 0; i < kNumAllocations; i++) { - char *p; - allocator.Allocate(kAllocationSize, &p); - allocator.FreeAll(); - } -} - -TEST(AllocatorTest, IncreasingSizes) { - Allocator allocator; - // Allocate sizes that increase by small amounts across FreeAll calls. - for (int i = 1; i < 100 * 1024; i++) { - char *p; - allocator.Allocate(i, &p); - allocator.FreeAll(); - } -} - -TEST(AllocatorTest, ManySmallAllocations) { - Allocator allocator; - // Allocate many small allocations between FreeAll calls. - for (int i = 0; i < 10 * 1024; i += 100) { - for (int j = 0; j < i; j++) { - char *p; - allocator.Allocate(1, &p); - } - allocator.FreeAll(); - } -} - -TEST(AllocatorTest, DestructorHandlesMainBumpPtr) { - // This is a white-box test. - Allocator allocator; - allocator.AllocateBytes(1); - allocator.FreeAll(); - // After the call to FreeAll, the allocator will consolidate all of the memory - // into the main bump-ptr allocator's block, which we then expect to be freed - // in the destructor. - // - // We have no test assertions -- we primarily expect that this trigger a leak - // checker and cause the test to fail. -} - -TEST(AllocatorTest, DestructorHandlesFallbackBlocks) { - // This is a white-box test. - Allocator allocator; - // Since we just created the allocator, this will allocate a fallback block, - // which we then expect to be freed in the destructor. - // - // We have no test assertions -- we primarily expect that this trigger a leak - // checker and cause the test to fail. - allocator.AllocateBytes(1); -} - -} // namespace -} // namespace ruy - -int main(int argc, char **argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/experimental/ruy/ruy/benchmark.cc b/tensorflow/lite/experimental/ruy/ruy/benchmark.cc deleted file mode 100644 index 406345cec06..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/benchmark.cc +++ /dev/null @@ -1,196 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/test.h" - -namespace ruy { - -using LhsScalar = RUY_TEST_LHSSCALAR; -using RhsScalar = RUY_TEST_RHSSCALAR; -using AccumScalar = RUY_TEST_ACCUMSCALAR; -using DstScalar = RUY_TEST_DSTSCALAR; -using TestSetType = - TestSet>; - -struct BenchmarkShape { - int rows; - int depth; - int cols; - int symm_lhs; - int symm_rhs; -}; - -template -std::vector>> BenchmarkRCC( - const BenchmarkShape& shape) { - TestSetType test_set; - test_set.rows = shape.rows; - test_set.depth = shape.depth; - test_set.cols = shape.cols; - test_set.lhs_order = Order::kRowMajor; - test_set.rhs_order = Order::kColMajor; - test_set.dst_order = Order::kColMajor; - test_set.layout_style = LayoutStyle::kPackedLinear; - test_set.benchmark = true; - const int asymmetry_lhs = shape.symm_lhs ? 0 : 1; - const int asymmetry_rhs = shape.symm_rhs ? 0 : 1; - test_set.lhs_zero_point = SymmetricZeroPoint() + asymmetry_lhs; - test_set.rhs_zero_point = SymmetricZeroPoint() + asymmetry_rhs; - test_set.use_specified_zero_points = true; - test_set.perchannel = GetBoolEnvVarOrFalse("PERCHANNEL"); - test_set.benchmark_prepack_lhs = GetBoolEnvVarOrFalse("PREPACK_LHS"); - test_set.benchmark_prepack_rhs = GetBoolEnvVarOrFalse("PREPACK_RHS"); - test_set.Run(); - return std::move(test_set.results); -} - -std::vector ParseCommaSeparatedInts( - const std::string& comma_separated_ints) { - std::vector result; - for (std::size_t pos = 0; pos < comma_separated_ints.size();) { - std::size_t delim_pos = comma_separated_ints.find(',', pos); - if (delim_pos == std::string::npos) { - delim_pos = comma_separated_ints.size(); - } - result.push_back( - std::stoi(comma_separated_ints.substr(pos, delim_pos - pos))); - pos = delim_pos + 1; - } - return result; -} - -void Benchmark() { - const bool symm_lhs = std::is_floating_point::value || - GetBoolEnvVarOrFalse("SYMM_LHS"); - const bool symm_rhs = std::is_floating_point::value || - GetBoolEnvVarOrFalse("SYMM_RHS"); - const bool benchmark_cubic = GetBoolEnvVarOrFalse("RUY_BENCHMARK_CUBIC") || - GetBoolEnvVarOrFalse("RUY_BENCHMARK_CUBIC_LIST"); - const int explicit_rows = GetIntEnvVarOrZero("ROWS"); - const int explicit_cols = GetIntEnvVarOrZero("COLS"); - const int explicit_depth = GetIntEnvVarOrZero("DEPTH"); - - std::vector shapes; - - if (benchmark_cubic) { - std::vector sizes; - const char* benchmark_cubic_list_env = getenv("RUY_BENCHMARK_CUBIC_LIST"); - if (benchmark_cubic_list_env) { - sizes = ParseCommaSeparatedInts(benchmark_cubic_list_env); - } else { - // Often 8 is used for this multiplier, but to check teeny sizes one can - // use 1. - static constexpr int cubic_size_multiplier = 8; - for (int i = 2 * cubic_size_multiplier; - i <= (512 * cubic_size_multiplier); i *= 2) { - sizes.push_back(i); - if (i < (512 * cubic_size_multiplier)) { - sizes.push_back(i * 3 / 2); - } - } - } - for (int i : sizes) { - BenchmarkShape shape; - // Even in cubic mode, one may still override an individual dimension - // to allow testing a batch of rectangular sizes. - shape.rows = explicit_rows ? explicit_rows : i; - shape.cols = explicit_cols ? explicit_cols : i; - shape.depth = explicit_depth ? explicit_depth : i; - shape.symm_lhs = symm_lhs; - shape.symm_rhs = symm_rhs; - shapes.push_back(shape); - } - } else { - BenchmarkShape shape; - shape.rows = explicit_rows; - shape.cols = explicit_cols; - shape.depth = explicit_depth; - if (!shape.rows || !shape.depth || !shape.cols) { - fprintf(stderr, - "Please specify positive sizes with these env vars: ROWS, DEPTH, " - "COLS.\n"); - exit(1); - } - shape.symm_lhs = symm_lhs; - shape.symm_rhs = symm_rhs; - shapes.push_back(shape); - } - - for (int i = 0; i < shapes.size(); i++) { - const auto& shape = shapes[i]; - const auto& results = BenchmarkRCC(shape); - if (i == 0) { - if (benchmark_cubic) { - printf("size"); - for (const auto& result : results) { - if (results.size() > 1) { - printf(",%s:Gop/s", PathName(*result).c_str()); - } else { - printf(",Gop/s"); - } - if (GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU")) { - printf( - ",l1_refill,l2_refill,l3_refill,l1tlb_refill,l2tlb_refill," - "mispred,frontend_stall,backend_stall"); - } - } - printf("\n"); - } else { - printf("path,shape,Gop/s\n"); - } - fflush(stdout); - } - if (benchmark_cubic) { - printf("%d", shape.rows); - for (const auto& result : results) { - printf(",%.4g", 2.0e-9 * shape.rows * shape.cols * shape.depth / - result->latency); - if (GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU")) { - printf(",%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g", - result->l1_refill_rate, result->l2_refill_rate, - result->l3_refill_rate, result->l1tlb_refill_rate, - result->l2tlb_refill_rate, result->mispred_rate, - result->frontend_stall_rate, result->backend_stall_rate); - } - } - printf("\n"); - fflush(stdout); - } else { - for (const auto& result : results) { - printf( - "%s,%dx%dx%d,%.4g", PathName(*result).c_str(), shape.rows, - shape.depth, shape.cols, - 2.0e-9 * shape.rows * shape.cols * shape.depth / result->latency); - if (GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU")) { - printf(",%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g", - result->l1_refill_rate, result->l2_refill_rate, - result->l3_refill_rate, result->l1tlb_refill_rate, - result->l2tlb_refill_rate, result->mispred_rate, - result->frontend_stall_rate, result->backend_stall_rate); - } - printf("\n"); - } - fflush(stdout); - } - } -} - -} // namespace ruy - -int main() { ruy::Benchmark(); } diff --git a/tensorflow/lite/experimental/ruy/ruy/block_map.cc b/tensorflow/lite/experimental/ruy/ruy/block_map.cc deleted file mode 100644 index 32781d82ad3..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/block_map.cc +++ /dev/null @@ -1,486 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/block_map.h" - -#include -#include - -#ifdef RUY_MAKEBLOCKMAP_DEBUG -#include -#include -#include -#endif - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" - -namespace ruy { - -namespace { - -void DecodeTraversalLinear(int size_log2, std::uint32_t square_index, - SidePair* local_pos) { - (*local_pos)[Side::kLhs] = square_index & ((1 << size_log2) - 1); - (*local_pos)[Side::kRhs] = square_index >> size_log2; -} - -void DecodeTraversalFractalZ(std::uint32_t square_index, - SidePair* local_pos) { - const std::uint32_t n1 = square_index; - const std::uint32_t n2 = (n1 & 0x99999999u) | ((n1 & 0x44444444u) >> 1) | - ((n1 & 0x22222222u) << 1); - const std::uint32_t n4 = (n2 & 0xc3c3c3c3u) | ((n2 & 0x30303030u) >> 2) | - ((n2 & 0x0c0c0c0cu) << 2); - const std::uint32_t n8 = (n4 & 0xf00ff00fu) | ((n4 & 0x0f000f00u) >> 4) | - ((n4 & 0x00f000f0u) << 4); - const std::uint32_t n16 = (n8 & 0xff0000ffu) | ((n8 & 0x00ff0000u) >> 8) | - ((n8 & 0x0000ff00u) << 8); - (*local_pos)[Side::kLhs] = n16 & 0xffff; - (*local_pos)[Side::kRhs] = n16 >> 16; -} - -void DecodeTraversalFractalU(std::uint32_t square_index, - SidePair* local_pos) { - DecodeTraversalFractalZ(square_index, local_pos); - // Change fractal z-order to u-order - (*local_pos)[Side::kLhs] ^= (*local_pos)[Side::kRhs]; -} - -// Code inspired by the sample code in -// https://en.wikipedia.org/wiki/Hilbert_curve -// The main optimization is to avoid hard-to-predict conditional branches -// based on the bits of the square_index parameter. -void DecodeTraversalFractalHilbert(int size_log2, std::uint32_t square_index, - SidePair* local_pos) { - std::uint32_t t = square_index; - std::uint32_t x = 0; - std::uint32_t y = 0; - // Easy-to-predict for loop, the number of iterations is the same for - // an entire GEMM. - for (int sb = 0; sb < size_log2; sb++) { - std::uint32_t s = 1 << sb; - bool rx = t & 2; - bool ry = (t & 1) ^ rx; - std::uint32_t tmp = rx ? (s - 1 - x) : x; - x = ry ? x : rx ? (s - 1 - y) : y; - y = ry ? (y + s) : tmp; - x = rx ? (x + s) : x; - t >>= 2; - } - (*local_pos)[Side::kLhs] = y; - (*local_pos)[Side::kRhs] = x; -} - -} // end anonymous namespace - -void GetBlockByIndex(const BlockMap& block_map, int index, - SidePair* block) { - profiler::ScopeLabel label("GetBlockByIndex"); - const std::uint32_t index_u32 = index; - - const std::uint32_t num_blocks_per_local_curve = - 1u << (2 * block_map.num_blocks_base_log2); - const std::uint32_t square_index = - index_u32 & (num_blocks_per_local_curve - 1); - - const int size_log2 = block_map.num_blocks_base_log2; - SidePair local_pos; - switch (block_map.traversal_order) { - case BlockMapTraversalOrder::kFractalZ: - DecodeTraversalFractalZ(square_index, &local_pos); - break; - case BlockMapTraversalOrder::kFractalU: - DecodeTraversalFractalU(square_index, &local_pos); - break; - case BlockMapTraversalOrder::kFractalHilbert: - DecodeTraversalFractalHilbert(size_log2, square_index, &local_pos); - break; - default: - RUY_DCHECK(block_map.traversal_order == BlockMapTraversalOrder::kLinear); - DecodeTraversalLinear(size_log2, square_index, &local_pos); - break; - } - - const std::uint32_t rectangular_index = - index_u32 >> 2 * block_map.num_blocks_base_log2; - for (Side side : {Side::kLhs, Side::kRhs}) { - const std::uint32_t mask = (1u << block_map.rectangularness_log2[side]) - 1; - const int rectangular_offset = (rectangular_index & mask) - << block_map.num_blocks_base_log2; - (*block)[side] = local_pos[side] + rectangular_offset; - } -} - -BlockMapTraversalOrder GetTraversalOrder(int rows, int cols, int depth, - int lhs_scalar_size, - int rhs_scalar_size, - int local_data_cache_size, - int shared_data_cache_size) { - const int kFractalOptSets = - RUY_OPT_FRACTAL_Z | RUY_OPT_FRACTAL_U | RUY_OPT_FRACTAL_HILBERT; - const int working_set_size = - (lhs_scalar_size * rows + rhs_scalar_size * cols) * depth; - if (RUY_OPT_ENABLED(kFractalOptSets) && - (working_set_size > local_data_cache_size)) { - if (RUY_OPT_ENABLED(RUY_OPT_FRACTAL_HILBERT) && - (working_set_size > shared_data_cache_size)) { - return BlockMapTraversalOrder::kFractalHilbert; - } else if (RUY_OPT_ENABLED(RUY_OPT_FRACTAL_U)) { - return BlockMapTraversalOrder::kFractalU; - } else { - return BlockMapTraversalOrder::kFractalZ; - } - } else { - return BlockMapTraversalOrder::kLinear; - } -} - -namespace { - -int floor_log2_quotient(int num, int denom) { - if (num <= denom) { - return 0; - } - int log2_quotient = floor_log2(num) - ceil_log2(denom); - if ((denom << (log2_quotient + 1)) <= num) { - log2_quotient++; - } - return log2_quotient; -} - -// Computes the rectangularness of the matrix shape (rows, cols). This is -// essentially just the log2 of the quotient (rows / cols). The kernel_rows and -// kernel_cols only get into the picture for clamping bounds but don't affect -// the generic computation. -void GetRectangularness(int rows, int cols, int kernel_rows, int kernel_cols, - int* rows_rectangularness_log2, - int* cols_rectangularness_log2) { - *rows_rectangularness_log2 = 0; - *cols_rectangularness_log2 = 0; - - // In GEMV-ish cases, that is when kernel blocks are as narrow as the kernel - // itself, we risk having too small kernel blocks for good kernel - // amortization. We avoid that by limiting recangularness so that kernel - // blocks are not too tiny at least in that dimension. Specifically, we try to - // have at least (2^min_kernel_inner_loop_runs_log2) kernels fitting in each - // kernel block along the large dimension. - const int min_kernel_inner_loop_runs_log2 = 3; - if (rows > cols) { - int cols_of_kernel_inner_loop_runs_log2 = - ceil_log2(cols) - pot_log2(kernel_cols); - int min_rows_of_kernel_inner_loop_runs_log2 = - std::max(0, min_kernel_inner_loop_runs_log2 - - cols_of_kernel_inner_loop_runs_log2); - *rows_rectangularness_log2 = - std::min(floor_log2_quotient(rows, cols), - std::max(0, floor_log2(rows) - pot_log2(kernel_rows) - - min_rows_of_kernel_inner_loop_runs_log2)); - // Sanity check that we did not over-estimate rows_rectangularness_log2. - RUY_DCHECK_GE(rows >> *rows_rectangularness_log2, cols); - } else if (cols > rows) { - int rows_of_kernel_inner_loop_runs_log2 = - ceil_log2(rows) - pot_log2(kernel_rows); - int min_cols_of_kernel_inner_loop_runs_log2 = - std::max(0, min_kernel_inner_loop_runs_log2 - - rows_of_kernel_inner_loop_runs_log2); - *cols_rectangularness_log2 = - std::min(floor_log2_quotient(cols, rows), - std::max(0, floor_log2(cols) - pot_log2(kernel_cols) - - min_cols_of_kernel_inner_loop_runs_log2)); - // Sanity check that we did not over-estimate cols_rectangularness_log2. - RUY_DCHECK_GE(cols >> *cols_rectangularness_log2, rows); - } - RUY_DCHECK(!*rows_rectangularness_log2 || !*cols_rectangularness_log2); -} - -// Computes a 'multithreading score'. When multithreading, we need there to -// be at least as many tiles as there are threads, and hopefully -// substantially more than that, so we benefit from ruy's ability to -// dispatch fine-grained workloads to threads. -int GetMultithreadingScore(int block_size_log2, int rows, int cols, - int tentative_thread_count) { - const int num_full_blocks_of_rows = rows >> block_size_log2; - const int num_full_blocks_of_cols = cols >> block_size_log2; - const int candidate_num_full_blocks_log2 = floor_log2( - std::max(1, num_full_blocks_of_rows * num_full_blocks_of_cols)); - - // The values here have been tuned on ARM Cortex-A55. - // We expect this to have to be tuned differently for other CPUs. - if (tentative_thread_count == 1) { - return 0; - } else { - const int blocks_per_thread_log2 = - candidate_num_full_blocks_log2 - ceil_log2(tentative_thread_count); - if (blocks_per_thread_log2 < 0) { - return -64; - } else if (blocks_per_thread_log2 == 0) { - return -16; - } else if (blocks_per_thread_log2 == 1) { - return -8; - } else if (blocks_per_thread_log2 == 2) { - return 0; - } else if (blocks_per_thread_log2 == 3) { - return 8; - } else { - return 16; - } - } -} - -// Computes a 'cache locality score'. -int GetCacheLocalityScore(int block_size_log2, int rows, int cols, int depth, - int kernel_rows_log2, int kernel_cols_log2, - int lhs_scalar_size, int rhs_scalar_size, Path path, - int local_data_cache_size) { - // In the narrow case (e.g. matrix*vector), each byte of the big operand - // matrix (either LHS or RHS) is traversed only once, so any notion of data - // locality is irrelevant. Ignore the 'cache locality score' by forcing it to - // be 0 in that case. - if (rows <= (1 << kernel_rows_log2) || cols <= (1 << kernel_cols_log2)) { - return 0; - } - const int block_rows = std::min(1 << block_size_log2, rows); - const int block_cols = std::min(1 << block_size_log2, cols); - const int total_read_bytes = - (lhs_scalar_size * block_rows + rhs_scalar_size * block_cols) * depth; - const int total_read_bytes_log2 = ceil_log2(total_read_bytes); - const int nonlocality_log2 = - total_read_bytes_log2 - floor_log2(local_data_cache_size); - // The values here have been tuned on ARM Cortex-A55. - // We expect this to have to be tuned differently for other CPUs. - if (nonlocality_log2 < -1) { - return 64; - } else if (nonlocality_log2 == -1) { - return 56; - } else if (nonlocality_log2 == 0) { - return 48; - } else if (nonlocality_log2 == 1) { - return 32; - } else if (nonlocality_log2 == 2) { - return 16; - } else if (nonlocality_log2 == 3) { - return 0; - } else { - return -64; - } -} - -// Compute a 'kernel amortization score'. This is the notion that very small -// tiles result in more overhead outside of kernels, more complex memory -// access patterns and less benefits from ruy's fat kernels, so we reward -// larger blocks more than smaller ones. -int GetKernelAmortizationScore(int block_size_log2, int rows, int cols, - int kernel_rows_log2, int kernel_cols_log2) { - const int block_rows = std::min(1 << block_size_log2, rows); - const int block_cols = std::min(1 << block_size_log2, cols); - const int kernels_per_block_log2 = - floor_log2(block_rows * block_cols) - kernel_rows_log2 - kernel_cols_log2; - RUY_DCHECK_GE(kernels_per_block_log2, 0); - // The values here have been tuned on ARM Cortex-A55. - // We expect this to have to be tuned differently for other CPUs. - if (kernels_per_block_log2 == 0) { - return 0; - } else if (kernels_per_block_log2 == 1) { - return 8; - } else if (kernels_per_block_log2 == 2) { - return 16; - } else if (kernels_per_block_log2 == 3) { - return 24; - } else if (kernels_per_block_log2 == 4) { - return 32; - } else if (kernels_per_block_log2 == 5) { - return 40; - } else if (kernels_per_block_log2 == 6) { - return 48; - } else if (kernels_per_block_log2 == 7) { - return 56; - } else { - return 64; - } -} - -} // namespace - -void MakeBlockMap(int rows, int cols, int depth, int kernel_rows, - int kernel_cols, int lhs_scalar_size, int rhs_scalar_size, - int tentative_thread_count, Path path, - int local_data_cache_size, int shared_data_cache_size, - BlockMap* block_map) { - profiler::ScopeLabel label("MakeBlockMap"); - -#ifdef RUY_MAKEBLOCKMAP_DEBUG -#if RUY_MAKEBLOCKMAP_DEBUG >= 2 - static constexpr bool debug_everytime = true; -#else - static constexpr bool debug_everytime = false; -#endif - static bool firsttime = true; - if (firsttime || debug_everytime) { - fprintf(stderr, - "MakeBlockMap(rows=%d, cols=%d, depth=%d, kernel_rows=%d, " - "kernel_cols=%d, lhs_scalar_size=%d, rhs_scalar_size=%d, " - "tentative_thread_count=%d)\n", - rows, cols, depth, kernel_rows, kernel_cols, lhs_scalar_size, - rhs_scalar_size, tentative_thread_count); - } -#endif - - RUY_DCHECK_GE(rows, kernel_rows); - RUY_DCHECK_GE(cols, kernel_cols); - RUY_DCHECK_EQ(rows % kernel_rows, 0); - RUY_DCHECK_EQ(cols % kernel_cols, 0); - - block_map->traversal_order = - GetTraversalOrder(rows, cols, depth, lhs_scalar_size, rhs_scalar_size, - local_data_cache_size, shared_data_cache_size); - - int rows_rectangularness_log2 = 0; - int cols_rectangularness_log2 = 0; - GetRectangularness(rows, cols, kernel_rows, kernel_cols, - &rows_rectangularness_log2, &cols_rectangularness_log2); - - const int kernel_rows_log2 = pot_log2(kernel_rows); - const int kernel_cols_log2 = pot_log2(kernel_cols); - const int kernel_size_log2 = std::max(kernel_cols_log2, kernel_rows_log2); - - const int size = std::min(rows, cols); - const int size_log2 = std::max(kernel_size_log2, floor_log2(size)); - - RUY_DCHECK_GE(size_log2, kernel_size_log2); - - // We are going to try candidate values for block_size_log2 ranging from - // kernel_size_log2 to (kernel_size_log2 + kMaxKernelsPerBlockLog2). - // For each of them we will compute a 'score' by adding individual scores - // for a few different considerations, all of which is entirely empirical. - // The values (and possibly the logic) around here are all subject to tuning - // based on benchmarks on different hardware. The current values are based - // on benchmarking on Qualcomm S855 (big and little cores), arm64, - // kNeonDotprod, 8bit quantized path. Don't read too much into it, go ahead - // and tune this as needed to achieve good performance elsewhere. Use - // the unit test, block_map_test, to encode values that should be preserved - // on specific architectures. Use RUY_MAKEBLOCKMAP_DEBUG to help tuning this. - static constexpr int kMaxKernelsPerBlockLog2 = 6; - const int max_block_size_log2 = - std::min(size_log2, kernel_size_log2 + kMaxKernelsPerBlockLog2); - int best_score = std::numeric_limits::min(); - int best_score_block_size_log2 = -1; - for (int block_size_log2 = kernel_size_log2; - block_size_log2 <= max_block_size_log2; block_size_log2++) { - const int multithreading_score = GetMultithreadingScore( - block_size_log2, rows, cols, tentative_thread_count); - const int cache_locality_score = GetCacheLocalityScore( - block_size_log2, rows, cols, depth, kernel_rows_log2, kernel_cols_log2, - lhs_scalar_size, rhs_scalar_size, path, local_data_cache_size); - const int kernel_amortization_score = GetKernelAmortizationScore( - block_size_log2, rows, cols, kernel_rows_log2, kernel_cols_log2); - const int score = - multithreading_score + cache_locality_score + kernel_amortization_score; -#ifdef RUY_MAKEBLOCKMAP_DEBUG - if (firsttime || debug_everytime) { - fprintf(stderr, - "block_size_log2=%d: score=%d multithreading_score=%d " - "cache_locality_score=%d kernel_amortization_score=%d\n", - block_size_log2, score, multithreading_score, - cache_locality_score, kernel_amortization_score); - } -#endif - if (score >= best_score) { - best_score = score; - best_score_block_size_log2 = block_size_log2; - } - } - -#ifdef RUY_MAKEBLOCKMAP_DEBUG - if (firsttime || debug_everytime) { - fprintf(stderr, "best_score_block_size_log2=%d\n", - best_score_block_size_log2); - } - - static const char* explicit_block_size_log2_env = - getenv("RUY_MAKEBLOCKMAP_EXPLICIT_BLOCK_SIZE_LOG2"); - if (explicit_block_size_log2_env) { - best_score_block_size_log2 = std::stoi(explicit_block_size_log2_env); - if (firsttime || debug_everytime) { - fprintf(stderr, "Overridden best_score_block_size_log2=%d\n", - best_score_block_size_log2); - } - } - firsttime = false; -#endif - - int num_blocks_base_log2 = size_log2 - best_score_block_size_log2; - RUY_DCHECK_GE(num_blocks_base_log2, 0); - - const int num_blocks_of_rows_log2 = - num_blocks_base_log2 + rows_rectangularness_log2; - const int num_blocks_of_cols_log2 = - num_blocks_base_log2 + cols_rectangularness_log2; - - const int smallr = - round_down_pot(rows >> num_blocks_of_rows_log2, kernel_rows); - const int smallc = - round_down_pot(cols >> num_blocks_of_cols_log2, kernel_cols); - const int missr = - round_up_pot(rows - (smallr << num_blocks_of_rows_log2), kernel_rows) >> - pot_log2(kernel_rows); - const int missc = - round_up_pot(cols - (smallc << num_blocks_of_cols_log2), kernel_cols) >> - pot_log2(kernel_cols); - - block_map->dims[Side::kLhs] = rows; - block_map->dims[Side::kRhs] = cols; - block_map->kernel_dims[Side::kLhs] = kernel_rows; - block_map->kernel_dims[Side::kRhs] = kernel_cols; - block_map->num_blocks_base_log2 = num_blocks_base_log2; - block_map->rectangularness_log2[Side::kLhs] = rows_rectangularness_log2; - block_map->rectangularness_log2[Side::kRhs] = cols_rectangularness_log2; - block_map->small_block_dims[Side::kLhs] = smallr; - block_map->small_block_dims[Side::kRhs] = smallc; - block_map->large_blocks[Side::kLhs] = missr; - block_map->large_blocks[Side::kRhs] = missc; - // Done last: NumBlocks needs some of the block_map fields to be already set. - block_map->thread_count = - std::min(tentative_thread_count, NumBlocks(*block_map)); -} - -void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block, - int* start, int* end) { - profiler::ScopeLabel label("GetBlockMatrixCoords"); - *start = block * block_map.small_block_dims[side] + - std::min(block, block_map.large_blocks[side]) * - block_map.kernel_dims[side]; - *end = - *start + block_map.small_block_dims[side] + - (block < block_map.large_blocks[side] ? block_map.kernel_dims[side] : 0); - - RUY_DCHECK_EQ(0, *start % block_map.kernel_dims[side]); - RUY_DCHECK_EQ(0, *end % block_map.kernel_dims[side]); - RUY_DCHECK_LE(*end, block_map.dims[side]); - RUY_DCHECK_LT(*start, *end); - RUY_DCHECK_GE(*start, 0); -} - -void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair& block, - SidePair* start, SidePair* end) { - for (Side side : {Side::kLhs, Side::kRhs}) { - GetBlockMatrixCoords(side, block_map, block[side], &(*start)[side], - &(*end)[side]); - } -} - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/block_map.h b/tensorflow/lite/experimental/ruy/ruy/block_map.h deleted file mode 100644 index 0fa4c9d5d60..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/block_map.h +++ /dev/null @@ -1,161 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCK_MAP_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCK_MAP_H_ - -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" - -namespace ruy { - -enum class BlockMapTraversalOrder { - // Plain old row-by-row or column-by-column traversal. - kLinear, - // Fractal Z-order curve, https://en.wikipedia.org/wiki/Z-order_curve - kFractalZ, - // Variant of Z-order doing a U instead of a Z. - kFractalU, - // Hilbert curve, https://en.wikipedia.org/wiki/Hilbert_curve - kFractalHilbert -}; - -// A BlockMap describes a tiling of a matrix, typically the destination matrix -// of a matrix multiplication computation. As is standard in matrix -// multiplication, a tile is called a "block". -// -// Ruy subdivides work by blocks of the destination matrix: each thread fully -// computes a block at once, then moves on to another block; each block is -// produced by a single thread. -// -// This ensures that the workloads for each block are mutually independent, -// which reduces synchronization requirements. -// -// Typically, a matrix multiplication will early on create a BlockMap by -// calling MakeBlockMap. It will then query the number of blocks in that -// BlockMap by calling NumBlocks. It will then create a single atomic integer -// counter indexing these blocks, called the 'index', and will distribute -// work to its N threads by ensuring that each thread works on disjoint sets -// of index values. For a given index value, the thread will call -// GetBlockByIndex to get the corresponding block, then GetBlockMatrixCoords -// to find the actual row and column numbers of this block. -// -// There are two nested levels of subdivision. On a local level, the matrix is -// tiled into a square NxN grid where N is a power of two, specifically: -// N = 2^num_blocks_base_log2. -// -// At a larger scale, around these blocks, there may be one further -// level of subdivision, in only one dimension: either along rows or along -// columns. That is used to handle arbitrarily rectangular matrices. The -// aforementioned high-level block grid is square, so it does not readily fit -// well very rectangular matrices. -// -// Taking together these two nested levels of subdivision, the effective -// tiling is by -// 2^(num_blocks_base_log2 + rows_rectangularness_log2) -// blocks in the row dimension, and by -// 2^(num_blocks_base_log2 + cols_rectangularness_log2) -// blocks in the column dimension. See NumBlocksOfRows, NumBlocksOfCols. -// -// Either rows_rectangularness_log2 or cols_rectangularness_log2 must be zero. -// -// Finally, this BlockMap is designed to operate under alignment constraints: -// two fields, kernel_rows and kernel_cols, describe the requested alignment -// of the effective grid in both dimensions. The idea is to feed matrix -// multiplication kernels with tiles that fit their width as much as possible. -// Of course, if rows (resp. cols) is not a multiple of kernel_rows (resp. -// kernel_cols) then some tile will have to have unaligned size. BlockMap -// will only allow that to happen in the last position along each axis, so -// as to minimize the overhead incurred onto the matrix multiplication kernels. -struct BlockMap { - // The number of threads to use (to distribute the blocks to). - int thread_count; - // The order in which to traverse the matrix of which this BlockMap represents - // a tiling (hereafter "the matrix"). - BlockMapTraversalOrder traversal_order; - // The dimensions of the block_map, that is, of the destination - // matrix rounded up to next multiples of kernel_dims. - SidePair dims; - // Log2 of the minimum number of subdivisions of the grid along either axis. - int num_blocks_base_log2; - // Log2 of the additional subdivision of the rows/columns axis. - SidePair rectangularness_log2; - // Requested alignment of the subdivisions of the grid along the rows/columns - // axis. - SidePair kernel_dims; - // Internal helper. Minimum number of rows/columns in each block. - SidePair small_block_dims; - // Internal helper. Number of blocks along each dimension that need to have - // their size in that dimension be given by (small_block_dims + kernel_dims) - // instead of just small_block_dims. - SidePair large_blocks; -}; - -// Returns the traversal order to be used for the given matrix multiplication -// parameters. -BlockMapTraversalOrder GetTraversalOrder(int rows, int cols, int depth, - int lhs_scalar_size, - int rhs_scalar_size, - int local_data_cache_size, - int shared_data_cache_size); - -// Create a BlockMap suitable for tiling the destination matrix in a -// matrix multiplication with the given parameters. -void MakeBlockMap(int rows, int cols, int depth, int kernel_rows, - int kernel_cols, int lhs_scalar_size, int rhs_scalar_size, - int tentative_thread_count, Path path, - int local_data_cache_size, int shared_data_cache_size, - BlockMap* block_map); - -// Maps an integer index to a block position in the grid. -void GetBlockByIndex(const BlockMap& block_map, int index, - SidePair* block); - -// Given a block position in the grid, returns its actual -// position in the matrix that the BlockMap refers to in the dimension -// referred to by `side`: along rows if side==kLhs, along columns if -// side==kRhs. -void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block, - int* start, int* end); - -// Given a block position in the grid, returns its actual -// position in the matrix that the BlockMap refers to in terms of -// actual row/column indices. -void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair& block, - SidePair* start, SidePair* end); - -// Returns the number of grid subdivisions along the rows dimension (if -// side == kLhs) or columns dimension (if side == kRhs). -inline int NumBlocksPerSide(Side side, const BlockMap& block_map) { - return 1 << (block_map.num_blocks_base_log2 + - block_map.rectangularness_log2[side]); -} - -// Returns the overall number of blocks in -// the BlockMap. The valid index values to pass to GetBlockByIndex are the -// integers from 0 to N-1 where N is the value returned here. -// -// Note that it is always true that -// NumBlocks == NumBlocksOfRows * NumBlocksOfCols -// because either rows_rectangularness_log2 or cols_rectangularness_log2 is 0. -inline int NumBlocks(const BlockMap& block_map) { - return 1 << (2 * block_map.num_blocks_base_log2 + - block_map.rectangularness_log2[Side::kLhs] + - block_map.rectangularness_log2[Side::kRhs]); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCK_MAP_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/block_map_test.cc b/tensorflow/lite/experimental/ruy/ruy/block_map_test.cc deleted file mode 100644 index cdd7ee0e01f..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/block_map_test.cc +++ /dev/null @@ -1,263 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/block_map.h" - -#include -#include -#include -#include -#include - -#include -#include "tensorflow/lite/experimental/ruy/ruy/cpu_cache_size.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" - -namespace ruy { -namespace { - -#if RUY_PLATFORM(NEON_64) - -// Unless otherwise specified, these tests have been tuned on ARM Cortex-A55. -void MakeBlockMapTuningTest(int rows, int cols, int depth, int kernel_rows, - int kernel_cols, int lhs_scalar_size, - int rhs_scalar_size, int tentative_thread_count, - Path path, int expected_num_blocks_base_log2, - int expected_rectangularness_log2) { - BlockMap block_map; - MakeBlockMap(rows, cols, depth, kernel_rows, kernel_cols, lhs_scalar_size, - rhs_scalar_size, tentative_thread_count, path, - LocalDataCacheSize(path), SharedDataCacheSize(path), &block_map); - EXPECT_EQ(block_map.num_blocks_base_log2, expected_num_blocks_base_log2); - EXPECT_EQ(std::min(block_map.rectangularness_log2[Side::kLhs], - block_map.rectangularness_log2[Side::kRhs]), - 0); - EXPECT_EQ(std::max(block_map.rectangularness_log2[Side::kLhs], - block_map.rectangularness_log2[Side::kRhs]), - expected_rectangularness_log2); -} - -TEST(BlockMapTest, MakeBlockMapTuningTest8bitCubicShapesOneThreadNeonDotprod) { - MakeBlockMapTuningTest(32, 32, 32, 8, 8, 1, 1, /* tentative_thread_count */ 1, - Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 0, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(48, 48, 48, 8, 8, 1, 1, /* tentative_thread_count */ 1, - Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 0, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(64, 64, 64, 8, 8, 1, 1, /* tentative_thread_count */ 1, - Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 0, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(96, 96, 96, 8, 8, 1, 1, /* tentative_thread_count */ 1, - Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 0, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(128, 128, 128, 8, 8, 1, 1, - /* tentative_thread_count */ 1, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 0, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(192, 192, 192, 8, 8, 1, 1, - /* tentative_thread_count */ 1, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 0, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(256, 256, 256, 8, 8, 1, 1, - /* tentative_thread_count */ 1, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 1, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(384, 384, 384, 8, 8, 1, 1, - /* tentative_thread_count */ 1, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 1, - /* expected_rectangularness_log2 */ 0); -} - -TEST(BlockMapTest, - MakeBlockMapTuningTest8bitCubicShapesFourThreadsNeonDotprod) { - MakeBlockMapTuningTest(32, 32, 32, 8, 8, 1, 1, /* tentative_thread_count */ 4, - Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 1, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(48, 48, 48, 8, 8, 1, 1, /* tentative_thread_count */ 4, - Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 1, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(64, 64, 64, 8, 8, 1, 1, /* tentative_thread_count */ 4, - Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 1, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(96, 96, 96, 8, 8, 1, 1, /* tentative_thread_count */ 4, - Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 1, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(128, 128, 128, 8, 8, 1, 1, - /* tentative_thread_count */ 4, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 1, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(192, 192, 192, 8, 8, 1, 1, - /* tentative_thread_count */ 4, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 1, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(256, 256, 256, 8, 8, 1, 1, - /* tentative_thread_count */ 4, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 2, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(384, 384, 384, 8, 8, 1, 1, - /* tentative_thread_count */ 4, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 2, - /* expected_rectangularness_log2 */ 0); -} - -TEST(BlockMapTest, MakeBlockMapTuningTest32bit) { - MakeBlockMapTuningTest(256, 256, 256, 8, 8, 4, 4, - /* tentative_thread_count */ 4, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 3, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(4096, 4096, 4096, 8, 8, 4, 4, - /* tentative_thread_count */ 4, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 7, - /* expected_rectangularness_log2 */ 0); -} - -TEST(BlockMapTest, MakeBlockMapTuningTestRectangular) { - MakeBlockMapTuningTest(256, 16, 256, 8, 8, 1, 1, - /* tentative_thread_count */ 1, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 0, - /* expected_rectangularness_log2 */ 3); - MakeBlockMapTuningTest(24, 2400, 256, 8, 8, 1, 1, - /* tentative_thread_count */ 1, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 0, - /* expected_rectangularness_log2 */ 6); -} - -#endif - -int L1Distance(const SidePair& a, const SidePair& b) { - return std::abs(a[Side::kLhs] - b[Side::kLhs]) + - std::abs(a[Side::kRhs] - b[Side::kRhs]); -} - -void GetBlockByIndexSquareTest(int num_blocks_base_log2, - BlockMapTraversalOrder traversal_order) { - // Arbitrary, does not affect this test. 3 is just a typical value. - constexpr int kKernelSizeLog2 = 3; - - const int size_log2 = num_blocks_base_log2 + kKernelSizeLog2; - BlockMap block_map; - block_map.thread_count = 1; - block_map.traversal_order = traversal_order; - block_map.num_blocks_base_log2 = num_blocks_base_log2; - for (Side side : {Side::kLhs, Side::kRhs}) { - block_map.dims[side] = 1 << size_log2; - block_map.rectangularness_log2[side] = 0; - block_map.kernel_dims[side] = 1 << kKernelSizeLog2; - block_map.small_block_dims[side] = block_map.kernel_dims[side]; - block_map.large_blocks[side] = 0; - } - - const int num_blocks_per_side = 1 << num_blocks_base_log2; - const int num_blocks = num_blocks_per_side * num_blocks_per_side; - EXPECT_EQ(num_blocks, NumBlocks(block_map)); - - // Perform a full traversal of all blocks, as if computing a whole matrix - // multiplication. - // - // Used to record how many times each block was hit by the traversal. - std::vector block_hit_counts(num_blocks); - // Here we guard an assumption that all traversal orders start at (0, 0). - SidePair previous_block_coords(0, 0); - // Sum of L1 norm of the coordinate change at every step of the traversal. - std::int64_t total_l1_distance = 0; - // Number of jumps i.e. traversal steps with a L1 norm greater than 1. - int discontinuity_count = 0; - for (int block_index = 0; block_index < num_blocks; block_index++) { - SidePair block_coords; - GetBlockByIndex(block_map, block_index, &block_coords); - ++block_hit_counts[block_coords[Side::kLhs] + - num_blocks_per_side * block_coords[Side::kRhs]]; - int distance = L1Distance(block_coords, previous_block_coords); - total_l1_distance += distance; - discontinuity_count += (distance > 1); - previous_block_coords = block_coords; - } - - // Verify that each block was traversed exactly once. - for (int l = 0; l < num_blocks_per_side; l++) { - for (int r = 0; r < num_blocks_per_side; r++) { - EXPECT_EQ(block_hit_counts[l + num_blocks_per_side * r], 1); - } - } - - // Verify that the discontinuity_count and total_l1_distance are as expected - // for the given traversal_order. - switch (traversal_order) { - case BlockMapTraversalOrder::kFractalHilbert: - // No discontinuity at all with this space-filling continuous curve! - EXPECT_EQ(discontinuity_count, 0); - // Therefore, total_l1_distance has to be the number of blocks minus one. - EXPECT_EQ(total_l1_distance, num_blocks - 1); - break; - case BlockMapTraversalOrder::kLinear: - EXPECT_EQ(discontinuity_count, num_blocks_per_side - 1); - EXPECT_EQ(total_l1_distance, - 2 * num_blocks_per_side * (num_blocks_per_side - 1)); - break; - case BlockMapTraversalOrder::kFractalZ: - EXPECT_EQ(discontinuity_count, num_blocks > 1 ? (num_blocks / 2 - 1) : 0); - EXPECT_EQ(total_l1_distance, - 2 * num_blocks_per_side * (num_blocks_per_side - 1)); - break; - case BlockMapTraversalOrder::kFractalU: { - if (num_blocks_base_log2 == 0) { - EXPECT_EQ(discontinuity_count, 0); - EXPECT_EQ(total_l1_distance, 0); - } else { - int expected_discontinuity_count = 0; - int expected_total_l1_distance = 3; - for (int i = 2; i <= num_blocks_base_log2; i++) { - expected_discontinuity_count = 4 * expected_discontinuity_count + 2; - expected_total_l1_distance = - 4 * expected_total_l1_distance + (1 << (i + 1)) - 1; - } - EXPECT_EQ(discontinuity_count, expected_discontinuity_count); - EXPECT_EQ(total_l1_distance, expected_total_l1_distance); - } - break; - } - default: - abort(); - } -} - -TEST(BlockMapTest, GetBlockByIndexSquare) { - for (int num_blocks_base_log2 = 0; num_blocks_base_log2 <= 10; - num_blocks_base_log2++) { - for (BlockMapTraversalOrder traversal_order : - {BlockMapTraversalOrder::kLinear, BlockMapTraversalOrder::kFractalZ, - BlockMapTraversalOrder::kFractalU, - BlockMapTraversalOrder::kFractalHilbert}) { - GetBlockByIndexSquareTest(num_blocks_base_log2, traversal_order); - } - } -} - -} // namespace -} // namespace ruy - -int main(int argc, char **argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/experimental/ruy/ruy/blocking_counter.cc b/tensorflow/lite/experimental/ruy/ruy/blocking_counter.cc deleted file mode 100644 index d313ffce51b..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/blocking_counter.cc +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/blocking_counter.h" - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/wait.h" - -namespace ruy { - -void BlockingCounter::Reset(int initial_count) { - int old_count_value = count_.load(std::memory_order_relaxed); - RUY_DCHECK_EQ(old_count_value, 0); - (void)old_count_value; - count_.store(initial_count, std::memory_order_release); -} - -bool BlockingCounter::DecrementCount() { - int old_count_value = count_.fetch_sub(1, std::memory_order_acq_rel); - RUY_DCHECK_GT(old_count_value, 0); - int count_value = old_count_value - 1; - bool hit_zero = (count_value == 0); - if (hit_zero) { - std::lock_guard lock(count_mutex_); - count_cond_.notify_all(); - } - return hit_zero; -} - -void BlockingCounter::Wait() { - const auto& condition = [this]() { - return count_.load(std::memory_order_acquire) == 0; - }; - ruy::Wait(condition, &count_cond_, &count_mutex_); -} - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/blocking_counter.h b/tensorflow/lite/experimental/ruy/ruy/blocking_counter.h deleted file mode 100644 index 878f0e7219e..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/blocking_counter.h +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCKING_COUNTER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCKING_COUNTER_H_ - -#include -#include // NOLINT(build/c++11) // IWYU pragma: keep -#include // NOLINT(build/c++11) // IWYU pragma: keep - -namespace ruy { - -// A BlockingCounter lets one thread to wait for N events to occur. -// This is how the master thread waits for all the worker threads -// to have finished working. -// The waiting is done using a naive spinlock waiting for the atomic -// count_ to hit the value 0. This is acceptable because in our usage -// pattern, BlockingCounter is used only to synchronize threads after -// short-lived tasks (performing parts of the same GEMM). It is not used -// for synchronizing longer waits (resuming work on the next GEMM). -class BlockingCounter { - public: - BlockingCounter() : count_(0) {} - - // Sets/resets the counter; initial_count is the number of - // decrementing events that the Wait() call will be waiting for. - void Reset(int initial_count); - - // Decrements the counter; if the counter hits zero, signals - // the threads that were waiting for that, and returns true. - // Otherwise (if the decremented count is still nonzero), - // returns false. - bool DecrementCount(); - - // Waits for the N other threads (N having been set by Reset()) - // to hit the BlockingCounter. - void Wait(); - - private: - std::atomic count_; - - // The condition variable and mutex allowing to passively wait for count_ - // to reach the value zero, in the case of longer waits. - std::condition_variable count_cond_; - std::mutex count_mutex_; -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCKING_COUNTER_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/build_defs.bzl b/tensorflow/lite/experimental/ruy/ruy/build_defs.bzl deleted file mode 100644 index 9bccccf6316..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/build_defs.bzl +++ /dev/null @@ -1,40 +0,0 @@ -"""Build definitions for Ruy.""" - -# 1. Enable -mfpu=neon unconditionally on ARM32. If it turns out that we need to support -# ARM32 without NEON then we'll implement runtime detection and dispatch at that point. -# 2. Explicitly pass -O3 on optimization configs where just "-c opt" means "optimize for code size". - -def ruy_copts_base(): - return select({ - ":armeabi-v7a": [ - "-mfpu=neon", - ], - "//conditions:default": [], - }) + select({ - ":optimized": ["-O3"], - "//conditions:default": [], - }) - -# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. -def ruy_copts_skylake(): - return [] - -# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. -def ruy_copts_avx2(): - return [] - -# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -# Optimization is not finished. In particular the dimensions of the kernel -# blocks can be changed as desired. -# -# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. -def ruy_copts_sse42(): - return [] - -# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -# Optimization is not finished. In particular the dimensions of the kernel -# blocks can be changed as desired. -# -# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. -def ruy_copts_avxvnni(): - return [] diff --git a/tensorflow/lite/experimental/ruy/ruy/check_macros.h b/tensorflow/lite/experimental/ruy/ruy/check_macros.h deleted file mode 100644 index 773f37d99f2..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/check_macros.h +++ /dev/null @@ -1,138 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CHECK_MACROS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CHECK_MACROS_H_ - -#include -#include -#include - -namespace ruy { -namespace check_macros { - -constexpr int kValueBufSize = 32; - -template -struct ToString { - static void Run(const T& value, char* buf) { - snprintf(buf, kValueBufSize, "(?)"); - } -}; - -template <> -struct ToString { - static void Run(float value, char* buf) { - snprintf(buf, kValueBufSize, "%.9g", static_cast(value)); - } -}; - -template <> -struct ToString { - static void Run(double value, char* buf) { - snprintf(buf, kValueBufSize, "%.16g", value); - } -}; - -template -struct ToString::value>::type> { - static void Run(const T& value, char* buf) { - snprintf(buf, kValueBufSize, "%lld", static_cast(value)); - } -}; - -template -struct ToString { - static void Run(T* value, char* buf) { - snprintf(buf, kValueBufSize, "%p", value); - } -}; - -template -struct ToString::value>::type> { - static void Run(const T& value, char* buf) { - snprintf(buf, kValueBufSize, "(enum value %d)", static_cast(value)); - } -}; - -inline void Failure(const char* file, int line, const char* macro, - const char* condition) { - fprintf(stderr, "%s:%d: %s condition not satisfied: %s\n", file, line, macro, - condition); - abort(); -} - -template -inline void Failure(const char* file, int line, const char* macro, - const char* lhs, const LhsType& lhs_value, const char* op, - const char* rhs, const RhsType& rhs_value) { - char lhs_value_buf[kValueBufSize]; - ToString::Run(lhs_value, lhs_value_buf); - char rhs_value_buf[kValueBufSize]; - ToString::Run(rhs_value, rhs_value_buf); - fprintf(stderr, - "%s:%d: %s condition not satisfied: [ %s %s %s ] with values [ " - "%s %s %s ].\n", - file, line, macro, lhs, op, rhs, lhs_value_buf, op, rhs_value_buf); - abort(); -} - -#define RUY_CHECK_IMPL(macro, condition) \ - do { \ - if (!(condition)) { \ - ruy::check_macros::Failure(__FILE__, __LINE__, #macro, #condition); \ - } \ - } while (false) - -#define RUY_CHECK_OP_IMPL(macro, lhs, op, rhs) \ - do { \ - const auto& lhs_value = (lhs); \ - const auto& rhs_value = (rhs); \ - if (!(lhs_value op rhs_value)) { \ - ruy::check_macros::Failure(__FILE__, __LINE__, #macro, #lhs, lhs_value, \ - #op, #rhs, rhs_value); \ - } \ - } while (false) - -#define RUY_CHECK(condition) RUY_CHECK_IMPL(RUY_CHECK, condition) -#define RUY_CHECK_EQ(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_EQ, x, ==, y) -#define RUY_CHECK_NE(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_NE, x, !=, y) -#define RUY_CHECK_GE(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_GE, x, >=, y) -#define RUY_CHECK_GT(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_GT, x, >, y) -#define RUY_CHECK_LE(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_LE, x, <=, y) -#define RUY_CHECK_LT(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_LT, x, <, y) - -#ifdef NDEBUG -#define RUY_DCHECK(condition) -#define RUY_DCHECK_EQ(x, y) -#define RUY_DCHECK_NE(x, y) -#define RUY_DCHECK_GE(x, y) -#define RUY_DCHECK_GT(x, y) -#define RUY_DCHECK_LE(x, y) -#define RUY_DCHECK_LT(x, y) -#else -#define RUY_DCHECK(condition) RUY_CHECK(condition) -#define RUY_DCHECK_EQ(x, y) RUY_CHECK_EQ(x, y) -#define RUY_DCHECK_NE(x, y) RUY_CHECK_NE(x, y) -#define RUY_DCHECK_GE(x, y) RUY_CHECK_GE(x, y) -#define RUY_DCHECK_GT(x, y) RUY_CHECK_GT(x, y) -#define RUY_DCHECK_LE(x, y) RUY_CHECK_LE(x, y) -#define RUY_DCHECK_LT(x, y) RUY_CHECK_LT(x, y) -#endif - -} // end namespace check_macros -} // end namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CHECK_MACROS_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/check_macros_test.cc b/tensorflow/lite/experimental/ruy/ruy/check_macros_test.cc deleted file mode 100644 index 1a2a5a238f2..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/check_macros_test.cc +++ /dev/null @@ -1,153 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" - -#include - -namespace { - -#define TEST_CONDITION_FOR_FAMILY(family, vacuously_succeeds, condition) \ - do { \ - if (vacuously_succeeds || (condition)) { \ - RUY_##family(condition); \ - } \ - } while (false) - -#define TEST_COMPARISON_FOR_FAMILY(family, vacuously_succeeds, op_name, x, op, \ - y) \ - do { \ - if (vacuously_succeeds || ((x)op(y))) { \ - RUY_##family##_##op_name(x, y); \ - } \ - } while (false) - -#ifdef NDEBUG -#define TEST_CONDITION(condition) \ - do { \ - TEST_CONDITION_FOR_FAMILY(CHECK, false, condition); \ - } while (false) -#define TEST_COMPARISON(op_name, x, op, y) \ - do { \ - TEST_COMPARISON_FOR_FAMILY(CHECK, false, op_name, x, op, y); \ - } while (false) -#else -#define TEST_CONDITION(condition) \ - do { \ - TEST_CONDITION_FOR_FAMILY(CHECK, false, condition); \ - TEST_CONDITION_FOR_FAMILY(DCHECK, false, condition); \ - } while (false) -#define TEST_COMPARISON(op_name, x, op, y) \ - do { \ - TEST_COMPARISON_FOR_FAMILY(CHECK, false, op_name, x, op, y); \ - TEST_COMPARISON_FOR_FAMILY(DCHECK, false, op_name, x, op, y); \ - } while (false) - -#endif - -template -void TestEqualityComparisons(const LhsType& lhs, const RhsType& rhs) { - RUY_CHECK_EQ(lhs, lhs); - TEST_COMPARISON(EQ, lhs, ==, lhs); - RUY_CHECK_EQ(lhs, lhs); - RUY_CHECK_EQ(lhs, lhs); - if (lhs == rhs) { - RUY_CHECK_EQ(lhs, rhs); - } - if (lhs != rhs) { - RUY_CHECK_NE(lhs, rhs); - } -} - -template -void TestComparisons(const LhsType& lhs, const RhsType& rhs) { - TestEqualityComparisons(lhs, rhs); - if (lhs > rhs) { - RUY_CHECK_GT(lhs, rhs); - } - if (lhs >= rhs) { - RUY_CHECK_GE(lhs, rhs); - } - if (lhs < rhs) { - RUY_CHECK_LT(lhs, rhs); - } - if (lhs <= rhs) { - RUY_CHECK_LE(lhs, rhs); - } -} - -TEST(CheckMacrosTest, IntInt) { - TestComparisons(0, 0); - TestComparisons(0, 1); - TestComparisons(1, -1); - TestComparisons(-1, 0); - TestComparisons(123, -456); - TestComparisons(std::numeric_limits::min(), - std::numeric_limits::max()); - TestComparisons(123, std::numeric_limits::max()); - TestComparisons(123, std::numeric_limits::min()); -} - -TEST(CheckMacrosTest, Uint8Uint8) { - TestComparisons(0, 0); - TestComparisons(255, 0); - TestComparisons(0, 255); - TestComparisons(12, 34); -} - -TEST(CheckMacrosTest, Uint8Int) { - TestComparisons(0, std::numeric_limits::min()); - TestComparisons(255, std::numeric_limits::min()); - TestComparisons(0, std::numeric_limits::max()); - TestComparisons(255, std::numeric_limits::max()); -} - -TEST(CheckMacrosTest, FloatFloat) { - TestComparisons(0.f, 0.f); - TestComparisons(0.f, 1.f); - TestComparisons(1.f, -1.f); - TestComparisons(-1.f, 0.f); - TestComparisons(123.f, -456.f); - TestComparisons(std::numeric_limits::lowest(), - std::numeric_limits::max()); - TestComparisons(123.f, std::numeric_limits::max()); - TestComparisons(123.f, std::numeric_limits::lowest()); -} - -TEST(CheckMacrosTest, IntFloat) { - TestComparisons(0, 0.f); - TestComparisons(0, 1.f); - TestComparisons(1, -1.f); - TestComparisons(-1, 0.f); - TestComparisons(123, -456.f); - TestComparisons(std::numeric_limits::lowest(), - std::numeric_limits::max()); - TestComparisons(123, std::numeric_limits::max()); - TestComparisons(123, std::numeric_limits::lowest()); -} - -TEST(CheckMacrosTest, EnumClass) { - enum class SomeEnumClass { kA, kB, kC }; - TestEqualityComparisons(SomeEnumClass::kA, SomeEnumClass::kA); - TestEqualityComparisons(SomeEnumClass::kA, SomeEnumClass::kB); - TestEqualityComparisons(SomeEnumClass::kC, SomeEnumClass::kB); -} - -} // namespace - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/experimental/ruy/ruy/common.h b/tensorflow/lite/experimental/ruy/ruy/common.h deleted file mode 100644 index e52a6ba6976..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/common.h +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Miscellaneous helpers internal library. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_COMMON_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_COMMON_H_ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" - -#if RUY_OPT_ENABLED(RUY_OPT_PREFETCH_LOAD) -#define RUY_PREFETCH_LOAD(X) X -#else -#define RUY_PREFETCH_LOAD(X) -#endif - -#if RUY_OPT_ENABLED(RUY_OPT_PREFETCH_STORE) -#define RUY_PREFETCH_STORE(X) X -#else -#define RUY_PREFETCH_STORE(X) -#endif - -#define RUY_STR(s) RUY_STR_UNEXPANDED(s) -#define RUY_STR_UNEXPANDED(s) #s - -namespace ruy { - -// Helper for type-erasing a pointer. -// -// Often inside Ruy, a template parameter holds type information statically, but -// we would like to have a function signature that doesn't depend on the -// template parameters, so that we can dispatch indirectly across multiple -// implementations. This helper is at the core of such type-erasure. -// -// The opposite of this operation is just `static_cast(void_ptr)`. -template -void* ToVoidPtr(T* p) { - return const_cast(static_cast(p)); -} - -template -Scalar SymmetricZeroPoint() { - if (std::is_floating_point::value) { - return 0; - } - if (std::is_signed::value) { - return 0; - } - return std::numeric_limits::max() / 2 + 1; -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_COMMON_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/context.cc b/tensorflow/lite/experimental/ruy/ruy/context.cc deleted file mode 100644 index e0d4701645f..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/context.cc +++ /dev/null @@ -1,109 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/context.h" - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/detect_arm.h" -#include "tensorflow/lite/experimental/ruy/ruy/detect_x86.h" -#include "tensorflow/lite/experimental/ruy/ruy/have_built_path_for.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" - -namespace ruy { - -void Context::SetRuntimeEnabledPaths(Path paths) { - runtime_enabled_paths_ = paths; -} - -Path Context::GetRuntimeEnabledPaths() { - // This function should always return the same value on a given machine. - // When runtime_enabled_paths_ has its initial value kNone, it performs - // some platform detection to resolve it to specific Path values. - - // Fast path: already resolved. - if (runtime_enabled_paths_ != Path::kNone) { - return runtime_enabled_paths_; - } - - // Need to resolve now. Start by considering all paths enabled. - runtime_enabled_paths_ = kAllPaths; - - // This mechanism is intended to be used for testing and benchmarking. For - // example, one can set RUY_FORCE_DISABLE_PATHS to Path::kAvx512 in order to - // evaluate AVX2 performance on an AVX-512 machine. -#ifdef RUY_FORCE_DISABLE_PATHS - runtime_enabled_paths_ = runtime_enabled_paths_ & ~(RUY_FORCE_DISABLE_PATHS); -#endif - -#if RUY_PLATFORM(ARM) - // Now selectively disable paths that aren't supported on this machine. - if ((runtime_enabled_paths_ & Path::kNeonDotprod) != Path::kNone) { - if (!DetectDotprod()) { - runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kNeonDotprod; - // Sanity check. - RUY_DCHECK((runtime_enabled_paths_ & Path::kNeonDotprod) == Path::kNone); - } - } -#endif // RUY_PLATFORM(ARM) - -#if RUY_PLATFORM(X86) - // TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / - // placeholder. Optimization is not finished. In particular the dimensions of - // the kernel blocks can be changed as desired. - // - if ((runtime_enabled_paths_ & Path::kSse42) != Path::kNone) { - if (!(HaveBuiltPathForSse42() && DetectCpuSse42())) { - runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kSse42; - // Sanity check. - RUY_DCHECK((runtime_enabled_paths_ & Path::kSse42) == Path::kNone); - } - } - - if ((runtime_enabled_paths_ & Path::kAvx2) != Path::kNone) { - if (!(HaveBuiltPathForAvx2() && DetectCpuAvx2())) { - runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kAvx2; - // Sanity check. - RUY_DCHECK((runtime_enabled_paths_ & Path::kAvx2) == Path::kNone); - } - } - - if ((runtime_enabled_paths_ & Path::kAvx512) != Path::kNone) { - if (!(HaveBuiltPathForAvx512() && DetectCpuAvx512())) { - runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kAvx512; - // Sanity check. - RUY_DCHECK((runtime_enabled_paths_ & Path::kAvx512) == Path::kNone); - } - } - - // TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / - // placeholder. Optimization is not finished. In particular the dimensions of - // the kernel blocks can be changed as desired. - // - if ((runtime_enabled_paths_ & Path::kAvxVnni) != Path::kNone) { - if (!(HaveBuiltPathForAvxVnni() && DetectCpuAvxVnni())) { - runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kAvxVnni; - // Sanity check. - RUY_DCHECK((runtime_enabled_paths_ & Path::kAvxVnni) == Path::kNone); - } - } -#endif // RUY_PLATFORM(X86) - - // Sanity check. We can't possibly have disabled all paths, as some paths - // are universally available (kReference, kStandardCpp). - RUY_DCHECK_NE(runtime_enabled_paths_, Path::kNone); - return runtime_enabled_paths_; -} - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/context.h b/tensorflow/lite/experimental/ruy/ruy/context.h deleted file mode 100644 index a2d05a9ba5c..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/context.h +++ /dev/null @@ -1,109 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CONTEXT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CONTEXT_H_ - -#include -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/allocator.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/prepacked_cache.h" -#include "tensorflow/lite/experimental/ruy/ruy/thread_pool.h" -#include "tensorflow/lite/experimental/ruy/ruy/trace.h" -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -namespace ruy { - -// The state private to each Ruy thread. -struct PerThreadState { - // Each thread may be running on a different microarchitecture. For example, - // some threads may be on big cores, while others are on little cores. Thus, - // it's best for the tuning to be per-thread. - TuningResolver tuning_resolver; - // Each thread has its own local allocator. - Allocator allocator; -}; - -// A Context holds runtime information used by Ruy. It holds runtime resources -// such as the workers thread pool and the allocator (which holds buffers for -// temporary data), as well as runtime options controlling which Paths are -// enabled (typically based on which instruction sets are detected) and how -// many threads to use. -struct Context final { - Path last_taken_path = Path::kNone; - Tuning explicit_tuning = Tuning::kAuto; - // TODO(benoitjacob) rename that thread_pool. Current name is gemmlowp legacy. - ThreadPool workers_pool; - int max_num_threads = 1; - // State for each thread in the thread pool. Entry 0 is the main thread. - std::vector> per_thread_states; - TracingContext tracing; - CachePolicy cache_policy = CachePolicy::kNoCache; - - Allocator* GetMainAllocator() { - if (!main_allocator_) { - main_allocator_.reset(new Allocator); - } - return main_allocator_.get(); - } - - PrepackedCache* GetPrepackedCache() { - if (!prepacked_cache_) { - prepacked_cache_.reset(new PrepackedCache); - } - return prepacked_cache_.get(); - } - - void ClearPrepackedCache() { prepacked_cache_ = nullptr; } - - void EnsureNPerThreadStates(int thread_count) { - while (per_thread_states.size() < static_cast(thread_count)) { - per_thread_states.emplace_back(new PerThreadState); - } - } - - Tuning GetMainThreadTuning() { - EnsureNPerThreadStates(1); - TuningResolver* tuning_resolver = &per_thread_states[0]->tuning_resolver; - tuning_resolver->SetTuning(explicit_tuning); - return tuning_resolver->Resolve(); - } - - template - Path GetPathToTake() { - last_taken_path = - GetMostSignificantPath(CompiledPaths & GetRuntimeEnabledPaths()); - return last_taken_path; - } - - void SetRuntimeEnabledPaths(Path paths); - Path GetRuntimeEnabledPaths(); - - private: - // Allocator for main thread work before invoking the threadpool. - // Our simple Allocator does not allow reserving/allocating more blocks - // while it's already in committed state, so the main thread needs both - // this allocator, and its per-thread allocator. - std::unique_ptr main_allocator_; - std::unique_ptr prepacked_cache_; - Path runtime_enabled_paths_ = Path::kNone; -}; - -} // end namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CONTEXT_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/context_test.cc b/tensorflow/lite/experimental/ruy/ruy/context_test.cc deleted file mode 100644 index bddbfcf8c55..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/context_test.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/context.h" - -#include -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" - -namespace ruy { -namespace { - -TEST(ContextTest, EnabledPathsGeneral) { - ruy::Context ruy_context; - const auto ruy_paths = ruy_context.GetRuntimeEnabledPaths(); - const auto ruy_paths_repeat = ruy_context.GetRuntimeEnabledPaths(); - ASSERT_EQ(ruy_paths, ruy_paths_repeat); - EXPECT_NE(ruy_paths, Path::kNone); - EXPECT_EQ(ruy_paths & Path::kReference, Path::kReference); - EXPECT_EQ(ruy_paths & Path::kStandardCpp, Path::kStandardCpp); -} - -#if RUY_PLATFORM(X86) -TEST(ContextTest, EnabledPathsX86) { - ruy::Context ruy_context; - ruy_context.SetRuntimeEnabledPaths(Path::kSse42 | Path::kAvx2 | - Path::kAvx512 | Path::kAvxVnni); - const auto ruy_paths = ruy_context.GetRuntimeEnabledPaths(); - EXPECT_EQ(ruy_paths & Path::kReference, Path::kNone); - EXPECT_EQ(ruy_paths & Path::kStandardCpp, Path::kNone); -} -#endif // RUY_PLATFORM(X86) - -#if RUY_PLATFORM(ARM) -TEST(ContextTest, EnabledPathsArm) { - ruy::Context ruy_context; - ruy_context.SetRuntimeEnabledPaths(Path::kNeon | Path::kNeonDotprod); - const auto ruy_paths = ruy_context.GetRuntimeEnabledPaths(); - EXPECT_EQ(ruy_paths & Path::kReference, Path::kNone); - EXPECT_EQ(ruy_paths & Path::kStandardCpp, Path::kNone); - EXPECT_EQ(ruy_paths & Path::kNeon, Path::kNeon); -} -#endif // RUY_PLATFORM(ARM) - -} // namespace -} // namespace ruy - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/experimental/ruy/ruy/cpu_cache_size.h b/tensorflow/lite/experimental/ruy/ruy/cpu_cache_size.h deleted file mode 100644 index 95ed35ec097..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/cpu_cache_size.h +++ /dev/null @@ -1,81 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CPU_CACHE_SIZE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CPU_CACHE_SIZE_H_ - -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" - -namespace ruy { - -// LocalDataCacheSize returns a sane default size for each CPU core's local -// data cache, i.e. the largest data cache that is local to that CPU core, not -// shared with other cores. That allows coarse tuning of code that aims for -// most of its memory accesses to hit such a typically fast data cache. -// -// SharedDataCacheSize returns a sane default size of the total data cache -// accessible to each CPU, including any shared cache. -// -// For example, if we design tune this code for a ARM Cortex-A55 with a local L1 -// cache of 32k, a local L2 cache of 128k and a shared L3 cache of 1M, -// LocalDataCacheSize should return 128k and SharedDataCacheSize -// should return 1M. -// -// Ideally these values would be queried at runtime, and we should probably -// do that on x86, but that is hard to do on ARM. -#if RUY_PLATFORM(ARM_64) -inline int LocalDataCacheSize() { return 1 << 15; } -inline int SharedDataCacheSize() { return 1 << 19; } -#elif RUY_PLATFORM(ARM_32) -inline int LocalDataCacheSize() { return 1 << 14; } -inline int SharedDataCacheSize() { return 1 << 18; } -#elif RUY_PLATFORM(X86) -inline int LocalDataCacheSize() { return 1 << 17; } -inline int SharedDataCacheSize() { return 1 << 21; } -#else -inline int LocalDataCacheSize() { return 1 << 14; } -inline int SharedDataCacheSize() { return 1 << 18; } -#endif -// Variants taking a Path argument which acts -// as a hint telling whether we're targeting more or less recent/powerful CPUs. -inline int LocalDataCacheSize(Path path) { -#if RUY_PLATFORM(ARM_64) - if (path == Path::kNeonDotprod) { - // At the moment, the smallest CPU with dotprod is probably Cortex-A55 with - // 128k L2 local cache. - return 1 << 17; - } -#else - (void)path; -#endif - return LocalDataCacheSize(); -} -inline int SharedDataCacheSize(Path path) { -#if RUY_PLATFORM(ARM_64) - if (path == Path::kNeonDotprod) { - // At the moment, the smallest CPU with dotprod is probably Cortex-A55 with - // 1M L3 shared cache. - return 1 << 20; - } -#else - (void)path; -#endif - return SharedDataCacheSize(); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CPU_CACHE_SIZE_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/detect_arm.cc b/tensorflow/lite/experimental/ruy/ruy/detect_arm.cc deleted file mode 100644 index 8f6d2c9f9fe..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/detect_arm.cc +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -/* Detection of dotprod instructions on ARM. - * The current Linux-specific code relies on sufficiently new Linux kernels: - * At least Linux 4.15 in general; on Android, at least Linux 4.14.111 thanks to - * a late backport. This was backported just before the Android 10 release, so - * this is leaving out pre-release Android 10 builds as well as earlier Android - * versions. - * - * It is possible to detect instructions in other ways that don't rely on - * an OS-provided feature identification mechanism: - * - * (A) We used to have a SIGILL-handler-based method that worked at least - * on Linux. Its downsides were (1) crashes on a few devices where - * signal handler installation didn't work as intended; (2) additional - * complexity to generalize to other Unix-ish operating systems including - * iOS; (3) source code complexity and fragility of anything installing - * and restoring signal handlers; (4) confusing behavior under a debugger. - * - * (B) We also experimented with a fork-ing approach where a subprocess - * tries the instruction. Compared to (A), this is much simpler and more - * reliable and portable, but also much higher latency on Android where - * an uncaught signal typically causes a 100 ms latency. - * - * Should there be interest in either technique again in the future, - * code implementing both (A) and (B) can be found in earlier revisions of this - * file - in actual code for (A) and in a comment for (B). - */ - -#include "tensorflow/lite/experimental/ruy/ruy/detect_arm.h" - -#if defined __linux__ && defined __aarch64__ -#include -#endif - -namespace ruy { - -namespace { - -#if defined __linux__ && defined __aarch64__ -bool DetectDotprodByLinuxAuxvMethod() { - // This is the value of HWCAP_ASIMDDP in sufficiently recent Linux headers, - // however we need to support building against older headers for the time - // being. - const int kLocalHwcapAsimddp = 1 << 20; - return getauxval(AT_HWCAP) & kLocalHwcapAsimddp; -} -#endif - -} // namespace - -bool DetectDotprod() { -#if defined __linux__ && defined __aarch64__ - return DetectDotprodByLinuxAuxvMethod(); -#endif - - return false; -} - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/detect_x86.cc b/tensorflow/lite/experimental/ruy/ruy/detect_x86.cc deleted file mode 100644 index 113a73c09e3..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/detect_x86.cc +++ /dev/null @@ -1,101 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/detect_x86.h" - -#include - -#if RUY_PLATFORM(X86) && RUY_PLATFORM(X86_ENHANCEMENTS) -#include // IWYU pragma: keep - -#endif - -namespace ruy { -#if RUY_PLATFORM(X86) && RUY_PLATFORM(X86_ENHANCEMENTS) - -namespace { - -// See Intel docs, such as http://goo.gl/c6IkGX. -inline void RunCpuid(std::uint32_t eax, std::uint32_t ecx, - std::uint32_t abcd[4]) { - std::uint32_t ebx, edx; -#if defined(__i386__) && defined(__PIC__) - /* in case of PIC under 32-bit EBX cannot be clobbered */ - asm volatile("movl %%ebx, %%edi \n\t cpuid \n\t xchgl %%ebx, %%edi" - : "=D"(ebx), -#else - asm volatile("cpuid" - : "+b"(ebx), -#endif - "+a"(eax), "+c"(ecx), "=d"(edx)); - abcd[0] = eax; - abcd[1] = ebx; - abcd[2] = ecx; - abcd[3] = edx; -} - -} // namespace - -bool DetectCpuSse42() { - std::uint32_t abcd[4]; - - constexpr std::uint32_t kEcxSse42 = 1u << 20; - RunCpuid(1, 0, abcd); - const bool has_sse4_2_base = (abcd[2] & kEcxSse42) == kEcxSse42; - -#ifdef RUY_ENABLE_AMD_CPUID_CHECKS - constexpr std::uint32_t kEcxAbm = 1u << 5; - RunCpuid(0x80000001, 0, abcd); - const bool has_extras = (abcd[2] & kEcxAbm) == kEcxAbm; -#else - constexpr std::uint32_t kEcxPopcnt = 1u << 23; - RunCpuid(1, 0, abcd); - const bool has_extras = (abcd[2] & kEcxPopcnt) == kEcxPopcnt; -#endif - - return has_sse4_2_base && has_extras; -} - -bool DetectCpuAvx2() { - constexpr std::uint32_t kEbxAvx2 = 1u << 5; - constexpr std::uint32_t kEcxFma = 1u << 12; - - std::uint32_t abcd[4]; - - RunCpuid(7, 0, abcd); - const bool has_avx2 = (abcd[1] & kEbxAvx2) == kEbxAvx2; - RunCpuid(1, 0, abcd); - const bool has_fma = (abcd[2] & kEcxFma) == kEcxFma; - - return has_avx2 && has_fma; -} - -bool DetectCpuAvx512() { - constexpr std::uint32_t kEbxAvx512F = 1u << 16; - constexpr std::uint32_t kEbxAvx512Dq = 1u << 17; - constexpr std::uint32_t kEbxAvx512Cd = 1u << 28; - constexpr std::uint32_t kEbxAvx512Bw = 1u << 30; - constexpr std::uint32_t kEbxAvx512Vl = 1u << 31; - - constexpr std::uint32_t kEbxAvx512Mask = - kEbxAvx512F | kEbxAvx512Dq | kEbxAvx512Cd | kEbxAvx512Bw | kEbxAvx512Vl; - std::uint32_t abcd[4]; - RunCpuid(7, 0, abcd); - - return (abcd[1] & kEbxAvx512Mask) == kEbxAvx512Mask; -} - -#endif -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/detect_x86.h b/tensorflow/lite/experimental/ruy/ruy/detect_x86.h deleted file mode 100644 index 185dabe06a5..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/detect_x86.h +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_X86_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_X86_H_ - -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" - -namespace ruy { - -#if RUY_PLATFORM(X86) -#if RUY_PLATFORM(X86_ENHANCEMENTS) - -// This also checks ABM support, which implies LZCNT and POPCNT. -bool DetectCpuSse42(); -bool DetectCpuAvx2(); -bool DetectCpuAvx512(); -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// TODO(b/146646451): Introduce and activate. -inline bool DetectCpuAvxVnni() { return false; } - -#else // RUY_PLATFORM(X86_ENHANCEMENTS) - -inline bool DetectCpuSse42() { return false; } -inline bool DetectCpuAvx2() { return false; } -inline bool DetectCpuAvx512() { return false; } -inline bool DetectCpuAvxVnni() { return false; } - -#endif // !RUY_PLATFORM(X86_ENHANCEMENTS) -#endif // RUY_PLATFORM(X86) - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_X86_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/dispatch.h b/tensorflow/lite/experimental/ruy/ruy/dispatch.h deleted file mode 100644 index d1e97e29b9c..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/dispatch.h +++ /dev/null @@ -1,482 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file implements the translation between Ruy's entry point (ruy::Mul) and -// the internal implementation of matrix multiplication. -// -// The primary elements of this dispatch are: -// - pick suitable gemm kernel and packing routines for the user-specified -// CompiledPaths based on the current CPU. -// - decide on the structure of the packed matrices needed by the internal -// implementation (see pack.h for more information on packing). -// - translate the Mul operation into TrMul (see trmul.h for why that is -// useful). This is done by changing the matrix Layout -- no matrix data is -// actually moved. -// -// This file is also factored to serve as a building block for the advanced API -// as well. -// -// This file also performs some checking of invariants to catch user errors. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DISPATCH_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DISPATCH_H_ - -#include -#include -#include // IWYU pragma: keep -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/context.h" -#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/kernel.h" -#include "tensorflow/lite/experimental/ruy/ruy/kernel_common.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/pack.h" -#include "tensorflow/lite/experimental/ruy/ruy/pack_common.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" -#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" -#include "tensorflow/lite/experimental/ruy/ruy/spec.h" -#include "tensorflow/lite/experimental/ruy/ruy/trmul.h" -#include "tensorflow/lite/experimental/ruy/ruy/trmul_params.h" - -namespace ruy { - -// If the Spec's LayoutSupport covers only some special cases, -// this function enforces that the matrix multiplication at hand falls into -// that special case. -template -void EnforceLayoutSupport(const Layout& lhs_layout, const Layout& rhs_layout, - const Layout& dst_layout) { - if (Spec::kLayoutSupport == LayoutSupport::kRCC) { - RUY_DCHECK(IsRowMajor(lhs_layout)); - RUY_DCHECK(IsColMajor(rhs_layout)); - RUY_DCHECK(IsColMajor(dst_layout)); - } -} - -template -bool IsSymmetricZeroPoint(Scalar zero_point) { - return zero_point == SymmetricZeroPoint(); -} - -template -void CheckZeroPoint(Scalar zero_point) { - if (std::is_floating_point::value || - Spec::kZeroPointSupport == ZeroPointSupport::kSymmetric) { - RUY_DCHECK(IsSymmetricZeroPoint(zero_point)); - } -} - -template -void EnforceZeroPointSupport(LhsScalar lhs_zero_point, RhsScalar rhs_zero_point, - DstScalar dst_zero_point) { - // If the Spec's ZeroPointSupport covers only some special cases, - // this function enforces that the matrix multiplication at hand falls into - // that special case. - CheckZeroPoint(lhs_zero_point); - CheckZeroPoint(rhs_zero_point); - CheckZeroPoint(dst_zero_point); - - // Guard against the case when both LHS and RHS zero_point's are equal to - // the minimum representable value. In that case, padding with zero_point - // values will generate the bad case for fast int8 kernels on NEON - // (pre-dotprod) which attempt to multiply-accumulate two pairs of int8 - // into a int16: this is safe except in the bad case -128*-128 + -128*-128. - // See b/131609283. This only affects the kNeon path but we ban this for all - // paths in order for ruy to have the same supported parameter space - // on all paths. - RUY_DCHECK(lhs_zero_point != std::numeric_limits::lowest() || - rhs_zero_point != std::numeric_limits::lowest()); -} - -template -void EnforceDstSpecSupport(const Spec& spec, DstScalar dst_zero_point) { - static_assert(std::is_same::value, ""); - if (!std::is_same::value) return; - - // If user is looking for the raw accumulator, zero_point and all the other - // dequantize fields don't make sense and should not be set. - RUY_DCHECK_EQ(dst_zero_point, 0); - RUY_DCHECK_EQ(spec.clamp_max, std::numeric_limits::max()); - RUY_DCHECK_EQ(spec.clamp_min, std::numeric_limits::min()); - RUY_DCHECK_EQ(spec.multiplier_fixedpoint, 0); - RUY_DCHECK_EQ(spec.multiplier_exponent, 0); - RUY_DCHECK_EQ(spec.multiplier_fixedpoint_perchannel, nullptr); - RUY_DCHECK_EQ(spec.multiplier_exponent_perchannel, nullptr); -} - -inline bool IsColMajorTrMul(const TrMulParams& params) { - return IsColMajor(params.src[Side::kLhs].layout) && - IsColMajor(params.src[Side::kRhs].layout) && - IsColMajor(params.dst.layout); -} - -inline void CreatePackedLayout(const Layout& src, const Type& scalar, - const KernelLayout& kernel_layout, - PackedLayout* packed) { - packed->order = Order::kColMajor; - packed->rows = round_up_pot(src.rows, kernel_layout.rows); - packed->cols = round_up_pot(src.cols, kernel_layout.cols); - packed->kernel = kernel_layout; - int inner_size = packed->rows; - if (RUY_OPT_ENABLED(RUY_OPT_AVOID_ALIASING)) { - packed->stride = - (inner_size * scalar.size) % 1024 ? inner_size : inner_size + 64; - } else { - packed->stride = inner_size; - } -} - -template -void CreatePackedMatrix(Side side, const KernelLayout& kernel_layout, - TrMulParams* params) { - // Ruy always uses 32-bit signed accumulators for quantized - // matrix multiplication, so we would like to always use std::int32_t - // unconditionally for SumsType. - // However, for floating point types, we still need a reasonable type here to - // avoid tripping assertions elsewhere in the code. - using SumsType = - typename std::conditional::value, Scalar, - std::int32_t>::type; - - const DMatrix& src = params->src[side]; - PMatrix* packed = ¶ms->packed[side]; - packed->data_type = Type::Create(); - packed->sums_type = Type::Create(); - CreatePackedLayout(src.layout, packed->data_type, kernel_layout, - &packed->layout); - packed->zero_point = Pack(src.zero_point); -} - -template -void PopulateTrMulParams(TrMulParams* params) { - static_assert((ThePath & Path::kReference) == Path::kNone, - "Path::kReference should not do TrMul"); - // The optimized code paths don't handle the full generality of Ruy's API. - // Fall back to Path::kStandardCpp if necessary. - bool fallback_to_standard_cpp = false; - if (ThePath != Path::kStandardCpp) { - // The optimized code paths currently only handle the case of all matrices - // being column major. - if (!IsColMajorTrMul(*params)) { - fallback_to_standard_cpp = true; - } - } - - if (fallback_to_standard_cpp) { - PopulateTrMulParams(params); - return; - } - - using PackedLhsScalar = PackedType; - using PackedRhsScalar = PackedType; - using Kernel = - Kernel; - using LhsKernelLayout = typename Kernel::LhsLayout; - using RhsKernelLayout = typename Kernel::RhsLayout; - - params->path = ThePath; - - params->local_data_cache_size = Spec::local_data_cache_size(); - params->shared_data_cache_size = Spec::shared_data_cache_size(); - - CreatePackedMatrix( - Side::kLhs, ToKernelLayout(), params); - CreatePackedMatrix( - Side::kRhs, ToKernelLayout(), params); - params->run_pack[Side::kLhs] = - &RunPack; - params->run_pack[Side::kRhs] = - &RunPack; - params->run_kernel = - &RunKernel; - - return; -} - -// PopulateTrMulParamsAllCompiledPaths calls into one of multiple -// instantiations of PopulateTrMulParams. For each bit that is set in -// CompiledPaths, it statically instantiates PopulateTrMulParams with a Path -// corresponding to that single bit. The call to PopulateTrMulParams is -// guarded by a runtime check that it is in fact the dynamically selected path. -// -// PopulateTrMulParamsAllCompiledPaths is implemented with template -// metaprogramming by mutual recursion between PathSearchCountdown and -// PathSearchCompiledPaths. -// -// PopulateTrMulParamsAllCompiledPaths is logically implementing the following -// computation: -// -// template -// void PopulateTrMulParamsAllCompiledPaths(Path the_path, -// TrMulParams* params) { -// for (int bit = 8 * sizeof(Path) - 1; bit != -1; bit--) { // [1] -// Path current_path = static_cast(1 << bit); -// if ((CompiledPaths & current_path) != Path::kNone) { // [2] -// if (current_path == the_path) { // [3] -// PopulateTrMulParams(the_path, params); -// return; -// } -// } -// } -// } -// -// -// -// [1] - Done by the main definition of PathSearchCountdown. The `bit--` is -// done in the recursion of PathSearchOnlyCompiledPaths. -// [2] - Done by PathSearchOnlyCompiledPaths's partial template -// specialization on InCompiledPaths. This is the check which necessitates -// doing the whole computation at C++ compile time. -// [3] - Done by the `if` in the main definition of -// PathSearchOnlyCompiledPaths. -// -// The template metaprogramming is necessary because: -// - In `PopulateTrMulParams`, current_path must be a C++ -// compile-time constant. -// - PopulateTrMulParamsAllCompiledPaths must not instantiate -// inner loops for paths that are not in CompiledPaths, since that can result in -// bogus instantiations which cause a compile time failure. -template -struct PathSearchCountdown; - -template -struct PathSearchOnlyCompiledPaths { - static constexpr Path kCurrentPath = static_cast(1 << BitNumber); - static void Search(Path the_path, TrMulParams* params) { - if (kCurrentPath == the_path) { - PopulateTrMulParams( - params); - return; - } - PathSearchCountdown::Search(the_path, params); - } -}; - -// Skip this iteration if CompiledPaths doesn't contain the specified path. -template -struct PathSearchOnlyCompiledPaths { - static void Search(Path the_path, TrMulParams* params) { - PathSearchCountdown::Search(the_path, params); - } -}; - -template -struct PathSearchCountdown { - static constexpr Path kCurrentPath = static_cast(1 << BitNumber); - static void Search(Path the_path, TrMulParams* params) { - PathSearchOnlyCompiledPaths< - CompiledPaths, (CompiledPaths & kCurrentPath) != Path::kNone, BitNumber, - LhsScalar, RhsScalar, DstScalar, Spec>::Search(the_path, params); - } -}; - -// Termination of the countdown. If the counter reaches -1, then we haven't -// found the specified path. -template -struct PathSearchCountdown { - static void Search(Path the_path, TrMulParams* params) { RUY_DCHECK(false); } -}; - -template -void PopulateTrMulParamsAllCompiledPaths(Path the_path, TrMulParams* params) { - return PathSearchCountdown::Search(the_path, - params); -} - -template -void CreateTrMulParams(const Matrix& lhs, - const Matrix& rhs, const Spec& spec, - Context* context, Matrix* dst, Path the_path, - TrMulParams* params) { - // Fill in the fields we already know. - params->src[Side::kLhs] = ToDMatrix(lhs); - params->src[Side::kRhs] = ToDMatrix(rhs); - params->dst = ToDMatrix(*dst); - params->spec = ToVoidPtr(&spec); - - // Create inner loops and packed matrices based on the Path. - PopulateTrMulParamsAllCompiledPaths(the_path, params); -} - -template -void ReferenceMul(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, Matrix* dst) { - profiler::ScopeLabel label("ReferenceMul"); - for (int i = 0; i < lhs.layout.rows; i++) { - for (int j = 0; j < rhs.layout.cols; j++) { - using AccumScalar = typename Spec::AccumScalar; - AccumScalar accum = 0; - for (int k = 0; k < lhs.layout.cols; k++) { - AccumScalar lhs_val = Element(lhs, i, k); - AccumScalar rhs_val = Element(rhs, k, j); - accum += (lhs_val - lhs.zero_point) * (rhs_val - rhs.zero_point); - } - if (spec.bias) { - accum += spec.bias[i]; - } - ApplyMultiplier(spec, i, &accum); - accum += dst->zero_point; - accum = std::min(accum, spec.clamp_max); - accum = std::max(accum, spec.clamp_min); - *ElementPtr(dst, i, j) = static_cast(accum); - } - } -} - -// Compile-time dispatch to ReferenceMul. This allows us to statically ensure -// that there is no call to ReferenceMul in the user's binary. -template -struct CompileTimeEnabledReferenceMul { - template - static void Run(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, Matrix* dst) { - ReferenceMul(lhs, rhs, spec, dst); - } -}; - -// When this partial specialization is chosen, it ensures that ReferenceMul -// is never compiled. -template <> -struct CompileTimeEnabledReferenceMul { - template - static void Run(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, Matrix* dst) { - RUY_DCHECK(false); - } -}; - -inline void HandlePrepackedCaching(TrMulParams* params, - const SidePair& cacheable, - Context* context) { - if (context->cache_policy == CachePolicy::kNoCache) { - return; - } - - if (context->cache_policy == CachePolicy::kCacheLHSOnNarrowMul) { - // TODO(b/149304278) Cache on dst.cols <= selected kernel width. - if (!cacheable[Side::kLhs] || params->dst.layout.cols > 4) { - return; - } - PrepackedCache* prepacked_cache = context->GetPrepackedCache(); - auto cache_key = std::make_pair(reinterpret_cast(params->run_kernel), - params->src[Side::kLhs].data); - auto it = prepacked_cache->FindAndUpdate(cache_key); - if (it != prepacked_cache->cend()) { - params->packed[Side::kLhs].data = it->second.first.data; - params->packed[Side::kLhs].sums = it->second.first.sums; - params->is_prepacked[Side::kLhs] = true; - return; - } - - // Allocate the prepacked matrix. - PrepackedMatrix prepacked_lhs; - prepacked_lhs.data_size = DataSize(params->packed[Side::kLhs]); - prepacked_lhs.sums_size = SumsSize(params->packed[Side::kLhs]); - prepacked_cache->AllocatePrepackedMatrix(&prepacked_lhs); - params->packed[Side::kLhs].data = prepacked_lhs.data; - params->packed[Side::kLhs].sums = prepacked_lhs.sums; - params->is_prepacked[Side::kLhs] = true; - Tuning tuning = context->GetMainThreadTuning(); - params->RunPack(Side::kLhs, tuning, 0, - params->packed[Side::kLhs].layout.cols); - prepacked_cache->Insert(cache_key, prepacked_lhs); - return; - } -} - -template -void DispatchMul(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, Context* context, Matrix* dst) { - static_assert(CompiledPaths != Path::kNone, "Must compile at least one Path"); - static_assert((CompiledPaths & ~kAllPaths) == Path::kNone, - "CompiledPaths must be a subset of ruy::kAllPaths"); - - profiler::ScopeLabel mul_label("Mul"); - profiler::ScopeLabel shape_specific_label("matmul shape: %dx%dx%d", - lhs.layout.rows, lhs.layout.cols, - rhs.layout.cols); - - EnforceLayoutSupport(lhs.layout, rhs.layout, dst->layout); - EnforceZeroPointSupport(lhs.zero_point, rhs.zero_point, - dst->zero_point); - EnforceDstSpecSupport(spec, dst->zero_point); - - // This should be a constant, for a given machine and CompiledPaths. - // There is a back door to override it for testing, but in production it will - // always be the "best" Path. I.e. the one with the newest SIMD instructions - // available on the present machine, and avoiding Path::kReference unless - // no other path is compiled. - // - // Unfortunately, it is not a *static* constant, since it depends on runtime - // detection of the available SIMD instructions. - Path the_path = context->GetPathToTake(); - - // Production code should probably never execute Path::kReference. - // Path::kReference implements a Mul, not a TrMul like the rest of Ruy, so if - // that's what we need to do, then get it out of the way before going down the - // TrMul path. - if (the_path == Path::kReference) { - constexpr bool ReferenceMulIsEnabled = - (CompiledPaths & Path::kReference) != Path::kNone; - CompileTimeEnabledReferenceMul::Run(lhs, rhs, spec, - dst); - return; - } - - // As described in the comment at the top of this file, Ruy internally - // converts Mul into TrMul. We handle that here. - // - // This is Ruy's main code path. - constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference; - Matrix transposed_lhs(lhs); - Transpose(&transposed_lhs); - TrMulParams params; - CreateTrMulParams(transposed_lhs, rhs, spec, context, dst, - the_path, ¶ms); - SidePair cacheable(lhs.cacheable, rhs.cacheable); - HandlePrepackedCaching(¶ms, cacheable, context); - TrMul(¶ms, context); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DISPATCH_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/example.cc b/tensorflow/lite/experimental/ruy/ruy/example.cc deleted file mode 100644 index 5d31d6c2e3e..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/example.cc +++ /dev/null @@ -1,136 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/ruy.h" - -void ExampleMulFloat(ruy::Context *context) { - const float lhs_data[] = {1, 2, 3, 4}; - const float rhs_data[] = {1, 2, 3, 4}; - float dst_data[4]; - - ruy::Matrix lhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); - lhs.data = lhs_data; - ruy::Matrix rhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout); - rhs.data = rhs_data; - ruy::Matrix dst; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout); - dst.data = dst_data; - - ruy::BasicSpec spec; - ruy::Mul(lhs, rhs, spec, context, &dst); - - std::cout << "Example Mul, float:\n"; - std::cout << "LHS:\n" << lhs; - std::cout << "RHS:\n" << rhs; - std::cout << "Result:\n" << dst << "\n"; -} - -void ExampleMulFloatWithBiasAddAndClamp(ruy::Context *context) { - const float lhs_data[] = {1, 2, 3, 4}; - const float rhs_data[] = {1, 2, 3, 4}; - const float bias_data[] = {1, 0}; - float dst_data[4]; - - ruy::Matrix lhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); - lhs.data = lhs_data; - ruy::Matrix rhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout); - rhs.data = rhs_data; - ruy::Matrix dst; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout); - dst.data = dst_data; - - ruy::BasicSpec spec; - spec.bias = bias_data; - spec.clamp_min = 0; - spec.clamp_max = 15; - ruy::Mul(lhs, rhs, spec, context, &dst); - - std::cout << "Example Mul, float with bias addition and clamp:\n"; - std::cout << "LHS:\n" << lhs; - std::cout << "RHS:\n" << rhs; - std::cout << "Result:\n" << dst << "\n"; -} - -void ExampleMulUint8AsymmetricQuantized(ruy::Context *context) { - const std::uint8_t lhs_data[] = {124, 125, 126, 127}; - const std::uint8_t rhs_data[] = {129, 130, 131, 132}; - std::uint8_t dst_data[4]; - - ruy::Matrix lhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); - lhs.data = lhs_data; - lhs.zero_point = 125; - ruy::Matrix rhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout); - rhs.data = rhs_data; - rhs.zero_point = 132; - ruy::Matrix dst; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout); - dst.data = dst_data; - dst.zero_point = 129; - - ruy::BasicSpec spec; - spec.multiplier_fixedpoint = 1 << 30; - - spec.multiplier_exponent = 0; - ruy::Mul(lhs, rhs, spec, context, &dst); - - std::cout << "Example Mul, uint8 quantized with asymmetric zero points:\n"; - std::cout << "LHS:\n" << lhs; - std::cout << "RHS:\n" << rhs; - std::cout << "Result:\n" << dst << "\n"; -} -void ExampleMulInt8PerChannelQuantized(ruy::Context *context) { - const std::int8_t lhs_data[] = {1, 2, 3, 4}; - const std::int8_t rhs_data[] = {1, 2, 3, 4}; - const std::int32_t multiplier_data[] = {3 << 28, 5 << 28}; - const int exponent_data[] = {1, -2}; - std::int8_t dst_data[4]; - - ruy::Matrix lhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); - lhs.data = lhs_data; - ruy::Matrix rhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout); - rhs.data = rhs_data; - ruy::Matrix dst; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout); - dst.data = dst_data; - - ruy::BasicSpec spec; - spec.multiplier_fixedpoint_perchannel = multiplier_data; - spec.multiplier_exponent_perchannel = exponent_data; - ruy::Mul(lhs, rhs, spec, context, &dst); - - std::cout << "Example Mul, int8 quantized with per-channel multipliers\n"; - std::cout << "LHS:\n" << lhs; - std::cout << "RHS:\n" << rhs; - std::cout << "Result:\n" << dst << "\n"; -} - -int main() { - ruy::Context context; - ExampleMulFloat(&context); - ExampleMulFloatWithBiasAddAndClamp(&context); - ExampleMulUint8AsymmetricQuantized(&context); - ExampleMulInt8PerChannelQuantized(&context); -} diff --git a/tensorflow/lite/experimental/ruy/ruy/example_advanced.cc b/tensorflow/lite/experimental/ruy/ruy/example_advanced.cc deleted file mode 100644 index 9e1dd17f86d..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/example_advanced.cc +++ /dev/null @@ -1,83 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/ruy_advanced.h" - -// Simple allocator for allocating pre-packed matrices. -class SimpleAllocator { - public: - void* AllocateBytes(std::size_t num_bytes) { - char* p = new char[num_bytes]; - buffers_.emplace_back(p); - return static_cast(p); - } - - private: - std::vector> buffers_; -}; - -void ExamplePrepack(ruy::Context* context) { - const float lhs_data[] = {1, 2, 3, 4}; - const float rhs_data[] = {1, 2, 3, 4}; - float dst_data[4]; - - // Set up the matrix layouts and spec. - ruy::Matrix lhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); - ruy::Matrix rhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout); - ruy::Matrix dst; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout); - ruy::BasicSpec spec; - - SimpleAllocator allocator; - auto alloc_fn = [&allocator](std::size_t num_bytes) -> void* { - return allocator.AllocateBytes(num_bytes); - }; - - // In this example, we pre-pack only the RHS, but either will work. - // Note that we only need to set the data pointer for the matrix we are - // pre-packing. - ruy::PrepackedMatrix prepacked_rhs; - rhs.data = rhs_data; - ruy::PrePackForMul(lhs, rhs, spec, context, &dst, - /*prepacked_lhs=*/nullptr, &prepacked_rhs, - alloc_fn); - - // No data will be read from the RHS input matrix when using a pre-packed RHS. - rhs.data = nullptr; - lhs.data = lhs_data; - dst.data = dst_data; - ruy::MulWithPrepacked(lhs, rhs, spec, context, &dst, - /*prepacked_lhs=*/nullptr, - &prepacked_rhs); - rhs.data = rhs_data; - - // Print out the results. - std::cout << "Example Mul with pre-packing RHS, float:\n"; - std::cout << "LHS:\n" << lhs; - std::cout << "RHS:\n" << rhs; - std::cout << "Result:\n" << dst << "\n"; -} - -int main() { - ruy::Context context; - ExamplePrepack(&context); -} diff --git a/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avx2.cc b/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avx2.cc deleted file mode 100644 index a9bcfbbbcfb..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avx2.cc +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/have_built_path_for.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" - -namespace ruy { - -#if RUY_PLATFORM(X86) -// IMPORTANT: -// These patterns must match those in the pack and kernel cc files. -#if !(RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -bool HaveBuiltPathForAvx2() { return false; } - -#else // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -bool HaveBuiltPathForAvx2() { return true; } - -#endif // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) -#endif // RUY_PLATFORM(X86) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avx512.cc b/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avx512.cc deleted file mode 100644 index 2b42cba26c9..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avx512.cc +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/have_built_path_for.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" - -namespace ruy { - -#if RUY_PLATFORM(X86) -// IMPORTANT: -// These patterns must match those in the pack and kernel cc files. -#if !(RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -bool HaveBuiltPathForAvx512() { return false; } - -#else // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -bool HaveBuiltPathForAvx512() { return true; } - -#endif // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) -#endif // RUY_PLATFORM(X86) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avxvnni.cc b/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avxvnni.cc deleted file mode 100644 index 42f9cb668df..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avxvnni.cc +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/have_built_path_for.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" - -namespace ruy { - -#if RUY_PLATFORM(X86) -// IMPORTANT: -// These patterns must match those in the pack and kernel cc files. -#if !(RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -bool HaveBuiltPathForAvxVnni() { return false; } - -#else // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -bool HaveBuiltPathForAvxVnni() { return true; } - -#endif // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) -#endif // RUY_PLATFORM(X86) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_sse42.cc b/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_sse42.cc deleted file mode 100644 index e7470f54520..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_sse42.cc +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/have_built_path_for.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" - -namespace ruy { - -#if RUY_PLATFORM(X86) -// IMPORTANT: -// These patterns must match those in the pack and kernel cc files. -#if !(RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -bool HaveBuiltPathForSse42() { return false; } - -#else // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -bool HaveBuiltPathForSse42() { return true; } - -#endif // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) -#endif // RUY_PLATFORM(X86) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/internal_matrix.h b/tensorflow/lite/experimental/ruy/ruy/internal_matrix.h deleted file mode 100644 index cf10adf084d..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/internal_matrix.h +++ /dev/null @@ -1,388 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Internal types and helpers for matrices. -// -// Ruy has a couple slightly different notions of matrices, besides the -// Matrix class that we expose to the user-facing API. -// -// TODO(silvasean): Put parts of this architecture description somewhere more -// prominent. -// -// The 4 main matrix types are: -// - Matrix: This is a user-facing type on Ruy's external API boundary. It is -// also used internally. -// - DMatrix: This is a type-erased version of Matrix. "D" = "dynamic". -// - PMatrix: This represents a packed matrix, which requires tracking kernel -// layout and row/column sums for quantization. It is type-erased. -// - PackedMatrix: This is a statically typed variant of PMatrix for -// convenience inside typed routines. -// -// Note that Matrix is *not* implemented in terms of the internal types. It -// is an independent, simple, and user-facing type. -// -// The use of type-erasure might seem surprising for a library like Ruy with a -// heavily-templated entry point, but it is motivated by the desire for most of -// Ruy's "middle-end" to be non-templated. Ruy can be thought of as having 3 -// main parts: -// - "front-end" (dispatch.h) - this is the highly templated ruy::Mul entry -// point, along with routines that select RunKernel and RunPack implementations -// statically based on those template parameters. -// - "back-end" (kernel.h, pack.h)- this consists of the implementations of -// RunKernel and RunPack, often in assembly code, which are the building blocks -// that Ruy calls to perform matrix multiplication. These are templated so that -// only the requested types/Path's are actually emitted by the compiler. -// - "middle-end" (trmul.h) - this is the part of Ruy that orchestrates the -// calls to the "back-end" optimized building blocks. This layer has to deal -// with issues like cache locality and low-overhead multi-threading. -// -// There is a desire for the "middle-end" to be non-templated in order to -// simplify the implementation and reduce code-size. We type-erase when going -// from the "front-end" to the "middle-end", and un-type-erase going from the -// "middle-end" to the "back-end". The un-type-erasure is possible because the -// "front-end" is responsible for instantiating the needed "back-end" templates, -// and thus the static type information is still present. -// -// Each layer of Ruy uses matrix types: -// - "front-end": Matrix -// - "middle-end": DMatrix, PMatrix -// - "back-end": Matrix, PackedMatrix -// -// The use of separate types for packed matrices is not essential, but makes it -// obvious at a glance whether a matrix is a packed matrix or not. We would -// reconsider this decision if there was significant duplication between packed -// and unpacked matrices, but that doesn't seem to be the case at the moment. -// -// Another goal is to keep the user-facing Matrix as simple and -// understandable as possible. Ideally, a user should be able to read the struct -// definition for Matrix and see a very simple definition with no internal -// details like sums and kernel block layout. -// -// To present another structured view of our various matrix types, here's a -// table: -// Plain matrices Packed matrices -// +---------------------------------- -// Templated | Matrix PackedMatrix -// Type-erased | DMatrix PMatrix -// -// -// There is 1 additional matrix type not mentioned above, due to its low -// importance: -// - PrepackedMatrix: This is a user-facing version of PMatrix. It has the bare -// minimum of fields needed for representing the raw data and sums buffers of a -// packed matrix for the "advanced" explicit pre-packing API. This type plays no -// role in Ruy's internals and can generally by ignored. The only reason it -// exists is so that PMatrix is not exposed to users -- we prefer to keep the -// internal matrix types hidden, even from "advanced" users. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_INTERNAL_MATRIX_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_INTERNAL_MATRIX_H_ - -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" - -namespace ruy { - -// KernelLayout describes small-scale block structure in a packed matrix layout. -// It's a runtime (as opposed to compile-time-constant) version of the -// FixedKernelLayout struct used to declare kernel layouts. -// -// This is is sometimes known as "tiling" in other contexts. -// -// For example, consider a packed matrix in column-major format with a -// column-major KernelLayout. The matrix logically has a shape of -// `[cols, rows]`. However, the matrix is laid out as though it were a 4D array -// of shape `[cols / kcols, rows / krows, kcols, krows]`. -// -// Note that in the case of kcols=1, krows=1, this degenerates to -// `[cols, rows, 1, 1]` which is equivalent to having no small-scale block -// structure. -struct KernelLayout { - Order order = Order::kColMajor; - std::uint8_t rows = 1; - std::uint8_t cols = 1; -}; - -// A packed matrix has a small-scale block structure that is not present in in -// the input matrices. This block structure is necessary for the kernels to -// process data efficiently. -// -// This struct is very similar to Layout, but has the extra KernelLayout field. -struct PackedLayout { - std::int32_t rows = 0; - std::int32_t cols = 0; - // Stride is the offset between two adjacent matrix elements - // in the non-contiguous direction. - std::int32_t stride = 0; - Order order = Order::kColMajor; - // Small scale layout shuffling, potentially departing from - // linear row-major or column-major storage. See KernelLayout. - KernelLayout kernel; -}; - -// Dynamic representation for a type. -// -// The most important field in this struct is the size, which Ruy uses to know -// how much memory to allocate without having to be templated on a type. -// Signed-ness and floating-point-ness are mainly present as debugging checks. -// -// Note: Ruy does not use this struct to to dynamically dispatch between -// different typed implementations. As described in the comment at the top of -// this file, Ruy's "front-end", which is templated, instantiates all the -// necessary "back-end" routines with complete static knowledge of all the -// types. -struct Type { - template - static Type Create() { - Type ret; - ret.is_signed = std::is_signed::value; - ret.is_floating_point = std::is_floating_point::value; - ret.size = sizeof(T); - return ret; - } - - template - void AssertIs() const { - RUY_DCHECK_EQ(is_signed, Create().is_signed); - RUY_DCHECK_EQ(is_floating_point, Create().is_floating_point); - RUY_DCHECK_EQ(size, Create().size); - } - - bool is_signed = false; - bool is_floating_point = false; - std::uint8_t size = 0; -}; - -// Type-erased matrix. -struct DMatrix { - Type data_type; - void* data = nullptr; - Layout layout; - std::int32_t zero_point = 0; -}; - -// Type-erased packed matrix. -struct PMatrix { - Type data_type; - void* data = nullptr; - Type sums_type; - void* sums = nullptr; - PackedLayout layout; - std::int32_t zero_point = 0; -}; - -// Convenient typed helper for packed matrices. -template -struct PackedMatrix { - // The row/column sums needed for quantized matrix multiplication when - // the opposite operand of the multiplication uses a non-symmetric zero - // point. - // This member is only relevant for packed matrices. - // Additionally, Ruy always uses 32-bit signed accumulators for quantized - // matrix multiplication. - // For floating point types, there is no quantization, so this pointer - // will always be null. We still need code referencing it to compile - // though, even if it is always branched around. Hence we use Scalar* - // itself as the type in that case. - using SumsType = - typename std::conditional::value, Scalar, - std::int32_t>::type; - - Scalar* data = nullptr; - SumsType* sums = nullptr; - PackedLayout layout; - std::int32_t zero_point = 0; -}; - -template -DMatrix ToDMatrix(const Matrix& matrix) { - DMatrix ret; - ret.data_type = Type::Create(); - ret.data = ToVoidPtr(matrix.data.get()); - ret.layout = matrix.layout; - ret.zero_point = matrix.zero_point; - return ret; -} - -template -Matrix ToMatrix(const DMatrix& dmatrix) { - dmatrix.data_type.AssertIs(); - Matrix ret; - ret.data = static_cast(dmatrix.data); - ret.layout = dmatrix.layout; - ret.zero_point = dmatrix.zero_point; - return ret; -} - -template -PackedMatrix ToPackedMatrix(const PMatrix& pmatrix) { - using SumsType = typename PackedMatrix::SumsType; - pmatrix.data_type.AssertIs(); - pmatrix.sums_type.AssertIs(); - PackedMatrix ret; - ret.data = static_cast(pmatrix.data); - ret.sums = static_cast(pmatrix.sums); - ret.layout = pmatrix.layout; - ret.zero_point = pmatrix.zero_point; - return ret; -} - -// Helpers for Layout / PackedLayout. - -inline bool IsPacked(const Layout& layout) { - if (layout.order == Order::kColMajor) { - return layout.stride == layout.rows; - } else { - return layout.stride == layout.cols; - } -} - -inline bool IsRowMajor(const Layout& layout) { - return layout.order == Order::kRowMajor; -} - -template -inline bool IsColMajor(const LayoutOrPackedLayout& layout) { - return layout.order == Order::kColMajor; -} - -template -inline int FlatSize(const LayoutOrPackedLayout& layout) { - const int outerdim = - layout.order == Order::kColMajor ? layout.cols : layout.rows; - return layout.stride * outerdim; -} - -// TODO(b/130417400) add a unit test -inline int Offset(const Layout& layout, int row, int col) { - // TODO(benoitjacob) - should check this but this make the _slow tests take - // 5x longer. Find a mitigation like in Eigen with an 'internal' variant - // bypassing the check? - // RUY_DCHECK_GE(row, 0); - // RUY_DCHECK_GE(col, 0); - // RUY_DCHECK_LT(row, layout.rows); - // RUY_DCHECK_LT(col, layout.cols); - int row_stride = layout.order == Order::kColMajor ? 1 : layout.stride; - int col_stride = layout.order == Order::kRowMajor ? 1 : layout.stride; - return row * row_stride + col * col_stride; -} - -// TODO(b/130417400) add a unit test -inline int Offset(const PackedLayout& layout, int row, int col) { - RUY_DCHECK(is_pot(layout.kernel.rows)); - RUY_DCHECK(is_pot(layout.kernel.cols)); - int row_outer = row & ~(layout.kernel.rows - 1); - int col_outer = col & ~(layout.kernel.cols - 1); - int row_stride_outer = - layout.order == Order::kColMajor ? layout.kernel.cols : layout.stride; - int col_stride_outer = - layout.order == Order::kRowMajor ? layout.kernel.rows : layout.stride; - int offset_outer = - row_outer * row_stride_outer + col_outer * col_stride_outer; - int row_inner = row - row_outer; - int col_inner = col - col_outer; - int row_stride_inner = - layout.kernel.order == Order::kColMajor ? 1 : layout.kernel.cols; - int col_stride_inner = - layout.kernel.order == Order::kRowMajor ? 1 : layout.kernel.rows; - int offset_inner = - row_inner * row_stride_inner + col_inner * col_stride_inner; - return offset_outer + offset_inner; -} - -// Helpers for Matrix. - -template -const Scalar* ElementPtr(const Matrix& mat, int row, int col) { - return mat.data.get() + Offset(mat.layout, row, col); -} - -template -Scalar* ElementPtr(Matrix* mat, int row, int col) { - return mat->data.get() + Offset(mat->layout, row, col); -} - -template -Scalar Element(const Matrix& mat, int row, int col) { - return *ElementPtr(mat, row, col); -} - -// Helpers for PackedMatrix. -// Duplicated from Matrix, but the duplication seems acceptable. - -template -const Scalar* ElementPtr(const PackedMatrix& mat, int row, int col) { - return mat.data + Offset(mat.layout, row, col); -} - -template -Scalar* ElementPtr(PackedMatrix* mat, int row, int col) { - return mat->data + Offset(mat->layout, row, col); -} - -template -Scalar Element(const PackedMatrix& mat, int row, int col) { - return *ElementPtr(mat, row, col); -} - -// Helpers for PMatrix. - -inline std::size_t DataSize(const PMatrix& packed) { - return FlatSize(packed.layout) * packed.data_type.size; -} - -inline std::size_t SumsSize(const PMatrix& packed) { - // Packed matrices are only relevant for Ruy's TrMul implementations. For - // TrMul, the number of sums is always equal to the number of columns. - return packed.layout.cols * packed.sums_type.size; -} - -// Transpose helpers. - -inline void Transpose(Order* order) { - *order = *order == Order::kColMajor ? Order::kRowMajor : Order::kColMajor; -} - -inline void Transpose(Layout* layout) { - Transpose(&layout->order); - std::swap(layout->rows, layout->cols); -} - -template -inline void Transpose(Matrix* matrix) { - Transpose(&matrix->layout); -} - -// Helpers for KernelLayout. - -template -KernelLayout ToKernelLayout() { - KernelLayout ret; - ret.order = FixedKernelLayout::kOrder; - ret.rows = FixedKernelLayout::kRows; - ret.cols = FixedKernelLayout::kCols; - return ret; -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_INTERNAL_MATRIX_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/kernel_arm.h b/tensorflow/lite/experimental/ruy/ruy/kernel_arm.h deleted file mode 100644 index 760f0f0b4b5..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/kernel_arm.h +++ /dev/null @@ -1,211 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_ARM_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_ARM_H_ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/kernel_common.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" -#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" -#include "tensorflow/lite/experimental/ruy/ruy/spec.h" -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -namespace ruy { - -#if RUY_PLATFORM(NEON) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -#if RUY_PLATFORM(NEON_64) -void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 4>& params); -void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 4>& params); -#elif RUY_PLATFORM(NEON_32) -void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 2>& params); -void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 2>& params); -#endif -void Kernel8bitNeonInOrder(const KernelParams8bit<4, 4>& params); -void Kernel8bitNeonDotprodOutOfOrder(const KernelParams8bit<8, 8>& params); -void Kernel8bitNeonDotprodOutOfOrder1Col(const KernelParams8bit<8, 8>& params); -void Kernel8bitNeonDotprodInOrder(const KernelParams8bit<8, 8>& params); - -#if RUY_PLATFORM(NEON_64) -template -struct Kernel> { - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - Tuning tuning = Tuning::kAuto; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, - int start_col, int end_row, int end_col, - Matrix* dst) const { - KernelParams8bit params; - MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, - dst, ¶ms); - if (dst->layout.cols == 1) { - Kernel8bitNeonOutOfOrder1Col(params); - return; - } - if (__builtin_expect(tuning == Tuning::kInOrder, true)) { - Kernel8bitNeonInOrder(params); - } else { - Kernel8bitNeonOutOfOrder(params); - } - } -}; -#endif - -#if RUY_PLATFORM(NEON_32) -template -struct Kernel> { - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - Tuning tuning = Tuning::kAuto; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, - int start_col, int end_row, int end_col, - Matrix* dst) const { - KernelParams8bit params; - MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, - dst, ¶ms); - if (dst->layout.cols == 1) { - Kernel8bitNeonOutOfOrder1Col(params); - return; - } - Kernel8bitNeonOutOfOrder(params); - } -}; -#endif - -#if RUY_PLATFORM(NEON_64) -template -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, - int start_col, int end_row, int end_col, - Matrix* dst) const { - KernelParams8bit params; - MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, - dst, ¶ms); - if (dst->layout.cols == 1) { - Kernel8bitNeonDotprodOutOfOrder1Col(params); - } else if (__builtin_expect(tuning == Tuning::kInOrder, true)) { - Kernel8bitNeonDotprodInOrder(params); - } else { - Kernel8bitNeonDotprodOutOfOrder(params); - } - } -}; -#endif - -void KernelFloatNeonOutOfOrder(const KernelParamsFloat<8, 8>& params); -void KernelFloatNeonInOrder(const KernelParamsFloat<8, 8>& params); -void KernelFloat32NeonOutOfOrder(const KernelParamsFloat<8, 4>& params); -void KernelFloatNeonDotprodInOrder(const KernelParamsFloat<8, 8>& params); - -#if RUY_PLATFORM(NEON_64) -// A Float kernel for ARM64 Neon. -template <> -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, int start_col, - int end_row, int end_col, Matrix* dst) const { - KernelParamsFloat params; - MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, - end_col, dst, ¶ms); - if (__builtin_expect(tuning == Tuning::kInOrder, true)) { - KernelFloatNeonInOrder(params); - } else { - KernelFloatNeonOutOfOrder(params); - } - } -}; -#endif - -#if RUY_PLATFORM(NEON_32) -// A Float kernel for ARM32 Neon. -template <> -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, int start_col, - int end_row, int end_col, Matrix* dst) const { - KernelParamsFloat<8, 4> params; - - MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, - end_col, dst, ¶ms); - - KernelFloat32NeonOutOfOrder(params); - } -}; -#endif - -// While the dotprod NEON extension does not concern floating-point arithmetic, -// its presence allows us to distinguish, in the in-order tuning case, between -// A53 and A55r1. TODO: should this be folded into tuning? -template <> -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - using Base = - Kernel>; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, int start_col, - int end_row, int end_col, Matrix* dst) const { - KernelParamsFloat params; - MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, - end_col, dst, ¶ms); - if (__builtin_expect(tuning == Tuning::kInOrder, true)) { - KernelFloatNeonDotprodInOrder(params); - } else { - KernelFloatNeonOutOfOrder(params); - } - } -}; - -#endif // RUY_PLATFORM(NEON) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_ARM_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/kernel_arm32.cc b/tensorflow/lite/experimental/ruy/ruy/kernel_arm32.cc deleted file mode 100644 index 673f2616f02..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/kernel_arm32.cc +++ /dev/null @@ -1,2499 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/kernel.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -namespace ruy { - -#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -#define RUY_ASM_LABEL_STORE_UINT8 91 -#define RUY_ASM_LABEL_STORE_INT8 92 -#define RUY_ASM_LABEL_STORE_INT16 93 -#define RUY_ASM_LABEL_STORE_INT32 94 -#define RUY_ASM_LABEL_AFTER_STORE 99 - -#define RUY_OFFSET_LHS_BASE_PTR 0 -#define RUY_OFFSET_RHS_BASE_PTR 4 -#define RUY_OFFSET_DST_BASE_PTR 8 -#define RUY_OFFSET_BIAS 12 -#define RUY_OFFSET_START_ROW 16 -#define RUY_OFFSET_START_COL 20 -#define RUY_OFFSET_LAST_ROW 24 -#define RUY_OFFSET_LAST_COL 28 -#define RUY_OFFSET_DST_ROWS 32 -#define RUY_OFFSET_DST_COLS 36 -#define RUY_OFFSET_LHS_STRIDE 40 -#define RUY_OFFSET_RHS_STRIDE 44 -#define RUY_OFFSET_DST_STRIDE 48 -#define RUY_OFFSET_DEPTH 52 -#define RUY_OFFSET_CLAMP_MIN 56 -#define RUY_OFFSET_CLAMP_MAX 60 -#define RUY_OFFSET_FLAGS 64 - -#define RUY_STACK_OFFSET_SIZE 96 -#define RUY_STACK_OFFSET_DST_COL_PTR 0 -#define RUY_STACK_OFFSET_DST_PTR 16 -#define RUY_STACK_OFFSET_ROW 32 -#define RUY_STACK_OFFSET_COL 48 -#define RUY_STACK_OFFSET_LHS_COL_PTR 64 -#define RUY_STACK_OFFSET_RHS_COL_PTR 80 - -template -void CheckOffsetsInKernelParamsFloat32(const Params&) { - static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); - static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, ""); - static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, ""); - static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); - static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); - static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, ""); - static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); - static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); - static_assert(offsetof(Params, dst_rows) == RUY_OFFSET_DST_ROWS, ""); - static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); - static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); - static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); - static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); - static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); - static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); - static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); -} - -// Float kernel for ARM32 out-of-order cores. -// Just like Float 64 version, except accumulate in to 8x4 block to only -// use 16 128-bit NEON registers. This is a "first pass" kernel and not -// tuned. It is meant to run on out-of-order CPUs like the Krait 400 or A9. -void KernelFloat32NeonOutOfOrder(const KernelParamsFloat<8, 4>& params) { - CheckOffsetsInKernelParamsFloat32(params); - profiler::ScopeLabel label( - "Kernel (kNeon, optimized for out-of-order cores)"); - - const float* lhs_ptr = params.lhs_base_ptr; - const float* rhs_ptr = params.rhs_base_ptr; - // In ARM32 NEON, there are 16 128-bit "q" registers. These registers are - // each composed of two 64-bit "d" registers. The asm kernel below has the - // following NEON register allocation: - // Registers q3 -- q10 are accumulators. During accumulation, - // q0 -- q2 (d0 -- d5) are used to load data from LHS and RHS. q0 and q1 - // are used to load a 8x1 block of LHS, and q2 is used to load a 1x4 block - // of RHS, like this: - - // Register layout in "q" registers: - // RHS 1x4 block - // /--------------------------\ - // |q2.s[0] ... q2.s[3] | - // \--------------------------/ - // LHS 8x1 block - // /---------------------\ /--------------------- \ - // | q0.s[0] | | q3.s[0] ... q9.s[0] | - // | ... | | ... ... | - // | q0.s[3] | | q3.s[3] q9.s[3] | - // | q1.s[0] | | q4.s[0] q10.s[0] | - // | ... | | ... ... ... | - // | q1.s[3] | | q4.s[3] .. q10.s[3] | - // \---------------------/ \--------------------------/ - // accumulators 8x4 block - // q11, q14, q15 currently unused. q12 and q13 are used to load - // parameters used for the post-accumulation part of the kernel. - // For completeness, here is the register layout in "d" registers: - // RHS 1x4 block - // /--------------------------\ - // |d4[0] ... d5[1] | - // \--------------------------/ - // LHS 8x1 block - // /---------------------\ /--------------------------\ - // | d0[0] | | d6[0] ... d18[0] | - // | ... | | ... ... | - // | d1[1] | | d7[1] d19[1] | - // | d2[0] | | d8[0] d20[0] | - // | ... | | ... ... ... | - // | d3[1] | | d9[1] ... d21[1] | - // \---------------------/ \--------------------------/ - // accumulators 8x4 block - asm volatile( -#define RUY_MAKE_ZERO(reg) "vmov.f32 " #reg ", #0.0\n" - - // clang-format off - - // Load the first 32 bytes of LHS and RHS data. - // Load q0, q1 - "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n" - "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" - RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") - // Load q2 - "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" - RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") - - "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" - "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n" - "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n" - "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - // Clear accumulators. - RUY_MAKE_ZERO(q3) - RUY_MAKE_ZERO(q4) - RUY_MAKE_ZERO(q5) - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - - // r1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 1. - "mov r1, #1\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - // Accumulation loop - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - "cmp r1, r2\n" - "beq 79f\n" - - "2:\n" - - "vmla.f32 q3, q0, d4[0]\n" - "vmla.f32 q5, q0, d4[1]\n" - "vmla.f32 q7, q0, d5[0]\n" - "vmla.f32 q9, q0, d5[1]\n" - "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n" // Reload LHS - - "vmla.f32 q4, q1, d4[0]\n" - "vmla.f32 q6, q1, d4[1]\n" - "vmla.f32 q8, q1, d5[0]\n" - "vmla.f32 q10, q1, d5[1]\n" - "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS - RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") - "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" // Reload RHS - RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") - - "add r1, r1, #1\n" - "cmp r1, r2\n" - - "blt 2b\n" - - "79:\n" - - // End of the inner loop on depth. Now perform the remaining - // multiply-adds of the last level of depth, for which the LHS - // and RHS data is already loaded. - - "vmla.f32 q3, q0, d4[0]\n" - "vmla.f32 q5, q0, d4[1]\n" - "vmla.f32 q7, q0, d5[0]\n" - "vmla.f32 q9, q0, d5[1]\n" - - "vmla.f32 q4, q1, d4[0]\n" - "vmla.f32 q6, q1, d4[1]\n" - "vmla.f32 q8, q1, d5[0]\n" - "vmla.f32 q10, q1, d5[1]\n" - - // End of accumulation. The registers q3 -- q10 contain the final - // float32 accumulator values of the current 8x8 destination block. - // We now have to compute the final values from these accumulators - // and advance to the next 8x8 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "cmp r1, r3\n" // Have we finished the last row? - - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "add r4, r4, r1, lsl #3\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "b 5f\n" - "4:\n" // Finished last row... - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - // Go back to first row - "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "cmp r8, r4\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "add r10, r10, r1, lsl #2\n" - "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "mov %[lhs_ptr], r4\n" - "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "mov %[rhs_ptr], r5\n" - - // Load some parameters needed for the end work on current block. - "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - - // Offset these base pointers as needed given the current row, col. - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "add r5, r1, r8, lsl #2\n" - - "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "it ne\n" - "movne r1, r5\n" - - // Load 8 bias values. - "vld1.32 {d24, d25, d26, d27}, [r1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into q0 -- q2, as we don't need q0 -- q2 anymore - // in the rest of the work on the current block. - // Load q0, q1 - "vld1.32 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" - RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") - // Load q2 - "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" - RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "vadd.f32 q3, q3, q12\n" - "vadd.f32 q4, q4, q13\n" - "vadd.f32 q5, q5, q12\n" - "vadd.f32 q6, q6, q13\n" - "vadd.f32 q7, q7, q12\n" - "vadd.f32 q8, q8, q13\n" - "vadd.f32 q9, q9, q12\n" - "vadd.f32 q10, q10, q13\n" - - // Load the clamp_min, clamp_max bounds - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "vdup.32 q12, r2\n" // clamp_min - "vdup.32 q13, r3\n" // clamp_max - - // Apply the clamp_min bound - "vmax.f32 q3, q3, q12\n" - "vmax.f32 q4, q4, q12\n" - "vmax.f32 q5, q5, q12\n" - "vmax.f32 q6, q6, q12\n" - "vmax.f32 q7, q7, q12\n" - "vmax.f32 q8, q8, q12\n" - "vmax.f32 q9, q9, q12\n" - "vmax.f32 q10, q10, q12\n" - - // Apply the clamp_max bound - "vmin.f32 q3, q3, q13\n" - "vmin.f32 q4, q4, q13\n" - "vmin.f32 q5, q5, q13\n" - "vmin.f32 q6, q6, q13\n" - "vmin.f32 q7, q7, q13\n" - "vmin.f32 q8, q8, q13\n" - "vmin.f32 q9, q9, q13\n" - "vmin.f32 q10, q10, q13\n" - - // Compute how much of the 8x4 block of destination values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x4, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #8\n" - "mov r5, #4\n" - "cmp r1, #8\n" - // Compute r1 = how many rows of the 8x4 block fit - "it gt\n" - "movgt r1, r3\n" - "cmp r2, #4\n" - // Compute r2 = how many cols of the 8x4 block fit - "it gt\n" - "movgt r2, r5\n" - - // Test if r1==8 && r2 == 4, i.e. if all of the 8x4 block fits. - "cmp r1, r3\n" - "it eq\n" - "cmpeq r2, r5\n" - // Yes, all of the 8x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 8x4 block fits. - // Set (r3 address, r4 stride) to write to dst_tmp_buf - "mov r3, %[dst_tmp_buf]\n" - "mov r4, #32\n" - "b 31f\n" - "30:\n" - // Yes, all of the 8x4 block fits. - // Set (r3 address, r4 stride) to write directly to destination matrix. - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r5\n" - "31:\n" - - // Write our float values to the destination described by - // (r3 address, r4 stride) - "vst1.32 {d6, d7, d8, d9}, [r3]\n" - "add r3, r3, r4\n" - RUY_MAKE_ZERO(q3) - RUY_MAKE_ZERO(q4) - "vst1.32 {d10, d11, d12, d13}, [r3]\n" - "add r3, r3, r4\n" - RUY_MAKE_ZERO(q5) - RUY_MAKE_ZERO(q6) - "vst1.32 {d14, d15, d16, d17}, [r3]\n" - "add r3, r3, r4\n" - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - "vst1.32 {d18, d19, d20, d21}, [r3]\n" - "add r3, r3, r4\n" - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - - // If all of the 8x4 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "mov r3, %[dst_tmp_buf]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r6, #0\n" - "50:\n" - "mov r5, #0\n" - "51:\n" - "ldr r10, [r3, r5, lsl #2]\n" - "str r10, [r4, r5, lsl #2]\n" - "add r5, r5, #1\n" - "cmp r5, r1\n" - "blt 51b\n" - "add r6, r6, #1\n" - "add r3, r3, #32\n" - "add r4, r4, r8\n" - // r2 = how many cols of the 8x4 block fit - "cmp r6, r2\n" - "blt 50b\n" - "41:\n" - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #32\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - // Reload some params --- we had used r3, r5, r10 for a few other things - // since the last time we had loaded them. - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "cmp r8, r3\n" - - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add r8, r8, #8\n" - // Store new value of row - "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - - "b 21f\n" - "20:\n" - // Was already at end row. - // Move back to first row. - "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - // Move to the next column. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "add r4, r4, #4\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - - "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - // Increment dst_col_ptr by 4 * dst_stride (i.e. 4 columns) - "add r1, r1, r8, lsl #2\n" - // Store dst_col_ptr - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - // Store dst_ptr - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "cmp r8, r4\n" - - // r1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 1. - "mov r1, #1\n" - - "ble 1b\n" - - // Restore stack pointer. - "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" - - // clang-format on - : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr) - : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf) - // Clobber list must specify q registers (and not their constituent - // d registers). There is a (currently unexplained) slowdown if - // d registers are listed in the clobbers list. - : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc", - "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", - "q9", "q10", "q12", "q13"); -} - -#undef RUY_MAKE_ZERO -#undef RUY_STACK_OFFSET_SIZE -#undef RUY_STACK_OFFSET_DST_COL_PTR -#undef RUY_STACK_OFFSET_DST_PTR -#undef RUY_STACK_OFFSET_ROW -#undef RUY_STACK_OFFSET_COL -#undef RUY_STACK_OFFSET_LHS_COL_PTR -#undef RUY_STACK_OFFSET_RHS_COL_PTR - -#undef RUY_OFFSET_LHS_BASE_PTR -#undef RUY_OFFSET_RHS_BASE_PTR -#undef RUY_OFFSET_DST_BASE_PTR -#undef RUY_OFFSET_BIAS -#undef RUY_OFFSET_START_ROW -#undef RUY_OFFSET_START_COL -#undef RUY_OFFSET_LAST_ROW -#undef RUY_OFFSET_LAST_COL -#undef RUY_OFFSET_DST_ROWS -#undef RUY_OFFSET_DST_COLS -#undef RUY_OFFSET_LHS_STRIDE -#undef RUY_OFFSET_RHS_STRIDE -#undef RUY_OFFSET_DST_STRIDE -#undef RUY_OFFSET_DEPTH -#undef RUY_OFFSET_CLAMP_MIN -#undef RUY_OFFSET_CLAMP_MAX -#undef RUY_OFFSET_FLAGS - -#define RUY_OFFSET_BIAS 0 -#define RUY_OFFSET_LHS_SUMS 4 -#define RUY_OFFSET_RHS_SUMS 8 -#define RUY_OFFSET_LHS_BASE_PTR 12 -#define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 16 -#define RUY_OFFSET_MULTIPLIER_EXPONENT 20 -#define RUY_OFFSET_RHS_BASE_PTR 24 -#define RUY_OFFSET_DST_BASE_PTR 28 -#define RUY_OFFSET_LHS_ZERO_POINT 32 -#define RUY_OFFSET_RHS_ZERO_POINT 36 -#define RUY_OFFSET_DST_ZERO_POINT 40 -#define RUY_OFFSET_PROD_ZP_DEPTH 44 -#define RUY_OFFSET_START_ROW 48 -#define RUY_OFFSET_START_COL 52 -#define RUY_OFFSET_LAST_ROW 56 -#define RUY_OFFSET_LAST_COL 60 -#define RUY_OFFSET_DST_ROWS 64 -#define RUY_OFFSET_DST_COLS 68 -#define RUY_OFFSET_LHS_STRIDE 72 -#define RUY_OFFSET_RHS_STRIDE 76 -#define RUY_OFFSET_DST_STRIDE 80 -#define RUY_OFFSET_DEPTH 84 -#define RUY_OFFSET_CLAMP_MIN 88 -#define RUY_OFFSET_CLAMP_MAX 92 -#define RUY_OFFSET_FLAGS 96 -#define RUY_OFFSET_DST_TYPE_ID 97 - -#define RUY_STACK_OFFSET_SIZE 96 -#define RUY_STACK_OFFSET_DST_COL_PTR 0 -#define RUY_STACK_OFFSET_DST_PTR 16 -#define RUY_STACK_OFFSET_ROW 32 -#define RUY_STACK_OFFSET_COL 48 -#define RUY_STACK_OFFSET_LHS_COL_PTR 64 -#define RUY_STACK_OFFSET_RHS_COL_PTR 80 - -template -void CheckOffsetsInKernelParams8bit(const Params&) { - static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT, - ""); - static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT, - ""); - static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT, - ""); - static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH, - ""); - static_assert(offsetof(Params, multiplier_fixedpoint) == - RUY_OFFSET_MULTIPLIER_FIXEDPOINT, - ""); - static_assert( - offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT, - ""); - static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); - static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); - static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); - static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, ""); - static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, ""); - static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); - static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); - static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); - static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); - static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); - static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); - static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); - static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); - static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); -} - -// Fast-int8 kernel, ported from ARM 64 version. -// Relevant target CPUs for this kernel include Krait 400 and A9, -// since these are 32-bit, out-of-order CPUs. -void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 2>& params) { - profiler::ScopeLabel label( - "Kernel (kNeon, optimized for out-of-order cores)"); - - CheckOffsetsInKernelParams8bit(params); - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - - // The asm kernel below has the following NEON register allocation: - // - // q6 - q13 are 128-bit (4x32b) accumulators. - // During accumulation, d0 -- d7 are used to load int8 data from LHS and - // d8 -- d11 from RHS: - // int8 RHS 16x2 block - // /-----------------------------\ - // |d8.b[0-7] ..... d10.b[0-7]| - // | ... ... | - // |d9.b[0-7] ..... d11.b[0-7]| - // \-----------------------------/ - // int8 LHS 4x16 block - // /------------------------\ /-----------------------------\ - // |d0.b[0-7] ... d1.b[0-7] | | q6 ..... q10 | - // |d2.b[0-7] ... d3.b[0-7] | | q7 ..... q11 | - // (Reload d0, d1, d2, d3) - // |d0.b[0-7] ... d1.b[0-7] | | q8 ..... q12 | - // |d2.b[0-7] ... d3.b[0-7] | | q9 ..... q13 | - // \------------------------/ \-----------------------------/ - // 128-bit accumulators 4x2 block - // - // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING - // optimization for this kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n" - - // clang-format off - - // Load the first 64 bytes of LHS and RHS data. - "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" - // Clear accumulators. - RUY_MAKE_ZERO(q6) - "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" - RUY_MAKE_ZERO(q11) - "vld1.8 {d10, d11}, [%[rhs_ptr]]!\n" - - "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" - RUY_MAKE_ZERO(q12) - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" - RUY_MAKE_ZERO(q13) - "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - RUY_MAKE_ZERO(q14) - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n" - RUY_MAKE_ZERO(q15) - "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n" - "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - - - // r1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 16. - "mov r1, #16\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - // r1 is how many levels of depth we have already loaded - // data for, r10 is the total depth. - "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - "cmp r1, r10\n" - "beq 79f\n" - - "2:\n" - - // Mult, mult-acc in to q14, q15, q2, q3 - "vmull.s8 q14, d0, d8\n" - "vmull.s8 q2, d0, d10\n" - - "vmull.s8 q15, d2, d8\n" - "vmull.s8 q3, d2, d10\n" - - "vmlal.s8 q14, d1, d9\n" - "vmlal.s8 q2, d1, d11\n" - "vmlal.s8 q15, d3, d9\n" - "vmlal.s8 q3, d3, d11\n" - "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS - - // Then pairwise accumulate in to q6, q7, q10, q11 - "vpadal.s16 q6, q14\n" - "vpadal.s16 q7, q15\n" - "vpadal.s16 q10, q2\n" - "vpadal.s16 q11, q3\n" - - // Mult, mult-acc in to q14, q15, q2, q3 - "vmull.s8 q14, d0, d8\n" - "vmull.s8 q2, d0, d10\n" - - "vmull.s8 q15, d2, d8\n" - "vmull.s8 q3, d2, d10\n" - - "vmlal.s8 q14, d1, d9\n" - "vmlal.s8 q2, d1, d11\n" - "vmlal.s8 q15, d3, d9\n" - "vmlal.s8 q3, d3, d11\n" - "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS - - // Then pairwise accumulate in to q8, q9, q12, q13 - "vpadal.s16 q8, q14\n" - "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n" - "vpadal.s16 q9, q15\n" - "vpadal.s16 q12, q2\n" - "vpadal.s16 q13, q3\n" - - // Prefetch the next 64 bytes of LHS and RHS data. - RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") - RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") - - // Each iteration of this loop advances by 16 levels of depth. - "add r1, r1, #16\n" - - // Loop termination condition - "cmp r1, r10\n" - - "blt 2b\n" - - "79:\n" - - // Mult, mult-acc in to q14, q15, q2, q3 - "vmull.s8 q14, d0, d8\n" - "vmull.s8 q2, d0, d10\n" - - "vmull.s8 q15, d2, d8\n" - "vmull.s8 q3, d2, d10\n" - - "vmlal.s8 q14, d1, d9\n" - "vmlal.s8 q2, d1, d11\n" - "vmlal.s8 q15, d3, d9\n" - "vmlal.s8 q3, d3, d11\n" - "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS - - // Then pairwise accumulate in to q6, q7, q10, q11 - "vpadal.s16 q6, q14\n" - "vpadal.s16 q7, q15\n" - "vpadal.s16 q10, q2\n" - "vpadal.s16 q11, q3\n" - - // Mult, mult-acc in to q14, q15, q2, q3 - "vmull.s8 q14, d0, d8\n" - "vmull.s8 q2, d0, d10\n" - - "vmull.s8 q15, d2, d8\n" - "vmull.s8 q3, d2, d10\n" - - "vmlal.s8 q14, d1, d9\n" - "vmlal.s8 q2, d1, d11\n" - "vmlal.s8 q15, d3, d9\n" - "vmlal.s8 q3, d3, d11\n" - - // Then pairwise accumulate in to q8, q9, q12, q13 - "vpadal.s16 q8, q14\n" - "vpadal.s16 q9, q15\n" - "vpadal.s16 q12, q2\n" - "vpadal.s16 q13, q3\n" - - - // All accumulation over depth done. q6 - q13 contain the 4x32b - // accumulators for the 4x2 final matrix. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 4x2 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // q6-q13 now contain 4 x 32b - "vpadd.i32 d0, d12, d13\n" - "vpadd.i32 d1, d14, d15\n" - "vpadd.i32 d2, d16, d17\n" - "vpadd.i32 d3, d18, d19\n" - "vpadd.i32 d4, d20, d21\n" - "vpadd.i32 d5, d22, d23\n" - "vpadd.i32 d6, d24, d25\n" - "vpadd.i32 d7, d26, d27\n" - - // d0-d7 each contain 2 x 32b accumulators. - // Need to add pairwise to get 1 x 32b for each of the 4x2 entries - // of destination, (Four 'd' registers total) - "vpadd.i32 d28, d0, d1\n" - "vpadd.i32 d29, d2, d3\n" - "vpadd.i32 d30, d4, d5\n" - "vpadd.i32 d31, d6, d7\n" - - //Now d28 - d31 have the 1 x 32b accumulators for the 4x2 entries - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "cmp r1, r3\n" // Have we finished the last row? - - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "add r4, r4, r1, lsl #2\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "b 5f\n" - "4:\n" // Finished last row... - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - // Go back to first row - "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "cmp r8, r4\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "add r10, r10, r1, lsl #1\n" - "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "mov %[lhs_ptr], r4\n" - "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "mov %[rhs_ptr], r5\n" - - // Now we load: bias data, LHS sums data, RHS sums data. - - // First, load the base pointers from the params. - "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - - // Offset these base pointers as needed given the current row, col. - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "add r5, r1, r8, lsl #2\n" - - "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "it ne\n" - "movne r1, r5\n" - - // Load 4 bias values. - "vld1.32 {d24, d25}, [r1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" - RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") - "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n" - RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") - - // Add to the bias values the product - // (depth * lhs_zero_point * rhs_zero_point), - // See the term NZ1Z2 in equation (7) in - // https://arxiv.org/pdf/1712.05877.pdf - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" - "vdup.32 q9, r3\n" - "vadd.i32 q12, q12, q9\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "vadd.i32 q14, q14, q12\n" - "vadd.i32 q15, q15, q12\n" - - // LHS/RHS zero points - // Has RHS sums - "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" - "beq 401f\n" - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - // Offset by current col * number of bytes per value - "add r3, r3, r4, lsl #2\n" - "vld1.32 { d12 }, [r3]\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" - "vdup.32 q10, r5\n" // create lhs_zero_point_vec - // Subtract rhs_sums * lhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "vmls.i32 q14, q10, d12[0]\n" - "vmls.i32 q15, q10, d12[1]\n" - "401:\n" - - // Has LHS sums - "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" - "beq 402f\n" - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - // Offset by current row * number of bytes per value - "add r2, r2, r4, lsl #2\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" - - // Load 4 lhs_sums values. - "vld1.32 {d22, d23}, [r2]\n" - "vdup.32 d13, r5\n" // rhs_zero_point - - // Compute lhs_sums * rhs_zero_point. - "vmul.i32 q11, q11, d13[1]\n" - // Subtract lhs_sums * rhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "vsub.s32 q14, q14, q11\n" - "vsub.s32 q15, q15, q11\n" - - // If the destination is int32, it means the user asks for the raw - // accumulators, no need for us to downquantize the value. - "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" - "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" - - "402:\n" - - // At this point we have computed the final int32 values. Now we - // start down-quantizing them to obtain the final 8bit values from them. - - // As part of this down-quantization, our int32 values will be - // multiplied by a multiplier that has a fixed-point component and an - // exponent component. - - //Load the exponent part of the multiplier. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" - "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "add r5, r1, r4, lsl #2\n" - "it ne\n" - "movne r1, r5\n" - - "vld1.32 {q10}, [r1]\n" - - RUY_MAKE_ZERO(q8) - "vmax.s32 q12, q10, q8\n" - - "vshl.s32 q14, q14, q12\n" - "vshl.s32 q15, q15, q12\n" - - "vmin.s32 q12, q10, q8\n" - - // Load fixed point part of the multiplier - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" - // r6 has flags, r4 has row - "add r5, r1, r4, lsl #2\n" - "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "it ne\n" - "movne r1, r5\n" - "vld1.32 {q10}, [r1]\n" // multiplier_fixedpoint - - // Apply the fixed-point part of the multiplier. - "vqrdmulh.s32 q14, q14, q10\n" - "vqrdmulh.s32 q15, q15, q10\n" - - // We have some rounding division-by-power-of-two to do. This should - // always use "round to nearest". We allow for some - // freedom in how ties are broken, to strike a good compromise of - // performance on given hardware vs. perfect agreement of results - // across hardware. - // - // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation - // defined tie-breaks to help performance. On NEON, this means that we - // can just use the NEON rounding instructions, such as srshl. They - // happen to be breaking ties upward. - // - // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict - // break-ties-away-from zero, as described in Appendix B of - // https://arxiv.org/pdf/1712.05877.pdf - // When we wrote that, we thought that that would be better unbiased - // than the NEON upwards tie-breaks, and we had observed some - // improvement on some model. However, that is only more unbiased for - // data centered at zero, which was likely the case in that model, - // but is not always the case. If we wanted something more consistently - // unbiased then we should try breaking ties toward-nearest-even. -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - // Fix up values to be right-shifted, so that the (round to nearest, - // break ties upward) behavior of srshl applied to these fixed-up - // values, produces the same result as the desired (round to nearest, - // break ties away from zero) behavior on the original values. - "vand q8, q14, q12\n" - "vand q9, q15, q12\n" - "vshr.s32 q8, q8, #31\n" - "vshr.s32 q9, q9, #31\n" - "vqadd.s32 q14, q14, q8\n" - "vqadd.s34 q15, q15, q9\n" - -#endif - // At this point we have reduced the problem of correctly implementing - // rounding divide-by-power-of-two, to what the SRSHL instruction can - // do. - "vrshl.s32 q14, q14, q12\n" - "vrshl.s32 q15, q15, q12\n" - - "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" - "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" - "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" - - // Store uint8 values: - RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" - - // Cast-and-saturate from int32 to int16 - // After this, all values for output are in q14. - "vqmovn.s32 d28, q14\n" - "vqmovn.s32 d29, q15\n" - - // At this point, d12 -- d26, d30, d31 aren't used anymore for the - // current block, so we can start clearing these accumulators for the - // next block (next iteration of the main loop). - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q15) - - // Load the destination zero point into each of the 8 16-bit slots - // in a q register. - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "vdup.16 q13, r4\n" // dst_zero_point - - // Add the destination zero point - "vadd.i16 q14, q14, q13\n" - - // Cast-and-saturate from int16 to uint8 - // Now all 8 1-byte values are in d30. - "vqmovun.s16 d30, q14\n" - - // Load the clamp_min, clamp_max bounds - "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "vdup.8 d28, r2\n" // clamp_min - "vdup.8 d29, r3\n" // clamp_max - - // Apply the clamp_min bound - "vmax.u8 d30, d30, d28\n" - // Apply the clamp_max bound - "vmin.u8 d30, d30, d29\n" - - // Compute how much of the 4x2 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x2, there are some 4x2 blocks along the boundaries that do - // not fit entirely. - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #4\n" - "mov r5, #2\n" - "cmp r1, #4\n" - // Compute r1 = how many rows of the 4x2 block fit - "it gt\n" - "movgt r1, r3\n" - - "cmp r2, #2\n" - // Compute r2 = how many cols of the 4x2 block fit - "it gt\n" - "movgt r2, r5\n" - - // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. - "cmp r1, r3\n" - "it eq\n" - "cmpeq r2, r5\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - // Yes, all of the 4x2 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x2 block fits. - // Store to dst_tmp_buf - // Set r3 address to write to dst_tmp_buf. - "mov r3, %[dst_tmp_buf]\n" - "vst1.8 {d30}, [r3]\n" - - // Slow loop copying from dst_tmp_buf to dst. - "mov r6, #0\n" - "50:\n" - "mov r8, #0\n" - "51:\n" - "ldrb r10, [r3, r8]\n" - "strb r10, [r4, r8]\n" - "add r8, r8, #1\n" - "cmp r8, r1\n" - "blt 51b\n" - "add r6, r6, #1\n" - "add r3, r3, #4\n" - "add r4, r4, r5\n" - "cmp r6, r2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x2 block fits. - // r3 address, r5 stride - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r3\n" - "mov r6, #1\n" - - "vst1.32 {d30[0]}, [r3]\n" - "add r4, r4, r5\n" - "mov r3, r4\n" - "vst1.32 {d30[1]}, [r3]\n" - - "31:\n" - - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #4\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q14) - RUY_MAKE_ZERO(q15) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - // Store int8 values: - RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" - - // Cast-and-saturate from int32 to int16 - // After this, all values for output are in q14. - "vqmovn.s32 d28, q14\n" - "vqmovn.s32 d29, q15\n" - - // At this point, d12 -- d26, d30, d31 aren't used anymore for the - // current block, so we can start clearing these accumulators for the - // next block (next iteration of the main loop). - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q15) - - // Load the destination zero point into each of the 8 16-bit slots - // in a q register. - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "vdup.16 q13, r4\n" // dst_zero_point - - // Add the destination zero point - "vadd.i16 q14, q14, q13\n" - - // Cast-and-saturate from int16 to int8 - // Now all 8 1-byte values are in d30. - "vqmovn.s16 d30, q14\n" - - // Load the clamp_min, clamp_max bounds - "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "vdup.8 d28, r2\n" // clamp_min - "vdup.8 d29, r3\n" // clamp_max - - // Apply the clamp_min bound - "vmax.s8 d30, d30, d28\n" - // Apply the clamp_max bound - "vmin.s8 d30, d30, d29\n" - - // Compute how much of the 4x2 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x2, there are some 4x2 blocks along the boundaries that do - // not fit entirely. - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #4\n" - "mov r5, #2\n" - "cmp r1, #4\n" - // Compute r1 = how many rows of the 4x2 block fit - "it gt\n" - "movgt r1, r3\n" - - "cmp r2, #2\n" - // Compute r2 = how many cols of the 4x2 block fit - "it gt\n" - "movgt r2, r5\n" - - // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. - "cmp r1, r3\n" - "it eq\n" - "cmpeq r2, r5\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - // Yes, all of the 4x2 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x2 block fits. - // Store to dst_tmp_buf - // Set r3 address to write to dst_tmp_buf. - "mov r3, %[dst_tmp_buf]\n" - "vst1.8 {d30}, [r3]\n" - - // Slow loop copying from dst_tmp_buf to dst. - "mov r6, #0\n" - "50:\n" - "mov r8, #0\n" - "51:\n" - "ldrb r10, [r3, r8]\n" - "strb r10, [r4, r8]\n" - "add r8, r8, #1\n" - "cmp r8, r1\n" - "blt 51b\n" - "add r6, r6, #1\n" - "add r3, r3, #4\n" - "add r4, r4, r5\n" - "cmp r6, r2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x2 block fits. - // r3 address, r5 stride - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r3\n" - "mov r6, #1\n" - - "vst1.32 {d30[0]}, [r3]\n" - "add r4, r4, r5\n" - "mov r3, r4\n" - "vst1.32 {d30[1]}, [r3]\n" - - "31:\n" - - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #4\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q14) - RUY_MAKE_ZERO(q15) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" - - // Load the destination zero point into each of the 4 32-bit slots - // in a q register. - "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "vdup.32 q13, r4\n" // dst_zero_point - // Add the destination zero point - "vadd.s32 q14, q14, q13\n" - "vadd.s32 q15, q15, q13\n" - - // Cast-and-saturate from int32 to int16 - // After this, all values for output are in q14. - "vqmovn.s32 d28, q14\n" - "vqmovn.s32 d29, q15\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q15) - - // Load the clamp_min, clamp_max bounds - "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "vdup.16 q12, r2\n" // clamp_min - "vdup.16 q13, r3\n" // clamp_max - - // Apply the clamp_min bound - "vmax.s16 q14, q14, q12\n" - // Apply the clamp_max bound - "vmin.s16 q14, q14, q13\n" - - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - - // Compute how much of the 4x2 block of destination 16-bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x2, there are some 4x2 blocks along the boundaries that do - // not fit entirely. - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #4\n" - "mov r5, #2\n" - "cmp r1, #4\n" - // Compute r1 = how many rows of the 4x2 block fit - "it gt\n" - "movgt r1, r3\n" - - "cmp r2, #2\n" - // Compute r2 = how many cols of the 4x2 block fit - "it gt\n" - "movgt r2, r5\n" - - // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. - "cmp r1, r3\n" - "it eq\n" - "cmpeq r2, r5\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - // Yes, all of the 4x2 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x2 block fits. - // Store to dst_tmp_buf - // Set r3 address to write to dst_tmp_buf. - "mov r3, %[dst_tmp_buf]\n" - "vst1.16 {q14}, [r3]\n" - - // Slow loop copying from dst_tmp_buf to dst. - "mov r6, #0\n" - "50:\n" - "mov r8, #0\n" - "51:\n" - // Shift of offset register for half-word loads not allowed in A32, - // so we shift, load/store, then shift back r8. - "lsl r8, r8, #1\n" - "ldrh r10, [r3, r8]\n" - "strh r10, [r4, r8]\n" - "lsr r8, r8, #1\n" - "add r8, r8, #1\n" - "cmp r8, r1\n" - "blt 51b\n" - "add r6, r6, #1\n" - "add r3, r3, #8\n" - "add r4, r4, r5\n" - "cmp r6, r2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x2 block fits. - // r3 address, r5 stride - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r3\n" - "mov r6, #2\n" - - "vst1.16 {d28[0]}, [r3], r6\n" - "add r4, r4, r5\n" - "vst1.16 {d28[1]}, [r3], r6\n" - "vst1.16 {d28[2]}, [r3], r6\n" - "vst1.16 {d28[3]}, [r3], r6\n" - "mov r3, r4\n" - "vst1.16 {d29[0]}, [r3], r6\n" - "vst1.16 {d29[1]}, [r3], r6\n" - "vst1.16 {d29[2]}, [r3], r6\n" - "vst1.16 {d29[3]}, [r3], r6\n" - "31:\n" - - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #8\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - RUY_MAKE_ZERO(q14) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" - - // Since the store type is the same as the accum type, no need for - // downcast. There's also no need for clamp by min/max. - - // At this point, v20 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - // Clear accumulators. - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - - // Compute how much of the 4x2 block of destination 32 bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x2, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #4\n" - "mov r5, #2\n" - "cmp r1, #4\n" - // Compute r1 = how many rows of the 4x2 block fit - "it gt\n" - "movgt r1, r3\n" - - "cmp r2, #2\n" - // Compute r2 = how many cols of the 4x2 block fit - "it gt\n" - "movgt r2, r5\n" - - // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. - "cmp r1, r3\n" - "it eq\n" - "cmpeq r2, r5\n" - // Yes, all of the 4x2 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x2 block fits. - // Set (r3 address, r4 stride) to write to dst_tmp_buf - "mov r3, %[dst_tmp_buf]\n" - "mov r4, #16\n" - "b 31f\n" - - "30:\n" - // Yes, all of the 4x2 block fits. - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - // r3 address, r4 stride - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r5\n" - - "31:\n" - - "vst1.32 {d28, d29}, [r3]\n" - "add r3, r3, r4\n" - "vst1.32 {d30, d31}, [r3]\n" - - // If all of the 4x2 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 4x2 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "mov r3, %[dst_tmp_buf]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r6, #0\n" - "50:\n" - "mov r5, #0\n" - "51:\n" - "ldr r10, [r3, r5, lsl #2]\n" - "str r10, [r4, r5, lsl #2]\n" - "add r5, r5, #1\n" - "cmp r5, r1\n" - "blt 51b\n" - "add r6, r6, #1\n" - "add r3, r3, #16\n" - "add r4, r4, r8\n" - // r2 = how many cols of the 8x4 block fit - "cmp r6, r2\n" - "blt 50b\n" - - "41:\n" - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #16\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "cmp r8, r3\n" - - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add r8, r8, #4\n" - // Store new value of row - "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - - "b 21f\n" - "20:\n" - // Was already at end row. - // Move back to first row. - "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - // Move to the next column. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "add r4, r4, #2\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - - "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - // Increment dst_col_ptr by 2 * dst_stride (i.e. 2 columns) - "add r1, r1, r8, lsl #1\n" - // Store dst_col_ptr - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - // Store dst_ptr - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "cmp r8, r4\n" - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 16. - "mov r1, #16\n" - - "ble 1b\n" - - // Restore stack pointer. - "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" - - // clang-format on - - : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr) - : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf) - : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc", - // Clobber list must specify q registers (and not their constituent - // d registers). There is a (currently unexplained) slowdown if - // d registers are listed in the clobbers list. - "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", - "q9", "q10", "q12", "q13", "q14", "q15"); -} - -// Fast-int8 true "GEMV" kernel (RHS has 1 column). We assume the RHS -// is still packed as if it has two columns -void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 2>& params) { - profiler::ScopeLabel label( - "Kernel (kNeon, optimized for out-of-order cores)"); - - CheckOffsetsInKernelParams8bit(params); - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - - // The asm kernel below has the following NEON register allocation: - // - // q6 - q13 are 128-bit (4x32b) accumulators. - // During accumulation, d0 -- d7 are used to load int8 data from LHS and - // d8 -- d11 from RHS: - // int8 RHS 16x1 block - // /------------\ - // | d8.b[0] | - // | ... | - // | d8.b[7] | - // | d9.b[0] | - // | ... | - // | d9.b[7] | - // \------------/ - // int8 LHS 4x16 block - // /-----------------------------------------\ /------------\ - // |d0.b[0] ... d0.b[7] d1.b[0] ... d1.b[7] | | q6 | - // |d2.b[0] ... d2.b[7] d3.b[0] ... d3.b[7] | | q7 | - // |d4.b[0] ... d4.b[7] d5.b[0] ... d5.b[7] | | q8 | - // |d6.b[0] ... d6.b[7] d7.b[0] ... d7.b[7] | | q9 | - // \-----------------------------------------/ \------------/ - // 128-bit accumulators 4x1 block - // - // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING - // optimization for this kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n" - - // clang-format off - - // Load the first 64 bytes of LHS and RHS data. - "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" - "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" - "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n" - "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n" - "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" - // Skip the other column and advance the pointer. - "add %[rhs_ptr], %[rhs_ptr], #16\n" - - "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" - "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n" - "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n" - "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - - // Clear accumulators. - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q14) - RUY_MAKE_ZERO(q15) - - // r1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 16. - "mov r1, #16\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - // r1 is how many levels of depth we have already loaded - // data for, r10 is the total depth. - "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - "cmp r1, r10\n" - "beq 79f\n" - - "2:\n" - - // Mult, mult-acc in to q14, q15 - "vmull.s8 q14, d0, d8\n" - "vmull.s8 q15, d2, d8\n" - "vmlal.s8 q14, d1, d9\n" - "vmlal.s8 q15, d3, d9\n" - - // Then pairwise accumulate in to q6, q7 - "vpadal.s16 q6, q14\n" - "vpadal.s16 q7, q15\n" - - // Mult, mult-acc in to q14, q15 - "vmull.s8 q14, d4, d8\n" - "vmull.s8 q15, d6, d8\n" - "vmlal.s8 q14, d5, d9\n" - "vmlal.s8 q15, d7, d9\n" - - // Then pairwise accumulate in to q8, q9 - "vpadal.s16 q8, q14\n" - "vpadal.s16 q9, q15\n" - - - // Load the next 64 bytes of LHS and RHS data. - "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" - "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" - "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n" - "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n" - RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") - "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" - // Skip the other column and advance the pointer. - "add %[rhs_ptr], %[rhs_ptr], #16\n" - RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") - - // Each iteration of this loop advances by 16 levels of depth. - "add r1, r1, #16\n" - - // Loop termination condition - "cmp r1, r10\n" - - "blt 2b\n" - - "79:\n" - - // Mult, mult-acc in to q14, q15 - "vmull.s8 q14, d0, d8\n" - "vmull.s8 q15, d2, d8\n" - "vmlal.s8 q14, d1, d9\n" - "vmlal.s8 q15, d3, d9\n" - - // Then pairwise accumulate in to q6, q7 - "vpadal.s16 q6, q14\n" - "vpadal.s16 q7, q15\n" - - // Mult, mult-acc in to q14, q15 - "vmull.s8 q14, d4, d8\n" - "vmull.s8 q15, d6, d8\n" - "vmlal.s8 q14, d5, d9\n" - "vmlal.s8 q15, d7, d9\n" - - // Then pairwise accumulate in to q8, q9 - "vpadal.s16 q8, q14\n" - "vpadal.s16 q9, q15\n" - - // All accumulation over depth done. q6 - q9 contain the 4x32b - // accumulators for the 4x1 final matrix. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 4x2 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // q6-q9 now contain 4 x 32b - "vpadd.i32 d0, d12, d13\n" - "vpadd.i32 d1, d14, d15\n" - "vpadd.i32 d2, d16, d17\n" - "vpadd.i32 d3, d18, d19\n" - - // d0-d4 each contain 2 x 32b accumulators. - // Need to add pairwise to get 1 x 32b for each of the 4x1 entries - // of destination, (Four 'd' registers total) - "vpadd.i32 d28, d0, d1\n" - "vpadd.i32 d29, d2, d3\n" - - // Now d28,d29 have the 1 x 32b accumulators for the 4x1 entries. - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "cmp r1, r3\n" // Have we finished the last row? - - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "add r4, r4, r1, lsl #2\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "b 5f\n" - "4:\n" // Finished last row... - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - // Go back to first row - "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "cmp r8, r4\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "add r10, r10, r1, lsl #1\n" - "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "mov %[lhs_ptr], r4\n" - "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "mov %[rhs_ptr], r5\n" - - // Now we load: bias data, LHS sums data, RHS sums data. - - // First, load the base pointers from the params. - "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - - // Offset these base pointers as needed given the current row, col. - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "add r5, r1, r8, lsl #2\n" - - "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "it ne\n" - "movne r1, r5\n" - - // Load 4 bias values. - "vld1.32 {d24, d25}, [r1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" - "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" - "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n" - "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n" - RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") - "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" - // Skip the other column and advance the pointer. - "add %[rhs_ptr], %[rhs_ptr], #16\n" - RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") - - // Add to the bias values the product - // (depth * lhs_zero_point * rhs_zero_point), - // See the term NZ1Z2 in equation (7) in - // https://arxiv.org/pdf/1712.05877.pdf - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" - "vdup.32 q9, r3\n" - "vadd.i32 q12, q12, q9\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "vadd.i32 q14, q14, q12\n" - - // LHS/RHS zero points - // Has RHS sums - "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" - "beq 401f\n" - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - // Offset by current col * number of bytes per value - "add r3, r3, r4, lsl #2\n" - "vld1.32 { d12 }, [r3]\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" - "vdup.32 q10, r5\n" // create lhs_zero_point_vec - // Subtract rhs_sums * lhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "vmls.i32 q14, q10, d12[0]\n" - "401:\n" - - // Has LHS sums - "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" - "beq 402f\n" - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - // Offset by current row * number of bytes per value - "add r2, r2, r4, lsl #2\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" - - // Load 4 lhs_sums values. - "vld1.32 {d22, d23}, [r2]\n" - "vdup.32 d13, r5\n" // rhs_zero_point - - // Compute lhs_sums * rhs_zero_point. - "vmul.i32 q11, q11, d13[1]\n" - // Subtract lhs_sums * rhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "vsub.s32 q14, q14, q11\n" - - // If the destination is int32, it means the user asks for the raw - // accumulators, no need for us to downquantize the value. - "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" - "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" - - "402:\n" - - // At this point we have computed the final int32 values. Now we - // start down-quantizing them to obtain the final 8bit values from them. - - // As part of this down-quantization, our int32 values will be - // multiplied by a multiplier that has a fixed-point component and an - // exponent component. - - //Load the exponent part of the multiplier. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" - "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "add r5, r1, r4, lsl #2\n" - "it ne\n" - "movne r1, r5\n" - - "vld1.32 {q10}, [r1]\n" - - RUY_MAKE_ZERO(q8) - "vmax.s32 q12, q10, q8\n" - - "vshl.s32 q14, q14, q12\n" - - "vmin.s32 q12, q10, q8\n" - - // Load fixed point part of the multiplier - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" - // r6 has flags, r4 has row - "add r5, r1, r4, lsl #2\n" - "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "it ne\n" - "movne r1, r5\n" - "vld1.32 {q10}, [r1]\n" // multiplier_fixedpoint - - // Apply the fixed-point part of the multiplier. - "vqrdmulh.s32 q14, q14, q10\n" - - // We have some rounding division-by-power-of-two to do. This should - // always use "round to nearest". We allow for some - // freedom in how ties are broken, to strike a good compromise of - // performance on given hardware vs. perfect agreement of results - // across hardware. - // - // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation - // defined tie-breaks to help performance. On NEON, this means that we - // can just use the NEON rounding instructions, such as srshl. They - // happen to be breaking ties upward. - // - // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict - // break-ties-away-from zero, as described in Appendix B of - // https://arxiv.org/pdf/1712.05877.pdf - // When we wrote that, we thought that that would be better unbiased - // than the NEON upwards tie-breaks, and we had observed some - // improvement on some model. However, that is only more unbiased for - // data centered at zero, which was likely the case in that model, - // but is not always the case. If we wanted something more consistently - // unbiased then we should try breaking ties toward-nearest-even. -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - // Fix up values to be right-shifted, so that the (round to nearest, - // break ties upward) behavior of srshl applied to these fixed-up - // values, produces the same result as the desired (round to nearest, - // break ties away from zero) behavior on the original values. - "vand q8, q14, q12\n" - "vshr.s32 q8, q8, #31\n" - "vqadd.s32 q14, q14, q8\n" - -#endif - // At this point we have reduced the problem of correctly implementing - // rounding divide-by-power-of-two, to what the SRSHL instruction can - // do. - "vrshl.s32 q14, q14, q12\n" - - "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" - "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" - "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" - - // Store uint8 values: - RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" - - // Cast-and-saturate from int32 to int16 - // After this, all values for output are in d28. - "vqmovn.s32 d28, q14\n" - - // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the - // current block, so we can start clearing these accumulators for the - // next block (next iteration of the main loop). - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q15) - - // Load the destination zero point into each of the 8 16-bit slots - // in a q register. - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "vdup.16 q13, r4\n" // dst_zero_point - - // Add the destination zero point - "vadd.i16 q14, q14, q13\n" - - // Cast-and-saturate from int16 to uint8 - "vqmovun.s16 d30, q14\n" - // At this point, we only need 4 8-bit values in the lower half - // of d30. - - - // Load the clamp_min, clamp_max bounds - "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "vdup.8 d28, r2\n" // clamp_min - "vdup.8 d29, r3\n" // clamp_max - - // Apply the clamp_min bound - "vmax.u8 d30, d30, d28\n" - // Apply the clamp_max bound - "vmin.u8 d30, d30, d29\n" - - // Compute how much of the 4x1 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x1, there are some 4x1 blocks along the boundaries that do - // not fit entirely. - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #4\n" - "mov r5, #2\n" - "cmp r1, #4\n" - // Compute r1 = how many rows of the 4x1 block fit - "it gt\n" - "movgt r1, r3\n" - - // Test if r1==4, i.e. if all of the 4x1 block fits. - "cmp r1, r3\n" - - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - // Yes, all of the 4x1 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x1 block fits. - // Store to dst_tmp_buf - // Set r3 address to write to dst_tmp_buf. - "mov r3, %[dst_tmp_buf]\n" - "vst1.8 {d30}, [r3]\n" - - // Slow loop copying from dst_tmp_buf to dst. - "50:\n" - "mov r8, #0\n" - "51:\n" - "ldrb r10, [r3, r8]\n" - "strb r10, [r4, r8]\n" - "add r8, r8, #1\n" - "cmp r8, r1\n" - "blt 51b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x1 block fits. - // r3 address, r5 stride - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r3\n" - "mov r6, #1\n" - - "vst1.8 {d30[0]}, [r3], r6\n" - "vst1.8 {d30[1]}, [r3], r6\n" - "vst1.8 {d30[2]}, [r3], r6\n" - "vst1.8 {d30[3]}, [r3], r6\n" - "31:\n" - - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #4\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q14) - RUY_MAKE_ZERO(q15) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - // Store int8 values: - RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" - - // Cast-and-saturate from int32 to int16 - // After this, all values for output are in d28. - "vqmovn.s32 d28, q14\n" - - // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the - // current block, so we can start clearing these accumulators for the - // next block (next iteration of the main loop). - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q15) - - // Load the destination zero point into each of the 8 16-bit slots - // in a q register. - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "vdup.16 q13, r4\n" // dst_zero_point - - // Add the destination zero point - "vadd.i16 q14, q14, q13\n" - - // Cast-and-saturate from int16 to int8 - "vqmovn.s16 d30, q14\n" - // At this point, we only need 4 8-bit values in the lower half - // of d30. - - // Load the clamp_min, clamp_max bounds - "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "vdup.8 d28, r2\n" // clamp_min - "vdup.8 d29, r3\n" // clamp_max - - // Apply the clamp_min bound - "vmax.s8 d30, d30, d28\n" - // Apply the clamp_max bound - "vmin.s8 d30, d30, d29\n" - - // Compute how much of the 4x1 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x2, there are some 4x2 blocks along the boundaries that do - // not fit entirely. - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #4\n" - "mov r5, #2\n" - "cmp r1, #4\n" - // Compute r1 = how many rows of the 4x2 block fit - "it gt\n" - "movgt r1, r3\n" - - // Test if r1==4 i.e. if all of the 4x1 block fits. - "cmp r1, r3\n" - - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - // Yes, all of the 4x2 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x2 block fits. - // Store to dst_tmp_buf - // Set r3 address to write to dst_tmp_buf. - "mov r3, %[dst_tmp_buf]\n" - "vst1.8 {d30}, [r3]\n" - - // Slow loop copying from dst_tmp_buf to dst. - "50:\n" - "mov r8, #0\n" - "51:\n" - "ldrb r10, [r3, r8]\n" - "strb r10, [r4, r8]\n" - "add r8, r8, #1\n" - "cmp r8, r1\n" - "blt 51b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x1 block fits. - // r3 address, r5 stride - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r3\n" - "mov r6, #1\n" - - "vst1.8 {d30[0]}, [r3], r6\n" - "vst1.8 {d30[1]}, [r3], r6\n" - "vst1.8 {d30[2]}, [r3], r6\n" - "vst1.8 {d30[3]}, [r3], r6\n" - "31:\n" - - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #4\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q14) - RUY_MAKE_ZERO(q15) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" - - // Load the destination zero point into each of the 4 32-bit slots - // in a q register. - "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "vdup.32 q13, r4\n" // dst_zero_point - // Add the destination zero point - "vadd.s32 q14, q14, q13\n" - //"vadd.s32 q15, q15, q13\n" - - // Cast-and-saturate from int32 to int16 - // After this, all values for output are in d28. - "vqmovn.s32 d28, q14\n" - - // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q15) - - // Load the clamp_min, clamp_max bounds - "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "vdup.16 d24, r2\n" // clamp_min - "vdup.16 d26, r3\n" // clamp_max - - // Apply the clamp_min bound - "vmax.s16 d28, d28, d24\n" - // Apply the clamp_max bound - "vmin.s16 d28, d28, d26\n" - - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - - // Compute how much of the 4x1 block of destination 16-bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x1, there are some 4x1 blocks along the boundaries that do - // not fit entirely. - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #4\n" - "mov r5, #2\n" - "cmp r1, #4\n" - // Compute r1 = how many rows of the 4x1 block fit - "it gt\n" - "movgt r1, r3\n" - - // Test if r1==4, i.e. if all of the 4x1 block fits. - "cmp r1, r3\n" - - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - // Yes, all of the 4x1 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x1 block fits. - // Store to dst_tmp_buf - // Set r3 address to write to dst_tmp_buf. - "mov r3, %[dst_tmp_buf]\n" - "vst1.16 {d28}, [r3]\n" - - // Slow loop copying from dst_tmp_buf to dst. - "50:\n" - "mov r8, #0\n" - "51:\n" - // Shift of offset register for half-word loads not allowed in A32, - // so we shift, load/store, then shift back r8. - "lsl r8, r8, #1\n" - "ldrh r10, [r3, r8]\n" - "strh r10, [r4, r8]\n" - "lsr r8, r8, #1\n" - "add r8, r8, #1\n" - "cmp r8, r1\n" - "blt 51b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x1 block fits. - // r3 address, r5 stride - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r3\n" - "mov r6, #2\n" - - "vst1.16 {d28[0]}, [r3], r6\n" - "vst1.16 {d28[1]}, [r3], r6\n" - "vst1.16 {d28[2]}, [r3], r6\n" - "vst1.16 {d28[3]}, [r3], r6\n" - "31:\n" - - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #8\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - RUY_MAKE_ZERO(q14) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" - - // Since the store type is the same as the accum type, no need for - // downcast. There's also no need for clamp by min/max. - - // At this point, v20 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - // Clear accumulators. - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - - // Compute how much of the 4x1 block of destination 32 bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x2, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #4\n" - "mov r5, #2\n" - "cmp r1, #4\n" - // Compute r1 = how many rows of the 4x2 block fit - "it gt\n" - "movgt r1, r3\n" - - // Test if r1==4, i.e. if all of the 4x1 block fits. - "cmp r1, r3\n" - - // Yes, all of the 4x1 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x1 block fits. - // Set (r3 address, r4 stride) to write to dst_tmp_buf - "mov r3, %[dst_tmp_buf]\n" - "mov r4, #16\n" - "b 31f\n" - - "30:\n" - // Yes, all of the 4x1 block fits. - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - // r3 address, r4 stride - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r5\n" - - "31:\n" - - "vst1.32 {d28, d29}, [r3]\n" - - // If all of the 4x1 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 4x1 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "mov r3, %[dst_tmp_buf]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "50:\n" - "mov r5, #0\n" - "51:\n" - "ldr r10, [r3, r5, lsl #2]\n" - "str r10, [r4, r5, lsl #2]\n" - "add r5, r5, #1\n" - "cmp r5, r1\n" - "blt 51b\n" - - "41:\n" - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #16\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "cmp r8, r3\n" - - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add r8, r8, #4\n" - // Store new value of row - "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - - "b 21f\n" - "20:\n" - // Was already at end row. - // Move back to first row. - "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - // Move to the next column. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "add r4, r4, #2\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - - "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - // Increment dst_col_ptr by dst_stride (i.e. 1 column) - "add r1, r1, r8\n" - // Store dst_col_ptr - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - // Store dst_ptr - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "cmp r8, r4\n" - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 16. - "mov r1, #16\n" - - "ble 1b\n" - - // Restore stack pointer. - "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" - - // clang-format on - - : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr) - : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf) - : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc", - // Clobber list must specify q registers (and not their constituent - // d registers). There is a (currently unexplained) slowdown if - // d registers are listed in the clobbers list. - "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", - "q9", "q10", "q12", "q13", "q14", "q15"); -} - -#undef RUY_OFFSET_BIAS -#undef RUY_OFFSET_LHS_SUMS -#undef RUY_OFFSET_RHS_SUMS -#undef RUY_OFFSET_LHS_BASE_PTR -#undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT -#undef RUY_OFFSET_MULTIPLIER_EXPONENT -#undef RUY_OFFSET_RHS_BASE_PTR -#undef RUY_OFFSET_DST_BASE_PTR -#undef RUY_OFFSET_LHS_ZERO_POINT -#undef RUY_OFFSET_RHS_ZERO_POINT -#undef RUY_OFFSET_DST_ZERO_POINT -#undef RUY_OFFSET_PROD_ZP_DEPTH -#undef RUY_OFFSET_START_ROW -#undef RUY_OFFSET_START_COL -#undef RUY_OFFSET_LAST_ROW -#undef RUY_OFFSET_LAST_COL -#undef RUY_OFFSET_DST_ROWS -#undef RUY_OFFSET_DST_COLS -#undef RUY_OFFSET_LHS_STRIDE -#undef RUY_OFFSET_RHS_STRIDE -#undef RUY_OFFSET_DST_STRIDE -#undef RUY_OFFSET_DEPTH -#undef RUY_OFFSET_CLAMP_MIN -#undef RUY_OFFSET_CLAMP_MAX -#undef RUY_OFFSET_FLAGS -#undef RUY_OFFSET_DST_TYPE_ID - -#undef RUY_STACK_OFFSET_SIZE -#undef RUY_STACK_OFFSET_DST_COL_PTR -#undef RUY_STACK_OFFSET_DST_PTR -#undef RUY_STACK_OFFSET_ROW -#undef RUY_STACK_OFFSET_COL -#undef RUY_STACK_OFFSET_LHS_COL_PTR -#undef RUY_STACK_OFFSET_RHS_COL_PTR - -#endif // RUY_PLATFORM(NEON_32) && (RUY_OPT_ENABLED(RUY_OPT_ASM) -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/kernel_arm64.cc b/tensorflow/lite/experimental/ruy/ruy/kernel_arm64.cc deleted file mode 100644 index eff9d2c8a09..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/kernel_arm64.cc +++ /dev/null @@ -1,7835 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "tensorflow/lite/experimental/ruy/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/kernel.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -namespace ruy { - -#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -#define RUY_ASM_LABEL_STORE_UINT8 91 -#define RUY_ASM_LABEL_STORE_INT8 92 -#define RUY_ASM_LABEL_STORE_INT16 93 -#define RUY_ASM_LABEL_STORE_INT32 94 -#define RUY_ASM_LABEL_AFTER_STORE 99 - -#define RUY_OFFSET_BIAS 0 -#define RUY_OFFSET_LHS_SUMS 8 -#define RUY_OFFSET_RHS_SUMS 16 -#define RUY_OFFSET_LHS_BASE_PTR 24 -#define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 32 -#define RUY_OFFSET_MULTIPLIER_EXPONENT 40 -#define RUY_OFFSET_RHS_BASE_PTR 48 -#define RUY_OFFSET_DST_BASE_PTR 56 -#define RUY_OFFSET_LHS_ZERO_POINT 64 -#define RUY_OFFSET_RHS_ZERO_POINT 68 -#define RUY_OFFSET_DST_ZERO_POINT 72 -#define RUY_OFFSET_PROD_ZP_DEPTH 76 -#define RUY_OFFSET_START_ROW 80 -#define RUY_OFFSET_START_COL 84 -#define RUY_OFFSET_LAST_ROW 88 -#define RUY_OFFSET_LAST_COL 92 -#define RUY_OFFSET_DST_ROWS 96 -#define RUY_OFFSET_DST_COLS 100 -#define RUY_OFFSET_LHS_STRIDE 104 -#define RUY_OFFSET_RHS_STRIDE 108 -#define RUY_OFFSET_DST_STRIDE 112 -#define RUY_OFFSET_DEPTH 116 -#define RUY_OFFSET_CLAMP_MIN 120 -#define RUY_OFFSET_CLAMP_MAX 124 -#define RUY_OFFSET_FLAGS 128 - -template -void CheckOffsetsInKernelParams8bit(const Params&) { - static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT, - ""); - static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT, - ""); - static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT, - ""); - static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH, - ""); - static_assert(offsetof(Params, multiplier_fixedpoint) == - RUY_OFFSET_MULTIPLIER_FIXEDPOINT, - ""); - static_assert( - offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT, - ""); - static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); - static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); - static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); - static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, ""); - static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, ""); - static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); - static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); - static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); - static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); - static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); - static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); - static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); - static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); - static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); -} - -// Fast-int8-trick kernel, similar to this production gemmlowp kernel: -// NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits -// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L2296 -// -// Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75, -// since these are 64-bit, out-of-order and without dotprod support. -void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 4>& params) { - profiler::ScopeLabel label( - "Kernel (kNeon, optimized for out-of-order cores)"); - - CheckOffsetsInKernelParams8bit(params); - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - void* dst_col_ptr = params.dst_base_ptr; - void* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v31 are int32 accumulators. - // During accumulation, v0 -- v3 are used to load int8 data from LHS and - // v4 -- v7 from RHS: - // - // int8 RHS 16x4 block - // /-----------------------------------------\ - // |v4.b[0] ... v7.b[0] | - // | ... ... | - // |v4.b[15] ... v7.b[15] | - // \-----------------------------------------/ - // int8 LHS 4x16 block - // /---------------------\ /-----------------------------------------\ - // |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s | - // |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s | - // |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s | - // |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s | - // \---------------------/ \-----------------------------------------/ - // int32 accumulators 4x4 block - // - // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING - // optimization for this kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - - // Load the first 64 bytes of LHS and RHS data. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" - - // Clear accumulators. - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 16. - "mov w1, #16\n" - - // Perform the first few multiply-adds on the data that we have already - // loaded. - "smull v8.8h, v0.8b, v4.8b\n" - "smull v9.8h, v1.8b, v4.8b\n" - "smull v10.8h, v2.8b, v4.8b\n" - "smull v11.8h, v3.8b, v4.8b\n" - "smull v12.8h, v0.8b, v5.8b\n" - "smull v13.8h, v1.8b, v5.8b\n" - "smull v14.8h, v2.8b, v5.8b\n" - "smull v15.8h, v3.8b, v5.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - "smlal2 v12.8h, v0.16b, v5.16b\n" - "smlal2 v13.8h, v1.16b, v5.16b\n" - "smlal2 v14.8h, v2.16b, v5.16b\n" - "smlal2 v15.8h, v3.16b, v5.16b\n" - - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - // Reminder - w1 is how many levels of depth we have already loaded - // data for, w12 is the total depth. - "cmp w1, w12\n" - "beq 79f\n" - - "2:\n" - - // Some multiplications and 16-bit accumulation were already done above, - // so we start right away in the middle. - "sadalp v16.4s, v8.8h\n" - "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" - "smull v8.8h, v0.8b, v6.8b\n" - "sadalp v17.4s, v9.8h\n" - "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" - "smull v9.8h, v1.8b, v6.8b\n" - "sadalp v18.4s, v10.8h\n" - "smull v10.8h, v2.8b, v6.8b\n" - "sadalp v19.4s, v11.8h\n" - "smull v11.8h, v3.8b, v6.8b\n" - "sadalp v20.4s, v12.8h\n" - "smull v12.8h, v0.8b, v7.8b\n" - "sadalp v21.4s, v13.8h\n" - "smull v13.8h, v1.8b, v7.8b\n" - "sadalp v22.4s, v14.8h\n" - "smull v14.8h, v2.8b, v7.8b\n" - "sadalp v23.4s, v15.8h\n" - "smull v15.8h, v3.8b, v7.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v6.16b\n" - "smlal2 v9.8h, v1.16b, v6.16b\n" - "smlal2 v10.8h, v2.16b, v6.16b\n" - "smlal2 v11.8h, v3.16b, v6.16b\n" - - "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" - - "smlal2 v12.8h, v0.16b, v7.16b\n" - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "smlal2 v13.8h, v1.16b, v7.16b\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "smlal2 v14.8h, v2.16b, v7.16b\n" - "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" - "smlal2 v15.8h, v3.16b, v7.16b\n" - "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" - - "sadalp v24.4s, v8.8h\n" - "smull v8.8h, v0.8b, v4.8b\n" - "sadalp v25.4s, v9.8h\n" - "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" - "smull v9.8h, v1.8b, v4.8b\n" - "sadalp v26.4s, v10.8h\n" - "smull v10.8h, v2.8b, v4.8b\n" - "sadalp v27.4s, v11.8h\n" - "smull v11.8h, v3.8b, v4.8b\n" - "sadalp v28.4s, v12.8h\n" - "smull v12.8h, v0.8b, v5.8b\n" - "sadalp v29.4s, v13.8h\n" - "smull v13.8h, v1.8b, v5.8b\n" - "sadalp v30.4s, v14.8h\n" - "smull v14.8h, v2.8b, v5.8b\n" - "sadalp v31.4s, v15.8h\n" - "smull v15.8h, v3.8b, v5.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - - "smlal2 v12.8h, v0.16b, v5.16b\n" - "smlal2 v13.8h, v1.16b, v5.16b\n" - "smlal2 v14.8h, v2.16b, v5.16b\n" - "smlal2 v15.8h, v3.16b, v5.16b\n" - - - - // Each iteration of this loop advances by 16 levels of depth. - "add w1, w1, #16\n" - - // Loop termination condition - "cmp w1, w12\n" - - "blt 2b\n" - - "79:\n" - - "sadalp v16.4s, v8.8h\n" - "smull v8.8h, v0.8b, v6.8b\n" - "sadalp v17.4s, v9.8h\n" - "smull v9.8h, v1.8b, v6.8b\n" - "sadalp v18.4s, v10.8h\n" - "smull v10.8h, v2.8b, v6.8b\n" - "sadalp v19.4s, v11.8h\n" - "smull v11.8h, v3.8b, v6.8b\n" - "sadalp v20.4s, v12.8h\n" - "smull v12.8h, v0.8b, v7.8b\n" - "sadalp v21.4s, v13.8h\n" - "smull v13.8h, v1.8b, v7.8b\n" - "sadalp v22.4s, v14.8h\n" - "smull v14.8h, v2.8b, v7.8b\n" - "sadalp v23.4s, v15.8h\n" - "smull v15.8h, v3.8b, v7.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v6.16b\n" - "smlal2 v9.8h, v1.16b, v6.16b\n" - "smlal2 v10.8h, v2.16b, v6.16b\n" - "smlal2 v11.8h, v3.16b, v6.16b\n" - - "smlal2 v12.8h, v0.16b, v7.16b\n" - "smlal2 v13.8h, v1.16b, v7.16b\n" - "smlal2 v14.8h, v2.16b, v7.16b\n" - "smlal2 v15.8h, v3.16b, v7.16b\n" - - "sadalp v24.4s, v8.8h\n" - "sadalp v25.4s, v9.8h\n" - "sadalp v26.4s, v10.8h\n" - "sadalp v27.4s, v11.8h\n" - "sadalp v28.4s, v12.8h\n" - "sadalp v29.4s, v13.8h\n" - "sadalp v30.4s, v14.8h\n" - "sadalp v31.4s, v15.8h\n" - - // End of accumulation. The registers v16 -- v31 contain the final - // int32 accumulator values of the current 4x4 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 4x4 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Reduce 32bit accumulators horizontally. - "addp v16.4s, v16.4s, v17.4s\n" - "addp v18.4s, v18.4s, v19.4s\n" - "addp v20.4s, v20.4s, v21.4s\n" - "addp v22.4s, v22.4s, v23.4s\n" - "addp v24.4s, v24.4s, v25.4s\n" - "addp v26.4s, v26.4s, v27.4s\n" - "addp v28.4s, v28.4s, v29.4s\n" - "addp v30.4s, v30.4s, v31.4s\n" - - // Reduce 32bit accumulators horizontally, second pass - // (each pass adds pairwise. we need to add 4-wise). - "addp v16.4s, v16.4s, v18.4s\n" - "addp v17.4s, v20.4s, v22.4s\n" - "addp v18.4s, v24.4s, v26.4s\n" - "addp v19.4s, v28.4s, v30.4s\n" - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "cmp %w[row], w7\n" // Have we finished the last row? - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - - // Load some parameters needed for the end work on current block. - RUY_MAKE_ZERO(v8) - "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" - "ins v13.h[4], w4\n" // dst_zero_point - "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" - "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "dup v9.4s, w3\n" // create prod_zp_depth_vec - "add x5, x4, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "csel x4, x4, x5, eq\n" - - "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint - - // Now we load: bias data, LHS sums data, RHS sums data. - - // First, load the base pointers from the params. - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - - "add x5, x1, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 4 bias values. - "ld1 {v14.4s}, [x1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" - - // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), - // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "add v14.4s, v14.4s, v9.4s\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "add v16.4s, v16.4s, v14.4s\n" - "add v17.4s, v17.4s, v14.4s\n" - "add v18.4s, v18.4s, v14.4s\n" - "add v19.4s, v19.4s, v14.4s\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" - "beq 401f\n" - "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" - "add x3, x3, %x[col], lsl #2\n" - "ld1 {v14.4s}, [x3]\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" - "dup v10.4s, w5\n" // create lhs_zero_point_vec - // Subtract rhs_sums * lhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "mls v16.4s, v10.4s, v14.s[0]\n" - "mls v17.4s, v10.4s, v14.s[1]\n" - "mls v18.4s, v10.4s, v14.s[2]\n" - "mls v19.4s, v10.4s, v14.s[3]\n" - "401:\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" - "beq 402f\n" - "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" - "add x2, x2, %x[row], lsl #2\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" - // Load 4 lhs_sums values. - "ld1 {v11.4s}, [x2]\n" - "ins v13.s[1], w5\n" // rhs_zero_point - // Compute lhs_sums * rhs_zero_point. - "mul v11.4s, v11.4s, v13.s[1]\n" - // Subtract lhs_sums * rhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "sub v16.4s, v16.4s, v11.4s\n" - "sub v17.4s, v17.4s, v11.4s\n" - "sub v18.4s, v18.4s, v11.4s\n" - "sub v19.4s, v19.4s, v11.4s\n" - - // If the destination is int32, it means the user asks for the raw - // accumulators, no need for us to downquantize the value. - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" - - "402:\n" - - // At this point we have computed the final int32 values. Now we - // start down-quantizing them to obtain the final 8bit values from them. - - // As part of this down-quantization, our int32 values will be - // multiplied by a multiplier that has a fixed-point component and an - // exponent component. - - //Load the exponent part of the multiplier. - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "add x5, x1, %x[row], lsl #2\n" - "csel x1, x1, x5, eq\n" - - "ld1 {v14.4s}, [x1]\n" - - "smax v12.4s, v14.4s, v8.4s\n" - - "sshl v16.4s, v16.4s, v12.4s\n" - "sshl v17.4s, v17.4s, v12.4s\n" - "sshl v18.4s, v18.4s, v12.4s\n" - "sshl v19.4s, v19.4s, v12.4s\n" - - "smin v12.4s, v14.4s, v8.4s\n" - - // Apply the fixed-point part of the multiplier. - "sqrdmulh v16.4s, v16.4s, v15.4s\n" - "sqrdmulh v17.4s, v17.4s, v15.4s\n" - "sqrdmulh v18.4s, v18.4s, v15.4s\n" - "sqrdmulh v19.4s, v19.4s, v15.4s\n" - - // We have some rounding division-by-power-of-two to do. This should - // always use "round to nearest". We allow for some - // freedom in how ties are broken, to strike a good compromise of - // performance on given hardware vs. perfect agreement of results - // across hardware. - // - // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation - // defined tie-breaks to help performance. On NEON, this means that we - // can just use the NEON rounding instructions, such as srshl. They - // happen to be breaking ties upward. - // - // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict - // break-ties-away-from zero, as described in Appendix B of - // https://arxiv.org/pdf/1712.05877.pdf - // When we wrote that, we thought that that would be better unbiased - // than the NEON upwards tie-breaks, and we had observed some - // improvement on some model. However, that is only more unbiased for - // data centered at zero, which was likely the case in that model, - // but is not always the case. If we wanted something more consistently - // unbiased then we should try breaking ties toward-nearest-even. -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - // Fix up values to be right-shifted, so that the (round to nearest, - // break ties upward) behavior of srshl applied to these fixed-up - // values, produces the same result as the desired (round to nearest, - // break ties away from zero) behavior on the original values. - "and v8.16b, v16.16b, v12.16b\n" - "and v9.16b, v17.16b, v12.16b\n" - "and v14.16b, v18.16b, v12.16b\n" - "and v15.16b, v19.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v16.4s, v16.4s, v8.4s\n" - "sqadd v17.4s, v17.4s, v9.4s\n" - "sqadd v18.4s, v18.4s, v14.4s\n" - "sqadd v19.4s, v19.4s, v15.4s\n" -#endif - // At this point we have reduced the problem of correctly implementing - // rounding divide-by-power-of-two, to what the SRSHL instruction can - // do. - "srshl v16.4s, v16.4s, v12.4s\n" - "srshl v17.4s, v17.4s, v12.4s\n" - "srshl v18.4s, v18.4s, v12.4s\n" - "srshl v19.4s, v19.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - "add v17.8h, v17.8h, v14.8h\n" - - // Cast-and-saturate from int16 to uint8 - "sqxtun v16.8b, v16.8h\n" - "sqxtun2 v16.16b, v17.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Apply the clamp_min bound - "umax v16.16b, v16.16b, v14.16b\n" - // Apply the clamp_max bound - "umin v16.16b, v16.16b, v15.16b\n" - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - // Compute w2 = how many cols of the 4x4 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "st1 {v16.16b}, [%[dst_tmp_buf]]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #4\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[0], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[1], [x3], #1\n" - "st1 {v16.b}[2], [x3], #1\n" - "st1 {v16.b}[3], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[4], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[5], [x3], #1\n" - "st1 {v16.b}[6], [x3], #1\n" - "st1 {v16.b}[7], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[8], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[9], [x3], #1\n" - "st1 {v16.b}[10], [x3], #1\n" - "st1 {v16.b}[11], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[12], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[13], [x3], #1\n" - "st1 {v16.b}[14], [x3], #1\n" - "st1 {v16.b}[15], [x3], #1\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #4\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - "add v17.8h, v17.8h, v14.8h\n" - - // Cast-and-saturate from int16 to int8 - "sqxtn v16.8b, v16.8h\n" - "sqxtn2 v16.16b, v17.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.16b, v16.16b, v14.16b\n" - // Apply the clamp_max bound - "smin v16.16b, v16.16b, v15.16b\n" - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - // Compute w2 = how many cols of the 4x4 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "st1 {v16.16b}, [%[dst_tmp_buf]]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #4\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[0], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[1], [x3], #1\n" - "st1 {v16.b}[2], [x3], #1\n" - "st1 {v16.b}[3], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[4], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[5], [x3], #1\n" - "st1 {v16.b}[6], [x3], #1\n" - "st1 {v16.b}[7], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[8], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[9], [x3], #1\n" - "st1 {v16.b}[10], [x3], #1\n" - "st1 {v16.b}[11], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[12], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[13], [x3], #1\n" - "st1 {v16.b}[14], [x3], #1\n" - "st1 {v16.b}[15], [x3], #1\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #4\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" - - // Add the destination zero point - "dup v14.4h, v13.h[4]\n" - "saddw v16.4s, v16.4s, v14.4h\n" - "saddw v17.4s, v17.4s, v14.4h\n" - "saddw v18.4s, v18.4s, v14.4h\n" - "saddw v19.4s, v19.4s, v14.4h\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Load the clamp_min, clamp_max bounds - "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.8h, w2\n" // clamp_min - "dup v15.8h, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.8h, v16.8h, v14.8h\n" - "smax v17.8h, v17.8h, v14.8h\n" - // Apply the clamp_max bound - "smin v16.8h, v16.8h, v15.8h\n" - "smin v17.8h, v17.8h, v15.8h\n" - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - // Compute w2 = how many cols of the 4x4 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "str q16, [%[dst_tmp_buf], #0]\n" - "str q17, [%[dst_tmp_buf], #16]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrh w7, [x3, x5, lsl #1]\n" - "strh w7, [x4, x5, lsl #1]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #8\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.h}[0], [x3], #2\n" - "add x4, x4, x11\n" - "st1 {v16.h}[1], [x3], #2\n" - "st1 {v16.h}[2], [x3], #2\n" - "st1 {v16.h}[3], [x3], #2\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.h}[4], [x3], #2\n" - "add x4, x4, x11\n" - "st1 {v16.h}[5], [x3], #2\n" - "st1 {v16.h}[6], [x3], #2\n" - "st1 {v16.h}[7], [x3], #2\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v17.h}[0], [x3], #2\n" - "add x4, x4, x11\n" - "st1 {v17.h}[1], [x3], #2\n" - "st1 {v17.h}[2], [x3], #2\n" - "st1 {v17.h}[3], [x3], #2\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v17.h}[4], [x3], #2\n" - "add x4, x4, x11\n" - "st1 {v17.h}[5], [x3], #2\n" - "st1 {v17.h}[6], [x3], #2\n" - "st1 {v17.h}[7], [x3], #2\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #8\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" - - // Since the store type is the same as the accum type, no need for - // downcast. There's also no need for clamp by min/max. - - // At this point, v20 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - // Compute w2 = how many cols of the 4x4 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "str q16, [%[dst_tmp_buf], #0]\n" - "str q17, [%[dst_tmp_buf], #16]\n" - "str q18, [%[dst_tmp_buf], #32]\n" - "str q19, [%[dst_tmp_buf], #48]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #16\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.s}[0], [x3], #4\n" - "add x4, x4, x11\n" - "st1 {v16.s}[1], [x3], #4\n" - "st1 {v16.s}[2], [x3], #4\n" - "st1 {v16.s}[3], [x3], #4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v17.s}[0], [x3], #4\n" - "add x4, x4, x11\n" - "st1 {v17.s}[1], [x3], #4\n" - "st1 {v17.s}[2], [x3], #4\n" - "st1 {v17.s}[3], [x3], #4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v18.s}[0], [x3], #4\n" - "add x4, x4, x11\n" - "st1 {v18.s}[1], [x3], #4\n" - "st1 {v18.s}[2], [x3], #4\n" - "st1 {v18.s}[3], [x3], #4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v19.s}[0], [x3], #4\n" - "add x4, x4, x11\n" - "st1 {v19.s}[1], [x3], #4\n" - "st1 {v19.s}[2], [x3], #4\n" - "st1 {v19.s}[3], [x3], #4\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #16\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - - RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - "smull v8.8h, v0.8b, v4.8b\n" - "smull v9.8h, v1.8b, v4.8b\n" - "smull v10.8h, v2.8b, v4.8b\n" - "smull v11.8h, v3.8b, v4.8b\n" - "smull v12.8h, v0.8b, v5.8b\n" - "smull v13.8h, v1.8b, v5.8b\n" - "smull v14.8h, v2.8b, v5.8b\n" - "smull v15.8h, v3.8b, v5.8b\n" - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - "smlal2 v12.8h, v0.16b, v5.16b\n" - "smlal2 v13.8h, v1.16b, v5.16b\n" - "smlal2 v14.8h, v2.16b, v5.16b\n" - "smlal2 v15.8h, v3.16b, v5.16b\n" - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #4\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #4\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 4. - "mov w1, #16\n" - - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), - [dst_type_id] "r"(params.dst_type_id) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "v31"); -} - -// Similar to existing Kernel8bitNeonOutOfOrder but specialized for the case of -// RHS cols == 1. -// Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75, -// since these are 64-bit, out-of-order and without dotprod support. -void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 4>& params) { - profiler::ScopeLabel label( - "Kernel (kNeon, optimized for out-of-order cores)"); - - CheckOffsetsInKernelParams8bit(params); - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - void* dst_col_ptr = params.dst_base_ptr; - void* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v19 are int32 accumulators. - // During accumulation, v0 -- v3 are used to load int8 data from LHS and - // v4 from RHS: - // - // int8 RHS 16x1 block - // /-----------\ - // |v4.b[0] | - // | ... | - // |v4.b[15] | - // \-----------/ - // int8 LHS 4x16 block - // /---------------------\ /-----------\ - // |v0.b[0] ... v0.b[15] | |v16.4s | - // |v1.b[0] ... v1.b[15] | |v17.4s | - // |v2.b[0] ... v2.b[15] | |v18.4s | - // |v3.b[0] ... v3.b[15] | |v19.4s | - // \---------------------/ \-----------/ - // int32 accumulators 4x1 block - // - // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING - // optimization for this kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - - // Load the first 64 bytes of LHS and RHS data. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" - "add %[rhs_ptr], %[rhs_ptr], #48\n" - - // Clear accumulators. - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 16. - "mov w1, #16\n" - - // Perform the first few multiply-adds on the data that we have already - // loaded. - "smull v8.8h, v0.8b, v4.8b\n" - "smull v9.8h, v1.8b, v4.8b\n" - "smull v10.8h, v2.8b, v4.8b\n" - "smull v11.8h, v3.8b, v4.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - // Reminder - w1 is how many levels of depth we have already loaded - // data for, w12 is the total depth. - "cmp w1, w12\n" - "beq 79f\n" - - "2:\n" - - // Some multiplications and 16-bit accumulation were already done above, - // so we start right away in the middle. - "sadalp v16.4s, v8.8h\n" - "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" - "add %[rhs_ptr], %[rhs_ptr], #48\n" - "sadalp v17.4s, v9.8h\n" - "sadalp v18.4s, v10.8h\n" - "sadalp v19.4s, v11.8h\n" - - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" - - "smull v8.8h, v0.8b, v4.8b\n" - "smull v9.8h, v1.8b, v4.8b\n" - "smull v10.8h, v2.8b, v4.8b\n" - "smull v11.8h, v3.8b, v4.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - - // Each iteration of this loop advances by 16 levels of depth. - "add w1, w1, #16\n" - - // Loop termination condition - "cmp w1, w12\n" - - "blt 2b\n" - - "79:\n" - - "sadalp v16.4s, v8.8h\n" - "sadalp v17.4s, v9.8h\n" - "sadalp v18.4s, v10.8h\n" - "sadalp v19.4s, v11.8h\n" - - // End of accumulation. The registers v16 -- v19 contain the final - // int32 accumulator values of the current 4x1 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 4x1 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Reduce 32bit accumulators horizontally. - "addp v16.4s, v16.4s, v17.4s\n" - "addp v18.4s, v18.4s, v19.4s\n" - - // Reduce 32bit accumulators horizontally, second pass - // (each pass adds pairwise. we need to add 4-wise). - "addp v16.4s, v16.4s, v18.4s\n" - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "cmp %w[row], w7\n" // Have we finished the last row? - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - // (still multiply column stride by 4 due to packing) - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - - // Load some parameters needed for the end work on current block. - RUY_MAKE_ZERO(v8) - "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" - "ins v13.h[4], w4\n" // dst_zero_point - "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" - "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "dup v9.4s, w3\n" // create prod_zp_depth_vec - "add x5, x4, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "csel x4, x4, x5, eq\n" - - "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint - - // Now we load: bias data, LHS sums data, RHS sums data. - - // First, load the base pointers from the params. - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - - "add x5, x1, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 4 bias values. - "ld1 {v14.4s}, [x1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" - "add %[rhs_ptr], %[rhs_ptr], #48\n" - - // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), - // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "add v14.4s, v14.4s, v9.4s\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - // (all four 32-bit accumulators are in v16 at this point) - "add v16.4s, v16.4s, v14.4s\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" - "beq 401f\n" - "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" - "add x3, x3, %x[col], lsl #2\n" - "ld1 {v14.4s}, [x3]\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" - "dup v10.4s, w5\n" // create lhs_zero_point_vec - // Subtract rhs_sums * lhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "mls v16.4s, v10.4s, v14.s[0]\n" - "401:\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" - "beq 402f\n" - "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" - "add x2, x2, %x[row], lsl #2\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" - // Load 4 lhs_sums values. - "ld1 {v11.4s}, [x2]\n" - "ins v13.s[1], w5\n" // rhs_zero_point - // Compute lhs_sums * rhs_zero_point. - "mul v11.4s, v11.4s, v13.s[1]\n" - // Subtract lhs_sums * rhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "sub v16.4s, v16.4s, v11.4s\n" - - // If the destination is int32, it means the user asks for the raw - // accumulators, no need for us to downquantize the value. - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" - - "402:\n" - - // At this point we have computed the final int32 values. Now we - // start down-quantizing them to obtain the final 8bit values from them. - - // As part of this down-quantization, our int32 values will be - // multiplied by a multiplier that has a fixed-point component and an - // exponent component. - - //Load the exponent part of the multiplier. - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "add x5, x1, %x[row], lsl #2\n" - "csel x1, x1, x5, eq\n" - - "ld1 {v14.4s}, [x1]\n" - - "smax v12.4s, v14.4s, v8.4s\n" - - "sshl v16.4s, v16.4s, v12.4s\n" - - "smin v12.4s, v14.4s, v8.4s\n" - - // Apply the fixed-point part of the multiplier. - "sqrdmulh v16.4s, v16.4s, v15.4s\n" - - // We have some rounding division-by-power-of-two to do. This should - // always use "round to nearest". We allow for some - // freedom in how ties are broken, to strike a good compromise of - // performance on given hardware vs. perfect agreement of results - // across hardware. - // - // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation - // defined tie-breaks to help performance. On NEON, this means that we - // can just use the NEON rounding instructions, such as srshl. They - // happen to be breaking ties upward. - // - // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict - // break-ties-away-from zero, as described in Appendix B of - // https://arxiv.org/pdf/1712.05877.pdf - // When we wrote that, we thought that that would be better unbiased - // than the NEON upwards tie-breaks, and we had observed some - // improvement on some model. However, that is only more unbiased for - // data centered at zero, which was likely the case in that model, - // but is not always the case. If we wanted something more consistently - // unbiased then we should try breaking ties toward-nearest-even. -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - // Fix up values to be right-shifted, so that the (round to nearest, - // break ties upward) behavior of srshl applied to these fixed-up - // values, produces the same result as the desired (round to nearest, - // break ties away from zero) behavior on the original values. - "and v8.16b, v16.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sqadd v16.4s, v16.4s, v8.4s\n" -#endif - // At this point we have reduced the problem of correctly implementing - // rounding divide-by-power-of-two, to what the SRSHL instruction can - // do. - "srshl v16.4s, v16.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" - - // Cast-and-saturate from int32 to int16 - // After this instruction, all data is in lower half (64-bits) of v16 - "sqxtn v16.4h, v16.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - - // Cast-and-saturate from int16 to uint8 - // Now all data is in the first 32-bits of v16 - "sqxtun v16.8b, v16.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Apply the clamp_min bound - "umax v16.16b, v16.16b, v14.16b\n" - // Apply the clamp_max bound - "umin v16.16b, v16.16b, v15.16b\n" - - // Compute how much of the 4x1 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x1, there are some 4x1 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x1 block fit - "csel w1, w1, w3, le\n" - - // Test if w1==4, i.e. if all of the 4x1 block fits. - "cmp w1, w3\n" - - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x1 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x1 block fits. - // Store to dst_tmp_buf - "st1 {v16.16b}, [%[dst_tmp_buf]]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x1 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[0], [x3], #1\n" - "st1 {v16.b}[1], [x3], #1\n" - "st1 {v16.b}[2], [x3], #1\n" - "st1 {v16.b}[3], [x3], #1\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #4\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" - - // Cast-and-saturate from int32 to int16 - // After this, all values for output are in the lower half (64 bits) of v16. - "sqxtn v16.4h, v16.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - - // Cast-and-saturate from int16 to int8 - "sqxtn v16.8b, v16.8h\n" - // At this point, we only need 4 lowest 8-bit values in v16. - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.16b, v16.16b, v14.16b\n" - // Apply the clamp_max bound - "smin v16.16b, v16.16b, v15.16b\n" - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x1 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - - // Test if w1==4, i.e. if all of the 4x1 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x1 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "st1 {v16.16b}, [%[dst_tmp_buf]]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[0], [x3], #1\n" - "st1 {v16.b}[1], [x3], #1\n" - "st1 {v16.b}[2], [x3], #1\n" - "st1 {v16.b}[3], [x3], #1\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #4\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" - - // Add the destination zero point - "dup v14.4h, v13.h[4]\n" - "saddw v16.4s, v16.4s, v14.4h\n" - - // Cast-and-saturate from int32 to int16 - // After this instruction, all data is in lower half of v16. - "sqxtn v16.4h, v16.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - - // Load the clamp_min, clamp_max bounds - "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.8h, w2\n" // clamp_min - "dup v15.8h, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.8h, v16.8h, v14.8h\n" - // Apply the clamp_max bound - "smin v16.8h, v16.8h, v15.8h\n" - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "str q16, [%[dst_tmp_buf], #0]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrh w7, [x3, x5, lsl #1]\n" - "strh w7, [x4, x5, lsl #1]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.h}[0], [x3], #2\n" - "st1 {v16.h}[1], [x3], #2\n" - "st1 {v16.h}[2], [x3], #2\n" - "st1 {v16.h}[3], [x3], #2\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #8\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" - - // Since the store type is the same as the accum type, no need for - // downcast. There's also no need for clamp by min/max. - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - - // Test if w1==4 i.e. if all of the 4x1 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x1 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "str q16, [%[dst_tmp_buf], #0]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.s}[0], [x3], #4\n" - "st1 {v16.s}[1], [x3], #4\n" - "st1 {v16.s}[2], [x3], #4\n" - "st1 {v16.s}[3], [x3], #4\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #16\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - - RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - "smull v8.8h, v0.8b, v4.8b\n" - "smull v9.8h, v1.8b, v4.8b\n" - "smull v10.8h, v2.8b, v4.8b\n" - "smull v11.8h, v3.8b, v4.8b\n" - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #4\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #4\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 16. - "mov w1, #16\n" - - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), - [dst_type_id] "r"(params.dst_type_id) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19"); -} - -// Variant of the above Kernel8bitNeonOutOfOrder, tuned for in-order CPUs. -// Specifically here, the relevant in-order CPUs are ARM Cortex-A53 and -// the original Cortex-A55, since these are 64-bit and do not support dotprod. -// -// While this kernel does not have a direct equivalent in gemmlowp, it was -// developed based on insights that David Mansell at ARM shared with their -// contribution of gemmlowp kernels tuned for Cortex-A53, with very helpful -// comments. Specifically, see this comment about tuning for Cortex-A53: -// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4215 -void Kernel8bitNeonInOrder(const KernelParams8bit<4, 4>& params) { - profiler::ScopeLabel label("Kernel (kNeon, optimized for in-order cores)"); - - CheckOffsetsInKernelParams8bit(params); - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - void* dst_col_ptr = params.dst_base_ptr; - void* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v31 are int32 accumulators. - // During accumulation, v0 -- v3 are used to load int8 data from LHS and - // v4 -- v7 from RHS: - // - // int8 RHS 16x4 block - // /-----------------------------------------\ - // |v4.b[0] ... v7.b[0] | - // | ... ... | - // |v4.b[15] ... v7.b[15] | - // \-----------------------------------------/ - // int8 LHS 4x16 block - // /---------------------\ /-----------------------------------------\ - // |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s | - // |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s | - // |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s | - // |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s | - // \---------------------/ \-----------------------------------------/ - // int32 accumulators 4x4 block - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - RUY_MAKE_ZERO(v16) - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - RUY_MAKE_ZERO(v17) - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - RUY_MAKE_ZERO(v18) - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - RUY_MAKE_ZERO(v19) - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - RUY_MAKE_ZERO(v20) - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - RUY_MAKE_ZERO(v21) - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - RUY_MAKE_ZERO(v22) - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - RUY_MAKE_ZERO(v23) - - // Load the first 64 bytes of LHS and RHS data. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - RUY_MAKE_ZERO(v24) - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - RUY_MAKE_ZERO(v25) - "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" - RUY_MAKE_ZERO(v26) - "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" - RUY_MAKE_ZERO(v27) - "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" - RUY_MAKE_ZERO(v28) - "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" - RUY_MAKE_ZERO(v29) - "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" - RUY_MAKE_ZERO(v30) - "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" - RUY_MAKE_ZERO(v31) - - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 16. - "mov w1, #16\n" - - // Perform the first few multiply-adds on the data that we have already - // loaded. - "smull v8.8h, v0.8b, v4.8b\n" - "smull v9.8h, v1.8b, v4.8b\n" - "smull v10.8h, v2.8b, v4.8b\n" - "smull v11.8h, v3.8b, v4.8b\n" - "smull v12.8h, v0.8b, v5.8b\n" - "smull v13.8h, v1.8b, v5.8b\n" - "smull v14.8h, v2.8b, v5.8b\n" - "smull v15.8h, v3.8b, v5.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - "smlal2 v12.8h, v0.16b, v5.16b\n" - "smlal2 v13.8h, v1.16b, v5.16b\n" - "smlal2 v14.8h, v2.16b, v5.16b\n" - "smlal2 v15.8h, v3.16b, v5.16b\n" - - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - // Reminder - w1 is how many levels of depth we have already loaded - // data for, w12 is the total depth. - "cmp w1, w12\n" - "beq 79f\n" - - "2:\n" - - // Some multiplications and 16-bit accumulation were already done above, - // so we start right away in the middle. - "sadalp v16.4s, v8.8h\n" - "ldr d4, [%[rhs_ptr], #0]\n" - "smull v8.8h, v0.8b, v6.8b\n" - "ldr x7, [%[rhs_ptr], #8]\n" - "sadalp v17.4s, v9.8h\n" - "ldr d5, [%[rhs_ptr], #16]\n" - "smull v9.8h, v1.8b, v6.8b\n" - "ldr x8, [%[rhs_ptr], #24]\n" - "sadalp v18.4s, v10.8h\n" - "smull v10.8h, v2.8b, v6.8b\n" - "sadalp v19.4s, v11.8h\n" - "add %[lhs_ptr], %[lhs_ptr], #64\n" - "smull v11.8h, v3.8b, v6.8b\n" - "add %[rhs_ptr], %[rhs_ptr], #64\n" - "sadalp v20.4s, v12.8h\n" - // Each iteration of this loop advances by 16 levels of depth. - "add w1, w1, #16\n" - "smull v12.8h, v0.8b, v7.8b\n" - // Loop termination condition - "cmp w1, w12\n" - "sadalp v21.4s, v13.8h\n" - "ldr x3, [%[lhs_ptr], #-56]\n" - "smull v13.8h, v1.8b, v7.8b\n" - "ldr x4, [%[lhs_ptr], #-40]\n" - "sadalp v22.4s, v14.8h\n" - "ldr x5, [%[lhs_ptr], #-24]\n" - "smull v14.8h, v2.8b, v7.8b\n" - "ldr x6, [%[lhs_ptr], #-8]\n" - "sadalp v23.4s, v15.8h\n" - "smull v15.8h, v3.8b, v7.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v6.16b\n" - "smlal2 v9.8h, v1.16b, v6.16b\n" - "smlal2 v10.8h, v2.16b, v6.16b\n" - "ldr x9, [%[rhs_ptr], #-24]\n" - "smlal2 v11.8h, v3.16b, v6.16b\n" - "ldr d6, [%[rhs_ptr], #-32]\n" - "smlal2 v12.8h, v0.16b, v7.16b\n" - "ldr d0, [%[lhs_ptr], #-64]\n" - "smlal2 v13.8h, v1.16b, v7.16b\n" - "ldr d1, [%[lhs_ptr], #-48]\n" - "smlal2 v14.8h, v2.16b, v7.16b\n" - "ins v4.d[1], x7\n" - "smlal2 v15.8h, v3.16b, v7.16b\n" - "ins v5.d[1], x8\n" - - "ldr d2, [%[lhs_ptr], #-32]\n" - "ins v0.d[1], x3\n" - "sadalp v24.4s, v8.8h\n" - "ldr d3, [%[lhs_ptr], #-16]\n" - "ins v1.d[1], x4\n" - "smull v8.8h, v0.8b, v4.8b\n" - "ins v2.d[1], x5\n" - "sadalp v25.4s, v9.8h\n" - "ins v3.d[1], x6\n" - "smull v9.8h, v1.8b, v4.8b\n" - "ldr d7, [%[rhs_ptr], #-16]\n" - "sadalp v26.4s, v10.8h\n" - "ldr x10, [%[rhs_ptr], #-8]\n" - "smull v10.8h, v2.8b, v4.8b\n" - "sadalp v27.4s, v11.8h\n" - "smull v11.8h, v3.8b, v4.8b\n" - "sadalp v28.4s, v12.8h\n" - "smull v12.8h, v0.8b, v5.8b\n" - "sadalp v29.4s, v13.8h\n" - "smull v13.8h, v1.8b, v5.8b\n" - "sadalp v30.4s, v14.8h\n" - "smull v14.8h, v2.8b, v5.8b\n" - "sadalp v31.4s, v15.8h\n" - "smull v15.8h, v3.8b, v5.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - - "smlal2 v12.8h, v0.16b, v5.16b\n" - "smlal2 v13.8h, v1.16b, v5.16b\n" - "ins v6.d[1], x9\n" - "smlal2 v14.8h, v2.16b, v5.16b\n" - "ins v7.d[1], x10\n" - "smlal2 v15.8h, v3.16b, v5.16b\n" - - "blt 2b\n" - - "79:\n" - - "sadalp v16.4s, v8.8h\n" - "smull v8.8h, v0.8b, v6.8b\n" - "sadalp v17.4s, v9.8h\n" - "smull v9.8h, v1.8b, v6.8b\n" - "sadalp v18.4s, v10.8h\n" - "smull v10.8h, v2.8b, v6.8b\n" - "sadalp v19.4s, v11.8h\n" - "smull v11.8h, v3.8b, v6.8b\n" - "sadalp v20.4s, v12.8h\n" - "smull v12.8h, v0.8b, v7.8b\n" - "sadalp v21.4s, v13.8h\n" - "smull v13.8h, v1.8b, v7.8b\n" - "sadalp v22.4s, v14.8h\n" - "smull v14.8h, v2.8b, v7.8b\n" - "sadalp v23.4s, v15.8h\n" - "smull v15.8h, v3.8b, v7.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v6.16b\n" - "smlal2 v9.8h, v1.16b, v6.16b\n" - "smlal2 v10.8h, v2.16b, v6.16b\n" - "smlal2 v11.8h, v3.16b, v6.16b\n" - - "smlal2 v12.8h, v0.16b, v7.16b\n" - "smlal2 v13.8h, v1.16b, v7.16b\n" - "smlal2 v14.8h, v2.16b, v7.16b\n" - "smlal2 v15.8h, v3.16b, v7.16b\n" - - "sadalp v24.4s, v8.8h\n" - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "sadalp v25.4s, v9.8h\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "sadalp v26.4s, v10.8h\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "sadalp v27.4s, v11.8h\n" - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "sadalp v28.4s, v12.8h\n" - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "sadalp v29.4s, v13.8h\n" - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "sadalp v30.4s, v14.8h\n" - "sadalp v31.4s, v15.8h\n" - - // End of accumulation. The registers v16 -- v31 contain the final - // int32 accumulator values of the current 4x4 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 4x4 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Reduce 32bit accumulators horizontally. - "addp v16.4s, v16.4s, v17.4s\n" - "addp v18.4s, v18.4s, v19.4s\n" - "addp v20.4s, v20.4s, v21.4s\n" - "addp v22.4s, v22.4s, v23.4s\n" - "addp v24.4s, v24.4s, v25.4s\n" - "addp v26.4s, v26.4s, v27.4s\n" - "addp v28.4s, v28.4s, v29.4s\n" - "addp v30.4s, v30.4s, v31.4s\n" - - // Reduce 32bit accumulators horizontally, second pass - // (each pass adds pairwise. we need to add 4-wise). - "addp v16.4s, v16.4s, v18.4s\n" - "addp v17.4s, v20.4s, v22.4s\n" - "addp v18.4s, v24.4s, v26.4s\n" - "addp v19.4s, v28.4s, v30.4s\n" - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "cmp %w[row], w7\n" // Have we finished the last row? - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - - // Load some parameters needed for the end work on current block. - RUY_MAKE_ZERO(v8) - "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" - "ins v13.h[4], w4\n" // dst_zero_point - "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" - "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "dup v9.4s, w3\n" // create prod_zp_depth_vec - "add x5, x4, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "csel x4, x4, x5, eq\n" - - "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint - - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - "add x5, x1, %x[row], lsl #2\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 4 bias values. - "ld1 {v14.4s}, [x1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - - // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), - // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "add v14.4s, v14.4s, v9.4s\n" - "ldr d0, [%[lhs_ptr], #0]\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "add v16.4s, v16.4s, v14.4s\n" - "ldr d1, [%[lhs_ptr], #16]\n" - "add v17.4s, v17.4s, v14.4s\n" - "ldr d2, [%[lhs_ptr], #32]\n" - "add v18.4s, v18.4s, v14.4s\n" - "ldr d3, [%[lhs_ptr], #48]\n" - "add v19.4s, v19.4s, v14.4s\n" - "ldr d4, [%[rhs_ptr], #0]\n" - "ldr d5, [%[rhs_ptr], #16]\n" - "ldr d6, [%[rhs_ptr], #32]\n" - "ldr d7, [%[rhs_ptr], #48]\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" - "beq 401f\n" - "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" - "add x3, x3, %x[col], lsl #2\n" - "ld1 {v14.4s}, [x3]\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" - "dup v10.4s, w5\n" // create lhs_zero_point_vec - // Subtract rhs_sums * lhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "mls v16.4s, v10.4s, v14.s[0]\n" - "mls v17.4s, v10.4s, v14.s[1]\n" - "mls v18.4s, v10.4s, v14.s[2]\n" - "mls v19.4s, v10.4s, v14.s[3]\n" - "401:\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" - "beq 402f\n" - "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" - "add x2, x2, %x[row], lsl #2\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" - // Load 4 lhs_sums values. - "ld1 {v11.4s}, [x2]\n" - "ins v13.s[1], w5\n" // rhs_zero_point - // Compute lhs_sums * rhs_zero_point. - "mul v11.4s, v11.4s, v13.s[1]\n" - // Subtract lhs_sums * rhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "sub v16.4s, v16.4s, v11.4s\n" - "sub v17.4s, v17.4s, v11.4s\n" - "sub v18.4s, v18.4s, v11.4s\n" - "sub v19.4s, v19.4s, v11.4s\n" - - // If the destination is int32, it means the user asks for the raw - // accumulators, no need for us to downquantize the value. - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" - - "402:\n" - - // At this point we have computed the final int32 values. Now we - // start down-quantizing them to obtain the final 8bit values from them. - - // As part of this down-quantization, our int32 values will be - // multiplied by a multiplier that has a fixed-point component and an - // exponent component. - - - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "add x5, x1, %x[row], lsl #2\n" - "csel x1, x1, x5, eq\n" - - "ld1 {v14.4s}, [x1]\n" - - "smax v12.4s, v14.4s, v8.4s\n" - "ldr x1, [%[lhs_ptr], #8]\n" - - "sshl v16.4s, v16.4s, v12.4s\n" - "ldr x2, [%[lhs_ptr], #24]\n" - "sshl v17.4s, v17.4s, v12.4s\n" - "ldr x3, [%[lhs_ptr], #40]\n" - "sshl v18.4s, v18.4s, v12.4s\n" - "ldr x4, [%[lhs_ptr], #56]\n" - "sshl v19.4s, v19.4s, v12.4s\n" - - "smin v12.4s, v14.4s, v8.4s\n" - - // Apply the fixed-point part of the multiplier. - "ins v0.d[1], x1\n" - "ldr x1, [%[rhs_ptr], #8]\n" - "sqrdmulh v16.4s, v16.4s, v15.4s\n" - "ins v1.d[1], x2\n" - "ldr x2, [%[rhs_ptr], #24]\n" - "sqrdmulh v17.4s, v17.4s, v15.4s\n" - "ins v2.d[1], x3\n" - "ldr x3, [%[rhs_ptr], #40]\n" - "sqrdmulh v18.4s, v18.4s, v15.4s\n" - "ins v3.d[1], x4\n" - "ldr x4, [%[rhs_ptr], #56]\n" - "sqrdmulh v19.4s, v19.4s, v15.4s\n" - - // We have some rounding division-by-power-of-two to do. This should - // always use "round to nearest". We allow for some - // freedom in how ties are broken, to strike a good compromise of - // performance on given hardware vs. perfect agreement of results - // across hardware. - // - // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation - // defined tie-breaks to help performance. On NEON, this means that we - // can just use the NEON rounding instructions, such as srshl. They - // happen to be breaking ties upward. - // - // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict - // break-ties-away-from zero, as described in Appendix B of - // https://arxiv.org/pdf/1712.05877.pdf - // When we wrote that, we thought that that would be better unbiased - // than the NEON upwards tie-breaks, and we had observed some - // improvement on some model. However, that is only more unbiased for - // data centered at zero, which was likely the case in that model, - // but is not always the case. If we wanted something more consistently - // unbiased then we should try breaking ties toward-nearest-even. -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - // Fix up values to be right-shifted, so that the (round to nearest, - // break ties upward) behavior of srshl applied to these fixed-up - // values, produces the same result as the desired (round to nearest, - // break ties away from zero) behavior on the original values. - "and v8.16b, v16.16b, v12.16b\n" - "and v9.16b, v17.16b, v12.16b\n" - "and v14.16b, v18.16b, v12.16b\n" - "and v15.16b, v19.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v16.4s, v16.4s, v8.4s\n" - "sqadd v17.4s, v17.4s, v9.4s\n" - "sqadd v18.4s, v18.4s, v14.4s\n" - "sqadd v19.4s, v19.4s, v15.4s\n" -#endif - // At this point we have reduced the problem of correctly implementing - // rounding divide-by-power-of-two, to what the SRSHL instruction can - // do. - "srshl v16.4s, v16.4s, v12.4s\n" - "srshl v17.4s, v17.4s, v12.4s\n" - "srshl v18.4s, v18.4s, v12.4s\n" - "srshl v19.4s, v19.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" - - "ins v4.d[1], x1\n" - "sqxtn v16.4h, v16.4s\n" - "ins v5.d[1], x2\n" - "sqxtn2 v16.8h, v17.4s\n" - "ins v6.d[1], x3\n" - "sqxtn v17.4h, v18.4s\n" - "ins v7.d[1], x4\n" - RUY_MAKE_ZERO(v18) - "sqxtn2 v17.8h, v19.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v19) - - // Add the destination zero point - "add %[lhs_ptr], %[lhs_ptr], #64\n" - "dup v14.8h, v13.h[4]\n" - RUY_MAKE_ZERO(v20) - "add %[rhs_ptr], %[rhs_ptr], #64\n" - "add v16.8h, v16.8h, v14.8h\n" - RUY_MAKE_ZERO(v21) - "add v17.8h, v17.8h, v14.8h\n" - RUY_MAKE_ZERO(v22) - - // Cast-and-saturate from int16 to uint8 - "sqxtun v16.8b, v16.8h\n" - RUY_MAKE_ZERO(v23) - "sqxtun2 v16.16b, v17.8h\n" - RUY_MAKE_ZERO(v24) - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - RUY_MAKE_ZERO(v25) - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - RUY_MAKE_ZERO(v26) - "dup v14.16b, w2\n" // clamp_min - RUY_MAKE_ZERO(v27) - "dup v15.16b, w3\n" // clamp_max - RUY_MAKE_ZERO(v28) - - // Apply the clamp_min bound - "umax v16.16b, v16.16b, v14.16b\n" - RUY_MAKE_ZERO(v29) - // Apply the clamp_max bound - "umin v16.16b, v16.16b, v15.16b\n" - RUY_MAKE_ZERO(v30) - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - RUY_MAKE_ZERO(v31) - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - // Compute w2 = how many cols of the 4x4 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "st1 {v16.16b}, [%[dst_tmp_buf]]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #4\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[0], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[1], [x3], #1\n" - "st1 {v16.b}[2], [x3], #1\n" - "st1 {v16.b}[3], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[4], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[5], [x3], #1\n" - "st1 {v16.b}[6], [x3], #1\n" - "st1 {v16.b}[7], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[8], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[9], [x3], #1\n" - "st1 {v16.b}[10], [x3], #1\n" - "st1 {v16.b}[11], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[12], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[13], [x3], #1\n" - "st1 {v16.b}[14], [x3], #1\n" - "st1 {v16.b}[15], [x3], #1\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #4\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" - - "ins v4.d[1], x1\n" - "sqxtn v16.4h, v16.4s\n" - "ins v5.d[1], x2\n" - "sqxtn2 v16.8h, v17.4s\n" - "ins v6.d[1], x3\n" - "sqxtn v17.4h, v18.4s\n" - "ins v7.d[1], x4\n" - RUY_MAKE_ZERO(v18) - "sqxtn2 v17.8h, v19.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v19) - - // Add the destination zero point - "add %[lhs_ptr], %[lhs_ptr], #64\n" - "dup v14.8h, v13.h[4]\n" - RUY_MAKE_ZERO(v20) - "add %[rhs_ptr], %[rhs_ptr], #64\n" - "add v16.8h, v16.8h, v14.8h\n" - RUY_MAKE_ZERO(v21) - "add v17.8h, v17.8h, v14.8h\n" - RUY_MAKE_ZERO(v22) - - // Cast-and-saturate from int16 to uint8 - "sqxtn v16.8b, v16.8h\n" - RUY_MAKE_ZERO(v23) - "sqxtn2 v16.16b, v17.8h\n" - RUY_MAKE_ZERO(v24) - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - RUY_MAKE_ZERO(v25) - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - RUY_MAKE_ZERO(v26) - "dup v14.16b, w2\n" // clamp_min - RUY_MAKE_ZERO(v27) - "dup v15.16b, w3\n" // clamp_max - RUY_MAKE_ZERO(v28) - - // Apply the clamp_min bound - "smax v16.16b, v16.16b, v14.16b\n" - RUY_MAKE_ZERO(v29) - // Apply the clamp_max bound - "smin v16.16b, v16.16b, v15.16b\n" - RUY_MAKE_ZERO(v30) - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - RUY_MAKE_ZERO(v31) - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - // Compute w2 = how many cols of the 4x4 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "st1 {v16.16b}, [%[dst_tmp_buf]]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #4\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[0], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[1], [x3], #1\n" - "st1 {v16.b}[2], [x3], #1\n" - "st1 {v16.b}[3], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[4], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[5], [x3], #1\n" - "st1 {v16.b}[6], [x3], #1\n" - "st1 {v16.b}[7], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[8], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[9], [x3], #1\n" - "st1 {v16.b}[10], [x3], #1\n" - "st1 {v16.b}[11], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[12], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[13], [x3], #1\n" - "st1 {v16.b}[14], [x3], #1\n" - "st1 {v16.b}[15], [x3], #1\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #4\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" - - // Add the destination zero point - "dup v14.4h, v13.h[4]\n" - "saddw v16.4s, v16.4s, v14.4h\n" - "saddw v17.4s, v17.4s, v14.4h\n" - "saddw v18.4s, v18.4s, v14.4h\n" - "saddw v19.4s, v19.4s, v14.4h\n" - - // Cast-and-saturate from int32 to int16 - "ins v4.d[1], x1\n" - "sqxtn v16.4h, v16.4s\n" - "ins v5.d[1], x2\n" - "sqxtn2 v16.8h, v17.4s\n" - "ins v6.d[1], x3\n" - "sqxtn v17.4h, v18.4s\n" - "ins v7.d[1], x4\n" - RUY_MAKE_ZERO(v18) - "sqxtn2 v17.8h, v19.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v19) - - "add %[lhs_ptr], %[lhs_ptr], #64\n" - RUY_MAKE_ZERO(v20) - "add %[rhs_ptr], %[rhs_ptr], #64\n" - RUY_MAKE_ZERO(v21) - RUY_MAKE_ZERO(v22) - - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - - // Load the clamp_min, clamp_max bounds - "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - RUY_MAKE_ZERO(v25) - "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - RUY_MAKE_ZERO(v26) - "dup v14.8h, w2\n" // clamp_min - RUY_MAKE_ZERO(v27) - "dup v15.8h, w3\n" // clamp_max - RUY_MAKE_ZERO(v28) - - // Apply the clamp_min bound - "smax v16.8h, v16.8h, v14.8h\n" - "smax v17.8h, v17.8h, v14.8h\n" - RUY_MAKE_ZERO(v29) - // Apply the clamp_max bound - "smin v16.8h, v16.8h, v15.8h\n" - "smin v17.8h, v17.8h, v15.8h\n" - RUY_MAKE_ZERO(v30) - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - RUY_MAKE_ZERO(v31) - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - // Compute w2 = how many cols of the 4x4 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "str q16, [%[dst_tmp_buf], #0]\n" - "str q17, [%[dst_tmp_buf], #16]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrh w7, [x3, x5, lsl #1]\n" - "strh w7, [x4, x5, lsl #1]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #8\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.h}[0], [x3], #2\n" - "add x4, x4, x11\n" - "st1 {v16.h}[1], [x3], #2\n" - "st1 {v16.h}[2], [x3], #2\n" - "st1 {v16.h}[3], [x3], #2\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.h}[4], [x3], #2\n" - "add x4, x4, x11\n" - "st1 {v16.h}[5], [x3], #2\n" - "st1 {v16.h}[6], [x3], #2\n" - "st1 {v16.h}[7], [x3], #2\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v17.h}[0], [x3], #2\n" - "add x4, x4, x11\n" - "st1 {v17.h}[1], [x3], #2\n" - "st1 {v17.h}[2], [x3], #2\n" - "st1 {v17.h}[3], [x3], #2\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v17.h}[4], [x3], #2\n" - "add x4, x4, x11\n" - "st1 {v17.h}[5], [x3], #2\n" - "st1 {v17.h}[6], [x3], #2\n" - "st1 {v17.h}[7], [x3], #2\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #8\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" - - "ldr x1, [%[lhs_ptr], #8]\n" - "ldr x2, [%[lhs_ptr], #24]\n" - "ldr x3, [%[lhs_ptr], #40]\n" - "ldr x4, [%[lhs_ptr], #56]\n" - - "ins v0.d[1], x1\n" - "ldr x1, [%[rhs_ptr], #8]\n" - "ins v1.d[1], x2\n" - "ldr x2, [%[rhs_ptr], #24]\n" - "ins v2.d[1], x3\n" - "ldr x3, [%[rhs_ptr], #40]\n" - "ins v3.d[1], x4\n" - "ldr x4, [%[rhs_ptr], #56]\n" - "ins v4.d[1], x1\n" - "ins v5.d[1], x2\n" - "ins v6.d[1], x3\n" - "ins v7.d[1], x4\n" - - // Since the store type is the same as the accum type, no need for - // downcast. There's also no need for clamp by min/max. - - // At this point, v20 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - - RUY_MAKE_ZERO(v20) - "add %[lhs_ptr], %[lhs_ptr], #64\n" - RUY_MAKE_ZERO(v21) - "add %[rhs_ptr], %[rhs_ptr], #64\n" - RUY_MAKE_ZERO(v22) - - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - RUY_MAKE_ZERO(v31) - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - // Compute w2 = how many cols of the 4x4 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "str q16, [%[dst_tmp_buf], #0]\n" - "str q17, [%[dst_tmp_buf], #16]\n" - "str q18, [%[dst_tmp_buf], #32]\n" - "str q19, [%[dst_tmp_buf], #48]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #16\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.s}[0], [x3], #4\n" - "add x4, x4, x11\n" - "st1 {v16.s}[1], [x3], #4\n" - "st1 {v16.s}[2], [x3], #4\n" - "st1 {v16.s}[3], [x3], #4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v17.s}[0], [x3], #4\n" - "add x4, x4, x11\n" - "st1 {v17.s}[1], [x3], #4\n" - "st1 {v17.s}[2], [x3], #4\n" - "st1 {v17.s}[3], [x3], #4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v18.s}[0], [x3], #4\n" - "add x4, x4, x11\n" - "st1 {v18.s}[1], [x3], #4\n" - "st1 {v18.s}[2], [x3], #4\n" - "st1 {v18.s}[3], [x3], #4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v19.s}[0], [x3], #4\n" - "add x4, x4, x11\n" - "st1 {v19.s}[1], [x3], #4\n" - "st1 {v19.s}[2], [x3], #4\n" - "st1 {v19.s}[3], [x3], #4\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #16\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - - RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - "smull v8.8h, v0.8b, v4.8b\n" - "smull v9.8h, v1.8b, v4.8b\n" - "smull v10.8h, v2.8b, v4.8b\n" - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "smull v11.8h, v3.8b, v4.8b\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "smull v12.8h, v0.8b, v5.8b\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "smull v13.8h, v1.8b, v5.8b\n" - "smull v14.8h, v2.8b, v5.8b\n" - "smull v15.8h, v3.8b, v5.8b\n" - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - "smlal2 v12.8h, v0.16b, v5.16b\n" - "smlal2 v13.8h, v1.16b, v5.16b\n" - "smlal2 v14.8h, v2.16b, v5.16b\n" - "smlal2 v15.8h, v3.16b, v5.16b\n" - - - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #4\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #4\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 4. - "mov w1, #16\n" - - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms),[dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), - [dst_type_id] "r"(params.dst_type_id) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "v31"); -} - -// Kernel taking advantage of the optional dotprod instruction. -// This is very similar to (and directly inspired by) this gemmlowp kernel -// which was contributed by David Mansell at ARM: -// NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct -// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L3391 -// -// Besides the ruy-ification, the main difference here is that we use a 8x8 -// instead of 12x8 width, so as to stick to power-of-two widths. This slightly -// narrower kernel layout is still wide enough to achieve high performance -// although we haven't actually performed a real comparison to know exactly -// how this compares to ARM's aforementioned kernel. -// -// Relevant target CPUs for this kernel include ARM Cortex-A76, -// since these are 64-bit, out-of-order and with dotprod support. -void Kernel8bitNeonDotprodOutOfOrder(const KernelParams8bit<8, 8>& params) { - profiler::ScopeLabel label( - "Kernel (kNeonDotprod, optimized for out-of-order cores)"); - - CheckOffsetsInKernelParams8bit(params); - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - void* dst_col_ptr = params.dst_base_ptr; - void* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v31 are int32 accumulators. - // During accumulation, v0 -- v15 are used to load int8 data from LHS and - // RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and - // v3 are used to load a 4x8 block of RHS, like this: - // - // int8 RHS 4x8 block - // /-----------------------------------------\ - // |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]| - // | ... ... | - // |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]| - // \-----------------------------------------/ - // int8 LHS 8x4 block - // /---------------------\ /-----------------------------------------\ - // |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]| - // | ... ... | | ... ... | - // |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]| - // |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]| - // | ... ... | | ... ... | - // |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]| - // \---------------------/ \-----------------------------------------/ - // int32 accumulators 8x8 block - // - // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step - // is repeated 4 times, using 4x more registers for LHS and RHS, so that - // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15. - // - // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are - // unused, and v8 -- v15 are used for loading parameters used for the - // post-accumulation part of the kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - - // Load the first 32 bytes of LHS and RHS data. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" - - // Clear accumulators. - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 4. - "mov w1, #4\n" - - // Perform the first few multiply-adds on the data that we have already - // loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - // Optional, maximally-streaming, partial-unrolling (4x unrolled) - // optimization of the kernel inner loop (over depth). For more - // comments, see the non-unrolled loop below after the #endif. -#if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING) - "cmp w12, #32\n" - "blt 78f\n" - - "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v8.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v9.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v10.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v11.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v12.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v13.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v14.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v15.16b}, [%[rhs_ptr]], #16\n" - "mov w1, #16\n" - - "and w3, w12, #-16\n" - "81:\n" - "add w1, w1, #16\n" - - ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" - ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" - ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" - ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" - "ldr q0, [%[lhs_ptr], #0]\n" - ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" - ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" - ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" - ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" - "ldr q2, [%[rhs_ptr], #0]\n" - ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" - ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" - ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" - ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" - "ldr q1, [%[lhs_ptr], #16]\n" - - ".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n" - ".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n" - "ldr q3, [%[rhs_ptr], #16]\n" - ".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n" - ".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n" - ".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n" - ".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n" - ".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n" - ".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n" - ".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n" - ".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n" - ".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n" - ".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n" - "ldr q5, [%[lhs_ptr], #48]\n" - ".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n" - ".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n" - "ldr q7, [%[rhs_ptr], #48]\n" - ".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n" - ".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n" - "ldr q4, [%[lhs_ptr], #32]\n" - - ".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n" - ".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n" - "ldr q6, [%[rhs_ptr], #32]\n" - ".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n" - ".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n" - ".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n" - ".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n" - ".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n" - ".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n" - ".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n" - ".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n" - ".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n" - ".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n" - "ldr q9, [%[lhs_ptr], #80]\n" - ".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n" - ".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n" - "ldr q11, [%[rhs_ptr], #80]\n" - ".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n" - ".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n" - "ldr q8, [%[lhs_ptr], #64]\n" - - ".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n" - ".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n" - "ldr q10, [%[rhs_ptr], #64]\n" - ".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n" - ".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n" - "add %[lhs_ptr], %[lhs_ptr], #128\n" - ".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n" - ".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n" - "add %[rhs_ptr], %[rhs_ptr], #128\n" - ".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n" - ".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n" - ".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n" - ".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n" - "cmp w1, w3\n" - ".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n" - ".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n" - "ldr q13, [%[lhs_ptr], #-16]\n" - ".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n" - ".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n" - "ldr q15, [%[rhs_ptr], #-16]\n" - ".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n" - ".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n" - "ldr q12, [%[lhs_ptr], #-32]\n" - - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - "ldr q14, [%[rhs_ptr], #-32]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - - "blt 81b\n" - - ".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n" - ".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n" - ".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n" - ".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n" - ".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n" - ".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n" - ".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n" - ".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n" - ".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n" - ".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n" - ".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n" - ".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n" - ".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n" - ".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n" - ".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n" - ".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n" - - ".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n" - ".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n" - ".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n" - ".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n" - ".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n" - ".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n" - ".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n" - ".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n" - ".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n" - ".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n" - ".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n" - ".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n" - ".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n" - ".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n" - ".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n" - ".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n" - - ".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n" - ".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n" - ".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n" - ".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n" - ".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n" - ".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n" - ".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n" - ".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n" - ".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n" - ".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n" - ".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n" - ".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n" - ".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n" - ".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n" - ".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n" - ".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n" - - "78:\n" - -#endif // #if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING) - - // Ordinary kernel inner loop (over depth), the simpler loop that the - // above was an equivalent 4x-partially-unrolled version of. - - // Reminder - w1 is how many levels of depth we have already loaded - // data for, w12 is the total depth. - "cmp w1, w12\n" - "beq 79f\n" - - "2:\n" - - // Because of the data that we have already loaded, we can start the - // loop body right away with some multiply-adds. - ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" - ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" - // Each iteration of this loop advances by 4 levels of depth. - "add w1, w1, #4\n" - ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" - ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" - ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" - // Loop termination condition. - "cmp w1, w12\n" - ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" - ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" - "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" - ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" - ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" - ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" - ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" - "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - - "blt 2b\n" - - "79:\n" - // End of the inner loop on depth. Now perform the remaining - // multiply-adds of the last 4 levels of depth, for which the LHS - // and RHS data is already loaded. - - ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" - ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" - ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" - ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" - ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" - ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" - ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" - ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" - ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" - ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" - ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" - ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" - - // End of accumulation. The registers v16 -- v31 contain the final - // int32 accumulator values of the current 8x8 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 8x8 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "cmp %w[row], w7\n" // Have we finished the last row? - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - - // Load some parameters needed for the end work on current block. - RUY_MAKE_ZERO(v8) - "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" - "ins v13.h[4], w4\n" // dst_zero_point - "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" - "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "dup v9.4s, w3\n" // create prod_zp_depth_vec - "add x5, x4, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "csel x4, x4, x5, eq\n" - - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - "add x5, x1, %x[row], lsl #2\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 8 bias values. - "ld1 {v14.4s}, [x1], #16\n" - "ld1 {v15.4s}, [x1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" - - // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), - // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "add v14.4s, v14.4s, v9.4s\n" - "add v15.4s, v15.4s, v9.4s\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "add v16.4s, v16.4s, v14.4s\n" - "add v17.4s, v17.4s, v15.4s\n" - "add v18.4s, v18.4s, v14.4s\n" - "add v19.4s, v19.4s, v15.4s\n" - "add v20.4s, v20.4s, v14.4s\n" - "add v21.4s, v21.4s, v15.4s\n" - "add v22.4s, v22.4s, v14.4s\n" - "add v23.4s, v23.4s, v15.4s\n" - "add v24.4s, v24.4s, v14.4s\n" - "add v25.4s, v25.4s, v15.4s\n" - "add v26.4s, v26.4s, v14.4s\n" - "add v27.4s, v27.4s, v15.4s\n" - "add v28.4s, v28.4s, v14.4s\n" - "add v29.4s, v29.4s, v15.4s\n" - "add v30.4s, v30.4s, v14.4s\n" - "add v31.4s, v31.4s, v15.4s\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" - "beq 401f\n" - "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" - "add x3, x3, %x[col], lsl #2\n" - "ld1 {v14.4s}, [x3], #16\n" - "ld1 {v15.4s}, [x3]\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" - "dup v10.4s, w5\n" // create lhs_zero_point_vec - // Subtract rhs_sums * lhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "mls v16.4s, v10.4s, v14.s[0]\n" - "mls v17.4s, v10.4s, v14.s[0]\n" - "mls v18.4s, v10.4s, v14.s[1]\n" - "mls v19.4s, v10.4s, v14.s[1]\n" - "mls v20.4s, v10.4s, v14.s[2]\n" - "mls v21.4s, v10.4s, v14.s[2]\n" - "mls v22.4s, v10.4s, v14.s[3]\n" - "mls v23.4s, v10.4s, v14.s[3]\n" - "mls v24.4s, v10.4s, v15.s[0]\n" - "mls v25.4s, v10.4s, v15.s[0]\n" - "mls v26.4s, v10.4s, v15.s[1]\n" - "mls v27.4s, v10.4s, v15.s[1]\n" - "mls v28.4s, v10.4s, v15.s[2]\n" - "mls v29.4s, v10.4s, v15.s[2]\n" - "mls v30.4s, v10.4s, v15.s[3]\n" - "mls v31.4s, v10.4s, v15.s[3]\n" - "401:\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" - "beq 402f\n" - "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" - "add x2, x2, %x[row], lsl #2\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" - // Load 4 lhs_sums values. - "ld1 {v11.4s}, [x2], #16\n" - "ld1 {v12.4s}, [x2]\n" - "ins v13.s[1], w5\n" // rhs_zero_point - // Compute lhs_sums * rhs_zero_point. - "mul v11.4s, v11.4s, v13.s[1]\n" - "mul v12.4s, v12.4s, v13.s[1]\n" - // Subtract lhs_sums * rhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "sub v16.4s, v16.4s, v11.4s\n" - "sub v17.4s, v17.4s, v12.4s\n" - "sub v18.4s, v18.4s, v11.4s\n" - "sub v19.4s, v19.4s, v12.4s\n" - "sub v20.4s, v20.4s, v11.4s\n" - "sub v21.4s, v21.4s, v12.4s\n" - "sub v22.4s, v22.4s, v11.4s\n" - "sub v23.4s, v23.4s, v12.4s\n" - "sub v24.4s, v24.4s, v11.4s\n" - "sub v25.4s, v25.4s, v12.4s\n" - "sub v26.4s, v26.4s, v11.4s\n" - "sub v27.4s, v27.4s, v12.4s\n" - "sub v28.4s, v28.4s, v11.4s\n" - "sub v29.4s, v29.4s, v12.4s\n" - "sub v30.4s, v30.4s, v11.4s\n" - "sub v31.4s, v31.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" - - "402:\n" - - // At this point we have computed the final int32 values. Now we - // start down-quantizing them to obtain the final 8bit values from them. - - // As part of this down-quantization, our int32 values will be - // multiplied by a multiplier that has a fixed-point component and an - // exponent component. - - //Load the exponent part of the multiplier. - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "add x5, x1, %x[row], lsl #2\n" - "csel x1, x1, x5, eq\n" - - "ldr q9, [x1]\n" - "ldr q10, [x1, #16]\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n" - "beq 403f\n" - "smax v11.4s, v9.4s, v8.4s\n" - "smax v12.4s, v10.4s, v8.4s\n" - "sshl v16.4s, v16.4s, v11.4s\n" - "sshl v17.4s, v17.4s, v12.4s\n" - "sshl v18.4s, v18.4s, v11.4s\n" - "sshl v19.4s, v19.4s, v12.4s\n" - "sshl v20.4s, v20.4s, v11.4s\n" - "sshl v21.4s, v21.4s, v12.4s\n" - "sshl v22.4s, v22.4s, v11.4s\n" - "sshl v23.4s, v23.4s, v12.4s\n" - "sshl v24.4s, v24.4s, v11.4s\n" - "sshl v25.4s, v25.4s, v12.4s\n" - "sshl v26.4s, v26.4s, v11.4s\n" - "sshl v27.4s, v27.4s, v12.4s\n" - "sshl v28.4s, v28.4s, v11.4s\n" - "sshl v29.4s, v29.4s, v12.4s\n" - "sshl v30.4s, v30.4s, v11.4s\n" - "sshl v31.4s, v31.4s, v12.4s\n" - "403:\n" - - "ldr q14, [x4]\n" // multiplier_fixedpoint - "ldr q15, [x4, #16]\n" // multiplier_fixedpoint - - "smin v11.4s, v9.4s, v8.4s\n" - "smin v12.4s, v10.4s, v8.4s\n" - - // Apply the fixed-point part of the multiplier. - "sqrdmulh v16.4s, v16.4s, v14.4s\n" - "sqrdmulh v17.4s, v17.4s, v15.4s\n" - "sqrdmulh v18.4s, v18.4s, v14.4s\n" - "sqrdmulh v19.4s, v19.4s, v15.4s\n" - "sqrdmulh v20.4s, v20.4s, v14.4s\n" - "sqrdmulh v21.4s, v21.4s, v15.4s\n" - "sqrdmulh v22.4s, v22.4s, v14.4s\n" - "sqrdmulh v23.4s, v23.4s, v15.4s\n" - "sqrdmulh v24.4s, v24.4s, v14.4s\n" - "sqrdmulh v25.4s, v25.4s, v15.4s\n" - "sqrdmulh v26.4s, v26.4s, v14.4s\n" - "sqrdmulh v27.4s, v27.4s, v15.4s\n" - "sqrdmulh v28.4s, v28.4s, v14.4s\n" - "sqrdmulh v29.4s, v29.4s, v15.4s\n" - "sqrdmulh v30.4s, v30.4s, v14.4s\n" - "sqrdmulh v31.4s, v31.4s, v15.4s\n" - - // We have some rounding division-by-power-of-two to do. This should - // always use "round to nearest". We allow for some - // freedom in how ties are broken, to strike a good compromise of - // performance on given hardware vs. perfect agreement of results - // across hardware. - // - // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation - // defined tie-breaks to help performance. On NEON, this means that we - // can just use the NEON rounding instructions, such as srshl. They - // happen to be breaking ties upward. - // - // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict - // break-ties-away-from zero, as described in Appendix B of - // https://arxiv.org/pdf/1712.05877.pdf - // When we wrote that, we thought that that would be better unbiased - // than the NEON upwards tie-breaks, and we had observed some - // improvement on some model. However, that is only more unbiased for - // data centered at zero, which was likely the case in that model, - // but is not always the case. If we wanted something more consistently - // unbiased then we should try breaking ties toward-nearest-even. -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - // Fix up values to be right-shifted, so that the (round to nearest, - // break ties upward) behavior of srshl applied to these fixed-up - // values, produces the same result as the desired (round to nearest, - // break ties away from zero) behavior on the original values. - "and v8.16b, v16.16b, v11.16b\n" - "and v9.16b, v17.16b, v12.16b\n" - "and v14.16b, v18.16b, v11.16b\n" - "and v15.16b, v19.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v16.4s, v16.4s, v8.4s\n" - "sqadd v17.4s, v17.4s, v9.4s\n" - "sqadd v18.4s, v18.4s, v14.4s\n" - "sqadd v19.4s, v19.4s, v15.4s\n" - "and v8.16b, v20.16b, v11.16b\n" - "and v9.16b, v21.16b, v12.16b\n" - "and v14.16b, v22.16b, v11.16b\n" - "and v15.16b, v23.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v20.4s, v20.4s, v8.4s\n" - "sqadd v21.4s, v21.4s, v9.4s\n" - "sqadd v22.4s, v22.4s, v14.4s\n" - "sqadd v23.4s, v23.4s, v15.4s\n" - "and v8.16b, v24.16b, v11.16b\n" - "and v9.16b, v25.16b, v12.16b\n" - "and v14.16b, v26.16b, v11.16b\n" - "and v15.16b, v27.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v24.4s, v24.4s, v8.4s\n" - "sqadd v25.4s, v25.4s, v9.4s\n" - "sqadd v26.4s, v26.4s, v14.4s\n" - "sqadd v27.4s, v27.4s, v15.4s\n" - "and v8.16b, v28.16b, v11.16b\n" - "and v9.16b, v29.16b, v12.16b\n" - "and v14.16b, v30.16b, v11.16b\n" - "and v15.16b, v31.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v28.4s, v28.4s, v8.4s\n" - "sqadd v29.4s, v29.4s, v9.4s\n" - "sqadd v30.4s, v30.4s, v14.4s\n" - "sqadd v31.4s, v31.4s, v15.4s\n" -#endif - // At this point we have reduced the problem of correctly implementing - // rounding divide-by-power-of-two, to what the SRSHL instruction can - // do. - "srshl v16.4s, v16.4s, v11.4s\n" - "srshl v17.4s, v17.4s, v12.4s\n" - "srshl v18.4s, v18.4s, v11.4s\n" - "srshl v19.4s, v19.4s, v12.4s\n" - "srshl v20.4s, v20.4s, v11.4s\n" - "srshl v21.4s, v21.4s, v12.4s\n" - "srshl v22.4s, v22.4s, v11.4s\n" - "srshl v23.4s, v23.4s, v12.4s\n" - "srshl v24.4s, v24.4s, v11.4s\n" - "srshl v25.4s, v25.4s, v12.4s\n" - "srshl v26.4s, v26.4s, v11.4s\n" - "srshl v27.4s, v27.4s, v12.4s\n" - "srshl v28.4s, v28.4s, v11.4s\n" - "srshl v29.4s, v29.4s, v12.4s\n" - "srshl v30.4s, v30.4s, v11.4s\n" - "srshl v31.4s, v31.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - "sqxtn v18.4h, v20.4s\n" - "sqxtn2 v18.8h, v21.4s\n" - "sqxtn v19.4h, v22.4s\n" - "sqxtn2 v19.8h, v23.4s\n" - "sqxtn v20.4h, v24.4s\n" - "sqxtn2 v20.8h, v25.4s\n" - "sqxtn v21.4h, v26.4s\n" - "sqxtn2 v21.8h, v27.4s\n" - "sqxtn v22.4h, v28.4s\n" - "sqxtn2 v22.8h, v29.4s\n" - "sqxtn v23.4h, v30.4s\n" - "sqxtn2 v23.8h, v31.4s\n" - - // At this point, v24 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - "add v17.8h, v17.8h, v14.8h\n" - "add v18.8h, v18.8h, v14.8h\n" - "add v19.8h, v19.8h, v14.8h\n" - "add v20.8h, v20.8h, v14.8h\n" - "add v21.8h, v21.8h, v14.8h\n" - "add v22.8h, v22.8h, v14.8h\n" - "add v23.8h, v23.8h, v14.8h\n" - - // Cast-and-saturate from int16 to uint8 - "sqxtun v16.8b, v16.8h\n" - "sqxtun2 v16.16b, v17.8h\n" - "sqxtun v17.8b, v18.8h\n" - "sqxtun2 v17.16b, v19.8h\n" - "sqxtun v18.8b, v20.8h\n" - "sqxtun2 v18.16b, v21.8h\n" - "sqxtun v19.8b, v22.8h\n" - "sqxtun2 v19.16b, v23.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Apply the clamp_min bound - "umax v16.16b, v16.16b, v14.16b\n" - "umax v17.16b, v17.16b, v14.16b\n" - "umax v18.16b, v18.16b, v14.16b\n" - "umax v19.16b, v19.16b, v14.16b\n" - - // Apply the clamp_max bound - "umin v16.16b, v16.16b, v15.16b\n" - "umin v17.16b, v17.16b, v15.16b\n" - "umin v18.16b, v18.16b, v15.16b\n" - "umin v19.16b, v19.16b, v15.16b\n" - - // Make it so that all of the final 8bit values are stored in the - // first 64bits of 128bit NEON registers, so they can be stored - // by 64bit st1 store instructions with byte alignment. - "dup d20, v16.d[1]\n" - "dup d21, v17.d[1]\n" - "dup d22, v18.d[1]\n" - "dup d23, v19.d[1]\n" - - // Compute how much of the 8x8 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w2 = how many cols of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 30f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #8\n" - "b 31f\n" - "30:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "31:\n" - - // Write our 8bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v16.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v16) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v20.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v20) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v17.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v17) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v21.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v18.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v18) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v22.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v22) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v19.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v19) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v23.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v23) - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #8\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "41:\n" - "add %[dst_ptr], %[dst_ptr], #8\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - "sqxtn v18.4h, v20.4s\n" - "sqxtn2 v18.8h, v21.4s\n" - "sqxtn v19.4h, v22.4s\n" - "sqxtn2 v19.8h, v23.4s\n" - "sqxtn v20.4h, v24.4s\n" - "sqxtn2 v20.8h, v25.4s\n" - "sqxtn v21.4h, v26.4s\n" - "sqxtn2 v21.8h, v27.4s\n" - "sqxtn v22.4h, v28.4s\n" - "sqxtn2 v22.8h, v29.4s\n" - "sqxtn v23.4h, v30.4s\n" - "sqxtn2 v23.8h, v31.4s\n" - - // At this point, v24 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - "add v17.8h, v17.8h, v14.8h\n" - "add v18.8h, v18.8h, v14.8h\n" - "add v19.8h, v19.8h, v14.8h\n" - "add v20.8h, v20.8h, v14.8h\n" - "add v21.8h, v21.8h, v14.8h\n" - "add v22.8h, v22.8h, v14.8h\n" - "add v23.8h, v23.8h, v14.8h\n" - - // Cast-and-saturate from int16 to uint8 - "sqxtn v16.8b, v16.8h\n" - "sqxtn2 v16.16b, v17.8h\n" - "sqxtn v17.8b, v18.8h\n" - "sqxtn2 v17.16b, v19.8h\n" - "sqxtn v18.8b, v20.8h\n" - "sqxtn2 v18.16b, v21.8h\n" - "sqxtn v19.8b, v22.8h\n" - "sqxtn2 v19.16b, v23.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.16b, v16.16b, v14.16b\n" - "smax v17.16b, v17.16b, v14.16b\n" - "smax v18.16b, v18.16b, v14.16b\n" - "smax v19.16b, v19.16b, v14.16b\n" - - // Apply the clamp_max bound - "smin v16.16b, v16.16b, v15.16b\n" - "smin v17.16b, v17.16b, v15.16b\n" - "smin v18.16b, v18.16b, v15.16b\n" - "smin v19.16b, v19.16b, v15.16b\n" - - // Make it so that all of the final 8bit values are stored in the - // first 64bits of 128bit NEON registers, so they can be stored - // by 64bit st1 store instructions with byte alignment. - "dup d20, v16.d[1]\n" - "dup d21, v17.d[1]\n" - "dup d22, v18.d[1]\n" - "dup d23, v19.d[1]\n" - - // Compute how much of the 8x8 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w2 = how many cols of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 130f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #8\n" - "b 131f\n" - "130:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "131:\n" - - // Write our 8bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v16.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v16) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v20.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v20) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v17.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v17) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v21.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v18.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v18) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v22.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v22) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v19.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v19) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v23.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v23) - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 141f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "150:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "151:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 151b\n" - "add w6, w6, #1\n" - "add x3, x3, #8\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 150b\n" - "141:\n" - "add %[dst_ptr], %[dst_ptr], #8\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "saddw v16.4s, v16.4s, v14.4h\n" - "saddw v17.4s, v17.4s, v14.4h\n" - "saddw v18.4s, v18.4s, v14.4h\n" - "saddw v19.4s, v19.4s, v14.4h\n" - "saddw v20.4s, v20.4s, v14.4h\n" - "saddw v21.4s, v21.4s, v14.4h\n" - "saddw v22.4s, v22.4s, v14.4h\n" - "saddw v23.4s, v23.4s, v14.4h\n" - "saddw v24.4s, v24.4s, v14.4h\n" - "saddw v25.4s, v25.4s, v14.4h\n" - "saddw v26.4s, v26.4s, v14.4h\n" - "saddw v27.4s, v27.4s, v14.4h\n" - "saddw v28.4s, v28.4s, v14.4h\n" - "saddw v29.4s, v29.4s, v14.4h\n" - "saddw v30.4s, v30.4s, v14.4h\n" - "saddw v31.4s, v31.4s, v14.4h\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - "sqxtn v18.4h, v20.4s\n" - "sqxtn2 v18.8h, v21.4s\n" - "sqxtn v19.4h, v22.4s\n" - "sqxtn2 v19.8h, v23.4s\n" - "sqxtn v20.4h, v24.4s\n" - "sqxtn2 v20.8h, v25.4s\n" - "sqxtn v21.4h, v26.4s\n" - "sqxtn2 v21.8h, v27.4s\n" - "sqxtn v22.4h, v28.4s\n" - "sqxtn2 v22.8h, v29.4s\n" - "sqxtn v23.4h, v30.4s\n" - "sqxtn2 v23.8h, v31.4s\n" - - // At this point, v24 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Load the clamp_min, clamp_max bounds - "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.8h, w2\n" // clamp_min - "dup v15.8h, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.8h, v16.8h, v14.8h\n" - "smax v17.8h, v17.8h, v14.8h\n" - "smax v18.8h, v18.8h, v14.8h\n" - "smax v19.8h, v19.8h, v14.8h\n" - "smax v20.8h, v20.8h, v14.8h\n" - "smax v21.8h, v21.8h, v14.8h\n" - "smax v22.8h, v22.8h, v14.8h\n" - "smax v23.8h, v23.8h, v14.8h\n" - // Apply the clamp_max bound - "smin v16.8h, v16.8h, v15.8h\n" - "smin v17.8h, v17.8h, v15.8h\n" - "smin v18.8h, v18.8h, v15.8h\n" - "smin v19.8h, v19.8h, v15.8h\n" - "smin v20.8h, v20.8h, v15.8h\n" - "smin v21.8h, v21.8h, v15.8h\n" - "smin v22.8h, v22.8h, v15.8h\n" - "smin v23.8h, v23.8h, v15.8h\n" - - // Compute how much of the 8x8 block of destination 16bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 230f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #16\n" - "b 231f\n" - "230:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "231:\n" - - // Write our 16bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v16.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v16) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v17.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v17) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v18.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v18) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v19.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v19) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v20.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v20) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v21.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v22.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v22) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v23.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v23) - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 241f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "250:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "251:\n" - "ldrsh w7, [x3, x5, lsl #1]\n" - "strh w7, [x4, x5, lsl #1]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 251b\n" - "add w6, w6, #1\n" - "add x3, x3, #16\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 250b\n" - "241:\n" - "add %[dst_ptr], %[dst_ptr], #16\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" - - // Since the store type is the same as the accum type, no need for - // downcast. There's also no need for clamp by min/max. - - // Compute how much of the 8x8 block of destination 32it values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 330f\n" - // Not all of the 8x8 block fits. - // Write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "st1 {v16.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v16) - "st1 {v17.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v17) - "st1 {v18.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v18) - "st1 {v19.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v19) - "st1 {v20.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v20) - "st1 {v21.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v21) - "st1 {v22.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v22) - "st1 {v23.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v23) - "st1 {v24.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v24) - "st1 {v25.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v25) - "st1 {v26.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v26) - "st1 {v27.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v27) - "st1 {v28.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v28) - "st1 {v29.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v29) - "st1 {v30.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v30) - "st1 {v31.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v31) - - "b 331f\n" - - "330:\n" - // Yes, all of the 8x8 block fits. - "mov x4, %[dst_ptr]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.4s, v17.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - "add x4, x4, x11\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v18.4s, v19.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - "add x4, x4, x11\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v20.4s, v21.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - "add x4, x4, x11\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v22.4s, v23.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - "add x4, x4, x11\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v24.4s, v25.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - "add x4, x4, x11\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v26.4s, v27.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - "add x4, x4, x11\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v28.4s, v29.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - "add x4, x4, x11\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v30.4s, v31.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - "331:\n" - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 341f\n" - - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "350:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "351:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 351b\n" - "add w6, w6, #1\n" - "add x3, x3, #32\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 350b\n" - "341:\n" - "add %[dst_ptr], %[dst_ptr], #32\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #8\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #8\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 4. - "mov w1, #4\n" - - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), - [dst_type_id] "r"(params.dst_type_id) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "v31"); -} - -// Similar to the above 8-bit dotprod kernel, but specialized for the case of -// RHS cols == 1. -// Relevant target CPUs for this kernel include ARM Cortex-A76, -// since these are 64-bit, out-of-order and with dotprod support. -void Kernel8bitNeonDotprodOutOfOrder1Col(const KernelParams8bit<8, 8>& params) { - profiler::ScopeLabel label( - "Kernel (kNeonDotprod, optimized for out-of-order cores)"); - - CheckOffsetsInKernelParams8bit(params); - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - void* dst_col_ptr = params.dst_base_ptr; - void* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v31 are int32 accumulators. - // During accumulation, v0 -- v15 are used to load int8 data from LHS and - // RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and - // v3 are used to load a 4x8 block of RHS, like this: - // - // int8 RHS 4x1 block - // /-------\ - // |v2.b[0]| - // | ... | - // |v2.b[3]| - // \-------/ - // int8 LHS 8x4 block - // /---------------------\ /--------\ - // |v0.b[0] ... v0.b[3] | |v16.s[0]| - // | ... ... | | ... | - // |v0.b[12] ... v0.b[15]| |v16.s[3]| - // |v1.b[0] ... v1.b[3] | |v17.s[0]| - // | ... ... | | ... | - // |v1.b[12] ... v1.b[15]| |v17.s[3]| - // \---------------------/ \--------/ - // int32 accumulators 8x1 block - // - // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step - // is repeated 4 times, using 4x more registers for LHS and RHS, so that - // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15. - // - // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are - // unused, and v8 -- v15 are used for loading parameters used for the - // post-accumulation part of the kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - - // Load the first 32 bytes of LHS and RHS data. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.8b}, [%[rhs_ptr]]\n" - "add %[rhs_ptr], %[rhs_ptr], #32\n" - - // Clear accumulators. - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 4. - "mov w1, #4\n" - - // Perform the first few multiply-adds on the data that we have already - // loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - // Ordinary kernel inner loop (over depth), the simpler loop that the - // above was an equivalent 4x-partially-unrolled version of. - - // Reminder - w1 is how many levels of depth we have already loaded - // data for, w12 is the total depth. - "cmp w1, w12\n" - "beq 79f\n" - - "2:\n" - - // Because of the data that we have already loaded, we can start the - // loop body right away with some multiply-adds. - // Each iteration of this loop advances by 4 levels of depth. - "add w1, w1, #4\n" - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" - // Loop termination condition. - "cmp w1, w12\n" - "ld1 {v2.8b}, [%[rhs_ptr]]\n" - "add %[rhs_ptr], %[rhs_ptr], #32\n" - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - - "blt 2b\n" - - "79:\n" - // End of the inner loop on depth. Now perform the remaining - // multiply-adds of the last 4 levels of depth, for which the LHS - // and RHS data is already loaded. - - ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" - - // End of accumulation. The registers v16 -- v31 contain the final - // int32 accumulator values of the current 8x8 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 8x8 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "cmp %w[row], w7\n" // Have we finished the last row? - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - - // Load some parameters needed for the end work on current block. - RUY_MAKE_ZERO(v8) - "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" - "ins v13.h[4], w4\n" // dst_zero_point - "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" - "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "dup v9.4s, w3\n" // create prod_zp_depth_vec - "add x5, x4, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "csel x4, x4, x5, eq\n" - - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - "add x5, x1, %x[row], lsl #2\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 8 bias values. - "ld1 {v14.4s}, [x1], #16\n" - "ld1 {v15.4s}, [x1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.8b}, [%[rhs_ptr]]\n" - "add %[rhs_ptr], %[rhs_ptr], #32\n" - - // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), - // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "add v14.4s, v14.4s, v9.4s\n" - "add v15.4s, v15.4s, v9.4s\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "add v16.4s, v16.4s, v14.4s\n" - "add v17.4s, v17.4s, v15.4s\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" - "beq 401f\n" - "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" - "add x3, x3, %x[col], lsl #2\n" - "ld1 {v14.4s}, [x3], #16\n" - "ld1 {v15.4s}, [x3]\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" - "dup v10.4s, w5\n" // create lhs_zero_point_vec - // Subtract rhs_sums * lhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "mls v16.4s, v10.4s, v14.s[0]\n" - "mls v17.4s, v10.4s, v14.s[0]\n" - "401:\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" - "beq 402f\n" - "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" - "add x2, x2, %x[row], lsl #2\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" - // Load 4 lhs_sums values. - "ld1 {v11.4s}, [x2], #16\n" - "ld1 {v12.4s}, [x2]\n" - "ins v13.s[1], w5\n" // rhs_zero_point - // Compute lhs_sums * rhs_zero_point. - "mul v11.4s, v11.4s, v13.s[1]\n" - "mul v12.4s, v12.4s, v13.s[1]\n" - // Subtract lhs_sums * rhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "sub v16.4s, v16.4s, v11.4s\n" - "sub v17.4s, v17.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" - - "402:\n" - - // At this point we have computed the final int32 values. Now we - // start down-quantizing them to obtain the final 8bit values from them. - - // As part of this down-quantization, our int32 values will be - // multiplied by a multiplier that has a fixed-point component and an - // exponent component. - - //Load the exponent part of the multiplier. - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "add x5, x1, %x[row], lsl #2\n" - "csel x1, x1, x5, eq\n" - - "ldr q9, [x1]\n" - "ldr q10, [x1, #16]\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n" - "beq 403f\n" - "smax v11.4s, v9.4s, v8.4s\n" - "smax v12.4s, v10.4s, v8.4s\n" - "sshl v16.4s, v16.4s, v11.4s\n" - "sshl v17.4s, v17.4s, v12.4s\n" - "403:\n" - - "ldr q14, [x4]\n" // multiplier_fixedpoint - "ldr q15, [x4, #16]\n" // multiplier_fixedpoint - - "smin v11.4s, v9.4s, v8.4s\n" - "smin v12.4s, v10.4s, v8.4s\n" - - // Apply the fixed-point part of the multiplier. - "sqrdmulh v16.4s, v16.4s, v14.4s\n" - "sqrdmulh v17.4s, v17.4s, v15.4s\n" - - // We have some rounding division-by-power-of-two to do. This should - // always use "round to nearest". We allow for some - // freedom in how ties are broken, to strike a good compromise of - // performance on given hardware vs. perfect agreement of results - // across hardware. - // - // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation - // defined tie-breaks to help performance. On NEON, this means that we - // can just use the NEON rounding instructions, such as srshl. They - // happen to be breaking ties upward. - // - // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict - // break-ties-away-from zero, as described in Appendix B of - // https://arxiv.org/pdf/1712.05877.pdf - // When we wrote that, we thought that that would be better unbiased - // than the NEON upwards tie-breaks, and we had observed some - // improvement on some model. However, that is only more unbiased for - // data centered at zero, which was likely the case in that model, - // but is not always the case. If we wanted something more consistently - // unbiased then we should try breaking ties toward-nearest-even. -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - // Fix up values to be right-shifted, so that the (round to nearest, - // break ties upward) behavior of srshl applied to these fixed-up - // values, produces the same result as the desired (round to nearest, - // break ties away from zero) behavior on the original values. - "and v8.16b, v16.16b, v11.16b\n" - "and v9.16b, v17.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sqadd v16.4s, v16.4s, v8.4s\n" - "sqadd v17.4s, v17.4s, v9.4s\n" - -#endif - // At this point we have reduced the problem of correctly implementing - // rounding divide-by-power-of-two, to what the SRSHL instruction can - // do. - "srshl v16.4s, v16.4s, v11.4s\n" - "srshl v17.4s, v17.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - // All data in v16 at this point. - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - - // Cast-and-saturate from int16 to uint8, leaving all data in the - // lower half of v16. - "sqxtun v16.8b, v16.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Apply the clamp_min bound - "umax v16.16b, v16.16b, v14.16b\n" - - // Apply the clamp_max bound - "umin v16.16b, v16.16b, v15.16b\n" - - // Make it so that all of the final 8bit values are stored in the - // first 64bits of 128bit NEON registers, so they can be stored - // by 64bit st1 store instructions with byte alignment. - "dup d20, v16.d[1]\n" - - // Compute how much of the 8x1 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x1, there are some 8x1 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x1 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - - // Test if w1==8, i.e. if all of the 8x1 block fits. - "cmp w1, w3\n" - // Yes, all of the 8x1 block fits, go to fast path. - "beq 30f\n" - // Not all of the 8x1 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #8\n" - "b 31f\n" - "30:\n" - // Yes, all of the 8x1 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "31:\n" - - // Write our 8bit values to the destination - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v16.8b}, [x3]\n" - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "41:\n" - "add %[dst_ptr], %[dst_ptr], #8\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - - // Cast-and-saturate from int16 to uint8 - "sqxtn v16.8b, v16.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.16b, v16.16b, v14.16b\n" - - // Apply the clamp_max bound - "smin v16.16b, v16.16b, v15.16b\n" - - // Make it so that all of the final 8bit values are stored in the - // first 64bits of 128bit NEON registers, so they can be stored - // by 64bit st1 store instructions with byte alignment. - "dup d20, v16.d[1]\n" - - // Compute how much of the 8x1 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x1 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - - // Test if w1==8, i.e. if all of the 8x1 block fits. - "cmp w1, w3\n" - // Yes, all of the 8x1 block fits, go to fast path. - "beq 130f\n" - // Not all of the 8x1 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #8\n" - "b 131f\n" - "130:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "131:\n" - - // Write our 8bit values to the destination - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v16.8b}, [x3]\n" - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 141f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "150:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "151:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 151b\n" - "141:\n" - "add %[dst_ptr], %[dst_ptr], #8\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "saddw v16.4s, v16.4s, v14.4h\n" - "saddw v17.4s, v17.4s, v14.4h\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - - // Load the clamp_min, clamp_max bounds - "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.8h, w2\n" // clamp_min - "dup v15.8h, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.8h, v16.8h, v14.8h\n" - // Apply the clamp_max bound - "smin v16.8h, v16.8h, v15.8h\n" - - // Compute how much of the 8x1 block of destination 16bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x1 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x1 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - - // Test if w1==8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - // Yes, all of the 8x1 block fits, go to fast path. - "beq 230f\n" - // Not all of the 8x1 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #16\n" - "b 231f\n" - "230:\n" - // Yes, all of the 8x1 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "231:\n" - - // Write our 16bit values to the destination - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v16.8h}, [x3]\n" - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - - // If all of the 8x1 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 241f\n" - // Not all of the 8x1 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "250:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "251:\n" - "ldrsh w7, [x3, x5, lsl #1]\n" - "strh w7, [x4, x5, lsl #1]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 251b\n" - "241:\n" - "add %[dst_ptr], %[dst_ptr], #16\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" - - // Since the store type is the same as the accum type, no need for - // downcast. There's also no need for clamp by min/max. - - // Compute how much of the 8x1 block of destination 32 bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x1, there are some 8x1 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x1 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - // Yes, all of the 8x1 block fits, go to fast path. - "beq 330f\n" - // Not all of the 8x1 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #16\n" - - // Write our 32bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v16.4s}, [x3], x4\n" - RUY_MAKE_ZERO(v16) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v17.4s}, [x3], x4\n" - RUY_MAKE_ZERO(v17) - - "b 331f\n" - - "330:\n" - // Yes, all of the 8x1 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x4, %[dst_ptr]\n" - "mov x3, x4\n" - - // Write our 32bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v16.4s, v17.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "331:\n" - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 341f\n" - - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "350:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "mov w5, #0\n" - "351:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 351b\n" - "341:\n" - "add %[dst_ptr], %[dst_ptr], #32\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #8\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #8\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 4. - "mov w1, #4\n" - - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), - [dst_type_id] "r"(params.dst_type_id) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17"); -} - -// Variant of the above Kernel8bitNeonDotprodOutOfOrder, tuned for in-order -// CPUs. Specifically here, the relevant in-order CPUs are ARM Cortex-A55r1, -// since these are 64-bit and support dotprod. -// -// While this kernel does not have a direct equivalent in gemmlowp, it was -// developed based on insights that David Mansell at ARM shared with their -// contribution of gemmlowp kernels tuned for Cortex-A55r1, with very helpful -// comments. Specifically, see this comment about tuning for Cortex-A55r1: -// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4412 -void Kernel8bitNeonDotprodInOrder(const KernelParams8bit<8, 8>& params) { - profiler::ScopeLabel label( - "Kernel (kNeonDotprod, optimized for in-order cores)"); - - CheckOffsetsInKernelParams8bit(params); - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - void* dst_col_ptr = params.dst_base_ptr; - void* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v31 are int32 accumulators. - // During accumulation, v0 -- v3 are used to load int8 data from LHS and - // RHS. - // - // int8 RHS 4x8 block - // /-----------------------------------------\ - // |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]| - // | ... ... | - // |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]| - // \-----------------------------------------/ - // int8 LHS 8x4 block - // /---------------------\ /-----------------------------------------\ - // |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]| - // | ... ... | | ... ... | - // |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]| - // |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]| - // | ... ... | | ... ... | - // |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]| - // \---------------------/ \-----------------------------------------/ - // int32 accumulators 8x8 block - // - // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because - // we did not observe a benefit of such partial unrolling on in-order CPUs. - // - // v4 -- v7 are unused, and v8 -- v15 are used for loading parameters used for - // the post-accumulation part of the kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - RUY_MAKE_ZERO(v16) - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - RUY_MAKE_ZERO(v17) - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - RUY_MAKE_ZERO(v18) - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - RUY_MAKE_ZERO(v19) - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - RUY_MAKE_ZERO(v20) - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - RUY_MAKE_ZERO(v21) - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - RUY_MAKE_ZERO(v22) - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - - // Load the first 32 bytes of LHS and RHS data. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" - - // Clear accumulators. - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - // Perform the first few multiply-adds on the data that we have already - // loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - RUY_MAKE_ZERO(v28) - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - RUY_MAKE_ZERO(v29) - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - RUY_MAKE_ZERO(v30) - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - RUY_MAKE_ZERO(v31) - - - "1:\n" - - "add x5, %[lhs_ptr], x12, lsl #3\n" - "sub x5, x5, #32\n" - "cmp %[lhs_ptr], x5\n" - - "beq 79f\n" - - // Main accumulation loop - "2:\n" - ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" - "ldr x1, [%[lhs_ptr], #8]\n" - ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" - "ldr x3, [%[rhs_ptr], #8]\n" - ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" - "ldr x4, [%[rhs_ptr], #24]\n" - ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" - "ldr d0, [%[lhs_ptr], #0]\n" - ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" - "ins v0.d[1], x1\n" - ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" - "ldr x2, [%[lhs_ptr], #24]\n" - ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" - "add %[lhs_ptr], %[lhs_ptr], #32\n" - ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" - "ldr d2, [%[rhs_ptr], #0]\n" - ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" - "ins v2.d[1], x3\n" - ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" - "cmp %[lhs_ptr], x5\n" - ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" - "add %[rhs_ptr], %[rhs_ptr], #32\n" - ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" - "ldr d3, [%[rhs_ptr], #-16]\n" - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - "ldr d1, [%[lhs_ptr], #-16]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - "ins v3.d[1], x4\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - "ins v1.d[1], x2\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - "blt 2b\n" - - // Last accumulation steps, nothing left to load. - "79:\n" - ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" - "cmp %w[row], w7\n" // Have we finished the last row? - ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" - ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" - ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" - ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" - ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" - ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" - ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" - ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" - ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" - ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" - - // End of accumulation. The registers v16 -- v31 contain the final - // int32 accumulator values of the current 8x8 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 8x8 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - // Load some parameters needed for the end work on current block. - RUY_MAKE_ZERO(v8) - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" - "ins v13.h[4], w4\n" // dst_zero_point - "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" - "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "dup v9.4s, w3\n" // create prod_zp_depth_vec - "add x5, x4, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "csel x4, x4, x5, eq\n" - - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - "add x5, x1, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 8 bias values. - "ld1 {v14.2s}, [x1], #8\n" - "ldr x5, [x1], #8\n" - "ins v14.d[1], x5\n" - "ld1 {v15.2s}, [x1], #8\n" - "ldr x5, [x1], #8\n" - "ins v15.d[1], x5\n" - - // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), - // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "add v14.4s, v14.4s, v9.4s\n" - "add v15.4s, v15.4s, v9.4s\n" - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "add v16.4s, v16.4s, v14.4s\n" - "add v17.4s, v17.4s, v15.4s\n" - "add v18.4s, v18.4s, v14.4s\n" - "add v19.4s, v19.4s, v15.4s\n" - "add v20.4s, v20.4s, v14.4s\n" - "add v21.4s, v21.4s, v15.4s\n" - "add v22.4s, v22.4s, v14.4s\n" - "add v23.4s, v23.4s, v15.4s\n" - "add v24.4s, v24.4s, v14.4s\n" - "add v25.4s, v25.4s, v15.4s\n" - "add v26.4s, v26.4s, v14.4s\n" - "add v27.4s, v27.4s, v15.4s\n" - "add v28.4s, v28.4s, v14.4s\n" - "add v29.4s, v29.4s, v15.4s\n" - "add v30.4s, v30.4s, v14.4s\n" - "add v31.4s, v31.4s, v15.4s\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" - "beq 401f\n" - "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" - "add x3, x3, %x[col], lsl #2\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" - "dup v10.4s, w5\n" // create lhs_zero_point_vec - // Load 8 rhs_sums values. - "ld1 {v14.2s}, [x3], #8\n" - "ldr x7, [x3], #8\n" - "ld1 {v15.2s}, [x3], #8\n" - "ins v14.d[1], x7\n" - "ldr x7, [x3], #8\n" - "ins v15.d[1], x7\n" - // Subtract rhs_sums * lhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "mls v16.4s, v10.4s, v14.s[0]\n" - "mls v17.4s, v10.4s, v14.s[0]\n" - "mls v18.4s, v10.4s, v14.s[1]\n" - "mls v19.4s, v10.4s, v14.s[1]\n" - "mls v20.4s, v10.4s, v14.s[2]\n" - "mls v21.4s, v10.4s, v14.s[2]\n" - "mls v22.4s, v10.4s, v14.s[3]\n" - "mls v23.4s, v10.4s, v14.s[3]\n" - "mls v24.4s, v10.4s, v15.s[0]\n" - "mls v25.4s, v10.4s, v15.s[0]\n" - "mls v26.4s, v10.4s, v15.s[1]\n" - "mls v27.4s, v10.4s, v15.s[1]\n" - "mls v28.4s, v10.4s, v15.s[2]\n" - "mls v29.4s, v10.4s, v15.s[2]\n" - "mls v30.4s, v10.4s, v15.s[3]\n" - "mls v31.4s, v10.4s, v15.s[3]\n" - "401:\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" - "beq 402f\n" - "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" - "add x2, x2, %x[row], lsl #2\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" - "ins v13.s[1], w5\n" // rhs_zero_point - // Load 8 lhs_sums values. - "ld1 {v11.2s}, [x2], #8\n" - "ldr x6, [x2], #8\n" - "ins v11.d[1], x6\n" - "ld1 {v12.2s}, [x2], #8\n" - "ldr x6, [x2], #8\n" - "ins v12.d[1], x6\n" - // Compute lhs_sums * rhs_zero_point. - "mul v11.4s, v11.4s, v13.s[1]\n" - "mul v12.4s, v12.4s, v13.s[1]\n" - // Subtract lhs_sums * rhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "sub v16.4s, v16.4s, v11.4s\n" - "sub v17.4s, v17.4s, v12.4s\n" - "sub v18.4s, v18.4s, v11.4s\n" - "sub v19.4s, v19.4s, v12.4s\n" - "sub v20.4s, v20.4s, v11.4s\n" - "sub v21.4s, v21.4s, v12.4s\n" - "sub v22.4s, v22.4s, v11.4s\n" - "sub v23.4s, v23.4s, v12.4s\n" - "sub v24.4s, v24.4s, v11.4s\n" - "sub v25.4s, v25.4s, v12.4s\n" - "sub v26.4s, v26.4s, v11.4s\n" - "sub v27.4s, v27.4s, v12.4s\n" - "sub v28.4s, v28.4s, v11.4s\n" - "sub v29.4s, v29.4s, v12.4s\n" - "sub v30.4s, v30.4s, v11.4s\n" - "sub v31.4s, v31.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" - - "402:\n" - - // At this point we have computed the final int32 values. Now we - // start down-quantizing them to obtain the final 8bit values from them. - - // As part of this down-quantization, our int32 values will be - // multiplied by a multiplier that has a fixed-point component and an - // exponent component. - - //Load the exponent part of the multiplier. - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" - "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "add x5, x1, %x[row], lsl #2\n" - "csel x1, x1, x5, eq\n" - - "ldr q9, [x1]\n" - "ldr q10, [x1, #16]\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n" - "beq 403f\n" - "smax v11.4s, v9.4s, v8.4s\n" - "smax v12.4s, v10.4s, v8.4s\n" - "sshl v16.4s, v16.4s, v11.4s\n" - "sshl v17.4s, v17.4s, v12.4s\n" - "sshl v18.4s, v18.4s, v11.4s\n" - "sshl v19.4s, v19.4s, v12.4s\n" - "sshl v20.4s, v20.4s, v11.4s\n" - "sshl v21.4s, v21.4s, v12.4s\n" - "sshl v22.4s, v22.4s, v11.4s\n" - "sshl v23.4s, v23.4s, v12.4s\n" - "sshl v24.4s, v24.4s, v11.4s\n" - "sshl v25.4s, v25.4s, v12.4s\n" - "sshl v26.4s, v26.4s, v11.4s\n" - "sshl v27.4s, v27.4s, v12.4s\n" - "sshl v28.4s, v28.4s, v11.4s\n" - "sshl v29.4s, v29.4s, v12.4s\n" - "sshl v30.4s, v30.4s, v11.4s\n" - "sshl v31.4s, v31.4s, v12.4s\n" - "403:\n" - - "ldr q14, [x4]\n" // multiplier_fixedpoint - "ldr q15, [x4, #16]\n" // multiplier_fixedpoint - - "smin v11.4s, v9.4s, v8.4s\n" - "smin v12.4s, v10.4s, v8.4s\n" - - // Apply the fixed-point part of the multiplier. - // - // ... and, interleaved into that: - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "ld1 {v0.8b}, [%[lhs_ptr]], #8\n" - "sqrdmulh v16.4s, v16.4s, v14.4s\n" - "ldr x1, [%[lhs_ptr]], #8\n" - "sqrdmulh v17.4s, v17.4s, v15.4s\n" - "ld1 {v1.8b}, [%[lhs_ptr]], #8\n" - "sqrdmulh v18.4s, v18.4s, v14.4s\n" - "ldr x2, [%[lhs_ptr]], #8\n" - "sqrdmulh v19.4s, v19.4s, v15.4s\n" - "ld1 {v2.8b}, [%[rhs_ptr]], #8\n" - "sqrdmulh v20.4s, v20.4s, v14.4s\n" - "ldr x5, [%[rhs_ptr]], #8\n" - "sqrdmulh v21.4s, v21.4s, v15.4s\n" - "ld1 {v3.8b}, [%[rhs_ptr]], #8\n" - "sqrdmulh v22.4s, v22.4s, v14.4s\n" - "ldr x6, [%[rhs_ptr]], #8\n" - "sqrdmulh v23.4s, v23.4s, v15.4s\n" - "sqrdmulh v24.4s, v24.4s, v14.4s\n" - "sqrdmulh v25.4s, v25.4s, v15.4s\n" - "sqrdmulh v26.4s, v26.4s, v14.4s\n" - "sqrdmulh v27.4s, v27.4s, v15.4s\n" - "sqrdmulh v28.4s, v28.4s, v14.4s\n" - "sqrdmulh v29.4s, v29.4s, v15.4s\n" - "sqrdmulh v30.4s, v30.4s, v14.4s\n" - "sqrdmulh v31.4s, v31.4s, v15.4s\n" - - // We have some rounding division-by-power-of-two to do. This should - // always use "round to nearest". We allow for some - // freedom in how ties are broken, to strike a good compromise of - // performance on given hardware vs. perfect agreement of results - // across hardware. - // - // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation - // defined tie-breaks to help performance. On NEON, this means that we - // can just use the NEON rounding instructions, such as srshl. They - // happen to be breaking ties upward. - // - // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict - // break-ties-away-from zero, as described in Appendix B of - // https://arxiv.org/pdf/1712.05877.pdf - // When we wrote that, we thought that that would be better unbiased - // than the NEON upwards tie-breaks, and we had observed some - // improvement on some model. However, that is only more unbiased for - // data centered at zero, which was likely the case in that model, - // but is not always the case. If we wanted something more consistently - // unbiased then we should try breaking ties toward-nearest-even. -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - // Fix up values to be right-shifted, so that the (round to nearest, - // break ties upward) behavior of srshl applied to these fixed-up - // values, produces the same result as the desired (round to nearest, - // break ties away from zero) behavior on the original values. - "and v8.16b, v16.16b, v11.16b\n" - "and v9.16b, v17.16b, v12.16b\n" - "and v14.16b, v18.16b, v11.16b\n" - "and v15.16b, v19.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v16.4s, v16.4s, v8.4s\n" - "sqadd v17.4s, v17.4s, v9.4s\n" - "sqadd v18.4s, v18.4s, v14.4s\n" - "sqadd v19.4s, v19.4s, v15.4s\n" - "and v8.16b, v20.16b, v11.16b\n" - "and v9.16b, v21.16b, v12.16b\n" - "and v14.16b, v22.16b, v11.16b\n" - "and v15.16b, v23.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v20.4s, v20.4s, v8.4s\n" - "sqadd v21.4s, v21.4s, v9.4s\n" - "sqadd v22.4s, v22.4s, v14.4s\n" - "sqadd v23.4s, v23.4s, v15.4s\n" - "and v8.16b, v24.16b, v11.16b\n" - "and v9.16b, v25.16b, v12.16b\n" - "and v14.16b, v26.16b, v11.16b\n" - "and v15.16b, v27.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v24.4s, v24.4s, v8.4s\n" - "sqadd v25.4s, v25.4s, v9.4s\n" - "sqadd v26.4s, v26.4s, v14.4s\n" - "sqadd v27.4s, v27.4s, v15.4s\n" - "and v8.16b, v28.16b, v11.16b\n" - "and v9.16b, v29.16b, v12.16b\n" - "and v14.16b, v30.16b, v11.16b\n" - "and v15.16b, v31.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v28.4s, v28.4s, v8.4s\n" - "sqadd v29.4s, v29.4s, v9.4s\n" - "sqadd v30.4s, v30.4s, v14.4s\n" - "sqadd v31.4s, v31.4s, v15.4s\n" -#endif - // At this point we have reduced the problem of correctly implementing - // rounding divide-by-power-of-two, to what the SRSHL instruction can - // do. - "srshl v16.4s, v16.4s, v11.4s\n" - "srshl v17.4s, v17.4s, v12.4s\n" - "srshl v18.4s, v18.4s, v11.4s\n" - "srshl v19.4s, v19.4s, v12.4s\n" - "srshl v20.4s, v20.4s, v11.4s\n" - "srshl v21.4s, v21.4s, v12.4s\n" - "srshl v22.4s, v22.4s, v11.4s\n" - "srshl v23.4s, v23.4s, v12.4s\n" - "srshl v24.4s, v24.4s, v11.4s\n" - "srshl v25.4s, v25.4s, v12.4s\n" - "srshl v26.4s, v26.4s, v11.4s\n" - "srshl v27.4s, v27.4s, v12.4s\n" - "ins v0.d[1], x1\n" - "srshl v28.4s, v28.4s, v11.4s\n" - "ins v1.d[1], x2\n" - "srshl v29.4s, v29.4s, v12.4s\n" - "ins v2.d[1], x5\n" - "srshl v30.4s, v30.4s, v11.4s\n" - "ins v3.d[1], x6\n" - "srshl v31.4s, v31.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - "sqxtn v18.4h, v20.4s\n" - "sqxtn2 v18.8h, v21.4s\n" - "sqxtn v19.4h, v22.4s\n" - "sqxtn2 v19.8h, v23.4s\n" - "sqxtn v20.4h, v24.4s\n" - "sqxtn2 v20.8h, v25.4s\n" - "sqxtn v21.4h, v26.4s\n" - "sqxtn2 v21.8h, v27.4s\n" - "sqxtn v22.4h, v28.4s\n" - "sqxtn2 v22.8h, v29.4s\n" - "sqxtn v23.4h, v30.4s\n" - "sqxtn2 v23.8h, v31.4s\n" - - // Destination zero_point - "dup v14.8h, v13.h[4]\n" - // At this point, v24 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Add the destination zero point - "add v16.8h, v16.8h, v14.8h\n" - "add v17.8h, v17.8h, v14.8h\n" - "add v18.8h, v18.8h, v14.8h\n" - "add v19.8h, v19.8h, v14.8h\n" - "add v20.8h, v20.8h, v14.8h\n" - "add v21.8h, v21.8h, v14.8h\n" - "add v22.8h, v22.8h, v14.8h\n" - "add v23.8h, v23.8h, v14.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - // Cast-and-saturate from int16 to uint8 - "sqxtun v16.8b, v16.8h\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "sqxtun2 v16.16b, v17.8h\n" - "sqxtun v17.8b, v18.8h\n" - "sqxtun2 v17.16b, v19.8h\n" - "sqxtun v18.8b, v20.8h\n" - "sqxtun2 v18.16b, v21.8h\n" - "sqxtun v19.8b, v22.8h\n" - "sqxtun2 v19.16b, v23.8h\n" - - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Compute how much of the 8x8 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - // Apply the clamp_min bound - "umax v16.16b, v16.16b, v14.16b\n" - "sub w2, %w[dst_cols], %w[col]\n" - "umax v17.16b, v17.16b, v14.16b\n" - "mov w3, #8\n" - "umax v18.16b, v18.16b, v14.16b\n" - "cmp w1, #8\n" - "umax v19.16b, v19.16b, v14.16b\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - // Apply the clamp_max bound - "umin v16.16b, v16.16b, v15.16b\n" - "cmp w2, #8\n" - "umin v17.16b, v17.16b, v15.16b\n" - // Compute w2 = how many cols of the 8x8 block fit - "csel w2, w2, w3, le\n" - "umin v18.16b, v18.16b, v15.16b\n" - "umin v19.16b, v19.16b, v15.16b\n" - - // Make it so that all of the final 8bit values are stored in the - // first 64bits of 128bit NEON registers, so they can be stored - // by 64bit st1 store instructions with byte alignment. - "dup d20, v16.d[1]\n" - "dup d21, v17.d[1]\n" - "dup d22, v18.d[1]\n" - "dup d23, v19.d[1]\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 30f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #8\n" - "b 31f\n" - "30:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "31:\n" - - // Write our 8bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v16.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v16) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v20.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v20) - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v17.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v17) - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v21.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v18.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v18) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v22.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v22) - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v19.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v19) - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v23.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v23) - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #8\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "41:\n" - "add %[dst_ptr], %[dst_ptr], #8\n" - - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - "sqxtn v18.4h, v20.4s\n" - "sqxtn2 v18.8h, v21.4s\n" - "sqxtn v19.4h, v22.4s\n" - "sqxtn2 v19.8h, v23.4s\n" - "sqxtn v20.4h, v24.4s\n" - "sqxtn2 v20.8h, v25.4s\n" - "sqxtn v21.4h, v26.4s\n" - "sqxtn2 v21.8h, v27.4s\n" - "sqxtn v22.4h, v28.4s\n" - "sqxtn2 v22.8h, v29.4s\n" - "sqxtn v23.4h, v30.4s\n" - "sqxtn2 v23.8h, v31.4s\n" - - // Destination zero_point - "dup v14.8h, v13.h[4]\n" - // At this point, v24 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Add the destination zero point - "add v16.8h, v16.8h, v14.8h\n" - "add v17.8h, v17.8h, v14.8h\n" - "add v18.8h, v18.8h, v14.8h\n" - "add v19.8h, v19.8h, v14.8h\n" - "add v20.8h, v20.8h, v14.8h\n" - "add v21.8h, v21.8h, v14.8h\n" - "add v22.8h, v22.8h, v14.8h\n" - "add v23.8h, v23.8h, v14.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - // Cast-and-saturate from int16 to uint8 - "sqxtn v16.8b, v16.8h\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "sqxtn2 v16.16b, v17.8h\n" - "sqxtn v17.8b, v18.8h\n" - "sqxtn2 v17.16b, v19.8h\n" - "sqxtn v18.8b, v20.8h\n" - "sqxtn2 v18.16b, v21.8h\n" - "sqxtn v19.8b, v22.8h\n" - "sqxtn2 v19.16b, v23.8h\n" - - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Compute how much of the 8x8 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - // Apply the clamp_min bound - "smax v16.16b, v16.16b, v14.16b\n" - "sub w2, %w[dst_cols], %w[col]\n" - "smax v17.16b, v17.16b, v14.16b\n" - "mov w3, #8\n" - "smax v18.16b, v18.16b, v14.16b\n" - "cmp w1, #8\n" - "smax v19.16b, v19.16b, v14.16b\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - // Apply the clamp_max bound - "smin v16.16b, v16.16b, v15.16b\n" - "cmp w2, #8\n" - "smin v17.16b, v17.16b, v15.16b\n" - // Compute w2 = how many cols of the 8x8 block fit - "csel w2, w2, w3, le\n" - "smin v18.16b, v18.16b, v15.16b\n" - "smin v19.16b, v19.16b, v15.16b\n" - - // Make it so that all of the final 8bit values are stored in the - // first 64bits of 128bit NEON registers, so they can be stored - // by 64bit st1 store instructions with byte alignment. - "dup d20, v16.d[1]\n" - "dup d21, v17.d[1]\n" - "dup d22, v18.d[1]\n" - "dup d23, v19.d[1]\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 130f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #8\n" - "b 131f\n" - "130:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "131:\n" - - // Write our 8bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v16.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v16) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v20.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v20) - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v17.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v17) - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v21.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v18.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v18) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v22.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v22) - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v19.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v19) - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v23.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v23) - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 141f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "150:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "151:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 151b\n" - "add w6, w6, #1\n" - "add x3, x3, #8\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 150b\n" - "141:\n" - "add %[dst_ptr], %[dst_ptr], #8\n" - - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "saddw v16.4s, v16.4s, v14.4h\n" - "saddw v17.4s, v17.4s, v14.4h\n" - "saddw v18.4s, v18.4s, v14.4h\n" - "saddw v19.4s, v19.4s, v14.4h\n" - "saddw v20.4s, v20.4s, v14.4h\n" - "saddw v21.4s, v21.4s, v14.4h\n" - "saddw v22.4s, v22.4s, v14.4h\n" - "saddw v23.4s, v23.4s, v14.4h\n" - "saddw v24.4s, v24.4s, v14.4h\n" - "saddw v25.4s, v25.4s, v14.4h\n" - "saddw v26.4s, v26.4s, v14.4h\n" - "saddw v27.4s, v27.4s, v14.4h\n" - "saddw v28.4s, v28.4s, v14.4h\n" - "saddw v29.4s, v29.4s, v14.4h\n" - "saddw v30.4s, v30.4s, v14.4h\n" - "saddw v31.4s, v31.4s, v14.4h\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - "sqxtn v18.4h, v20.4s\n" - "sqxtn2 v18.8h, v21.4s\n" - "sqxtn v19.4h, v22.4s\n" - "sqxtn2 v19.8h, v23.4s\n" - "sqxtn v20.4h, v24.4s\n" - "sqxtn2 v20.8h, v25.4s\n" - "sqxtn v21.4h, v26.4s\n" - "sqxtn2 v21.8h, v27.4s\n" - "sqxtn v22.4h, v28.4s\n" - "sqxtn2 v22.8h, v29.4s\n" - "sqxtn v23.4h, v30.4s\n" - "sqxtn2 v23.8h, v31.4s\n" - - // At this point, v24 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Load the clamp_min, clamp_max bounds - "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.8h, w2\n" // clamp_min - "dup v15.8h, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.8h, v16.8h, v14.8h\n" - "smax v17.8h, v17.8h, v14.8h\n" - "smax v18.8h, v18.8h, v14.8h\n" - "smax v19.8h, v19.8h, v14.8h\n" - "smax v20.8h, v20.8h, v14.8h\n" - "smax v21.8h, v21.8h, v14.8h\n" - "smax v22.8h, v22.8h, v14.8h\n" - "smax v23.8h, v23.8h, v14.8h\n" - // Apply the clamp_max bound - "smin v16.8h, v16.8h, v15.8h\n" - "smin v17.8h, v17.8h, v15.8h\n" - "smin v18.8h, v18.8h, v15.8h\n" - "smin v19.8h, v19.8h, v15.8h\n" - "smin v20.8h, v20.8h, v15.8h\n" - "smin v21.8h, v21.8h, v15.8h\n" - "smin v22.8h, v22.8h, v15.8h\n" - "smin v23.8h, v23.8h, v15.8h\n" - - // Compute how much of the 8x8 block of destination 16bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 230f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #16\n" - "b 231f\n" - "230:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "231:\n" - - // Write our 8bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v16.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v16) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v17.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v17) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v18.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v18) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v19.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v19) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v20.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v20) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v21.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v22.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v22) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v23.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v23) - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 241f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "250:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "251:\n" - "ldrsh w7, [x3, x5, lsl #1]\n" - "strh w7, [x4, x5, lsl #1]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 251b\n" - "add w6, w6, #1\n" - "add x3, x3, #16\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 250b\n" - "241:\n" - "add %[dst_ptr], %[dst_ptr], #16\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" - - "ld1 {v0.8b}, [%[lhs_ptr]], #8\n" - "ldr x1, [%[lhs_ptr]], #8\n" - "ld1 {v1.8b}, [%[lhs_ptr]], #8\n" - "ldr x2, [%[lhs_ptr]], #8\n" - "ld1 {v2.8b}, [%[rhs_ptr]], #8\n" - "ldr x5, [%[rhs_ptr]], #8\n" - "ld1 {v3.8b}, [%[rhs_ptr]], #8\n" - "ldr x6, [%[rhs_ptr]], #8\n" - "ins v0.d[1], x1\n" - "ins v1.d[1], x2\n" - "ins v2.d[1], x5\n" - "ins v3.d[1], x6\n" - - // Since the store type is the same as the accum type, no need for - // downcast. There's also no need for clamp by min/max. - - // Compute how much of the 8x8 block of destination 32it values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 330f\n" - // Not all of the 8x8 block fits. - // Write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "st1 {v16.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v16) - "st1 {v17.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v17) - "st1 {v18.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v18) - "st1 {v19.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v19) - "st1 {v20.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v20) - "st1 {v21.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v21) - "st1 {v22.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v22) - "st1 {v23.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v23) - "st1 {v24.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v24) - "st1 {v25.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v25) - "st1 {v26.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v26) - "st1 {v27.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v27) - "st1 {v28.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v28) - "st1 {v29.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v29) - "st1 {v30.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v30) - "st1 {v31.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v31) - - "b 331f\n" - - "330:\n" - // Yes, all of the 8x8 block fits. - "mov x4, %[dst_ptr]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v16.4s, v17.4s}, [x4], x11\n" - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v18.4s, v19.4s}, [x4], x11\n" - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v20.4s, v21.4s}, [x4], x11\n" - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v22.4s, v23.4s}, [x4], x11\n" - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v24.4s, v25.4s}, [x4], x11\n" - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v26.4s, v27.4s}, [x4], x11\n" - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v28.4s, v29.4s}, [x4], x11\n" - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v30.4s, v31.4s}, [x4], x11\n" - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - "331:\n" - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 341f\n" - - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "350:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "351:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 351b\n" - "add w6, w6, #1\n" - "add x3, x3, #32\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 350b\n" - "341:\n" - "add %[dst_ptr], %[dst_ptr], #32\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #8\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #8\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), - [dst_type_id] "r"(params.dst_type_id) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "v31"); -} -#undef RUY_OFFSET_BIAS -#undef RUY_OFFSET_LHS_SUMS -#undef RUY_OFFSET_RHS_SUMS -#undef RUY_OFFSET_LHS_BASE_PTR -#undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT -#undef RUY_OFFSET_MULTIPLIER_EXPONENT -#undef RUY_OFFSET_RHS_BASE_PTR -#undef RUY_OFFSET_DST_BASE_PTR -#undef RUY_OFFSET_LHS_ZERO_POINT -#undef RUY_OFFSET_RHS_ZERO_POINT -#undef RUY_OFFSET_DST_ZERO_POINT -#undef RUY_OFFSET_PROD_ZP_DEPTH -#undef RUY_OFFSET_START_ROW -#undef RUY_OFFSET_START_COL -#undef RUY_OFFSET_LAST_ROW -#undef RUY_OFFSET_LAST_COL -#undef RUY_OFFSET_DST_ROWS -#undef RUY_OFFSET_DST_COLS -#undef RUY_OFFSET_LHS_STRIDE -#undef RUY_OFFSET_RHS_STRIDE -#undef RUY_OFFSET_DST_STRIDE -#undef RUY_OFFSET_DEPTH -#undef RUY_OFFSET_CLAMP_MIN -#undef RUY_OFFSET_CLAMP_MAX -#undef RUY_OFFSET_FLAGS - -#define RUY_OFFSET_LHS_BASE_PTR 0 -#define RUY_OFFSET_RHS_BASE_PTR 8 -#define RUY_OFFSET_DST_BASE_PTR 16 -#define RUY_OFFSET_BIAS 24 -#define RUY_OFFSET_START_ROW 32 -#define RUY_OFFSET_START_COL 36 -#define RUY_OFFSET_LAST_ROW 40 -#define RUY_OFFSET_LAST_COL 44 -#define RUY_OFFSET_LHS_STRIDE 56 -#define RUY_OFFSET_RHS_STRIDE 60 -#define RUY_OFFSET_DST_STRIDE 64 -#define RUY_OFFSET_DEPTH 68 -#define RUY_OFFSET_CLAMP_MIN 72 -#define RUY_OFFSET_CLAMP_MAX 76 -#define RUY_OFFSET_FLAGS 80 - -template -void CheckOffsetsInKernelParamsFloat(const Params&) { - static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); - static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, ""); - static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, ""); - static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); - static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); - static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, ""); - static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); - static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); - static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); - static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); - static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); - static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); - static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); - static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); - static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); -} - -// Just a plain float kernel; good enough for out-of-order cores. -// The closest to it in the gemmlowp collection would be -// NEON_64bit_GEMM_Float32_WithScalar, -// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L3925 -// -// Besides ruy-ification, the main nuance here is that we stick to a 8x8 -// width instead of the wider 12x8 that the register space permits and that -// the aforementioned gemmlowp kernel uses. Ruy likes powers of two for now -// and we don't have evidence that going beyond 8x8 is needed. -void KernelFloatNeonOutOfOrder(const KernelParamsFloat<8, 8>& params) { - CheckOffsetsInKernelParamsFloat(params); - profiler::ScopeLabel label( - "Kernel (kNeon, optimized for out-of-order cores)"); - - const float* lhs_col_ptr = params.lhs_base_ptr; - const float* rhs_col_ptr = params.rhs_base_ptr; - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - float* dst_col_ptr = params.dst_base_ptr; - float* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v31 are accumulators. - // During accumulation, v0 -- v15 are used to load data from LHS and RHS. - // At least v0 and v1 are used to load a 8x1 block of LHS, and v2 and - // v3 are used to load a 1x8 block of RHS, like this: - // - // RHS 1x8 block - // /-----------------------------------------\ - // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| - // \-----------------------------------------/ - // LHS 8x1 block - // /---------------------\ /-----------------------------------------\ - // | v0.s[0] | |v16.s[0] ... v30.s[0]| - // | ... | | ... ... | - // | v0.s[3] | |v16.s[3] ... v30.s[3]| - // | v1.s[0] | |v17.s[0] ... v31.s[0]| - // | ... | | ... ... | - // | v1.s[3] | |v17.s[3] ... v31.s[3]| - // \---------------------/ \-----------------------------------------/ - // accumulators 8x8 block - // - // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step - // is repeated 4 times, using 4x more registers for LHS and RHS, so that - // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15. - // - // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are - // unused, and v8 -- v15 are used for floading parameters used for the - // post-accumulation part of the kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - - // Load the first 32 bytes of LHS and RHS data. - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" - - // Clear accumulators. - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 1. - "mov w1, #1\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - "fmla v16.4s, v0.4s, v2.s[0]\n" - "fmla v18.4s, v0.4s, v2.s[1]\n" - "fmla v20.4s, v0.4s, v2.s[2]\n" - "fmla v22.4s, v0.4s, v2.s[3]\n" - -#if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING) - "cmp w12, #8\n" - "blt 78f\n" - "and w2, w12, #-4\n" - - "ld1 {v4.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v5.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v6.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v7.4s}, [%[rhs_ptr]], #16\n" - - "ld1 {v8.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v9.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v10.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v11.4s}, [%[rhs_ptr]], #16\n" - - "ld1 {v12.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v13.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v14.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v15.4s}, [%[rhs_ptr]], #16\n" - "mov w1, #4\n" - - "80:\n" - - "add %[lhs_ptr], %[lhs_ptr], #128\n" - "add %[rhs_ptr], %[rhs_ptr], #128\n" - - "fmla v24.4s, v0.4s, v3.s[0]\n" - "fmla v26.4s, v0.4s, v3.s[1]\n" - "fmla v28.4s, v0.4s, v3.s[2]\n" - "fmla v30.4s, v0.4s, v3.s[3]\n" - "ldr q0, [%[lhs_ptr], #-128]\n" - "fmla v25.4s, v1.4s, v3.s[0]\n" - "fmla v27.4s, v1.4s, v3.s[1]\n" - "fmla v29.4s, v1.4s, v3.s[2]\n" - "fmla v31.4s, v1.4s, v3.s[3]\n" - "ldr q3, [%[rhs_ptr], #-112]\n" - "fmla v17.4s, v1.4s, v2.s[0]\n" - "fmla v19.4s, v1.4s, v2.s[1]\n" - "fmla v21.4s, v1.4s, v2.s[2]\n" - "fmla v23.4s, v1.4s, v2.s[3]\n" - "ldr q1, [%[lhs_ptr], #-112]\n" - "fmla v16.4s, v4.4s, v6.s[0]\n" - "fmla v18.4s, v4.4s, v6.s[1]\n" - "ldr q2, [%[rhs_ptr], #-128]\n" - "fmla v20.4s, v4.4s, v6.s[2]\n" - "fmla v22.4s, v4.4s, v6.s[3]\n" - - "fmla v24.4s, v4.4s, v7.s[0]\n" - "fmla v26.4s, v4.4s, v7.s[1]\n" - "fmla v28.4s, v4.4s, v7.s[2]\n" - "fmla v30.4s, v4.4s, v7.s[3]\n" - "ldr q4, [%[lhs_ptr], #-96]\n" - "fmla v25.4s, v5.4s, v7.s[0]\n" - "fmla v27.4s, v5.4s, v7.s[1]\n" - "fmla v29.4s, v5.4s, v7.s[2]\n" - "fmla v31.4s, v5.4s, v7.s[3]\n" - "ldr q7, [%[rhs_ptr], #-80]\n" - "fmla v17.4s, v5.4s, v6.s[0]\n" - "fmla v19.4s, v5.4s, v6.s[1]\n" - "fmla v21.4s, v5.4s, v6.s[2]\n" - "fmla v23.4s, v5.4s, v6.s[3]\n" - "ldr q5, [%[lhs_ptr], #-80]\n" - "fmla v16.4s, v8.4s, v10.s[0]\n" - "fmla v18.4s, v8.4s, v10.s[1]\n" - "ldr q6, [%[rhs_ptr], #-96]\n" - "fmla v20.4s, v8.4s, v10.s[2]\n" - "fmla v22.4s, v8.4s, v10.s[3]\n" - - "fmla v24.4s, v8.4s, v11.s[0]\n" - "fmla v26.4s, v8.4s, v11.s[1]\n" - "fmla v28.4s, v8.4s, v11.s[2]\n" - "fmla v30.4s, v8.4s, v11.s[3]\n" - "ldr q8, [%[lhs_ptr], #-64]\n" - "fmla v25.4s, v9.4s, v11.s[0]\n" - "fmla v27.4s, v9.4s, v11.s[1]\n" - "fmla v29.4s, v9.4s, v11.s[2]\n" - "fmla v31.4s, v9.4s, v11.s[3]\n" - "ldr q11, [%[rhs_ptr], #-48]\n" - "fmla v17.4s, v9.4s, v10.s[0]\n" - "fmla v19.4s, v9.4s, v10.s[1]\n" - "fmla v21.4s, v9.4s, v10.s[2]\n" - "fmla v23.4s, v9.4s, v10.s[3]\n" - "ldr q9, [%[lhs_ptr], #-48]\n" - "fmla v16.4s, v12.4s, v14.s[0]\n" - "fmla v18.4s, v12.4s, v14.s[1]\n" - "ldr q10, [%[rhs_ptr], #-64]\n" - "fmla v20.4s, v12.4s, v14.s[2]\n" - "fmla v22.4s, v12.4s, v14.s[3]\n" - - "fmla v24.4s, v12.4s, v15.s[0]\n" - "fmla v26.4s, v12.4s, v15.s[1]\n" - "fmla v28.4s, v12.4s, v15.s[2]\n" - "fmla v30.4s, v12.4s, v15.s[3]\n" - "ldr q12, [%[lhs_ptr], #-32]\n" - "fmla v25.4s, v13.4s, v15.s[0]\n" - "fmla v27.4s, v13.4s, v15.s[1]\n" - "fmla v29.4s, v13.4s, v15.s[2]\n" - "fmla v31.4s, v13.4s, v15.s[3]\n" - "ldr q15, [%[rhs_ptr], #-16]\n" - "fmla v17.4s, v13.4s, v14.s[0]\n" - "fmla v19.4s, v13.4s, v14.s[1]\n" - "fmla v21.4s, v13.4s, v14.s[2]\n" - "fmla v23.4s, v13.4s, v14.s[3]\n" - "ldr q13, [%[lhs_ptr], #-16]\n" - "fmla v16.4s, v0.4s, v2.s[0]\n" - "fmla v18.4s, v0.4s, v2.s[1]\n" - "ldr q14, [%[rhs_ptr], #-32]\n" - "fmla v20.4s, v0.4s, v2.s[2]\n" - "fmla v22.4s, v0.4s, v2.s[3]\n" - - "add w1, w1, #4\n" - "cmp w1, w2\n" - "blt 80b\n" - - "fmla v16.4s, v4.4s, v6.s[0]\n" - "fmla v18.4s, v4.4s, v6.s[1]\n" - "fmla v20.4s, v4.4s, v6.s[2]\n" - "fmla v22.4s, v4.4s, v6.s[3]\n" - "fmla v24.4s, v4.4s, v7.s[0]\n" - "fmla v26.4s, v4.4s, v7.s[1]\n" - "fmla v28.4s, v4.4s, v7.s[2]\n" - "fmla v30.4s, v4.4s, v7.s[3]\n" - "fmla v25.4s, v5.4s, v7.s[0]\n" - "fmla v27.4s, v5.4s, v7.s[1]\n" - "fmla v29.4s, v5.4s, v7.s[2]\n" - "fmla v31.4s, v5.4s, v7.s[3]\n" - "fmla v17.4s, v5.4s, v6.s[0]\n" - "fmla v19.4s, v5.4s, v6.s[1]\n" - "fmla v21.4s, v5.4s, v6.s[2]\n" - "fmla v23.4s, v5.4s, v6.s[3]\n" - - "fmla v16.4s, v8.4s, v10.s[0]\n" - "fmla v18.4s, v8.4s, v10.s[1]\n" - "fmla v20.4s, v8.4s, v10.s[2]\n" - "fmla v22.4s, v8.4s, v10.s[3]\n" - "fmla v24.4s, v8.4s, v11.s[0]\n" - "fmla v26.4s, v8.4s, v11.s[1]\n" - "fmla v28.4s, v8.4s, v11.s[2]\n" - "fmla v30.4s, v8.4s, v11.s[3]\n" - "fmla v25.4s, v9.4s, v11.s[0]\n" - "fmla v27.4s, v9.4s, v11.s[1]\n" - "fmla v29.4s, v9.4s, v11.s[2]\n" - "fmla v31.4s, v9.4s, v11.s[3]\n" - "fmla v17.4s, v9.4s, v10.s[0]\n" - "fmla v19.4s, v9.4s, v10.s[1]\n" - "fmla v21.4s, v9.4s, v10.s[2]\n" - "fmla v23.4s, v9.4s, v10.s[3]\n" - - "fmla v16.4s, v12.4s, v14.s[0]\n" - "fmla v18.4s, v12.4s, v14.s[1]\n" - "fmla v20.4s, v12.4s, v14.s[2]\n" - "fmla v22.4s, v12.4s, v14.s[3]\n" - "fmla v24.4s, v12.4s, v15.s[0]\n" - "fmla v26.4s, v12.4s, v15.s[1]\n" - "fmla v28.4s, v12.4s, v15.s[2]\n" - "fmla v30.4s, v12.4s, v15.s[3]\n" - "fmla v25.4s, v13.4s, v15.s[0]\n" - "fmla v27.4s, v13.4s, v15.s[1]\n" - "fmla v29.4s, v13.4s, v15.s[2]\n" - "fmla v31.4s, v13.4s, v15.s[3]\n" - "fmla v17.4s, v13.4s, v14.s[0]\n" - "fmla v19.4s, v13.4s, v14.s[1]\n" - "fmla v21.4s, v13.4s, v14.s[2]\n" - "fmla v23.4s, v13.4s, v14.s[3]\n" - - "78:\n" -#endif - - // Accumulation loop - "cmp w1, w12\n" - "beq 79f\n" - - "2:\n" - "fmla v24.4s, v0.4s, v3.s[0]\n" - "fmla v26.4s, v0.4s, v3.s[1]\n" - "ld1 {v4.4s}, [%[rhs_ptr]], #16\n" - "fmla v28.4s, v0.4s, v3.s[2]\n" - "fmla v30.4s, v0.4s, v3.s[3]\n" - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - "fmla v25.4s, v1.4s, v3.s[0]\n" - "fmla v27.4s, v1.4s, v3.s[1]\n" - "add w1, w1, #1\n" - "fmla v29.4s, v1.4s, v3.s[2]\n" - "fmla v31.4s, v1.4s, v3.s[3]\n" - "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" - "fmla v17.4s, v1.4s, v2.s[0]\n" - "fmla v19.4s, v1.4s, v2.s[1]\n" - "cmp w1, w12\n" - "fmla v21.4s, v1.4s, v2.s[2]\n" - "fmla v23.4s, v1.4s, v2.s[3]\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "fmla v16.4s, v0.4s, v4.s[0]\n" - "fmla v18.4s, v0.4s, v4.s[1]\n" - "mov v2.16b, v4.16b\n" - "fmla v20.4s, v0.4s, v4.s[2]\n" - "fmla v22.4s, v0.4s, v4.s[3]\n" - "blt 2b\n" - - "79:\n" - - // End of the inner loop on depth. Now perform the remaining - // multiply-adds of the last level of depth, for which the LHS - // and RHS data is already loaded. - - "fmla v24.4s, v0.4s, v3.s[0]\n" - "fmla v26.4s, v0.4s, v3.s[1]\n" - "fmla v28.4s, v0.4s, v3.s[2]\n" - "fmla v30.4s, v0.4s, v3.s[3]\n" - "fmla v25.4s, v1.4s, v3.s[0]\n" - "fmla v27.4s, v1.4s, v3.s[1]\n" - "fmla v29.4s, v1.4s, v3.s[2]\n" - "fmla v31.4s, v1.4s, v3.s[3]\n" - "fmla v17.4s, v1.4s, v2.s[0]\n" - "fmla v19.4s, v1.4s, v2.s[1]\n" - "fmla v21.4s, v1.4s, v2.s[2]\n" - "fmla v23.4s, v1.4s, v2.s[3]\n" - - // End of accumulation. The registers v16 -- v31 contain the final - // int32 accumulator values of the current 8x8 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 8x8 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "cmp %w[row], w7\n" // Have we finished the last row? - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - - // Load some parameters needed for the end work on current block. - "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - - // Offset these base pointers as needed given the current row, col. - "add x5, x1, %x[row], lsl #2\n" - - "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 8 bias values. - "ld1 {v14.4s}, [x1], #16\n" - "ld1 {v15.4s}, [x1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "fadd v16.4s, v16.4s, v14.4s\n" - "fadd v17.4s, v17.4s, v15.4s\n" - "fadd v18.4s, v18.4s, v14.4s\n" - "fadd v19.4s, v19.4s, v15.4s\n" - "fadd v20.4s, v20.4s, v14.4s\n" - "fadd v21.4s, v21.4s, v15.4s\n" - "fadd v22.4s, v22.4s, v14.4s\n" - "fadd v23.4s, v23.4s, v15.4s\n" - "fadd v24.4s, v24.4s, v14.4s\n" - "fadd v25.4s, v25.4s, v15.4s\n" - "fadd v26.4s, v26.4s, v14.4s\n" - "fadd v27.4s, v27.4s, v15.4s\n" - "fadd v28.4s, v28.4s, v14.4s\n" - "fadd v29.4s, v29.4s, v15.4s\n" - "fadd v30.4s, v30.4s, v14.4s\n" - "fadd v31.4s, v31.4s, v15.4s\n" - - // Load the clamp_min, clamp_max bounds - "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.4s, w2\n" // clamp_min - "dup v15.4s, w3\n" // clamp_max - - // Apply the clamp_min bound - "fmax v16.4s, v16.4s, v14.4s\n" - "fmax v17.4s, v17.4s, v14.4s\n" - "fmax v18.4s, v18.4s, v14.4s\n" - "fmax v19.4s, v19.4s, v14.4s\n" - "fmax v20.4s, v20.4s, v14.4s\n" - "fmax v21.4s, v21.4s, v14.4s\n" - "fmax v22.4s, v22.4s, v14.4s\n" - "fmax v23.4s, v23.4s, v14.4s\n" - "fmax v24.4s, v24.4s, v14.4s\n" - "fmax v25.4s, v25.4s, v14.4s\n" - "fmax v26.4s, v26.4s, v14.4s\n" - "fmax v27.4s, v27.4s, v14.4s\n" - "fmax v28.4s, v28.4s, v14.4s\n" - "fmax v29.4s, v29.4s, v14.4s\n" - "fmax v30.4s, v30.4s, v14.4s\n" - "fmax v31.4s, v31.4s, v14.4s\n" - - // Apply the clamp_max bound - "fmin v16.4s, v16.4s, v15.4s\n" - "fmin v17.4s, v17.4s, v15.4s\n" - "fmin v18.4s, v18.4s, v15.4s\n" - "fmin v19.4s, v19.4s, v15.4s\n" - "fmin v20.4s, v20.4s, v15.4s\n" - "fmin v21.4s, v21.4s, v15.4s\n" - "fmin v22.4s, v22.4s, v15.4s\n" - "fmin v23.4s, v23.4s, v15.4s\n" - "fmin v24.4s, v24.4s, v15.4s\n" - "fmin v25.4s, v25.4s, v15.4s\n" - "fmin v26.4s, v26.4s, v15.4s\n" - "fmin v27.4s, v27.4s, v15.4s\n" - "fmin v28.4s, v28.4s, v15.4s\n" - "fmin v29.4s, v29.4s, v15.4s\n" - "fmin v30.4s, v30.4s, v15.4s\n" - "fmin v31.4s, v31.4s, v15.4s\n" - - // Compute how much of the 8x8 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w2 = how many cols of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 30f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #32\n" - "b 31f\n" - "30:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "31:\n" - - // Write our 8bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "str q16, [x3, #0]\n" - "str q17, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - "str q18, [x3, #0]\n" - "str q19, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - "str q20, [x3, #0]\n" - "str q21, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - "str q22, [x3, #0]\n" - "str q23, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - "str q24, [x3, #0]\n" - "str q25, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - "str q26, [x3, #0]\n" - "str q27, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - "str q28, [x3, #0]\n" - "str q29, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - "str q30, [x3, #0]\n" - "str q31, [x3, #16]\n" - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #32\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "41:\n" - "add %[dst_ptr], %[dst_ptr], #32\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #8\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #8\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 1. - "mov w1, #1\n" - - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "v31"); -} - -// Variant of KernelFloatNeonOutOfOrder tuned for in-order CPUs that do not -// support dotprod (while dotprod by itself is not relevant to floating-point, -// this additional bit of information that we have about the target happens to -// be useful here). -// -// So a typical target CPU here would be ARM Cortex-A53 or the original -// Cortex-A55. -// -// This kernel is similar to and inspired by gemmlowp's -// NEON_64bit_GEMM_Float32_WithScalar_A53. -// which was contributed by David Mansell with very helpful -// comments. Specifically, see this comment about tuning for Cortex-A53: -// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4215 -void KernelFloatNeonInOrder(const KernelParamsFloat<8, 8>& params) { - profiler::ScopeLabel label("Kernel (kNeon, optimized for in-order cores)"); - - CheckOffsetsInKernelParamsFloat(params); - - const float* lhs_col_ptr = params.lhs_base_ptr; - const float* rhs_col_ptr = params.rhs_base_ptr; - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - float* dst_col_ptr = params.dst_base_ptr; - float* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v31 are accumulators. - // During accumulation, v0 -- v3 are used to load data from LHS and RHS. - // - // RHS 1x8 block - // /-----------------------------------------\ - // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| - // \-----------------------------------------/ - // LHS 8x1 block - // /---------------------\ /-----------------------------------------\ - // | v0.s[0] | |v16.s[0] ... v30.s[0]| - // | ... | | ... ... | - // | v0.s[3] | |v16.s[3] ... v30.s[3]| - // | v1.s[0] | |v17.s[0] ... v31.s[0]| - // | ... | | ... ... | - // | v1.s[3] | |v17.s[3] ... v31.s[3]| - // \---------------------/ \-----------------------------------------/ - // accumulators 8x8 block - // - // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because - // we did not observe a benefit of such partial unrolling on in-order CPUs. - // - // v4 -- v7 are unused, and v8 -- v15 are used for floading parameters used - // for the post-accumulation part of the kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - - - // Clear accumulators. - RUY_MAKE_ZERO(v16) - // Load the first 32 bytes of LHS and RHS data. - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - RUY_MAKE_ZERO(v17) - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - RUY_MAKE_ZERO(v18) - "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" - RUY_MAKE_ZERO(v19) - "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" - RUY_MAKE_ZERO(v20) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #64]\n") - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #64]\n") - RUY_MAKE_ZERO(v22) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #128]\n") - RUY_MAKE_ZERO(v23) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #128]\n") - RUY_MAKE_ZERO(v24) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #192]\n") - RUY_MAKE_ZERO(v25) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #192]\n") - RUY_MAKE_ZERO(v26) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n") - RUY_MAKE_ZERO(v27) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n") - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // w1 is the number of levels of depth that remain to load - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently depth - 1. - "sub w1, w12, #1\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - "cmp w1, #0\n" - "fmla v16.4s, v0.4s, v2.s[0]\n" - "fmla v18.4s, v0.4s, v2.s[1]\n" - "fmla v20.4s, v0.4s, v2.s[2]\n" - "fmla v22.4s, v0.4s, v2.s[3]\n" - - // Accumulation loop - "beq 79f\n" - - "2:\n" - - "fmla v24.4s, v0.4s, v3.s[0]\n" - "ldr x2, [%[lhs_ptr], #8]\n" - "fmla v26.4s, v0.4s, v3.s[1]\n" - "ldr x3, [%[lhs_ptr], #24]\n" - "fmla v28.4s, v0.4s, v3.s[2]\n" - "ldr x5, [%[rhs_ptr], #24]\n" - "fmla v30.4s, v0.4s, v3.s[3]\n" - "ldr x4, [%[rhs_ptr], #8]\n" - "fmla v25.4s, v1.4s, v3.s[0]\n" - "subs w1, w1, #1\n" - "ldr d0, [%[lhs_ptr]], #32\n" - "fmla v27.4s, v1.4s, v3.s[1]\n" - "fmla v29.4s, v1.4s, v3.s[2]\n" - "fmla v31.4s, v1.4s, v3.s[3]\n" - "ins v0.d[1], x2\n" - "ldr d3, [%[rhs_ptr], #16]\n" - "fmla v17.4s, v1.4s, v2.s[0]\n" - "fmla v19.4s, v1.4s, v2.s[1]\n" - "ins v3.d[1], x5\n" - "ldr d4, [%[rhs_ptr]], #32\n" - "fmla v21.4s, v1.4s, v2.s[2]\n" - "fmla v23.4s, v1.4s, v2.s[3]\n" - "fmla v16.4s, v0.4s, v4.s[0]\n" - "ins v4.d[1], x4\n" - "ldr d1, [%[lhs_ptr], #-16]\n" - "fmla v18.4s, v0.4s, v4.s[1]\n" - "fmla v20.4s, v0.4s, v4.s[2]\n" - "ins v1.d[1], x3\n" - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n") - "mov v2.16b, v4.16b\n" - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n") - "fmla v22.4s, v0.4s, v4.s[3]\n" - "bne 2b\n" - - "79:\n" - - // End of the inner loop on depth. Now perform the remaining - // multiply-adds of the last level of depth, for which the LHS - // and RHS data is already loaded. - - "fmla v24.4s, v0.4s, v3.s[0]\n" - "fmla v26.4s, v0.4s, v3.s[1]\n" - "fmla v28.4s, v0.4s, v3.s[2]\n" - "fmla v30.4s, v0.4s, v3.s[3]\n" - "fmla v25.4s, v1.4s, v3.s[0]\n" - "fmla v27.4s, v1.4s, v3.s[1]\n" - "fmla v29.4s, v1.4s, v3.s[2]\n" - "fmla v31.4s, v1.4s, v3.s[3]\n" - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "fmla v17.4s, v1.4s, v2.s[0]\n" - "fmla v19.4s, v1.4s, v2.s[1]\n" - "fmla v21.4s, v1.4s, v2.s[2]\n" - "fmla v23.4s, v1.4s, v2.s[3]\n" - - // End of accumulation. The registers v16 -- v31 contain the final - // int32 accumulator values of the current 8x8 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 8x8 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "cmp %w[row], w7\n" // Have we finished the last row? - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - - // Load some parameters needed for the end work on current block. - "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - - // Offset these base pointers as needed given the current row, col. - "add x5, x1, %x[row], lsl #2\n" - - "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 8 bias values. - "ld1 {v14.4s}, [x1], #16\n" - "ld1 {v15.4s}, [x1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "fadd v16.4s, v16.4s, v14.4s\n" - "fadd v17.4s, v17.4s, v15.4s\n" - "fadd v18.4s, v18.4s, v14.4s\n" - "fadd v19.4s, v19.4s, v15.4s\n" - "fadd v20.4s, v20.4s, v14.4s\n" - "fadd v21.4s, v21.4s, v15.4s\n" - "fadd v22.4s, v22.4s, v14.4s\n" - "fadd v23.4s, v23.4s, v15.4s\n" - "fadd v24.4s, v24.4s, v14.4s\n" - "fadd v25.4s, v25.4s, v15.4s\n" - "fadd v26.4s, v26.4s, v14.4s\n" - "fadd v27.4s, v27.4s, v15.4s\n" - "fadd v28.4s, v28.4s, v14.4s\n" - "fadd v29.4s, v29.4s, v15.4s\n" - "fadd v30.4s, v30.4s, v14.4s\n" - "fadd v31.4s, v31.4s, v15.4s\n" - - // Load the clamp_min, clamp_max bounds - "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.4s, w2\n" // clamp_min - "dup v15.4s, w3\n" // clamp_max - - // Apply the clamp_min bound - "fmax v16.4s, v16.4s, v14.4s\n" - "fmax v17.4s, v17.4s, v14.4s\n" - "fmax v18.4s, v18.4s, v14.4s\n" - "fmax v19.4s, v19.4s, v14.4s\n" - "fmax v20.4s, v20.4s, v14.4s\n" - "fmax v21.4s, v21.4s, v14.4s\n" - "fmax v22.4s, v22.4s, v14.4s\n" - "fmax v23.4s, v23.4s, v14.4s\n" - "fmax v24.4s, v24.4s, v14.4s\n" - "fmax v25.4s, v25.4s, v14.4s\n" - "fmax v26.4s, v26.4s, v14.4s\n" - "fmax v27.4s, v27.4s, v14.4s\n" - "fmax v28.4s, v28.4s, v14.4s\n" - "fmax v29.4s, v29.4s, v14.4s\n" - "fmax v30.4s, v30.4s, v14.4s\n" - "fmax v31.4s, v31.4s, v14.4s\n" - - // Apply the clamp_max bound - "fmin v16.4s, v16.4s, v15.4s\n" - "fmin v17.4s, v17.4s, v15.4s\n" - "fmin v18.4s, v18.4s, v15.4s\n" - "fmin v19.4s, v19.4s, v15.4s\n" - "fmin v20.4s, v20.4s, v15.4s\n" - "fmin v21.4s, v21.4s, v15.4s\n" - "fmin v22.4s, v22.4s, v15.4s\n" - "fmin v23.4s, v23.4s, v15.4s\n" - "fmin v24.4s, v24.4s, v15.4s\n" - "fmin v25.4s, v25.4s, v15.4s\n" - "fmin v26.4s, v26.4s, v15.4s\n" - "fmin v27.4s, v27.4s, v15.4s\n" - "fmin v28.4s, v28.4s, v15.4s\n" - "fmin v29.4s, v29.4s, v15.4s\n" - "fmin v30.4s, v30.4s, v15.4s\n" - "fmin v31.4s, v31.4s, v15.4s\n" - - // Compute how much of the 8x8 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w2 = how many cols of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 30f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #32\n" - "b 31f\n" - "30:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "31:\n" - - // Write our 8bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "str q16, [x3, #0]\n" - "str q17, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - "str q18, [x3, #0]\n" - "str q19, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - "str q20, [x3, #0]\n" - "str q21, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - "str q22, [x3, #0]\n" - "str q23, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - "str q24, [x3, #0]\n" - "str q25, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - "str q26, [x3, #0]\n" - "str q27, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - "str q28, [x3, #0]\n" - "str q29, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - "str q30, [x3, #0]\n" - "str q31, [x3, #16]\n" - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #32\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "41:\n" - "add %[dst_ptr], %[dst_ptr], #32\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #8\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #8\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - - // w1 is the number of levels of depth that remain to load - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently depth - 1. - "sub w1, w12, #1\n" - - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "v31"); -} - -// Variant of KernelFloatNeonInOrder tuned for in-order CPUs that do -// support dotprod (while dotprod by itself is not relevant to floating-point, -// this additional bit of information that we have about the target happens to -// be useful here). -// -// So a typical target CPU here would be ARM Cortex-A55r1. -// -// This kernel is similar to and inspired by gemmlowp's -// NEON_64bit_GEMM_Float32_WithScalar_A55r1. -// which was contributed by David Mansell with very helpful -// comments. Specifically, see this comment about tuning for Cortex-A55r1: -// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4412 -void KernelFloatNeonDotprodInOrder(const KernelParamsFloat<8, 8>& params) { - profiler::ScopeLabel label( - "Kernel (kNeonDotprod, optimized for in-order cores)"); - - CheckOffsetsInKernelParamsFloat(params); - - const float* lhs_col_ptr = params.lhs_base_ptr; - const float* rhs_col_ptr = params.rhs_base_ptr; - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - float* dst_col_ptr = params.dst_base_ptr; - float* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v31 are accumulators. - // During accumulation, v0 -- v3 are used to load data from LHS and RHS. - // - // RHS 1x8 block - // /-----------------------------------------\ - // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| - // \-----------------------------------------/ - // LHS 8x1 block - // /---------------------\ /-----------------------------------------\ - // | v0.s[0] | |v16.s[0] ... v30.s[0]| - // | ... | | ... ... | - // | v0.s[3] | |v16.s[3] ... v30.s[3]| - // | v1.s[0] | |v17.s[0] ... v31.s[0]| - // | ... | | ... ... | - // | v1.s[3] | |v17.s[3] ... v31.s[3]| - // \---------------------/ \-----------------------------------------/ - // accumulators 8x8 block - // - // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because - // we did not observe a benefit of such partial unrolling on in-order CPUs. - // - // v4 -- v7 are unused, and v8 -- v15 are used for floading parameters used - // for the post-accumulation part of the kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - - - // Clear accumulators. - RUY_MAKE_ZERO(v16) - // Load the first 32 bytes of LHS and RHS data. - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - RUY_MAKE_ZERO(v17) - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - RUY_MAKE_ZERO(v18) - "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" - RUY_MAKE_ZERO(v19) - "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" - RUY_MAKE_ZERO(v20) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #64]\n") - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #64]\n") - RUY_MAKE_ZERO(v22) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #128]\n") - RUY_MAKE_ZERO(v23) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #128]\n") - RUY_MAKE_ZERO(v24) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #192]\n") - RUY_MAKE_ZERO(v25) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #192]\n") - RUY_MAKE_ZERO(v26) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n") - RUY_MAKE_ZERO(v27) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n") - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // w1 is the number of levels of depth that remain to load - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently depth - 1. - "sub w1, w12, #1\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - "cmp w1, #0\n" - "fmla v16.4s, v0.4s, v2.s[0]\n" - "fmla v18.4s, v0.4s, v2.s[1]\n" - "fmla v20.4s, v0.4s, v2.s[2]\n" - "fmla v22.4s, v0.4s, v2.s[3]\n" - - // Accumulation loop - "beq 79f\n" - - "2:\n" - - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n") - "fmla v24.4s, v0.4s, v3.s[0]\n" - "ldr x2, [%[lhs_ptr], #8]\n" - "fmla v26.4s, v0.4s, v3.s[1]\n" - "ldr x3, [%[lhs_ptr], #24]\n" - "fmla v28.4s, v0.4s, v3.s[2]\n" - "ldr x5, [%[rhs_ptr], #24]\n" - "fmla v30.4s, v0.4s, v3.s[3]\n" - "ldr d0, [%[lhs_ptr]], #32\n" - "fmla v25.4s, v1.4s, v3.s[0]\n" - "ldr x4, [%[rhs_ptr], #8]\n" - "fmla v27.4s, v1.4s, v3.s[1]\n" - "subs w1, w1, #1\n" - "fmla v29.4s, v1.4s, v3.s[2]\n" - "ins v0.d[1], x2\n" - "fmla v31.4s, v1.4s, v3.s[3]\n" - "ldr d3, [%[rhs_ptr], #16]\n" - "fmla v17.4s, v1.4s, v2.s[0]\n" - "ins v3.d[1], x5\n" - "fmla v19.4s, v1.4s, v2.s[1]\n" - "ldr d4, [%[rhs_ptr]], #32\n" - "fmla v21.4s, v1.4s, v2.s[2]\n" - "ins v4.d[1], x4\n" - "fmla v23.4s, v1.4s, v2.s[3]\n" - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n") - "fmla v16.4s, v0.4s, v4.s[0]\n" - "ldr d1, [%[lhs_ptr], #-16]\n" - "fmla v18.4s, v0.4s, v4.s[1]\n" - "ins v1.d[1], x3\n" - "fmla v20.4s, v0.4s, v4.s[2]\n" - "mov v2.16b, v4.16b\n" - "fmla v22.4s, v0.4s, v4.s[3]\n" - "bne 2b\n" - - "79:\n" - - // End of the inner loop on depth. Now perform the remaining - // multiply-adds of the last level of depth, for which the LHS - // and RHS data is already loaded. - - "fmla v24.4s, v0.4s, v3.s[0]\n" - "fmla v26.4s, v0.4s, v3.s[1]\n" - "fmla v28.4s, v0.4s, v3.s[2]\n" - "fmla v30.4s, v0.4s, v3.s[3]\n" - "fmla v25.4s, v1.4s, v3.s[0]\n" - "fmla v27.4s, v1.4s, v3.s[1]\n" - "fmla v29.4s, v1.4s, v3.s[2]\n" - "fmla v31.4s, v1.4s, v3.s[3]\n" - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "fmla v17.4s, v1.4s, v2.s[0]\n" - "fmla v19.4s, v1.4s, v2.s[1]\n" - "fmla v21.4s, v1.4s, v2.s[2]\n" - "fmla v23.4s, v1.4s, v2.s[3]\n" - - // End of accumulation. The registers v16 -- v31 contain the final - // int32 accumulator values of the current 8x8 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 8x8 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "cmp %w[row], w7\n" // Have we finished the last row? - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - - // Load some parameters needed for the end work on current block. - "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - - // Offset these base pointers as needed given the current row, col. - "add x5, x1, %x[row], lsl #2\n" - - "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 8 bias values. - "ld1 {v14.4s}, [x1], #16\n" - "ld1 {v15.4s}, [x1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "fadd v16.4s, v16.4s, v14.4s\n" - "fadd v17.4s, v17.4s, v15.4s\n" - "fadd v18.4s, v18.4s, v14.4s\n" - "fadd v19.4s, v19.4s, v15.4s\n" - "fadd v20.4s, v20.4s, v14.4s\n" - "fadd v21.4s, v21.4s, v15.4s\n" - "fadd v22.4s, v22.4s, v14.4s\n" - "fadd v23.4s, v23.4s, v15.4s\n" - "fadd v24.4s, v24.4s, v14.4s\n" - "fadd v25.4s, v25.4s, v15.4s\n" - "fadd v26.4s, v26.4s, v14.4s\n" - "fadd v27.4s, v27.4s, v15.4s\n" - "fadd v28.4s, v28.4s, v14.4s\n" - "fadd v29.4s, v29.4s, v15.4s\n" - "fadd v30.4s, v30.4s, v14.4s\n" - "fadd v31.4s, v31.4s, v15.4s\n" - - // Load the clamp_min, clamp_max bounds - "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.4s, w2\n" // clamp_min - "dup v15.4s, w3\n" // clamp_max - - // Apply the clamp_min bound - "fmax v16.4s, v16.4s, v14.4s\n" - "fmax v17.4s, v17.4s, v14.4s\n" - "fmax v18.4s, v18.4s, v14.4s\n" - "fmax v19.4s, v19.4s, v14.4s\n" - "fmax v20.4s, v20.4s, v14.4s\n" - "fmax v21.4s, v21.4s, v14.4s\n" - "fmax v22.4s, v22.4s, v14.4s\n" - "fmax v23.4s, v23.4s, v14.4s\n" - "fmax v24.4s, v24.4s, v14.4s\n" - "fmax v25.4s, v25.4s, v14.4s\n" - "fmax v26.4s, v26.4s, v14.4s\n" - "fmax v27.4s, v27.4s, v14.4s\n" - "fmax v28.4s, v28.4s, v14.4s\n" - "fmax v29.4s, v29.4s, v14.4s\n" - "fmax v30.4s, v30.4s, v14.4s\n" - "fmax v31.4s, v31.4s, v14.4s\n" - - // Apply the clamp_max bound - "fmin v16.4s, v16.4s, v15.4s\n" - "fmin v17.4s, v17.4s, v15.4s\n" - "fmin v18.4s, v18.4s, v15.4s\n" - "fmin v19.4s, v19.4s, v15.4s\n" - "fmin v20.4s, v20.4s, v15.4s\n" - "fmin v21.4s, v21.4s, v15.4s\n" - "fmin v22.4s, v22.4s, v15.4s\n" - "fmin v23.4s, v23.4s, v15.4s\n" - "fmin v24.4s, v24.4s, v15.4s\n" - "fmin v25.4s, v25.4s, v15.4s\n" - "fmin v26.4s, v26.4s, v15.4s\n" - "fmin v27.4s, v27.4s, v15.4s\n" - "fmin v28.4s, v28.4s, v15.4s\n" - "fmin v29.4s, v29.4s, v15.4s\n" - "fmin v30.4s, v30.4s, v15.4s\n" - "fmin v31.4s, v31.4s, v15.4s\n" - - // Compute how much of the 8x8 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w2 = how many cols of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 30f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #32\n" - "b 31f\n" - "30:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "31:\n" - - // Write our 8bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "str q16, [x3, #0]\n" - "str q17, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - "str q18, [x3, #0]\n" - "str q19, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - "str q20, [x3, #0]\n" - "str q21, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - "str q22, [x3, #0]\n" - "str q23, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - "str q24, [x3, #0]\n" - "str q25, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - "str q26, [x3, #0]\n" - "str q27, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - "str q28, [x3, #0]\n" - "str q29, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - "str q30, [x3, #0]\n" - "str q31, [x3, #16]\n" - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #32\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "41:\n" - "add %[dst_ptr], %[dst_ptr], #32\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #8\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #8\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - - // w1 is the number of levels of depth that remain to load - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently depth - 1. - "sub w1, w12, #1\n" - - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "v31"); -} -#undef RUY_OFFSET_BIAS -#undef RUY_OFFSET_FLAGS -#undef RUY_OFFSET_LHS_BASE_PTR -#undef RUY_OFFSET_CLAMP_MIN -#undef RUY_OFFSET_CLAMP_MAX -#undef RUY_OFFSET_START_ROW -#undef RUY_OFFSET_LAST_ROW -#undef RUY_OFFSET_LAST_COL -#undef RUY_OFFSET_LHS_STRIDE -#undef RUY_OFFSET_RHS_STRIDE -#undef RUY_OFFSET_DST_STRIDE -#undef RUY_OFFSET_DEPTH -#undef RUY_OFFSET_START_COL -#undef RUY_OFFSET_RHS_BASE_PTR -#undef RUY_OFFSET_DST_BASE_PTR - -#endif // RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/kernel_avx2.cc b/tensorflow/lite/experimental/ruy/ruy/kernel_avx2.cc deleted file mode 100644 index 1113469fd28..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/kernel_avx2.cc +++ /dev/null @@ -1,1664 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/kernel.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -#if RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) -#include // IWYU pragma: keep -#endif - -namespace ruy { - -#if !(RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -#else // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -static constexpr int kAvx8bitBlockSize = 8; -static constexpr int kAvx8bitInnerSize = 4; - -namespace { -namespace intrin_utils { - -inline __m256 mm256_n_loadu_epi32(int n, const std::int32_t* src) { - switch (n) { - case 0: - return _mm256_setzero_si256(); - case 1: - return _mm256_setr_m128(_mm_setr_epi32(src[0], 0, 0, 0), - _mm_setzero_si128()); - case 2: - return _mm256_setr_m128(_mm_setr_epi32(src[0], src[1], 0, 0), - _mm_setzero_si128()); - case 3: - return _mm256_setr_m128(_mm_setr_epi32(src[0], src[1], src[2], 0), - _mm_setzero_si128()); - case 4: - return _mm256_castsi128_si256( - _mm_loadu_si128(reinterpret_cast<__m128i const*>(src))); - case 5: - return _mm256_setr_epi32(src[0], src[1], src[2], src[3], src[4], 0, 0, 0); - case 6: - return _mm256_setr_epi32(src[0], src[1], src[2], src[3], src[4], src[5], - 0, 0); - case 7: - return _mm256_setr_epi32(src[0], src[1], src[2], src[3], src[4], src[5], - src[6], 0); - case 8: - return _mm256_loadu_si256(reinterpret_cast<__m256i const*>(src)); - default: - RUY_DCHECK_LT(n, 9); - return _mm256_setzero_si256(); - } -} - -inline void mm256_n_storeu_cvtepi32_epi8(std::uint8_t* dst, int residual_rows, - const __m256 v) { - // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. - const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); - __m256i shuffled_v; - if (residual_rows > 1) { - // This selects 0, 4, 8, 12, 0, 4, 8, 12, ..., but we only use the first 4 - // in each 128-bit lane. - shuffled_v = _mm256_shuffle_epi8(v, repack_perm); - } - switch (residual_rows) { - case 0: - break; - case 1: - dst[0] = _mm256_extract_epi8(v, 0); - break; - case 2: - _mm_storeu_si16(dst, _mm256_extracti128_si256(shuffled_v, 0)); - break; - case 3: { - __m128i trailing_packed = _mm256_extracti128_si256(shuffled_v, 0); - _mm_storeu_si16(dst, trailing_packed); - dst[2] = _mm_extract_epi8(trailing_packed, 2); - break; - } - case 4: - _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - break; - case 5: - _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - dst[4] = _mm256_extract_epi8(shuffled_v, 16); - break; - case 6: - _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - _mm_storeu_si16(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); - break; - case 7: { - _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - __m128i trailing_packed = _mm256_extracti128_si256(shuffled_v, 1); - _mm_storeu_si16(dst + 4, trailing_packed); - dst[6] = _mm_extract_epi8(trailing_packed, 2); - break; - } - case 8: - _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - _mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); - break; - default: - RUY_DCHECK_LE(residual_rows, 8); - break; - } -} - -inline void mm256_storeu_cvtepi32_epi8(std::uint8_t* dst, const __m256 v) { - // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. - const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); - const __m256i shuffled_v = _mm256_shuffle_epi8(v, repack_perm); - _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - _mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); -} - -inline void mm256_n_storeu_cvtepi32_epi8(std::int8_t* dst, int residual_rows, - const __m256 v) { - intrin_utils::mm256_n_storeu_cvtepi32_epi8( - reinterpret_cast(dst), residual_rows, v); -} - -inline void mm256_storeu_cvtepi32_epi8(std::int8_t* dst, const __m256 v) { - // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. - const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); - const __m256i shuffled_v = _mm256_shuffle_epi8(v, repack_perm); - _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - _mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); -} - -inline void mm256_n_storeu_cvtepi32_epi16(std::int16_t* dst, int residual_rows, - const __m256 v) { - // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively - // truncating each 16-bit integer. - const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100); - __m256i shuffled_v; - __m128i shuffled_v_low; - if (residual_rows > 1) { - shuffled_v = _mm256_shuffle_epi8(v, repack_perm); - shuffled_v_low = _mm256_extracti128_si256(shuffled_v, 0); - } else { - shuffled_v_low = _mm256_extracti128_si256(v, 0); - } - switch (residual_rows) { - case 0: - break; - case 1: - _mm_storeu_si16(dst, shuffled_v_low); - break; - case 2: - _mm_storeu_si32(dst, shuffled_v_low); - break; - case 3: { - _mm_storeu_si32(dst, shuffled_v_low); - dst[2] = _mm_extract_epi16(shuffled_v_low, 2); - break; - } - case 4: - _mm_storeu_si64(dst, shuffled_v_low); - break; - case 5: - _mm_storeu_si64(dst, shuffled_v_low); - dst[4] = _mm256_extract_epi16(shuffled_v, 8); - break; - case 6: - _mm_storeu_si64(dst, shuffled_v_low); - _mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); - break; - case 7: { - _mm_storeu_si64(dst, shuffled_v_low); - __m128i trailing_packed = _mm256_extracti128_si256(shuffled_v, 1); - _mm_storeu_si32(dst + 4, trailing_packed); - dst[6] = _mm_extract_epi16(trailing_packed, 2); - break; - } - case 8: - _mm_storeu_si64(dst, _mm256_extracti128_si256(shuffled_v, 0)); - _mm_storeu_si64(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); - break; - default: - RUY_DCHECK_LE(residual_rows, 8); - break; - } -} - -inline void mm256_storeu_cvtepi32_epi16(std::int16_t* dst, const __m256 v) { - // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively - // truncating each 16-bit integer. - const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100); - const __m256i shuffled_v = _mm256_shuffle_epi8(v, repack_perm); - _mm_storeu_si64(dst, _mm256_extracti128_si256(shuffled_v, 0)); - _mm_storeu_si64(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); -} - -inline void mm256_n_storeu_epi32(std::int32_t* dst, int residual_rows, - const __m256 v) { - const __m128i v_low = _mm256_extracti128_si256(v, 0); - switch (residual_rows) { - case 0: - break; - case 1: - _mm_storeu_si32(dst, v_low); - break; - case 2: - _mm_storeu_si64(dst, v_low); - break; - case 3: { - __m128i trailing_packed = v_low; - _mm_storeu_si64(dst, trailing_packed); - dst[2] = _mm_extract_epi32(trailing_packed, 2); - break; - } - case 4: - _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); - break; - case 5: - _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); - dst[4] = _mm256_extract_epi32(v, 4); - break; - case 6: - _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); - _mm_storeu_si64(dst + 4, _mm256_extracti128_si256(v, 1)); - break; - case 7: { - _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); - __m128i trailing_packed = _mm256_extracti128_si256(v, 1); - _mm_storeu_si64(dst + 4, trailing_packed); - dst[6] = _mm_extract_epi32(trailing_packed, 2); - break; - } - case 8: - _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v); - break; - default: - RUY_DCHECK_LE(residual_rows, 8); - break; - } -} - -inline void mm256_storeu_epi32(std::int32_t* dst, const __m256 v) { - _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v); -} - -inline float mm256_get1_ps(const __m256 a, int i) { - __m256i ai = _mm256_castps_si256(a); - int float_val_as_int; - switch (i) { - case 0: - float_val_as_int = _mm256_extract_epi32(ai, 0); - break; - case 1: - float_val_as_int = _mm256_extract_epi32(ai, 1); - break; - case 2: - float_val_as_int = _mm256_extract_epi32(ai, 2); - break; - case 3: - float_val_as_int = _mm256_extract_epi32(ai, 3); - break; - case 4: - float_val_as_int = _mm256_extract_epi32(ai, 4); - break; - case 5: - float_val_as_int = _mm256_extract_epi32(ai, 5); - break; - case 6: - float_val_as_int = _mm256_extract_epi32(ai, 6); - break; - case 7: - float_val_as_int = _mm256_extract_epi32(ai, 7); - break; - default: - RUY_DCHECK_LT(i, 8); - return .0f; - } - return reinterpret_cast(float_val_as_int); -} - -inline __m256 mm256_n_loadu_ps(int i, const float* src) { - switch (i) { - case 0: - return _mm256_setzero_ps(); - case 1: - return _mm256_setr_m128(_mm_setr_ps(src[0], .0f, .0f, .0f), - _mm_setzero_ps()); - case 2: - return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], .0f, .0f), - _mm_setzero_ps()); - case 3: - return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], src[2], .0f), - _mm_setzero_ps()); - case 4: - return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], src[2], src[3]), - _mm_setzero_ps()); - case 5: - return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], .0f, .0f, - .0f); - case 6: - return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], src[5], .0f, - .0f); - case 7: - return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], src[5], - src[6], .0f); - case 8: - return _mm256_loadu_ps(src); - default: - RUY_DCHECK_LT(i, 9); - return _mm256_setzero_ps(); - } -} - -inline void mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) { - for (int i = 0; i < residual_rows; ++i) { - dst[i] = intrin_utils::mm256_get1_ps(v, i); - } -} -} // namespace intrin_utils -} // namespace - -void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { - profiler::ScopeLabel label("Kernel kAvx2 8-bit"); - const std::int8_t splitter_idx_data[32] = { - 0, 1, 4, 5, 8, 9, 12, 13, // - 2, 3, 6, 7, 10, 11, 14, 15, // - 0, 1, 4, 5, 8, 9, 12, 13, // - 2, 3, 6, 7, 10, 11, 14, 15 // - }; - - std::int32_t dst_stride; - if ((params.dst_type_id == DstTypeId::kValue) || - (params.dst_type_id == DstTypeId::kValue)) { - dst_stride = params.dst_stride; - } else if (params.dst_type_id == DstTypeId::kValue) { - dst_stride = params.dst_stride / sizeof(std::int16_t); - } else if (params.dst_type_id == DstTypeId::kValue) { - dst_stride = params.dst_stride / sizeof(std::int32_t); - } else { - RUY_DCHECK(false); - } - - int bias_ptr_block_increment = - params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; - - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - void* dst_col_ptr = params.dst_base_ptr; - const std::int32_t* bias_col_ptr = params.bias; - if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { - bias_col_ptr += params.start_row; - } - - for (int col = params.start_col; col <= params.last_col; - col += kAvx8bitBlockSize) { - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - void* dst_ptr = dst_col_ptr; - const std::int32_t* bias_ptr = bias_col_ptr; - - const std::int32_t lhs_zero_point = params.lhs_zero_point; - const bool has_rhs_sums_offsets = - (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; - std::int32_t rhs_sums_offsets[8]; - if (has_rhs_sums_offsets) { - const __m256i rhs_sums_offset_v = _mm256_mullo_epi32( - _mm256_set1_epi32(lhs_zero_point), - _mm256_loadu_si256( - reinterpret_cast<__m256i const*>(¶ms.rhs_sums[col]))); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets), - rhs_sums_offset_v); - } - - for (int row = params.start_row; row <= params.last_row; - row += kAvx8bitBlockSize) { - const int residual_rows = - std::min(params.dst_rows - row, kAvx8bitBlockSize); - const int residual_cols = - std::min(params.dst_cols - col, kAvx8bitBlockSize); - - const __m256i splitter_idx = _mm256_loadu_si256( - reinterpret_cast<__m256i const*>(splitter_idx_data)); - - __m256i accum_data_v0; - __m256i accum_data_v1; - __m256i accum_data_v2; - __m256i accum_data_v3; - __m256i accum_data_v4; - __m256i accum_data_v5; - __m256i accum_data_v6; - __m256i accum_data_v7; - - // Initialize with bias. - __m256i initial_accum_data = - intrin_utils::mm256_n_loadu_epi32(residual_rows, bias_ptr); - bias_ptr += bias_ptr_block_increment; - - // Adjustments common across columns. - const std::int32_t rhs_zero_point = params.rhs_zero_point; - if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { - const __m256i lhs_sums_offset = _mm256_mullo_epi32( - _mm256_set1_epi32(rhs_zero_point), - _mm256_loadu_si256( - reinterpret_cast<__m256i const*>(¶ms.lhs_sums[row]))); - initial_accum_data = - _mm256_sub_epi32(initial_accum_data, lhs_sums_offset); - } - const std::int32_t prod_zp_depth = params.prod_zp_depth; - if (prod_zp_depth) { - initial_accum_data = _mm256_add_epi32(initial_accum_data, - _mm256_set1_epi32(prod_zp_depth)); - } - - // Adjustments differing across columns. - if (has_rhs_sums_offsets) { - accum_data_v0 = _mm256_sub_epi32( - initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0])); - accum_data_v1 = _mm256_sub_epi32( - initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[1])); - accum_data_v2 = _mm256_sub_epi32( - initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[2])); - accum_data_v3 = _mm256_sub_epi32( - initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[3])); - accum_data_v4 = _mm256_sub_epi32( - initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[4])); - accum_data_v5 = _mm256_sub_epi32( - initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[5])); - accum_data_v6 = _mm256_sub_epi32( - initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[6])); - accum_data_v7 = _mm256_sub_epi32( - initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[7])); - } else { - accum_data_v0 = initial_accum_data; - accum_data_v1 = initial_accum_data; - accum_data_v2 = initial_accum_data; - accum_data_v3 = initial_accum_data; - accum_data_v4 = initial_accum_data; - accum_data_v5 = initial_accum_data; - accum_data_v6 = initial_accum_data; - accum_data_v7 = initial_accum_data; - } - - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { - const __m256i lhs_data = - _mm256_load_si256(reinterpret_cast(lhs_ptr)); - const __m256i rhs_data_8bit = - _mm256_load_si256(reinterpret_cast(rhs_ptr)); - - // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. - std::int32_t rhs_data[16]; - const __m128i rhs_data_bottom_lane = - _mm256_castsi256_si128(rhs_data_8bit); - const __m128i rhs_data_top_lane = - _mm256_extracti128_si256(rhs_data_8bit, 1); - const __m256i rhs_16_bit_dup_low = - _mm256_cvtepi8_epi16(rhs_data_bottom_lane); - const __m256i rhs_16_bit_dup_high = - _mm256_cvtepi8_epi16(rhs_data_top_lane); - // Now that we have cast the RHS data, we store it so that each value - // can be separately loaded in the accumulation loop. - _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data), - rhs_16_bit_dup_low); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data + 8), - rhs_16_bit_dup_high); - - // NOTE: There may be opportunities for permuting the data in the - // packing code instead of here. - const __m256i lhs_data_split = - _mm256_shuffle_epi8(lhs_data, splitter_idx); - const __m256i lhs_data_split_expand_bottom = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0)); - const __m256i lhs_data_split_expand_top = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1)); - - // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit. - const __m256i lhs_16_bit_low = _mm256_permute2x128_si256( - lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20); - // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit. - const __m256i lhs_16_bit_high = _mm256_permute2x128_si256( - lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31); - // Accumulate for column 0. - { - const std::int32_t low_rhs_value = rhs_data[0]; - const std::int32_t high_rhs_value = rhs_data[1]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v0 = _mm256_add_epi32( - accum_data_v0, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v0 = _mm256_add_epi32( - accum_data_v0, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 1. - { - const std::int32_t low_rhs_value = rhs_data[2]; - const std::int32_t high_rhs_value = rhs_data[3]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v1 = _mm256_add_epi32( - accum_data_v1, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v1 = _mm256_add_epi32( - accum_data_v1, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 2. - { - const std::int32_t low_rhs_value = rhs_data[4]; - const std::int32_t high_rhs_value = rhs_data[5]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v2 = _mm256_add_epi32( - accum_data_v2, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v2 = _mm256_add_epi32( - accum_data_v2, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 3. - { - const std::int32_t low_rhs_value = rhs_data[6]; - const std::int32_t high_rhs_value = rhs_data[7]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v3 = _mm256_add_epi32( - accum_data_v3, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v3 = _mm256_add_epi32( - accum_data_v3, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 4. - { - const std::int32_t low_rhs_value = rhs_data[8]; - const std::int32_t high_rhs_value = rhs_data[9]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v4 = _mm256_add_epi32( - accum_data_v4, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v4 = _mm256_add_epi32( - accum_data_v4, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 5. - { - const std::int32_t low_rhs_value = rhs_data[10]; - const std::int32_t high_rhs_value = rhs_data[11]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v5 = _mm256_add_epi32( - accum_data_v5, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v5 = _mm256_add_epi32( - accum_data_v5, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 6. - { - const std::int32_t low_rhs_value = rhs_data[12]; - const std::int32_t high_rhs_value = rhs_data[13]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v6 = _mm256_add_epi32( - accum_data_v6, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v6 = _mm256_add_epi32( - accum_data_v6, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 7. - { - const std::int32_t low_rhs_value = rhs_data[14]; - const std::int32_t high_rhs_value = rhs_data[15]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v7 = _mm256_add_epi32( - accum_data_v7, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v7 = _mm256_add_epi32( - accum_data_v7, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - - lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - } - - if (params.dst_type_id != DstTypeId::kValue) { - __m256i m_vector; - __m256i e_vector; - // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. - if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { - m_vector = intrin_utils::mm256_n_loadu_epi32( - residual_rows, ¶ms.multiplier_fixedpoint[row]); - e_vector = intrin_utils::mm256_n_loadu_epi32( - residual_rows, ¶ms.multiplier_exponent[row]); - } else { - // These arrays have size LhsCols, and are pre-filled. - m_vector = _mm256_set1_epi32(params.multiplier_fixedpoint[0]); - e_vector = _mm256_set1_epi32(params.multiplier_exponent[0]); - } - - const __m256i m_64bit_low = - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0)); - const __m256i m_64bit_high = - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1)); - - const __m256i zero_vector = _mm256_setzero_si256(); - const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector); - const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector); - const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector); - const __m256i final_right_shift = - _mm256_add_epi32(right_shift, _mm256_set1_epi32(31)); - const __m256i final_right_shift_low = _mm256_cvtepi32_epi64( - _mm256_extracti128_si256(final_right_shift, 0)); - const __m256i final_right_shift_high = _mm256_cvtepi32_epi64( - _mm256_extracti128_si256(final_right_shift, 1)); - // Really we want 0x100000000, but use half to avoid overflowing. - const __m256i convert_to_signed_halved = - _mm256_srlv_epi32(_mm256_set1_epi32(0x80000000), right_shift); - const __m256i convert_to_unsigned_64 = - _mm256_set1_epi64x(0x8000000000000000); - - __m256i post_scaling_offset = _mm256_add_epi32( - convert_to_signed_halved, convert_to_signed_halved); - - const __m256i offset_vector = - _mm256_slli_epi64(_mm256_set1_epi64x(1), 30); - // Really these should be shifted by neg_e_vector, but tests pass when - // using right_shift. - const __m256i offset_vector_low = _mm256_add_epi64( - _mm256_sllv_epi64(offset_vector, - _mm256_cvtepi32_epi64( - _mm256_extracti128_si256(right_shift, 0))), - convert_to_unsigned_64); - const __m256i offset_vector_high = _mm256_add_epi64( - _mm256_sllv_epi64(offset_vector, - _mm256_cvtepi32_epi64( - _mm256_extracti128_si256(right_shift, 1))), - convert_to_unsigned_64); - - if (params.dst_zero_point) { - const __m256i dst_zero_point = - _mm256_set1_epi32(params.dst_zero_point); - // The post-scaling offset is subtracted later, so this has the effect - // of adding the zero point. - post_scaling_offset = - _mm256_sub_epi32(post_scaling_offset, dst_zero_point); - } - -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - RUY_DCHECK(false); -#endif - const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7); - - // We cannot do - // - // scaled_v_low = - // _mm256_srav_epi64(scaled_v_low, final_right_shift_low); - // scaled_v_high = - // _mm256_srav_epi64(scaled_v_high, final_right_shift_high); - // - // since this instruction is not in AVX2. Instead we use - // _mm256_srlv_epi64, but this is an unsigned shift, so we applied - // offsets before (convert_to_unsigned_64) and after - // (convert_to_signed_halved). - // - // The overall process is, for 64-bit scaled accumulator: - // unsigned_accum = signed_accum + 1 << 63; - // unsigned_accum = (unsigned_accum >> right_shift) >> 31; - // signed_accum = unsigned_accum - ((1 << 32) >> right_shift) / 2 * 2; - - // There are various ways to repack the results, in the absence of - // _mm256_cvtepi64_epi32() or anything like it. - // A. - // accum_data_v[j] = - // _mm256_set_epi32(_mm256_extract_epi32(scaled_v_high, 6), - // _mm256_extract_epi32(scaled_v_high, 4), - // _mm256_extract_epi32(scaled_v_high, 2), - // _mm256_extract_epi32(scaled_v_high, 0), - // _mm256_extract_epi32(scaled_v_low, 6), - // _mm256_extract_epi32(scaled_v_low, 4), - // _mm256_extract_epi32(scaled_v_low, 2), - // _mm256_extract_epi32(scaled_v_low, 0)); - // B. - // scaled_v_low = _mm256_shuffle_epi32(scaled_v_low, 0xd8); - // scaled_v_high = _mm256_shuffle_epi32(scaled_v_high, 0xd8); - // accum_data_v[j] = - // _mm256_set_epi64x(_mm256_extract_epi64(scaled_v_high, 2), - // _mm256_extract_epi64(scaled_v_high, 0), - // _mm256_extract_epi64(scaled_v_low, 2), - // _mm256_extract_epi64(scaled_v_low, 0)); - // C. - // scaled_v_low = - // _mm256_permutevar8x32_epi32(scaled_v_low, repack_perm); - // scaled_v_high = - // _mm256_permutevar8x32_epi32(scaled_v_high, repack_perm); - // accum_data_v[j] = - // _mm256_permute2x128_si256(scaled_v_low, scaled_v_high, 0x20); - // - // However, we choose the following because it uses two lighter - // instructions. The permutation does have a longer latency, but this - // loop can be unrolled. - // D. - // scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - // __m256i results = - // _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - // results = _mm256_permutevar8x32_epi32(results, repack_perm); - // accum_data_v[j] = _mm256_sub_epi32(results, post_scaling_offset); - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v0, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v0 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v1, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v1 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v2, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v2 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v3, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v3 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v4, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v4 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v5, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v5 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v6, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v6 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v7, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v7 = _mm256_sub_epi32(results, post_scaling_offset); - } - } - const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max); - const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min); - const bool store_full_block = (residual_rows == kAvx8bitBlockSize) && - (residual_cols == kAvx8bitBlockSize); - - __m256i accum_data_v[kAvx8bitBlockSize]; - if (!store_full_block) { - accum_data_v[0] = accum_data_v0; - accum_data_v[1] = accum_data_v1; - accum_data_v[2] = accum_data_v2; - accum_data_v[3] = accum_data_v3; - accum_data_v[4] = accum_data_v4; - accum_data_v[5] = accum_data_v5; - accum_data_v[6] = accum_data_v6; - accum_data_v[7] = accum_data_v7; - } - - if (params.dst_type_id == DstTypeId::kValue) { - std::int8_t* tmp_ptr = static_cast(dst_ptr); - if (store_full_block) { - accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v); - accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v); - accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v); - accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v); - accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v); - accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v); - accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v); - accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v); - accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v); - accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v); - accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v); - accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v); - accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v); - accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v); - accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v); - accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[0 * dst_stride], - accum_data_v0); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[1 * dst_stride], - accum_data_v1); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[2 * dst_stride], - accum_data_v2); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[3 * dst_stride], - accum_data_v3); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[4 * dst_stride], - accum_data_v4); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[5 * dst_stride], - accum_data_v5); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[6 * dst_stride], - accum_data_v6); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[7 * dst_stride], - accum_data_v7); - } else { - for (int j = 0; j < residual_cols; ++j) { - __m256 result = accum_data_v[j]; - result = _mm256_min_epi32(result, clamp_max_v); - result = _mm256_max_epi32(result, clamp_min_v); - intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows, - result); - tmp_ptr += dst_stride; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::uint8_t* tmp_ptr = static_cast(dst_ptr); - if (store_full_block) { - accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v); - accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v); - accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v); - accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v); - accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v); - accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v); - accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v); - accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v); - accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v); - accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v); - accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v); - accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v); - accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v); - accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v); - accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v); - accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[0], accum_data_v0); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[dst_stride], - accum_data_v1); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[2 * dst_stride], - accum_data_v2); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[3 * dst_stride], - accum_data_v3); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[4 * dst_stride], - accum_data_v4); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[5 * dst_stride], - accum_data_v5); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[6 * dst_stride], - accum_data_v6); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[7 * dst_stride], - accum_data_v7); - } else { - for (int j = 0; j < residual_cols; ++j) { - __m256 result = accum_data_v[j]; - result = _mm256_min_epi32(result, clamp_max_v); - result = _mm256_max_epi32(result, clamp_min_v); - intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows, - result); - tmp_ptr += dst_stride; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::int16_t* tmp_ptr = static_cast(dst_ptr); - if (store_full_block) { - accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v); - accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v); - accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v); - accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v); - accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v); - accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v); - accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v); - accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v); - accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v); - accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v); - accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v); - accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v); - accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v); - accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v); - accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v); - accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[0], accum_data_v0); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[dst_stride], - accum_data_v1); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[2 * dst_stride], - accum_data_v2); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[3 * dst_stride], - accum_data_v3); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[4 * dst_stride], - accum_data_v4); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[5 * dst_stride], - accum_data_v5); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[6 * dst_stride], - accum_data_v6); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[7 * dst_stride], - accum_data_v7); - } else { - for (int j = 0; j < residual_cols; ++j) { - __m256 result = accum_data_v[j]; - result = _mm256_min_epi32(result, clamp_max_v); - result = _mm256_max_epi32(result, clamp_min_v); - intrin_utils::mm256_n_storeu_cvtepi32_epi16(tmp_ptr, residual_rows, - result); - tmp_ptr += dst_stride; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - if (store_full_block) { - std::int32_t* tmp_ptr = static_cast(dst_ptr); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[0], accum_data_v0); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[dst_stride], accum_data_v1); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[2 * dst_stride], - accum_data_v2); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[3 * dst_stride], - accum_data_v3); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[4 * dst_stride], - accum_data_v4); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[5 * dst_stride], - accum_data_v5); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[6 * dst_stride], - accum_data_v6); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[7 * dst_stride], - accum_data_v7); - } else { - std::int32_t* dst_block_ptr = static_cast(dst_ptr); - for (int j = 0; j < residual_cols; ++j) { - intrin_utils::mm256_n_storeu_epi32(dst_block_ptr, residual_rows, - accum_data_v[j]); - dst_block_ptr += dst_stride; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else { - RUY_DCHECK(false); - } - - lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; - } // End row-block loop. - - dst_col_ptr = static_cast(static_cast(dst_col_ptr) + - kAvx8bitBlockSize * params.dst_stride); - rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; - } // End col-block loop. -} // NOLINT(readability/fn_size) - -void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) { - profiler::ScopeLabel label("Kernel kAvx2 8-bit GEMV"); - - RUY_DCHECK_EQ(params.dst_cols, 1); - RUY_DCHECK_EQ(params.last_col, 0); - RUY_DCHECK_EQ(params.start_col, 0); - - const std::int8_t splitter_idx_data[32] = { - 0, 1, 4, 5, 8, 9, 12, 13, // - 2, 3, 6, 7, 10, 11, 14, 15, // - 0, 1, 4, 5, 8, 9, 12, 13, // - 2, 3, 6, 7, 10, 11, 14, 15 // - }; - - int bias_ptr_block_increment = - params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; - - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - void* dst_col_ptr = params.dst_base_ptr; - const std::int32_t* bias_col_ptr = params.bias; - if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { - bias_col_ptr += params.start_row; - } - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - void* dst_ptr = dst_col_ptr; - const std::int32_t* bias_ptr = bias_col_ptr; - - const std::int32_t lhs_zero_point = params.lhs_zero_point; - const bool has_rhs_sums_offsets = - (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; - std::int32_t rhs_sums_offsets[8]; - if (has_rhs_sums_offsets) { - const __m256i rhs_sums_offset_v = _mm256_mullo_epi32( - _mm256_set1_epi32(lhs_zero_point), - _mm256_loadu_si256( - reinterpret_cast<__m256i const*>(¶ms.rhs_sums[0]))); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets), - rhs_sums_offset_v); - } - - for (int row = params.start_row; row <= params.last_row; - row += kAvx8bitBlockSize) { - const int residual_rows = - std::min(params.dst_rows - row, kAvx8bitBlockSize); - - const __m256i splitter_idx = - _mm256_loadu_si256(reinterpret_cast<__m256i const*>(splitter_idx_data)); - - __m256i accum_data_v0; - - // Initialize with bias. - __m256i initial_accum_data = - intrin_utils::mm256_n_loadu_epi32(residual_rows, bias_ptr); - bias_ptr += bias_ptr_block_increment; - - // Adjustments common across columns. - const std::int32_t rhs_zero_point = params.rhs_zero_point; - if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { - const __m256i lhs_sums_offset = _mm256_mullo_epi32( - _mm256_set1_epi32(rhs_zero_point), - _mm256_loadu_si256( - reinterpret_cast<__m256i const*>(¶ms.lhs_sums[row]))); - initial_accum_data = - _mm256_sub_epi32(initial_accum_data, lhs_sums_offset); - } - const std::int32_t prod_zp_depth = params.prod_zp_depth; - if (prod_zp_depth) { - initial_accum_data = _mm256_add_epi32(initial_accum_data, - _mm256_set1_epi32(prod_zp_depth)); - } - - // Adjustments differing across columns. - if (has_rhs_sums_offsets) { - accum_data_v0 = _mm256_sub_epi32(initial_accum_data, - _mm256_set1_epi32(rhs_sums_offsets[0])); - } else { - accum_data_v0 = initial_accum_data; - } - - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { - const __m256i lhs_data = - _mm256_load_si256(reinterpret_cast(lhs_ptr)); - const __m128i rhs_data_8bit = _mm_loadu_si32(rhs_ptr); - - // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. - // For simplicity we load 4x the data that we need and process twice the - // data that we need and store only the data we need. - std::int32_t rhs_data[2]; - const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); - // Now that we have cast the RHS data, we store it so that each value - // can be separately loaded in the accumulation loop. - _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup); - - // NOTE: There may be opportunities for permuting the data in the packing - // code instead of here. - const __m256i lhs_data_split = - _mm256_shuffle_epi8(lhs_data, splitter_idx); - const __m256i lhs_data_split_expand_bottom = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0)); - const __m256i lhs_data_split_expand_top = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1)); - - // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit. - const __m256i lhs_16_bit_low = _mm256_permute2x128_si256( - lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20); - // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit. - const __m256i lhs_16_bit_high = _mm256_permute2x128_si256( - lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31); - // Accumulate for column 0. - const std::int32_t low_rhs_value = rhs_data[0]; - const std::int32_t high_rhs_value = rhs_data[1]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v0 = _mm256_add_epi32( - accum_data_v0, _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v0 = _mm256_add_epi32( - accum_data_v0, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - - lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - } - - if (params.dst_type_id != DstTypeId::kValue) { - __m256i m_vector; - __m256i e_vector; - // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. - if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { - m_vector = intrin_utils::mm256_n_loadu_epi32( - residual_rows, ¶ms.multiplier_fixedpoint[row]); - e_vector = intrin_utils::mm256_n_loadu_epi32( - residual_rows, ¶ms.multiplier_exponent[row]); - } else { - // These arrays have size LhsCols, and are pre-filled. - m_vector = _mm256_set1_epi32(params.multiplier_fixedpoint[0]); - e_vector = _mm256_set1_epi32(params.multiplier_exponent[0]); - } - - const __m256i m_64bit_low = - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0)); - const __m256i m_64bit_high = - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1)); - - const __m256i zero_vector = _mm256_setzero_si256(); - const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector); - const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector); - const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector); - const __m256i final_right_shift = - _mm256_add_epi32(right_shift, _mm256_set1_epi32(31)); - const __m256i final_right_shift_low = - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 0)); - const __m256i final_right_shift_high = - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 1)); - // Really we want 0x100000000, but use half to avoid overflowing. - const __m256i convert_to_signed_halved = - _mm256_srlv_epi32(_mm256_set1_epi32(0x80000000), right_shift); - const __m256i convert_to_unsigned_64 = - _mm256_set1_epi64x(0x8000000000000000); - - __m256i post_scaling_offset = - _mm256_add_epi32(convert_to_signed_halved, convert_to_signed_halved); - - const __m256i offset_vector = - _mm256_slli_epi64(_mm256_set1_epi64x(1), 30); - // Really these should be shifted by neg_e_vector, but tests pass when - // using right_shift. - const __m256i offset_vector_low = _mm256_add_epi64( - _mm256_sllv_epi64( - offset_vector, - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(right_shift, 0))), - convert_to_unsigned_64); - const __m256i offset_vector_high = _mm256_add_epi64( - _mm256_sllv_epi64( - offset_vector, - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(right_shift, 1))), - convert_to_unsigned_64); - - if (params.dst_zero_point) { - const __m256i dst_zero_point = _mm256_set1_epi32(params.dst_zero_point); - // The post-scaling offset is subtracted later, so this has the effect - // of adding the zero point. - post_scaling_offset = - _mm256_sub_epi32(post_scaling_offset, dst_zero_point); - } - -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - RUY_DCHECK(false); -#endif - const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7); - - // See GEMM version for details of this process. - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v0, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v0 = _mm256_sub_epi32(results, post_scaling_offset); - } - } - const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max); - const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min); - - if (params.dst_type_id == DstTypeId::kValue) { - std::int8_t* tmp_ptr = static_cast(dst_ptr); - __m256 result = accum_data_v0; - result = _mm256_min_epi32(result, clamp_max_v); - result = _mm256_max_epi32(result, clamp_min_v); - intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows, - result); - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::uint8_t* tmp_ptr = static_cast(dst_ptr); - __m256 result = accum_data_v0; - result = _mm256_min_epi32(result, clamp_max_v); - result = _mm256_max_epi32(result, clamp_min_v); - intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows, - result); - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::int16_t* tmp_ptr = static_cast(dst_ptr); - __m256 result = accum_data_v0; - result = _mm256_min_epi32(result, clamp_max_v); - result = _mm256_max_epi32(result, clamp_min_v); - intrin_utils::mm256_n_storeu_cvtepi32_epi16(tmp_ptr, residual_rows, - result); - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::int32_t* dst_block_ptr = static_cast(dst_ptr); - intrin_utils::mm256_n_storeu_epi32(dst_block_ptr, residual_rows, - accum_data_v0); - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else { - RUY_DCHECK(false); - } - - lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; - } // End row-block loop. - - dst_col_ptr = static_cast(static_cast(dst_col_ptr) + - kAvx8bitBlockSize * params.dst_stride); - rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; -} // NOLINT(readability/fn_size) - -void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) { - profiler::ScopeLabel label("Kernel kAvx2 float"); - - // As parameters are defined, we need to scale by sizeof(float). - const std::int64_t lhs_stride = params.lhs_stride >> 2; - const std::int64_t dst_stride = params.dst_stride >> 2; - const std::int64_t rhs_stride = params.rhs_stride >> 2; - // - int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; - // AVX2 float block size = 8. - const int end_row = std::min(params.dst_rows, params.last_row + 8); - const int end_col = std::min(params.dst_cols, params.last_col + 8); - // - const float* adj_rhs_col_ptr = - params.rhs_base_ptr - params.start_col * rhs_stride; - float* adj_dst_col_ptr = - params.dst_base_ptr - params.start_col * dst_stride - params.start_row; - const float* adj_lhs_col_ptr = - params.lhs_base_ptr - params.start_row * lhs_stride; - const float* bias_col_ptr = params.bias; - - const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max); - const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min); - - int col = params.start_col; - // Loop through cols by float block size, leaving incomplete remainder - for (; col <= end_col - 8; col += 8) { - __m256 accum_data_v[8]; - - const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; - float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; - - for (int row = params.start_row; row < end_row; row += 8) { - const int residual_rows = std::min(end_row - row, 8); - - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - const __m256 initial_accum_data = - intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr); - - for (int j = 0; j < 8; ++j) { - accum_data_v[j] = initial_accum_data; - } - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; ++d) { - const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); - // In this version RHS values are loaded individually rather than first - // loading together and then extract with broadcasting. This is because - // AVX flavours and instrinsics and compilers in combination do not - // handle this pattern of extraction very well. - const float* rhs_data = rhs_ptr; - - for (int j = 0; j < 8; ++j) { - const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]); - accum_data_v[j] = - _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]); - } - lhs_ptr += 8; - rhs_ptr += 8; - } - - if (residual_rows == 8) { - for (int j = 0; j < 8; ++j) { - float* block_ptr = dst_ptr + j * dst_stride; - accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); - accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); - _mm256_storeu_ps(block_ptr, accum_data_v[j]); - } - } else { - for (int j = 0; j < 8; ++j) { - float* block_ptr = dst_ptr + j * dst_stride; - accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); - accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); - intrin_utils::mm256_n_storeu_ps(block_ptr, residual_rows, - accum_data_v[j]); - } - } - } // End row-block loop. - } // End col-block loop. - - if (col < end_col) { - // Remaining cols in [0, float block size). - RUY_DCHECK_GE(end_col - col, 0); - RUY_DCHECK_LT(end_col - col, 8); - - __m256 accum_data_v[8]; - - const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; - float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; - const int residual_cols = std::min(end_col - col, 8); - - for (int row = params.start_row; row < end_row; row += 8) { - const int residual_rows = std::min(end_row - row, 8); - - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - const __m256 initial_accum_data = - intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr); - - for (int j = 0; j < 8; ++j) { - accum_data_v[j] = initial_accum_data; - } - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; ++d) { - const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - - for (int j = 0; j < 8; ++j) { - const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]); - accum_data_v[j] = - _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]); - } - lhs_ptr += 8; - rhs_ptr += 8; - } - - for (int j = 0; j < residual_cols; ++j) { - float* block_ptr = dst_ptr + j * dst_stride; - accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); - accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); - intrin_utils::mm256_n_storeu_ps(block_ptr, residual_rows, - accum_data_v[j]); - } - } // End row-block loop. - } // End col-block terminal conditional. -} - -void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) { - profiler::ScopeLabel label("Kernel kAvx2 float GEMV"); - - RUY_DCHECK_EQ(params.dst_cols, 1); - RUY_DCHECK_EQ(params.last_col, 0); - RUY_DCHECK_EQ(params.start_col, 0); - - // As parameters are defined, we need to scale by sizeof(float). - const std::int64_t lhs_stride = params.lhs_stride >> 2; - // - int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; - // AVX2 float block size = 8. - const int end_row = std::min(params.dst_rows, params.last_row + 8); - - float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row; - const float* adj_lhs_col_ptr = - params.lhs_base_ptr - params.start_row * lhs_stride; - const float* bias_col_ptr = params.bias; - - const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max); - const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min); - - __m256 accum_data_v; - - const float* rhs_col_ptr = params.rhs_base_ptr; - float* dst_col_ptr = adj_dst_col_ptr; - - int row = params.start_row; - for (; row <= end_row - 8; row += 8) { - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - accum_data_v = _mm256_loadu_ps(bias_ptr); - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - int d = 0; - for (; d <= params.depth - 4; d += 4) { - const __m256 lhs_data_0 = _mm256_loadu_ps(lhs_ptr); - const __m256 dup_rhs_element_0 = _mm256_set1_ps(rhs_ptr[0]); - accum_data_v = - _mm256_fmadd_ps(lhs_data_0, dup_rhs_element_0, accum_data_v); - const __m256 dup_rhs_element_1 = _mm256_set1_ps(rhs_ptr[8]); - const __m256 lhs_data_1 = _mm256_loadu_ps(lhs_ptr + 8); - accum_data_v = - _mm256_fmadd_ps(lhs_data_1, dup_rhs_element_1, accum_data_v); - - const __m256 lhs_data_2 = _mm256_loadu_ps(lhs_ptr + 16); - const __m256 dup_rhs_element_2 = _mm256_set1_ps(rhs_ptr[16]); - accum_data_v = - _mm256_fmadd_ps(lhs_data_2, dup_rhs_element_2, accum_data_v); - const __m256 dup_rhs_element_3 = _mm256_set1_ps(rhs_ptr[24]); - const __m256 lhs_data_3 = _mm256_loadu_ps(lhs_ptr + 24); - accum_data_v = - _mm256_fmadd_ps(lhs_data_3, dup_rhs_element_3, accum_data_v); - lhs_ptr += 32; // Loaded 8 * 4 floats. - rhs_ptr += 32; - } - for (; d < params.depth; ++d) { - const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - - const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]); - accum_data_v = _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); - lhs_ptr += 8; - rhs_ptr += 8; - } - - accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v); - accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v); - _mm256_storeu_ps(dst_ptr, accum_data_v); - } // End row-block loop. - - if (row < end_row) { - const int residual_rows = end_row - row; - RUY_CHECK_GE(residual_rows, 1); - RUY_CHECK_LT(residual_rows, 8); - - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - accum_data_v = intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr); - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; ++d) { - const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - - const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]); - accum_data_v = _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); - lhs_ptr += 8; - rhs_ptr += 8; - } - - accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v); - accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v); - intrin_utils::mm256_n_storeu_ps(dst_ptr, residual_rows, accum_data_v); - } // End handling of residual rows. -} - -#endif // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/kernel_avx512.cc b/tensorflow/lite/experimental/ruy/ruy/kernel_avx512.cc deleted file mode 100644 index e51876fcc02..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/kernel_avx512.cc +++ /dev/null @@ -1,1820 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/kernel.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -#if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) -#include // IWYU pragma: keep -#endif - -namespace ruy { - -#if !(RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -#else // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { - profiler::ScopeLabel label("Kernel kAvx512 8-bit"); - - std::int32_t dst_stride; - if ((params.dst_type_id == DstTypeId::kValue) || - (params.dst_type_id == DstTypeId::kValue)) { - dst_stride = params.dst_stride; - } else if (params.dst_type_id == DstTypeId::kValue) { - dst_stride = params.dst_stride / sizeof(std::int16_t); - } else if (params.dst_type_id == DstTypeId::kValue) { - dst_stride = params.dst_stride / sizeof(std::int32_t); - } else { - RUY_DCHECK(false); - } - - int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0; - - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - void* dst_col_ptr = params.dst_base_ptr; - const std::int32_t* bias_col_ptr = params.bias; - if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { - bias_col_ptr += params.start_row; - } - - for (int col = params.start_col; col <= params.last_col; col += 16) { - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - void* dst_ptr = dst_col_ptr; - const std::int32_t* bias_ptr = bias_col_ptr; - - const std::int32_t lhs_zero_point = params.lhs_zero_point; - const bool has_rhs_sums_offsets = - (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; - std::int32_t rhs_sums_offsets[16]; - if (has_rhs_sums_offsets) { - const __m512i rhs_sums_offset_v = - _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point), - _mm512_loadu_epi32(¶ms.rhs_sums[col])); - _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets), - rhs_sums_offset_v); - } - - for (int row = params.start_row; row <= params.last_row; row += 16) { - const int residual_rows = std::min(params.dst_rows - row, 16); - const int residual_cols = std::min(params.dst_cols - col, 16); - - __m512i accum_data_v0; - __m512i accum_data_v1; - __m512i accum_data_v2; - __m512i accum_data_v3; - __m512i accum_data_v4; - __m512i accum_data_v5; - __m512i accum_data_v6; - __m512i accum_data_v7; - __m512i accum_data_v8; - __m512i accum_data_v9; - __m512i accum_data_va; - __m512i accum_data_vb; - __m512i accum_data_vc; - __m512i accum_data_vd; - __m512i accum_data_ve; - __m512i accum_data_vf; - - // Initialize with bias. - const __mmask16 row_mask = - (static_cast(1) << residual_rows) - 1; - __m512i initial_accum_data = _mm512_maskz_loadu_epi32(row_mask, bias_ptr); - bias_ptr += bias_ptr_block_increment; - - const std::int32_t rhs_zero_point = params.rhs_zero_point; - if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { - const __m512i lhs_sums_offset = - _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point), - _mm512_loadu_epi32(¶ms.lhs_sums[row])); - initial_accum_data = - _mm512_sub_epi32(initial_accum_data, lhs_sums_offset); - } - - const std::int32_t prod_zp_depth = params.prod_zp_depth; - if (prod_zp_depth != 0) { - initial_accum_data = _mm512_add_epi32(initial_accum_data, - _mm512_set1_epi32(prod_zp_depth)); - } - - // Adjustments differing across columns. - if (has_rhs_sums_offsets) { - accum_data_v0 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[0])); - accum_data_v1 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[1])); - accum_data_v2 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[2])); - accum_data_v3 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[3])); - accum_data_v4 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[4])); - accum_data_v5 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[5])); - accum_data_v6 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[6])); - accum_data_v7 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[7])); - accum_data_v8 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[8])); - accum_data_v9 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[9])); - accum_data_va = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[10])); - accum_data_vb = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[11])); - accum_data_vc = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[12])); - accum_data_vd = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[13])); - accum_data_ve = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[14])); - accum_data_vf = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[15])); - } else { - accum_data_v0 = initial_accum_data; - accum_data_v1 = initial_accum_data; - accum_data_v2 = initial_accum_data; - accum_data_v3 = initial_accum_data; - accum_data_v4 = initial_accum_data; - accum_data_v5 = initial_accum_data; - accum_data_v6 = initial_accum_data; - accum_data_v7 = initial_accum_data; - accum_data_v8 = initial_accum_data; - accum_data_v9 = initial_accum_data; - accum_data_va = initial_accum_data; - accum_data_vb = initial_accum_data; - accum_data_vc = initial_accum_data; - accum_data_vd = initial_accum_data; - accum_data_ve = initial_accum_data; - accum_data_vf = initial_accum_data; - } - - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; d += 4) { - const __m512i lhs_data = _mm512_loadu_epi8(lhs_ptr); - __m512i rhs_data_8bit = _mm512_loadu_epi8(rhs_ptr); - - // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. - std::int32_t rhs_data[32]; - const __m256i rhs_data_bottom_lane = - _mm512_castsi512_si256(rhs_data_8bit); - const __m256i rhs_data_top_lane = - _mm512_extracti32x8_epi32(rhs_data_8bit, 1); - const __m512i rhs_16_bit_dup_low = - _mm512_cvtepi8_epi16(rhs_data_bottom_lane); - const __m512i rhs_16_bit_dup_high = - _mm512_cvtepi8_epi16(rhs_data_top_lane); - // Now that we have cast the RHS data, we store it so that each value - // can be separately loaded in the accumulation loop. - _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data), - rhs_16_bit_dup_low); - _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data + 16), - rhs_16_bit_dup_high); - - // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit. - const __m512i lhs_16_bit_low = - _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data)); - // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit. - const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16( - _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16))); - - // Process column 0. - { - __m512i accum_v = accum_data_v0; - constexpr int index = 0; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v0 = accum_v; - } - // Process column 1. - { - __m512i accum_v = accum_data_v1; - constexpr int index = 2; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v1 = accum_v; - } - // Process column 2. - { - __m512i accum_v = accum_data_v2; - constexpr int index = 4; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v2 = accum_v; - } - // Process column 3. - { - __m512i accum_v = accum_data_v3; - constexpr int index = 6; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v3 = accum_v; - } - // Process column 4. - { - __m512i accum_v = accum_data_v4; - constexpr int index = 8; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v4 = accum_v; - } - // Process column 5. - { - __m512i accum_v = accum_data_v5; - constexpr int index = 10; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v5 = accum_v; - } - // Process column 6. - { - __m512i accum_v = accum_data_v6; - constexpr int index = 12; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v6 = accum_v; - } - // Process column 7. - { - __m512i accum_v = accum_data_v7; - constexpr int index = 14; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v7 = accum_v; - } - // Process column 8. - { - __m512i accum_v = accum_data_v8; - constexpr int index = 16; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v8 = accum_v; - } - // Process column 9. - { - __m512i accum_v = accum_data_v9; - constexpr int index = 18; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v9 = accum_v; - } - // Process column 10. - { - __m512i accum_v = accum_data_va; - constexpr int index = 20; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_va = accum_v; - } - // Process column 11. - { - __m512i accum_v = accum_data_vb; - constexpr int index = 22; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_vb = accum_v; - } - // Process column 12. - { - __m512i accum_v = accum_data_vc; - constexpr int index = 24; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_vc = accum_v; - } - // Process column 13. - { - __m512i accum_v = accum_data_vd; - constexpr int index = 26; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_vd = accum_v; - } - // Process column 14. - { - __m512i accum_v = accum_data_ve; - constexpr int index = 28; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_ve = accum_v; - } - // Process column 15. - { - __m512i accum_v = accum_data_vf; - constexpr int index = 30; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_vf = accum_v; - } - - lhs_ptr += 16 * 4; - rhs_ptr += 16 * 4; - } - - if (params.dst_type_id != DstTypeId::kValue) { - __m512i m_vector; - __m512i e_vector; - // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. - if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { - m_vector = _mm512_maskz_loadu_epi32( - row_mask, ¶ms.multiplier_fixedpoint[row]); - e_vector = _mm512_maskz_loadu_epi32(row_mask, - ¶ms.multiplier_exponent[row]); - } else { - // These arrays have size LhsCols, and are pre-filled. - m_vector = _mm512_set1_epi32(params.multiplier_fixedpoint[0]); - e_vector = _mm512_set1_epi32(params.multiplier_exponent[0]); - } - - const __m512i m_64bit_low = - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0)); - const __m512i m_64bit_high = - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1)); - - const __m512i zero_vector = _mm512_setzero_epi32(); - const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector); - const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector); - const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector); - const __m512i final_right_shift = - _mm512_add_epi32(right_shift, _mm512_set1_epi32(31)); - const __m512i final_right_shift_low = _mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(final_right_shift, 0)); - const __m512i final_right_shift_high = _mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(final_right_shift, 1)); - - const __m512i offset_vector = - _mm512_slli_epi64(_mm512_set1_epi64(1), 30); - // Really these should be shifted by neg_e_vector, but tests pass when - // using right_shift. - const __m512i offset_vector_low = _mm512_sllv_epi64( - offset_vector, - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0))); - const __m512i offset_vector_high = _mm512_sllv_epi64( - offset_vector, - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1))); - - // Shift and round column 0. - { - accum_data_v0 = _mm512_sllv_epi32(accum_data_v0, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v0, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v0, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v0 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v0 = _mm512_inserti32x8( - accum_data_v0, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 1. - { - accum_data_v1 = _mm512_sllv_epi32(accum_data_v1, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v1, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v1, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v1 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v1 = _mm512_inserti32x8( - accum_data_v1, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 2. - { - accum_data_v2 = _mm512_sllv_epi32(accum_data_v2, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v2, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v2, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v2 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v2 = _mm512_inserti32x8( - accum_data_v2, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 3. - { - accum_data_v3 = _mm512_sllv_epi32(accum_data_v3, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v3, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v3, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v3 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v3 = _mm512_inserti32x8( - accum_data_v3, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 4. - { - accum_data_v4 = _mm512_sllv_epi32(accum_data_v4, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v4, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v4, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v4 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v4 = _mm512_inserti32x8( - accum_data_v4, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 5. - { - accum_data_v5 = _mm512_sllv_epi32(accum_data_v5, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v5, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v5, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v5 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v5 = _mm512_inserti32x8( - accum_data_v5, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 6. - { - accum_data_v6 = _mm512_sllv_epi32(accum_data_v6, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v6, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v6, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v6 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v6 = _mm512_inserti32x8( - accum_data_v6, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 7. - { - accum_data_v7 = _mm512_sllv_epi32(accum_data_v7, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v7, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v7, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v7 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v7 = _mm512_inserti32x8( - accum_data_v7, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 8. - { - accum_data_v8 = _mm512_sllv_epi32(accum_data_v8, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v8, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v8, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v8 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v8 = _mm512_inserti32x8( - accum_data_v8, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 9. - { - accum_data_v9 = _mm512_sllv_epi32(accum_data_v9, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v9, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v9, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v9 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v9 = _mm512_inserti32x8( - accum_data_v9, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 10. - { - accum_data_va = _mm512_sllv_epi32(accum_data_va, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_va, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_va, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_va = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_va = _mm512_inserti32x8( - accum_data_va, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 11. - { - accum_data_vb = _mm512_sllv_epi32(accum_data_vb, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_vb, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_vb, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_vb = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_vb = _mm512_inserti32x8( - accum_data_vb, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 12. - { - accum_data_vc = _mm512_sllv_epi32(accum_data_vc, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_vc, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_vc, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_vc = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_vc = _mm512_inserti32x8( - accum_data_vc, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 13. - { - accum_data_vd = _mm512_sllv_epi32(accum_data_vd, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_vd, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_vd, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_vd = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_vd = _mm512_inserti32x8( - accum_data_vd, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 14. - { - accum_data_ve = _mm512_sllv_epi32(accum_data_ve, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_ve, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_ve, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_ve = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_ve = _mm512_inserti32x8( - accum_data_ve, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 15. - { - accum_data_vf = _mm512_sllv_epi32(accum_data_vf, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_vf, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_vf, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_vf = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_vf = _mm512_inserti32x8( - accum_data_vf, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - RUY_DCHECK(false); -#endif - - if (params.dst_zero_point != 0) { - __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point); - accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point); - accum_data_v1 = _mm512_add_epi32(accum_data_v1, dst_zero_point); - accum_data_v2 = _mm512_add_epi32(accum_data_v2, dst_zero_point); - accum_data_v3 = _mm512_add_epi32(accum_data_v3, dst_zero_point); - accum_data_v4 = _mm512_add_epi32(accum_data_v4, dst_zero_point); - accum_data_v5 = _mm512_add_epi32(accum_data_v5, dst_zero_point); - accum_data_v6 = _mm512_add_epi32(accum_data_v6, dst_zero_point); - accum_data_v7 = _mm512_add_epi32(accum_data_v7, dst_zero_point); - accum_data_v8 = _mm512_add_epi32(accum_data_v8, dst_zero_point); - accum_data_v9 = _mm512_add_epi32(accum_data_v9, dst_zero_point); - accum_data_va = _mm512_add_epi32(accum_data_va, dst_zero_point); - accum_data_vb = _mm512_add_epi32(accum_data_vb, dst_zero_point); - accum_data_vc = _mm512_add_epi32(accum_data_vc, dst_zero_point); - accum_data_vd = _mm512_add_epi32(accum_data_vd, dst_zero_point); - accum_data_ve = _mm512_add_epi32(accum_data_ve, dst_zero_point); - accum_data_vf = _mm512_add_epi32(accum_data_vf, dst_zero_point); - } - } - - const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max); - const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min); - - const bool store_full_block = - (residual_rows == 16) && (residual_cols == 16); - - __m512i accum_data_v[16]; - - // In most cases we would make this conditional on (!store_full_block) and - // unwind the clamp-and-store loop, but the benefit appears small. - { - accum_data_v[0] = accum_data_v0; - accum_data_v[1] = accum_data_v1; - accum_data_v[2] = accum_data_v2; - accum_data_v[3] = accum_data_v3; - accum_data_v[4] = accum_data_v4; - accum_data_v[5] = accum_data_v5; - accum_data_v[6] = accum_data_v6; - accum_data_v[7] = accum_data_v7; - accum_data_v[8] = accum_data_v8; - accum_data_v[9] = accum_data_v9; - accum_data_v[10] = accum_data_va; - accum_data_v[11] = accum_data_vb; - accum_data_v[12] = accum_data_vc; - accum_data_v[13] = accum_data_vd; - accum_data_v[14] = accum_data_ve; - accum_data_v[15] = accum_data_vf; - } - - if (params.dst_type_id == DstTypeId::kValue) { - std::int8_t* tmp_ptr = static_cast(dst_ptr); - const int block_col_offset = dst_stride; - if (store_full_block) { - for (int j = 0; j < 16; ++j) { - __m512i result = accum_data_v[j]; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm_storeu_epi8(tmp_ptr + j * block_col_offset, - _mm512_cvtepi32_epi8(result)); - } - } else { - for (int j = 0; j < residual_cols; ++j) { - __m512i result = accum_data_v[j]; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask, - _mm512_cvtepi32_epi8(result)); - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + 16); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::uint8_t* tmp_ptr = static_cast(dst_ptr); - const int block_col_offset = dst_stride; - if (store_full_block) { - for (int j = 0; j < residual_cols; ++j) { - __m512i result = accum_data_v[j]; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm_storeu_epi8(tmp_ptr + j * block_col_offset, - _mm512_cvtepi32_epi8(result)); - } - } else { - for (int j = 0; j < residual_cols; ++j) { - __m512i result = accum_data_v[j]; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask, - _mm512_cvtepi32_epi8(result)); - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + 16); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::int16_t* tmp_ptr = static_cast(dst_ptr); - const int block_col_offset = dst_stride; - if (store_full_block) { - for (int j = 0; j < 16; ++j) { - __m512i result = accum_data_v[j]; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm256_storeu_epi16(tmp_ptr + j * block_col_offset, - _mm512_cvtepi32_epi16(result)); - } - } else { - for (int j = 0; j < residual_cols; ++j) { - __m512i result = accum_data_v[j]; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm256_mask_storeu_epi16(tmp_ptr + j * block_col_offset, row_mask, - _mm512_cvtepi32_epi16(result)); - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + 16); - } else if (params.dst_type_id == DstTypeId::kValue) { - if (store_full_block) { - std::int32_t* tmp_ptr = static_cast(dst_ptr); - for (int j = 0; j < 16; ++j) { - _mm512_storeu_epi32(tmp_ptr + j * dst_stride, accum_data_v[j]); - } - } else { - std::int32_t* tmp_ptr = static_cast(dst_ptr); - for (int j = 0; j < residual_cols; ++j) { - _mm512_mask_storeu_epi32(tmp_ptr + j * dst_stride, row_mask, - accum_data_v[j]); - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + 16); - } else { - RUY_DCHECK(false); - } - - lhs_col_ptr += 16 * params.lhs_stride; - } // End row-block loop. - - dst_col_ptr = static_cast(static_cast(dst_col_ptr) + - 16 * params.dst_stride); - rhs_col_ptr += 16 * params.rhs_stride; - } // End col-block loop. -} // NOLINT(readability/fn_size) - -void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { - profiler::ScopeLabel label("Kernel kAvx512 8-bit GEMV"); - - RUY_DCHECK_EQ(params.dst_cols, 1); - RUY_DCHECK_EQ(params.last_col, 0); - RUY_DCHECK_EQ(params.start_col, 0); - - std::int32_t dst_stride; - if ((params.dst_type_id == DstTypeId::kValue) || - (params.dst_type_id == DstTypeId::kValue)) { - dst_stride = params.dst_stride; - } else if (params.dst_type_id == DstTypeId::kValue) { - dst_stride = params.dst_stride / sizeof(std::int16_t); - } else if (params.dst_type_id == DstTypeId::kValue) { - dst_stride = params.dst_stride / sizeof(std::int32_t); - } else { - RUY_DCHECK(false); - } - - int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0; - - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - void* dst_col_ptr = params.dst_base_ptr; - const std::int32_t* bias_col_ptr = params.bias; - if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { - bias_col_ptr += params.start_row; - } - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - void* dst_ptr = dst_col_ptr; - const std::int32_t* bias_ptr = bias_col_ptr; - - const std::int32_t lhs_zero_point = params.lhs_zero_point; - const bool has_rhs_sums_offsets = - (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; - std::int32_t rhs_sums_offsets[16]; - if (has_rhs_sums_offsets) { - const __m512i rhs_sums_offset_v = - _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point), - _mm512_loadu_epi32(¶ms.rhs_sums[0])); - _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets), - rhs_sums_offset_v); - } - - for (int row = params.start_row; row <= params.last_row; row += 16) { - const int residual_rows = std::min(params.dst_rows - row, 16); - - __m512i accum_data_v0; - - // Initialize with bias. - const __mmask16 row_mask = - (static_cast(1) << residual_rows) - 1; - __m512i initial_accum_data = _mm512_maskz_loadu_epi32(row_mask, bias_ptr); - bias_ptr += bias_ptr_block_increment; - - const std::int32_t rhs_zero_point = params.rhs_zero_point; - if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { - const __m512i lhs_sums_offset = - _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point), - _mm512_loadu_epi32(¶ms.lhs_sums[row])); - initial_accum_data = - _mm512_sub_epi32(initial_accum_data, lhs_sums_offset); - } - - const std::int32_t prod_zp_depth = params.prod_zp_depth; - if (prod_zp_depth != 0) { - initial_accum_data = _mm512_add_epi32(initial_accum_data, - _mm512_set1_epi32(prod_zp_depth)); - } - - // Adjustments differing across columns. - if (has_rhs_sums_offsets) { - accum_data_v0 = _mm512_sub_epi32(initial_accum_data, - _mm512_set1_epi32(rhs_sums_offsets[0])); - } else { - accum_data_v0 = initial_accum_data; - } - - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; d += 4) { - const __m512i lhs_data = _mm512_loadu_epi8(lhs_ptr); - const __m128i rhs_data_8bit = _mm_loadu_epi8(rhs_ptr); - - // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. - // For simplicity we load 4x the data that we need and process twice the - // data that we need and store only the data we need. - std::int32_t rhs_data[2]; - const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); - // Now that we have cast the RHS data, we store it so that each value - // can be separately loaded in the accumulation loop. - _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup); - - // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit. - const __m512i lhs_16_bit_low = - _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data)); - // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit. - const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16( - _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16))); - - // Process column 0. - __m512i accum_v = accum_data_v0; - constexpr int index = 0; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v0 = accum_v; - - lhs_ptr += 16 * 4; - rhs_ptr += 16 * 4; - } - - if (params.dst_type_id != DstTypeId::kValue) { - __m512i m_vector; - __m512i e_vector; - // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. - if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { - m_vector = _mm512_maskz_loadu_epi32(row_mask, - ¶ms.multiplier_fixedpoint[row]); - e_vector = _mm512_maskz_loadu_epi32(row_mask, - ¶ms.multiplier_exponent[row]); - } else { - // These arrays have size LhsCols, and are pre-filled. - m_vector = _mm512_set1_epi32(params.multiplier_fixedpoint[0]); - e_vector = _mm512_set1_epi32(params.multiplier_exponent[0]); - } - - const __m512i m_64bit_low = - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0)); - const __m512i m_64bit_high = - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1)); - - const __m512i zero_vector = _mm512_setzero_epi32(); - const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector); - const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector); - const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector); - const __m512i final_right_shift = - _mm512_add_epi32(right_shift, _mm512_set1_epi32(31)); - const __m512i final_right_shift_low = _mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(final_right_shift, 0)); - const __m512i final_right_shift_high = _mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(final_right_shift, 1)); - - const __m512i offset_vector = _mm512_slli_epi64(_mm512_set1_epi64(1), 30); - // Really these should be shifted by neg_e_vector, but tests pass when - // using right_shift. - const __m512i offset_vector_low = _mm512_sllv_epi64( - offset_vector, - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0))); - const __m512i offset_vector_high = _mm512_sllv_epi64( - offset_vector, - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1))); - - // Shift and round column 0. - accum_data_v0 = _mm512_sllv_epi32(accum_data_v0, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = _mm512_mul_epi32( - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 0)), - m_64bit_low); - __m512i scaled_v_high = _mm512_mul_epi32( - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v0 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v0 = _mm512_inserti32x8( - accum_data_v0, _mm512_cvtepi64_epi32(scaled_v_high), 1); -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - RUY_DCHECK(false); -#endif - - if (params.dst_zero_point != 0) { - __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point); - accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point); - } - } - - const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max); - const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min); - - if (params.dst_type_id == DstTypeId::kValue) { - std::int8_t* tmp_ptr = static_cast(dst_ptr); - __m512i result = accum_data_v0; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result)); - dst_ptr = static_cast(static_cast(dst_ptr) + 16); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::uint8_t* tmp_ptr = static_cast(dst_ptr); - __m512i result = accum_data_v0; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result)); - dst_ptr = static_cast(static_cast(dst_ptr) + 16); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::int16_t* tmp_ptr = static_cast(dst_ptr); - __m512i result = accum_data_v0; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm256_mask_storeu_epi16(tmp_ptr, row_mask, - _mm512_cvtepi32_epi16(result)); - dst_ptr = static_cast(static_cast(dst_ptr) + 16); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::int32_t* tmp_ptr = static_cast(dst_ptr); - _mm512_mask_storeu_epi32(tmp_ptr, row_mask, accum_data_v0); - dst_ptr = static_cast(static_cast(dst_ptr) + 16); - } else { - RUY_DCHECK(false); - } - - lhs_col_ptr += 16 * params.lhs_stride; - } // End row-block loop. -} // NOLINT(readability/fn_size) - -void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { - profiler::ScopeLabel label("Kernel kAvx512 float"); - - // As parameters are defined, we need to scale by sizeof(float). - const std::int64_t lhs_stride = params.lhs_stride >> 2; - const std::int64_t dst_stride = params.dst_stride >> 2; - const std::int64_t rhs_stride = params.rhs_stride >> 2; - - int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; - const int end_row = std::min(params.dst_rows, params.last_row + 16); - const int end_col = std::min(params.dst_cols, params.last_col + 16); - - const float* adj_rhs_col_ptr = - params.rhs_base_ptr - params.start_col * rhs_stride; - float* adj_dst_col_ptr = - params.dst_base_ptr - params.start_col * dst_stride - params.start_row; - const float* adj_lhs_col_ptr = - params.lhs_base_ptr - params.start_row * lhs_stride; - const float* bias_col_ptr = params.bias; - - const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max); - const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min); - - int col = params.start_col; - for (; col <= end_col - 16; col += 16) { - const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; - float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; - - int row = params.start_row; - for (; row <= end_row - 16; row += 16) { - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - const __m512 initial_accum_data = _mm512_loadu_ps(bias_ptr); - - // Process block in two halves, split by columns. - { - constexpr int mmm = 0; - - __m512 accum_data_v0 = initial_accum_data; - __m512 accum_data_v1 = initial_accum_data; - __m512 accum_data_v2 = initial_accum_data; - __m512 accum_data_v3 = initial_accum_data; - __m512 accum_data_v4 = initial_accum_data; - __m512 accum_data_v5 = initial_accum_data; - __m512 accum_data_v6 = initial_accum_data; - __m512 accum_data_v7 = initial_accum_data; - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr + 8 * mmm; - for (int d = 0; d < (params.depth - 1); ++d) { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - // In this version RHS values are loaded individually rather than - // first loading together and then extract with broadcasting. This is - // because AVX flavours and instrinsics and compilers in combination - // do not handle this pattern of extraction very well. - const float* rhs_data = rhs_ptr; - lhs_ptr += 16; - rhs_ptr += 16; - - { - const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } - } - { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - { - const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } - { - float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; - accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); - accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v); - _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0); - accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v); - accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v); - _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1); - accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v); - accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v); - _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2); - accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v); - accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v); - _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3); - accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v); - accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v); - _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4); - accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v); - accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v); - _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5); - accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v); - accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v); - _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6); - accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v); - accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v); - _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7); - } - } - } // Inner half-block loop, unrolled, first iteration. - { - constexpr int mmm = 1; - - __m512 accum_data_v0 = initial_accum_data; - __m512 accum_data_v1 = initial_accum_data; - __m512 accum_data_v2 = initial_accum_data; - __m512 accum_data_v3 = initial_accum_data; - __m512 accum_data_v4 = initial_accum_data; - __m512 accum_data_v5 = initial_accum_data; - __m512 accum_data_v6 = initial_accum_data; - __m512 accum_data_v7 = initial_accum_data; - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr + 8 * mmm; - for (int d = 0; d < (params.depth - 1); ++d) { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - lhs_ptr += 16; - rhs_ptr += 16; - { - const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } - } - { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - { - const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } - { - float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; - accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); - accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v); - _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0); - accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v); - accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v); - _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1); - accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v); - accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v); - _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2); - accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v); - accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v); - _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3); - accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v); - accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v); - _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4); - accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v); - accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v); - _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5); - accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v); - accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v); - _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6); - accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v); - accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v); - _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7); - } - } - } // Inner half-block loop, unrolled, second iteration. - } // End row-block loop. - - // The unrolling within this conditional may be somewhat pointless. It - // depends on the kinds of models. - if (row < end_row) { - const int residual_rows = end_row - row; - - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - const __mmask16 row_mask = - (static_cast(1) << residual_rows) - 1; - const __m512 initial_accum_data = - _mm512_maskz_loadu_ps(row_mask, bias_ptr); - - // Process block in two halves, split by columns. - for (int mmm = 0; mmm < 2; ++mmm) { - __m512 accum_data_v0 = initial_accum_data; - __m512 accum_data_v1 = initial_accum_data; - __m512 accum_data_v2 = initial_accum_data; - __m512 accum_data_v3 = initial_accum_data; - __m512 accum_data_v4 = initial_accum_data; - __m512 accum_data_v5 = initial_accum_data; - __m512 accum_data_v6 = initial_accum_data; - __m512 accum_data_v7 = initial_accum_data; - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr + 8 * mmm; - for (int d = 0; d < (params.depth - 1); ++d) { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - lhs_ptr += 16; - rhs_ptr += 16; - { - const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } - } - { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - { - const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } - { - float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; - accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); - accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v); - _mm512_mask_storeu_ps(block_ptr + 0 * dst_stride, row_mask, - accum_data_v0); - accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v); - accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v); - _mm512_mask_storeu_ps(block_ptr + 1 * dst_stride, row_mask, - accum_data_v1); - accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v); - accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v); - _mm512_mask_storeu_ps(block_ptr + 2 * dst_stride, row_mask, - accum_data_v2); - accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v); - accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v); - _mm512_mask_storeu_ps(block_ptr + 3 * dst_stride, row_mask, - accum_data_v3); - accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v); - accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v); - _mm512_mask_storeu_ps(block_ptr + 4 * dst_stride, row_mask, - accum_data_v4); - accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v); - accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v); - _mm512_mask_storeu_ps(block_ptr + 5 * dst_stride, row_mask, - accum_data_v5); - accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v); - accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v); - _mm512_mask_storeu_ps(block_ptr + 6 * dst_stride, row_mask, - accum_data_v6); - accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v); - accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v); - _mm512_mask_storeu_ps(block_ptr + 7 * dst_stride, row_mask, - accum_data_v7); - } - } - } // Inner half-block loop. - } // Residual rows, main col-block loop. - } // End col-block loop. - - if (col < end_col) { - RUY_DCHECK_GE(end_col - col, 0); - RUY_DCHECK_LT(end_col - col, 16); - - __m512 accum_data_v[8]; - - const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; - float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; - - for (int row = params.start_row; row < end_row; row += 16) { - const int residual_rows = std::min(end_row - row, 16); - - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - const __mmask16 row_mask = - (static_cast(1) << residual_rows) - 1; - const __m512 initial_accum_data = - _mm512_maskz_loadu_ps(row_mask, bias_ptr); - - // Process block in two halves, split by columns. - for (int mmm = 0; mmm < 2; ++mmm) { - for (int j = 0; j < 8; ++j) { - accum_data_v[j] = initial_accum_data; - } - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr + 8 * mmm; - for (int d = 0; d < params.depth; ++d) { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - - for (int j = 0; j < 8; ++j) { - const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data[j]); - accum_data_v[j] = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]); - } - lhs_ptr += 16; - rhs_ptr += 16; - } - - const int residual_cols = std::min(end_col - col - 8 * mmm, 8); - - if (residual_rows == 16) { - if (residual_cols == 8) { - for (int j = 0; j < 8; ++j) { - float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride; - accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v); - accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v); - _mm512_storeu_ps(block_ptr, accum_data_v[j]); - } - } else { - for (int j = 0; j < residual_cols; ++j) { - float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride; - accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v); - accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v); - _mm512_storeu_ps(block_ptr, accum_data_v[j]); - } - } - } else { - for (int j = 0; j < residual_cols; ++j) { - float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride; - accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v); - accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v); - _mm512_mask_storeu_ps(block_ptr, row_mask, accum_data_v[j]); - } - } - } // Inner half-block loop. - } // End row-block loop. - } // Residual cols. -} - -void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& params) { - profiler::ScopeLabel label("Kernel kAvx512 float GEMV"); - - RUY_DCHECK_EQ(params.dst_cols, 1); - RUY_DCHECK_EQ(params.last_col, 0); - RUY_DCHECK_EQ(params.start_col, 0); - - // As parameters are defined, we need to scale by sizeof(float). - const std::int64_t lhs_stride = params.lhs_stride >> 2; - - int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; - const int end_row = std::min(params.dst_rows, params.last_row + 16); - - float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row; - const float* adj_lhs_col_ptr = - params.lhs_base_ptr - params.start_row * lhs_stride; - const float* bias_col_ptr = params.bias; - - const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max); - const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min); - - __m512 accum_data_v; - - const float* rhs_col_ptr = params.rhs_base_ptr; - float* dst_col_ptr = adj_dst_col_ptr; - - int row = params.start_row; - for (; row <= end_row - 16; row += 16) { - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - accum_data_v = _mm512_loadu_ps(bias_ptr); - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; ++d) { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float rhs_data = *rhs_ptr; - - const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data); - accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); - lhs_ptr += 16; - rhs_ptr += 16; - } - - accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v); - accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v); - _mm512_storeu_ps(dst_ptr, accum_data_v); - } // End row-block loop. - - if (row < end_row) { - const int residual_rows = end_row - row; - RUY_CHECK_GE(residual_rows, 1); - RUY_CHECK_LT(residual_rows, 16); - - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - const __mmask16 row_mask = - (static_cast(1) << residual_rows) - 1; - accum_data_v = _mm512_maskz_loadu_ps(row_mask, bias_ptr); - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; ++d) { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float rhs_data = *rhs_ptr; - - const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data); - accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); - lhs_ptr += 16; - rhs_ptr += 16; - } - - accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v); - accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v); - _mm512_mask_storeu_ps(dst_ptr, row_mask, accum_data_v); - } // End handling of residual rows. -} - -#endif // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/kernel_avxvnni.cc b/tensorflow/lite/experimental/ruy/ruy/kernel_avxvnni.cc deleted file mode 100644 index c868c00957b..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/kernel_avxvnni.cc +++ /dev/null @@ -1,435 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/kernel.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -#if RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) -#include // IWYU pragma: keep -#endif - -namespace ruy { - -#if !(RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -void Kernel8bitAvxVnni(const KernelParams8bit<16, 16>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void KernelFloatAvxVnni(const KernelParamsFloat<16, 16>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -#else // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -static constexpr int kAvxFloatBlockSize = 16; -static constexpr int kAvx8bitBlockSize = 16; -static constexpr int kAvx8bitInnerSize = 4; - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// When removing this comment, update profiling label below. -void Kernel8bitAvxVnni(const KernelParams8bit<16, 16>& params) { - profiler::ScopeLabel label("Kernel kAvxVnni 8-bit (UNFINISHED)"); - - std::int32_t accum_data[kAvx8bitBlockSize][kAvx8bitBlockSize]; - - int bias_ptr_block_increment = - params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; - - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - void* dst_col_ptr = params.dst_base_ptr; - const std::int32_t* bias_col_ptr = params.bias; - if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { - bias_col_ptr += params.start_row; - } - - for (int col = params.start_col; col <= params.last_col; - col += kAvx8bitBlockSize) { - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - void* dst_ptr = dst_col_ptr; - const std::int32_t* bias_ptr = bias_col_ptr; - - for (int row = params.start_row; row <= params.last_row; - row += kAvx8bitBlockSize) { - const int residual_rows = - std::min(params.dst_rows - row, kAvx8bitBlockSize); - const int residual_cols = - std::min(params.dst_cols - col, kAvx8bitBlockSize); - - // Initialize with bias. - std::int32_t initial_accum_data[kAvx8bitBlockSize]; - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - initial_accum_data[i] = 0; - } - for (int i = 0; i < residual_rows; ++i) { - initial_accum_data[i] = bias_ptr[i]; - } - - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] = initial_accum_data[i]; - } - } - bias_ptr += bias_ptr_block_increment; - - std::int8_t lhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize]; - std::int8_t rhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize]; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - for (int x = 0; x < kAvx8bitInnerSize; ++x) { - lhs_data[i][x] = lhs_ptr[i * kAvx8bitInnerSize + x]; - rhs_data[i][x] = rhs_ptr[i * kAvx8bitInnerSize + x]; - } - } - - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - for (int x = 0; x < kAvx8bitInnerSize; ++x) { - accum_data[j][i] += lhs_data[i][x] * rhs_data[j][x]; - } - } - } - - lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - } - - if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && params.rhs_zero_point) { - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] -= - params.rhs_zero_point * params.lhs_sums[row + i]; - } - } - } - if ((params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && params.lhs_zero_point) { - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] -= - params.lhs_zero_point * params.rhs_sums[col + j]; - } - } - } - if (params.lhs_zero_point && params.rhs_zero_point) { - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] += params.prod_zp_depth; - } - } - } - - if (params.dst_type_id != DstTypeId::kValue) { - std::int32_t m_vector[kAvx8bitBlockSize]; - std::int32_t e_vector[kAvx8bitBlockSize]; - // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. - if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { - int i = 0; - for (; i < residual_rows; ++i) { - m_vector[i] = params.multiplier_fixedpoint[row + i]; - e_vector[i] = params.multiplier_exponent[row + i]; - } - for (; i < kAvx8bitBlockSize; ++i) { - m_vector[i] = m_vector[0]; - e_vector[i] = e_vector[0]; - } - } else { - // These arrays have size LhsCols, and are pre-filled. - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - m_vector[i] = params.multiplier_fixedpoint[i]; - e_vector[i] = params.multiplier_exponent[i]; - } - } - - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] = MultiplyByQuantizedMultiplier( - accum_data[j][i], m_vector[i], e_vector[i]); - } - } - - if (params.dst_zero_point) { - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] += params.dst_zero_point; - } - } - } - - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] = - std::min(accum_data[j][i], params.clamp_max); - accum_data[j][i] = - std::max(accum_data[j][i], params.clamp_min); - } - } - } - - const bool store_full_block = (residual_rows == kAvx8bitBlockSize) && - (residual_cols == kAvx8bitBlockSize); - - if (params.dst_type_id == DstTypeId::kValue) { - std::int8_t* tmp_ptr = - store_full_block - ? static_cast(dst_ptr) - : const_cast( - reinterpret_cast(params.dst_tmp_buf)); - const int block_col_offset = - store_full_block ? params.dst_stride / sizeof(std::int8_t) - : kAvx8bitBlockSize; - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - - if (!store_full_block) { - const std::int8_t* block_ptr = - reinterpret_cast(params.dst_tmp_buf); - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - static_cast( - dst_ptr)[j * params.dst_stride / sizeof(std::int8_t) + i] = - block_ptr[i]; - } - block_ptr += kAvx8bitBlockSize; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::uint8_t* tmp_ptr = store_full_block - ? static_cast(dst_ptr) - : const_cast( - reinterpret_cast( - params.dst_tmp_buf)); - const int block_col_offset = - store_full_block ? params.dst_stride : kAvx8bitBlockSize; - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - - if (!store_full_block) { - const std::uint8_t* block_ptr = - reinterpret_cast(params.dst_tmp_buf); - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - static_cast( - dst_ptr)[j * params.dst_stride / sizeof(std::uint8_t) + i] = - block_ptr[i]; - } - block_ptr += kAvx8bitBlockSize; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - if (store_full_block) { - std::int16_t* tmp_ptr = static_cast(dst_ptr); - const int block_col_offset = params.dst_stride / sizeof(std::int16_t); - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - } else { - std::int16_t* tmp_ptr = const_cast( - reinterpret_cast(params.dst_tmp_buf)); - const int block_col_offset = kAvx8bitBlockSize; - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - const std::int16_t* block_ptr = - reinterpret_cast(params.dst_tmp_buf); - std::int16_t* dst_block_ptr = static_cast(dst_ptr); - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - dst_block_ptr[i] = block_ptr[i]; - } - dst_block_ptr += params.dst_stride / sizeof(std::int16_t); - block_ptr += kAvx8bitBlockSize; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - if (store_full_block) { - std::int32_t* tmp_ptr = static_cast(dst_ptr); - const int block_col_offset = params.dst_stride / sizeof(std::int32_t); - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - } else { - std::int32_t* dst_block_ptr = static_cast(dst_ptr); - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - dst_block_ptr[i] = accum_data[j][i]; - } - dst_block_ptr += params.dst_stride / sizeof(std::int32_t); - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else { - RUY_DCHECK(false); - } - - lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; - } // End row-block loop. - - dst_col_ptr = static_cast(static_cast(dst_col_ptr) + - kAvx8bitBlockSize * params.dst_stride); - rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; - } // End col-block loop. -} // NOLINT(readability/fn_size) - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// When removing this comment, update profiling label below. -void KernelFloatAvxVnni(const KernelParamsFloat<16, 16>& params) { - profiler::ScopeLabel label("Kernel kAvxVnni float (UNFINISHED)"); - - float lhs_data[kAvxFloatBlockSize]; - float rhs_data[kAvxFloatBlockSize]; - float accum_data[kAvxFloatBlockSize][kAvxFloatBlockSize]; - int bias_ptr_block_increment = - params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvxFloatBlockSize : 0; - - const float* rhs_col_ptr = params.rhs_base_ptr; - float* dst_col_ptr = params.dst_base_ptr; - const float* bias_col_ptr = params.bias; - if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { - bias_col_ptr += params.start_row; - } - - for (int col = params.start_col; col <= params.last_col; - col += kAvxFloatBlockSize) { - const float* lhs_col_ptr = params.lhs_base_ptr; - float* dst_ptr = dst_col_ptr; - const float* bias_ptr = bias_col_ptr; - - for (int row = params.start_row; row <= params.last_row; - row += kAvxFloatBlockSize) { - const int residual_rows = - std::min(params.dst_rows - row, kAvxFloatBlockSize); - const int residual_cols = - std::min(params.dst_cols - col, kAvxFloatBlockSize); - - // Initialize with bias. - float initial_accum_data[kAvxFloatBlockSize]; - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - initial_accum_data[i] = 0.0f; - } - for (int i = 0; i < residual_rows; ++i) { - initial_accum_data[i] = bias_ptr[i]; - } - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - accum_data[j][i] = initial_accum_data[i]; - } - } - bias_ptr += bias_ptr_block_increment; - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; ++d) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - lhs_data[i] = lhs_ptr[i]; - rhs_data[i] = rhs_ptr[i]; - } - - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - accum_data[j][i] += lhs_data[i] * rhs_data[j]; - } - } - - lhs_ptr += kAvxFloatBlockSize; - rhs_ptr += kAvxFloatBlockSize; - } - - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - accum_data[j][i] = - std::min(accum_data[j][i], params.clamp_max); - accum_data[j][i] = - std::max(accum_data[j][i], params.clamp_min); - } - } - - const bool store_full_block = (residual_rows == kAvxFloatBlockSize) && - (residual_cols == kAvxFloatBlockSize); - - { - float* block_ptr = - store_full_block ? dst_ptr : const_cast(params.dst_tmp_buf); - const int block_col_offset = store_full_block - ? params.dst_stride / sizeof(float) - : kAvxFloatBlockSize; - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - block_ptr[i] = accum_data[j][i]; - } - block_ptr += block_col_offset; - } - } - if (!store_full_block) { - const float* block_ptr = params.dst_tmp_buf; - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - dst_ptr[j * params.dst_stride / sizeof(float) + i] = block_ptr[i]; - } - block_ptr += kAvxFloatBlockSize; - } - } - - lhs_col_ptr += kAvxFloatBlockSize * params.lhs_stride / sizeof(float); - dst_ptr += kAvxFloatBlockSize; - } // End row-block loop. - - dst_col_ptr += kAvxFloatBlockSize * params.dst_stride / sizeof(float); - rhs_col_ptr += kAvxFloatBlockSize * params.rhs_stride / sizeof(float); - } // End col-block loop. -} - -#endif // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/kernel_common.h b/tensorflow/lite/experimental/ruy/ruy/kernel_common.h deleted file mode 100644 index c1721b81869..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/kernel_common.h +++ /dev/null @@ -1,481 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_COMMON_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_COMMON_H_ - -#include -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" -#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" -#include "tensorflow/lite/experimental/ruy/ruy/spec.h" -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -namespace ruy { - -template -struct Kernel {}; - -template -void RunKernelTyped(Tuning tuning, const PackedMatrix& lhs, - const PackedMatrix& rhs, const Spec& spec, - int start_row, int start_col, int end_row, int end_col, - Matrix* dst) { - using Kernel = Kernel; - Kernel kernel(tuning); -#if !defined(NDEBUG) || !RUY_OPT_ENABLED(RUY_OPT_FAT_KERNEL) - using LhsLayout = typename Kernel::LhsLayout; - using RhsLayout = typename Kernel::RhsLayout; -#endif - // end_row and end_col may be larger than dst dimensions. - // that is because kernels write directly to the destination matrix, whose - // dimensions may not be a multiple of the kernel dimensions, and we try to - // keep this annoyance localized as an implementation detail in kernels, - // by allowing to pass rounded-up values down as far as possible. - // These assertions encode the contract. - RUY_DCHECK_LE(0, start_row); - RUY_DCHECK_LE(start_row, end_row); - RUY_DCHECK_LT(end_row, dst->layout.rows + LhsLayout::kCols); - RUY_DCHECK_EQ((end_row - start_row) % LhsLayout::kCols, 0); - RUY_DCHECK_LE(0, start_col); - RUY_DCHECK_LE(start_col, end_col); - RUY_DCHECK_LT(end_col, dst->layout.cols + RhsLayout::kCols); - RUY_DCHECK_EQ((end_col - start_col) % RhsLayout::kCols, 0); -#if RUY_OPT_ENABLED(RUY_OPT_FAT_KERNEL) - kernel.Run(lhs, rhs, spec, start_row, start_col, end_row, end_col, dst); -#else - for (int col = start_col; col < end_col; col += RhsLayout::kCols) { - int block_end_col = std::min(col + RhsLayout::kCols, end_col); - for (int row = start_row; row < end_row; row += LhsLayout::kCols) { - int block_end_row = std::min(row + LhsLayout::kCols, end_row); - kernel.Run(lhs, rhs, spec, row, col, block_end_row, block_end_col, dst); - } - } -#endif -} - -// Main entry point for kernels. -template -void RunKernel(Tuning tuning, const SidePair& src, void* spec, - const SidePair& start, const SidePair& end, - DMatrix* dst) { - Matrix mdst = ToMatrix(*dst); - RunKernelTyped( - tuning, ToPackedMatrix(src[Side::kLhs]), - ToPackedMatrix(src[Side::kRhs]), - *static_cast(spec), start[Side::kLhs], start[Side::kRhs], - end[Side::kLhs], end[Side::kRhs], &mdst); -} - -// Copied from gemmlowp/fixedpoint. -inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, - std::int32_t b) { - bool overflow = a == b && a == std::numeric_limits::min(); - std::int64_t a_64(a); - std::int64_t b_64(b); - std::int64_t ab_64 = a_64 * b_64; - std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30)); - std::int32_t ab_x2_high32 = - static_cast((ab_64 + nudge) / (1ll << 31)); - return overflow ? std::numeric_limits::max() : ab_x2_high32; -} - -inline std::int32_t RoundingDivideByPOT(std::int32_t numerator, int exponent) { - std::int32_t sign = numerator >= 0 ? 1 : -1; - std::int32_t abs_numerator = std::abs(numerator); - std::int32_t mask = (1LL << exponent) - 1; - std::int32_t remainder = abs_numerator & mask; - std::int32_t threshold = mask >> 1; - std::int32_t abs_result = - (abs_numerator >> exponent) + (remainder > threshold ? 1 : 0); - return sign * abs_result; -} - -// Copied from TF Lite code. -inline std::int32_t MultiplyByQuantizedMultiplier( - std::int32_t x, std::int32_t quantized_multiplier, int shift) { - int left_shift = shift > 0 ? shift : 0; - int right_shift = shift > 0 ? 0 : -shift; - return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( - x * (1 << left_shift), quantized_multiplier), - right_shift); -} - -// Helper to apply a fixed-point multiplier. Only 'applicable' if AccumScalar -// is int32 (i.e. in all cases except floating-point) and if the destination is -// not int32 (i.e. unless the user wants to get raw accumulators). -template ::value && - !std::is_same::value> -struct ApplyMultiplierImpl {}; - -// Specialization in non-applicable case: do nothing, just check that values -// are default. -template -struct ApplyMultiplierImpl { - using AccumScalar = typename Spec::AccumScalar; - using DstScalar = typename Spec::DstScalar; - static void Run(const Spec& spec, int row, AccumScalar* accum) { - RUY_DCHECK_EQ(spec.multiplier_fixedpoint, 0); - RUY_DCHECK_EQ(spec.multiplier_exponent, 0); - } -}; - -template -struct ApplyMultiplierImpl { - using AccumScalar = typename Spec::AccumScalar; - using DstScalar = typename Spec::DstScalar; - static void Run(const Spec& spec, int row, AccumScalar* accum) { - AccumScalar m = spec.multiplier_fixedpoint_perchannel - ? spec.multiplier_fixedpoint_perchannel[row] - : spec.multiplier_fixedpoint; - int e = spec.multiplier_exponent_perchannel - ? spec.multiplier_exponent_perchannel[row] - : spec.multiplier_exponent; - *accum = MultiplyByQuantizedMultiplier(*accum, m, e); - } -}; - -template -void ApplyMultiplier(const Spec& spec, int row, - typename Spec::AccumScalar* accum) { - ApplyMultiplierImpl::Run(spec, row, accum); -} - -template -struct Kernel { - using AccumScalar = typename Spec::AccumScalar; - using LhsLayout = typename Spec::StandardCppKernelLhsLayout; - using RhsLayout = typename Spec::StandardCppKernelRhsLayout; - explicit Kernel(Tuning) {} - void Run(const PackedMatrix& lhs, - const PackedMatrix& rhs, const Spec& spec, int start_row, - int start_col, int end_row, int end_col, - Matrix* dst) const { - // See the comment in RunKernelTyped. end_row may be larger than - // dst->layout.rows. It's the responsibility of the kernel to avoid - // overrunning dst boundaries, which we do here by computing - // clamped_end_row. - int clamped_end_row = std::min(end_row, dst->layout.rows); - int clamped_end_col = std::min(end_col, dst->layout.cols); - RUY_DCHECK_LE(0, start_row); - RUY_DCHECK_LE(start_row, clamped_end_row); - RUY_DCHECK_LE(clamped_end_row, dst->layout.rows); - RUY_DCHECK_LE(clamped_end_row, end_row); - RUY_DCHECK_LE(end_row - clamped_end_row, LhsLayout::kCols); - RUY_DCHECK_LE(0, start_col); - RUY_DCHECK_LE(start_col, clamped_end_col); - RUY_DCHECK_LE(clamped_end_col, dst->layout.cols); - RUY_DCHECK_LE(clamped_end_col, end_col); - RUY_DCHECK_LE(end_col - clamped_end_col, RhsLayout::kCols); - profiler::ScopeLabel label("Kernel (Standard Cpp)"); - const int depth = lhs.layout.rows; - for (int i = start_row; i < clamped_end_row; i++) { - for (int j = start_col; j < clamped_end_col; j++) { - using AccumScalar = typename Spec::AccumScalar; - AccumScalar accum = 0; - for (int k = 0; k < depth; k++) { - AccumScalar lhs_val = Element(lhs, k, i); - AccumScalar rhs_val = Element(rhs, k, j); - accum += lhs_val * rhs_val; - } - if (spec.bias) { - accum += spec.bias[i]; - } - if (lhs.zero_point) { - accum -= lhs.zero_point * rhs.sums[j]; - } - if (rhs.zero_point) { - accum -= rhs.zero_point * lhs.sums[i]; - } - if (lhs.zero_point && rhs.zero_point) { - accum += lhs.zero_point * rhs.zero_point * depth; - } - ApplyMultiplier(spec, i, &accum); - accum += dst->zero_point; - accum = std::min(accum, spec.clamp_max); - accum = std::max(accum, spec.clamp_min); - *ElementPtr(dst, i, j) = static_cast(accum); - } - } - } -}; - -#define RUY_INHERIT_KERNEL(PARENT, CHILD) \ - template \ - struct Kernel \ - : Kernel { \ - explicit Kernel(Tuning tuning) \ - : Kernel(tuning) {} \ - }; - -#if RUY_PLATFORM(NEON) -RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kNeon) -RUY_INHERIT_KERNEL(Path::kNeon, Path::kNeonDotprod) -#elif RUY_PLATFORM(X86) -RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kSse42) -RUY_INHERIT_KERNEL(Path::kSse42, Path::kAvx2) -RUY_INHERIT_KERNEL(Path::kAvx2, Path::kAvx512) -RUY_INHERIT_KERNEL(Path::kAvx512, Path::kAvxVnni) -#endif - -// KernelParams are shared across 32-bit and 64-bit NEON code, and x86 code. -// -// In other cases, we still define (empty) versions, so that dummy kernels -// can use the classes in function signatures. -#if ((RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && \ - RUY_OPT_ENABLED(RUY_OPT_ASM)) || \ - RUY_PLATFORM(X86) - -#define RUY_ASM_FLAG_HAS_BIAS 0x1 -#define RUY_ASM_FLAG_HAS_LHS_SUMS 0x2 -#define RUY_ASM_FLAG_HAS_RHS_SUMS 0x4 -#define RUY_ASM_FLAG_HAS_PERCHANNEL 0x8 -#define RUY_ASM_FLAG_NEEDS_LEFT_SHIFT 0x10 - -#define RUY_ASM_TYPE_ID_UINT8 1 -#define RUY_ASM_TYPE_ID_INT8 2 -#define RUY_ASM_TYPE_ID_INT16 3 -#define RUY_ASM_TYPE_ID_INT32 4 - -template -struct DstTypeId {}; - -template <> -struct DstTypeId { - static constexpr int kValue = RUY_ASM_TYPE_ID_UINT8; -}; - -template <> -struct DstTypeId { - static constexpr int kValue = RUY_ASM_TYPE_ID_INT8; -}; - -template <> -struct DstTypeId { - static constexpr int kValue = RUY_ASM_TYPE_ID_INT16; -}; - -template <> -struct DstTypeId { - static constexpr int kValue = RUY_ASM_TYPE_ID_INT32; -}; - -template -struct KernelParams8bit { - static constexpr int kMaxDstTypeSize = 4; - - const std::int32_t* bias; - const std::int32_t* lhs_sums; - const std::int32_t* rhs_sums; - const std::int8_t* lhs_base_ptr; - const std::int32_t* multiplier_fixedpoint; - const std::int32_t* multiplier_exponent; - const std::int8_t* rhs_base_ptr; - void* dst_base_ptr; - std::int32_t lhs_zero_point; - std::int32_t rhs_zero_point; - std::int32_t dst_zero_point; - std::int32_t prod_zp_depth; - std::int32_t start_row; - std::int32_t start_col; - std::int32_t last_row; - std::int32_t last_col; - std::int32_t dst_rows; - std::int32_t dst_cols; - std::int32_t lhs_stride; - std::int32_t rhs_stride; - std::int32_t dst_stride; - std::int32_t depth; - std::int32_t clamp_min; - std::int32_t clamp_max; - std::uint8_t flags; - std::uint8_t dst_type_id; - const std::int32_t zero_data[LhsCols] = {0}; - std::uint8_t dst_tmp_buf[LhsCols * RhsCols * kMaxDstTypeSize]; - std::int32_t multiplier_fixedpoint_buf[LhsCols]; - std::int32_t multiplier_exponent_buf[LhsCols]; -}; - -template -void MakeKernelParams8bit(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, - int start_row, int start_col, int end_row, - int end_col, Matrix* dst, - KernelParams8bit* params) { - using Params = KernelParams8bit; - - static_assert(sizeof(DstScalar) <= Params::kMaxDstTypeSize, ""); - - const int depth = lhs.layout.rows; - RUY_DCHECK_EQ(start_row % LhsCols, 0); - RUY_DCHECK_EQ(start_col % RhsCols, 0); - RUY_DCHECK_EQ(end_row % LhsCols, 0); - RUY_DCHECK_EQ(end_col % RhsCols, 0); - - params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride; - params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride; - params->flags = 0; - params->bias = params->zero_data; - if (spec.bias) { - params->bias = spec.bias; - params->flags |= RUY_ASM_FLAG_HAS_BIAS; - } - if (lhs.sums) { - params->lhs_sums = lhs.sums; - params->flags |= RUY_ASM_FLAG_HAS_LHS_SUMS; - } - if (rhs.sums) { - params->rhs_sums = rhs.sums; - params->flags |= RUY_ASM_FLAG_HAS_RHS_SUMS; - } - params->start_row = start_row; - params->start_col = start_col; - params->last_row = end_row - LhsCols; - params->last_col = end_col - RhsCols; - params->lhs_stride = lhs.layout.stride; - params->rhs_stride = rhs.layout.stride; - params->dst_stride = sizeof(DstScalar) * dst->layout.stride; - params->lhs_zero_point = lhs.zero_point; - params->rhs_zero_point = rhs.zero_point; - params->dst_zero_point = dst->zero_point; - params->depth = depth; - params->prod_zp_depth = lhs.zero_point * rhs.zero_point * depth; - if (spec.multiplier_fixedpoint_perchannel) { - params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT; - params->flags |= RUY_ASM_FLAG_HAS_PERCHANNEL; - params->multiplier_fixedpoint = spec.multiplier_fixedpoint_perchannel; - params->multiplier_exponent = spec.multiplier_exponent_perchannel; - } else { - if (spec.multiplier_exponent > 0) { - params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT; - } - params->multiplier_fixedpoint = params->multiplier_fixedpoint_buf; - params->multiplier_exponent = params->multiplier_exponent_buf; - for (int i = 0; i < LhsCols; i++) { - params->multiplier_fixedpoint_buf[i] = spec.multiplier_fixedpoint; - params->multiplier_exponent_buf[i] = spec.multiplier_exponent; - } - } - params->clamp_min = spec.clamp_min; - params->clamp_max = spec.clamp_max; - params->dst_rows = dst->layout.rows; - params->dst_cols = dst->layout.cols; - - RUY_DCHECK_LT(params->last_row, params->dst_rows); - RUY_DCHECK_LT(params->last_col, params->dst_cols); - - params->dst_type_id = DstTypeId::kValue; - params->dst_base_ptr = - dst->data.get() + start_col * dst->layout.stride + start_row; -} - -template -struct KernelParamsFloat { - const float* lhs_base_ptr; - const float* rhs_base_ptr; - float* dst_base_ptr; - const float* bias; - std::int32_t start_row; - std::int32_t start_col; - std::int32_t last_row; - std::int32_t last_col; - std::int32_t dst_rows; - std::int32_t dst_cols; - std::int32_t lhs_stride; - std::int32_t rhs_stride; - std::int32_t dst_stride; - std::int32_t depth; - float clamp_min; - float clamp_max; - std::uint8_t flags; - const float zero_data[LhsCols] = {0}; - float dst_tmp_buf[LhsCols * RhsCols]; -}; - -template -inline void MakeKernelParamsFloat(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, - int start_row, int start_col, int end_row, - int end_col, Matrix* dst, - KernelParamsFloat* params) { - const int depth = lhs.layout.rows; - RUY_DCHECK_EQ(start_row % LhsCols, 0); - RUY_DCHECK_EQ(start_col % RhsCols, 0); - RUY_DCHECK_EQ(end_row % LhsCols, 0); - RUY_DCHECK_EQ(end_col % RhsCols, 0); - - params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride; - params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride; - params->dst_base_ptr = - dst->data.get() + start_col * dst->layout.stride + start_row; - - std::uint8_t flags = 0; - params->bias = params->zero_data; - if (spec.bias) { - params->bias = spec.bias; - flags |= RUY_ASM_FLAG_HAS_BIAS; - } - params->flags = flags; - params->start_row = start_row; - params->start_col = start_col; - params->last_row = end_row - LhsCols; - params->last_col = end_col - RhsCols; - params->lhs_stride = sizeof(float) * lhs.layout.stride; - params->rhs_stride = sizeof(float) * rhs.layout.stride; - params->dst_stride = sizeof(float) * dst->layout.stride; - params->depth = depth; - params->clamp_min = spec.clamp_min; - params->clamp_max = spec.clamp_max; - params->dst_rows = dst->layout.rows; - params->dst_cols = dst->layout.cols; - - RUY_DCHECK_LT(params->last_row, params->dst_rows); - RUY_DCHECK_LT(params->last_col, params->dst_cols); -} - -#else // ((RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && - // RUY_OPT_ENABLED(RUY_OPT_ASM)) || RUY_PLATFORM(X86) - -template -struct KernelParams8bit {}; - -template -struct KernelParamsFloat {}; - -#endif // ((RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && - // RUY_OPT_ENABLED(RUY_OPT_ASM)) || RUY_PLATFORM(X86) - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_COMMON_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/kernel_sse42.cc b/tensorflow/lite/experimental/ruy/ruy/kernel_sse42.cc deleted file mode 100644 index 46a6d045e6a..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/kernel_sse42.cc +++ /dev/null @@ -1,428 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/kernel.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -#if RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) -#include // IWYU pragma: keep -#endif - -namespace ruy { - -#if !(RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -void Kernel8bitSse42(const KernelParams8bit<8, 8>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void KernelFloatSse42(const KernelParamsFloat<8, 8>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -#else // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -static constexpr int kAvxFloatBlockSize = 8; -static constexpr int kAvx8bitBlockSize = 8; -static constexpr int kAvx8bitInnerSize = 4; - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// When removing this comment, update profiling label below. -void Kernel8bitSse42(const KernelParams8bit<8, 8>& params) { - profiler::ScopeLabel label("Kernel kSse42 8-bit (UNFINISHED)"); - std::int32_t accum_data[kAvx8bitBlockSize][kAvx8bitBlockSize]; - - int bias_ptr_block_increment = - params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; - - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - void* dst_col_ptr = params.dst_base_ptr; - const std::int32_t* bias_col_ptr = params.bias; - if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { - bias_col_ptr += params.start_row; - } - - for (int col = params.start_col; col <= params.last_col; - col += kAvx8bitBlockSize) { - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - void* dst_ptr = dst_col_ptr; - const std::int32_t* bias_ptr = bias_col_ptr; - - for (int row = params.start_row; row <= params.last_row; - row += kAvx8bitBlockSize) { - const int residual_rows = - std::min(params.dst_rows - row, kAvx8bitBlockSize); - const int residual_cols = - std::min(params.dst_cols - col, kAvx8bitBlockSize); - - // Initialize with bias. - std::int32_t initial_accum_data[kAvx8bitBlockSize]; - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - initial_accum_data[i] = 0; - } - for (int i = 0; i < residual_rows; ++i) { - initial_accum_data[i] = bias_ptr[i]; - } - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] = initial_accum_data[i]; - } - } - bias_ptr += bias_ptr_block_increment; - - std::int8_t lhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize]; - std::int8_t rhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize]; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - for (int x = 0; x < kAvx8bitInnerSize; ++x) { - lhs_data[i][x] = lhs_ptr[i * kAvx8bitInnerSize + x]; - rhs_data[i][x] = rhs_ptr[i * kAvx8bitInnerSize + x]; - } - } - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - for (int x = 0; x < kAvx8bitInnerSize; ++x) { - accum_data[j][i] += lhs_data[i][x] * rhs_data[j][x]; - } - } - } - lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - } - - if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && params.rhs_zero_point) { - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] -= - params.rhs_zero_point * params.lhs_sums[row + i]; - } - } - } - if ((params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && params.lhs_zero_point) { - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] -= - params.lhs_zero_point * params.rhs_sums[col + j]; - } - } - } - if (params.lhs_zero_point && params.rhs_zero_point) { - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] += params.prod_zp_depth; - } - } - } - - if (params.dst_type_id != DstTypeId::kValue) { - std::int32_t m_vector[kAvx8bitBlockSize]; - std::int32_t e_vector[kAvx8bitBlockSize]; - // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. - if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { - int i = 0; - for (; i < residual_rows; ++i) { - m_vector[i] = params.multiplier_fixedpoint[row + i]; - e_vector[i] = params.multiplier_exponent[row + i]; - } - for (; i < kAvx8bitBlockSize; ++i) { - m_vector[i] = m_vector[0]; - e_vector[i] = e_vector[0]; - } - } else { - // These arrays have size LhsCols, and are pre-filled. - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - m_vector[i] = params.multiplier_fixedpoint[i]; - e_vector[i] = params.multiplier_exponent[i]; - } - } - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] = MultiplyByQuantizedMultiplier( - accum_data[j][i], m_vector[i], e_vector[i]); - } - } - - if (params.dst_zero_point) { - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] += params.dst_zero_point; - } - } - } - - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] = - std::min(accum_data[j][i], params.clamp_max); - accum_data[j][i] = - std::max(accum_data[j][i], params.clamp_min); - } - } - } - - const bool store_full_block = (residual_rows == kAvx8bitBlockSize) && - (residual_cols == kAvx8bitBlockSize); - - if (params.dst_type_id == DstTypeId::kValue) { - std::int8_t* tmp_ptr = - store_full_block - ? static_cast(dst_ptr) - : const_cast( - reinterpret_cast(params.dst_tmp_buf)); - const int block_col_offset = - store_full_block ? params.dst_stride / sizeof(std::int8_t) - : kAvx8bitBlockSize; - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - - if (!store_full_block) { - const std::int8_t* block_ptr = - reinterpret_cast(params.dst_tmp_buf); - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - static_cast( - dst_ptr)[j * params.dst_stride / sizeof(std::int8_t) + i] = - block_ptr[i]; - } - block_ptr += kAvx8bitBlockSize; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::uint8_t* tmp_ptr = store_full_block - ? static_cast(dst_ptr) - : const_cast( - reinterpret_cast( - params.dst_tmp_buf)); - const int block_col_offset = - store_full_block ? params.dst_stride : kAvx8bitBlockSize; - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - - if (!store_full_block) { - const std::uint8_t* block_ptr = - reinterpret_cast(params.dst_tmp_buf); - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - static_cast( - dst_ptr)[j * params.dst_stride / sizeof(std::uint8_t) + i] = - block_ptr[i]; - } - block_ptr += kAvx8bitBlockSize; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - if (store_full_block) { - std::int16_t* tmp_ptr = static_cast(dst_ptr); - const int block_col_offset = params.dst_stride / sizeof(std::int16_t); - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - } else { - std::int16_t* tmp_ptr = const_cast( - reinterpret_cast(params.dst_tmp_buf)); - const int block_col_offset = kAvx8bitBlockSize; - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - const std::int16_t* block_ptr = - reinterpret_cast(params.dst_tmp_buf); - std::int16_t* dst_block_ptr = static_cast(dst_ptr); - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - dst_block_ptr[i] = block_ptr[i]; - } - dst_block_ptr += params.dst_stride / sizeof(std::int16_t); - block_ptr += kAvx8bitBlockSize; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - if (store_full_block) { - std::int32_t* tmp_ptr = static_cast(dst_ptr); - const int block_col_offset = params.dst_stride / sizeof(std::int32_t); - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - } else { - std::int32_t* dst_block_ptr = static_cast(dst_ptr); - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - dst_block_ptr[i] = accum_data[j][i]; - } - dst_block_ptr += params.dst_stride / sizeof(std::int32_t); - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else { - RUY_DCHECK(false); - } - - lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; - } // End row-block loop. - - dst_col_ptr = static_cast(static_cast(dst_col_ptr) + - kAvx8bitBlockSize * params.dst_stride); - rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; - } // End col-block loop. -} // NOLINT(readability/fn_size) - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// When removing this comment, update profiling label below. -void KernelFloatSse42(const KernelParamsFloat<8, 8>& params) { - profiler::ScopeLabel label("Kernel kSse42 float (UNFINISHED)"); - - float lhs_data[kAvxFloatBlockSize]; - float rhs_data[kAvxFloatBlockSize]; - float accum_data[kAvxFloatBlockSize][kAvxFloatBlockSize]; - int bias_ptr_block_increment = - params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvxFloatBlockSize : 0; - - const float* rhs_col_ptr = params.rhs_base_ptr; - float* dst_col_ptr = params.dst_base_ptr; - const float* bias_col_ptr = params.bias; - if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { - bias_col_ptr += params.start_row; - } - - for (int col = params.start_col; col <= params.last_col; - col += kAvxFloatBlockSize) { - const float* lhs_col_ptr = params.lhs_base_ptr; - float* dst_ptr = dst_col_ptr; - const float* bias_ptr = bias_col_ptr; - - for (int row = params.start_row; row <= params.last_row; - row += kAvxFloatBlockSize) { - const int residual_rows = - std::min(params.dst_rows - row, kAvxFloatBlockSize); - const int residual_cols = - std::min(params.dst_cols - col, kAvxFloatBlockSize); - - // Initialize with bias. - float initial_accum_data[kAvxFloatBlockSize]; - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - initial_accum_data[i] = 0.0f; - } - for (int i = 0; i < residual_rows; ++i) { - initial_accum_data[i] = bias_ptr[i]; - } - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - accum_data[j][i] = initial_accum_data[i]; - } - } - bias_ptr += bias_ptr_block_increment; - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; ++d) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - lhs_data[i] = lhs_ptr[i]; - rhs_data[i] = rhs_ptr[i]; - } - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - accum_data[j][i] += lhs_data[i] * rhs_data[j]; - } - } - lhs_ptr += kAvxFloatBlockSize; - rhs_ptr += kAvxFloatBlockSize; - } - - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - accum_data[j][i] = - std::min(accum_data[j][i], params.clamp_max); - accum_data[j][i] = - std::max(accum_data[j][i], params.clamp_min); - } - } - - const bool store_full_block = (residual_rows == kAvxFloatBlockSize) && - (residual_cols == kAvxFloatBlockSize); - - { - float* block_ptr = - store_full_block ? dst_ptr : const_cast(params.dst_tmp_buf); - const int block_col_offset = store_full_block - ? params.dst_stride / sizeof(float) - : kAvxFloatBlockSize; - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - block_ptr[i] = accum_data[j][i]; - } - block_ptr += block_col_offset; - } - } - if (!store_full_block) { - const float* block_ptr = params.dst_tmp_buf; - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - dst_ptr[j * params.dst_stride / sizeof(float) + i] = block_ptr[i]; - } - block_ptr += kAvxFloatBlockSize; - } - } - - lhs_col_ptr += kAvxFloatBlockSize * params.lhs_stride / sizeof(float); - dst_ptr += kAvxFloatBlockSize; - } // End row-block loop. - - dst_col_ptr += kAvxFloatBlockSize * params.dst_stride / sizeof(float); - rhs_col_ptr += kAvxFloatBlockSize * params.rhs_stride / sizeof(float); - } // End col-block loop. -} - -#endif // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/kernel_x86.h b/tensorflow/lite/experimental/ruy/ruy/kernel_x86.h deleted file mode 100644 index f79f70ab88c..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/kernel_x86.h +++ /dev/null @@ -1,222 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_X86_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_X86_H_ - -#include - -#include "tensorflow/lite/experimental/ruy/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/kernel_common.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/spec.h" -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -namespace ruy { - -#if RUY_PLATFORM(X86) -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -void Kernel8bitSse42(const KernelParams8bit<8, 8>& params); - -template -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, - int start_col, int end_row, int end_col, - Matrix* dst) const { - KernelParams8bit params; - MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, - dst, ¶ms); - Kernel8bitSse42(params); - } -}; - -void KernelFloatSse42(const KernelParamsFloat<8, 8>& params); - -template <> -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, int start_col, - int end_row, int end_col, Matrix* dst) const { - KernelParamsFloat params; - MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, - end_col, dst, ¶ms); - KernelFloatSse42(params); - } -}; - -void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params); -void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params); - -template -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, - int start_col, int end_row, int end_col, - Matrix* dst) const { - KernelParams8bit params; - MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, - dst, ¶ms); - if (dst->layout.cols == 1) { - Kernel8bitAvx512SingleCol(params); - } else { - Kernel8bitAvx512(params); - } - } -}; - -void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params); -void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& param); - -template <> -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, int start_col, - int end_row, int end_col, Matrix* dst) const { - KernelParamsFloat params; - MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, - end_col, dst, ¶ms); - if (dst->layout.cols == 1) { - KernelFloatAvx512SingleCol(params); - } else { - KernelFloatAvx512(params); - } - } -}; - -void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params); -void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params); - -template -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, - int start_col, int end_row, int end_col, - Matrix* dst) const { - KernelParams8bit params; - MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, - dst, ¶ms); - if (dst->layout.cols == 1) { - Kernel8bitAvx2SingleCol(params); - } else { - Kernel8bitAvx2(params); - } - } -}; - -void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params); -void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params); - -template <> -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, int start_col, - int end_row, int end_col, Matrix* dst) const { - KernelParamsFloat params; - MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, - end_col, dst, ¶ms); - if (dst->layout.cols == 1) { - KernelFloatAvx2SingleCol(params); - } else { - KernelFloatAvx2(params); - } - } -}; - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -void Kernel8bitAvxVnni(const KernelParams8bit<16, 16>& params); - -template -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, - int start_col, int end_row, int end_col, - Matrix* dst) const { - KernelParams8bit params; - MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, - dst, ¶ms); - Kernel8bitAvxVnni(params); - } -}; - -void KernelFloatAvxVnni(const KernelParamsFloat<16, 16>& params); - -template <> -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, int start_col, - int end_row, int end_col, Matrix* dst) const { - KernelParamsFloat params; - MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, - end_col, dst, ¶ms); - KernelFloatAvxVnni(params); - } -}; - -#endif // RUY_PLATFORM(X86) - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_X86_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/matrix.h b/tensorflow/lite/experimental/ruy/ruy/matrix.h deleted file mode 100644 index a76f32136c6..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/matrix.h +++ /dev/null @@ -1,182 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_MATRIX_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_MATRIX_H_ - -#include -#include // IWYU pragma: keep -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" - -namespace ruy { - -// Layout storage order. Here and elsewhere, 'col' is short for 'column'. -// 'column-major' means that each column is contiguous in memory. -enum class Order : std::uint8_t { kColMajor, kRowMajor }; - -// Describes the shape and storage layout of a matrix. -struct Layout final { - std::int32_t rows = 0; - std::int32_t cols = 0; - // Stride is the offset between two adjacent matrix elements - // in the non-contiguous direction. - std::int32_t stride = 0; - Order order = Order::kColMajor; -}; - -namespace detail { - -// Thin wrapper around a pointer that tracks its constness dynamically. -// -// This is our take on the C++ problem of enforcing constness of data -// wrapped in a containers class: it's not worth the hassle of trying to -// make it fully work at compile-time. -// Instead, we only enforce constness at runtime, and to make it -// zero-overhead, we only enforce it in debug builds. -template -class ConstCheckingPtr final { - public: - using element_type = T; - - // Convenience methods. Most `set` calls go through these. - ConstCheckingPtr& operator=(T* ptr) { - set(ptr); - return *this; - } - ConstCheckingPtr& operator=(const T* ptr) { - set(ptr); - return *this; - } - ConstCheckingPtr& operator=(std::nullptr_t) { - set(static_cast(nullptr)); - return *this; - } - - // Core accessors. These encapsulate the main logic: - // - for `set`, the constness of the argument determines whether internal - // pointer should be tracked as const/mutable. - // - for `get`, the constness of `this` determines whether the call - // counts as a const or mutable use of the internal pointer. - void set(T* ptr) { - ptr_ = ptr; - set_mutable(true); - } - void set(const T* ptr) { - ptr_ = ptr; - set_mutable(false); - } - T* get() /* NOT const */ { - assert_mutable(); - return const_cast(ptr_); - } - const T* get() const { return ptr_; } - - private: - static_assert(!std::is_const::value, ""); - const T* ptr_ = nullptr; -#ifndef NDEBUG - bool is_mutable_ = true; - void set_mutable(bool val) { is_mutable_ = val; } - void assert_mutable() { RUY_DCHECK(is_mutable_); } -#else - void set_mutable(bool) {} - void assert_mutable() {} -#endif -}; - -} // namespace detail - -// A Matrix is really what Eigen and gemmlowp would have called a 'matrix map': -// it merely wraps existing data as a matrix. It doesn't own any buffer. -// Scalar may be any floating-point or integral type. When integral, it may be -// signed or unsigned. -template -struct Matrix final { - Matrix& operator=(const Matrix& other) { - data = other.data; - cacheable = other.cacheable; - layout = other.layout; - zero_point = other.zero_point; - return *this; - } - - // The underlying buffer wrapped by this matrix. - detail::ConstCheckingPtr data; - // The shape and data layout of this matrix. - Layout layout; - // The zero_point, i.e. which Scalar value is to be interpreted as zero. - // When Scalar is floating-point, this must be 0. - Scalar zero_point = 0; - // Clients of Ruy must set this flag to enable any caching behavior. Doesn't - // impact numerical results, but caching can impact observable metrics like - // latency, memory usage, power, etc. - bool cacheable = false; -}; - -inline void MakeSimpleLayout(int rows, int cols, Order order, Layout* layout) { - layout->rows = rows; - layout->cols = cols; - layout->order = order; - layout->stride = order == Order::kColMajor ? rows : cols; -} - -// Opaque data structure representing a pre-packed matrix, as obtained from -// Ruy's advanced API. -struct PrepackedMatrix { - void* data = nullptr; - std::size_t data_size = 0; - void* sums = nullptr; - std::size_t sums_size = 0; -}; - -template -StreamType& operator<<(StreamType& stream, const Matrix& mat) { - for (int row = 0; row < mat.layout.rows; row++) { - for (int col = 0; col < mat.layout.cols; col++) { - stream << static_cast(Element(mat, row, col)) << " "; - } - stream << "\n"; - } - return stream; -} - -// Compile-time version of KernelLayout, used to declare kernel layouts in a -// way that can be consumed by compile-time logic. -// See how partial specializations of Kernel use it to declare their layouts. -// The only reason why this is currently part of the public API is to -// allow testing various layouts for the Path::kStandardCpp kernel, as a -// testing-only feature. See Spec::StandardCppKernelLhsLayout. -template -struct FixedKernelLayout { - static constexpr Order kOrder = tOrder; - static constexpr int kRows = tRows; - static constexpr int kCols = tCols; -}; - -#if (__cplusplus < 201703L) -// A static constexpr data member is automatically inline and should not require -// redeclaration without an initializer. This is actually deprecated from C++17 -// onwards. Clang with -O0 without this can fail to link. -template -constexpr int FixedKernelLayout::kCols; -template -constexpr int FixedKernelLayout::kRows; -#endif - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_MATRIX_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/opt_set.h b/tensorflow/lite/experimental/ruy/ruy/opt_set.h deleted file mode 100644 index fef0107ed01..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/opt_set.h +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_OPT_SET_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_OPT_SET_H_ - -// RUY_OPT_SET is a compile-time API that Ruy provides for enabling/disabling -// certain optimizations. It should be used by defining that macro on the -// compiler command line. -// -// Each bit in RUY_OPT_SET controls a particular optimization done in Ruy. -#define RUY_OPT_INTRINSICS 0x1 -#define RUY_OPT_ASM 0x2 -#define RUY_OPT_TUNING 0x4 -#define RUY_OPT_FAT_KERNEL 0x8 -#define RUY_OPT_NATIVE_ROUNDING 0x10 -#define RUY_OPT_AVOID_ALIASING 0x20 -#define RUY_OPT_MAX_STREAMING 0x40 -#define RUY_OPT_PACK_AHEAD 0x80 -#define RUY_OPT_PREFETCH_LOAD 0x100 -#define RUY_OPT_PREFETCH_STORE 0x200 -#define RUY_OPT_FRACTAL_Z 0x400 -#define RUY_OPT_FRACTAL_U 0x800 -#define RUY_OPT_FRACTAL_HILBERT 0x1000 - -#if !defined(RUY_OPT_SET) -#ifdef RUY_OPTIMIZE_FOR_MATMUL_BENCHMARK -// Load prefetching is detrimental in matrix multiplication benchmarks. -// Store prefetching is not. -#define RUY_OPT_SET (~RUY_OPT_PREFETCH_LOAD) -#else -// Default to all optimizations. -#define RUY_OPT_SET (~0) -#endif -#endif - -#define RUY_OPT_ENABLED(ruy_opt) ((RUY_OPT_SET & ruy_opt) != 0) - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_OPT_SET_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/pack.h b/tensorflow/lite/experimental/ruy/ruy/pack.h deleted file mode 100644 index 96040aa1039..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/pack.h +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// # What is "packing"? -// -// Before feeding data to the gemm kernels (the parts of Ruy that do lots -// of multiply-add operations), Ruy first performs a data transformation (which -// we call "packing") on the input matrices. This transformation has two main -// goals: -// - rearrange data into blocks that are a convenient size/layout for the gemm -// kernels to consume. This helps make the memory access pattern of the gemm -// kernel simpler and more contiguous, and puts the data in a layout most -// convenient for specific arithmetic instructions in the gemm kernel. -// - compute row/column sums needed for handling quantization with non-symmetric -// zero points. -// -// # Simplified algorithmic analysis of packing -// -// Packing is a relatively simple transformation which does a small constant -// amount of work on each element of an input matrix, and hence for an NxM -// matrix performs O(N*M) work. If N and M are of the same order, then this is -// O(N^2) work. -// -// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations. -// Note that if N, K, and M are all the same order, then the number of -// multiply-accumulate operations is O(N^3). -// -// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the -// case of all dimensions being roughly the same order. -// -// # Packing cost can be significant -// -// When matrix * matrix multiplications begin to look more like matrix * vector -// multiplications, packing cost can become significant. We sometimes call these -// cases "gemv-like". -// -// Continuing the algorithmic analysis above, if we consider a case where an -// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the -// situation is different. In this case, the multiply-accumulate work is only -// quadratic, so the quadratic cost of packing can be come significant. -// -// Another way to say this is that the cost of packing an input matrix (either -// the LHS or RHS) is amortized across the non-depth dimension of the opposite -// input matrix. Thus, when the LHS has very few rows or the RHS has very few -// columns, the cost of packing the opposite input matrix can become -// significant. -// -// As a rough rule of thumb, the cost of packing starts to become significant -// when either N or M is below 32 (and other dimensions are hundreds), with very -// significant packing costs at 8 or below. This varies by data type, Path, and -// tuning, so these numbers are only rough guides. -// -// One practical use case that is affected by this is inference of -// fully connected neural network layers with a low batch size. The weight -// matrix (which is a constant for inference) is the one affected by significant -// packing cost. -// -// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack -// input matrices that are affected by significant packing costs. -// -// # Implementation notes -// -// Ruy's packing routines always operate on a range of columns and can be -// applied to either the LHS or RHS. This is possible because Ruy internally -// implements a TrMul, so the accumulation along depth is done along columns of -// both the LHS and RHS (whereas for a normal Mul the accumulation along depth -// for the LHS is along rows). As another example, we are always computing -// column sums for quantization (and never row sums, since the LHS is -// transposed). - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_H_ - -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" - -// IWYU pragma: begin_exports -#if RUY_PLATFORM(NEON) -#include "tensorflow/lite/experimental/ruy/ruy/pack_arm.h" -#elif RUY_PLATFORM(X86) -#include "tensorflow/lite/experimental/ruy/ruy/pack_x86.h" -#else -#include "tensorflow/lite/experimental/ruy/ruy/pack_common.h" -#endif -// IWYU pragma: end_exports - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/pack_arm.cc b/tensorflow/lite/experimental/ruy/ruy/pack_arm.cc deleted file mode 100644 index 52b55a57cc6..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/pack_arm.cc +++ /dev/null @@ -1,1936 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include - -#include "tensorflow/lite/experimental/ruy/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/pack.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -namespace ruy { - -#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -void Pack8bitNeonOutOfOrder(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, int src_zero_point, - std::int8_t* packed_ptr, int start_col, int end_col, - std::int32_t* sums_ptr, int input_xor) { - profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)"); - asm volatile( - // clang-format off - "dup v26.16b, %w[input_xor]\n" - "mov w1, #0\n" - "dup v28.4s, wzr\n" - "dup v29.4s, wzr\n" - "dup v30.4s, wzr\n" - "dup v31.4s, wzr\n" - - "and w2, %w[rows], #-16\n" - "cmp w1, w2\n" - "beq 3f\n" - - "add w1, w1, #16\n" - "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" - "cmp w1, w2\n" - "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" - "beq 2f\n" - - "1:\n" - - "add w1, w1, #16\n" - "eor v4.16b, v0.16b, v26.16b\n" - "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" - "eor v5.16b, v1.16b, v26.16b\n" - "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" - "eor v6.16b, v2.16b, v26.16b\n" - "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" - "eor v7.16b, v3.16b, v26.16b\n" - "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" - - "saddlp v16.8h, v4.16b\n" - "str q4, [%[packed_ptr], #0]\n" - "saddlp v17.8h, v5.16b\n" - "str q5, [%[packed_ptr], #16]\n" - "saddlp v18.8h, v6.16b\n" - "str q6, [%[packed_ptr], #32]\n" - "saddlp v19.8h, v7.16b\n" - "str q7, [%[packed_ptr], #48]\n" - "sadalp v28.4s, v16.8h\n" - "cmp w1, w2\n" - "sadalp v29.4s, v17.8h\n" - "add %[packed_ptr], %[packed_ptr], #64\n" - "sadalp v30.4s, v18.8h\n" - "sadalp v31.4s, v19.8h\n" - - "bne 1b\n" - - "2:\n" - - "eor v4.16b, v0.16b, v26.16b\n" - "eor v5.16b, v1.16b, v26.16b\n" - "eor v6.16b, v2.16b, v26.16b\n" - "eor v7.16b, v3.16b, v26.16b\n" - - "saddlp v16.8h, v4.16b\n" - "str q4, [%[packed_ptr], #0]\n" - "saddlp v17.8h, v5.16b\n" - "str q5, [%[packed_ptr], #16]\n" - "saddlp v18.8h, v6.16b\n" - "str q6, [%[packed_ptr], #32]\n" - "saddlp v19.8h, v7.16b\n" - "str q7, [%[packed_ptr], #48]\n" - "sadalp v28.4s, v16.8h\n" - "sadalp v29.4s, v17.8h\n" - "sadalp v30.4s, v18.8h\n" - "sadalp v31.4s, v19.8h\n" - - "add %[packed_ptr], %[packed_ptr], #64\n" - - "3:\n" - - "ands w2, %w[rows], #15\n" - "beq 4f\n" - "dup v0.16b, %w[src_zero_point]\n" - "dup v1.16b, %w[src_zero_point]\n" - "dup v2.16b, %w[src_zero_point]\n" - "dup v3.16b, %w[src_zero_point]\n" -#define RUY_LOAD_ONE_ROW(R) \ - "cmp w2, #" #R "\n" \ - "beq 5f\n" \ - "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ - "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ - "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ - "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" - - RUY_LOAD_ONE_ROW(0) - RUY_LOAD_ONE_ROW(1) - RUY_LOAD_ONE_ROW(2) - RUY_LOAD_ONE_ROW(3) - RUY_LOAD_ONE_ROW(4) - RUY_LOAD_ONE_ROW(5) - RUY_LOAD_ONE_ROW(6) - RUY_LOAD_ONE_ROW(7) - RUY_LOAD_ONE_ROW(8) - RUY_LOAD_ONE_ROW(9) - RUY_LOAD_ONE_ROW(10) - RUY_LOAD_ONE_ROW(11) - RUY_LOAD_ONE_ROW(12) - RUY_LOAD_ONE_ROW(13) - RUY_LOAD_ONE_ROW(14) - RUY_LOAD_ONE_ROW(15) -#undef RUY_LOAD_ONE_ROW - "5:\n" - - "eor v4.16b, v0.16b, v26.16b\n" - "eor v5.16b, v1.16b, v26.16b\n" - "eor v6.16b, v2.16b, v26.16b\n" - "eor v7.16b, v3.16b, v26.16b\n" - - "saddlp v16.8h, v4.16b\n" - "saddlp v17.8h, v5.16b\n" - "saddlp v18.8h, v6.16b\n" - "saddlp v19.8h, v7.16b\n" - "sadalp v28.4s, v16.8h\n" - "sadalp v29.4s, v17.8h\n" - "sadalp v30.4s, v18.8h\n" - "sadalp v31.4s, v19.8h\n" - - "str q4, [%[packed_ptr], #0]\n" - "str q5, [%[packed_ptr], #16]\n" - "str q6, [%[packed_ptr], #32]\n" - "str q7, [%[packed_ptr], #48]\n" - "add %[packed_ptr], %[packed_ptr], #64\n" - - "4:\n" - - "addp v28.4s, v28.4s, v29.4s\n" - "addp v30.4s, v30.4s, v31.4s\n" - "addp v28.4s, v28.4s, v30.4s\n" - - "cmp %[sums_ptr], #0\n" - "beq 6f\n" - "st1 {v28.4s}, [%[sums_ptr]], #16\n" - "6:\n" - // clang-format on - - : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), - [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), - [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr) - : [ src_inc0 ] "r"(static_cast(src_inc0)), - [ src_inc1 ] "r"(static_cast(src_inc1)), - [ src_inc2 ] "r"(static_cast(src_inc2)), - [ src_inc3 ] "r"(static_cast(src_inc3)), - [ rows ] "r"(src_rows), [ src_zero_point ] "r"(src_zero_point), - [ input_xor ] "r"(input_xor) - : "cc", "memory", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5", "v6", - "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", - "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", - "v27", "v28", "v29", "v30", "v31"); -} -#endif - -#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -#define RUY_OFFSET_SRC_PTR0 0 -#define RUY_OFFSET_SRC_PTR1 4 -#define RUY_OFFSET_SRC_PTR2 8 -#define RUY_OFFSET_SRC_PTR3 12 -#define RUY_OFFSET_SUMS_PTR 16 -#define RUY_OFFSET_PACKED_PTR 20 -#define RUY_OFFSET_SRC_INC0 24 -#define RUY_OFFSET_SRC_INC1 28 -#define RUY_OFFSET_SRC_INC2 32 -#define RUY_OFFSET_SRC_INC3 36 -#define RUY_OFFSET_SRC_ROWS 40 -#define RUY_OFFSET_SRC_ZERO_POINT 44 -#define RUY_OFFSET_INPUT_XOR 48 - -template -void CheckOffsetsInPackParams8bit(const Params&) { - static_assert(offsetof(Params, src_ptr0) == RUY_OFFSET_SRC_PTR0, ""); - static_assert(offsetof(Params, src_ptr1) == RUY_OFFSET_SRC_PTR1, ""); - static_assert(offsetof(Params, src_ptr2) == RUY_OFFSET_SRC_PTR2, ""); - static_assert(offsetof(Params, src_ptr3) == RUY_OFFSET_SRC_PTR3, ""); - static_assert(offsetof(Params, sums_ptr) == RUY_OFFSET_SUMS_PTR, ""); - static_assert(offsetof(Params, packed_ptr) == RUY_OFFSET_PACKED_PTR, ""); - static_assert(offsetof(Params, src_inc0) == RUY_OFFSET_SRC_INC0, ""); - static_assert(offsetof(Params, src_inc1) == RUY_OFFSET_SRC_INC1, ""); - static_assert(offsetof(Params, src_inc2) == RUY_OFFSET_SRC_INC2, ""); - static_assert(offsetof(Params, src_inc3) == RUY_OFFSET_SRC_INC3, ""); - static_assert(offsetof(Params, src_rows) == RUY_OFFSET_SRC_ROWS, ""); - static_assert(offsetof(Params, src_zero_point) == RUY_OFFSET_SRC_ZERO_POINT, - ""); - static_assert(offsetof(Params, input_xor) == RUY_OFFSET_INPUT_XOR, ""); -} - -// Packing code for out-of-order ARMv7 CPUs like the Krait 400 or A9. -// No attempt made at making this code efficient on in-order cores yet. -void Pack8bitNeonOutOfOrder4Cols(const PackParams8bit& params) { - CheckOffsetsInPackParams8bit(params); - profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)"); - const void* src_ptr0 = params.src_ptr0; - const void* src_ptr1 = params.src_ptr1; - const void* src_ptr2 = params.src_ptr2; - const void* src_ptr3 = params.src_ptr3; - const int src_inc0 = params.src_inc0; - const int src_inc1 = params.src_inc1; - const int src_inc2 = params.src_inc2; - const int src_inc3 = params.src_inc3; - const std::int8_t* packed_ptr = params.packed_ptr; - - asm volatile( - // clang-format off - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_INPUT_XOR) "]\n" - "vdup.8 q11, r2\n" - "mov r1, #0\n" - // Zero-out the accumulators - "vmov.i32 q12, #0\n" - "vmov.i32 q13, #0\n" - "vmov.i32 q14, #0\n" - "vmov.i32 q15, #0\n" - - // Round down src_rows to nearest multiple of 16. - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" - "and r2, r3, #-16\n" - "cmp r1, r2\n" - "beq 3f\n" - - "1:\n" - "add r1, r1, #16\n" - /* Load q0 */ - "vld1.8 {d0, d1}, [%[src_ptr0]]\n" - "add %[src_ptr0], %[src_ptr0], %[src_inc0]\n" - RUY_PREFETCH_LOAD("pld [%[src_ptr0]]\n") - - /* Load q1 */ - "vld1.8 {d2, d3}, [%[src_ptr1]]\n" - "add %[src_ptr1], %[src_ptr1], %[src_inc1]\n" - RUY_PREFETCH_LOAD("pld [%[src_ptr1]]\n") - - "veor.8 q4, q0, q11\n" - "veor.8 q5, q1, q11\n" - - // Pairwise add in to 16b accumulators. - "vpaddl.s8 q8, q4\n" - "vpaddl.s8 q9, q5\n" - - "vst1.32 {q4}, [%[packed_ptr]]!\n" - "vst1.32 {q5}, [%[packed_ptr]]!\n" - - // Pairwise add accumulate into 32b accumulators. - // q12 and q13 contain 4x32b accumulators - "vpadal.s16 q12, q8\n" - "vpadal.s16 q13, q9\n" - - // Now do the same for src_ptr2 and src_ptr3. - "vld1.8 {d0, d1}, [%[src_ptr2]]\n" - "add %[src_ptr2], %[src_ptr2], %[src_inc2]\n" - RUY_PREFETCH_LOAD("pld [%[src_ptr2]]\n") - - "vld1.8 {d2, d3}, [%[src_ptr3]]\n" - "add %[src_ptr3], %[src_ptr3], %[src_inc3]\n" - RUY_PREFETCH_LOAD("pld [%[src_ptr3]]\n") - - "veor.8 q4, q0, q11\n" - "veor.8 q5, q1, q11\n" - - "vpaddl.s8 q8, q4\n" - "vpaddl.s8 q9, q5\n" - - "vst1.32 {q4}, [%[packed_ptr]]!\n" - "vst1.32 {q5}, [%[packed_ptr]]!\n" - - // Pairwise add accumulate into 32b accumulators. - // q14 and q15 contain 4x32b accumulators - "vpadal.s16 q14, q8\n" - "vpadal.s16 q15, q9\n" - - "cmp r1, r2\n" - "bne 1b\n" - - "3:\n" - - // Now pack the last (num_rows % 16) rows. - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" - "ands r2, r3, #15\n" - "beq 4f\n" - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ZERO_POINT) "]\n" - "vdup.8 q0, r3\n" - "vdup.8 q1, r3\n" - -// First, read/accumulate/write for src_ptr0 and src_ptr1. -#define RUY_LOAD_ONE_ROW1(I, R) \ - "cmp r2, #" #I "\n" \ - "beq 5f\n" \ - "vld1.8 { d0[" #R "]}, [%[src_ptr0]]!\n" \ - "vld1.8 { d2[" #R "]}, [%[src_ptr1]]!\n" \ - - RUY_LOAD_ONE_ROW1(0, 0) - RUY_LOAD_ONE_ROW1(1, 1) - RUY_LOAD_ONE_ROW1(2, 2) - RUY_LOAD_ONE_ROW1(3, 3) - RUY_LOAD_ONE_ROW1(4, 4) - RUY_LOAD_ONE_ROW1(5, 5) - RUY_LOAD_ONE_ROW1(6, 6) - RUY_LOAD_ONE_ROW1(7, 7) -#undef RUY_LOAD_ONE_ROW1 - -#define RUY_LOAD_ONE_ROW2(I, R) \ - "cmp r2, #" #I "\n" \ - "beq 5f\n" \ - "vld1.8 { d1[" #R "]}, [%[src_ptr0]]!\n" \ - "vld1.8 { d3[" #R "]}, [%[src_ptr1]]!\n" \ - - RUY_LOAD_ONE_ROW2(8, 0) - RUY_LOAD_ONE_ROW2(9, 1) - RUY_LOAD_ONE_ROW2(10, 2) - RUY_LOAD_ONE_ROW2(11, 3) - RUY_LOAD_ONE_ROW2(12, 4) - RUY_LOAD_ONE_ROW2(13, 5) - RUY_LOAD_ONE_ROW2(14, 6) - RUY_LOAD_ONE_ROW2(15, 7) -#undef RUY_LOAD_ONE_ROW2 - - "5:\n" - - "veor.16 q4, q0, q11\n" - "veor.16 q5, q1, q11\n" - - "vpaddl.s8 q8, q4\n" - "vpaddl.s8 q9, q5\n" - - // Pairwise add accumulate to 4x32b accumulators. - "vpadal.s16 q12, q8\n" - "vpadal.s16 q13, q9\n" - - "vst1.32 {q4}, [%[packed_ptr]]!\n" - "vst1.32 {q5}, [%[packed_ptr]]!\n" - - // Reset to src_zero for src_ptr2 and src_ptr3. - "vdup.8 q0, r3\n" - "vdup.8 q1, r3\n" - -// Next, read/accumulate/write for src_ptr2 and src_ptr3. -#define RUY_LOAD_ONE_ROW1(I, R) \ - "cmp r2, #" #I "\n" \ - "beq 5f\n" \ - "vld1.8 { d0[" #R "]}, [%[src_ptr2]]!\n" \ - "vld1.8 { d2[" #R "]}, [%[src_ptr3]]!\n" \ - - RUY_LOAD_ONE_ROW1(0, 0) - RUY_LOAD_ONE_ROW1(1, 1) - RUY_LOAD_ONE_ROW1(2, 2) - RUY_LOAD_ONE_ROW1(3, 3) - RUY_LOAD_ONE_ROW1(4, 4) - RUY_LOAD_ONE_ROW1(5, 5) - RUY_LOAD_ONE_ROW1(6, 6) - RUY_LOAD_ONE_ROW1(7, 7) -#undef RUY_LOAD_ONE_ROW1 - -#define RUY_LOAD_ONE_ROW2(I, R) \ - "cmp r2, #" #I "\n" \ - "beq 5f\n" \ - "vld1.8 { d1[" #R "]}, [%[src_ptr2]]!\n" \ - "vld1.8 { d3[" #R "]}, [%[src_ptr3]]!\n" \ - - RUY_LOAD_ONE_ROW2(8, 0) - RUY_LOAD_ONE_ROW2(9, 1) - RUY_LOAD_ONE_ROW2(10, 2) - RUY_LOAD_ONE_ROW2(11, 3) - RUY_LOAD_ONE_ROW2(12, 4) - RUY_LOAD_ONE_ROW2(13, 5) - RUY_LOAD_ONE_ROW2(14, 6) - RUY_LOAD_ONE_ROW2(15, 7) -#undef RUY_LOAD_ONE_ROW2 - - "5:\n" - - "veor.16 q4, q0, q11\n" - "veor.16 q5, q1, q11\n" - - "vpaddl.s8 q8, q4\n" - "vpaddl.s8 q9, q5\n" - - // Pairwise add accumulate to 4x32b accumulators. - "vpadal.s16 q14, q8\n" - "vpadal.s16 q15, q9\n" - - "vst1.32 {q4}, [%[packed_ptr]]!\n" - "vst1.32 {q5}, [%[packed_ptr]]!\n" - - "4:\n" - // Pairwise add 32-bit accumulators - "vpadd.i32 d24, d24, d25\n" - "vpadd.i32 d26, d26, d27\n" - "vpadd.i32 d28, d28, d29\n" - "vpadd.i32 d30, d30, d31\n" - // Final 32-bit values per row - "vpadd.i32 d25, d24, d26\n" - "vpadd.i32 d27, d28, d30\n" - - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SUMS_PTR) "]\n" - "cmp r3, #0\n" - "beq 6f\n" - "vst1.32 {d25}, [r3]!\n" - "vst1.32 {d27}, [r3]!\n" - "6:\n" - // clang-format on - - : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), - [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3) - : [ src_inc0 ] "r"(src_inc0), [ src_inc1 ] "r"(src_inc1), - [ src_inc2 ] "r"(src_inc2), [ src_inc3 ] "r"(src_inc3), - [ packed_ptr ] "r"(packed_ptr), [ params ] "r"(¶ms) - : "cc", "memory", "r1", "r2", "r3", "q0", "q1", "q2", "q3", - "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13"); -} - -// Packing code for out-of-order ARMv7 CPUs like the Krait 400 or A9. -// No attempt made at making this code efficient on in-order cores yet. -// This version differs from the above in that we only handle two columns -// at a time. -void Pack8bitNeonOutOfOrder2Cols(const PackParams8bit& params) { - CheckOffsetsInPackParams8bit(params); - profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)"); - const void* src_ptr0 = params.src_ptr0; - const void* src_ptr1 = params.src_ptr1; - const int src_inc0 = params.src_inc0; - const int src_inc1 = params.src_inc1; - const std::int8_t* packed_ptr = params.packed_ptr; - - asm volatile( - // clang-format off - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_INPUT_XOR) "]\n" - "vdup.8 q11, r2\n" - "mov r1, #0\n" - // Zero-out the accumulators - "vmov.i32 q12, #0\n" - "vmov.i32 q13, #0\n" - - // Round down src_rows to nearest multiple of 16. - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" - "and r2, r3, #-16\n" - "cmp r1, r2\n" - "beq 3f\n" - - "1:\n" - "add r1, r1, #16\n" - /* Load q0 */ - "vld1.8 {d0, d1}, [%[src_ptr0]]\n" - "add %[src_ptr0], %[src_ptr0], %[src_inc0]\n" - - /* Load q1 */ - "vld1.8 {d2, d3}, [%[src_ptr1]]\n" - "add %[src_ptr1], %[src_ptr1], %[src_inc1]\n" - - "veor.8 q4, q0, q11\n" - "veor.8 q5, q1, q11\n" - - // Pairwise add in to 16b accumulators. - "vpaddl.s8 q8, q4\n" - "vpaddl.s8 q9, q5\n" - - "vst1.32 {q4}, [%[packed_ptr]]!\n" - "vst1.32 {q5}, [%[packed_ptr]]!\n" - - // Pairwise add accumulate into 32b accumulators. - // q12 and q13 contain 4x32b accumulators - "vpadal.s16 q12, q8\n" - "vpadal.s16 q13, q9\n" - - "cmp r1, r2\n" - - "bne 1b\n" - - "3:\n" - - // Now pack the last (num_rows % 16) rows. - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" - "ands r2, r3, #15\n" - "beq 4f\n" - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ZERO_POINT) "]\n" - "vdup.8 q0, r3\n" - "vdup.8 q1, r3\n" - -// Read/accumulate/write for src_ptr0 and src_ptr1. -#define RUY_LOAD_ONE_ROW1(I, R) \ - "cmp r2, #" #I "\n" \ - "beq 5f\n" \ - "vld1.8 { d0[" #R "]}, [%[src_ptr0]]!\n" \ - "vld1.8 { d2[" #R "]}, [%[src_ptr1]]!\n" \ - - RUY_LOAD_ONE_ROW1(0, 0) - RUY_LOAD_ONE_ROW1(1, 1) - RUY_LOAD_ONE_ROW1(2, 2) - RUY_LOAD_ONE_ROW1(3, 3) - RUY_LOAD_ONE_ROW1(4, 4) - RUY_LOAD_ONE_ROW1(5, 5) - RUY_LOAD_ONE_ROW1(6, 6) - RUY_LOAD_ONE_ROW1(7, 7) -#undef RUY_LOAD_ONE_ROW1 - -#define RUY_LOAD_ONE_ROW2(I, R) \ - "cmp r2, #" #I "\n" \ - "beq 5f\n" \ - "vld1.8 { d1[" #R "]}, [%[src_ptr0]]!\n" \ - "vld1.8 { d3[" #R "]}, [%[src_ptr1]]!\n" \ - - RUY_LOAD_ONE_ROW2(8, 0) - RUY_LOAD_ONE_ROW2(9, 1) - RUY_LOAD_ONE_ROW2(10, 2) - RUY_LOAD_ONE_ROW2(11, 3) - RUY_LOAD_ONE_ROW2(12, 4) - RUY_LOAD_ONE_ROW2(13, 5) - RUY_LOAD_ONE_ROW2(14, 6) - RUY_LOAD_ONE_ROW2(15, 7) -#undef RUY_LOAD_ONE_ROW2 - - "5:\n" - - "veor.16 q4, q0, q11\n" - "veor.16 q5, q1, q11\n" - - "vpaddl.s8 q8, q4\n" - "vpaddl.s8 q9, q5\n" - - - // Pairwise add accumulate to 4x32b accumulators. - "vpadal.s16 q12, q8\n" - "vpadal.s16 q13, q9\n" - - "vst1.32 {q4}, [%[packed_ptr]]!\n" - "vst1.32 {q5}, [%[packed_ptr]]!\n" - - "4:\n" - - // Pairwise add 32-bit accumulators - "vpadd.i32 d24, d24, d25\n" - "vpadd.i32 d26, d26, d27\n" - // Final 32-bit values per row - "vpadd.i32 d25, d24, d26\n" - - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SUMS_PTR) "]\n" - "cmp r3, #0\n" - "beq 6f\n" - "vst1.32 {d25}, [r3]!\n" - "6:\n" - // clang-format on - - : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1) - : [ src_inc0 ] "r"(src_inc0), [ src_inc1 ] "r"(src_inc1), - [ packed_ptr ] "r"(packed_ptr), [ params ] "r"(¶ms) - : "cc", "memory", "r1", "r2", "r3", "q0", "q1", "q2", "q3", - "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13"); -} - -#undef RUY_OFFSET_SRC_PTR0 -#undef RUY_OFFSET_SRC_PTR1 -#undef RUY_OFFSET_SRC_PTR2 -#undef RUY_OFFSET_SRC_PTR32 -#undef RUY_OFFSET_SUMS_PTR -#undef RUY_OFFSET_PACKED_PTR0 -#undef RUY_OFFSET_SRC_INC0 -#undef RUY_OFFSET_SRC_INC1 -#undef RUY_OFFSET_SRC_INC2 -#undef RUY_OFFSET_SRC_INC3 -#undef RUY_OFFSET_SRC_ROWS -#undef RUY_OFFSET_SRC_ZERO_POINT -#undef RUY_OFFSET_INPUT_XOR - -#endif // RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -void Pack8bitNeonInOrder(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, int src_inc3, - int src_rows, int src_zero_point, - std::int8_t* packed_ptr, int start_col, int end_col, - std::int32_t* sums_ptr, int input_xor) { - profiler::ScopeLabel label("Pack (kNeon, optimized for in-order cores)"); - asm volatile( - // clang-format off - "dup v26.16b, %w[input_xor]\n" - "mov w1, #0\n" - "dup v28.4s, wzr\n" - "dup v29.4s, wzr\n" - "dup v30.4s, wzr\n" - "dup v31.4s, wzr\n" - - "and w2, %w[rows], #-16\n" - "cmp w1, w2\n" - "beq 3f\n" - "ldr x10, [%[src_ptr0], #8]\n" - "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" - "ldr x11, [%[src_ptr1], #8]\n" - "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" - "ldr x12, [%[src_ptr2], #8]\n" - "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" - "ldr x13, [%[src_ptr3], #8]\n" - "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n") - "add w1, w1, #16\n" - "cmp w1, w2\n" - - "beq 2f\n" - - "1:\n" - "add w1, w1, #16\n" - "ins v0.d[1], x10\n" - "ldr x10, [%[src_ptr0], #8]\n" - "ins v1.d[1], x11\n" - "ldr x11, [%[src_ptr1], #8]\n" - "ins v2.d[1], x12\n" - "ldr x12, [%[src_ptr2], #8]\n" - "ins v3.d[1], x13\n" - "ldr x13, [%[src_ptr3], #8]\n" - "eor v4.16b, v0.16b, v26.16b\n" - "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" - "eor v5.16b, v1.16b, v26.16b\n" - "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" - "eor v6.16b, v2.16b, v26.16b\n" - "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" - "eor v7.16b, v3.16b, v26.16b\n" - "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" - "saddlp v16.8h, v4.16b\n" - "str q4, [%[packed_ptr], #0]\n" - "saddlp v17.8h, v5.16b\n" - "str q5, [%[packed_ptr], #16]\n" - "saddlp v18.8h, v6.16b\n" - "str q6, [%[packed_ptr], #32]\n" - "saddlp v19.8h, v7.16b\n" - "str q7, [%[packed_ptr], #48]\n" - "sadalp v28.4s, v16.8h\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n") - "cmp w1, w2\n" - "sadalp v29.4s, v17.8h\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n") - "add %[packed_ptr], %[packed_ptr], #64\n" - "sadalp v30.4s, v18.8h\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n") - "sadalp v31.4s, v19.8h\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n") - - "bne 1b\n" - - "2:\n" - "ins v0.d[1], x10\n" - "ins v1.d[1], x11\n" - "ins v2.d[1], x12\n" - "ins v3.d[1], x13\n" - "eor v4.16b, v0.16b, v26.16b\n" - "eor v5.16b, v1.16b, v26.16b\n" - "eor v6.16b, v2.16b, v26.16b\n" - "eor v7.16b, v3.16b, v26.16b\n" - - "saddlp v16.8h, v4.16b\n" - "str q4, [%[packed_ptr], #0]\n" - "saddlp v17.8h, v5.16b\n" - "str q5, [%[packed_ptr], #16]\n" - "saddlp v18.8h, v6.16b\n" - "str q6, [%[packed_ptr], #32]\n" - "saddlp v19.8h, v7.16b\n" - "str q7, [%[packed_ptr], #48]\n" - "sadalp v28.4s, v16.8h\n" - "sadalp v29.4s, v17.8h\n" - "sadalp v30.4s, v18.8h\n" - "sadalp v31.4s, v19.8h\n" - - "add %[packed_ptr], %[packed_ptr], #64\n" - - "3:\n" - - "ands w2, %w[rows], #15\n" - "beq 4f\n" - "dup v0.16b, %w[src_zero_point]\n" - "dup v1.16b, %w[src_zero_point]\n" - "dup v2.16b, %w[src_zero_point]\n" - "dup v3.16b, %w[src_zero_point]\n" -#define RUY_LOAD_ONE_ROW(R) \ - "cmp w2, #" #R "\n" \ - "beq 5f\n" \ - "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ - "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ - "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ - "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" - - RUY_LOAD_ONE_ROW(0) - RUY_LOAD_ONE_ROW(1) - RUY_LOAD_ONE_ROW(2) - RUY_LOAD_ONE_ROW(3) - RUY_LOAD_ONE_ROW(4) - RUY_LOAD_ONE_ROW(5) - RUY_LOAD_ONE_ROW(6) - RUY_LOAD_ONE_ROW(7) - RUY_LOAD_ONE_ROW(8) - RUY_LOAD_ONE_ROW(9) - RUY_LOAD_ONE_ROW(10) - RUY_LOAD_ONE_ROW(11) - RUY_LOAD_ONE_ROW(12) - RUY_LOAD_ONE_ROW(13) - RUY_LOAD_ONE_ROW(14) - RUY_LOAD_ONE_ROW(15) -#undef RUY_LOAD_ONE_ROW - "5:\n" - - "eor v4.16b, v0.16b, v26.16b\n" - "eor v5.16b, v1.16b, v26.16b\n" - "eor v6.16b, v2.16b, v26.16b\n" - "eor v7.16b, v3.16b, v26.16b\n" - - "saddlp v16.8h, v4.16b\n" - "saddlp v17.8h, v5.16b\n" - "saddlp v18.8h, v6.16b\n" - "saddlp v19.8h, v7.16b\n" - "sadalp v28.4s, v16.8h\n" - "sadalp v29.4s, v17.8h\n" - "sadalp v30.4s, v18.8h\n" - "sadalp v31.4s, v19.8h\n" - - "str q4, [%[packed_ptr], #0]\n" - "str q5, [%[packed_ptr], #16]\n" - "str q6, [%[packed_ptr], #32]\n" - "str q7, [%[packed_ptr], #48]\n" - "add %[packed_ptr], %[packed_ptr], #64\n" - - "4:\n" - - "addp v28.4s, v28.4s, v29.4s\n" - "addp v30.4s, v30.4s, v31.4s\n" - "addp v28.4s, v28.4s, v30.4s\n" - - "cmp %[sums_ptr], #0\n" - "beq 6f\n" - "st1 {v28.4s}, [%[sums_ptr]], #16\n" - "6:\n" - // clang-format on - - : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), - [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), - [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr) - : [ src_inc0 ] "r"(static_cast(src_inc0)), [ src_inc1 ] "r"(static_cast(src_inc1)), - [ src_inc2 ] "r"(static_cast(src_inc2)), [ src_inc3 ] "r"(static_cast(src_inc3)), - [ rows ] "r"(src_rows), - [ src_zero_point ] "r"(src_zero_point), - [input_xor] "r"(input_xor) - : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", - "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", - "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", - "v25", "v26", "v27", "v28", "v29", "v30", "v31"); -} - -void Pack8bitNeonDotprodInOrder(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, int src_zero_point, - std::int8_t* packed_ptr, int start_col, - int end_col, std::int32_t* sums_ptr, - int input_xor) { - profiler::ScopeLabel label( - "Pack (kNeonDotprod, optimized for in-order cores)"); - asm volatile( - // clang-format off - "dup v26.16b, %w[input_xor]\n" - "mov w1, #1\n" - "dup v27.16b, w1\n" - "mov w1, #0\n" - "dup v28.4s, wzr\n" - "dup v29.4s, wzr\n" - "dup v30.4s, wzr\n" - "dup v31.4s, wzr\n" - - "and w2, %w[rows], #-16\n" - "cmp w1, w2\n" - "beq 3f\n" - "ldr x10, [%[src_ptr0], #8]\n" - "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" - "ldr x11, [%[src_ptr1], #8]\n" - "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" - "ldr x12, [%[src_ptr2], #8]\n" - "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" - "ldr x13, [%[src_ptr3], #8]\n" - "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n") - "add w1, w1, #16\n" - "cmp w1, w2\n" - - "beq 2f\n" - - "1:\n" - "add w1, w1, #16\n" - "ins v0.d[1], x10\n" - "ldr x10, [%[src_ptr0], #8]\n" - "ins v1.d[1], x11\n" - "ldr x11, [%[src_ptr1], #8]\n" - "ins v2.d[1], x12\n" - "ldr x12, [%[src_ptr2], #8]\n" - "ins v3.d[1], x13\n" - "ldr x13, [%[src_ptr3], #8]\n" - - "eor v4.16b, v0.16b, v26.16b\n" - "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" - "eor v5.16b, v1.16b, v26.16b\n" - "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" - "eor v6.16b, v2.16b, v26.16b\n" - "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" - "eor v7.16b, v3.16b, v26.16b\n" - "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" - - "trn1 v16.4s, v4.4s, v5.4s\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n") - "trn2 v17.4s, v4.4s, v5.4s\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n") - "trn1 v18.4s, v6.4s, v7.4s\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n") - "trn2 v19.4s, v6.4s, v7.4s\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n") - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - "cmp w1, w2\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - "str q20, [%[packed_ptr], #0]\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - "str q21, [%[packed_ptr], #32]\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - "str q22, [%[packed_ptr], #64]\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - "str q23, [%[packed_ptr], #96]\n" - - "add %[packed_ptr], %[packed_ptr], #128\n" - - "bne 1b\n" - - "2:\n" - "ins v0.d[1], x10\n" - "ins v1.d[1], x11\n" - "ins v2.d[1], x12\n" - "ins v3.d[1], x13\n" - "eor v0.16b, v0.16b, v26.16b\n" - "eor v1.16b, v1.16b, v26.16b\n" - "eor v2.16b, v2.16b, v26.16b\n" - "eor v3.16b, v3.16b, v26.16b\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - "str q20, [%[packed_ptr], #0]\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - "str q21, [%[packed_ptr], #32]\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - "str q22, [%[packed_ptr], #64]\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "3:\n" - - "ands w2, %w[rows], #15\n" - "beq 4f\n" - "dup v0.16b, %w[src_zero_point]\n" - "dup v1.16b, %w[src_zero_point]\n" - "dup v2.16b, %w[src_zero_point]\n" - "dup v3.16b, %w[src_zero_point]\n" -#define RUY_LOAD_ONE_ROW(R) \ - "cmp w2, #" #R "\n" \ - "beq 5f\n" \ - "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ - "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ - "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ - "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" - - RUY_LOAD_ONE_ROW(0) - RUY_LOAD_ONE_ROW(1) - RUY_LOAD_ONE_ROW(2) - RUY_LOAD_ONE_ROW(3) - RUY_LOAD_ONE_ROW(4) - RUY_LOAD_ONE_ROW(5) - RUY_LOAD_ONE_ROW(6) - RUY_LOAD_ONE_ROW(7) - RUY_LOAD_ONE_ROW(8) - RUY_LOAD_ONE_ROW(9) - RUY_LOAD_ONE_ROW(10) - RUY_LOAD_ONE_ROW(11) - RUY_LOAD_ONE_ROW(12) - RUY_LOAD_ONE_ROW(13) - RUY_LOAD_ONE_ROW(14) - RUY_LOAD_ONE_ROW(15) -#undef RUY_LOAD_ONE_ROW - "5:\n" - - "eor v0.16b, v0.16b, v26.16b\n" - "eor v1.16b, v1.16b, v26.16b\n" - "eor v2.16b, v2.16b, v26.16b\n" - "eor v3.16b, v3.16b, v26.16b\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - "str q20, [%[packed_ptr], #0]\n" - "cmp w2, #4\n" - "ble 4f\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - "str q21, [%[packed_ptr], #32]\n" - "cmp w2, #8\n" - "ble 4f\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - "str q22, [%[packed_ptr], #64]\n" - "cmp w2, #12\n" - "ble 4f\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "4:\n" - - "add v28.4s, v28.4s, v29.4s\n" - "add v30.4s, v30.4s, v31.4s\n" - "add v28.4s, v28.4s, v30.4s\n" - - "cmp %[sums_ptr], #0\n" - "beq 6f\n" - "st1 {v28.4s}, [%[sums_ptr]], #16\n" - "6:\n" - // clang-format on - - : [ src_ptr0 ] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), [src_ptr2] "+r"(src_ptr2), - [src_ptr3] "+r"(src_ptr3), [packed_ptr] "+r"(packed_ptr), [sums_ptr] "+r"(sums_ptr) - : [ src_inc0 ] "r"(static_cast(src_inc0)), [ src_inc1 ] "r"(static_cast(src_inc1)), - [ src_inc2 ] "r"(static_cast(src_inc2)), [ src_inc3 ] "r"(static_cast(src_inc3)), - [rows] "r"(src_rows), - [src_zero_point] "r"(static_cast(src_zero_point)), - [input_xor] "r"(input_xor) - : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); -} - -void Pack8bitNeonDotprodOutOfOrder(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, - int src_zero_point, std::int8_t* packed_ptr, - int start_col, int end_col, - std::int32_t* sums_ptr, int input_xor) { - profiler::ScopeLabel label( - "Pack (kNeonDotprod, optimized for out-of-order cores)"); - asm volatile( - // clang-format off - "dup v26.16b, %w[input_xor]\n" - "mov w1, #1\n" - "dup v27.16b, w1\n" - "mov w1, #0\n" - "dup v28.4s, wzr\n" - "dup v29.4s, wzr\n" - "dup v30.4s, wzr\n" - "dup v31.4s, wzr\n" - -#if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING) - "and w2, %w[rows], #-64\n" - "cmp w1, w2\n" - "beq 9f\n" - - "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" - "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n" - "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n" - "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n" - "add w1, w1, #64\n" - "cmp w1, w2\n" - "beq 8f\n" - - "7:\n" - "eor v0.16b, v0.16b, v26.16b\n" - "eor v1.16b, v1.16b, v26.16b\n" - "eor v2.16b, v2.16b, v26.16b\n" - "eor v3.16b, v3.16b, v26.16b\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" - "add w1, w1, #16\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "eor v4.16b, v4.16b, v26.16b\n" - "eor v5.16b, v5.16b, v26.16b\n" - "eor v6.16b, v6.16b, v26.16b\n" - "eor v7.16b, v7.16b, v26.16b\n" - - "trn1 v16.4s, v4.4s, v5.4s\n" - "trn2 v17.4s, v4.4s, v5.4s\n" - "trn1 v18.4s, v6.4s, v7.4s\n" - "trn2 v19.4s, v6.4s, v7.4s\n" - - "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n" - "add w1, w1, #16\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "eor v8.16b, v8.16b, v26.16b\n" - "eor v9.16b, v9.16b, v26.16b\n" - "eor v10.16b, v10.16b, v26.16b\n" - "eor v11.16b, v11.16b, v26.16b\n" - - "trn1 v16.4s, v8.4s, v9.4s\n" - "trn2 v17.4s, v8.4s, v9.4s\n" - "trn1 v18.4s, v10.4s, v11.4s\n" - "trn2 v19.4s, v10.4s, v11.4s\n" - - "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n" - "add w1, w1, #16\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "eor v12.16b, v12.16b, v26.16b\n" - "eor v13.16b, v13.16b, v26.16b\n" - "eor v14.16b, v14.16b, v26.16b\n" - "eor v15.16b, v15.16b, v26.16b\n" - - "trn1 v16.4s, v12.4s, v13.4s\n" - "trn2 v17.4s, v12.4s, v13.4s\n" - "trn1 v18.4s, v14.4s, v15.4s\n" - "trn2 v19.4s, v14.4s, v15.4s\n" - - "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n" - "add w1, w1, #16\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "cmp w1, w2\n" - "bne 7b\n" - - "8:\n" - - "eor v0.16b, v0.16b, v26.16b\n" - "eor v1.16b, v1.16b, v26.16b\n" - "eor v2.16b, v2.16b, v26.16b\n" - "eor v3.16b, v3.16b, v26.16b\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "eor v4.16b, v4.16b, v26.16b\n" - "eor v5.16b, v5.16b, v26.16b\n" - "eor v6.16b, v6.16b, v26.16b\n" - "eor v7.16b, v7.16b, v26.16b\n" - - "trn1 v16.4s, v4.4s, v5.4s\n" - "trn2 v17.4s, v4.4s, v5.4s\n" - "trn1 v18.4s, v6.4s, v7.4s\n" - "trn2 v19.4s, v6.4s, v7.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "eor v8.16b, v8.16b, v26.16b\n" - "eor v9.16b, v9.16b, v26.16b\n" - "eor v10.16b, v10.16b, v26.16b\n" - "eor v11.16b, v11.16b, v26.16b\n" - - "trn1 v16.4s, v8.4s, v9.4s\n" - "trn2 v17.4s, v8.4s, v9.4s\n" - "trn1 v18.4s, v10.4s, v11.4s\n" - "trn2 v19.4s, v10.4s, v11.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "eor v12.16b, v12.16b, v26.16b\n" - "eor v13.16b, v13.16b, v26.16b\n" - "eor v14.16b, v14.16b, v26.16b\n" - "eor v15.16b, v15.16b, v26.16b\n" - - "trn1 v16.4s, v12.4s, v13.4s\n" - "trn2 v17.4s, v12.4s, v13.4s\n" - "trn1 v18.4s, v14.4s, v15.4s\n" - "trn2 v19.4s, v14.4s, v15.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "9:\n" -#endif // #if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING) - "and w2, %w[rows], #-16\n" - "cmp w1, w2\n" - "beq 3f\n" - - "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" - "add w1, w1, #16\n" - "cmp w1, w2\n" - "beq 2f\n" - - "1:\n" - - "eor v0.16b, v0.16b, v26.16b\n" - "eor v1.16b, v1.16b, v26.16b\n" - "eor v2.16b, v2.16b, v26.16b\n" - "eor v3.16b, v3.16b, v26.16b\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" - "add w1, w1, #16\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "cmp w1, w2\n" - "bne 1b\n" - - "2:\n" - - "eor v0.16b, v0.16b, v26.16b\n" - "eor v1.16b, v1.16b, v26.16b\n" - "eor v2.16b, v2.16b, v26.16b\n" - "eor v3.16b, v3.16b, v26.16b\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "3:\n" - - "ands w2, %w[rows], #15\n" - "beq 4f\n" - "dup v0.16b, %w[src_zero_point]\n" - "dup v1.16b, %w[src_zero_point]\n" - "dup v2.16b, %w[src_zero_point]\n" - "dup v3.16b, %w[src_zero_point]\n" -#define RUY_LOAD_ONE_ROW(R) \ - "cmp w2, #" #R "\n" \ - "beq 5f\n" \ - "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ - "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ - "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ - "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" - - RUY_LOAD_ONE_ROW(0) - RUY_LOAD_ONE_ROW(1) - RUY_LOAD_ONE_ROW(2) - RUY_LOAD_ONE_ROW(3) - RUY_LOAD_ONE_ROW(4) - RUY_LOAD_ONE_ROW(5) - RUY_LOAD_ONE_ROW(6) - RUY_LOAD_ONE_ROW(7) - RUY_LOAD_ONE_ROW(8) - RUY_LOAD_ONE_ROW(9) - RUY_LOAD_ONE_ROW(10) - RUY_LOAD_ONE_ROW(11) - RUY_LOAD_ONE_ROW(12) - RUY_LOAD_ONE_ROW(13) - RUY_LOAD_ONE_ROW(14) - RUY_LOAD_ONE_ROW(15) -#undef RUY_LOAD_ONE_ROW - "5:\n" - - "eor v0.16b, v0.16b, v26.16b\n" - "eor v1.16b, v1.16b, v26.16b\n" - "eor v2.16b, v2.16b, v26.16b\n" - "eor v3.16b, v3.16b, v26.16b\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - "str q20, [%[packed_ptr], #0]\n" - "cmp w2, #4\n" - "ble 4f\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - "str q21, [%[packed_ptr], #32]\n" - "cmp w2, #8\n" - "ble 4f\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - "str q22, [%[packed_ptr], #64]\n" - "cmp w2, #12\n" - "ble 4f\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "4:\n" - - "add v28.4s, v28.4s, v29.4s\n" - "add v30.4s, v30.4s, v31.4s\n" - "add v28.4s, v28.4s, v30.4s\n" - - "cmp %[sums_ptr], #0\n" - "beq 6f\n" - "st1 {v28.4s}, [%[sums_ptr]], #16\n" - "6:\n" - // clang-format on - - : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), - [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), - [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr) - : [ src_inc0 ] "r"(static_cast(src_inc0)), - [ src_inc1 ] "r"(static_cast(src_inc1)), - [ src_inc2 ] "r"(static_cast(src_inc2)), - [ src_inc3 ] "r"(static_cast(src_inc3)), - [ rows ] "r"(src_rows), - [ src_zero_point ] "r"(static_cast(src_zero_point)), - [ input_xor ] "r"(input_xor) - : "cc", "memory", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5", "v6", - "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", - "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", - "v27", "v28", "v29", "v30", "v31"); -} - -#endif // RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) -void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1, - const float* src_ptr2, const float* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, int src_zero_point, - float* packed_ptr, int start_col, int end_col) { - profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)"); - asm volatile( - // clang-format off - "mov w1, #0\n" - - "and w2, %w[rows], #-4\n" - "cmp w1, w2\n" - "beq 3f\n" - "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n" - "add w1, w1, #4\n" - "cmp w1, w2\n" - - "beq 2f\n" - - "1:\n" - "add w1, w1, #4\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - "cmp w1, w2\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - - "add %[packed_ptr], %[packed_ptr], #128\n" - - "bne 1b\n" - - "2:\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "3:\n" - - "ands w2, %w[rows], #3\n" - "beq 4f\n" - "dup v0.16b, wzr\n" - "dup v1.16b, wzr\n" - "dup v2.16b, wzr\n" - "dup v3.16b, wzr\n" -#define RUY_LOAD_ONE_ROW(R) \ - "cmp w2, #" #R "\n" \ - "beq 5f\n" \ - "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \ - "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \ - "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \ - "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n" - - RUY_LOAD_ONE_ROW(0) - RUY_LOAD_ONE_ROW(1) - RUY_LOAD_ONE_ROW(2) - RUY_LOAD_ONE_ROW(3) -#undef RUY_LOAD_ONE_ROW - "5:\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - "mov x1, #32\n" - -#define RUY_STORE_ONE_ROW(ROW, REGISTER) \ - "cmp w2, #" #ROW "\n" \ - "beq 4f\n" \ - "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n" - - RUY_STORE_ONE_ROW(0, v20) - RUY_STORE_ONE_ROW(1, v21) - RUY_STORE_ONE_ROW(2, v22) - RUY_STORE_ONE_ROW(3, v23) - -#undef RUY_STORE_ONE_ROW - - "4:\n" - - // clang-format on - - : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), - [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), - [ packed_ptr ] "+r"(packed_ptr) - : [ src_inc0 ] "r"(static_cast(src_inc0)), - [ src_inc1 ] "r"(static_cast(src_inc1)), - [ src_inc2 ] "r"(static_cast(src_inc2)), - [ src_inc3 ] "r"(static_cast(src_inc3)), - [ rows ] "r"(src_rows) - : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", - "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", - "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); -} -#endif - -#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) -void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1, - const float* src_ptr2, const float* src_ptr3, - int src_inc, int src_rows, int src_zero_point, - float* packed_ptr, int start_col, int end_col, - int output_stride) { - profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)"); - asm volatile( - // clang-format off - "mov r1, #0\n" - "and r2, %[rows], #-4\n" - "cmp r1, r2\n" - "beq 3f\n" -#define RUY_LOAD_FOUR_BY_FOUR() \ - /* Load q0 */ \ - "vld1.32 {d0, d1}, [%[src_ptr0]]\n" \ - /* if src_inc0 != 0, add 16 to src_ptr0 */ \ - "and r3, %[src_inc], #1\n" \ - "add %[src_ptr0], %[src_ptr0], r3, lsl #4\n"\ - /* Load q1 */ \ - "vld1.32 {d2, d3}, [%[src_ptr1]]\n" \ - /* if src_inc1 != 0, add 16 to src_ptr0 */ \ - "and r3, %[src_inc], #2\n" \ - "add %[src_ptr1], %[src_ptr1], r3, lsl #3\n"\ - /* Load q2 */ \ - "vld1.32 {d4, d5}, [%[src_ptr2]]\n" \ - /* if src_inc2 != 0, add 16 to src_ptr0 */ \ - "and r3, %[src_inc], #4\n" \ - "add %[src_ptr2], %[src_ptr2], r3, lsl #2\n"\ - /* Load q3 */ \ - "vld1.32 {d6, d7}, [%[src_ptr3]]\n" \ - /* if src_inc3 != 0, add 16 to src_ptr0 */ \ - "and r3, %[src_inc], #8\n" \ - "add %[src_ptr3], %[src_ptr3], r3, lsl #1\n"\ - - RUY_LOAD_FOUR_BY_FOUR() - "add r1, r1, #4\n" - "cmp r1, r2\n" - - "beq 2f\n" - - "1:\n" - "add r1, r1, #4\n" - - // Transpose 4x4 matrix. - "vzip.32 q0, q1\n" - "vzip.32 q2, q3\n" - - "vtrn.32 q0, q2\n" - "vtrn.32 q1, q3\n" - - "vzip.32 q0, q2\n" - "vzip.32 q1, q3\n" - - "vmov q8, q0\n" - "vmov q9, q1\n" - "vmov q10, q2\n" - "vmov q11, q3\n" - - RUY_LOAD_FOUR_BY_FOUR() -#undef RUY_LOAD_FOUR_BY_FOUR - -#define RUY_STORE_FOUR_BY_FOUR() \ - /* Store q8, q10, q9, q11 */ \ - /* q8 = d16, d17 */ \ - "vst1.32 {d16, d17}, [%[packed_ptr]]\n" \ - /* q10 = d20, d21 */ \ - "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ - "vst1.32 {d20, d21}, [%[packed_ptr]]\n" \ - /* q9 = d18, d19 */ \ - "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ - "vst1.32 {d18, d19}, [%[packed_ptr]]\n" \ - /* q11 = d22, d23 */ \ - "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ - "vst1.32 {d22, d23}, [%[packed_ptr]]\n" \ - "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ - - RUY_STORE_FOUR_BY_FOUR() - "cmp r1, r2\n" - - "bne 1b\n" - - "2:\n" - - // Transpose 4x4 matrix. - "vzip.32 q0, q1\n" - "vzip.32 q2, q3\n" - - "vtrn.32 q0, q2\n" - "vtrn.32 q1, q3\n" - - "vzip.32 q0, q2\n" - "vzip.32 q1, q3\n" - - "vmov q8, q0\n" - "vmov q9, q1\n" - "vmov q10, q2\n" - "vmov q11, q3\n" - - RUY_STORE_FOUR_BY_FOUR() -#undef RUY_STORE_FOUR_BY_FOUR - "3:\n" - - "ands r2, %[rows], #3\n" - "beq 4f\n" - "mov r0, #0\n" - // Zero out q0 - q3 - "vdup.32 q0, r0\n" - "vdup.32 q1, r0\n" - "vdup.32 q2, r0\n" - "vdup.32 q3, r0\n" -#define RUY_LOAD_ONE_ROW_FIRST_HALF(R, I) \ - "cmp r2, #" #R "\n" \ - "beq 5f\n" \ - "vld1.32 { d0[" #I "] }, [%[src_ptr0]]!\n" \ - "vld1.32 { d2[" #I "] }, [%[src_ptr1]]!\n" \ - "vld1.32 { d4[" #I "] }, [%[src_ptr2]]!\n" \ - "vld1.32 { d6[" #I "] }, [%[src_ptr3]]!\n" - -#define RUY_LOAD_ONE_ROW_SECOND_HALF(R, I) \ - "vld1.32 { d1[" #I "] }, [%[src_ptr0]]!\n" \ - "vld1.32 { d3[" #I "] }, [%[src_ptr1]]!\n" \ - "vld1.32 { d5[" #I "] }, [%[src_ptr2]]!\n" \ - "vld1.32 { d7[" #I "] }, [%[src_ptr3]]!\n" - - RUY_LOAD_ONE_ROW_FIRST_HALF(0, 0) - RUY_LOAD_ONE_ROW_FIRST_HALF(1, 1) - RUY_LOAD_ONE_ROW_SECOND_HALF(2, 0) - RUY_LOAD_ONE_ROW_SECOND_HALF(3, 1) -#undef RUY_LOAD_ONE_ROW_SECOND_HALF -#undef RUY_LOAD_ONE_ROW_FIRST_HALF - "5:\n" - - // Transpose 4x4 matrix. - "vzip.32 q0, q1\n" - "vzip.32 q2, q3\n" - - "vtrn.32 q0, q2\n" - "vtrn.32 q1, q3\n" - - "vzip.32 q0, q2\n" - "vzip.32 q1, q3\n" - - "vmov q8, q0\n" - "vmov q9, q1\n" - "vmov q10, q2\n" - "vmov q11, q3\n" - - "mov r1, #32\n" - -#define RUY_STORE_ONE_ROW(ROW, REGISTER) \ - "cmp r2, #" #ROW "\n" \ - "beq 4f\n" \ - "vst1.32 {" #REGISTER "}, [%[packed_ptr]]\n" \ - "add %[packed_ptr], %[packed_ptr], %[stride]\n" - - // Store q8 - RUY_STORE_ONE_ROW(0, q8) - // Store q10 - RUY_STORE_ONE_ROW(1, q10) - // Store q9 - RUY_STORE_ONE_ROW(2, q9) - // Store q11 - RUY_STORE_ONE_ROW(3, q11) - -#undef RUY_STORE_ONE_ROW - - "4:\n" - - // clang-format on - : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), - [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), - [ packed_ptr ] "+r"(packed_ptr) - : [ src_inc ] "r"(static_cast(src_inc)), - [ rows ] "r"(src_rows), [ stride ] "r"(output_stride) - : "cc", "memory", "r0", "r1", "r2", "r3", "q0", "q1", "q2", "q3", - "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11"); -} - -#endif // (RUY_PLATFORM(NEON_32) - -#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) -void PackFloatNeonInOrder(const float* src_ptr0, const float* src_ptr1, - const float* src_ptr2, const float* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, int src_zero_point, - float* packed_ptr, int start_col, int end_col) { - profiler::ScopeLabel label("Pack (kNeon, optimized for in-order cores)"); - - asm volatile( - // clang-format off - "mov w1, #0\n" - - "and w2, %w[rows], #-4\n" - "cmp w1, w2\n" - "beq 3f\n" - "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n") - "add w1, w1, #4\n" - "cmp w1, w2\n" - - "beq 2f\n" - - "1:\n" - "add w1, w1, #4\n" - - "ldr x10, [%[src_ptr0], #8]\n" - "trn1 v16.4s, v0.4s, v1.4s\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n") - "ldr x11, [%[src_ptr1], #8]\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n") - "ldr x12, [%[src_ptr2], #8]\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n") - "ldr x13, [%[src_ptr3], #8]\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n") - - "ld1 {v0.2s}, [%[src_ptr0]], %[src_inc0]\n" - "trn1 v20.2d, v16.2d, v18.2d\n" - "ld1 {v1.2s}, [%[src_ptr1]], %[src_inc1]\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "ld1 {v2.2s}, [%[src_ptr2]], %[src_inc2]\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "ld1 {v3.2s}, [%[src_ptr3]], %[src_inc3]\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - "cmp w1, w2\n" - - "ins v0.d[1], x10\n" - "str q20, [%[packed_ptr], #0]\n" - "ins v1.d[1], x11\n" - "str q21, [%[packed_ptr], #32]\n" - "ins v2.d[1], x12\n" - "str q22, [%[packed_ptr], #64]\n" - "ins v3.d[1], x13\n" - "str q23, [%[packed_ptr], #96]\n" - - "add %[packed_ptr], %[packed_ptr], #128\n" - - "bne 1b\n" - - "2:\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "3:\n" - - "ands w2, %w[rows], #3\n" - "beq 4f\n" - "dup v0.16b, wzr\n" - "dup v1.16b, wzr\n" - "dup v2.16b, wzr\n" - "dup v3.16b, wzr\n" -#define RUY_LOAD_ONE_ROW(R) \ - "cmp w2, #" #R "\n" \ - "beq 5f\n" \ - "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \ - "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \ - "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \ - "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n" - - RUY_LOAD_ONE_ROW(0) - RUY_LOAD_ONE_ROW(1) - RUY_LOAD_ONE_ROW(2) - RUY_LOAD_ONE_ROW(3) -#undef RUY_LOAD_ONE_ROW - "5:\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - "mov x1, #32\n" - -#define RUY_STORE_ONE_ROW(ROW, REGISTER) \ - "cmp w2, #" #ROW "\n" \ - "beq 4f\n" \ - "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n" - - RUY_STORE_ONE_ROW(0, v20) - RUY_STORE_ONE_ROW(1, v21) - RUY_STORE_ONE_ROW(2, v22) - RUY_STORE_ONE_ROW(3, v23) - -#undef RUY_STORE_ONE_ROW - - "4:\n" - - // clang-format on - - : [ src_ptr0 ] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), [src_ptr2] "+r"(src_ptr2), - [src_ptr3] "+r"(src_ptr3), [packed_ptr] "+r"(packed_ptr) - : [ src_inc0 ] "r"(static_cast(src_inc0)), [src_inc1] "r"(static_cast(src_inc1)), [src_inc2] "r"(static_cast(src_inc2)), - [src_inc3] "r"(static_cast(src_inc3)), [rows] "r"(src_rows) - : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); -} -#endif // RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/pack_arm.h b/tensorflow/lite/experimental/ruy/ruy/pack_arm.h deleted file mode 100644 index f4691d66fcb..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/pack_arm.h +++ /dev/null @@ -1,497 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// # What is "packing"? -// -// Before feeding data to the gemm kernels (the parts of Ruy that do lots -// of multiply-add operations), Ruy first performs a data transformation (which -// we call "packing") on the input matrices. This transformation has two main -// goals: -// - rearrange data into blocks that are a convenient size/layout for the gemm -// kernels to consume. This helps make the memory access pattern of the gemm -// kernel simpler and more contiguous, and puts the data in a layout most -// convenient for specific arithmetic instructions in the gemm kernel. -// - compute row/column sums needed for handling quantization with non-symmetric -// zero points. -// -// # Simplified algorithmic analysis of packing -// -// Packing is a relatively simple transformation which does a small constant -// amount of work on each element of an input matrix, and hence for an NxM -// matrix performs O(N*M) work. If N and M are of the same order, then this is -// O(N^2) work. -// -// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations. -// Note that if N, K, and M are all the same order, then the number of -// multiply-accumulate operations is O(N^3). -// -// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the -// case of all dimensions being roughly the same order. -// -// # Packing cost can be significant -// -// When matrix * matrix multiplications begin to look more like matrix * vector -// multiplications, packing cost can become significant. We sometimes call these -// cases "gemv-like". -// -// Continuing the algorithmic analysis above, if we consider a case where an -// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the -// situation is different. In this case, the multiply-accumulate work is only -// quadratic, so the quadratic cost of packing can be come significant. -// -// Another way to say this is that the cost of packing an input matrix (either -// the LHS or RHS) is amortized across the non-depth dimension of the opposite -// input matrix. Thus, when the LHS has very few rows or the RHS has very few -// columns, the cost of packing the opposite input matrix can become -// significant. -// -// As a rough rule of thumb, the cost of packing starts to become significant -// when either N or M is below 32 (and other dimensions are hundreds), with very -// significant packing costs at 8 or below. This varies by data type, Path, and -// tuning, so these numbers are only rough guides. -// -// One practical use case that is affected by this is inference of -// fully connected neural network layers with a low batch size. The weight -// matrix (which is a constant for inference) is the one affected by significant -// packing cost. -// -// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack -// input matrices that are affected by significant packing costs. -// -// # Implementation notes -// -// Ruy's packing routines always operate on a range of columns and can be -// applied to either the LHS or RHS. This is possible because Ruy internally -// implements a TrMul, so the accumulation along depth is done along columns of -// both the LHS and RHS (whereas for a normal Mul the accumulation along depth -// for the LHS is along rows). As another example, we are always computing -// column sums for quantization (and never row sums, since the LHS is -// transposed). - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_ARM_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_ARM_H_ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/pack_common.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -namespace ruy { - -#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) -void Pack8bitNeonOutOfOrder(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, int src_zero_point, - std::int8_t* packed_ptr, int start_col, int end_col, - std::int32_t* sums_ptr, int input_xor); -void Pack8bitNeonInOrder(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, int src_inc3, - int src_rows, int src_zero_point, - std::int8_t* packed_ptr, int start_col, int end_col, - std::int32_t* sums_ptr, int input_xor); -void Pack8bitNeonDotprodOutOfOrder(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, - int src_zero_point, std::int8_t* packed_ptr, - int start_col, int end_col, - std::int32_t* sums_ptr, int input_xor); -void Pack8bitNeonDotprodInOrder(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, int src_zero_point, - std::int8_t* packed_ptr, int start_col, - int end_col, std::int32_t* sums_ptr, - int input_xor); - -#elif RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) -void Pack8bitNeonOutOfOrder4Cols(const PackParams8bit& params); -void Pack8bitNeonOutOfOrder2Cols(const PackParams8bit& params); -#endif // (RUY_PLATFORM(NEON_64)&& RUY_OPT_ENABLED(RUY_OPT_ASM) - -#if (RUY_PLATFORM(NEON_32) || RUY_PLATFORM(NEON_64)) && \ - RUY_OPT_ENABLED(RUY_OPT_ASM) - -template -struct PackImpl, Scalar, - std::int8_t, std::int32_t> { - static_assert(std::is_same::value || - std::is_same::value, - ""); - static constexpr int kInputXor = - std::is_same::value ? 0 : 0x80; - - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ(start_col % 4, 0); - std::int32_t* sums = packed_matrix->sums; - Scalar zerobuf[16]; - memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf)); - for (int block_col = start_col; block_col < end_col; block_col += 4) { - int src_stride = src_matrix.layout.stride; - const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col; - const Scalar* src_ptr1 = src_ptr0 + src_stride; - const Scalar* src_ptr2 = src_ptr1 + src_stride; - const Scalar* src_ptr3 = src_ptr2 + src_stride; - int src_inc0 = 16; - int src_inc1 = 16; - int src_inc2 = 16; - int src_inc3 = 16; - if (block_col >= src_matrix.layout.cols - 3) { - if (block_col >= src_matrix.layout.cols - 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (block_col >= src_matrix.layout.cols - 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (block_col >= src_matrix.layout.cols - 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (block_col >= src_matrix.layout.cols - 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - } - std::int8_t* packed_ptr = - packed_matrix->data + packed_matrix->layout.stride * block_col; - std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; -#if RUY_PLATFORM(NEON_64) - if (__builtin_expect(tuning == Tuning::kInOrder, true)) { - Pack8bitNeonInOrder( - src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, - src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, - packed_ptr, start_col, end_col, sums_ptr, kInputXor); - } else { - Pack8bitNeonOutOfOrder( - src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, - src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, - packed_ptr, start_col, end_col, sums_ptr, kInputXor); - } -#else - // We have a more limited set of general purpose registers in ARMv7, so - // we use the "params" struct technique from the kernel code to save - // registers. - PackParams8bit params; - MakePackParams8bit(src_ptr0, src_ptr1, src_ptr2, src_ptr3, sums_ptr, - packed_ptr, src_inc0, src_inc1, src_inc2, src_inc3, - src_matrix.layout.rows, src_matrix.zero_point, - kInputXor, ¶ms); - Pack8bitNeonOutOfOrder4Cols(params); -#endif // RUY_PLATFORM(NEON_64) - } - } -}; - -#endif // (RUY_PLATFORM(NEON_32) || RUY_PLATFORM(NEON_64)) && - // RUY_OPT_ENABLED(RUY_OPT_ASM) - -#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) -// The 32-bit float kernel is 4 rows X 2 columns, so we need an additional -// partial specialization for the RHS, which has a FixedKernelLayout with 2 -// columns. -template -struct PackImpl, Scalar, - std::int8_t, std::int32_t> { - static_assert(std::is_same::value || - std::is_same::value, - ""); - static constexpr int kInputXor = - std::is_same::value ? 0 : 0x80; - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ(start_col % 2, 0); - std::int32_t* sums = packed_matrix->sums; - Scalar zerobuf[16]; - memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf)); - for (int block_col = start_col; block_col < end_col; block_col += 2) { - int src_stride = src_matrix.layout.stride; - const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col; - const Scalar* src_ptr1 = src_ptr0 + src_stride; - int src_inc0 = 16; - int src_inc1 = 16; - if (block_col >= src_matrix.layout.cols - 2) { - if (block_col >= src_matrix.layout.cols - 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (block_col >= src_matrix.layout.cols - 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - } - std::int8_t* packed_ptr = - packed_matrix->data + packed_matrix->layout.stride * block_col; - std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; - PackParams8bit params; - MakePackParams8bit(src_ptr0, src_ptr1, nullptr, nullptr, sums_ptr, - packed_ptr, src_inc0, src_inc1, -1, -1, - src_matrix.layout.rows, src_matrix.zero_point, - kInputXor, ¶ms); - Pack8bitNeonOutOfOrder2Cols(params); - } - } -}; -#endif // (RUY_PLATFORM(NEON_32)) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) -template -struct PackImpl, - Scalar, std::int8_t, std::int32_t> { - static_assert(std::is_same::value || - std::is_same::value, - ""); - static constexpr int kInputXor = - std::is_same::value ? 0 : 0x80; - - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ(start_col % 8, 0); - std::int32_t* sums = packed_matrix->sums; - Scalar zerobuf[16]; - memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf)); - for (int block_col = start_col; block_col < end_col; block_col += 4) { - int src_stride = src_matrix.layout.stride; - const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col; - const Scalar* src_ptr1 = src_ptr0 + src_stride; - const Scalar* src_ptr2 = src_ptr1 + src_stride; - const Scalar* src_ptr3 = src_ptr2 + src_stride; - std::int64_t src_inc0 = 16; - std::int64_t src_inc1 = 16; - std::int64_t src_inc2 = 16; - std::int64_t src_inc3 = 16; - if (block_col >= src_matrix.layout.cols - 3) { - if (block_col >= src_matrix.layout.cols - 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (block_col >= src_matrix.layout.cols - 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (block_col >= src_matrix.layout.cols - 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (block_col >= src_matrix.layout.cols - 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - } - std::int8_t* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & ~7) + - ((block_col & 4) * 4); - std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; - if (__builtin_expect(tuning == Tuning::kInOrder, true)) { - Pack8bitNeonDotprodInOrder( - src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, - src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, - packed_ptr, start_col, end_col, sums_ptr, kInputXor); - } else { - Pack8bitNeonDotprodOutOfOrder( - src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, - src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, - packed_ptr, start_col, end_col, sums_ptr, kInputXor); - } - } - } -}; -#endif // (RUY_PLATFORM(NEON_64)&& RUY_OPT_ENABLED(RUY_OPT_ASM) - -#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) -void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1, - const float* src_ptr2, const float* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, int src_zero_point, - float* packed_ptr, int start_col, int end_col); -void PackFloatNeonInOrder(const float* src_ptr0, const float* src_ptr1, - const float* src_ptr2, const float* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, int src_zero_point, - float* packed_ptr, int start_col, int end_col); - -#elif RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) -void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1, - const float* src_ptr2, const float* src_ptr3, - int src_inc, int src_rows, int src_zero_point, - float* packed_ptr, int start_col, int end_col, - int stride); -#endif // (RUY_PLATFORM(NEON_64)&& RUY_OPT_ENABLED(RUY_OPT_ASM) - -#if (RUY_PLATFORM(NEON_32) || RUY_PLATFORM(NEON_64)) && \ - RUY_OPT_ENABLED(RUY_OPT_ASM) - -template <> -struct PackImpl, float, - float, float> { - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ(start_col % 8, 0); - const float zerobuf[4] = {0}; - for (int block_col = start_col; block_col < end_col; block_col += 4) { - int src_stride = src_matrix.layout.stride; - const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col; - const float* src_ptr1 = src_ptr0 + src_stride; - const float* src_ptr2 = src_ptr1 + src_stride; - const float* src_ptr3 = src_ptr2 + src_stride; - std::int64_t src_inc0 = 16; - std::int64_t src_inc1 = 16; - std::int64_t src_inc2 = 16; - std::int64_t src_inc3 = 16; - if (block_col >= src_matrix.layout.cols - 3) { - if (block_col >= src_matrix.layout.cols - 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (block_col >= src_matrix.layout.cols - 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (block_col >= src_matrix.layout.cols - 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (block_col >= src_matrix.layout.cols - 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - } - float* packed_ptr = packed_matrix->data + - packed_matrix->layout.stride * (block_col & ~7) + - ((block_col & 4)); -#if RUY_PLATFORM(NEON_64) - if (__builtin_expect(tuning == Tuning::kInOrder, true)) { - PackFloatNeonInOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, - src_inc1, src_inc2, src_inc3, - src_matrix.layout.rows, src_matrix.zero_point, - packed_ptr, start_col, end_col); - } else { - PackFloatNeonOutOfOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, - src_inc0, src_inc1, src_inc2, src_inc3, - src_matrix.layout.rows, src_matrix.zero_point, - packed_ptr, start_col, end_col); - } -#else - // Encode each of src_inc0, ..., src_inc3 in lowest 4 bits of src_inc - // to save on registers (we have fewer general purpose registers in - // 32-bit ARM than in 64-bit ARM). For the 64-bit case, we pass four - // values that are each either 16 or 0 and use them directly. For the - // 32-bit case, bits 0, 1, 2, and 3 are used to determine if we should - // use the value 16 (bit is set) or 0 (bit is not set) for the - // respective increment value. - std::int64_t src_inc = 0; - src_inc += src_inc0 == 16 ? 1 : 0; - src_inc += src_inc1 == 16 ? 2 : 0; - src_inc += src_inc2 == 16 ? 4 : 0; - src_inc += src_inc3 == 16 ? 8 : 0; - const int kOutputStride = 32; - PackFloatNeonOutOfOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc, - src_matrix.layout.rows, src_matrix.zero_point, - packed_ptr, start_col, end_col, kOutputStride); -#endif // RUY_PLATFORM(NEON_64) - } - } -}; - -#if RUY_PLATFORM(NEON_32) -// The 32-bit float kernel is 8 rows X 4 columns, so we need an additional -// specialization for a FixedKernelLayout with 4 columns. -template <> -struct PackImpl, float, - float, float> { - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ(start_col % 4, 0); - const float zerobuf[4] = {0}; - for (int block_col = start_col; block_col < end_col; block_col += 4) { - int src_stride = src_matrix.layout.stride; - const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col; - const float* src_ptr1 = src_ptr0 + src_stride; - const float* src_ptr2 = src_ptr1 + src_stride; - const float* src_ptr3 = src_ptr2 + src_stride; - std::int64_t src_inc0 = 16; - std::int64_t src_inc1 = 16; - std::int64_t src_inc2 = 16; - std::int64_t src_inc3 = 16; - if (block_col >= src_matrix.layout.cols - 3) { - if (block_col >= src_matrix.layout.cols - 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (block_col >= src_matrix.layout.cols - 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (block_col >= src_matrix.layout.cols - 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (block_col >= src_matrix.layout.cols - 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - } - float* packed_ptr = - packed_matrix->data + packed_matrix->layout.stride * (block_col); - // Encode each of src_inc0, ..., src_inc1 in lowest 4 bits of scrc_inc - // to save registers. - std::int64_t src_inc = 0; - src_inc += src_inc0 == 16 ? 1 : 0; - src_inc += src_inc1 == 16 ? 2 : 0; - src_inc += src_inc2 == 16 ? 4 : 0; - src_inc += src_inc3 == 16 ? 8 : 0; - const int kOutputStride = 16; - PackFloatNeonOutOfOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc, - src_matrix.layout.rows, src_matrix.zero_point, - packed_ptr, start_col, end_col, kOutputStride); - } - } -}; -#endif // (RUY_PLATFORM(NEON_32)) -#endif // (RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && \ - // RUY_OPT_ENABLED(RUY_OPT_ASM) - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_ARM_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/pack_avx2.cc b/tensorflow/lite/experimental/ruy/ruy/pack_avx2.cc deleted file mode 100644 index 3575943e50e..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/pack_avx2.cc +++ /dev/null @@ -1,816 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/pack.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -#if RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) -#include // IWYU pragma: keep -#endif - -namespace ruy { - -#if !(RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -void Pack8bitAvx2(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, std::int8_t* packed_ptr, - std::int32_t* sums_ptr) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void PackFloatAvx2(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -#else // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -// The first int8_t template parameter is arbitrary: this routine is common to -// all 8-bit source matrix types. -using PackImpl8bitAvx2 = - PackImpl, - std::int8_t, std::int8_t, std::int32_t>; - -using PackImplFloatAvx2 = - PackImpl, float, - float, float>; - -namespace { - -inline __m256i MaskLoadu(int available_src_rows, std::int8_t zero_point, - const std::int8_t* addr) { - RUY_DCHECK_LT(available_src_rows, 32); - __m256i padded_data; - - if (available_src_rows >= 16) { - __m128i load_hi = _mm_set1_epi8(zero_point); - __m128i load_lo = _mm_loadu_si128(reinterpret_cast(addr)); - memcpy(&load_hi, addr + 16, available_src_rows - 16); - padded_data = _mm256_set_m128i(load_hi, load_lo); - } else { - __m128i load_hi = _mm_set1_epi8(zero_point); - __m128i load_lo = load_hi; - memcpy(&load_lo, addr, available_src_rows); - padded_data = _mm256_set_m128i(load_hi, load_lo); - } - return padded_data; -} - -inline void Pack8bitAvx2Packer(const std::int8_t* src_ptr, - std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr, - std::int8_t* trailing_buf) { - using Layout = PackImpl8bitAvx2::Layout; - RUY_DCHECK_EQ(Layout::kCols, 8); - RUY_DCHECK_EQ(Layout::kRows, 4); - // Each Layout::Rows is 4 contiguous input, contiguous packed elements. - // We process 8 of these chunks at a time, padding short input chunks. - constexpr int kNumRowChunks = 8; - constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows; - - const std::int8_t* src_ptr0 = src_ptr; - const std::int8_t* src_ptr1 = src_ptr0 + src_stride; - const std::int8_t* src_ptr2 = src_ptr1 + src_stride; - const std::int8_t* src_ptr3 = src_ptr2 + src_stride; - const std::int8_t* src_ptr4 = src_ptr3 + src_stride; - const std::int8_t* src_ptr5 = src_ptr4 + src_stride; - const std::int8_t* src_ptr6 = src_ptr5 + src_stride; - const std::int8_t* src_ptr7 = src_ptr6 + src_stride; - std::int64_t src_inc0 = kNumChunkedSrcRows; - std::int64_t src_inc1 = kNumChunkedSrcRows; - std::int64_t src_inc2 = kNumChunkedSrcRows; - std::int64_t src_inc3 = kNumChunkedSrcRows; - std::int64_t src_inc4 = kNumChunkedSrcRows; - std::int64_t src_inc5 = kNumChunkedSrcRows; - std::int64_t src_inc6 = kNumChunkedSrcRows; - std::int64_t src_inc7 = kNumChunkedSrcRows; - // Handle cases where source does not have Layout::kCols (8) columns. - if (remaining_src_cols < 8) { - if (remaining_src_cols <= 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (remaining_src_cols <= 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (remaining_src_cols <= 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (remaining_src_cols <= 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - if (remaining_src_cols <= 4) { - src_ptr4 = zerobuf; - src_inc4 = 0; - } - if (remaining_src_cols <= 5) { - src_ptr5 = zerobuf; - src_inc5 = 0; - } - if (remaining_src_cols <= 6) { - src_ptr6 = zerobuf; - src_inc6 = 0; - } - src_ptr7 = zerobuf; - src_inc7 = 0; - } - - const std::int8_t zero_point = zerobuf[0]; - - if (sums_ptr) { - // i: Layout::kCols. - for (int i = 0; i < 8; ++i) { - sums_ptr[i] = 0; - } - } - std::int32_t sums_adjustment = 0; - const __m256i ones_16bit = _mm256_set1_epi16(1); - __m256i sums_4x2_32bit_lo = _mm256_set1_epi32(0); - __m256i sums_4x2_32bit_hi = _mm256_set1_epi32(0); - - // The overall packing effectively pads the source rows to - // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we - // only pack for (src_rows + 31) & ~31. When there is an incomplete - // destination block, this is stored into trailing_buf instead of packed_ptr. - for (int k = 0; k < src_rows; k += kNumChunkedSrcRows) { - // Available source rows. - // If this is less than 0 (for m=1), we skip, having filled trailing - // buffer for m=0. Also, if source rows is zero on m=1, then we filled - // exactly to the end of the column in the packed buffer. - const int available_src_rows = src_rows - k; - // Effectively, - // available rows = std::max(0, std::min(8, src_rows - k)); - // treat each case separately. - if (available_src_rows >= kNumChunkedSrcRows) { - if (sums_ptr) { - __m256i t0, t1, t2, t3, t4, t5, t6, t7; - __m256i r0, r1, r2, r3, r4, r5, r6, r7; - const __m256i input_xor_v = _mm256_set1_epi8(input_xor); - - t0 = _mm256_loadu_si256(reinterpret_cast(src_ptr0)); - t4 = _mm256_loadu_si256(reinterpret_cast(src_ptr4)); - t1 = _mm256_loadu_si256(reinterpret_cast(src_ptr1)); - t5 = _mm256_loadu_si256(reinterpret_cast(src_ptr5)); - t2 = _mm256_loadu_si256(reinterpret_cast(src_ptr2)); - t6 = _mm256_loadu_si256(reinterpret_cast(src_ptr6)); - t3 = _mm256_loadu_si256(reinterpret_cast(src_ptr3)); - t7 = _mm256_loadu_si256(reinterpret_cast(src_ptr7)); - - r0 = _mm256_unpacklo_epi32(t0, t1); - r4 = _mm256_unpacklo_epi32(t4, t5); - r2 = _mm256_unpackhi_epi32(t0, t1); - r6 = _mm256_unpackhi_epi32(t4, t5); - r1 = _mm256_unpacklo_epi32(t2, t3); - r5 = _mm256_unpacklo_epi32(t6, t7); - r3 = _mm256_unpackhi_epi32(t2, t3); - r7 = _mm256_unpackhi_epi32(t6, t7); - - t0 = _mm256_unpacklo_epi64(r0, r1); - t4 = _mm256_unpacklo_epi64(r4, r5); - t2 = _mm256_unpackhi_epi64(r0, r1); - t6 = _mm256_unpackhi_epi64(r4, r5); - t1 = _mm256_unpacklo_epi64(r2, r3); - t5 = _mm256_unpacklo_epi64(r6, r7); - t3 = _mm256_unpackhi_epi64(r2, r3); - t7 = _mm256_unpackhi_epi64(r6, r7); - - // The preceding sets of rearrangement operations interleaved by 4 bytes - // and then by 8 bytes *within* lanes. The following set interleave by - // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, - // t4) are interleaved to create (r0, r1). This complexity follows from - // the way that AVX is centered around MM 128-bit lanes. - r0 = _mm256_permute2x128_si256(t0, t4, 0x20); - r4 = _mm256_permute2x128_si256(t1, t5, 0x20); - r1 = _mm256_permute2x128_si256(t0, t4, 0x31); - r5 = _mm256_permute2x128_si256(t1, t5, 0x31); - r2 = _mm256_permute2x128_si256(t2, t6, 0x20); - r6 = _mm256_permute2x128_si256(t3, t7, 0x20); - r3 = _mm256_permute2x128_si256(t2, t6, 0x31); - r7 = _mm256_permute2x128_si256(t3, t7, 0x31); - - r0 = _mm256_xor_si256(r0, input_xor_v); - r1 = _mm256_xor_si256(r1, input_xor_v); - r2 = _mm256_xor_si256(r2, input_xor_v); - r3 = _mm256_xor_si256(r3, input_xor_v); - r4 = _mm256_xor_si256(r4, input_xor_v); - r5 = _mm256_xor_si256(r5, input_xor_v); - r6 = _mm256_xor_si256(r6, input_xor_v); - r7 = _mm256_xor_si256(r7, input_xor_v); - - __m256i sums_4x4_16bit_lo; - sums_4x4_16bit_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0)); - sums_4x4_16bit_lo = - _mm256_add_epi16(sums_4x4_16bit_lo, - _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1))); - sums_4x4_16bit_lo = - _mm256_add_epi16(sums_4x4_16bit_lo, - _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2))); - sums_4x4_16bit_lo = - _mm256_add_epi16(sums_4x4_16bit_lo, - _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3))); - sums_4x4_16bit_lo = - _mm256_add_epi16(sums_4x4_16bit_lo, - _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4))); - sums_4x4_16bit_lo = - _mm256_add_epi16(sums_4x4_16bit_lo, - _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5))); - sums_4x4_16bit_lo = - _mm256_add_epi16(sums_4x4_16bit_lo, - _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6))); - sums_4x4_16bit_lo = - _mm256_add_epi16(sums_4x4_16bit_lo, - _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7))); - - // The sums have been performed across columns, and now we have 4x16-bit - // sums packed together. We use madd for pairwise 32-bit sums. - const __m256i sums_4x2_32bit_lo_new = - _mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit); - sums_4x2_32bit_lo = - _mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new); - - __m256i sums_4x4_16bit_hi; - sums_4x4_16bit_hi = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r0, 1)); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r1, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r2, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r3, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r4, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r5, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r6, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r7, 1))); - - const __m256i sums_4x2_32bit_hi_new = - _mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit); - sums_4x2_32bit_hi = - _mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new); - - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4), - r0); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4), - r4); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4), - r1); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4), - r5); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4), - r2); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4), - r6); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4), - r3); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4), - r7); - } else { - __m256i t0, t1, t2, t3, t4, t5, t6, t7; - __m256i r0, r1, r2, r3, r4, r5, r6, r7; - const __m256i input_xor_v = _mm256_set1_epi8(input_xor); - - t0 = _mm256_loadu_si256(reinterpret_cast(src_ptr0)); - t4 = _mm256_loadu_si256(reinterpret_cast(src_ptr4)); - t1 = _mm256_loadu_si256(reinterpret_cast(src_ptr1)); - t5 = _mm256_loadu_si256(reinterpret_cast(src_ptr5)); - t2 = _mm256_loadu_si256(reinterpret_cast(src_ptr2)); - t6 = _mm256_loadu_si256(reinterpret_cast(src_ptr6)); - t3 = _mm256_loadu_si256(reinterpret_cast(src_ptr3)); - t7 = _mm256_loadu_si256(reinterpret_cast(src_ptr7)); - - r0 = _mm256_unpacklo_epi32(t0, t1); - r4 = _mm256_unpacklo_epi32(t4, t5); - r2 = _mm256_unpackhi_epi32(t0, t1); - r6 = _mm256_unpackhi_epi32(t4, t5); - r1 = _mm256_unpacklo_epi32(t2, t3); - r5 = _mm256_unpacklo_epi32(t6, t7); - r3 = _mm256_unpackhi_epi32(t2, t3); - r7 = _mm256_unpackhi_epi32(t6, t7); - - t0 = _mm256_unpacklo_epi64(r0, r1); - t4 = _mm256_unpacklo_epi64(r4, r5); - t2 = _mm256_unpackhi_epi64(r0, r1); - t6 = _mm256_unpackhi_epi64(r4, r5); - t1 = _mm256_unpacklo_epi64(r2, r3); - t5 = _mm256_unpacklo_epi64(r6, r7); - t3 = _mm256_unpackhi_epi64(r2, r3); - t7 = _mm256_unpackhi_epi64(r6, r7); - - // The preceding sets of rearrangement operations interleaved by 4 bytes - // and then by 8 bytes *within* lanes. The following set interleave by - // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, - // t4) are interleaved to create (r0, r1). This complexity follows from - // the way that AVX is centered around MM 128-bit lanes. - r0 = _mm256_permute2x128_si256(t0, t4, 0x20); - r4 = _mm256_permute2x128_si256(t1, t5, 0x20); - r1 = _mm256_permute2x128_si256(t0, t4, 0x31); - r5 = _mm256_permute2x128_si256(t1, t5, 0x31); - r2 = _mm256_permute2x128_si256(t2, t6, 0x20); - r6 = _mm256_permute2x128_si256(t3, t7, 0x20); - r3 = _mm256_permute2x128_si256(t2, t6, 0x31); - r7 = _mm256_permute2x128_si256(t3, t7, 0x31); - - r0 = _mm256_xor_si256(r0, input_xor_v); - r1 = _mm256_xor_si256(r1, input_xor_v); - r2 = _mm256_xor_si256(r2, input_xor_v); - r3 = _mm256_xor_si256(r3, input_xor_v); - r4 = _mm256_xor_si256(r4, input_xor_v); - r5 = _mm256_xor_si256(r5, input_xor_v); - r6 = _mm256_xor_si256(r6, input_xor_v); - r7 = _mm256_xor_si256(r7, input_xor_v); - - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4), - r0); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4), - r4); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4), - r1); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4), - r5); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4), - r2); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4), - r6); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4), - r3); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4), - r7); - } - } else if (available_src_rows > 0) { - RUY_DCHECK_LT(available_src_rows, kNumChunkedSrcRows); - // We do not care what goes into the trailing buffer, but we want - // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. - // - // We compensate for padding-with-zero_point by initializing the - // summations with the compensating offset, effectively - // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) * - // 4 * (8 - ((available_src_rows + 3) >> 2)). - // - // Note that (zero_point ^ input_xor) is performed in 8-bits and then - // cast. - sums_adjustment += - -(zero_point ^ input_xor) * 4 * (8 - ((available_src_rows + 3) >> 2)); - - __m256i t0, t1, t2, t3, t4, t5, t6, t7; - __m256i r0, r1, r2, r3, r4, r5, r6, r7; - const __m256i input_xor_v = _mm256_set1_epi8(input_xor); - - t0 = MaskLoadu(available_src_rows, zero_point, src_ptr0); - t4 = MaskLoadu(available_src_rows, zero_point, src_ptr4); - t1 = MaskLoadu(available_src_rows, zero_point, src_ptr1); - t5 = MaskLoadu(available_src_rows, zero_point, src_ptr5); - t2 = MaskLoadu(available_src_rows, zero_point, src_ptr2); - t6 = MaskLoadu(available_src_rows, zero_point, src_ptr6); - t3 = MaskLoadu(available_src_rows, zero_point, src_ptr3); - t7 = MaskLoadu(available_src_rows, zero_point, src_ptr7); - - r0 = _mm256_unpacklo_epi32(t0, t1); - r4 = _mm256_unpacklo_epi32(t4, t5); - r2 = _mm256_unpackhi_epi32(t0, t1); - r6 = _mm256_unpackhi_epi32(t4, t5); - r1 = _mm256_unpacklo_epi32(t2, t3); - r5 = _mm256_unpacklo_epi32(t6, t7); - r3 = _mm256_unpackhi_epi32(t2, t3); - r7 = _mm256_unpackhi_epi32(t6, t7); - - t0 = _mm256_unpacklo_epi64(r0, r1); - t4 = _mm256_unpacklo_epi64(r4, r5); - t2 = _mm256_unpackhi_epi64(r0, r1); - t6 = _mm256_unpackhi_epi64(r4, r5); - t1 = _mm256_unpacklo_epi64(r2, r3); - t5 = _mm256_unpacklo_epi64(r6, r7); - t3 = _mm256_unpackhi_epi64(r2, r3); - t7 = _mm256_unpackhi_epi64(r6, r7); - - // The preceding sets of rearrangement operations interleaved by 4 bytes - // and then by 8 bytes *within* lanes. The following set interleave by - // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, - // t4) are interleaved to create (r0, r1). This complexity follows from - // the way that AVX is centered around MM 128-bit lanes. - r0 = _mm256_permute2x128_si256(t0, t4, 0x20); - r4 = _mm256_permute2x128_si256(t1, t5, 0x20); - r1 = _mm256_permute2x128_si256(t0, t4, 0x31); - r5 = _mm256_permute2x128_si256(t1, t5, 0x31); - r2 = _mm256_permute2x128_si256(t2, t6, 0x20); - r6 = _mm256_permute2x128_si256(t3, t7, 0x20); - r3 = _mm256_permute2x128_si256(t2, t6, 0x31); - r7 = _mm256_permute2x128_si256(t3, t7, 0x31); - - r0 = _mm256_xor_si256(r0, input_xor_v); - r1 = _mm256_xor_si256(r1, input_xor_v); - r2 = _mm256_xor_si256(r2, input_xor_v); - r3 = _mm256_xor_si256(r3, input_xor_v); - r4 = _mm256_xor_si256(r4, input_xor_v); - r5 = _mm256_xor_si256(r5, input_xor_v); - r6 = _mm256_xor_si256(r6, input_xor_v); - r7 = _mm256_xor_si256(r7, input_xor_v); - - __m256i sums_4x4_16bit_lo; - sums_4x4_16bit_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0)); - sums_4x4_16bit_lo = _mm256_add_epi16( - sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1))); - sums_4x4_16bit_lo = _mm256_add_epi16( - sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2))); - sums_4x4_16bit_lo = _mm256_add_epi16( - sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3))); - sums_4x4_16bit_lo = _mm256_add_epi16( - sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4))); - sums_4x4_16bit_lo = _mm256_add_epi16( - sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5))); - sums_4x4_16bit_lo = _mm256_add_epi16( - sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6))); - sums_4x4_16bit_lo = _mm256_add_epi16( - sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7))); - - // The sums have been performed across columns, and now we have 4x16-bit - // sums packed together. We use madd for pairwise 32-bit sums. - const __m256i sums_4x2_32bit_lo_new = - _mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit); - sums_4x2_32bit_lo = - _mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new); - - __m256i sums_4x4_16bit_hi; - sums_4x4_16bit_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r0, 1)); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r1, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r2, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r3, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r4, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r5, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r6, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r7, 1))); - - const __m256i sums_4x2_32bit_hi_new = - _mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit); - sums_4x2_32bit_hi = - _mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new); - - _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 0 * 8 * 4), - r0); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 2 * 8 * 4), - r4); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 4 * 8 * 4), - r1); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 6 * 8 * 4), - r5); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 1 * 8 * 4), - r2); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 3 * 8 * 4), - r6); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 5 * 8 * 4), - r3); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 7 * 8 * 4), - r7); - } - - packed_ptr += 8 * kNumChunkedSrcRows; - src_ptr0 += src_inc0; - src_ptr1 += src_inc1; - src_ptr2 += src_inc2; - src_ptr3 += src_inc3; - src_ptr4 += src_inc4; - src_ptr5 += src_inc5; - src_ptr6 += src_inc6; - src_ptr7 += src_inc7; - } - - if (sums_ptr) { - const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment); - - __m256i sums = - _mm256_loadu_si256(reinterpret_cast(sums_ptr)); - const __m256i idx = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); - - // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the - // neighbours, finshing up by adding them to the stored accumulated sums. - const __m256i sums_2x4_32bit_lo = - _mm256_permutevar8x32_epi32(sums_4x2_32bit_lo, idx); - const __m256i sums_2x4_32bit_hi = - _mm256_permutevar8x32_epi32(sums_4x2_32bit_hi, idx); - const __m256i sums_2x4_32bit_a = - _mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x20); - const __m256i sums_2x4_32bit_b = - _mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x31); - sums = _mm256_add_epi32(sums, sums_adjustment_v); - sums = _mm256_add_epi32(sums, sums_2x4_32bit_a); - sums = _mm256_add_epi32(sums, sums_2x4_32bit_b); - - _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums); - } -} - -inline __m256 Mm256UnpackloPsx2(const __m256 a, const __m256 b) { - return _mm256_castpd_ps( - _mm256_unpacklo_pd(_mm256_castps_pd(a), _mm256_castps_pd(b))); -} - -inline __m256 Mm256UnpackhiPsx2(const __m256 a, const __m256 b) { - return _mm256_castpd_ps( - _mm256_unpackhi_pd(_mm256_castps_pd(a), _mm256_castps_pd(b))); -} - -inline void PackFloatAvx2Packer(const float* src_ptr, const float* zerobuf, - int src_stride, int remaining_src_cols, - int src_rows, float* packed_ptr, - float* trailing_buf) { - RUY_DCHECK_EQ(PackImplFloatAvx2::Layout::kCols, 8); - RUY_DCHECK_EQ(PackImplFloatAvx2::Layout::kRows, 1); - - // This packing amounts to transposition of 8x8 blocks. - static constexpr int kPackCols = 8; // Source cols packed together. - static constexpr int kPackRows = 8; // Short input is padded. - - const float* src_ptr0 = src_ptr; - const float* src_ptr1 = src_ptr0 + src_stride; - const float* src_ptr2 = src_ptr1 + src_stride; - const float* src_ptr3 = src_ptr2 + src_stride; - const float* src_ptr4 = src_ptr3 + src_stride; - const float* src_ptr5 = src_ptr4 + src_stride; - const float* src_ptr6 = src_ptr5 + src_stride; - const float* src_ptr7 = src_ptr6 + src_stride; - std::int64_t src_inc0 = 8; - std::int64_t src_inc1 = 8; - std::int64_t src_inc2 = 8; - std::int64_t src_inc3 = 8; - std::int64_t src_inc4 = 8; - std::int64_t src_inc5 = 8; - std::int64_t src_inc6 = 8; - std::int64_t src_inc7 = 8; - // Handle cases where source does not have kPackDim (8) columns. - if (remaining_src_cols < kPackCols) { - if (remaining_src_cols <= 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (remaining_src_cols <= 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (remaining_src_cols <= 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (remaining_src_cols <= 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - if (remaining_src_cols <= 4) { - src_ptr4 = zerobuf; - src_inc4 = 0; - } - if (remaining_src_cols <= 5) { - src_ptr5 = zerobuf; - src_inc5 = 0; - } - if (remaining_src_cols <= 6) { - src_ptr6 = zerobuf; - src_inc6 = 0; - } - src_ptr7 = zerobuf; - src_inc7 = 0; - } - - for (int k = 0; k < src_rows; k += kPackRows) { - const int available_src_rows = src_rows - k; - // Effectively, - // available_src_rows = std::max(0, std::min(kPackDim, src_rows - k)); - // but treat each case separately. - if (available_src_rows >= kPackRows) { - __m256 t0, t1, t2, t3, t4, t5, t6, t7; - __m256 r0, r1, r2, r3, r4, r5, r6, r7; - - t0 = _mm256_loadu_ps(src_ptr0); - t4 = _mm256_loadu_ps(src_ptr4); - t1 = _mm256_loadu_ps(src_ptr1); - t5 = _mm256_loadu_ps(src_ptr5); - t2 = _mm256_loadu_ps(src_ptr2); - t6 = _mm256_loadu_ps(src_ptr6); - t3 = _mm256_loadu_ps(src_ptr3); - t7 = _mm256_loadu_ps(src_ptr7); - - r0 = _mm256_unpacklo_ps(t0, t1); - r4 = _mm256_unpacklo_ps(t4, t5); - r2 = _mm256_unpackhi_ps(t0, t1); - r6 = _mm256_unpackhi_ps(t4, t5); - r1 = _mm256_unpacklo_ps(t2, t3); - r5 = _mm256_unpacklo_ps(t6, t7); - r3 = _mm256_unpackhi_ps(t2, t3); - r7 = _mm256_unpackhi_ps(t6, t7); - - t0 = Mm256UnpackloPsx2(r0, r1); - t4 = Mm256UnpackloPsx2(r4, r5); - t2 = Mm256UnpackhiPsx2(r0, r1); - t6 = Mm256UnpackhiPsx2(r4, r5); - t1 = Mm256UnpackloPsx2(r2, r3); - t5 = Mm256UnpackloPsx2(r6, r7); - t3 = Mm256UnpackhiPsx2(r2, r3); - t7 = Mm256UnpackhiPsx2(r6, r7); - - // The preceding sets of rearrangement operations interleaved by 4 bytes - // and then by 8 bytes *within* lanes. The following set interleave by 16 - // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4) - // are interleaved to create (r0, r1). This complexity follows from the - // way that AVX is centered around MM 128-bit lanes. - r0 = _mm256_permute2f128_ps(t0, t4, 0x20); - r4 = _mm256_permute2f128_ps(t1, t5, 0x20); - r1 = _mm256_permute2f128_ps(t0, t4, 0x31); - r5 = _mm256_permute2f128_ps(t1, t5, 0x31); - r2 = _mm256_permute2f128_ps(t2, t6, 0x20); - r6 = _mm256_permute2f128_ps(t3, t7, 0x20); - r3 = _mm256_permute2f128_ps(t2, t6, 0x31); - r7 = _mm256_permute2f128_ps(t3, t7, 0x31); - - _mm256_storeu_ps(packed_ptr + 0 * 8, r0); - _mm256_storeu_ps(packed_ptr + 2 * 8, r4); - _mm256_storeu_ps(packed_ptr + 4 * 8, r1); - _mm256_storeu_ps(packed_ptr + 6 * 8, r5); - _mm256_storeu_ps(packed_ptr + 1 * 8, r2); - _mm256_storeu_ps(packed_ptr + 3 * 8, r6); - _mm256_storeu_ps(packed_ptr + 5 * 8, r3); - _mm256_storeu_ps(packed_ptr + 7 * 8, r7); - } else if (available_src_rows > 0) { - const __m256i series = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); - const __m256i row_mask_v = - _mm256_cmpgt_epi32(_mm256_set1_epi32(available_src_rows), series); - - __m256 t0, t1, t2, t3, t4, t5, t6, t7; - __m256 r0, r1, r2, r3, r4, r5, r6, r7; - - t0 = _mm256_maskload_ps(src_ptr0, row_mask_v); - t4 = _mm256_maskload_ps(src_ptr4, row_mask_v); - t1 = _mm256_maskload_ps(src_ptr1, row_mask_v); - t5 = _mm256_maskload_ps(src_ptr5, row_mask_v); - t2 = _mm256_maskload_ps(src_ptr2, row_mask_v); - t6 = _mm256_maskload_ps(src_ptr6, row_mask_v); - t3 = _mm256_maskload_ps(src_ptr3, row_mask_v); - t7 = _mm256_maskload_ps(src_ptr7, row_mask_v); - - r0 = _mm256_unpacklo_ps(t0, t1); - r4 = _mm256_unpacklo_ps(t4, t5); - r2 = _mm256_unpackhi_ps(t0, t1); - r6 = _mm256_unpackhi_ps(t4, t5); - r1 = _mm256_unpacklo_ps(t2, t3); - r5 = _mm256_unpacklo_ps(t6, t7); - r3 = _mm256_unpackhi_ps(t2, t3); - r7 = _mm256_unpackhi_ps(t6, t7); - - t0 = Mm256UnpackloPsx2(r0, r1); - t4 = Mm256UnpackloPsx2(r4, r5); - t2 = Mm256UnpackhiPsx2(r0, r1); - t6 = Mm256UnpackhiPsx2(r4, r5); - t1 = Mm256UnpackloPsx2(r2, r3); - t5 = Mm256UnpackloPsx2(r6, r7); - t3 = Mm256UnpackhiPsx2(r2, r3); - t7 = Mm256UnpackhiPsx2(r6, r7); - - // The preceding sets of rearrangement operations interleaved by 4 bytes - // and then by 8 bytes *within* lanes. The following set interleave by 16 - // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4) - // are interleaved to create (r0, r1). This complexity follows from the - // way that AVX is centered around MM 128-bit lanes. - r0 = _mm256_permute2f128_ps(t0, t4, 0x20); - r4 = _mm256_permute2f128_ps(t1, t5, 0x20); - r1 = _mm256_permute2f128_ps(t0, t4, 0x31); - r5 = _mm256_permute2f128_ps(t1, t5, 0x31); - r2 = _mm256_permute2f128_ps(t2, t6, 0x20); - r6 = _mm256_permute2f128_ps(t3, t7, 0x20); - r3 = _mm256_permute2f128_ps(t2, t6, 0x31); - // r7 no longer needed. - - _mm256_storeu_ps(trailing_buf + 0 * 8, r0); - _mm256_storeu_ps(trailing_buf + 2 * 8, r4); - _mm256_storeu_ps(trailing_buf + 4 * 8, r1); - _mm256_storeu_ps(trailing_buf + 6 * 8, r5); - _mm256_storeu_ps(trailing_buf + 1 * 8, r2); - _mm256_storeu_ps(trailing_buf + 3 * 8, r6); - _mm256_storeu_ps(trailing_buf + 5 * 8, r3); - // No store to (trailing_buf + 7 * 8), space not allocated. - } - - packed_ptr += kPackRows * kPackCols; - src_ptr0 += src_inc0; - src_ptr1 += src_inc1; - src_ptr2 += src_inc2; - src_ptr3 += src_inc3; - src_ptr4 += src_inc4; - src_ptr5 += src_inc5; - src_ptr6 += src_inc6; - src_ptr7 += src_inc7; - } -} - -} // namespace. - -void Pack8bitAvx2(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, std::int8_t* packed_ptr, - std::int32_t* sums_ptr) { - profiler::ScopeLabel label("Pack kAvx2 8bit"); - - using Layout = PackImpl8bitAvx2::Layout; - RUY_DCHECK_EQ(Layout::kCols, 8); - RUY_DCHECK_EQ(Layout::kRows, 4); - - // Each Layout::Rows is 4 contiguous input, contiguous packed elements. - // We process 8 of these chunks at a time, padding short input chunks. - static constexpr int kNumRowChunks = 8; // Short input is padded. - - // Each packed block is 4*8, and there are normally 8. The trailing block is - // only slightly shorter. - constexpr int kTrailingBufSize = - kNumRowChunks * Layout::kCols * Layout::kRows; - std::int8_t trailing_buf[kTrailingBufSize]; - memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t)); - - Pack8bitAvx2Packer(src_ptr, input_xor, zerobuf, src_stride, - remaining_src_cols, src_rows, packed_ptr, sums_ptr, - trailing_buf); - - constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; - const bool trailing_data = (src_rows & kChunkedRowMask) > 0; - // If the number of source rows is not a multiple of kChunkedRowMask, there - // will be data in the trailing buffer, - if (trailing_data > 0) { - const int non_trailing_rows = src_rows & ~kChunkedRowMask; - // Destination "rows" are padded to next highest multiple of Layout::kRows. - const int dst_rows = (src_rows + 3) & ~3; - const int trailing_rows = dst_rows - non_trailing_rows; - memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf, - Layout::kCols * trailing_rows * sizeof(std::int8_t)); - } -} - -void PackFloatAvx2(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr) { - profiler::ScopeLabel label("Pack kAvx2 float"); - static constexpr int kPackCols = 8; // Source cols packed together. - static constexpr int kPackRows = 8; // Short input is padded. - float trailing_buf[(kPackRows - 1) * kPackCols]; - if (remaining_src_cols < 8) { - memset(trailing_buf, 0, sizeof(trailing_buf)); - } - PackFloatAvx2Packer(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_rows, packed_ptr, trailing_buf); - - const int trailing_rows = src_rows & (kPackRows - 1); - if (trailing_rows > 0) { - const int non_trailing_rows = src_rows & ~(kPackRows - 1); - memcpy(packed_ptr + kPackCols * non_trailing_rows, trailing_buf, - kPackCols * trailing_rows * sizeof(float)); - } -} - -#endif // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/pack_avx512.cc b/tensorflow/lite/experimental/ruy/ruy/pack_avx512.cc deleted file mode 100644 index d5636572eed..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/pack_avx512.cc +++ /dev/null @@ -1,693 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/pack.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -#if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) -#include // IWYU pragma: keep -#endif - -namespace ruy { - -#if !(RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -void Pack8bitAvx512(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void PackFloatAvx512(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -#else // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -// The first int8_t template parameter is arbitrary: this routine is common to -// all 8-bit source matrix types. -using PackImpl8bitAvx512 = - PackImpl, - std::int8_t, std::int8_t, std::int32_t>; - -namespace { - -inline void ZeroHalf8bitAvx512(int src_rows, std::int8_t packed_zero_point, - std::int8_t* packed_ptr) { - using Layout = PackImpl8bitAvx512::Layout; - static constexpr int kHalfLayoutCols = - PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a - // block. - RUY_DCHECK_EQ(kHalfLayoutCols, 8); - RUY_DCHECK_EQ(Layout::kCols, 16); - RUY_DCHECK_EQ(Layout::kRows, 4); - - const int non_trailing_blocks = (src_rows & ~31) >> 2; - // This routine fills half blocks, and typically fills the second halves. - // Thus packed_ptr is already offset by 8 * 4. - for (int k = 0; k < non_trailing_blocks; ++k) { - for (int j = 0; j < (kHalfLayoutCols * Layout::kRows); ++j) { - packed_ptr[Layout::kCols * Layout::kRows * k + j] = packed_zero_point; - } - } -} - -inline __m512i LoaduTwo(const std::int8_t* addr_lo, - const std::int8_t* addr_hi) { - __m512i lower_filled = _mm512_castsi256_si512(_mm256_loadu_epi8(addr_lo)); - return _mm512_inserti32x8(lower_filled, _mm256_loadu_epi8(addr_hi), 1); -} - -inline __m512i MaskLoaduTwo(__mmask32 row_mask, const __m256i default_value_v, - const std::int8_t* addr_lo, - const std::int8_t* addr_hi) { - const __m512i lower_filled = _mm512_castsi256_si512( - _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_lo)); - return _mm512_inserti32x8( - lower_filled, _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_hi), - 1); -} - -inline void HalfPack8bitAvx512(const std::int8_t* src_ptr, - std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr, - std::int8_t* trailing_buf) { - using Layout = PackImpl8bitAvx512::Layout; - RUY_DCHECK_EQ(Layout::kCols, 16); - RUY_DCHECK_EQ(Layout::kRows, 4); - // Each Layout::Rows is 4 contiguous input, contiguous packed elements. - // We process 8 of these chunks at a time, padding short input chunks. - constexpr int kNumRowChunks = 8; - constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows; - - const std::int8_t* src_ptr0 = src_ptr; - const std::int8_t* src_ptr1 = src_ptr0 + src_stride; - const std::int8_t* src_ptr2 = src_ptr1 + src_stride; - const std::int8_t* src_ptr3 = src_ptr2 + src_stride; - const std::int8_t* src_ptr4 = src_ptr3 + src_stride; - const std::int8_t* src_ptr5 = src_ptr4 + src_stride; - const std::int8_t* src_ptr6 = src_ptr5 + src_stride; - const std::int8_t* src_ptr7 = src_ptr6 + src_stride; - std::int64_t src_inc0 = kNumChunkedSrcRows; - std::int64_t src_inc1 = kNumChunkedSrcRows; - std::int64_t src_inc2 = kNumChunkedSrcRows; - std::int64_t src_inc3 = kNumChunkedSrcRows; - std::int64_t src_inc4 = kNumChunkedSrcRows; - std::int64_t src_inc5 = kNumChunkedSrcRows; - std::int64_t src_inc6 = kNumChunkedSrcRows; - std::int64_t src_inc7 = kNumChunkedSrcRows; - // Handle cases where source does not have kHalfLayoutCols (8) columns. - if (remaining_src_cols < 8) { - if (remaining_src_cols <= 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (remaining_src_cols <= 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (remaining_src_cols <= 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (remaining_src_cols <= 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - if (remaining_src_cols <= 4) { - src_ptr4 = zerobuf; - src_inc4 = 0; - } - if (remaining_src_cols <= 5) { - src_ptr5 = zerobuf; - src_inc5 = 0; - } - if (remaining_src_cols <= 6) { - src_ptr6 = zerobuf; - src_inc6 = 0; - } - src_ptr7 = zerobuf; - src_inc7 = 0; - } - - const std::int8_t zero_point = zerobuf[0]; - - if (sums_ptr) { - // i: kHalfLayoutCols. - for (int i = 0; i < 8; ++i) { - sums_ptr[i] = 0; - } - } - std::int32_t sums_adjustment = 0; - const __m512i ones_16bit = _mm512_set1_epi16(1); - __m512i sums_8x2_32bit = _mm512_set1_epi32(0); - - // The overall packing effectively pads the source rows to - // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we - // only pack for (src_rows + 31) & ~31. When there is an incomplete - // destination block, this is stored into trailing_buf instead of packed_ptr. - for (int k = 0; k < src_rows; k += 2 * kNumChunkedSrcRows) { - // m: {0, 1} for 2 chunks of rows. - for (int m = 0; m < 2; ++m) { - // Available source rows. - // If this is less than 0 (for m=1), we skip, having filled trailing - // buffer for m=0. Also, if source rows is zero on m=1, then we filled - // exactly to the end of the column in the packed buffer. - const int available_src_rows = src_rows - k - m * kNumChunkedSrcRows; - // Effectively, - // available rows = std::max(0, std::min(8, src_rows - k - 8 * 4 * m)); - // treat each case separately. - if (available_src_rows >= kNumChunkedSrcRows) { - // i: chunks, s: Layout::Rows. - if (sums_ptr) { - __m512i t0, t1, t2, t3; - __m512i r0, r1, r2, r3; - const __m512i input_xor_v = _mm512_set1_epi8(input_xor); - - t0 = LoaduTwo(src_ptr0, src_ptr4); - t1 = LoaduTwo(src_ptr1, src_ptr5); - t2 = LoaduTwo(src_ptr2, src_ptr6); - t3 = LoaduTwo(src_ptr3, src_ptr7); - - r0 = _mm512_unpacklo_epi32(t0, t1); - r2 = _mm512_unpackhi_epi32(t0, t1); - r1 = _mm512_unpacklo_epi32(t2, t3); - r3 = _mm512_unpackhi_epi32(t2, t3); - - t0 = _mm512_unpacklo_epi64(r0, r1); - t2 = _mm512_unpackhi_epi64(r0, r1); - t1 = _mm512_unpacklo_epi64(r2, r3); - t3 = _mm512_unpackhi_epi64(r2, r3); - - r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); - r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd); - r2 = _mm512_shuffle_i32x4(t2, t3, 0x88); - r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd); - - r0 = _mm512_xor_si512(r0, input_xor_v); - r1 = _mm512_xor_si512(r1, input_xor_v); - r2 = _mm512_xor_si512(r2, input_xor_v); - r3 = _mm512_xor_si512(r3, input_xor_v); - - const __m256i r0_0 = _mm512_castsi512_si256(r0); - const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1); - const __m256i r1_0 = _mm512_castsi512_si256(r1); - const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1); - const __m256i r2_0 = _mm512_castsi512_si256(r2); - const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1); - const __m256i r3_0 = _mm512_castsi512_si256(r3); - const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1); - - __m512i sums_8x4_16bit; - sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1)); - // The sums have been performed across columns, and now we have - // 4x16-bit sums packed together. We use madd for pairwise 32-bit - // sums. - const __m512i sums_8x2_32bit_new = - _mm512_madd_epi16(sums_8x4_16bit, ones_16bit); - sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new); - - _mm256_storeu_epi8(packed_ptr + 0 * 16 * 4, r0_0); - _mm256_storeu_epi8(packed_ptr + 2 * 16 * 4, r0_1); - _mm256_storeu_epi8(packed_ptr + 4 * 16 * 4, r1_0); - _mm256_storeu_epi8(packed_ptr + 6 * 16 * 4, r1_1); - _mm256_storeu_epi8(packed_ptr + 1 * 16 * 4, r2_0); - _mm256_storeu_epi8(packed_ptr + 3 * 16 * 4, r2_1); - _mm256_storeu_epi8(packed_ptr + 5 * 16 * 4, r3_0); - _mm256_storeu_epi8(packed_ptr + 7 * 16 * 4, r3_1); - } else { - __m512i t0, t1, t2, t3; - __m512i r0, r1, r2, r3; - const __m512i input_xor_v = _mm512_set1_epi8(input_xor); - - t0 = LoaduTwo(src_ptr0, src_ptr4); - t1 = LoaduTwo(src_ptr1, src_ptr5); - t2 = LoaduTwo(src_ptr2, src_ptr6); - t3 = LoaduTwo(src_ptr3, src_ptr7); - - r0 = _mm512_unpacklo_epi32(t0, t1); - r2 = _mm512_unpackhi_epi32(t0, t1); - r1 = _mm512_unpacklo_epi32(t2, t3); - r3 = _mm512_unpackhi_epi32(t2, t3); - - t0 = _mm512_unpacklo_epi64(r0, r1); - t2 = _mm512_unpackhi_epi64(r0, r1); - t1 = _mm512_unpacklo_epi64(r2, r3); - t3 = _mm512_unpackhi_epi64(r2, r3); - - r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); - r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd); - r2 = _mm512_shuffle_i32x4(t2, t3, 0x88); - r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd); - - r0 = _mm512_xor_si512(r0, input_xor_v); - r1 = _mm512_xor_si512(r1, input_xor_v); - r2 = _mm512_xor_si512(r2, input_xor_v); - r3 = _mm512_xor_si512(r3, input_xor_v); - - const __m256i r0_0 = _mm512_castsi512_si256(r0); - const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1); - const __m256i r1_0 = _mm512_castsi512_si256(r1); - const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1); - const __m256i r2_0 = _mm512_castsi512_si256(r2); - const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1); - const __m256i r3_0 = _mm512_castsi512_si256(r3); - const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1); - _mm256_storeu_epi8(packed_ptr + 0 * 16 * 4, r0_0); - _mm256_storeu_epi8(packed_ptr + 2 * 16 * 4, r0_1); - _mm256_storeu_epi8(packed_ptr + 4 * 16 * 4, r1_0); - _mm256_storeu_epi8(packed_ptr + 6 * 16 * 4, r1_1); - _mm256_storeu_epi8(packed_ptr + 1 * 16 * 4, r2_0); - _mm256_storeu_epi8(packed_ptr + 3 * 16 * 4, r2_1); - _mm256_storeu_epi8(packed_ptr + 5 * 16 * 4, r3_0); - _mm256_storeu_epi8(packed_ptr + 7 * 16 * 4, r3_1); - } - } else if (available_src_rows > 0) { - RUY_DCHECK_LT(available_src_rows >> 2, kNumChunkedSrcRows); - const __mmask32 row_mask = - (static_cast(1) << available_src_rows) - 1; - - // We do not care what goes into the trailing buffer, but we want - // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. - // - // We compensate for padding-with-zero_point by initializing the - // summations with the compensating offset, effectively - // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) * - // 4 * (8 - ((available_src_rows + 3) >> 2)). - // - // Note that (zero_point ^ input_xor) is performed in 8-bits and then - // cast. - sums_adjustment += -(zero_point ^ input_xor) * 4 * - (8 - ((available_src_rows + 3) >> 2)); - - __m512i t0, t1, t2, t3; - __m512i r0, r1, r2, r3; - const __m512i input_xor_v = _mm512_set1_epi8(input_xor); - const __m256i zero_point_v = _mm256_set1_epi8(zero_point); - - t0 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr0, src_ptr4); - t1 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr1, src_ptr5); - t2 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr2, src_ptr6); - t3 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr3, src_ptr7); - - r0 = _mm512_unpacklo_epi32(t0, t1); - r2 = _mm512_unpackhi_epi32(t0, t1); - r1 = _mm512_unpacklo_epi32(t2, t3); - r3 = _mm512_unpackhi_epi32(t2, t3); - - t0 = _mm512_unpacklo_epi64(r0, r1); - t2 = _mm512_unpackhi_epi64(r0, r1); - t1 = _mm512_unpacklo_epi64(r2, r3); - t3 = _mm512_unpackhi_epi64(r2, r3); - - r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); - r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd); - r2 = _mm512_shuffle_i32x4(t2, t3, 0x88); - r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd); - - r0 = _mm512_xor_si512(r0, input_xor_v); - r1 = _mm512_xor_si512(r1, input_xor_v); - r2 = _mm512_xor_si512(r2, input_xor_v); - r3 = _mm512_xor_si512(r3, input_xor_v); - - const __m256i r0_0 = _mm512_castsi512_si256(r0); - const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1); - const __m256i r1_0 = _mm512_castsi512_si256(r1); - const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1); - const __m256i r2_0 = _mm512_castsi512_si256(r2); - const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1); - const __m256i r3_0 = _mm512_castsi512_si256(r3); - const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1); - - __m512i sums_8x4_16bit; - sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1)); - // The sums have been performed across columns, and now we have - // 4x16-bit sums packed together. We use madd for pairwise 32-bit - // sums. - const __m512i sums_8x2_32bit_new = - _mm512_madd_epi16(sums_8x4_16bit, ones_16bit); - sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new); - - _mm256_storeu_epi8(trailing_buf + 0 * 16 * 4, r0_0); - _mm256_storeu_epi8(trailing_buf + 2 * 16 * 4, r0_1); - _mm256_storeu_epi8(trailing_buf + 4 * 16 * 4, r1_0); - _mm256_storeu_epi8(trailing_buf + 6 * 16 * 4, r1_1); - _mm256_storeu_epi8(trailing_buf + 1 * 16 * 4, r2_0); - _mm256_storeu_epi8(trailing_buf + 3 * 16 * 4, r2_1); - _mm256_storeu_epi8(trailing_buf + 5 * 16 * 4, r3_0); - _mm256_storeu_epi8(trailing_buf + 7 * 16 * 4, r3_1); - } - - packed_ptr += 16 * kNumChunkedSrcRows; - src_ptr0 += src_inc0; - src_ptr1 += src_inc1; - src_ptr2 += src_inc2; - src_ptr3 += src_inc3; - src_ptr4 += src_inc4; - src_ptr5 += src_inc5; - src_ptr6 += src_inc6; - src_ptr7 += src_inc7; - } - } - - if (sums_ptr) { - const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment); - - __m256i sums = _mm256_loadu_epi32(sums_ptr); - const __m512i idx = - _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0); - - // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the - // neighbours, finshing up by adding them to the stored accumulated sums. - const __m512i sums_2x8_32bit = - _mm512_permutexvar_epi32(idx, sums_8x2_32bit); - sums = _mm256_add_epi32(sums, sums_adjustment_v); - sums = _mm256_add_epi32(sums, _mm512_castsi512_si256(sums_2x8_32bit)); - sums = _mm256_add_epi32(sums, _mm512_extracti32x8_epi32(sums_2x8_32bit, 1)); - - _mm256_storeu_epi32(sums_ptr, sums); - } -} - -inline __m512 LoaduTwo(const float* addr_lo, const float* addr_hi) { - const __m512 lower_filled = _mm512_castps256_ps512(_mm256_loadu_ps(addr_lo)); - return _mm512_insertf32x8(lower_filled, _mm256_loadu_ps(addr_hi), 1); -} - -inline __m512 MaskLoaduTwo(__mmask8 row_mask, const float* addr_lo, - const float* addr_hi) { - const __m512 lower_filled = - _mm512_castps256_ps512(_mm256_maskz_loadu_ps(row_mask, addr_lo)); - return _mm512_insertf32x8(lower_filled, - _mm256_maskz_loadu_ps(row_mask, addr_hi), 1); -} - -inline __m512 Mm512UnpackloPsx2(const __m512 a, const __m512 b) { - return _mm512_castpd_ps( - _mm512_unpacklo_pd(_mm512_castps_pd(a), _mm512_castps_pd(b))); -} - -inline __m512 Mm512UnpackhiPsx2(const __m512 a, const __m512 b) { - return _mm512_castpd_ps( - _mm512_unpackhi_pd(_mm512_castps_pd(a), _mm512_castps_pd(b))); -} - -inline void HalfPackFloatAvx512(const float* src_ptr, const float* zerobuf, - int src_stride, int remaining_src_cols, - int src_rows, float* packed_ptr, - float* trailing_buf) { - const float* src_ptr0 = src_ptr; - const float* src_ptr1 = src_ptr0 + src_stride; - const float* src_ptr2 = src_ptr1 + src_stride; - const float* src_ptr3 = src_ptr2 + src_stride; - const float* src_ptr4 = src_ptr3 + src_stride; - const float* src_ptr5 = src_ptr4 + src_stride; - const float* src_ptr6 = src_ptr5 + src_stride; - const float* src_ptr7 = src_ptr6 + src_stride; - std::int64_t src_inc0 = 8; - std::int64_t src_inc1 = 8; - std::int64_t src_inc2 = 8; - std::int64_t src_inc3 = 8; - std::int64_t src_inc4 = 8; - std::int64_t src_inc5 = 8; - std::int64_t src_inc6 = 8; - std::int64_t src_inc7 = 8; - if (remaining_src_cols < 8) { - if (remaining_src_cols <= 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (remaining_src_cols <= 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (remaining_src_cols <= 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (remaining_src_cols <= 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - if (remaining_src_cols <= 4) { - src_ptr4 = zerobuf; - src_inc4 = 0; - } - if (remaining_src_cols <= 5) { - src_ptr5 = zerobuf; - src_inc5 = 0; - } - if (remaining_src_cols <= 6) { - src_ptr6 = zerobuf; - src_inc6 = 0; - } - src_ptr7 = zerobuf; - src_inc7 = 0; - } - - for (int k = 0; k < src_rows; k += 16) { - for (int m = 0; m < 2; ++m) { - const int available_src_rows = src_rows - k - 8 * m; - // Effectively, - // available_src_rows = std::max(0, std::min(8, src_rows - k - 8 * m)); - // but treat each case separately. - if (available_src_rows > 7) { - __m512 t0, t1, t2, t3; - __m512 r0, r1, r2, r3; - - t0 = LoaduTwo(src_ptr0, src_ptr4); - t1 = LoaduTwo(src_ptr1, src_ptr5); - t2 = LoaduTwo(src_ptr2, src_ptr6); - t3 = LoaduTwo(src_ptr3, src_ptr7); - - r0 = _mm512_unpacklo_ps(t0, t1); - r2 = _mm512_unpackhi_ps(t0, t1); - r1 = _mm512_unpacklo_ps(t2, t3); - r3 = _mm512_unpackhi_ps(t2, t3); - - t0 = Mm512UnpackloPsx2(r0, r1); - t2 = Mm512UnpackhiPsx2(r0, r1); - t1 = Mm512UnpackloPsx2(r2, r3); - t3 = Mm512UnpackhiPsx2(r2, r3); - - r0 = _mm512_shuffle_f32x4(t0, t1, 0x88); - r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd); - r2 = _mm512_shuffle_f32x4(t2, t3, 0x88); - r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd); - - _mm256_storeu_ps(packed_ptr + 0 * 16, _mm512_castps512_ps256(r0)); - _mm256_storeu_ps(packed_ptr + 2 * 16, _mm512_extractf32x8_ps(r0, 1)); - _mm256_storeu_ps(packed_ptr + 4 * 16, _mm512_castps512_ps256(r1)); - _mm256_storeu_ps(packed_ptr + 6 * 16, _mm512_extractf32x8_ps(r1, 1)); - _mm256_storeu_ps(packed_ptr + 1 * 16, _mm512_castps512_ps256(r2)); - _mm256_storeu_ps(packed_ptr + 3 * 16, _mm512_extractf32x8_ps(r2, 1)); - _mm256_storeu_ps(packed_ptr + 5 * 16, _mm512_castps512_ps256(r3)); - _mm256_storeu_ps(packed_ptr + 7 * 16, _mm512_extractf32x8_ps(r3, 1)); - } else if (available_src_rows > 0) { - const __mmask8 row_mask = - (static_cast(1) << available_src_rows) - 1; - - __m512 t0, t1, t2, t3; - __m512 r0, r1, r2, r3; - - t0 = MaskLoaduTwo(row_mask, src_ptr0, src_ptr4); - t1 = MaskLoaduTwo(row_mask, src_ptr1, src_ptr5); - t2 = MaskLoaduTwo(row_mask, src_ptr2, src_ptr6); - t3 = MaskLoaduTwo(row_mask, src_ptr3, src_ptr7); - - r0 = _mm512_unpacklo_ps(t0, t1); - r2 = _mm512_unpackhi_ps(t0, t1); - r1 = _mm512_unpacklo_ps(t2, t3); - r3 = _mm512_unpackhi_ps(t2, t3); - - t0 = Mm512UnpackloPsx2(r0, r1); - t2 = Mm512UnpackhiPsx2(r0, r1); - t1 = Mm512UnpackloPsx2(r2, r3); - t3 = Mm512UnpackhiPsx2(r2, r3); - - r0 = _mm512_shuffle_f32x4(t0, t1, 0x88); - r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd); - r2 = _mm512_shuffle_f32x4(t2, t3, 0x88); - r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd); - - _mm256_storeu_ps(trailing_buf + 0 * 16, _mm512_castps512_ps256(r0)); - _mm256_storeu_ps(trailing_buf + 2 * 16, _mm512_extractf32x8_ps(r0, 1)); - _mm256_storeu_ps(trailing_buf + 4 * 16, _mm512_castps512_ps256(r1)); - _mm256_storeu_ps(trailing_buf + 6 * 16, _mm512_extractf32x8_ps(r1, 1)); - _mm256_storeu_ps(trailing_buf + 1 * 16, _mm512_castps512_ps256(r2)); - _mm256_storeu_ps(trailing_buf + 3 * 16, _mm512_extractf32x8_ps(r2, 1)); - _mm256_storeu_ps(trailing_buf + 5 * 16, _mm512_castps512_ps256(r3)); - // Do not store _mm512_extractf32x8_ps(r3, 1). - } - - packed_ptr += 16 * 8; - src_ptr0 += src_inc0; - src_ptr1 += src_inc1; - src_ptr2 += src_inc2; - src_ptr3 += src_inc3; - src_ptr4 += src_inc4; - src_ptr5 += src_inc5; - src_ptr6 += src_inc6; - src_ptr7 += src_inc7; - } - } -} - -inline void ZeroHalfFloatAvx512(int src_rows, float* packed_ptr) { - const int non_trailing_rows = src_rows & ~7; - for (int k = 0; k < non_trailing_rows; ++k) { - for (int j = 0; j < 8; ++j) { - packed_ptr[j] = 0.0f; - } - packed_ptr += 16; - } -} - -} // namespace. - -void Pack8bitAvx512(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr) { - profiler::ScopeLabel label("Pack kAvx512 8bit"); - - using Layout = PackImpl8bitAvx512::Layout; - constexpr int kHalfBlockOffset = 32; - RUY_DCHECK_EQ(kHalfBlockOffset * 2, Layout::kRows * Layout::kCols); - static constexpr int kHalfLayoutCols = - PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a - // block. - RUY_DCHECK_EQ(kHalfLayoutCols, 8); - RUY_DCHECK_EQ(Layout::kCols, 16); - RUY_DCHECK_EQ(Layout::kRows, 4); - - // Each Layout::Rows is 4 contiguous input, contiguous packed elements. - // We process 8 of these chunks at a time, padding short input chunks. - constexpr int kNumRowChunks = 8; - - // Each packed block is 4*16, and there are normally 8. The trailing block is - // only slightly shorter. - constexpr int kTrailingBufSize = - kNumRowChunks * Layout::kCols * Layout::kRows; - std::int8_t trailing_buf[kTrailingBufSize]; - memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t)); - - std::int32_t* second_sums_ptr = - sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr; - if (remaining_src_cols > kHalfLayoutCols) { - HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride, - remaining_src_cols, src_rows, packed_ptr, sums_ptr, - trailing_buf); - HalfPack8bitAvx512(src_ptr + src_stride * kHalfLayoutCols, input_xor, - zerobuf, src_stride, - remaining_src_cols - kHalfLayoutCols, src_rows, - packed_ptr + kHalfBlockOffset, second_sums_ptr, - trailing_buf + kHalfBlockOffset); - } else { - HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride, - remaining_src_cols, src_rows, packed_ptr, sums_ptr, - trailing_buf); - ZeroHalf8bitAvx512(src_rows, zerobuf[0] ^ input_xor, - packed_ptr + kHalfBlockOffset); - // The kernel may not need the second half-blocks sums to be set. - if (second_sums_ptr) { - for (int i = 0; i < kHalfLayoutCols; ++i) { - second_sums_ptr[i] = (zerobuf[0] ^ input_xor) * ((src_rows + 3) & ~3); - } - } - } - constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; - const bool trailing_data = (src_rows & kChunkedRowMask) > 0; - // If the number of source rows is not a multiple of kChunkedRowMask, there - // will be data in the trailing buffer, - if (trailing_data > 0) { - const int non_trailing_rows = src_rows & ~kChunkedRowMask; - // Destination "rows" are padded to next highest multiple of Layout::kRows. - const int dst_rows = (src_rows + 3) & ~3; - const int trailing_rows = dst_rows - non_trailing_rows; - memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf, - Layout::kCols * trailing_rows * sizeof(std::int8_t)); - } -} - -void PackFloatAvx512(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr) { - profiler::ScopeLabel label("Pack kAvx512 float"); - float trailing_buf[7 * 16]; - if (remaining_src_cols > 8) { - HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_rows, packed_ptr, trailing_buf); - HalfPackFloatAvx512(src_ptr + src_stride * 8, zerobuf, src_stride, - remaining_src_cols - 8, src_rows, packed_ptr + 8, - trailing_buf + 8); - } else { - memset(trailing_buf, 0, sizeof(trailing_buf)); - HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_rows, packed_ptr, trailing_buf); - ZeroHalfFloatAvx512(src_rows, packed_ptr + 8); - } - const int trailing_rows = src_rows & 7; - if (trailing_rows > 0) { - const int non_trailing_rows = src_rows & ~7; - memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf, - 16 * trailing_rows * sizeof(float)); - } -} - -#endif // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/pack_avxvnni.cc b/tensorflow/lite/experimental/ruy/ruy/pack_avxvnni.cc deleted file mode 100644 index 49b4a1f978c..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/pack_avxvnni.cc +++ /dev/null @@ -1,478 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/pack.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -#if RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) -#include // IWYU pragma: keep -#endif - -namespace ruy { - -#if !(RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -void Pack8bitAvxVnni(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void PackFloatAvxVnni(const float* src_ptr, const float* zerobuf, - int src_stride, int remaining_src_cols, int src_rows, - float* packed_ptr) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -#else // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -// The first int8_t template parameter is arbitrary: this routine is common to -// all 8-bit source matrix types. -using PackImpl8bitAvxVnni = - PackImpl, - std::int8_t, std::int8_t, std::int32_t>; - -namespace { - -inline void ZeroHalf8bitAvxVnni(int src_rows, std::int8_t packed_zero_point, - std::int8_t* packed_ptr) { - const int non_trailing_blocks = (src_rows & ~31) >> 2; - // This routine fills half blocks, and typically fills the second halves. Thus - // packed_ptr is already offset by 8*4. - for (int k = 0; k < non_trailing_blocks; ++k) { - for (int j = 0; j < (8 * 4); ++j) { - packed_ptr[16 * 4 * k + j] = packed_zero_point; - } - } -} - -inline void HalfPack8bitAvxVnni(const std::int8_t* src_ptr, - std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr, - std::int8_t* trailing_buf) { - std::int8_t in_data[8][8][4]; - - const std::int8_t* src_ptr0 = src_ptr; - const std::int8_t* src_ptr1 = src_ptr0 + src_stride; - const std::int8_t* src_ptr2 = src_ptr1 + src_stride; - const std::int8_t* src_ptr3 = src_ptr2 + src_stride; - const std::int8_t* src_ptr4 = src_ptr3 + src_stride; - const std::int8_t* src_ptr5 = src_ptr4 + src_stride; - const std::int8_t* src_ptr6 = src_ptr5 + src_stride; - const std::int8_t* src_ptr7 = src_ptr6 + src_stride; - std::int64_t src_inc0 = 8 * 4; - std::int64_t src_inc1 = 8 * 4; - std::int64_t src_inc2 = 8 * 4; - std::int64_t src_inc3 = 8 * 4; - std::int64_t src_inc4 = 8 * 4; - std::int64_t src_inc5 = 8 * 4; - std::int64_t src_inc6 = 8 * 4; - std::int64_t src_inc7 = 8 * 4; - if (remaining_src_cols < 8) { - if (remaining_src_cols <= 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (remaining_src_cols <= 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (remaining_src_cols <= 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (remaining_src_cols <= 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - if (remaining_src_cols <= 4) { - src_ptr4 = zerobuf; - src_inc4 = 0; - } - if (remaining_src_cols <= 5) { - src_ptr5 = zerobuf; - src_inc5 = 0; - } - if (remaining_src_cols <= 6) { - src_ptr6 = zerobuf; - src_inc6 = 0; - } - src_ptr7 = zerobuf; - src_inc7 = 0; - } - - const std::int8_t zero_point = zerobuf[0]; - - if (sums_ptr) { - for (int i = 0; i < 8; ++i) { - sums_ptr[i] = 0; - } - } - - // The overall packing effectively pads the source rows to - // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we - // only pack for (src_rows + 31) & ~31. When there is an incomplete - // destination block, this is stored into trailing_buf instead of packed_ptr. - for (int k = 0; k < src_rows; k += 16 * 4) { - for (int m = 0; m < 2; ++m) { - // Available source rows. - // If this is less than 0 (for m=1), we skip, having filled trailing - // buffer for m=0. Also, if source rows is zero on m=1, then we filled - // exactly to the end of the column in the packed buffer. - const int packed_rows = src_rows - k - 8 * m * 4; - // Effectively, - // packed_rows = std::max(0, std::min(8, src_rows - k - 8 * m)); - // but treat each case separately. - if (packed_rows >= (8 * 4)) { - for (int i = 0; i < 8; ++i) { - for (int s = 0; s < 4; ++s) { - in_data[0][i][s] = src_ptr0[i * 4 + s]; - in_data[1][i][s] = src_ptr1[i * 4 + s]; - in_data[2][i][s] = src_ptr2[i * 4 + s]; - in_data[3][i][s] = src_ptr3[i * 4 + s]; - in_data[4][i][s] = src_ptr4[i * 4 + s]; - in_data[5][i][s] = src_ptr5[i * 4 + s]; - in_data[6][i][s] = src_ptr6[i * 4 + s]; - in_data[7][i][s] = src_ptr7[i * 4 + s]; - } - } - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; ++j) { - for (int s = 0; s < 4; ++s) { - packed_ptr[(16 * i + j) * 4 + s] = - static_cast(in_data[j][i][s] ^ input_xor); - } - if (sums_ptr) { - for (int s = 0; s < 4; ++s) { - sums_ptr[j] += in_data[j][i][s] ^ input_xor; - } - } - } - } - } else if (packed_rows > 0) { - RUY_DCHECK_LT(packed_rows >> 2, 8); - int i = 0; - for (; i < (packed_rows >> 2); ++i) { - for (int s = 0; s < 4; ++s) { - in_data[0][i][s] = src_ptr0[i * 4 + s]; - in_data[1][i][s] = src_ptr1[i * 4 + s]; - in_data[2][i][s] = src_ptr2[i * 4 + s]; - in_data[3][i][s] = src_ptr3[i * 4 + s]; - in_data[4][i][s] = src_ptr4[i * 4 + s]; - in_data[5][i][s] = src_ptr5[i * 4 + s]; - in_data[6][i][s] = src_ptr6[i * 4 + s]; - in_data[7][i][s] = src_ptr7[i * 4 + s]; - } - } - if (i < ((packed_rows + 3) >> 2)) { - int s = 0; - for (; s < (packed_rows & 3); ++s) { - in_data[0][i][s] = src_ptr0[i * 4 + s]; - in_data[1][i][s] = src_ptr1[i * 4 + s]; - in_data[2][i][s] = src_ptr2[i * 4 + s]; - in_data[3][i][s] = src_ptr3[i * 4 + s]; - in_data[4][i][s] = src_ptr4[i * 4 + s]; - in_data[5][i][s] = src_ptr5[i * 4 + s]; - in_data[6][i][s] = src_ptr6[i * 4 + s]; - in_data[7][i][s] = src_ptr7[i * 4 + s]; - } - RUY_DCHECK_LE(s, 4); - for (; s < 4; ++s) { - for (int j = 0; j < 8; ++j) { - in_data[j][i][s] = zero_point; - } - } - ++i; - } - // We do not care what goes into the trailing buffer, but we want - // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. - // - // It might prove better in optimized code to pad uniformly with - // zero_point, and compensate by initializing the summations with the - // compensating offset, effectively - // ((input_xor - zero_point) ^ input_xor) * - // 4 * (8 - ((packed_rows + 3) >> 2)). - for (; i < 8; ++i) { - for (int s = 0; s < 4; ++s) { - for (int j = 0; j < 8; ++j) { - in_data[j][i][s] = input_xor; - } - } - } - // We loop through [0, 8) rather than [0, (packed_rows + 3) >> 2), since - // that emulates what we might do in fully-optimized code. - if (sums_ptr) { - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; ++j) { - for (int s = 0; s < 4; ++s) { - trailing_buf[(16 * i + j) * 4 + s] = - static_cast(in_data[j][i][s] ^ input_xor); - sums_ptr[j] += in_data[j][i][s] ^ input_xor; - } - } - } - } else { - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; ++j) { - for (int s = 0; s < 4; ++s) { - trailing_buf[(16 * i + j) * 4 + s] = - static_cast(in_data[j][i][s] ^ input_xor); - } - } - } - } - } - - packed_ptr += 16 * 8 * 4; - src_ptr0 += src_inc0; - src_ptr1 += src_inc1; - src_ptr2 += src_inc2; - src_ptr3 += src_inc3; - src_ptr4 += src_inc4; - src_ptr5 += src_inc5; - src_ptr6 += src_inc6; - src_ptr7 += src_inc7; - } - } -} - -inline void HalfPackFloatAvxVnni(const float* src_ptr, const float* zerobuf, - int src_stride, int remaining_src_cols, - int src_rows, float* packed_ptr, - float* trailing_buf) { - float in_data[8][8]; - - const float* src_ptr0 = src_ptr; - const float* src_ptr1 = src_ptr0 + src_stride; - const float* src_ptr2 = src_ptr1 + src_stride; - const float* src_ptr3 = src_ptr2 + src_stride; - const float* src_ptr4 = src_ptr3 + src_stride; - const float* src_ptr5 = src_ptr4 + src_stride; - const float* src_ptr6 = src_ptr5 + src_stride; - const float* src_ptr7 = src_ptr6 + src_stride; - std::int64_t src_inc0 = 8; - std::int64_t src_inc1 = 8; - std::int64_t src_inc2 = 8; - std::int64_t src_inc3 = 8; - std::int64_t src_inc4 = 8; - std::int64_t src_inc5 = 8; - std::int64_t src_inc6 = 8; - std::int64_t src_inc7 = 8; - if (remaining_src_cols < 8) { - if (remaining_src_cols <= 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (remaining_src_cols <= 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (remaining_src_cols <= 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (remaining_src_cols <= 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - if (remaining_src_cols <= 4) { - src_ptr4 = zerobuf; - src_inc4 = 0; - } - if (remaining_src_cols <= 5) { - src_ptr5 = zerobuf; - src_inc5 = 0; - } - if (remaining_src_cols <= 6) { - src_ptr6 = zerobuf; - src_inc6 = 0; - } - src_ptr7 = zerobuf; - src_inc7 = 0; - } - - for (int k = 0; k < src_rows; k += 16) { - for (int m = 0; m < 2; ++m) { - const int packed_rows = src_rows - k - 8 * m; - // Effectively, - // packed_rows = std::max(0, std::min(8, src_rows - k - 8 * m)); - // but treat each case separately. - if (packed_rows > 7) { - for (int i = 0; i < 8; ++i) { - in_data[0][i] = src_ptr0[i]; - in_data[1][i] = src_ptr1[i]; - in_data[2][i] = src_ptr2[i]; - in_data[3][i] = src_ptr3[i]; - in_data[4][i] = src_ptr4[i]; - in_data[5][i] = src_ptr5[i]; - in_data[6][i] = src_ptr6[i]; - in_data[7][i] = src_ptr7[i]; - } - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; ++j) { - packed_ptr[16 * i + j] = in_data[j][i]; - } - } - } else if (packed_rows > 0) { - for (int i = 0; i < packed_rows; ++i) { - in_data[0][i] = src_ptr0[i]; - in_data[1][i] = src_ptr1[i]; - in_data[2][i] = src_ptr2[i]; - in_data[3][i] = src_ptr3[i]; - in_data[4][i] = src_ptr4[i]; - in_data[5][i] = src_ptr5[i]; - in_data[6][i] = src_ptr6[i]; - in_data[7][i] = src_ptr7[i]; - } - for (int i = packed_rows; i < 8; ++i) { - in_data[0][i] = 0.0f; - in_data[1][i] = 0.0f; - in_data[2][i] = 0.0f; - in_data[3][i] = 0.0f; - in_data[4][i] = 0.0f; - in_data[5][i] = 0.0f; - in_data[6][i] = 0.0f; - in_data[7][i] = 0.0f; - } - // We loop through [0, 7) rather than [0, packed_rows), since that - // emulates what we might do in fully-optimized code. - for (int i = 0; i < 7; ++i) { - for (int j = 0; j < 8; ++j) { - trailing_buf[16 * i + j] = in_data[j][i]; - } - } - } - - packed_ptr += 16 * 8; - src_ptr0 += src_inc0; - src_ptr1 += src_inc1; - src_ptr2 += src_inc2; - src_ptr3 += src_inc3; - src_ptr4 += src_inc4; - src_ptr5 += src_inc5; - src_ptr6 += src_inc6; - src_ptr7 += src_inc7; - } - } -} - -inline void ZeroHalfFloatAvxVnni(int src_rows, float* packed_ptr) { - const int non_trailing_rows = src_rows & ~7; - for (int k = 0; k < non_trailing_rows; ++k) { - for (int j = 0; j < 8; ++j) { - packed_ptr[j] = 0.0f; - } - packed_ptr += 16; - } -} - -} // namespace. - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// When removing this comment, update profiling label below. -void Pack8bitAvxVnni(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr) { - profiler::ScopeLabel label("Pack kAvxVnni 8bit (UNFINISHED)"); - - // Each packed block is 4*16, and there are normally 8. The trailing block is - // only slightly shorter. - std::int8_t trailing_buf[8 * 16 * 4]; - memset(trailing_buf, 0, 8 * 16 * 4 * sizeof(std::int8_t)); - - std::int32_t* second_sums_ptr = sums_ptr ? sums_ptr + 8 : nullptr; - if (remaining_src_cols > 8) { - HalfPack8bitAvxVnni(src_ptr, input_xor, zerobuf, src_stride, - remaining_src_cols, src_rows, packed_ptr, sums_ptr, - trailing_buf); - HalfPack8bitAvxVnni(src_ptr + src_stride * 8, input_xor, zerobuf, - src_stride, remaining_src_cols - 8, src_rows, - packed_ptr + 8 * 4, second_sums_ptr, - trailing_buf + 8 * 4); - } else { - HalfPack8bitAvxVnni(src_ptr, input_xor, zerobuf, src_stride, - remaining_src_cols, src_rows, packed_ptr, sums_ptr, - trailing_buf); - ZeroHalf8bitAvxVnni(src_rows, zerobuf[0] ^ input_xor, packed_ptr + 8 * 4); - // The kernel may not need the second half-blocks sums to be set. - if (second_sums_ptr) { - for (int i = 0; i < 8; ++i) { - second_sums_ptr[i] = (zerobuf[0] ^ input_xor) * ((src_rows + 3) & ~3); - } - } - } - const bool trailing_data = (src_rows & 31) > 0; - // If the number of source rows is not a multiple of 32, there will be data in - // the trailing buffer, - if (trailing_data > 0) { - const int non_trailing_rows = src_rows & ~31; - // Destination "rows" are padded to next highest multiple of 4. - const int dst_rows = (src_rows + 3) & ~3; - const int trailing_rows = dst_rows - non_trailing_rows; - memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf, - 16 * trailing_rows * sizeof(std::int8_t)); - } -} - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// When removing this comment, update profiling label below. -void PackFloatAvxVnni(const float* src_ptr, const float* zerobuf, - int src_stride, int remaining_src_cols, int src_rows, - float* packed_ptr) { - profiler::ScopeLabel label("Pack kAvxVnni float (UNFINISHED)"); - float trailing_buf[7 * 16]; - if (remaining_src_cols > 8) { - HalfPackFloatAvxVnni(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_rows, packed_ptr, trailing_buf); - HalfPackFloatAvxVnni(src_ptr + src_stride * 8, zerobuf, src_stride, - remaining_src_cols - 8, src_rows, packed_ptr + 8, - trailing_buf + 8); - } else { - memset(trailing_buf, 0, sizeof(trailing_buf)); - HalfPackFloatAvxVnni(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_rows, packed_ptr, trailing_buf); - ZeroHalfFloatAvxVnni(src_rows, packed_ptr + 8); - } - const int trailing_rows = src_rows & 7; - if (trailing_rows > 0) { - const int non_trailing_rows = src_rows & ~7; - memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf, - 16 * trailing_rows * sizeof(float)); - } -} - -#endif // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/pack_common.h b/tensorflow/lite/experimental/ruy/ruy/pack_common.h deleted file mode 100644 index 91d47af8a5f..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/pack_common.h +++ /dev/null @@ -1,246 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// # What is "packing"? -// -// Before feeding data to the gemm kernels (the parts of Ruy that do lots -// of multiply-add operations), Ruy first performs a data transformation (which -// we call "packing") on the input matrices. This transformation has two main -// goals: -// - rearrange data into blocks that are a convenient size/layout for the gemm -// kernels to consume. This helps make the memory access pattern of the gemm -// kernel simpler and more contiguous, and puts the data in a layout most -// convenient for specific arithmetic instructions in the gemm kernel. -// - compute row/column sums needed for handling quantization with non-symmetric -// zero points. -// -// # Simplified algorithmic analysis of packing -// -// Packing is a relatively simple transformation which does a small constant -// amount of work on each element of an input matrix, and hence for an NxM -// matrix performs O(N*M) work. If N and M are of the same order, then this is -// O(N^2) work. -// -// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations. -// Note that if N, K, and M are all the same order, then the number of -// multiply-accumulate operations is O(N^3). -// -// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the -// case of all dimensions being roughly the same order. -// -// # Packing cost can be significant -// -// When matrix * matrix multiplications begin to look more like matrix * vector -// multiplications, packing cost can become significant. We sometimes call these -// cases "gemv-like". -// -// Continuing the algorithmic analysis above, if we consider a case where an -// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the -// situation is different. In this case, the multiply-accumulate work is only -// quadratic, so the quadratic cost of packing can be come significant. -// -// Another way to say this is that the cost of packing an input matrix (either -// the LHS or RHS) is amortized across the non-depth dimension of the opposite -// input matrix. Thus, when the LHS has very few rows or the RHS has very few -// columns, the cost of packing the opposite input matrix can become -// significant. -// -// As a rough rule of thumb, the cost of packing starts to become significant -// when either N or M is below 32 (and other dimensions are hundreds), with very -// significant packing costs at 8 or below. This varies by data type, Path, and -// tuning, so these numbers are only rough guides. -// -// One practical use case that is affected by this is inference of -// fully connected neural network layers with a low batch size. The weight -// matrix (which is a constant for inference) is the one affected by significant -// packing cost. -// -// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack -// input matrices that are affected by significant packing costs. -// -// # Implementation notes -// -// Ruy's packing routines always operate on a range of columns and can be -// applied to either the LHS or RHS. This is possible because Ruy internally -// implements a TrMul, so the accumulation along depth is done along columns of -// both the LHS and RHS (whereas for a normal Mul the accumulation along depth -// for the LHS is along rows). As another example, we are always computing -// column sums for quantization (and never row sums, since the LHS is -// transposed). - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_COMMON_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_COMMON_H_ - -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -namespace ruy { - -template -struct PackedTypeImpl { - using Type = Scalar; -}; - -#if RUY_PLATFORM(NEON_32) -struct PackParams8bit { - const void* src_ptr0; - const void* src_ptr1; - const void* src_ptr2; - const void* src_ptr3; - const std::int32_t* sums_ptr; - const std::int8_t* packed_ptr; - int src_inc0; - int src_inc1; - int src_inc2; - int src_inc3; - int src_rows; - int src_zero_point; - int input_xor; -}; - -inline void MakePackParams8bit(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - const std::int32_t* sums_ptr, - const std::int8_t* packed_ptr, int src_inc0, - int src_inc1, int src_inc2, int src_inc3, - int src_rows, int src_zero_point, int input_xor, - PackParams8bit* params) { - params->src_ptr0 = src_ptr0; - params->src_ptr1 = src_ptr1; - params->src_ptr2 = src_ptr2; - params->src_ptr3 = src_ptr3; - params->sums_ptr = sums_ptr; - params->packed_ptr = packed_ptr; - params->src_inc0 = src_inc0; - params->src_inc1 = src_inc1; - params->src_inc2 = src_inc2; - params->src_inc3 = src_inc3; - params->src_rows = src_rows; - params->src_zero_point = src_zero_point; - params->input_xor = input_xor; -} -#endif - -#if RUY_PLATFORM(NEON) -template <> -struct PackedTypeImpl { - using Type = std::int8_t; -}; -template <> -struct PackedTypeImpl { - using Type = std::int8_t; -}; -#elif RUY_PLATFORM(X86) -template <> -struct PackedTypeImpl { - using Type = std::int8_t; -}; -template <> -struct PackedTypeImpl { - using Type = std::int8_t; -}; -template <> -struct PackedTypeImpl { - using Type = std::int8_t; -}; -template <> -struct PackedTypeImpl { - using Type = std::int8_t; -}; -#endif - -template -using PackedType = typename PackedTypeImpl::Type; - -template -PackedScalar Pack(Scalar x) { - return x - SymmetricZeroPoint() + SymmetricZeroPoint(); -} - -template -struct PackImpl {}; - -#define RUY_INHERIT_PACK(PARENT, CHILD) \ - template \ - struct PackImpl \ - : PackImpl { \ - }; - -template -struct PackImpl { - static void Run(Tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (generic)"); - RUY_DCHECK_EQ((end_col - start_col) % FixedKernelLayout::kCols, 0); - SumsType* sums = packed_matrix->sums; - for (int col = start_col; col < end_col; col++) { - SumsType accum = 0; - for (int row = 0; row < packed_matrix->layout.rows; row++) { - PackedScalar packed_val; - if (col < src_matrix.layout.cols && row < src_matrix.layout.rows) { - packed_val = Pack(Element(src_matrix, row, col)); - } else { - packed_val = packed_matrix->zero_point; - } - accum += packed_val; - *ElementPtr(packed_matrix, row, col) = packed_val; - } - if (sums) { - sums[col] = accum; - } - } - } -}; - -#if RUY_PLATFORM(NEON) -RUY_INHERIT_PACK(Path::kStandardCpp, Path::kNeon) -RUY_INHERIT_PACK(Path::kNeon, Path::kNeonDotprod) -#elif RUY_PLATFORM(X86) -RUY_INHERIT_PACK(Path::kStandardCpp, Path::kSse42) -RUY_INHERIT_PACK(Path::kSse42, Path::kAvx2) -RUY_INHERIT_PACK(Path::kAvx2, Path::kAvx512) -RUY_INHERIT_PACK(Path::kAvx512, Path::kAvxVnni) -#endif - -// Main entry point for packing. -template -void RunPack(Tuning tuning, const DMatrix& src_matrix, PMatrix* packed_matrix, - int start_col, int end_col) { - using SumsType = typename PackedMatrix::SumsType; - Matrix src = ToMatrix(src_matrix); - PackedMatrix packed = - ToPackedMatrix(*packed_matrix); - PackImpl::Run( - tuning, src, &packed, start_col, end_col); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_COMMON_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/pack_sse42.cc b/tensorflow/lite/experimental/ruy/ruy/pack_sse42.cc deleted file mode 100644 index ecd1cf83c6d..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/pack_sse42.cc +++ /dev/null @@ -1,471 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/pack.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -#if RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) -#include // IWYU pragma: keep -#endif - -namespace ruy { - -#if !(RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -void Pack8bitSse42(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void PackFloatSse42(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -#else // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -// The first int8_t template parameter is arbitrary: this routine is common to -// all 8-bit source matrix types. -using PackImpl8bitSse42 = - PackImpl, - std::int8_t, std::int8_t, std::int32_t>; - -using PackImplFloatSse42 = - PackImpl, float, - float, float>; - -namespace { - -inline void Pack8bitSse42Packer(const std::int8_t* src_ptr, - std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr, - std::int8_t* trailing_buf) { - using Layout = PackImpl8bitSse42::Layout; - RUY_DCHECK_EQ(Layout::kCols, 8); - RUY_DCHECK_EQ(Layout::kRows, 4); - // Each Layout::Rows is 4 contiguous input, contiguous packed elements. - // We process 8 of these chunks at a time, padding short input chunks. - constexpr int kNumRowChunks = 8; - constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows; - - std::int8_t in_data[Layout::kCols][kNumRowChunks][Layout::kRows]; - - const std::int8_t* src_ptr0 = src_ptr; - const std::int8_t* src_ptr1 = src_ptr0 + src_stride; - const std::int8_t* src_ptr2 = src_ptr1 + src_stride; - const std::int8_t* src_ptr3 = src_ptr2 + src_stride; - const std::int8_t* src_ptr4 = src_ptr3 + src_stride; - const std::int8_t* src_ptr5 = src_ptr4 + src_stride; - const std::int8_t* src_ptr6 = src_ptr5 + src_stride; - const std::int8_t* src_ptr7 = src_ptr6 + src_stride; - std::int64_t src_inc0 = kNumChunkedSrcRows; - std::int64_t src_inc1 = kNumChunkedSrcRows; - std::int64_t src_inc2 = kNumChunkedSrcRows; - std::int64_t src_inc3 = kNumChunkedSrcRows; - std::int64_t src_inc4 = kNumChunkedSrcRows; - std::int64_t src_inc5 = kNumChunkedSrcRows; - std::int64_t src_inc6 = kNumChunkedSrcRows; - std::int64_t src_inc7 = kNumChunkedSrcRows; - // Handle cases where source does not have Layout::kCols (8) columns. - if (remaining_src_cols < 8) { - if (remaining_src_cols <= 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (remaining_src_cols <= 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (remaining_src_cols <= 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (remaining_src_cols <= 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - if (remaining_src_cols <= 4) { - src_ptr4 = zerobuf; - src_inc4 = 0; - } - if (remaining_src_cols <= 5) { - src_ptr5 = zerobuf; - src_inc5 = 0; - } - if (remaining_src_cols <= 6) { - src_ptr6 = zerobuf; - src_inc6 = 0; - } - src_ptr7 = zerobuf; - src_inc7 = 0; - } - - const std::int8_t zero_point = zerobuf[0]; - - if (sums_ptr) { - // i: Layout::kCols. - for (int i = 0; i < 8; ++i) { - sums_ptr[i] = 0; - } - } - - // The overall packing effectively pads the source rows to - // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we - // only pack for (src_rows + 31) & ~31. When there is an incomplete - // destination block, this is stored into trailing_buf instead of packed_ptr. - for (int k = 0; k < src_rows; k += kNumChunkedSrcRows) { - // Available source rows. - // If this is less than 0 (for m=1), we skip, having filled trailing - // buffer for m=0. Also, if source rows is zero on m=1, then we filled - // exactly to the end of the column in the packed buffer. - const int available_src_rows = src_rows - k; - // Effectively, - // available rows = std::max(0, std::min(8, src_rows - k)); - // treat each case separately. - if (available_src_rows >= kNumChunkedSrcRows) { - // i: chunks, s: Layout::Rows. - for (int i = 0; i < 8; ++i) { - for (int s = 0; s < 4; ++s) { - in_data[0][i][s] = src_ptr0[i * 4 + s]; - in_data[1][i][s] = src_ptr1[i * 4 + s]; - in_data[2][i][s] = src_ptr2[i * 4 + s]; - in_data[3][i][s] = src_ptr3[i * 4 + s]; - in_data[4][i][s] = src_ptr4[i * 4 + s]; - in_data[5][i][s] = src_ptr5[i * 4 + s]; - in_data[6][i][s] = src_ptr6[i * 4 + s]; - in_data[7][i][s] = src_ptr7[i * 4 + s]; - } - } - // i: chunks, j: Layout::kCols, s: Layout::Rows. - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; ++j) { - for (int s = 0; s < 4; ++s) { - // 8 * 4 * i is offset for each block, that is - // (Layout::kCols * Layout::kRows * i) - packed_ptr[(8 * i + j) * 4 + s] = in_data[j][i][s] ^ input_xor; - } - if (sums_ptr) { - for (int s = 0; s < 4; ++s) { - sums_ptr[j] += in_data[j][i][s] ^ input_xor; - } - } - } - } - } else if (available_src_rows > 0) { - RUY_DCHECK_LT(available_src_rows, kNumChunkedSrcRows); - int i = 0; - // Consume chunks of 4 rows that are complete. - for (; i < (available_src_rows >> 2); ++i) { - for (int s = 0; s < 4; ++s) { - in_data[0][i][s] = src_ptr0[i * 4 + s]; - in_data[1][i][s] = src_ptr1[i * 4 + s]; - in_data[2][i][s] = src_ptr2[i * 4 + s]; - in_data[3][i][s] = src_ptr3[i * 4 + s]; - in_data[4][i][s] = src_ptr4[i * 4 + s]; - in_data[5][i][s] = src_ptr5[i * 4 + s]; - in_data[6][i][s] = src_ptr6[i * 4 + s]; - in_data[7][i][s] = src_ptr7[i * 4 + s]; - } - } - // Consume any incomplete chunk. - if (i < ((available_src_rows + 3) >> 2)) { - int s = 0; - for (; s < (available_src_rows & 3); ++s) { - in_data[0][i][s] = src_ptr0[i * 4 + s]; - in_data[1][i][s] = src_ptr1[i * 4 + s]; - in_data[2][i][s] = src_ptr2[i * 4 + s]; - in_data[3][i][s] = src_ptr3[i * 4 + s]; - in_data[4][i][s] = src_ptr4[i * 4 + s]; - in_data[5][i][s] = src_ptr5[i * 4 + s]; - in_data[6][i][s] = src_ptr6[i * 4 + s]; - in_data[7][i][s] = src_ptr7[i * 4 + s]; - } - RUY_DCHECK_LE(s, 4); - for (; s < 4; ++s) { - // j: Layout::kCols. - for (int j = 0; j < 8; ++j) { - in_data[j][i][s] = zero_point; - } - } - ++i; - } - // We do not care what goes into the trailing buffer, but we want - // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. - // - // It might prove better in optimized code to pad uniformly with - // zero_point, and compensate by initializing the summations with the - // compensating offset, effectively - // ((input_xor - zero_point) ^ input_xor) * - // 4 * (8 - ((available_src_rows + 3) >> 2)). - for (; i < 8; ++i) { - for (int s = 0; s < 4; ++s) { - for (int j = 0; j < 8; ++j) { - in_data[j][i][s] = input_xor; - } - } - } - // We loop through [0, 8) rather than - // [0, (available_src_rows + 3) >> 2), since that emulates what we might - // do in fully-optimized code. - // - // i: chunks, j: Layout::kCols, s: Layout::Rows. - if (sums_ptr) { - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; ++j) { - for (int s = 0; s < 4; ++s) { - trailing_buf[(8 * i + j) * 4 + s] = in_data[j][i][s] ^ input_xor; - sums_ptr[j] = sums_ptr[j] + (in_data[j][i][s] ^ input_xor); - } - } - } - } else { - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; ++j) { - for (int s = 0; s < 4; ++s) { - trailing_buf[(8 * i + j) * 4 + s] = in_data[j][i][s] ^ input_xor; - } - } - } - } - } - - packed_ptr += 8 * kNumChunkedSrcRows; - src_ptr0 += src_inc0; - src_ptr1 += src_inc1; - src_ptr2 += src_inc2; - src_ptr3 += src_inc3; - src_ptr4 += src_inc4; - src_ptr5 += src_inc5; - src_ptr6 += src_inc6; - src_ptr7 += src_inc7; - } -} - -inline void PackFloatSse42Packer(const float* src_ptr, const float* zerobuf, - int src_stride, int remaining_src_cols, - int src_rows, float* packed_ptr, - float* trailing_buf) { - using Layout = PackImplFloatSse42::Layout; - RUY_DCHECK_EQ(Layout::kCols, 8); - RUY_DCHECK_EQ(Layout::kRows, 1); - - // This packing amounts to tranposition of 8x8 blocks. - static constexpr int kPackCols = 8; // Source cols packed together. - static constexpr int kPackRows = 8; // Short input is padded. - - float in_data[kPackCols][kPackRows]; - - const float* src_ptr0 = src_ptr; - const float* src_ptr1 = src_ptr0 + src_stride; - const float* src_ptr2 = src_ptr1 + src_stride; - const float* src_ptr3 = src_ptr2 + src_stride; - const float* src_ptr4 = src_ptr3 + src_stride; - const float* src_ptr5 = src_ptr4 + src_stride; - const float* src_ptr6 = src_ptr5 + src_stride; - const float* src_ptr7 = src_ptr6 + src_stride; - std::int64_t src_inc0 = 8; - std::int64_t src_inc1 = 8; - std::int64_t src_inc2 = 8; - std::int64_t src_inc3 = 8; - std::int64_t src_inc4 = 8; - std::int64_t src_inc5 = 8; - std::int64_t src_inc6 = 8; - std::int64_t src_inc7 = 8; - // Handle cases where source does not have kPackDim (8) columns. - if (remaining_src_cols < kPackCols) { - if (remaining_src_cols <= 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (remaining_src_cols <= 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (remaining_src_cols <= 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (remaining_src_cols <= 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - if (remaining_src_cols <= 4) { - src_ptr4 = zerobuf; - src_inc4 = 0; - } - if (remaining_src_cols <= 5) { - src_ptr5 = zerobuf; - src_inc5 = 0; - } - if (remaining_src_cols <= 6) { - src_ptr6 = zerobuf; - src_inc6 = 0; - } - src_ptr7 = zerobuf; - src_inc7 = 0; - } - - for (int k = 0; k < src_rows; k += kPackRows) { - const int available_src_rows = src_rows - k; - // Effectively, - // available_src_rows = std::max(0, std::min(kPackDim, src_rows - k)); - // but treat each case separately. - if (available_src_rows >= kPackRows) { - for (int i = 0; i < 8; ++i) { - in_data[0][i] = src_ptr0[i]; - in_data[1][i] = src_ptr1[i]; - in_data[2][i] = src_ptr2[i]; - in_data[3][i] = src_ptr3[i]; - in_data[4][i] = src_ptr4[i]; - in_data[5][i] = src_ptr5[i]; - in_data[6][i] = src_ptr6[i]; - in_data[7][i] = src_ptr7[i]; - } - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; ++j) { - packed_ptr[8 * i + j] = in_data[j][i]; - } - } - } else if (available_src_rows > 0) { - for (int i = 0; i < available_src_rows; ++i) { - in_data[0][i] = src_ptr0[i]; - in_data[1][i] = src_ptr1[i]; - in_data[2][i] = src_ptr2[i]; - in_data[3][i] = src_ptr3[i]; - in_data[4][i] = src_ptr4[i]; - in_data[5][i] = src_ptr5[i]; - in_data[6][i] = src_ptr6[i]; - in_data[7][i] = src_ptr7[i]; - } - for (int i = available_src_rows; i < kPackRows; ++i) { - in_data[0][i] = 0.0f; - in_data[1][i] = 0.0f; - in_data[2][i] = 0.0f; - in_data[3][i] = 0.0f; - in_data[4][i] = 0.0f; - in_data[5][i] = 0.0f; - in_data[6][i] = 0.0f; - in_data[7][i] = 0.0f; - } - // We loop through [0, 7) rather than [0, packed_rows), since that - // emulates what we might do in fully-optimized code. - // i: (kPackRows - 1), j: kPackCols. - for (int i = 0; i < 7; ++i) { - for (int j = 0; j < 8; ++j) { - trailing_buf[kPackRows * i + j] = in_data[j][i]; - } - } - } - - packed_ptr += kPackRows * kPackCols; - src_ptr0 += src_inc0; - src_ptr1 += src_inc1; - src_ptr2 += src_inc2; - src_ptr3 += src_inc3; - src_ptr4 += src_inc4; - src_ptr5 += src_inc5; - src_ptr6 += src_inc6; - src_ptr7 += src_inc7; - } -} - -} // namespace. - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// When removing this comment, update profiling label below. -void Pack8bitSse42(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr) { - profiler::ScopeLabel label("Pack kSse42 8bit (UNFINISHED)"); - - using Layout = PackImpl8bitSse42::Layout; - RUY_DCHECK_EQ(Layout::kCols, 8); - RUY_DCHECK_EQ(Layout::kRows, 4); - - // Each Layout::Rows is 4 contiguous input, contiguous packed elements. - // We process 8 of these chunks at a time, padding short input chunks. - static constexpr int kNumRowChunks = 8; // Short input is padded. - - // Each packed block is 4*8, and there are normally 8. The trailing block is - // only slightly shorter. - constexpr int kTrailingBufSize = - kNumRowChunks * Layout::kCols * Layout::kRows; - std::int8_t trailing_buf[kTrailingBufSize]; - memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t)); - - Pack8bitSse42Packer(src_ptr, input_xor, zerobuf, src_stride, - remaining_src_cols, src_rows, packed_ptr, sums_ptr, - trailing_buf); - - constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; - const bool trailing_data = (src_rows & kChunkedRowMask) > 0; - // If the number of source rows is not a multiple of kChunkedRowMask, there - // will be data in the trailing buffer, - if (trailing_data > 0) { - const int non_trailing_rows = src_rows & ~kChunkedRowMask; - // Destination "rows" are padded to next highest multiple of Layout::kRows. - const int dst_rows = (src_rows + 3) & ~3; - const int trailing_rows = dst_rows - non_trailing_rows; - memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf, - Layout::kCols * trailing_rows * sizeof(std::int8_t)); - } -} - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// When removing this comment, update profiling label below. -void PackFloatSse42(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr) { - profiler::ScopeLabel label("Pack kSse42 float (UNFINISHED)"); - static constexpr int kPackCols = 8; // Source cols packed together. - static constexpr int kPackRows = 8; // Short input is padded. - float trailing_buf[(kPackRows - 1) * kPackCols]; - if (remaining_src_cols < 8) { - memset(trailing_buf, 0, sizeof(trailing_buf)); - } - PackFloatSse42Packer(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_rows, packed_ptr, trailing_buf); - - const int trailing_rows = src_rows & (kPackRows - 1); - if (trailing_rows > 0) { - const int non_trailing_rows = src_rows & ~(kPackRows - 1); - memcpy(packed_ptr + kPackCols * non_trailing_rows, trailing_buf, - kPackCols * trailing_rows * sizeof(float)); - } -} - -#endif // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/pack_x86.h b/tensorflow/lite/experimental/ruy/ruy/pack_x86.h deleted file mode 100644 index 8bdc88e5763..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/pack_x86.h +++ /dev/null @@ -1,461 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// # What is "packing"? -// -// Before feeding data to the gemm kernels (the parts of Ruy that do lots -// of multiply-add operations), Ruy first performs a data transformation (which -// we call "packing") on the input matrices. This transformation has two main -// goals: -// - rearrange data into blocks that are a convenient size/layout for the gemm -// kernels to consume. This helps make the memory access pattern of the gemm -// kernel simpler and more contiguous, and puts the data in a layout most -// convenient for specific arithmetic instructions in the gemm kernel. -// - compute row/column sums needed for handling quantization with non-symmetric -// zero points. -// -// # Simplified algorithmic analysis of packing -// -// Packing is a relatively simple transformation which does a small constant -// amount of work on each element of an input matrix, and hence for an NxM -// matrix performs O(N*M) work. If N and M are of the same order, then this is -// O(N^2) work. -// -// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations. -// Note that if N, K, and M are all the same order, then the number of -// multiply-accumulate operations is O(N^3). -// -// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the -// case of all dimensions being roughly the same order. -// -// # Packing cost can be significant -// -// When matrix * matrix multiplications begin to look more like matrix * vector -// multiplications, packing cost can become significant. We sometimes call these -// cases "gemv-like". -// -// Continuing the algorithmic analysis above, if we consider a case where an -// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the -// situation is different. In this case, the multiply-accumulate work is only -// quadratic, so the quadratic cost of packing can be come significant. -// -// Another way to say this is that the cost of packing an input matrix (either -// the LHS or RHS) is amortized across the non-depth dimension of the opposite -// input matrix. Thus, when the LHS has very few rows or the RHS has very few -// columns, the cost of packing the opposite input matrix can become -// significant. -// -// As a rough rule of thumb, the cost of packing starts to become significant -// when either N or M is below 32 (and other dimensions are hundreds), with very -// significant packing costs at 8 or below. This varies by data type, Path, and -// tuning, so these numbers are only rough guides. -// -// One practical use case that is affected by this is inference of -// fully connected neural network layers with a low batch size. The weight -// matrix (which is a constant for inference) is the one affected by significant -// packing cost. -// -// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack -// input matrices that are affected by significant packing costs. -// -// # Implementation notes -// -// Ruy's packing routines always operate on a range of columns and can be -// applied to either the LHS or RHS. This is possible because Ruy internally -// implements a TrMul, so the accumulation along depth is done along columns of -// both the LHS and RHS (whereas for a normal Mul the accumulation along depth -// for the LHS is along rows). As another example, we are always computing -// column sums for quantization (and never row sums, since the LHS is -// transposed). - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_X86_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_X86_H_ - -#include -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/pack_common.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -namespace ruy { - -#if RUY_PLATFORM(X86) -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// Note that source and zero buffers can be uint8 type, but in the packing -// function are reinterpreted as int8, and are XOR-ed with input_xor. -void Pack8bitSse42(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr); - -template -struct PackImpl, Scalar, - std::int8_t, std::int32_t> { - static_assert(std::is_same::value || - std::is_same::value, - ""); - using Layout = FixedKernelLayout; - static constexpr std::int8_t kInputXor = - std::is_same::value ? 0 : 0x80; - - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (SSE 4.2 8-bit)"); - - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); - RUY_DCHECK_EQ(start_col % Layout::kCols, 0); - std::int32_t* sums = packed_matrix->sums; - Scalar zerobuf[Layout::kCols * Layout::kRows]; - memset(zerobuf, packed_matrix->zero_point ^ kInputXor, - Layout::kCols * Layout::kRows * sizeof(Scalar)); - for (int block_col = start_col; block_col < end_col; - block_col += Layout::kCols) { - std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; - int src_stride = src_matrix.layout.stride; - const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; - int remaining_src_cols = src_matrix.layout.cols - block_col; - - static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. - std::int8_t* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & block_col_mask); - Pack8bitSse42(reinterpret_cast(src_ptr), kInputXor, - reinterpret_cast(zerobuf), src_stride, - remaining_src_cols, src_matrix.layout.rows, packed_ptr, - sums_ptr); - } - } -}; - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -void PackFloatSse42(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr); - -template <> -struct PackImpl, float, - float, float> { - using Layout = FixedKernelLayout; - static void Run(Tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (SSE 4.2 float)"); - - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); - RUY_DCHECK_EQ(start_col % Layout::kCols, 0); - const float zerobuf[Layout::kCols] = { - 0.0f}; // Remainder default inits to 0.0f. - for (int block_col = start_col; block_col < end_col; - block_col += Layout::kCols) { - int src_stride = src_matrix.layout.stride; - const float* src_ptr = src_matrix.data.get() + src_stride * block_col; - int remaining_src_cols = src_matrix.layout.cols - block_col; - - static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. - float* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & block_col_mask); - PackFloatSse42(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_matrix.layout.rows, packed_ptr); - } - } -}; - -// Note that source and zero buffers can be uint8 type, but in the packing -// function are reinterpreted as int8, and are XOR-ed with input_xor. -void Pack8bitAvx2(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, std::int8_t* packed_ptr, - std::int32_t* sums_ptr); - -template -struct PackImpl, Scalar, - std::int8_t, std::int32_t> { - static_assert(std::is_same::value || - std::is_same::value, - ""); - using Layout = FixedKernelLayout; - static constexpr std::int8_t kInputXor = - std::is_same::value ? 0 : 0x80; - - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (AVX2 8-bit)"); - - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); - RUY_DCHECK_EQ(start_col % Layout::kCols, 0); - std::int32_t* sums = packed_matrix->sums; - Scalar zerobuf[Layout::kCols * Layout::kRows]; - memset(zerobuf, packed_matrix->zero_point ^ kInputXor, - Layout::kCols * Layout::kRows * sizeof(Scalar)); - for (int block_col = start_col; block_col < end_col; - block_col += Layout::kCols) { - std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; - int src_stride = src_matrix.layout.stride; - const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; - int remaining_src_cols = src_matrix.layout.cols - block_col; - - static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. - std::int8_t* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & block_col_mask); - Pack8bitAvx2(reinterpret_cast(src_ptr), kInputXor, - reinterpret_cast(zerobuf), src_stride, - remaining_src_cols, src_matrix.layout.rows, packed_ptr, - sums_ptr); - } - } -}; - -void PackFloatAvx2(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr); - -template <> -struct PackImpl, float, - float, float> { - using Layout = FixedKernelLayout; - static void Run(Tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (AVX2 float)"); - - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); - RUY_DCHECK_EQ(start_col % Layout::kCols, 0); - const float zerobuf[Layout::kCols] = { - 0.0f}; // Remainder default inits to 0.0f. - for (int block_col = start_col; block_col < end_col; - block_col += Layout::kCols) { - int src_stride = src_matrix.layout.stride; - const float* src_ptr = src_matrix.data.get() + src_stride * block_col; - int remaining_src_cols = src_matrix.layout.cols - block_col; - - static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. - float* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & block_col_mask); - PackFloatAvx2(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_matrix.layout.rows, packed_ptr); - } - } -}; - -// Note that source and zero buffers can be uint8 type, but in the packing -// function are reinterpreted as int8, and are XOR-ed with input_xor. -void Pack8bitAvx512(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr); - -template -struct PackImpl, - Scalar, std::int8_t, std::int32_t> { - static_assert(std::is_same::value || - std::is_same::value, - ""); - using Layout = FixedKernelLayout; - static constexpr int kHalfLayoutCols = - 8; // Half the number of cols in a block. - static constexpr std::int8_t kInputXor = - std::is_same::value ? 0 : 0x80; - - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (AVX-512 8-bit)"); - - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); - RUY_DCHECK_EQ(start_col % Layout::kCols, 0); - RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols); - std::int32_t* sums = packed_matrix->sums; - Scalar zerobuf[kHalfLayoutCols * Layout::kRows]; - memset(zerobuf, packed_matrix->zero_point ^ kInputXor, - kHalfLayoutCols * Layout::kRows * sizeof(Scalar)); - for (int block_col = start_col; block_col < end_col; - block_col += Layout::kCols) { - std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; - int src_stride = src_matrix.layout.stride; - const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; - int remaining_src_cols = src_matrix.layout.cols - block_col; - - static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. - std::int8_t* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & block_col_mask); - Pack8bitAvx512(reinterpret_cast(src_ptr), kInputXor, - reinterpret_cast(zerobuf), src_stride, - remaining_src_cols, src_matrix.layout.rows, packed_ptr, - sums_ptr); - } - } -}; - -void PackFloatAvx512(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr); - -template <> -struct PackImpl, - float, float, float> { - static void Run(Tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (AVX-512 float)"); - using Layout = FixedKernelLayout; - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); - RUY_DCHECK_EQ(start_col % Layout::kCols, 0); - const float zerobuf[Layout::kCols] = { - 0.0f}; // Remainder default inits to 0.0f. - for (int block_col = start_col; block_col < end_col; - block_col += Layout::kCols) { - int src_stride = src_matrix.layout.stride; - const float* src_ptr = src_matrix.data.get() + src_stride * block_col; - int remaining_src_cols = src_matrix.layout.cols - block_col; - - static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. - float* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & block_col_mask); - PackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_matrix.layout.rows, packed_ptr); - } - } -}; - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// Note that source and zero buffers can be uint8 type, but in the packing -// function are reinterpreted as int8, and are XOR-ed with input_xor. -void Pack8bitAvxVnni(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr); - -template -struct PackImpl, - Scalar, std::int8_t, std::int32_t> { - static_assert(std::is_same::value || - std::is_same::value, - ""); - using Layout = FixedKernelLayout; - static constexpr int kHalfLayoutCols = - 8; // Half the number of cols in a block. - static constexpr std::int8_t kInputXor = - std::is_same::value ? 0 : 0x80; - - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (AVX-512 8-bit)"); - - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); - RUY_DCHECK_EQ(start_col % Layout::kCols, 0); - RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols); - std::int32_t* sums = packed_matrix->sums; - Scalar zerobuf[kHalfLayoutCols * Layout::kRows]; - memset(zerobuf, packed_matrix->zero_point ^ kInputXor, - kHalfLayoutCols * Layout::kRows * sizeof(Scalar)); - for (int block_col = start_col; block_col < end_col; - block_col += Layout::kCols) { - std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; - int src_stride = src_matrix.layout.stride; - const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; - int remaining_src_cols = src_matrix.layout.cols - block_col; - - static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. - std::int8_t* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & block_col_mask); - Pack8bitAvxVnni(reinterpret_cast(src_ptr), kInputXor, - reinterpret_cast(zerobuf), src_stride, - remaining_src_cols, src_matrix.layout.rows, packed_ptr, - sums_ptr); - } - } -}; - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -void PackFloatAvxVnni(const float* src_ptr, const float* zerobuf, - int src_stride, int remaining_src_cols, int src_rows, - float* packed_ptr); - -template <> -struct PackImpl, - float, float, float> { - static void Run(Tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (AVX-512 float)"); - - using Layout = FixedKernelLayout; - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); - RUY_DCHECK_EQ(start_col % Layout::kCols, 0); - const float zerobuf[Layout::kCols] = { - 0.0f}; // Remainder default inits to 0.0f. - for (int block_col = start_col; block_col < end_col; - block_col += Layout::kCols) { - int src_stride = src_matrix.layout.stride; - const float* src_ptr = src_matrix.data.get() + src_stride * block_col; - int remaining_src_cols = src_matrix.layout.cols - block_col; - - static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. - float* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & block_col_mask); - PackFloatAvxVnni(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_matrix.layout.rows, packed_ptr); - } - } -}; -#endif // RUY_PLATFORM(X86) - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_X86_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/path.h b/tensorflow/lite/experimental/ruy/ruy/path.h deleted file mode 100644 index 5973b8040a7..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/path.h +++ /dev/null @@ -1,162 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PATH_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PATH_H_ - -#include - -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" - -namespace ruy { - -// A Path is a choice of implementation path, e.g. between reference code -// and optimized code, or between different optimized code paths using different -// instruction sets. -// -// It's important that any symbol that depends on such implementation -// details, is somehow templatized in such a Path, so that different Path values -// yield different symbols, so we never have the situation where a symbols has -// multiple inequivalent definitions based on which code paths are compiled. -// That would be a violation of the ODR (One Definition Rule) which is Undefined -// Behavior, and one of the most serious issues plaguing both Eigen and -// gemmlowp. -// -// This enum is actually a bit-field: aside from kNone, all other values are -// powers of two, thus are one bit each. We define bit-wise operators below -// for this enum. Some places in Ruy accept a Path bit-field where multiple -// Paths may be selected, while some other places require a single Path (i.e. -// just one of the enum values here). Typically, user-facing parts of Ruy -// accept arbitrary bit-fields, allowing the user to compile support for -// multiple paths and to inform Ruy of all the paths that are to be enabled -// at runtime; then, typically in dispatch.h, we internally pick one -// specific path and from there on, internal Ruy code deals with only one -// path. -// -// When a user selects a set of compiled paths, Ruy internally dispatches to the -// "best" one, which typically means the newest optimized instructions for a -// given base architecture (such as ARM). Higher values of this enum correspond -// to "better" code paths within a given base architecture for which Ruy has -// optimized code paths. -// -// Values are reused across architectures. -// Rationale: Scale better to N architectures, it is good to have small values -// both for the compile-time logic to select paths, and when manually spelling -// out Path values, such as when invoking a test or benchmark. -enum class Path : std::uint8_t { - // This is a special null value, representing the absence of any path. - kNone = 0, - // Reference multiplication code. - // The main purpose of this path is to have a very simple standalone Mul - // implementation to check against. - // This path bypasses almost all of Ruy's internal implementation details. - // - // This is intended for testing/development. - kReference = 0x1, - // Standard C++ implementation of Ruy's architecture-specific parts. - // Unlike Path::kReference, this path exercises most of Ruy's internal logic. - // - // This is intended for testing/development. - kStandardCpp = 0x2, - -#if RUY_PLATFORM(ARM) - // ARM architectures. - // - // Optimized path using a widely available subset of ARM NEON instructions. - kNeon = 0x4, - // Optimized path making use of ARM NEON dot product instructions that are - // available on newer ARM cores. - kNeonDotprod = 0x8, -#endif // RUY_PLATFORM(ARM) - -#if RUY_PLATFORM(X86) - // x86 architectures. - // - // TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / - // placeholder. - // Optimization is not finished. In particular the dimensions of the kernel - // blocks can be changed as desired. - // - // Optimized for SSE 4.2. - kSse42 = 0x4, - // Optimized for AVX2. - kAvx2 = 0x8, - // Optimized for AVX-512. - kAvx512 = 0x10, - // TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / - // placeholder. - // Optimization is not finished. In particular the dimensions of the kernel - // blocks can be changed as desired. - // - // Optimized for AVX-VNNI. - kAvxVnni = 0x20, -#endif // RUY_PLATFORM(X86) -}; - -inline constexpr Path operator|(Path p, Path q) { - return static_cast(static_cast(p) | - static_cast(q)); -} - -inline constexpr Path operator&(Path p, Path q) { - return static_cast(static_cast(p) & - static_cast(q)); -} - -inline constexpr Path operator^(Path p, Path q) { - return static_cast(static_cast(p) ^ - static_cast(q)); -} - -inline constexpr Path operator~(Path p) { - return static_cast(~static_cast(p)); -} - -inline Path GetMostSignificantPath(Path path_mask) { - return static_cast(round_down_pot(static_cast(path_mask))); -} - -// ruy::kAllPaths represents all Path's that make sense to on a given -// base architecture. -#ifdef __linux__ -#if RUY_PLATFORM(NEON_64) -constexpr Path kAllPaths = - Path::kReference | Path::kStandardCpp | Path::kNeon | Path::kNeonDotprod; -#elif RUY_PLATFORM(NEON_32) -constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | Path::kNeon; -#elif RUY_PLATFORM(X86) -constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | - Path::kSse42 | Path::kAvx2 | Path::kAvx512 | - Path::kAvxVnni; -#else -constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp; -#endif -#else // __linux__ -// We don't know how to do runtime dotprod detection outside of linux for now. -#if RUY_PLATFORM(NEON) -constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | Path::kNeon; -#elif RUY_PLATFORM(X86) -constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | - Path::kSse42 | Path::kAvx2 | Path::kAvx512 | - Path::kAvxVnni; -#else -constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp; -#endif -#endif // __linux__ - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PATH_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/platform.h b/tensorflow/lite/experimental/ruy/ruy/platform.h deleted file mode 100644 index d6e86e6a792..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/platform.h +++ /dev/null @@ -1,156 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PLATFORM_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PLATFORM_H_ - -#ifdef __ANDROID_NDK__ -#include -#endif - -#define RUY_PLATFORM(X) ((RUY_DONOTUSEDIRECTLY_##X) != 0) - -// Architecture-level platform detection. -// -// Ruy requires these to be mutually exclusive. - -// Detect x86. -#if defined(__x86_64__) || defined(__i386__) || defined(__i386) || \ - defined(__x86__) || defined(__X86__) || defined(_X86_) || \ - defined(_M_IX86) || defined(_M_X64) -#define RUY_DONOTUSEDIRECTLY_X86 1 -#else -#define RUY_DONOTUSEDIRECTLY_X86 0 -#endif - -// Detect ARM 32-bit. -#ifdef __arm__ -#define RUY_DONOTUSEDIRECTLY_ARM_32 1 -#else -#define RUY_DONOTUSEDIRECTLY_ARM_32 0 -#endif - -// Detect ARM 64-bit. -#ifdef __aarch64__ -#define RUY_DONOTUSEDIRECTLY_ARM_64 1 -#else -#define RUY_DONOTUSEDIRECTLY_ARM_64 0 -#endif - -// Combined ARM. -#define RUY_DONOTUSEDIRECTLY_ARM \ - (RUY_DONOTUSEDIRECTLY_ARM_64 || RUY_DONOTUSEDIRECTLY_ARM_32) - -// Feature and capability platform detection. -// -// These are mostly sub-selections of architectures. - -// Detect NEON. Explicitly avoid emulation, or anything like it, on x86. -#if (defined(__ARM_NEON) || defined(__ARM_NEON__)) && !RUY_PLATFORM(X86) -#define RUY_DONOTUSEDIRECTLY_NEON 1 -#else -#define RUY_DONOTUSEDIRECTLY_NEON 0 -#endif - -// Define ARM 32-bit NEON. -#define RUY_DONOTUSEDIRECTLY_NEON_32 \ - (RUY_DONOTUSEDIRECTLY_NEON && RUY_DONOTUSEDIRECTLY_ARM_32) - -// Define ARM 64-bit NEON. -// Note: NEON is implied by ARM64, so this define is redundant. -// It still allows some conveyance of intent. -#define RUY_DONOTUSEDIRECTLY_NEON_64 \ - (RUY_DONOTUSEDIRECTLY_NEON && RUY_DONOTUSEDIRECTLY_ARM_64) - -// Disable X86 enhancements on __APPLE__ because b/138922878, see comment #8, we -// may only need to disable this on XCode <= 10.2. -// -// Disable when not using Clang-Linux, because too many user issues arise from -// compilation variations. -// -// NOTE: Consider guarding by !defined(__APPLE__) when removing Linux-only -// restriction. -// -// __EMSCRIPTEN__ is checked because the runtime Path resolution can use asm. -// -// The Android NDK logic excludes earlier and very broken versions of intrinsics -// headers. -#if defined(RUY_FORCE_ENABLE_X86_ENHANCEMENTS) || \ - (defined(__clang__) && (__clang_major__ >= 8) && defined(__linux__) && \ - !defined(__EMSCRIPTEN__) && \ - (!defined(__ANDROID_NDK__) || \ - (defined(__NDK_MAJOR__) && (__NDK_MAJOR__ >= 20)))) -#define RUY_DONOTUSEDIRECTLY_X86_ENHANCEMENTS 1 -#else -#define RUY_DONOTUSEDIRECTLY_X86_ENHANCEMENTS 0 -#endif - -// These CPU capabilities will all be true when Skylake, etc, are enabled during -// compilation. -#if RUY_PLATFORM(X86_ENHANCEMENTS) && RUY_PLATFORM(X86) && \ - defined(__AVX512F__) && defined(__AVX512DQ__) && defined(__AVX512CD__) && \ - defined(__AVX512BW__) && defined(__AVX512VL__) -#define RUY_DONOTUSEDIRECTLY_AVX512 1 -#else -#define RUY_DONOTUSEDIRECTLY_AVX512 0 -#endif - -#if RUY_PLATFORM(X86_ENHANCEMENTS) && RUY_PLATFORM(X86) && defined(__AVX2__) -#define RUY_DONOTUSEDIRECTLY_AVX2 1 -#else -#define RUY_DONOTUSEDIRECTLY_AVX2 0 -#endif - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// Note does not check for LZCNT or POPCNT. -#if defined(RUY_ENABLE_SSE_ENHANCEMENTS) && RUY_PLATFORM(X86_ENHANCEMENTS) && \ - RUY_PLATFORM(X86) && defined(__SSE4_2__) && defined(__FMA__) -#define RUY_DONOTUSEDIRECTLY_SSE42 1 -#else -#define RUY_DONOTUSEDIRECTLY_SSE42 0 -#endif - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// Note that defined(__AVX512VBMI2__) can be false for compilation with -// -march=cascadelake. -// TODO(b/146646451) Check if we should also gate on defined(__AVX512VBMI2__). -#if defined(RUY_ENABLE_VNNI_ENHANCEMENTS) && RUY_PLATFORM(AVX512) && \ - defined(__AVX512VNNI__) -#define RUY_DONOTUSEDIRECTLY_AVX_VNNI 1 -#else -#define RUY_DONOTUSEDIRECTLY_AVX_VNNI 0 -#endif - -// Detect APPLE. -#ifdef __APPLE__ -#define RUY_DONOTUSEDIRECTLY_APPLE 1 -#else -#define RUY_DONOTUSEDIRECTLY_APPLE 0 -#endif - -// Detect Emscripten, typically Wasm. -#ifdef __EMSCRIPTEN__ -#define RUY_DONOTUSEDIRECTLY_EMSCRIPTEN 1 -#else -#define RUY_DONOTUSEDIRECTLY_EMSCRIPTEN 0 -#endif - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PLATFORM_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/pmu.cc b/tensorflow/lite/experimental/ruy/ruy/pmu.cc deleted file mode 100644 index 6405aa15e6a..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/pmu.cc +++ /dev/null @@ -1,281 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/pmu.h" - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" - -#ifdef __linux__ -#include -#include -#include -#include -#include - -#include -#endif - -#include -#include -#include -#include - -namespace ruy { - -// Linux-specific. Not ARM-specific. -#ifdef __linux__ -class PerfEvent { - public: - PerfEvent(std::uint32_t type, std::uint64_t config) { - perf_event_attr pe; - memset(&pe, 0, sizeof(pe)); - pe.size = sizeof(pe); - pe.type = type; - pe.config = config; - pe.disabled = 1; - pe.exclude_kernel = 1; - pe.exclude_hv = 1; - pe.inherit = 1; - fd_ = syscall(__NR_perf_event_open, &pe, 0, -1, -1, 0); - if (fd_ == -1) { - fprintf(stderr, "perf_event_open failed for config 0x%lx\n", - static_cast(config)); - // abort(); - } - } - - ~PerfEvent() { - RUY_CHECK(!started_); - close(fd_); - } - - void Start() { - RUY_CHECK(!started_); - started_ = true; - ioctl(fd_, PERF_EVENT_IOC_RESET, 0); - ioctl(fd_, PERF_EVENT_IOC_ENABLE, 0); - count_at_start_ = Read(); - } - - void Stop() { - RUY_CHECK(started_); - started_ = false; - ioctl(fd_, PERF_EVENT_IOC_DISABLE, 0); - count_at_stop_ = Read(); - } - - std::int64_t Count() const { - RUY_CHECK(!started_); - return count_at_stop_ - count_at_start_; - } - - private: - std::int64_t Read() const { - std::int64_t count; - RUY_CHECK_NE(read(fd_, &count, sizeof(count)), -1); - return count; - } - std::int64_t count_at_start_ = -1; - std::int64_t count_at_stop_ = -1; - bool started_ = false; - int fd_ = -1; -}; -#else -// Placeholder implementation to at least compile outside of linux. -#define PERF_TYPE_RAW 0 -class PerfEvent { - public: - PerfEvent(std::uint32_t, std::uint64_t) {} - ~PerfEvent() {} - void Start() {} - void Stop() {} - std::int64_t Count() const { return 0; } -}; -#endif - -// ARM-specific. Query ARM PMU counters as Linux perf events using -// PERF_TYPE_RAW. -namespace arm_pmuv3 { - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-const-variable" - -// These event numbers are listed in the ARMv8 architecture reference manual. -constexpr std::uint16_t L1I_CACHE_REFILL = 0x01; -constexpr std::uint16_t L1I_TLB_REFILL = 0x02; -constexpr std::uint16_t L1D_CACHE_REFILL = 0x03; -constexpr std::uint16_t L1D_CACHE = 0x04; -constexpr std::uint16_t L1D_TLB_REFILL = 0x05; -constexpr std::uint16_t LD_RETIRED = 0x06; -constexpr std::uint16_t ST_RETIRED = 0x07; -constexpr std::uint16_t INST_RETIRED = 0x08; -constexpr std::uint16_t EXC_TAKEN = 0x09; -constexpr std::uint16_t EXC_RETURN = 0x0A; -constexpr std::uint16_t CID_WRITE_RETIRED = 0x0B; -constexpr std::uint16_t PC_WRITE_RETIRED = 0x0C; -constexpr std::uint16_t BR_IMMED_RETIRED = 0x0D; -constexpr std::uint16_t BR_RETURN_RETIRED = 0x0E; -constexpr std::uint16_t UNALIGNED_LDST_RETIRED = 0x0F; -constexpr std::uint16_t BR_MIS_PRED = 0x10; -constexpr std::uint16_t CPU_CYCLES = 0x11; -constexpr std::uint16_t BR_PRED = 0x12; -constexpr std::uint16_t MEM_ACCESS = 0x13; -constexpr std::uint16_t L1I_CACHE = 0x14; -constexpr std::uint16_t L1D_CACHE_WB = 0x15; -constexpr std::uint16_t L2D_CACHE = 0x16; -constexpr std::uint16_t L2D_CACHE_REFILL = 0x17; -constexpr std::uint16_t L2D_CACHE_WB = 0x18; -constexpr std::uint16_t BUS_ACCESS = 0x19; -constexpr std::uint16_t MEMORY_ERROR = 0x1A; -constexpr std::uint16_t INST_SPEC = 0x1B; -constexpr std::uint16_t TTBR_WRITE_RETIRED = 0x1C; -constexpr std::uint16_t BUS_CYCLES = 0x1D; -constexpr std::uint16_t CHAIN = 0x1E; -constexpr std::uint16_t L1D_CACHE_ALLOCATE = 0x1F; -constexpr std::uint16_t L2D_CACHE_ALLOCATE = 0x20; -constexpr std::uint16_t BR_RETIRED = 0x21; -constexpr std::uint16_t BR_MIS_PRED_RETIRED = 0x22; -constexpr std::uint16_t STALL_FRONTEND = 0x23; -constexpr std::uint16_t STALL_BACKEND = 0x24; -constexpr std::uint16_t L1D_TLB = 0x25; -constexpr std::uint16_t L1I_TLB = 0x26; -constexpr std::uint16_t L2I_CACHE = 0x27; -constexpr std::uint16_t L2I_CACHE_REFILL = 0x28; -constexpr std::uint16_t L3D_CACHE_ALLOCATE = 0x29; -constexpr std::uint16_t L3D_CACHE_REFILL = 0x2A; -constexpr std::uint16_t L3D_CACHE = 0x2B; -constexpr std::uint16_t L3D_CACHE_WB = 0x2C; -constexpr std::uint16_t L2D_TLB_REFILL = 0x2D; -constexpr std::uint16_t L2I_TLB_REFILL = 0x2E; -constexpr std::uint16_t L2D_TLB = 0x2F; -constexpr std::uint16_t L2I_TLB = 0x30; -constexpr std::uint16_t LL_CACHE = 0x32; -constexpr std::uint16_t LL_CACHE_MISS = 0x33; -constexpr std::uint16_t DTLB_WALK = 0x34; -constexpr std::uint16_t LL_CACHE_RD = 0x36; -constexpr std::uint16_t LL_CACHE_MISS_RD = 0x37; - -// Additional implementation-defined events found by googling around. -constexpr std::uint16_t L1D_CACHE_RD = 0x40; -constexpr std::uint16_t L1D_CACHE_REFILL_RD = 0x42; -constexpr std::uint16_t L1D_TLB_REFILL_RD = 0x4C; -constexpr std::uint16_t L1D_TLB_RD = 0x4E; -constexpr std::uint16_t L2D_CACHE_RD = 0x50; -constexpr std::uint16_t L2D_CACHE_REFILL_RD = 0x52; -constexpr std::uint16_t BUS_ACCESS_RD = 0x60; -constexpr std::uint16_t MEM_ACCESS_RD = 0x66; -constexpr std::uint16_t L3D_CACHE_RD = 0xA0; -constexpr std::uint16_t L3D_CACHE_REFILL_RD = 0xA2; - -#pragma GCC diagnostic pop - -}; // namespace arm_pmuv3 - -class PmuEventsPrivate { - public: - PmuEventsPrivate() - : l1d_cache_refill(PERF_TYPE_RAW, arm_pmuv3::L1D_CACHE_REFILL), - l2d_cache_refill(PERF_TYPE_RAW, arm_pmuv3::L2D_CACHE_REFILL), - l3d_cache_refill(PERF_TYPE_RAW, arm_pmuv3::L3D_CACHE_REFILL), - ll_cache_miss(PERF_TYPE_RAW, arm_pmuv3::LL_CACHE_MISS), - l1d_tlb_refill(PERF_TYPE_RAW, arm_pmuv3::L1D_TLB_REFILL), - l2d_tlb_refill(PERF_TYPE_RAW, arm_pmuv3::L2D_TLB_REFILL), - stall_frontend(PERF_TYPE_RAW, arm_pmuv3::STALL_FRONTEND), - stall_backend(PERF_TYPE_RAW, arm_pmuv3::STALL_BACKEND), - br_mis_pred(PERF_TYPE_RAW, arm_pmuv3::BR_MIS_PRED) {} - - private: - friend class PmuEvents; - PerfEvent l1d_cache_refill; - PerfEvent l2d_cache_refill; - PerfEvent l3d_cache_refill; - PerfEvent ll_cache_miss; - PerfEvent l1d_tlb_refill; - PerfEvent l2d_tlb_refill; - PerfEvent stall_frontend; - PerfEvent stall_backend; - PerfEvent br_mis_pred; -}; - -PmuEvents::PmuEvents() : priv(new PmuEventsPrivate) {} -PmuEvents::~PmuEvents() { delete priv; } - -void PmuEvents::StartRecording() { - priv->l1d_cache_refill.Start(); - priv->l2d_cache_refill.Start(); - priv->l3d_cache_refill.Start(); - priv->ll_cache_miss.Start(); - priv->l1d_tlb_refill.Start(); - priv->l2d_tlb_refill.Start(); - priv->stall_frontend.Start(); - priv->stall_backend.Start(); - priv->br_mis_pred.Start(); -} - -void PmuEvents::StopRecording() { - priv->l1d_cache_refill.Stop(); - priv->l2d_cache_refill.Stop(); - priv->l3d_cache_refill.Stop(); - priv->ll_cache_miss.Stop(); - priv->l1d_tlb_refill.Stop(); - priv->l2d_tlb_refill.Stop(); - priv->stall_frontend.Stop(); - priv->stall_backend.Stop(); - priv->br_mis_pred.Stop(); -} - -float PmuEvents::BranchMispredictionCount() const { - return static_cast(priv->br_mis_pred.Count()); -} - -float PmuEvents::FrontendStallCount() const { - return static_cast(priv->stall_frontend.Count()); -} - -float PmuEvents::BackendStallCount() const { - return static_cast(priv->stall_backend.Count()); -} - -float PmuEvents::L1RefillCount() const { - return static_cast(priv->l1d_cache_refill.Count()); -} - -float PmuEvents::L2RefillCount() const { - return static_cast(priv->l2d_cache_refill.Count()); -} - -float PmuEvents::L3RefillCount() const { - // Important: this was discovered in the context of the above experiments, - // which also tested the _RD variants of these counters. So it's possible that - // it's just not needed here with the default (non _RD) counters. - // - // Some CPUs implement LL_CACHE_MISS[_RD], some implement - // L3D_CACHE_REFILL[_RD]. It seems that either one of these two counters is - // zero, or they roughly both agree with each other. Therefore, taking the max - // of them is a reasonable way to get something more portable across various - // CPUs. - return static_cast( - std::max(priv->l3d_cache_refill.Count(), priv->ll_cache_miss.Count())); -} - -float PmuEvents::L1TLBRefillCount() const { - return static_cast(priv->l1d_tlb_refill.Count()); -} - -float PmuEvents::L2TLBRefillCount() const { - return static_cast(priv->l2d_tlb_refill.Count()); -} - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/pmu.h b/tensorflow/lite/experimental/ruy/ruy/pmu.h deleted file mode 100644 index 721c1d5f1cc..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/pmu.h +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PMU_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PMU_H_ - -namespace ruy { - -class PmuEventsPrivate; - -class PmuEvents { - public: - PmuEvents(); - ~PmuEvents(); - void StartRecording(); - void StopRecording(); - float L1RefillCount() const; - float L2RefillCount() const; - float L3RefillCount() const; - float BranchMispredictionCount() const; - float FrontendStallCount() const; - float BackendStallCount() const; - float L1TLBRefillCount() const; - float L2TLBRefillCount() const; - - private: - PmuEventsPrivate* priv = nullptr; -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PMU_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/prepack.h b/tensorflow/lite/experimental/ruy/ruy/prepack.h deleted file mode 100644 index 794b8df7b4d..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/prepack.h +++ /dev/null @@ -1,108 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Implementation of low-level pre-packing API. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACK_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACK_H_ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/context.h" -#include "tensorflow/lite/experimental/ruy/ruy/dispatch.h" -#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" -#include "tensorflow/lite/experimental/ruy/ruy/spec.h" -#include "tensorflow/lite/experimental/ruy/ruy/trmul.h" -#include "tensorflow/lite/experimental/ruy/ruy/trmul_params.h" -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -namespace ruy { - -template -void PrePackForMulInternal(const Matrix& lhs, - const Matrix& rhs, const Spec& spec, - Context* context, Matrix* dst, - SidePair prepacked, - std::function alloc_fn) { - profiler::ScopeLabel label("PrePackForMul"); - Path the_path = context->GetPathToTake(); - RUY_CHECK_NE(the_path, Path::kReference); - constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference; - Matrix transposed_lhs(lhs); - Transpose(&transposed_lhs); - TrMulParams params; - CreateTrMulParams(transposed_lhs, rhs, spec, context, dst, - the_path, ¶ms); - - const SidePair origin{0, 0}; - const SidePair rounded_dims{params.packed[Side::kLhs].layout.cols, - params.packed[Side::kRhs].layout.cols}; - - Tuning tuning = context->GetMainThreadTuning(); - for (Side side : {Side::kLhs, Side::kRhs}) { - if (prepacked[side]) { - prepacked[side]->data_size = DataSize(params.packed[side]); - prepacked[side]->sums_size = SumsSize(params.packed[side]); - prepacked[side]->data = alloc_fn(prepacked[side]->data_size); - prepacked[side]->sums = alloc_fn(prepacked[side]->sums_size); - params.packed[side].data = prepacked[side]->data; - params.packed[side].sums = prepacked[side]->sums; - params.RunPack(side, tuning, origin[side], rounded_dims[side]); - } - } -} - -template -void MulWithPrepackedInternal(const Matrix& lhs, - const Matrix& rhs, const Spec& spec, - Context* context, Matrix* dst, - SidePair prepacked) { - profiler::ScopeLabel label("MulWithPrepacked"); - - EnforceLayoutSupport(lhs.layout, rhs.layout, dst->layout); - EnforceZeroPointSupport(lhs.zero_point, rhs.zero_point, - dst->zero_point); - - Path the_path = context->GetPathToTake(); - RUY_CHECK_NE(the_path, Path::kReference); - constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference; - Matrix transposed_lhs(lhs); - Transpose(&transposed_lhs); - TrMulParams params; - CreateTrMulParams(transposed_lhs, rhs, spec, context, dst, - the_path, ¶ms); - - for (Side side : {Side::kLhs, Side::kRhs}) { - if (prepacked[side]) { - params.packed[side].data = prepacked[side]->data; - params.packed[side].sums = prepacked[side]->sums; - params.is_prepacked[side] = true; - } - } - - TrMul(¶ms, context); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACK_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/prepacked_cache.cc b/tensorflow/lite/experimental/ruy/ruy/prepacked_cache.cc deleted file mode 100644 index da683020169..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/prepacked_cache.cc +++ /dev/null @@ -1,86 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/prepacked_cache.h" - -#include - -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -namespace ruy { - -using CacheIterator = PrepackedCache::CacheIterator; - -// Looks for an entry with `key`. If found, update its time stamp. -CacheIterator PrepackedCache::FindAndUpdate(const CacheKey &key) { - auto itr = cache_.find(key); - // If found, update with new access time for this entry. - if (itr != cache_.end()) { - const TimePoint time = CacheNow(); - itr->second.second = time; - } - // std::move() is required in the MSVC STL when NDEBUG is not set, and has no - // effect in libc++. - return std::move(itr); // NOLINT -} - -void PrepackedCache::Insert(const CacheKey &key, - const PrepackedMatrix &matrix) { - // Calculate size of this new item. - const size_t size_bytes = matrix.data_size + matrix.sums_size; - - // While we are above the threshold of ejection, eject the LRU entry. - while (!cache_.empty() && - ((TotalSize() + size_bytes) > ejection_threshold_)) { - EjectOne(); - } - DoInsert(key, matrix); - cache_size_ += matrix.data_size + matrix.sums_size; -} - -void PrepackedCache::EjectOne() { - TimePoint oldest_time = CacheNow(); - auto oldest = cache_.begin(); - { - profiler::ScopeLabel label("PepackedCacheEjection"); - for (auto itr = cache_.begin(); itr != cache_.end(); ++itr) { - if (itr->second.second < oldest_time) { - oldest_time = itr->second.second; - oldest = itr; - } - } - } - PrepackedMatrix &pmatrix = oldest->second.first; - cache_size_ -= pmatrix.data_size; - cache_size_ -= pmatrix.sums_size; - allocator_.Free(pmatrix.data); - allocator_.Free(pmatrix.sums); - cache_.erase(oldest); -} - -void PrepackedCache::AllocatePrepackedMatrix(PrepackedMatrix *pmatrix) { - pmatrix->data = allocator_.Alloc(pmatrix->data_size); - pmatrix->sums = allocator_.Alloc(pmatrix->sums_size); -} - -void PrepackedCache::DoInsert(const CacheKey &key, - const PrepackedMatrix &matrix) { - const TimePoint t = CacheNow(); - const MatrixWithTimeStamp mts({matrix, t}); - cache_.insert({key, mts}); -} - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/prepacked_cache.h b/tensorflow/lite/experimental/ruy/ruy/prepacked_cache.h deleted file mode 100644 index f2ee15559c7..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/prepacked_cache.h +++ /dev/null @@ -1,130 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACKED_CACHE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACKED_CACHE_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/allocator.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/time.h" - -namespace ruy { - -namespace detail { - -// Tracks a set of blocks allocated from the underlying system allocator. -class SystemBlockAllocator { - public: - void *Alloc(std::ptrdiff_t num_bytes) { - void *p = detail::SystemAlignedAlloc(num_bytes); - blocks_.push_back(p); - return p; - } - - void Free(void *block) { - for (auto it = blocks_.begin(); it != blocks_.end(); ++it) { - if (*it == block) { - detail::SystemAlignedFree(block); - blocks_.erase(it); - return; - } - } - RUY_DCHECK(false); // Trying to free pointer we did not allocate. - } - - ~SystemBlockAllocator() { - for (void *block : blocks_) { - detail::SystemAlignedFree(block); - } - } - - private: - std::vector blocks_; -}; - -} // namespace detail - -enum CachePolicy { kNoCache, kCacheLHSOnNarrowMul }; - -// "Low effort" Least Recently Used Cache for Prepacked Matrices -// A cache mechanism for prepacked matrices that ejects oldest entries. -// The implementation is "low effort" in the following ways: -// - we just linearly search for the oldest entry when doing an ejection -// - the ejection policy is very simple: if the new size would be above the -// . threshold, we will eject entries until the size is below the threshold. -// Current use cases (RNNs with GEMV operations) indicate that ejection is rare -// and memory constraints are tight, so we devote no additional storage to the -// LRU mechanism and accept O(n) search to eject oldest entry. In practice, -// the number of total entries has not been shown to be large. -// This class is not thread safe. In Ruy, memory allocation for packed matrices -// is done in a single threaded context and the actual packing activity may -// be done in a multi-threaded context. -class PrepackedCache { - public: - static constexpr int kDefaultEjectionThresholdBytes = 1 << 28; - - using CacheKey = std::pair; - - using MatrixWithTimeStamp = std::pair; - - using CacheIterator = std::map::const_iterator; - - using AlignedAllocator = detail::AlignedAllocator; - - explicit PrepackedCache( - int32_t ejection_threshold = kDefaultEjectionThresholdBytes) - : ejection_threshold_(ejection_threshold), cache_size_(0) {} - - // Looks for an entry with `key`. If found, update its time stamp. - CacheIterator FindAndUpdate(const CacheKey &key); - - // Returns end iterator for internal cache. The iterator type is appropriate - // to use with `FindAndUpdate`. - CacheIterator cend() const { return cache_.end(); } - - // Returns the total size (in bytes) of data held in this cache. - int TotalSize() const { return cache_size_; } - - // All calls to get current TimePoints go through here. - // TODO(b/145625614) Profile timestamps on relevant models to see if - // this level of granularity is sufficient. CoarseNow is cheap so - // it would be nice to keep it. - TimePoint CacheNow() const { return CoarseNow(); } - - // Performs the memory allocation for the `data` and `sums` members of a - // PrepackedMatrix. - void AllocatePrepackedMatrix(PrepackedMatrix *pmatrix); - - // Adds the PrepackedMatrix to the cache, possibly ejecting other values. - void Insert(const CacheKey &key, const PrepackedMatrix &matrix); - - private: - void EjectOne(); - void DoInsert(const CacheKey &key, const PrepackedMatrix &matrix); - detail::SystemBlockAllocator allocator_; - std::map cache_; - const int32_t ejection_threshold_; - size_t cache_size_; -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACKED_CACHE_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/prepacked_cache_test.cc b/tensorflow/lite/experimental/ruy/ruy/prepacked_cache_test.cc deleted file mode 100644 index 453190a3b88..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/prepacked_cache_test.cc +++ /dev/null @@ -1,210 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/prepacked_cache.h" - -#include // NOLINT(build/c++11) - -#include -#include "tensorflow/lite/experimental/ruy/ruy/ruy.h" -#include "tensorflow/lite/experimental/ruy/ruy/time.h" - -namespace ruy { -namespace { - -TEST(PrepackedCacheTest, TestCacheEjection) { - // Create the cache. - PrepackedCache prepacked_cache(32); - // Allocate the prepacked matrix. - PrepackedMatrix mat1; - mat1.data_size = 16; - mat1.sums_size = 8; - prepacked_cache.AllocatePrepackedMatrix(&mat1); - auto cache_key1 = std::make_pair(nullptr, mat1.data); - prepacked_cache.Insert(cache_key1, mat1); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - // Get a time point after the insertion into the cache. - TimePoint current = CoarseNow(); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - PrepackedCache::CacheIterator itr = prepacked_cache.FindAndUpdate(cache_key1); - EXPECT_NE(itr, prepacked_cache.cend()); - // By finding mat1, we updated its timestamp. Verify that `current` is older - // than the time stamp now associated with mat1. - EXPECT_LT(current, itr->second.second); - PrepackedMatrix mat2; - mat2.data_size = 8; - mat2.sums_size = 4; - prepacked_cache.AllocatePrepackedMatrix(&mat2); - - auto cache_key2 = std::make_pair(nullptr, mat2.data); - prepacked_cache.Insert(cache_key2, mat2); - // The cache size was exceeded by inserting mat2. Ensure that mat1 was - // ejected. - EXPECT_EQ(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); -} - -TEST(PrepackedCacheTest, TestCacheBasic) { - // Create the cache. - PrepackedCache prepacked_cache(48); - // Allocate the prepacked matrix. - PrepackedMatrix mat1; - mat1.data_size = 16; - mat1.sums_size = 8; - prepacked_cache.AllocatePrepackedMatrix(&mat1); - - auto cache_key1 = std::make_pair(nullptr, mat1.data); - prepacked_cache.Insert(cache_key1, mat1); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); - - PrepackedMatrix mat2; - mat2.data_size = 8; - mat2.sums_size = 4; - prepacked_cache.AllocatePrepackedMatrix(&mat2); - - auto cache_key2 = std::make_pair(nullptr, mat2.data); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - prepacked_cache.Insert(cache_key2, mat2); - // The cache size was not exceeded by inserting mat2. Ensure that mat1 was not - // ejected. - EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); -} - -TEST(PrepackedCacheTest, TestCacheEjection2) { - // Create the cache. - PrepackedCache prepacked_cache(73); - // Allocate the prepacked matrix 1. - PrepackedMatrix mat1; - mat1.data_size = 16; - mat1.sums_size = 8; - prepacked_cache.AllocatePrepackedMatrix(&mat1); - auto cache_key1 = std::make_pair(nullptr, mat1.data); - prepacked_cache.Insert(cache_key1, mat1); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - - // Allocate the prepacked matrix 2. - PrepackedMatrix mat2; - mat2.data_size = 16; - mat2.sums_size = 8; - prepacked_cache.AllocatePrepackedMatrix(&mat2); - auto cache_key2 = std::make_pair(nullptr, mat2.data); - prepacked_cache.Insert(cache_key2, mat2); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - - // Allocate the prepacked matrix 3. - PrepackedMatrix mat31; - mat31.data_size = 16; - mat31.sums_size = 8; - prepacked_cache.AllocatePrepackedMatrix(&mat31); - auto cache_key3 = std::make_pair(nullptr, mat31.data); - prepacked_cache.Insert(cache_key3, mat31); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - - // The next insertion will cause the cache size to go over the ejection - // threshold. Touch matrix 1 and matrix 3 to make matrix 2 the oldest - EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); - EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key3), prepacked_cache.cend()); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - - // Allocate the prepacked matrix 4. - PrepackedMatrix mat4; - mat4.data_size = 16; - mat4.sums_size = 8; - prepacked_cache.AllocatePrepackedMatrix(&mat4); - auto cache_key4 = std::make_pair(nullptr, mat4.data); - prepacked_cache.Insert(cache_key4, mat4); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - - // Ensure that mat2 was ejected, but mat1, mat3, and mat4 were not. - EXPECT_EQ(prepacked_cache.FindAndUpdate(cache_key2), prepacked_cache.cend()); - EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key3), prepacked_cache.cend()); - EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); - EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key4), prepacked_cache.cend()); -} - -TEST(PrepackedCacheTest, TestCacheOnCacheable) { - // Create context and set the cache policy - ruy::Context context; - context.cache_policy = ruy::kCacheLHSOnNarrowMul; - PrepackedCache* cache = context.GetPrepackedCache(); - EXPECT_EQ(cache->TotalSize(), 0); - - const float lhs_data[] = {1, 2, 3, 4}; - const float rhs_data[] = {1, 2}; - float dst_data[4]; - - ruy::Matrix lhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); - lhs.data = lhs_data; - ruy::Matrix rhs; - ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &rhs.layout); - rhs.data = rhs_data; - ruy::Matrix dst; - ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &dst.layout); - dst.data = dst_data; - - ruy::BasicSpec spec; - // Perform the multiplication and confirm no caching occurred. - ruy::Mul(lhs, rhs, spec, &context, &dst); - EXPECT_EQ(cache->TotalSize(), 0); - - // Set cacheable for the LHS, repeat the multiplication, and see - // that caching did occur. - lhs.cacheable = true; - ruy::Mul(lhs, rhs, spec, &context, &dst); - EXPECT_NE(cache->TotalSize(), 0); -} - -TEST(PrepackedCacheTest, TestClearCache) { - // Create context and set the cache policy - ruy::Context context; - context.cache_policy = ruy::kCacheLHSOnNarrowMul; - PrepackedCache* cache = context.GetPrepackedCache(); - EXPECT_EQ(cache->TotalSize(), 0); - - const float lhs_data[] = {1, 2, 3, 4}; - const float rhs_data[] = {1, 2}; - float dst_data[4]; - - ruy::Matrix lhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); - lhs.data = lhs_data; - ruy::Matrix rhs; - ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &rhs.layout); - rhs.data = rhs_data; - ruy::Matrix dst; - ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &dst.layout); - dst.data = dst_data; - - ruy::BasicSpec spec; - // Set cacheable for the LHS and see that caching occurs. - lhs.cacheable = true; - ruy::Mul(lhs, rhs, spec, &context, &dst); - EXPECT_NE(cache->TotalSize(), 0); - - // Clear the cache via the Context. - context.ClearPrepackedCache(); - // Verify that the cache is now empty. - cache = context.GetPrepackedCache(); - EXPECT_EQ(cache->TotalSize(), 0); -} - -} // namespace -} // namespace ruy - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/experimental/ruy/ruy/profiler/BUILD b/tensorflow/lite/experimental/ruy/ruy/profiler/BUILD deleted file mode 100644 index 5e9d9bd3bae..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/profiler/BUILD +++ /dev/null @@ -1,60 +0,0 @@ -# A minimalistic profiler sampling pseudo-stacks - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -config_setting( - name = "ruy_profiler", - define_values = {"ruy_profiler": "true"}, -) - -# Used to build TFLite Micro RUY dependency for embedded targets outside of the -# RUY source tree. -filegroup( - name = "ruy_instrumentation_header", - srcs = ["instrumentation.h"], - visibility = ["//visibility:public"], -) - -cc_library( - name = "instrumentation", - srcs = ["instrumentation.cc"], - hdrs = ["instrumentation.h"], - defines = select({ - ":ruy_profiler": ["RUY_PROFILER"], - "//conditions:default": [], - }), -) - -cc_library( - name = "profiler", - srcs = [ - "profiler.cc", - "treeview.cc", - ], - hdrs = [ - "profiler.h", - "treeview.h", - ], - deps = [":instrumentation"], -) - -cc_library( - name = "test_instrumented_library", - testonly = True, - srcs = ["test_instrumented_library.cc"], - hdrs = ["test_instrumented_library.h"], - deps = [":instrumentation"], -) - -cc_test( - name = "test", - srcs = ["test.cc"], - deps = [ - ":profiler", - ":test_instrumented_library", - "@com_google_googletest//:gtest", - ], -) diff --git a/tensorflow/lite/experimental/ruy/ruy/profiler/README.md b/tensorflow/lite/experimental/ruy/ruy/profiler/README.md deleted file mode 100644 index 8d7902566b3..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/profiler/README.md +++ /dev/null @@ -1,149 +0,0 @@ -# A minimalistic profiler sampling pseudo-stacks - -## Overview - -The present directory is the "ruy profiler". As a time profiler, it allows to -measure where code is spending time. - -Contrary to most typical profilers, what it samples is not real call stacks, but -"pseudo-stacks" which are just simple data structures constructed from within -the program being profiled. Using this profiler requires manually instrumenting -code to construct such pseudo-stack information. - -Another unusual characteristic of this profiler is that it uses only the C++11 -standard library. It does not use any non-portable feature, in particular it -does not rely on signal handlers. The sampling is performed by a thread, the -"profiler thread". - -A discussion of pros/cons of this approach is appended below. - -## How to use this profiler - -### How to instrument code - -An example of instrumented code is given in `test_instrumented_library.cc`. - -Code is instrumented by constructing `ScopeLabel` objects. These are RAII -helpers, ensuring that the thread pseudo-stack contains the label during their -lifetime. In the most common use case, one would construct such an object at the -start of a function, so that its scope is the function scope and it allows to -measure how much time is spent in this function. - -```c++ -#include "ruy/profiler/instrumentation.h" - -... - -void SomeFunction() { - ruy::profiling::ScopeLabel function_label("SomeFunction"); - ... do something ... -} -``` - -A `ScopeLabel` may however have any scope, for instance: - -```c++ -if (some_case) { - ruy::profiling::ScopeLabel extra_work_label("Some more work"); - ... do some more work ... -} -``` - -The string passed to the `ScopeLabel` constructor must be just a pointer to a -literal string (a `char*` pointer). The profiler will assume that these pointers -stay valid until the profile is finalized. - -However, that literal string may be a `printf` format string, and labels may -have up to 4 parameters, of type `int`. For example: - -```c++ -void SomeFunction(int size) { - ruy::profiling::ScopeLabel function_label("SomeFunction (size=%d)", size); - -``` - -### How to run the profiler - -Profiling instrumentation is a no-op unless the preprocessor token -`RUY_PROFILER` is defined, so defining it is the first step when actually -profiling. When building with Bazel, the preferred way to enable that is to pass -this flag on the Bazel command line: - -``` ---define=ruy_profiler=true -``` - -To actually profile a code scope, it is enough to construct a `ScopeProfile` -object, also a RAII helper. It will start the profiler on construction, and on -destruction it will terminate the profiler and report the profile treeview on -standard output by default. Example: - -```c++ -void SomeProfiledBenchmark() { - ruy::profiling::ScopeProfile profile; - - CallSomeInstrumentedCode(); -} -``` - -An example is provided by the `:test` target in the present directory. Run it -with `--define=ruy_profiler=true` as explained above: - -``` -bazel run -c opt \ - --define=ruy_profiler=true \ - //tensorflow/lite/experimental/ruy/profiler:test -``` - -The default behavior dumping the treeview on standard output may be overridden -by passing a pointer to a `TreeView` object to the `ScopeProfile` constructor. -This causes the tree-view to be stored in that `TreeView` object, where it may -be accessed an manipulated using the functions declared in `treeview.h`. The -aforementioned `:test` provides examples for doing so. - -## Advantages and inconvenients - -Compared to a traditional profiler, e.g. Linux's "perf", the present kind of -profiler has the following inconvenients: - -* Requires manual instrumentation of code being profiled. -* Substantial overhead, modifying the performance characteristics of the code - being measured. -* Questionable accuracy. - -But also the following advantages: - -* Profiling can be driven from within a benchmark program, allowing the entire - profiling procedure to be a single command line. -* Not relying on symbol information removes removes exposure to toolchain - details and means less hassle in some build environments, especially - embedded/mobile (single command line to run and profile, no symbols files - required). -* Fully portable (all of this is standard C++11). -* Fully testable (see `:test`). Profiling becomes just another feature of the - code like any other. -* Customized instrumentation can result in easier to read treeviews (only - relevant functions, and custom labels may be more readable than function - names). -* Parametrized/formatted labels allow to do things that aren't possible with - call-stack-sampling profilers. For example, break down a profile where much - time is being spent in matrix multiplications, by the various matrix - multiplication shapes involved. - -The philosophy underlying this profiler is that software performance depends on -software engineers profiling often, and a key factor limiting that in practice -is the difficulty or cumbersome aspects of profiling with more serious profilers -such as Linux's "perf", especially in embedded/mobile development: multiple -command lines are involved to copy symbol files to devices, retrieve profile -data from the device, etc. In that context, it is useful to make profiling as -easy as benchmarking, even on embedded targets, even if the price to pay for -that is lower accuracy, higher overhead, and some intrusive instrumentation -requirement. - -Another key aspect determining what profiling approach is suitable for a given -context, is whether one already has a-priori knowledge of where much of the time -is likely being spent. When one has such a-priori knowledge, it is feasible to -instrument the known possibly-critical code as per the present approach. On the -other hand, in situations where one doesn't have such a-priori knowledge, a real -profiler such as Linux's "perf" allows to right away get a profile of real -stacks, from just symbol information generated by the toolchain. diff --git a/tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.cc b/tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.cc deleted file mode 100644 index b7c330c04bd..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.cc +++ /dev/null @@ -1,130 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -#ifdef RUY_PROFILER - -namespace ruy { -namespace profiler { - -void Label::operator=(const Label& other) { - format_ = other.format_; - args_count_ = other.args_count_; - for (int i = 0; i < args_count_; i++) { - args_[i] = other.args_[i]; - } -} - -bool Label::operator==(const Label& other) const { - if (std::string(format_) != std::string(other.format_)) { - return false; - } - if (args_count_ != other.args_count_) { - return false; - } - for (int i = 0; i < args_count_; i++) { - if (args_[i] != other.args_[i]) { - return false; - } - } - return true; -} - -std::string Label::Formatted() const { - static constexpr int kBufSize = 256; - char buf[kBufSize]; - if (args_count_ == 0) { - return format_; - } - if (args_count_ == 1) { - snprintf(buf, kBufSize, format_, args_[0]); - } else if (args_count_ == 2) { - snprintf(buf, kBufSize, format_, args_[0], args_[1]); - } else if (args_count_ == 3) { - snprintf(buf, kBufSize, format_, args_[0], args_[1], args_[2]); - } else if (args_count_ == 4) { - snprintf(buf, kBufSize, format_, args_[0], args_[1], args_[2], args_[3]); - } else { - abort(); - } - return buf; -} - -namespace detail { - -std::mutex* GlobalsMutex() { - static std::mutex mutex; - return &mutex; -} - -bool& GlobalIsProfilerRunning() { - static bool b; - return b; -} - -std::vector* GlobalAllThreadStacks() { - static std::vector all_stacks; - return &all_stacks; -} - -ThreadStack* ThreadLocalThreadStack() { - thread_local static ThreadStack thread_stack; - return &thread_stack; -} - -ThreadStack::ThreadStack() { - std::lock_guard lock(*GlobalsMutex()); - static std::uint32_t global_next_thread_stack_id = 0; - stack_.id = global_next_thread_stack_id++; - GlobalAllThreadStacks()->push_back(this); -} - -ThreadStack::~ThreadStack() { - std::lock_guard lock(*GlobalsMutex()); - std::vector* all_stacks = GlobalAllThreadStacks(); - for (auto it = all_stacks->begin(); it != all_stacks->end(); ++it) { - if (*it == this) { - all_stacks->erase(it); - return; - } - } -} -int GetBufferSize(const Stack& stack) { - return sizeof(stack.id) + sizeof(stack.size) + - stack.size * sizeof(stack.labels[0]); -} - -void CopyToBuffer(const Stack& stack, char* dst) { - memcpy(dst, &stack.id, sizeof(stack.id)); - dst += sizeof(stack.id); - memcpy(dst, &stack.size, sizeof(stack.size)); - dst += sizeof(stack.size); - memcpy(dst, stack.labels, stack.size * sizeof(stack.labels[0])); -} - -void ReadFromBuffer(const char* src, Stack* stack) { - memcpy(&stack->id, src, sizeof(stack->id)); - src += sizeof(stack->id); - memcpy(&stack->size, src, sizeof(stack->size)); - src += sizeof(stack->size); - memcpy(stack->labels, src, stack->size * sizeof(stack->labels[0])); -} - -} // namespace detail -} // namespace profiler -} // namespace ruy - -#endif diff --git a/tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h b/tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h deleted file mode 100644 index a9046d465af..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h +++ /dev/null @@ -1,203 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_INSTRUMENTATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_INSTRUMENTATION_H_ - -#ifdef RUY_PROFILER -#include -#include -#include -#endif - -namespace ruy { -namespace profiler { - -#ifdef RUY_PROFILER - -// A label is how a code scope is annotated to appear in profiles. -// The stacks that are sampled by the profiler are stacks of such labels. -// A label consists of a literal string, plus optional integer arguments. -class Label { - public: - Label() {} - template - explicit Label(Args... args) { - Set(args...); - } - void Set(const char* format) { - format_ = format; - args_count_ = 0; - } - template - void Set(const char* format, Args... args) { - format_ = format; - args_count_ = sizeof...(args); - SetArgs(0, args...); - } - - void operator=(const Label& other); - - bool operator==(const Label& other) const; - - std::string Formatted() const; - const char* format() const { return format_; } - - private: - void SetArgs(int position, int arg0) { args_[position] = arg0; } - - template - void SetArgs(int position, int arg0, Args... args) { - SetArgs(position, arg0); - SetArgs(position + 1, args...); - } - - static constexpr int kMaxArgs = 4; - const char* format_ = nullptr; - int args_count_ = 0; - int args_[kMaxArgs]; -}; - -namespace detail { - -// Forward-declaration, see class ThreadStack below. -class ThreadStack; - -bool& GlobalIsProfilerRunning(); - -// Returns the global vector of pointers to all stacks, there being one stack -// per thread executing instrumented code. -std::vector* GlobalAllThreadStacks(); - -// Returns the mutex to be locked around any access to GlobalAllThreadStacks(). -std::mutex* GlobalsMutex(); - -// Returns the thread-local stack, specific to the current thread. -ThreadStack* ThreadLocalThreadStack(); - -// This 'stack' is what may be more appropriately called a 'pseudostack': -// It contains Label entries that are 'manually' entered by instrumentation -// code. It's unrelated to real call stacks. -struct Stack { - std::uint32_t id = 0; - static constexpr int kMaxSize = 64; - int size = 0; - Label labels[kMaxSize]; -}; - -// Returns the buffer byte size required by CopyToSample. -int GetBufferSize(const Stack& stack); - -// Copies this Stack into a byte buffer, called a 'sample'. -void CopyToBuffer(const Stack& stack, char* dst); - -// Populates this Stack from an existing sample buffer, typically -// produced by CopyToSample. -void ReadFromBuffer(const char* src, Stack* stack); - -// ThreadStack is meant to be used as a thread-local singleton, assigning to -// each thread a Stack object holding its pseudo-stack of profile labels, -// plus a mutex allowing to synchronize accesses to this pseudo-stack between -// this thread and a possible profiler thread sampling it. -class ThreadStack { - public: - ThreadStack(); - ~ThreadStack(); - - const Stack& stack() const { return stack_; } - - // Returns the mutex to lock around any access to this stack. Each stack is - // accessed by potentially two threads: the thread that it belongs to - // (which calls Push and Pop) and the profiler thread during profiling - // (which calls CopyToSample). - std::mutex& Mutex() const { return mutex_; } - - // Pushes a new label on the top of this Stack. - template - void Push(Args... args) { - // This mutex locking is needed to guard against race conditions as both - // the current thread and the profiler thread may be concurrently accessing - // this stack. In addition to that, this mutex locking also serves the other - // purpose of acting as a barrier (of compiler code reordering, of runtime - // CPU instruction reordering, and of memory access reordering), which - // gives a measure of correctness to this profiler. The downside is some - // latency. As this lock will be uncontended most of the times, the cost - // should be roughly that of an sequentially-consistent atomic access, - // comparable to an access to the level of CPU data cache that is shared - // among all cores, typically 60 cycles on current ARM CPUs, plus side - // effects from barrier instructions. - std::lock_guard lock(mutex_); - // Avoid overrunning the stack, even in 'release' builds. This profiling - // instrumentation code should not ship in release builds anyway, the - // overhead of this check is negligible, and overrunning a stack array would - // be bad. - if (stack_.size >= Stack::kMaxSize) { - abort(); - } - stack_.labels[stack_.size++].Set(args...); - } - - // Pops the top-most label from this Stack. - void Pop() { - // See the comment in Push about this lock. While it would be tempting to - // try to remove this lock and just atomically decrement size_ with a - // store-release, that would not necessarily be a substitute for all of the - // purposes that this lock serves, or if it was done carefully to serve all - // of the same purposes, then that wouldn't be faster than this (mostly - // uncontended) lock. - std::lock_guard lock(mutex_); - stack_.size--; - } - - private: - mutable std::mutex mutex_; - Stack stack_; -}; - -} // namespace detail - -// RAII user-facing way to construct Labels associated with their life scope -// and get them pushed to / popped from the current thread stack. -class ScopeLabel { - public: - template - ScopeLabel(Args... args) : thread_stack_(detail::ThreadLocalThreadStack()) { - thread_stack_->Push(args...); - } - - ~ScopeLabel() { thread_stack_->Pop(); } - - private: - detail::ThreadStack* thread_stack_; -}; - -#else // no RUY_PROFILER - -class ScopeLabel { - public: - template - explicit ScopeLabel(Args...) {} - - // This destructor is needed to consistently silence clang's -Wunused-variable - // which seems to trigger semi-randomly. - ~ScopeLabel() {} -}; - -#endif - -} // namespace profiler -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_INSTRUMENTATION_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/profiler/profiler.cc b/tensorflow/lite/experimental/ruy/ruy/profiler/profiler.cc deleted file mode 100644 index c5ff598ee2b..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/profiler/profiler.cc +++ /dev/null @@ -1,109 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/profiler/profiler.h" - -#ifdef RUY_PROFILER -#include -#include // NOLINT -#include -#include -#include // NOLINT -#include -#endif - -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/treeview.h" - -namespace ruy { -namespace profiler { - -#ifdef RUY_PROFILER - -ScopeProfile::ScopeProfile() { Start(); } -ScopeProfile::ScopeProfile(bool enable) { - if (enable) { - Start(); - } -} -ScopeProfile::~ScopeProfile() { - if (!thread_) { - return; - } - finishing_.store(true); - thread_->join(); - Finish(); -} - -void ScopeProfile::Start() { - { - std::lock_guard lock(*detail::GlobalsMutex()); - if (detail::GlobalIsProfilerRunning()) { - fprintf(stderr, "FATAL: profiler already running!\n"); - abort(); - } - detail::GlobalIsProfilerRunning() = true; - } - finishing_ = false; - thread_.reset(new std::thread(&ScopeProfile::ThreadFunc, this)); -} - -void ScopeProfile::ThreadFunc() { - while (!finishing_.load()) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - std::lock_guard lock(*detail::GlobalsMutex()); - auto* thread_stacks = detail::GlobalAllThreadStacks(); - for (detail::ThreadStack* thread_stack : *thread_stacks) { - Sample(*thread_stack); - } - } -} - -void ScopeProfile::Sample(const detail::ThreadStack& thread_stack) { - std::lock_guard lock(thread_stack.Mutex()); - // Drop empty stacks. - // This ensures that profiles aren't polluted by uninteresting threads. - if (thread_stack.stack().size == 0) { - return; - } - int sample_size = detail::GetBufferSize(thread_stack.stack()); - int old_buf_size = samples_buf_.size(); - samples_buf_.resize(old_buf_size + sample_size); - detail::CopyToBuffer(thread_stack.stack(), - samples_buf_.data() + old_buf_size); -} - -void ScopeProfile::Finish() { - { - std::lock_guard lock(*detail::GlobalsMutex()); - if (!detail::GlobalIsProfilerRunning()) { - fprintf(stderr, "FATAL: profiler is not running!\n"); - abort(); - } - detail::GlobalIsProfilerRunning() = false; - } - if (user_treeview_) { - user_treeview_->Populate(samples_buf_); - } else { - TreeView treeview; - treeview.Populate(samples_buf_); - Print(treeview); - } -} - -#endif // RUY_PROFILER - -} // namespace profiler -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/profiler/profiler.h b/tensorflow/lite/experimental/ruy/ruy/profiler/profiler.h deleted file mode 100644 index 19ef0deba0c..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/profiler/profiler.h +++ /dev/null @@ -1,106 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_PROFILER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_PROFILER_H_ - -#include - -#ifdef RUY_PROFILER -#include -#include -#include -#include -#endif - -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/treeview.h" - -namespace ruy { -namespace profiler { - -#ifdef RUY_PROFILER - -// RAII user-facing way to create a profiler and let it profile a code scope, -// and print out an ASCII/MarkDown treeview upon leaving the scope. -class ScopeProfile { - public: - // Default constructor, unconditionally profiling. - ScopeProfile(); - - // Constructor allowing to choose at runtime whether to profile. - explicit ScopeProfile(bool enable); - - // Destructor. It's where the profile is reported. - ~ScopeProfile(); - - // See treeview_ member. - void SetUserTreeView(TreeView* treeview) { user_treeview_ = treeview; } - - private: - void Start(); - - // Thread entry point function for the profiler thread. This thread is - // created on construction. - void ThreadFunc(); - - // Record a stack as a sample. - void Sample(const detail::ThreadStack& stack); - - // Finalize the profile. Called on destruction. - // If user_treeview_ is non-null, it will receive the treeview. - // Otherwise the treeview will just be printed. - void Finish(); - - // Buffer where samples are recorded during profiling. - std::vector samples_buf_; - - // Used to synchronize thread termination. - std::atomic finishing_; - - // Underlying profiler thread, which will perform the sampling. - // This profiler approach relies on a thread rather than on signals. - std::unique_ptr thread_; - - // TreeView to populate upon destruction. If left null (the default), - // a temporary treeview will be used and dumped on stdout. The user - // may override that by passing their own TreeView object for other - // output options or to directly inspect the TreeView. - TreeView* user_treeview_ = nullptr; -}; - -#else // no RUY_PROFILER - -struct ScopeProfile { - ScopeProfile() { -#ifdef GEMMLOWP_PROFILING - fprintf( - stderr, - "\n\n\n**********\n\nWARNING:\n\nLooks like you defined " - "GEMMLOWP_PROFILING, but this code has been ported to the new ruy " - "profiler replacing the old gemmlowp profiler. You should now be " - "defining RUY_PROFILER and not GEMMLOWP_PROFILING. When building using " - "Bazel, just pass --define=ruy_profiler=true.\n\n**********\n\n\n"); -#endif - } - explicit ScopeProfile(bool) {} -}; - -#endif - -} // namespace profiler -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_PROFILER_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/profiler/test.cc b/tensorflow/lite/experimental/ruy/ruy/profiler/test.cc deleted file mode 100644 index feab967c87c..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/profiler/test.cc +++ /dev/null @@ -1,167 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include -#include "tensorflow/lite/experimental/ruy/ruy/profiler/profiler.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/test_instrumented_library.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/treeview.h" - -namespace ruy { -namespace profiler { -namespace { - -void DoSomeMergeSort(int size) { - std::vector data(size); - - std::default_random_engine engine; - for (auto& val : data) { - val = engine(); - } - - MergeSort(size, data.data()); -} - -// The purpose of this basic test is to cover the basic path that will be taken -// by a majority of users, not inspecting treeviews but just implicitly printing -// them on stdout, and to have this test enabled even when RUY_PROFILER is not -// defined, so that we have coverage for the non-RUY_PROFILER case. -TEST(ProfilerTest, MergeSortSingleThreadBasicTestEvenWithoutProfiler) { - { - ScopeProfile profile; - DoSomeMergeSort(1 << 20); - } -} - -#ifdef RUY_PROFILER - -TEST(ProfilerTest, MergeSortSingleThread) { - TreeView treeview; - { - ScopeProfile profile; - profile.SetUserTreeView(&treeview); - DoSomeMergeSort(1 << 20); - } - Print(treeview); - EXPECT_EQ(treeview.thread_roots().size(), 1); - const auto& thread_root = *treeview.thread_roots().begin()->second; - EXPECT_EQ(DepthOfTreeBelow(thread_root), 22); - EXPECT_GE( - WeightBelowNodeMatchingUnformatted(thread_root, "Merging sorted halves"), - 0.1 * thread_root.weight); - EXPECT_GE(WeightBelowNodeMatchingFormatted( - thread_root, "MergeSortRecurse (level=20, size=1)"), - 0.01 * thread_root.weight); - - TreeView treeview_collapsed; - CollapseNodesMatchingUnformatted(treeview, 5, "MergeSort (size=%d)", - &treeview_collapsed); - Print(treeview_collapsed); - const auto& collapsed_thread_root = - *treeview_collapsed.thread_roots().begin()->second; - EXPECT_EQ(DepthOfTreeBelow(collapsed_thread_root), 6); - EXPECT_EQ( - WeightBelowNodeMatchingUnformatted(thread_root, "MergeSort (size=%d)"), - WeightBelowNodeMatchingUnformatted(collapsed_thread_root, - "MergeSort (size=%d)")); -} - -TEST(ProfilerTest, MemcpyFourThreads) { - TreeView treeview; - { - ScopeProfile profile; - profile.SetUserTreeView(&treeview); - std::vector> threads; - for (int i = 0; i < 4; i++) { - threads.emplace_back(new std::thread([i]() { - ScopeLabel thread_label("worker thread #%d", i); - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - ScopeLabel some_more_work_label("some more work"); - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - })); - } - for (int i = 0; i < 4; i++) { - threads[i]->join(); - } - } - Print(treeview); - // Since we cleared GlobalAllThreadStacks and the current thread hasn't - // created any ScopeLabel, only the 4 worker threads should be recorded. - EXPECT_EQ(treeview.thread_roots().size(), 4); - for (const auto& thread_root : treeview.thread_roots()) { - const TreeView::Node& root_node = *thread_root.second; - // The root node may have 1 or 2 children depending on whether there is - // an "[other]" child. - EXPECT_GE(root_node.children.size(), 1); - EXPECT_LE(root_node.children.size(), 2); - const TreeView::Node& child_node = *root_node.children[0]; - EXPECT_EQ(child_node.label.format(), "worker thread #%d"); - // There must be 2 children, since roughly half the time will be in - // "some more work" leaving the other half in "[other]". - EXPECT_EQ(child_node.children.size(), 2); - const TreeView::Node& child_child_node = *child_node.children[0]; - // Since we sample every millisecond and the threads run for >= 2000 - // milliseconds, the "thread func" label should get roughly 2000 samples. - // Not very rigorous, as we're depending on the profiler thread getting - // scheduled, so to avoid this test being flaky, we use a much more - // conservative value of 500, one quarter of that normal value 2000. - EXPECT_GE(child_node.weight, 500); - // Likewise, allow up to four times more than the normal value 2000. - EXPECT_LE(child_node.weight, 8000); - // Roughly half of time should be spent under the "some more work" label. - float some_more_work_percentage = - 100.f * child_child_node.weight / child_node.weight; - EXPECT_GE(some_more_work_percentage, 40.0f); - EXPECT_LE(some_more_work_percentage, 60.0f); - } -} - -TEST(ProfilerTest, OneThreadAfterAnother) { - TreeView treeview; - { - ScopeProfile profile; - profile.SetUserTreeView(&treeview); - { - std::thread thread([]() { - ScopeLabel thread_label("thread 0"); - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - }); - thread.join(); - } - { - std::thread thread([]() { - ScopeLabel thread_label("thread 1"); - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - }); - thread.join(); - } - } - Print(treeview); - EXPECT_EQ(treeview.thread_roots().size(), 2); -} - -#endif // RUY_PROFILER - -} // namespace -} // namespace profiler -} // namespace ruy - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/experimental/ruy/ruy/profiler/test_instrumented_library.cc b/tensorflow/lite/experimental/ruy/ruy/profiler/test_instrumented_library.cc deleted file mode 100644 index e9b5929c9b7..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/profiler/test_instrumented_library.cc +++ /dev/null @@ -1,59 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -namespace { - -void MergeSortRecurse(int level, int size, int* data, int* workspace) { - ruy::profiler::ScopeLabel function_label( - "MergeSortRecurse (level=%d, size=%d)", level, size); - if (size <= 1) { - return; - } - int half_size = size / 2; - MergeSortRecurse(level + 1, half_size, data, workspace); - MergeSortRecurse(level + 1, size - half_size, data + half_size, - workspace + half_size); - - ruy::profiler::ScopeLabel merging_sorted_halves_label( - "Merging sorted halves"); - int dst_index = 0; - int left_index = 0; - int right_index = half_size; - while (dst_index < size) { - int val; - if (left_index < half_size && - ((right_index >= size) || data[left_index] < data[right_index])) { - val = data[left_index++]; - } else { - val = data[right_index++]; - } - workspace[dst_index++] = val; - } - for (int i = 0; i < size; i++) { - data[i] = workspace[i]; - } -} - -} // namespace - -void MergeSort(int size, int* data) { - ruy::profiler::ScopeLabel function_label("MergeSort (size=%d)", size); - std::vector workspace(size); - MergeSortRecurse(0, size, data, workspace.data()); -} diff --git a/tensorflow/lite/experimental/ruy/ruy/profiler/treeview.cc b/tensorflow/lite/experimental/ruy/ruy/profiler/treeview.cc deleted file mode 100644 index 256d2a1106c..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/profiler/treeview.cc +++ /dev/null @@ -1,248 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifdef RUY_PROFILER - -#include "tensorflow/lite/experimental/ruy/ruy/profiler/treeview.h" - -#include -#include -#include -#include -#include - -namespace ruy { -namespace profiler { - -namespace { - -void SortNode(TreeView::Node* node) { - using NodePtr = std::unique_ptr; - std::sort(node->children.begin(), node->children.end(), - [](const NodePtr& n1, const NodePtr& n2) { - return n1->weight > n2->weight; - }); - for (const auto& child : node->children) { - SortNode(child.get()); - } -} - -// Records a stack i.e. a sample in a treeview, by incrementing the weights -// of matching existing nodes and/or by creating new nodes as needed, -// recursively, below the given node. -void AddStack(const detail::Stack& stack, TreeView::Node* node, int level) { - node->weight++; - if (stack.size == level) { - return; - } - TreeView::Node* child_to_add_to = nullptr; - for (const auto& child : node->children) { - if (child->label == stack.labels[level]) { - child_to_add_to = child.get(); - break; - } - } - if (!child_to_add_to) { - child_to_add_to = node->children.emplace_back(new TreeView::Node).get(); - child_to_add_to->label = stack.labels[level]; - } - AddStack(stack, child_to_add_to, level + 1); -} - -// Recursively populates the treeview below the given node with 'other' -// entries documenting for each node the difference between its weight and the -// sum of its children's weight. -void AddOther(TreeView::Node* node) { - int top_level_children_weight = 0; - for (const auto& child : node->children) { - AddOther(child.get()); - top_level_children_weight += child->weight; - } - if (top_level_children_weight != 0 && - top_level_children_weight != node->weight) { - const auto& new_child = node->children.emplace_back(new TreeView::Node); - new_child->label = Label("[other]"); - new_child->weight = node->weight - top_level_children_weight; - } -} - -} // namespace - -void TreeView::Populate(const std::vector& samples_buf_) { - thread_roots_.clear(); - // Populate the treeview with regular nodes coming from samples. - const char* buf_ptr = samples_buf_.data(); - const char* const buf_ptr_end = buf_ptr + samples_buf_.size(); - while (buf_ptr < buf_ptr_end) { - detail::Stack stack; - detail::ReadFromBuffer(buf_ptr, &stack); - // Empty stacks should have been dropped during sampling. - assert(stack.size > 0); - buf_ptr += GetBufferSize(stack); - const int id = stack.id; - if (!thread_roots_[id]) { - thread_roots_[id].reset(new Node); - } - AddStack(stack, thread_roots_[id].get(), 0); - } - // Populate the treeview with additional 'other' nodes, sort, and set - // root labels. - for (const auto& thread_root : thread_roots_) { - std::uint32_t id = thread_root.first; - Node* root = thread_root.second.get(); - AddOther(root); - SortNode(root); - root->label.Set("Thread %x (%d samples)", id, root->weight); - } -} - -// Recursively prints the treeview below the given node. The 'root' node -// argument is only needed to compute weights ratios, with the root ratio -// as denominator. -void PrintTreeBelow(const TreeView::Node& node, const TreeView::Node& root, - int level) { - if (&node == &root) { - printf("%s\n\n", node.label.Formatted().c_str()); - } else { - for (int i = 1; i < level; i++) { - printf(" "); - } - printf("* %.2f%% %s\n", 100.0f * node.weight / root.weight, - node.label.Formatted().c_str()); - } - for (const auto& child : node.children) { - PrintTreeBelow(*child, root, level + 1); - } -} - -void Print(const TreeView& treeview) { - printf("\n"); - printf("Profile (%d threads):\n\n", - static_cast(treeview.thread_roots().size())); - for (const auto& thread_root : treeview.thread_roots()) { - const TreeView::Node& root = *thread_root.second; - PrintTreeBelow(root, root, 0); - printf("\n"); - } -} - -int DepthOfTreeBelow(const TreeView::Node& node) { - if (node.children.empty()) { - return 0; - } else { - int max_child_depth = 0; - for (const auto& child : node.children) { - max_child_depth = std::max(max_child_depth, DepthOfTreeBelow(*child)); - } - return 1 + max_child_depth; - } -} - -int WeightBelowNodeMatchingFunction( - const TreeView::Node& node, - const std::function& match) { - int weight = 0; - if (match(node.label)) { - weight += node.weight; - } - for (const auto& child : node.children) { - weight += WeightBelowNodeMatchingFunction(*child, match); - } - return weight; -} - -int WeightBelowNodeMatchingUnformatted(const TreeView::Node& node, - const std::string& format) { - return WeightBelowNodeMatchingFunction( - node, [&format](const Label& label) { return label.format() == format; }); -} - -int WeightBelowNodeMatchingFormatted(const TreeView::Node& node, - const std::string& formatted) { - return WeightBelowNodeMatchingFunction( - node, [&formatted](const Label& label) { - return label.Formatted() == formatted; - }); -} - -void CollapseNode(const TreeView::Node& node_in, int depth, - TreeView::Node* node_out) { - node_out->label = node_in.label; - node_out->weight = node_in.weight; - node_out->children.clear(); - if (depth > 0) { - for (const auto& child_in : node_in.children) { - auto* child_out = new TreeView::Node; - node_out->children.emplace_back(child_out); - CollapseNode(*child_in, depth - 1, child_out); - } - } -} - -void CollapseSubnodesMatchingFunction( - const TreeView::Node& node_in, int depth, - const std::function& match, TreeView::Node* node_out) { - if (match(node_in.label)) { - CollapseNode(node_in, depth, node_out); - } else { - node_out->label = node_in.label; - node_out->weight = node_in.weight; - node_out->children.clear(); - - for (const auto& child_in : node_in.children) { - auto* child_out = new TreeView::Node; - node_out->children.emplace_back(child_out); - CollapseSubnodesMatchingFunction(*child_in, depth, match, child_out); - } - } -} - -void CollapseNodesMatchingFunction( - const TreeView& treeview_in, int depth, - const std::function& match, TreeView* treeview_out) { - treeview_out->mutable_thread_roots()->clear(); - for (const auto& thread_root_in : treeview_in.thread_roots()) { - std::uint32_t id = thread_root_in.first; - const auto& root_in = *thread_root_in.second; - auto* root_out = new TreeView::Node; - treeview_out->mutable_thread_roots()->emplace(id, root_out); - CollapseSubnodesMatchingFunction(root_in, depth, match, root_out); - } -} - -void CollapseNodesMatchingUnformatted(const TreeView& treeview_in, int depth, - const std::string& format, - TreeView* treeview_out) { - CollapseNodesMatchingFunction( - treeview_in, depth, - [&format](const Label& label) { return label.format() == format; }, - treeview_out); -} - -void CollapseNodesMatchingFormatted(const TreeView& treeview_in, int depth, - const std::string& formatted, - TreeView* treeview_out) { - CollapseNodesMatchingFunction( - treeview_in, depth, - [&formatted](const Label& label) { - return label.Formatted() == formatted; - }, - treeview_out); -} - -} // namespace profiler -} // namespace ruy - -#endif // RUY_PROFILER diff --git a/tensorflow/lite/experimental/ruy/ruy/profiler/treeview.h b/tensorflow/lite/experimental/ruy/ruy/profiler/treeview.h deleted file mode 100644 index 7f48af5ece0..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/profiler/treeview.h +++ /dev/null @@ -1,130 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TREEVIEW_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TREEVIEW_H_ - -#ifdef RUY_PROFILER - -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -namespace ruy { -namespace profiler { - -// A tree view of a profile. -class TreeView { - public: - struct Node { - std::vector> children; - Label label; - int weight = 0; - }; - - void Populate(const std::vector& samples_buf_); - - // Intentionally an *ordered* map so that threads are enumerated - // in an order that's consistent and typically putting the 'main thread' - // first. - using ThreadRootsMap = std::map>; - - const ThreadRootsMap& thread_roots() const { return thread_roots_; } - ThreadRootsMap* mutable_thread_roots() { return &thread_roots_; } - - private: - ThreadRootsMap thread_roots_; -}; - -/* Below are API functions for manipulating and printing treeviews. */ - -// Prints the treeview to stdout. -void Print(const TreeView& treeview); - -// Prints the treeview below the given node on stdout. -void PrintTreeBelow(const TreeView::Node& node); - -// Returns the tree depth below the given node. -int DepthOfTreeBelow(const TreeView::Node& node); - -// Returns the sum of weights of nodes below the given node and filtered by -// the `match` predicate. -int WeightBelowNodeMatchingFunction( - const TreeView::Node& node, const std::function& match); - -// Returns the sum of weights of nodes below the given node and whose -// unformatted label (i.e. raw format string) matches the given `format` string. -// -// This allows to aggregate nodes whose labels differ only by parameter values. -int WeightBelowNodeMatchingUnformatted(const TreeView::Node& node, - const std::string& format); - -// Returns the sum of weights of nodes below the given node and whose formatted -// label matches the `formatted` string. -// -// In the case of nodes with parametrized labels, this allows to count only -// nodes with specific parameter values. For that purpose, one may also instead -// use WeightBelowNodeMatchingFunction directly, with a `match` predicate -// comparing raw integer parameter values directly, instead of going through -// formatted strings. -int WeightBelowNodeMatchingFormatted(const TreeView::Node& node, - const std::string& formatted); - -// Produces a `node_out` that is a copy of `node_in` but with tree depth below -// it clamped at `depth`, with further subtrees aggregated into single leaf -// nodes. -void CollapseNode(const TreeView::Node& node_in, int depth, - TreeView::Node* node_out); - -// Calls CollapseNode with the given `depth` on every subnode filtered by the -// `match` predicate. Note that this does NOT limit the tree depth below -// `node_out` to `depth`, since each collapsed node below `node_out` may be -// arbitrarily far below it and `depth` is only used as the collapsing depth -// at that point. -void CollapseSubnodesMatchingFunction( - const TreeView::Node& node_in, int depth, - const std::function& match, TreeView::Node* node_out); - -// Calls CollapseNode with the given `depth` on every node filtered by the -// `match` predicate. Note that this does NOT limit the tree depth below -// `node_out` to `depth`, since each collapsed node below `node_out` may be -// arbitrarily far below it and `depth` is only used as the collapsing depth -// at that point. -void CollapseNodesMatchingFunction( - const TreeView& treeview_in, int depth, - const std::function& match, TreeView* treeview_out); - -// Special case of CollapseNodesMatchingFunction matching unformatted labels, -// i.e. raw format strings. -// See the comment on WeightBelowNodeMatchingUnformatted. -void CollapseNodesMatchingUnformatted(const TreeView& treeview_in, int depth, - const std::string& format, - TreeView* treeview_out); - -// Special case of CollapseNodesMatchingFunction matching formatted labels. -// See the comment on WeightBelowNodeMatchingFormatted. -void CollapseNodesMatchingFormatted(const TreeView& treeview_in, int depth, - const std::string& formatted, - TreeView* treeview_out); - -} // namespace profiler -} // namespace ruy - -#endif // RUY_PROFILER - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TREEVIEW_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/ruy.h b/tensorflow/lite/experimental/ruy/ruy/ruy.h deleted file mode 100644 index 783c410cf82..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/ruy.h +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This is the only Ruy header that users should #include. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_H_ - -#include "tensorflow/lite/experimental/ruy/ruy/context.h" -#include "tensorflow/lite/experimental/ruy/ruy/dispatch.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/spec.h" - -namespace ruy { - -// Performs a multiplication of matrices. This is Ruy's only API entry point. -// Should be self-explanatory given the above documentation for each of Matrix, -// Spec and Context. -template -void Mul(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, Context* context, Matrix* dst) { - DispatchMul( - lhs, rhs, spec, context, dst); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/ruy_advanced.h b/tensorflow/lite/experimental/ruy/ruy/ruy_advanced.h deleted file mode 100644 index 0b24636ef06..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/ruy_advanced.h +++ /dev/null @@ -1,69 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_ADVANCED_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_ADVANCED_H_ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/context.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/prepack.h" -#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" - -namespace ruy { - -// Low-level, explicit pre-packing API. -// -// The cost of packing an input matrix (either the LHS or RHS) is amortized -// across the non-depth dimension of the opposite input matrix. Thus, when the -// LHS has very few rows or the RHS has very few columns, the cost of packing -// the opposite input matrix can become significant. See pack.h for further -// information on packing. -// -// This file provides an API allowing a user to explicitly pack a matrix and -// reuse the pre-packed matrix, avoiding that cost. -// -// See example_prepack.cc for example usage. - -template -void PrePackForMul(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, Context* context, Matrix* dst, - PrepackedMatrix* prepacked_lhs, - PrepackedMatrix* prepacked_rhs, - std::function alloc_fn) { - SidePair prepacked(prepacked_lhs, prepacked_rhs); - PrePackForMulInternal(lhs, rhs, spec, context, dst, prepacked, - alloc_fn); -} - -template -void MulWithPrepacked(const Matrix& lhs, - const Matrix& rhs, const Spec& spec, - Context* context, Matrix* dst, - PrepackedMatrix* prepacked_lhs, - PrepackedMatrix* prepacked_rhs) { - SidePair prepacked(prepacked_lhs, prepacked_rhs); - MulWithPrepackedInternal(lhs, rhs, spec, context, dst, - prepacked); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_ADVANCED_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/ruy_test.bzl b/tensorflow/lite/experimental/ruy/ruy/ruy_test.bzl deleted file mode 100644 index ef7e8b1bb79..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/ruy_test.bzl +++ /dev/null @@ -1,34 +0,0 @@ -# Provides the ruy_test macro for type-parametrized tests. -"""ruy_test is a macro for building a test with multiple paths corresponding to tuples of types for LHS, RHS, accumulator and destination.""" - -def ruy_test(name, srcs, lhs_rhs_accum_dst, copts, tags = [], deps = None): - for (lhs, rhs, accum, dst) in lhs_rhs_accum_dst: - native.cc_test( - name = "%s_%s_%s_%s_%s" % (name, lhs, rhs, accum, dst), - srcs = srcs, - copts = copts + [ - "-DRUY_TEST_LHSSCALAR=%s" % lhs, - "-DRUY_TEST_RHSSCALAR=%s" % rhs, - "-DRUY_TEST_ACCUMSCALAR=%s" % accum, - "-DRUY_TEST_DSTSCALAR=%s" % dst, - ], - deps = deps, - tags = tags, - ) - -def ruy_benchmark(name, srcs, lhs_rhs_accum_dst, copts, deps = None): - tags = ["req_dep=//third_party/gemmlowp:profiler"] - for (lhs, rhs, accum, dst) in lhs_rhs_accum_dst: - native.cc_binary( - name = "%s_%s_%s_%s_%s" % (name, lhs, rhs, accum, dst), - testonly = True, - srcs = srcs, - copts = copts + [ - "-DRUY_TEST_LHSSCALAR=%s" % lhs, - "-DRUY_TEST_RHSSCALAR=%s" % rhs, - "-DRUY_TEST_ACCUMSCALAR=%s" % accum, - "-DRUY_TEST_DSTSCALAR=%s" % dst, - ], - deps = deps, - tags = tags, - ) diff --git a/tensorflow/lite/experimental/ruy/ruy/ruy_test_ext.bzl b/tensorflow/lite/experimental/ruy/ruy/ruy_test_ext.bzl deleted file mode 100644 index 5701fffa0f7..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/ruy_test_ext.bzl +++ /dev/null @@ -1,7 +0,0 @@ -"""Allows to specialize the ruy BUILD to availability of external libraries""" - -def ruy_test_ext_defines(): - return [] - -def ruy_test_ext_deps(): - return [] diff --git a/tensorflow/lite/experimental/ruy/ruy/side_pair.h b/tensorflow/lite/experimental/ruy/ruy/side_pair.h deleted file mode 100644 index a3210e27a53..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/side_pair.h +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIDE_PAIR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIDE_PAIR_H_ - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" - -namespace ruy { - -// Enumeration of the sides, i.e. the operands 'slots', in a matrix -// multiplication. The numerical values of these enumeration constants matter -// because these will be used as indices into the array underlying a SidePair. -enum class Side { - // Left-hand side - kLhs = 0, - // Right-hand side - kRhs = 1 -}; - -// SidePair is a pair container where the two elements are indexed by a Side -// enum. -template -class SidePair final { - public: - SidePair() {} - SidePair(const T& a, const T& b) : elem_{a, b} {} - const T& operator[](Side side) const { - const int index = static_cast(side); - // Technically this check is vacuous, since other values would be - // out-of-range for enum Side. - RUY_DCHECK(index == 0 || index == 1); - return elem_[index]; - } - - T& operator[](Side side) { - const int index = static_cast(side); - // Technically this check is vacuous, since other values would be - // out-of-range for enum Side. - RUY_DCHECK(index == 0 || index == 1); - return elem_[index]; - } - - private: - static_assert(static_cast(Side::kLhs) == 0, ""); - static_assert(static_cast(Side::kRhs) == 1, ""); - T elem_[2]; -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIDE_PAIR_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/size_util.h b/tensorflow/lite/experimental/ruy/ruy/size_util.h deleted file mode 100644 index 56dd095de85..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/size_util.h +++ /dev/null @@ -1,93 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIZE_UTIL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIZE_UTIL_H_ - -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" - -#ifdef _WIN32 -#include -#endif - -namespace ruy { - -template -inline Integer floor_log2(Integer n) { - static_assert(std::is_integral::value, ""); - static_assert(std::is_signed::value, ""); - static_assert(sizeof(Integer) == 4 || sizeof(Integer) == 8, ""); - - RUY_DCHECK_GE(n, 1); -#ifdef _WIN32 - unsigned long result; // NOLINT[runtime/int] - if (sizeof(Integer) == 4) { - _BitScanReverse(&result, n); - } else { - _BitScanReverse64(&result, n); - } - return result; -#else - if (sizeof(Integer) == 4) { - return 31 - __builtin_clz(n); - } else { - return 63 - __builtin_clzll(n); - } -#endif -} - -template -Integer ceil_log2(Integer n) { - RUY_DCHECK_GE(n, 1); - return n == 1 ? 0 : floor_log2(n - 1) + 1; -} - -template -bool is_pot(Integer value) { - return (value > 0) && ((value & (value - 1)) == 0); -} - -template -Integer pot_log2(Integer n) { - RUY_DCHECK(is_pot(n)); - return floor_log2(n); -} - -template -Integer round_down_pot(Integer value) { - return static_cast(1) << floor_log2(value); -} - -template -Integer round_up_pot(Integer value) { - return static_cast(1) << ceil_log2(value); -} - -template -Integer round_down_pot(Integer value, Modulo modulo) { - RUY_DCHECK_EQ(modulo & (modulo - 1), 0); - return value & ~(modulo - 1); -} - -template -Integer round_up_pot(Integer value, Modulo modulo) { - return round_down_pot(value + modulo - 1, modulo); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIZE_UTIL_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/size_util_test.cc b/tensorflow/lite/experimental/ruy/ruy/size_util_test.cc deleted file mode 100644 index 442c31958cc..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/size_util_test.cc +++ /dev/null @@ -1,101 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" - -#include -#include -#include - -#include - -namespace ruy { -namespace { - -template -void SizeUtilTestValue(Integer value) { - if (value == 0) { - return; - } - - EXPECT_LE(0, floor_log2(value)); - EXPECT_LE(floor_log2(value), ceil_log2(value)); - EXPECT_LE(ceil_log2(value), 8 * sizeof(Integer)); - - if (is_pot(value)) { - EXPECT_EQ(floor_log2(value), ceil_log2(value)); - EXPECT_EQ(floor_log2(value), pot_log2(value)); - } else { - EXPECT_EQ(floor_log2(value) + 1, ceil_log2(value)); - } - EXPECT_EQ(value >> floor_log2(value), 1); - EXPECT_EQ(round_down_pot(value), static_cast(1) - << floor_log2(value)); - EXPECT_LE(round_down_pot(value), value); - EXPECT_GE(round_down_pot(value), value >> 1); - EXPECT_TRUE(is_pot(round_down_pot(value))); - - if (ceil_log2(value) < 8 * sizeof(Integer) - 1) { - EXPECT_EQ(value >> ceil_log2(value), is_pot(value) ? 1 : 0); - EXPECT_EQ(round_up_pot(value), static_cast(1) << ceil_log2(value)); - EXPECT_GE(round_up_pot(value), value); - EXPECT_LE(round_up_pot(value) >> 1, value); - EXPECT_TRUE(is_pot(round_up_pot(value))); - } - - for (std::uint8_t modulo : {1, 2, 8, 32, 128}) { - EXPECT_GE(value, round_down_pot(value, modulo)); - EXPECT_EQ(round_down_pot(value, modulo) % modulo, 0); - - if (value <= std::numeric_limits::max() - modulo) { - EXPECT_LE(value, round_up_pot(value, modulo)); - EXPECT_EQ(round_up_pot(value, modulo) % modulo, 0); - } - } -} - -template -void SizeUtilTest() { - for (int exponent = 0; exponent < 8 * sizeof(Integer) - 1; exponent++) { - const Integer pot = static_cast(1) << exponent; - SizeUtilTestValue(pot - 1); - SizeUtilTestValue(pot); - SizeUtilTestValue(pot + 1); - SizeUtilTestValue(pot + 12); - SizeUtilTestValue(pot + 123); - } - SizeUtilTestValue(std::numeric_limits::max() - 1); - SizeUtilTestValue(std::numeric_limits::max()); -} - -TEST(SizeUtilTest, Int) { SizeUtilTest(); } - -TEST(SizeUtilTest, Long) { SizeUtilTest(); } // NOLINT - -TEST(SizeUtilTest, LongLong) { SizeUtilTest(); } // NOLINT - -TEST(SizeUtilTest, Int32) { SizeUtilTest(); } - -TEST(SizeUtilTest, Int64) { SizeUtilTest(); } - -TEST(SizeUtilTest, Ptrdiff) { SizeUtilTest(); } - -} // namespace -} // namespace ruy - -int main(int argc, char **argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/experimental/ruy/ruy/spec.h b/tensorflow/lite/experimental/ruy/ruy/spec.h deleted file mode 100644 index 584d90ea047..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/spec.h +++ /dev/null @@ -1,118 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SPEC_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SPEC_H_ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/cpu_cache_size.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" - -namespace ruy { - -// Our 'general' loop structure (the default) involves multi-threading and -// complicated loops aiming to optimize cache-friendliness. One may opt out of -// this and pick the 'simple' loop structure instead, which only performs well -// for small matrix sizes and only allows using one thread, in exchange for -// smaller code size. -enum class LoopStructure { kGeneral, kSimple, kAuto }; - -// In general we allow zero_point's to have any Scalar value. This is called -// 'asymmetric' quantization. We do take advantage of the optimization -// opportunities when zero_points happen at runtime to be 'symmetric' (e.g. the -// int8 value 0 or the uint8 value 128), but we still generate code to handle -// the general asymmetric case. By choosing kSymmetric here, one opts out of -// this and supports only the symmetric case, in exchange for smaller code size. -enum class ZeroPointSupport { kGeneral, kSymmetric }; - -// In general we allow all Layout's, even if we may use slow paths for some -// kinds of layouts. By choosing kRCC, one may opt out of this and -// only keep support for the simplest and most efficient combination of -// Layout's, in exchange for smaller code size. The case covered by -// kRCC is where the storage orders are exactly the following: -// - LHS is RowMajor -// - RHS is ColMajor -// - Destination is ColMajor -enum class LayoutSupport { kGeneral, kRCC }; - -// A Spec describes all about a matrix multiplication operation that isn't -// encoded in the LHS, RHS and destination matrices. Some of that information -// is encoded as compile-time constants and types (for instance, the choice -// of accumulator type, AccumScalar). Some of that information is encoded as -// runtime values (for instance, the optional bias vector). -template -struct BasicSpec { - // Accumulator type. The type of accumulators used to compute the dot-products - // before being ultimately casted to the destination type. - using AccumScalar = tAccumScalar; - // The destination scalar type. - using DstScalar = tDstScalar; - // The bias vector data, if not null. - const AccumScalar* bias = nullptr; - // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa) - // of the multiplier by which accumulators are multiplied before being casted - // to the destination type. - AccumScalar multiplier_fixedpoint = 0; - // Only for non-floating-point cases. The exponent part of the aforementioned - // multiplier. - int multiplier_exponent = 0; - // Per-channel variant of multiplier_fixedpoint. If not nullptr, this must - // point to a buffer of as many values as there are rows in the destination - // matrix. Each row of the destination matrix will use the corresponding - // buffer element instead of multiplier_fixedpoint. - const AccumScalar* multiplier_fixedpoint_perchannel = nullptr; - // Per-channel variant of multiplier_exponent. If not nullptr, this must - // point to a buffer of as many values as there are rows in the destination - // matrix. Each row of the destination matrix will use the corresponding - // buffer element instead of multiplier_exponent. - // - // Either none or both of multiplier_exponent_perchannel and - // multiplier_fixedpoint_perchannel must be nullptr. - const int* multiplier_exponent_perchannel = nullptr; - // min clamp bound of destination values. - DstScalar clamp_min = std::is_floating_point::value - ? -std::numeric_limits::infinity() - : std::numeric_limits::lowest(); - // max clamp bound of destination values. - DstScalar clamp_max = std::is_floating_point::value - ? std::numeric_limits::infinity() - : std::numeric_limits::max(); - // See above enum LoopStructure - static constexpr LoopStructure kLoopStructure = LoopStructure::kAuto; - // See above enum LayoutSupport - static constexpr LayoutSupport kLayoutSupport = LayoutSupport::kGeneral; - // See above enum ZeroPointSupport - static constexpr ZeroPointSupport kZeroPointSupport = - ZeroPointSupport::kGeneral; - // Testing-only, not meant to be used by actual users: - // Used for testing of various kernel layouts. - using StandardCppKernelLhsLayout = FixedKernelLayout; - using StandardCppKernelRhsLayout = FixedKernelLayout; - // Returns (a reasonable estimate of) the local CPU cache size. - // See ruy::LocalDataCacheSize() which returns some coarse, sane default for - // each CPU architecture. - // This may be overridden, either to provide more accurate/runtime values, - // or to test with other values to let testcases have more coverage. - static int local_data_cache_size() { return LocalDataCacheSize(); } - // Same as local_data_cache_size but for the total data cache size accessible - // to each CPU core. See ruy::SharedDataCacheSize(). - static int shared_data_cache_size() { return SharedDataCacheSize(); } -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SPEC_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/test.h b/tensorflow/lite/experimental/ruy/ruy/test.h deleted file mode 100644 index 305b5a844fa..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/test.h +++ /dev/null @@ -1,2125 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TEST_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TEST_H_ - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include // IWYU pragma: export -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" // IWYU pragma: export -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/pmu.h" -#include "tensorflow/lite/experimental/ruy/ruy/ruy.h" -#include "tensorflow/lite/experimental/ruy/ruy/ruy_advanced.h" -#include "tensorflow/lite/experimental/ruy/ruy/spec.h" // IWYU pragma: export -#include "tensorflow/lite/experimental/ruy/ruy/time.h" - -#ifdef RUY_TEST_EXTERNAL_PATHS -#define EIGEN_USE_THREADS -#define EIGEN_USE_CUSTOM_THREAD_POOL -#include "third_party/eigen3/Eigen/Core" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "public/gemmlowp.h" -#include "third_party/lapack/blas.h" -#endif - -#ifdef RUY_PROFILER -#include "tensorflow/lite/experimental/ruy/ruy/profiler/profiler.h" -#endif - -namespace ruy { - -const float kClampRatio = 0.1f; - -enum class ExternalPath { kNone, kGemmlowp, kEigen, kEigenTensor, kOpenBlas }; - -inline std::vector* CoveredPaths() { - static std::vector covered_paths; - return &covered_paths; -} - -inline const char* PathName(Path path) { -#define RUY_PATHNAME_CASE(NAME) \ - case Path::NAME: \ - return #NAME; - switch (path) { - RUY_PATHNAME_CASE(kReference) - RUY_PATHNAME_CASE(kStandardCpp) -#if RUY_PLATFORM(NEON) - RUY_PATHNAME_CASE(kNeon) - RUY_PATHNAME_CASE(kNeonDotprod) -#elif RUY_PLATFORM(X86) - RUY_PATHNAME_CASE(kSse42) - RUY_PATHNAME_CASE(kAvx2) - RUY_PATHNAME_CASE(kAvx512) - RUY_PATHNAME_CASE(kAvxVnni) -#endif - default: - RUY_CHECK(false); - return nullptr; - } -#undef RUY_PATHNAME_CASE -} - -inline const char* TuningName(Tuning tuning) { -#define RUY_SUBPATHNAME_CASE(NAME) \ - case Tuning::NAME: \ - return #NAME; - switch (tuning) { - RUY_SUBPATHNAME_CASE(kInOrder) - RUY_SUBPATHNAME_CASE(kOutOfOrder) - default: - RUY_CHECK(false); - return nullptr; - } -#undef RUY_SUBPATHNAME_CASE -} - -inline const char* PathName(ExternalPath path) { -#define RUY_PATHNAME_CASE(NAME) \ - case ExternalPath::NAME: \ - return #NAME; - switch (path) { - RUY_PATHNAME_CASE(kGemmlowp) - RUY_PATHNAME_CASE(kEigen) - RUY_PATHNAME_CASE(kEigenTensor) - RUY_PATHNAME_CASE(kOpenBlas) - default: - RUY_CHECK(false); - return nullptr; - } -#undef RUY_PATHNAME_CASE -} - -inline std::ostream& operator<<(std::ostream& stream, Path path) { - return stream << PathName(path); -} - -inline std::ostream& operator<<(std::ostream& stream, - ExternalPath external_path) { - return stream << PathName(external_path); -} - -template -std::string Join(const ContainerType& container) { - if (container.empty()) { - return ""; - } - std::ostringstream stream; - auto it = container.begin(); - stream << *it++; - for (; it != container.end(); ++it) { - stream << ", "; - stream << *it; - } - return stream.str(); -} - -struct LogCoveredPathsOnDestruction final { - ~LogCoveredPathsOnDestruction() { - std::cerr << "Covered paths: " << Join(*CoveredPaths()) << std::endl; - - // When testing on ARM64 ChromiumOS emulator, make sure that we covered - // the dotprod path. We're getting such coverage at the moment thanks to - // using a sufficiently recent emulator, and we don't want to regress that. -#if RUY_PLATFORM(ARM_64) && defined RUY_TESTING_ON_CHROMIUMOS - bool found_dotprod = false; - for (const std::string& covered_path : *CoveredPaths()) { - if (covered_path == "kNeonDotprod") { - found_dotprod = true; - } - } - if (!found_dotprod) { - std::cerr - << "Error: we haven't tested the kNeonDotprod path as we should " - "have. At the moment, this is required on ChromiumOS as this is " - "what we run emulator tests in, that currently supports " - "dot-product " - "instructions, and we care very much about not regressing that. " - "If this test was run in an emulator, please upgrade to a newer " - "emulator version. If this test was run on an actual device, and " - "you need to be able to run ruy tests on devices not supporting " - "dot-product instructions, get in touch with us.\n" - << std::endl; - abort(); - } -#endif - } - static void Singleton() { static LogCoveredPathsOnDestruction singleton; } -}; - -enum class RandomRange { - kGeneral, - kAvoidMinValue, - kOffCenterAvoidMinValue, - kReasonableSrcZeroPoint, - kReasonableDstZeroPoint, - kBias -}; - -template ::value> -struct RandomRangeBounds {}; - -template -struct RandomRangeBounds { - static Scalar GetMinBound(RandomRange range) { - switch (range) { - case RandomRange::kGeneral: - return -1; - case RandomRange::kAvoidMinValue: - return -1; - case RandomRange::kOffCenterAvoidMinValue: - return -1; - case RandomRange::kReasonableSrcZeroPoint: - return 0; - case RandomRange::kReasonableDstZeroPoint: - return 0; - case RandomRange::kBias: - return -1; - default: - RUY_CHECK(false); - return 0; - } - } - static Scalar GetMaxBound(RandomRange range) { - switch (range) { - case RandomRange::kGeneral: - return 1; - case RandomRange::kAvoidMinValue: - return 1; - case RandomRange::kOffCenterAvoidMinValue: - return 1; - case RandomRange::kReasonableSrcZeroPoint: - return 0; - case RandomRange::kReasonableDstZeroPoint: - return 0; - case RandomRange::kBias: - return 1; - default: - RUY_CHECK(false); - return 0; - } - } -}; - -template -Scalar WeightedSum(Scalar s1, float weight1, Scalar s2, float weight2) { - float sum = s1 * weight1 + s2 * weight2; - float clamped = std::min( - std::numeric_limits::max(), - std::max(std::numeric_limits::lowest(), sum)); - return static_cast(clamped); -} - -template -Scalar Parametrized(float param) { - return WeightedSum(std::numeric_limits::max(), param, - std::numeric_limits::lowest(), 1 - param); -} - -template -struct RandomRangeBounds { - static Scalar GetMinBound(RandomRange range) { - static constexpr double offcenteredness = - 0.02; // Shift lower limit by about 5 for range of 255. - switch (range) { - case RandomRange::kGeneral: - return std::numeric_limits::lowest(); - case RandomRange::kAvoidMinValue: - return 1 + std::numeric_limits::lowest(); - case RandomRange::kOffCenterAvoidMinValue: - return 1 + std::numeric_limits::lowest() + - static_cast( - offcenteredness * std::numeric_limits::max() - - offcenteredness * - (std::numeric_limits::lowest() + 1)); - case RandomRange::kReasonableSrcZeroPoint: - return std::numeric_limits::lowest(); - case RandomRange::kReasonableDstZeroPoint: - return Parametrized(0.4); - case RandomRange::kBias: - return std::is_same::value - ? static_cast(-10000) - : 0; - default: - RUY_CHECK(false); - return 0; - } - } - static Scalar GetMaxBound(RandomRange range) { - switch (range) { - case RandomRange::kGeneral: - return std::numeric_limits::max(); - case RandomRange::kAvoidMinValue: - return std::numeric_limits::max(); - case RandomRange::kOffCenterAvoidMinValue: - return std::numeric_limits::max(); - case RandomRange::kReasonableSrcZeroPoint: - return std::numeric_limits::max(); - case RandomRange::kReasonableDstZeroPoint: - return Parametrized(0.6); - case RandomRange::kBias: - return std::is_same::value - ? static_cast(10000) - : 0; - default: - RUY_CHECK(false); - return 0; - } - } -}; - -inline std::default_random_engine& global_random_engine() { - static std::default_random_engine engine; - return engine; -} - -template -struct UniformRandomDistribution { - UniformRandomDistribution(RandomRange range) - : dist(RandomRangeBounds::GetMinBound(range), - RandomRangeBounds::GetMaxBound(range)) {} - Scalar Get() { return dist(global_random_engine()); } - // std::uniform_int_distribution is specified not to support char types, - // only short and wider types. MSVC actually generates an error on - // std::uniform_int_distribution. - using StdDistType = typename std::conditional< - std::is_floating_point::value, - std::uniform_real_distribution, - std::uniform_int_distribution>::type; - StdDistType dist; -}; - -template -void MakeRandomScalar(UniformRandomDistribution* uniform_dist, - Scalar* dst) { - *dst = uniform_dist->Get(); -} - -template -void MakeRandomVector(UniformRandomDistribution* uniform_dist, int size, - std::vector* dst) { - dst->resize(size); - for (auto& x : *dst) { - MakeRandomScalar(uniform_dist, &x); - } -} - -template -void MakeRandomScalar(RandomRange range, Scalar* dst) { - UniformRandomDistribution dist(range); - *dst = dist.Get(); - if (range == RandomRange::kReasonableDstZeroPoint || - range == RandomRange::kReasonableSrcZeroPoint) { - if (global_random_engine()() & 1) { - *dst = SymmetricZeroPoint(); - } - } -} - -template -void MakeRandomVector(RandomRange range, int size, std::vector* dst) { - UniformRandomDistribution dist(range); - dst->resize(size); - for (auto& x : *dst) { - MakeRandomScalar(&dist, &x); - } -} - -enum class LayoutStyle { kPackedLinear, kLinear }; - -inline void MakeLayout(int rows, int cols, Order order, - LayoutStyle layout_style, Layout* layout) { - layout->rows = rows; - layout->cols = cols; - layout->order = order; - - const int packed_stride = order == Order::kColMajor ? rows : cols; - - RUY_CHECK(layout_style == LayoutStyle::kPackedLinear || - layout_style == LayoutStyle::kLinear); - if (layout_style == LayoutStyle::kPackedLinear) { - layout->stride = packed_stride; - } else { - layout->stride = packed_stride + 1; - } -} - -template -struct StorageMatrix { - StorageMatrix() = default; - StorageMatrix(const StorageMatrix&) = delete; - void operator=(const StorageMatrix&) = delete; - std::vector data; - Matrix matrix; -}; - -template -void VerifyConsistentFields(const StorageMatrix& storage_matrix) { - if (storage_matrix.data.empty()) { - RUY_CHECK_EQ(storage_matrix.matrix.data.get(), nullptr); - RUY_CHECK_EQ(storage_matrix.matrix.layout.rows, 0); - RUY_CHECK_EQ(storage_matrix.matrix.layout.cols, 0); - } else { - RUY_CHECK_EQ(storage_matrix.matrix.data.get(), storage_matrix.data.data()); - RUY_CHECK_EQ(FlatSize(storage_matrix.matrix.layout), - storage_matrix.data.size()); - } -} - -template -void MakeRandom(int rows, int cols, Order order, Scalar zero_point, - LayoutStyle layout_style, RandomRange range, - StorageMatrix* storage_matrix) { - MakeLayout(rows, cols, order, layout_style, &storage_matrix->matrix.layout); - storage_matrix->matrix.zero_point = zero_point; - UniformRandomDistribution data_dist(range); - MakeRandomVector(&data_dist, FlatSize(storage_matrix->matrix.layout), - &storage_matrix->data); - storage_matrix->matrix.data = storage_matrix->data.data(); - VerifyConsistentFields(*storage_matrix); -} - -template -struct TestResult { - void operator=(const TestResult&) = delete; - void operator=(const TestResult&&) = delete; - StorageMatrix storage_matrix; - Path path = Path::kNone; - Tuning tuning = Tuning::kAuto; - ExternalPath external_path = ExternalPath::kNone; - float latency; - float l1_refill_rate; - float l2_refill_rate; - float l3_refill_rate; - float l1tlb_refill_rate; - float l2tlb_refill_rate; - float mispred_rate; - float frontend_stall_rate; - float backend_stall_rate; - - // Per-path data for pre-packing. - // This is not used by external paths or by Path::kReference. - Allocator allocator; - PrepackedMatrix prepacked_lhs; - PrepackedMatrix prepacked_rhs; - bool use_prepacked_lhs = false; - bool use_prepacked_rhs = false; -}; - -template -std::string PathName(const TestResult& result) { - std::string pathname; - if (result.path != Path::kNone) { - pathname.assign(PathName(result.path)); - } else if (result.external_path != ExternalPath::kNone) { - pathname.assign(PathName(result.external_path)); - } else { - RUY_CHECK(false); - } - if (result.tuning != Tuning::kAuto) { - pathname.append("/"); - pathname.append(TuningName(result.tuning)); - } - return pathname; -} - -enum class ExpectedOutcome { kSuccess, kDeath }; - -template -struct TestSet final { - using LhsScalar = tLhsScalar; - using RhsScalar = tRhsScalar; - using AccumScalar = typename SpecType::AccumScalar; - using DstScalar = typename SpecType::DstScalar; - using Spec = SpecType; - using TestResultType = TestResult; - - void Run() { - MakeZeroPoints(); - MakeLhsRhs(); - MakeSpec(); - MakeOtherParams(); - MakeResultPaths(); - MakePrepackedMatrices(); - Eval(); - Verify(); - } - - private: - void MakeZeroPoints(); - void MakeLhsRhs(); - void MakeSpec(); - void MakeResultPaths(); - void MakePrepackedMatrices(); - void MakeOtherParams(); - void EvalAndVerify(); - void Eval(); - void Verify(); - - void EvalResult(TestResultType* result); - void EvalRuy(TestResultType* result); - void DoMul(TestResultType* result); - void Benchmark(TestResultType* result); - void VerifyTestResults() const; - - public: - enum class LifeStage { - kInitial, - kHasZeroPoints, - kHasLhsRhs, - kHasSpec, - kHasOtherParams, - kHasResultPaths, - kHasPrepackedMatrices, - kEvaluated, - kFinal - }; - - ~TestSet() { - RUY_CHECK_EQ(life_stage, LifeStage::kFinal); - LogCoveredPathsOnDestruction::Singleton(); - } - - LifeStage life_stage = LifeStage::kInitial; - - int rows = 0; - int cols = 0; - int depth = 0; - Order lhs_order = Order::kRowMajor; - Order rhs_order = Order::kColMajor; - Order dst_order = Order::kColMajor; - LayoutStyle layout_style = LayoutStyle::kPackedLinear; - ExpectedOutcome expected_outcome = ExpectedOutcome::kSuccess; - - bool use_specified_zero_points = false; - LhsScalar lhs_zero_point = 0; - RhsScalar rhs_zero_point = 0; - DstScalar dst_zero_point = 0; - - std::vector per_channel_multiplier_fixedpoint; - std::vector per_channel_multiplier_exponent; - - StorageMatrix lhs; - StorageMatrix rhs; - Spec spec; - std::vector bias_data; - std::vector> results; - - std::vector paths; - std::vector external_paths; - - bool benchmark = false; - bool perchannel = false; - int max_num_threads = 0; - bool benchmark_prepack_lhs = false; - bool benchmark_prepack_rhs = false; -}; - -inline PmuEvents& GlobalPmuEvents() { - static PmuEvents pmu; - return pmu; -} - -inline Context& GlobalContext() { - // Ensure that GlobalPmuEvents is constructed before we create any context. - // This ensures that pmu counters are opened before we create any worker - // thread, which is necessary to count events from worker threads. - GlobalPmuEvents(); - - static Context context; - return context; -} - -#if defined(__has_feature) -#if __has_feature(thread_sanitizer) -#define RUY_TSAN -#endif -#if __has_feature(address_sanitizer) -#define RUY_ASAN -#endif -#endif // defined(__has_feature) - -template -void TestSet::DoMul(TestResultType* result) { - Context* context = &GlobalContext(); - - if (!result->use_prepacked_lhs && !result->use_prepacked_rhs) { - Mul(lhs.matrix, rhs.matrix, spec, context, - &result->storage_matrix.matrix); - return; - } - - // If we prepacked an input matrix, null out its data pointer to check - // that we don't access any data through it. - Matrix null_data_lhs = lhs.matrix; - Matrix null_data_rhs = rhs.matrix; - if (result->use_prepacked_lhs) { - null_data_lhs.data = nullptr; - } - if (result->use_prepacked_rhs) { - null_data_rhs.data = nullptr; - } - - // Do the multiplication with pre-packed matrices. - PrepackedMatrix* prepacked_lhs_ptr = - result->use_prepacked_lhs ? &result->prepacked_lhs : nullptr; - PrepackedMatrix* prepacked_rhs_ptr = - result->use_prepacked_rhs ? &result->prepacked_rhs : nullptr; - MulWithPrepacked(null_data_lhs, null_data_rhs, spec, context, - &result->storage_matrix.matrix, prepacked_lhs_ptr, - prepacked_rhs_ptr); -} - -// When building for WAsm, ASSERT_DEATH is not defined. -#ifdef ASSERT_DEATH -#define RUY_ASSERT_DEATH(CONDITION, MESSAGE) ASSERT_DEATH(CONDITION, MESSAGE) -#else -#define RUY_ASSERT_DEATH(CONDITION, MESSAGE) -#endif - -template -void TestSet::EvalRuy(TestResultType* result) { - GlobalContext().explicit_tuning = result->tuning; - if (max_num_threads) { - GlobalContext().max_num_threads = max_num_threads; - } else if (benchmark) { - GlobalContext().max_num_threads = 1; - } else { - GlobalContext().max_num_threads = 1 + global_random_engine()() % 8; - } - GlobalContext().SetRuntimeEnabledPaths(result->path); - if (expected_outcome == ExpectedOutcome::kSuccess) { - DoMul(result); - RUY_CHECK_EQ(GlobalContext().last_taken_path, result->path); - } else if (expected_outcome == ExpectedOutcome::kDeath) { - // TODO(benoitjacob) TSan and ASan seem to be breaking ASSERT_DEATH. - // Report a bug? -#if (!defined NDEBUG) && (!defined RUY_ASAN) && (!defined RUY_TSAN) - RUY_ASSERT_DEATH(DoMul(result), ""); -#endif - } else { - RUY_CHECK(false); - } - GlobalContext().explicit_tuning = Tuning::kAuto; - GlobalContext().max_num_threads = 1; -} - -#ifdef RUY_TEST_EXTERNAL_PATHS - -template -void WrapGemmlowp(const Matrix& src, - gemmlowp::MatrixMap* dst) { - RUY_CHECK(src.layout.order == (tOrder == gemmlowp::MapOrder::ColMajor - ? Order::kColMajor - : Order::kRowMajor)); - *dst = gemmlowp::MatrixMap( - src.data.get(), src.layout.rows, src.layout.cols, src.layout.stride); -} - -template -void WrapGemmlowpMutable(Matrix* src, - gemmlowp::MatrixMap* dst) { - RUY_CHECK(src->layout.order == (tOrder == gemmlowp::MapOrder::ColMajor - ? Order::kColMajor - : Order::kRowMajor)); - *dst = gemmlowp::MatrixMap( - src->data.get(), src->layout.rows, src->layout.cols, src->layout.stride); -} - -template -struct GemmlowpOrder {}; - -template <> -struct GemmlowpOrder { - static constexpr gemmlowp::MapOrder kValue = gemmlowp::MapOrder::ColMajor; -}; - -template <> -struct GemmlowpOrder { - static constexpr gemmlowp::MapOrder kValue = gemmlowp::MapOrder::RowMajor; -}; - -inline gemmlowp::GemmContext& GlobalGemmlowpContext() { - static gemmlowp::GemmContext context; - return context; -} - -template -void EvalGemmlowp(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, int max_num_threads, - Matrix* dst) { - static constexpr gemmlowp::MapOrder kGemmlowpLhsOrder = - GemmlowpOrder::kValue; - static constexpr gemmlowp::MapOrder kGemmlowpRhsOrder = - GemmlowpOrder::kValue; - static constexpr gemmlowp::MapOrder kGemmlowpDstOrder = - GemmlowpOrder::kValue; - gemmlowp::MatrixMap gemmlowp_lhs; - gemmlowp::MatrixMap gemmlowp_rhs; - gemmlowp::MatrixMap gemmlowp_dst; - WrapGemmlowp(lhs, &gemmlowp_lhs); - WrapGemmlowp(rhs, &gemmlowp_rhs); - WrapGemmlowpMutable(dst, &gemmlowp_dst); - - gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage; - quantize_down_stage.result_offset_after_shift = dst->zero_point; - quantize_down_stage.result_fixedpoint_multiplier = spec.multiplier_fixedpoint; - quantize_down_stage.result_exponent = spec.multiplier_exponent; - gemmlowp::OutputStageScaleInt32ByFixedPointAndExponentPC< - gemmlowp::VectorShape::Col> - quantize_down_stage_pc; - quantize_down_stage_pc.result_offset_after_shift = dst->zero_point; - using ColVectorMap = - gemmlowp::VectorMap; - quantize_down_stage_pc.result_fixedpoint_multiplier = - ColVectorMap(spec.multiplier_fixedpoint_perchannel, lhs.layout.rows); - quantize_down_stage_pc.result_exponent = - ColVectorMap(spec.multiplier_exponent_perchannel, lhs.layout.rows); - - gemmlowp::OutputStageClamp clamp_stage; - clamp_stage.min = spec.clamp_min; - clamp_stage.max = spec.clamp_max; - using OutputStageSaturatingCast = typename std::conditional< - std::is_same::value, - gemmlowp::OutputStageSaturatingCastToUint8, - gemmlowp::OutputStageSaturatingCastToInt16>::type; - OutputStageSaturatingCast saturating_cast_stage; - - GlobalGemmlowpContext().set_max_num_threads(max_num_threads ? max_num_threads - : 1); - if (spec.bias) { - using ColVectorMap = - gemmlowp::VectorMap; - gemmlowp::OutputStageBiasAddition bias_add_stage; - bias_add_stage.bias_vector = ColVectorMap(spec.bias, dst->layout.rows); -#ifndef GEMMLOWP_SSE4 // gemmlowp perchannel stuff does not build on SSE - if (spec.multiplier_exponent_perchannel) { - const auto& output_pipeline = - std::make_tuple(bias_add_stage, quantize_down_stage_pc, clamp_stage, - saturating_cast_stage); - gemmlowp::GemmWithOutputPipeline< - LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( - &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, - -lhs.zero_point, -rhs.zero_point, output_pipeline); - } else // NOLINT[readability/braces] -#endif - { - const auto& output_pipeline = - std::make_tuple(bias_add_stage, quantize_down_stage, clamp_stage, - saturating_cast_stage); - gemmlowp::GemmWithOutputPipeline< - LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( - &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, - -lhs.zero_point, -rhs.zero_point, output_pipeline); - } - } else { -#ifndef GEMMLOWP_SSE4 // gemmlowp perchannel stuff does not build on SSE - if (spec.multiplier_exponent_perchannel) { - const auto& output_pipeline = std::make_tuple( - quantize_down_stage_pc, clamp_stage, saturating_cast_stage); - gemmlowp::GemmWithOutputPipeline< - LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( - &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, - -lhs.zero_point, -rhs.zero_point, output_pipeline); - } else // NOLINT[readability/braces] -#endif - { - const auto& output_pipeline = std::make_tuple( - quantize_down_stage, clamp_stage, saturating_cast_stage); - gemmlowp::GemmWithOutputPipeline< - LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( - &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, - -lhs.zero_point, -rhs.zero_point, output_pipeline); - } - } -} - -inline constexpr int Mash(Order LhsOrder, Order RhsOrder, Order DstOrder) { - return (LhsOrder == Order::kRowMajor ? 4 : 0) + - (RhsOrder == Order::kRowMajor ? 2 : 0) + - (DstOrder == Order::kRowMajor ? 1 : 0); -} - -template -void EvalGemmlowp(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, int max_num_threads, - Matrix* dst) { - int index = Mash(lhs.layout.order, rhs.layout.order, dst->layout.order); - switch (index) { -#define EVALGEMMLOWP_CASE3(LHS, RHS, DST) \ - case Mash(LHS, RHS, DST): \ - return EvalGemmlowp(lhs, rhs, spec, max_num_threads, dst); -#define EVALGEMMLOWP_CASE2(LHS, RHS) \ - EVALGEMMLOWP_CASE3(LHS, RHS, Order::kColMajor) \ - EVALGEMMLOWP_CASE3(LHS, RHS, Order::kRowMajor) -#define EVALGEMMLOWP_CASE1(LHS) \ - EVALGEMMLOWP_CASE2(LHS, Order::kColMajor) \ - EVALGEMMLOWP_CASE2(LHS, Order::kRowMajor) - - EVALGEMMLOWP_CASE1(Order::kColMajor) - EVALGEMMLOWP_CASE1(Order::kRowMajor) - -#undef EVALGEMMLOWP_CASE1 -#undef EVALGEMMLOWP_CASE2 -#undef EVALGEMMLOWP_CASE3 - - default: - RUY_CHECK(false); - } -} - -template -struct EigenOrder {}; - -template <> -struct EigenOrder { - static constexpr int kValue = Eigen::ColMajor; -}; - -template <> -struct EigenOrder { - static constexpr int kValue = Eigen::RowMajor; -}; - -template -void EvalEigen(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, int max_num_threads, Matrix* dst) { - RUY_CHECK_EQ(lhs.zero_point, 0); - RUY_CHECK_EQ(rhs.zero_point, 0); - RUY_CHECK_EQ(dst->zero_point, 0); - RUY_CHECK_EQ(spec.multiplier_fixedpoint, 0); - RUY_CHECK_EQ(spec.multiplier_exponent, 0); - - static constexpr int kEigenLhsOrder = EigenOrder::kValue; - static constexpr int kEigenRhsOrder = EigenOrder::kValue; - static constexpr int kEigenDstOrder = EigenOrder::kValue; - - using EigenLhsType = typename Eigen::Matrix:: - template StridedConstMapType>::type; - using EigenRhsType = typename Eigen::Matrix:: - template StridedConstMapType>::type; - using EigenDstType = typename Eigen::Matrix:: - template StridedMapType>::type; - using EigenBiasType = - typename Eigen::Matrix::ConstMapType; - - EigenLhsType eigen_lhs(lhs.data.get(), lhs.layout.rows, lhs.layout.cols, - Eigen::OuterStride(lhs.layout.stride)); - EigenRhsType eigen_rhs(rhs.data.get(), rhs.layout.rows, rhs.layout.cols, - Eigen::OuterStride(rhs.layout.stride)); - EigenDstType eigen_dst( - dst->data.get(), dst->layout.rows, dst->layout.cols, - Eigen::OuterStride(dst->layout.stride)); - Eigen::setNbThreads(max_num_threads ? max_num_threads : 1); - - if (spec.bias) { - EigenBiasType eigen_bias(spec.bias, dst->layout.rows); - if (spec.clamp_max == std::numeric_limits::infinity() && - spec.clamp_min == -std::numeric_limits::infinity()) { - eigen_dst.noalias() = (eigen_lhs * eigen_rhs).colwise() + eigen_bias; - } else { - eigen_dst.noalias() = ((eigen_lhs * eigen_rhs).colwise() + eigen_bias) - .cwiseMin(spec.clamp_max) - .cwiseMax(spec.clamp_min); - } - } else { - if (spec.clamp_max == std::numeric_limits::infinity() && - spec.clamp_min == -std::numeric_limits::infinity()) { - eigen_dst.noalias() = eigen_lhs * eigen_rhs; - } else { - eigen_dst.noalias() = (eigen_lhs * eigen_rhs) - .cwiseMin(spec.clamp_max) - .cwiseMax(spec.clamp_min); - } - } -} - -template -void EvalEigen(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, int max_num_threads, Matrix* dst) { - int index = Mash(lhs.layout.order, rhs.layout.order, dst->layout.order); - switch (index) { -#define EVALEIGEN_CASE3(LHS, RHS, DST) \ - case Mash(LHS, RHS, DST): \ - return EvalEigen(lhs, rhs, spec, max_num_threads, dst); -#define EVALEIGEN_CASE2(LHS, RHS) \ - EVALEIGEN_CASE3(LHS, RHS, Order::kColMajor) \ - EVALEIGEN_CASE3(LHS, RHS, Order::kRowMajor) -#define EVALEIGEN_CASE1(LHS) \ - EVALEIGEN_CASE2(LHS, Order::kColMajor) \ - EVALEIGEN_CASE2(LHS, Order::kRowMajor) - - EVALEIGEN_CASE1(Order::kColMajor) - EVALEIGEN_CASE1(Order::kRowMajor) - -#undef EVALEIGEN_CASE1 -#undef EVALEIGEN_CASE2 -#undef EVALEIGEN_CASE3 - - default: - RUY_CHECK(false); - } -} - -template -void EvalEigenTensor(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, int max_num_threads, - Matrix* dst) { - RUY_CHECK_EQ(lhs.zero_point, 0); - RUY_CHECK_EQ(rhs.zero_point, 0); - RUY_CHECK_EQ(dst->zero_point, 0); - RUY_CHECK_EQ(spec.multiplier_fixedpoint, 0); - RUY_CHECK_EQ(spec.multiplier_exponent, 0); - - // Eigen::TensorMap only supports packed layouts - RUY_CHECK(IsPacked(lhs.layout)); - RUY_CHECK(IsPacked(rhs.layout)); - RUY_CHECK(IsPacked(dst->layout)); - - using TensorLhsType = - Eigen::TensorMap>; - using TensorRhsType = - Eigen::TensorMap>; - using TensorDstType = - Eigen::TensorMap>; - using TensorBiasType = - Eigen::TensorMap>; - - const bool tr = DstOrder == Order::kRowMajor; - const auto& contract_lhs = tr ? rhs : lhs; - const auto& contract_rhs = tr ? lhs : rhs; - - TensorLhsType tensor_lhs( - contract_lhs.data.get(), - LhsOrder == Order::kColMajor ? contract_lhs.layout.rows - : contract_lhs.layout.cols, - LhsOrder == Order::kColMajor ? contract_lhs.layout.cols - : contract_lhs.layout.rows); - TensorRhsType tensor_rhs( - contract_rhs.data.get(), - RhsOrder == Order::kColMajor ? contract_rhs.layout.rows - : contract_rhs.layout.cols, - RhsOrder == Order::kColMajor ? contract_rhs.layout.cols - : contract_rhs.layout.rows); - TensorDstType tensor_dst( - dst->data.get(), - DstOrder == Order::kColMajor ? dst->layout.rows : dst->layout.cols, - DstOrder == Order::kColMajor ? dst->layout.cols : dst->layout.rows); - using DimPair = - typename Eigen::Tensor::DimensionPair; - Eigen::array contract_dims( - {DimPair((LhsOrder == Order::kColMajor) ? 1 : 0, - (RhsOrder == Order::kColMajor) ? 0 : 1)}); - Eigen::array shuffle(DstOrder == Order::kColMajor ? 0 : 1, - DstOrder == Order::kColMajor ? 1 : 0); - static Eigen::ThreadPool pool(max_num_threads ? max_num_threads : 1); - static Eigen::ThreadPoolDevice device(&pool, pool.NumThreads()); - if (spec.bias) { - TensorBiasType tensor_bias(spec.bias, dst->layout.rows); - Eigen::array bias_2d_shape(tr ? 1 : dst->layout.rows, - tr ? dst->layout.rows : 1); - Eigen::array bcast(tr ? dst->layout.cols : 1, - tr ? 1 : dst->layout.cols); - if (spec.clamp_max == std::numeric_limits::infinity() && - spec.clamp_min == -std::numeric_limits::infinity()) { - tensor_dst.device(device) = - tensor_lhs.contract(tensor_rhs, contract_dims); - } else { - tensor_dst.device(device) = - (tensor_lhs.contract(tensor_rhs, contract_dims) + - tensor_bias.reshape(bias_2d_shape).broadcast(bcast)) - .cwiseMin(spec.clamp_max) - .cwiseMax(spec.clamp_min); - } - } else { - if (spec.clamp_max == std::numeric_limits::infinity() && - spec.clamp_min == -std::numeric_limits::infinity()) { - tensor_dst.device(device) = - tensor_lhs.contract(tensor_rhs, contract_dims); - } else { - tensor_dst.device(device) = tensor_lhs.contract(tensor_rhs, contract_dims) - .cwiseMin(spec.clamp_max) - .cwiseMax(spec.clamp_min); - } - } -} - -template -void EvalEigenTensor(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, int max_num_threads, - Matrix* dst) { - int index = Mash(lhs.layout.order, rhs.layout.order, dst->layout.order); - switch (index) { -#define EVALEIGENTENSOR_CASE3(LHS, RHS, DST) \ - case Mash(LHS, RHS, DST): \ - return EvalEigenTensor(lhs, rhs, spec, max_num_threads, dst); -#define EVALEIGENTENSOR_CASE2(LHS, RHS) \ - EVALEIGENTENSOR_CASE3(LHS, RHS, Order::kColMajor) \ - EVALEIGENTENSOR_CASE3(LHS, RHS, Order::kRowMajor) -#define EVALEIGENTENSOR_CASE1(LHS) \ - EVALEIGENTENSOR_CASE2(LHS, Order::kColMajor) \ - EVALEIGENTENSOR_CASE2(LHS, Order::kRowMajor) - - EVALEIGENTENSOR_CASE1(Order::kColMajor) - EVALEIGENTENSOR_CASE1(Order::kRowMajor) - -#undef EVALEIGENTENSOR_CASE1 -#undef EVALEIGENTENSOR_CASE2 -#undef EVALEIGENTENSOR_CASE3 - - default: - RUY_CHECK(false); - } -} - -template -struct GenericBlasGemm {}; - -template <> -struct GenericBlasGemm { - static void Run(char* transa, char* transb, lapack::integer* m, - lapack::integer* n, lapack::integer* k, - lapack::doublereal* alpha, lapack::doublereal* a, - lapack::integer* lda, lapack::doublereal* b, - lapack::integer* ldb, lapack::doublereal* beta, - lapack::doublereal* c, lapack::integer* ldc) { - dgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - } -}; - -template <> -struct GenericBlasGemm { - static void Run(char* transa, char* transb, lapack::integer* m, - lapack::integer* n, lapack::integer* k, lapack::real* alpha, - lapack::real* a, lapack::integer* lda, lapack::real* b, - lapack::integer* ldb, lapack::real* beta, lapack::real* c, - lapack::integer* ldc) { - sgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - } -}; - -template -void EvalOpenBlas(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, int max_num_threads, Matrix* dst) { - RUY_CHECK_EQ(lhs.zero_point, 0); - RUY_CHECK_EQ(rhs.zero_point, 0); - RUY_CHECK_EQ(dst->zero_point, 0); - RUY_CHECK_EQ(spec.multiplier_fixedpoint, 0); - RUY_CHECK_EQ(spec.multiplier_exponent, 0); - - Matrix gemm_lhs; - Matrix gemm_rhs; - Matrix gemm_dst; - gemm_dst = *dst; - - // Use Transpose to reduce to the all-column-major case. - // Notice that ruy::Matrix merely holds a pointer, does not own data, - // so Transpose is cheap -- no actual matrix data is being transposed here. - if (dst->layout.order == Order::kColMajor) { - gemm_lhs = lhs; - gemm_rhs = rhs; - } else { - gemm_lhs = rhs; - gemm_rhs = lhs; - Transpose(&gemm_lhs); - Transpose(&gemm_rhs); - Transpose(&gemm_dst); - } - bool transposed_lhs = false; - bool transposed_rhs = false; - - if (gemm_lhs.layout.order == Order::kRowMajor) { - Transpose(&gemm_lhs); - transposed_lhs = true; - } - if (gemm_rhs.layout.order == Order::kRowMajor) { - Transpose(&gemm_rhs); - transposed_rhs = true; - } - - RUY_CHECK_EQ(gemm_lhs.layout.order, Order::kColMajor); - RUY_CHECK_EQ(gemm_rhs.layout.order, Order::kColMajor); - RUY_CHECK_EQ(gemm_dst.layout.order, Order::kColMajor); - - char transa = transposed_lhs ? 'T' : 'N'; - char transb = transposed_rhs ? 'T' : 'N'; - int m = gemm_lhs.layout.rows; - int n = gemm_rhs.layout.cols; - int k = gemm_lhs.layout.cols; - float alpha = 1; - Scalar* a = gemm_lhs.data.get(); - int lda = gemm_lhs.layout.stride; - Scalar* b = gemm_rhs.data.get(); - int ldb = gemm_rhs.layout.stride; - float beta = 0; - Scalar* c = gemm_dst.data.get(); - int ldc = gemm_dst.layout.stride; - GenericBlasGemm::Run(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, - &ldb, &beta, c, &ldc); - - // BLAS does not allow us to express the bias-addition and clamping, so - // we use Eigen for that. - - using EigenDstType = - typename Eigen::Matrix:: - template StridedMapType>::type; - using EigenBiasType = - typename Eigen::Matrix::ConstMapType; - - EigenDstType eigen_dst( - gemm_dst.data.get(), gemm_dst.layout.rows, gemm_dst.layout.cols, - Eigen::OuterStride(gemm_dst.layout.stride)); - Eigen::setNbThreads(max_num_threads ? max_num_threads : 1); - - if (spec.bias) { - EigenBiasType eigen_bias(spec.bias, dst->layout.rows); - if (spec.clamp_max == std::numeric_limits::infinity() && - spec.clamp_min == -std::numeric_limits::infinity()) { - eigen_dst.noalias() = eigen_dst.colwise() + eigen_bias; - } else { - eigen_dst.noalias() = (eigen_dst.colwise() + eigen_bias) - .cwiseMin(spec.clamp_max) - .cwiseMax(spec.clamp_min); - } - } else { - if (spec.clamp_max == std::numeric_limits::infinity() && - spec.clamp_min == -std::numeric_limits::infinity()) { - } else { - eigen_dst.noalias() = - eigen_dst.cwiseMin(spec.clamp_max).cwiseMax(spec.clamp_min); - } - } -} - -template -struct SupportsGemmlowp { - static constexpr bool kValue = - std::is_same::value && - std::is_same::value; -}; - -template -struct UsesSingleScalarType { - static constexpr bool kValue = - std::is_same::value && - std::is_same::value && - std::is_same::value; -}; - -template ::value, - bool EnableGemmlowp = SupportsGemmlowp::kValue, - bool SingleScalarType = UsesSingleScalarType::kValue> -struct EvalExternalPathImpl { - using DstScalar = typename TestSetType::DstScalar; - static void Run(TestSetType*, TestResult*) { RUY_CHECK(false); } -}; - -template -struct EvalExternalPathImpl { - using DstScalar = typename TestSetType::DstScalar; - static void Run(TestSetType* test_set, TestResult* test_result) { - if (test_result->external_path == ExternalPath::kEigen) { - EvalEigen(test_set->lhs.matrix, test_set->rhs.matrix, test_set->spec, - test_set->max_num_threads, &test_result->storage_matrix.matrix); - } else if (test_result->external_path == ExternalPath::kEigenTensor) { - EvalEigenTensor(test_set->lhs.matrix, test_set->rhs.matrix, - test_set->spec, test_set->max_num_threads, - &test_result->storage_matrix.matrix); - } else if (test_result->external_path == ExternalPath::kOpenBlas) { - EvalOpenBlas(test_set->lhs.matrix, test_set->rhs.matrix, test_set->spec, - test_set->max_num_threads, - &test_result->storage_matrix.matrix); - } else { - RUY_CHECK(false); - } - } -}; - -template -struct EvalExternalPathImpl { - using DstScalar = typename TestSetType::DstScalar; - static void Run(TestSetType* test_set, TestResult* test_result) { - if (test_result->external_path == ExternalPath::kGemmlowp) { - EvalGemmlowp(test_set->lhs.matrix, test_set->rhs.matrix, test_set->spec, - test_set->max_num_threads, - &test_result->storage_matrix.matrix); - } else { - RUY_CHECK(false); - } - } -}; - -template -void EvalExternalPath( - TestSetType* test_set, - TestResult* test_result) { - EvalExternalPathImpl::Run(test_set, test_result); -} - -#endif // RUY_TEST_EXTERNAL_PATHS - -template -bool Agree(const Matrix& matrix1, const Matrix& matrix2, - int depth) { - RUY_CHECK_EQ(matrix1.layout.rows, matrix2.layout.rows); - RUY_CHECK_EQ(matrix1.layout.cols, matrix2.layout.cols); - RUY_CHECK_EQ(matrix1.zero_point, matrix2.zero_point); - const int size = matrix1.layout.rows * matrix1.layout.cols; - double tolerated_max_diff = 0; - double tolerated_mean_diff = 0; - if (std::is_floating_point::value) { - // TODO: replace hardcoded 100 by something more sensible, probably - // roughly sqrt(depth) based on central limit theorem. - double max_abs_val = 0; - for (int row = 0; row < matrix1.layout.rows; row++) { - for (int col = 0; col < matrix1.layout.cols; col++) { - max_abs_val = - std::max(max_abs_val, - std::abs(static_cast(Element(matrix1, row, col)))); - max_abs_val = - std::max(max_abs_val, - std::abs(static_cast(Element(matrix2, row, col)))); - } - } - tolerated_max_diff = max_abs_val * std::numeric_limits::epsilon() * - 64 * std::sqrt(static_cast(depth)); - tolerated_mean_diff = tolerated_max_diff / std::sqrt(size); - } else if (RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)) { - tolerated_max_diff = 1; - // totally empirical - tolerated_mean_diff = std::min(1.0, 2.0 * std::pow(size, -0.2)); - } - double sum_diff = 0; - for (int row = 0; row < matrix1.layout.rows; row++) { - for (int col = 0; col < matrix1.layout.cols; col++) { - double elem1 = Element(matrix1, row, col); - double elem2 = Element(matrix2, row, col); - double diff = elem1 - elem2; - - sum_diff += diff; - // Test (std::abs(diff) > tolerated_max_diff), but also true if diff is - // NaN. - if (!(std::abs(diff) <= tolerated_max_diff)) { - return false; - } - } - } - double mean_diff = sum_diff / size; - if (std::abs(mean_diff) > tolerated_mean_diff) { - return false; - } - return true; -} - -template -bool Agree(const StorageMatrix& storage_matrix1, - const StorageMatrix& storage_matrix2, int depth) { - VerifyConsistentFields(storage_matrix1); - VerifyConsistentFields(storage_matrix2); - return Agree(storage_matrix1.matrix, storage_matrix2.matrix, depth); -} - -template -bool Agree(const TestResult& result1, const TestResult& result2, - int depth) { - return Agree(result1.storage_matrix, result2.storage_matrix, depth); -} - -struct Stats { - double median; - double mean; - double min; - double max; -}; - -inline std::string StatsAsString(const Stats& stats) { - char buf[256]; - snprintf(buf, sizeof(buf), "(median = %g, mean = %g, min = %g, max = %g)", - stats.median, stats.mean, stats.min, stats.max); - return std::string(buf); -} - -template -void GetMatrixStats(const Matrix& matrix, Stats* stats) { - double min = std::numeric_limits::infinity(); - double max = -std::numeric_limits::infinity(); - double sum = 0; - std::vector allvals; - for (int row = 0; row < matrix.layout.rows; row++) { - for (int col = 0; col < matrix.layout.cols; col++) { - double val = Element(matrix, row, col); - min = std::min(min, val); - max = std::max(max, val); - sum += val; - allvals.push_back(val); - } - } - std::sort(allvals.begin(), allvals.end()); - stats->min = min; - stats->max = max; - stats->mean = sum / allvals.size(); - stats->median = allvals[allvals.size() / 2]; -} - -struct ErrorAnalysis { - Stats stats_good; - Stats stats_bad; - // The below is to help document departure from bit exactness. It's probably - // not going to be relevant to floating-point. - std::set error_rows; - std::set error_cols; - int row_of_first_error = 0; - int col_of_first_error = 0; - double first_error_good_value = 0; - double first_error_bad_value = 0; -}; - -template -void AnalyzeTestError(const TestSetType& test_set, int first_bad_result_index, - ErrorAnalysis* error_analysis) { - const auto& good_matrix = test_set.results[0]->storage_matrix.matrix; - const auto& bad_matrix = - test_set.results[first_bad_result_index]->storage_matrix.matrix; - GetMatrixStats(good_matrix, &error_analysis->stats_good); - GetMatrixStats(bad_matrix, &error_analysis->stats_bad); - bool found_first_error = false; - for (int row = 0; row < good_matrix.layout.rows; row++) { - for (int col = 0; col < good_matrix.layout.cols; col++) { - if (Element(good_matrix, row, col) != Element(bad_matrix, row, col)) { - if (!found_first_error) { - found_first_error = true; - error_analysis->row_of_first_error = row; - error_analysis->col_of_first_error = col; - error_analysis->first_error_good_value = - Element(good_matrix, row, col); - error_analysis->first_error_bad_value = Element(bad_matrix, row, col); - } - error_analysis->error_rows.insert(row); - error_analysis->error_cols.insert(col); - } - } - } -} - -template -void ComputeReasonableMultiplier( - const Matrix& lhs, - const Matrix& rhs, double* multiplier) { - using LhsScalar = typename TestSetType::LhsScalar; - using RhsScalar = typename TestSetType::RhsScalar; - using DstScalar = typename TestSetType::DstScalar; - if (std::is_floating_point::value || - std::is_same::value) { - *multiplier = 0; - return; - } - *multiplier = static_cast(std::numeric_limits::max()) / - (static_cast(lhs.layout.cols) * - std::numeric_limits::max() * - std::numeric_limits::max()); -} - -inline void QuantizeMultiplier(double multiplier_double, - std::int32_t* multiplier_fixedpoint, - int* multiplier_exponent) { - RUY_CHECK_GT(multiplier_double, 0); - if (multiplier_double == 0.) { - *multiplier_fixedpoint = 0; - *multiplier_exponent = 0; - return; - } - const double q = std::frexp(multiplier_double, multiplier_exponent); - auto q_fixed = static_cast(std::round(q * (1ll << 31))); - RUY_CHECK_LE(q_fixed, (1ll << 31)); - if (q_fixed == (1ll << 31)) { - q_fixed /= 2; - ++*multiplier_exponent; - } - RUY_CHECK_LE(q_fixed, std::numeric_limits::max()); - *multiplier_fixedpoint = static_cast(q_fixed); -} - -template -void SwitchMultiplierToPerChannel(TestSetType* test_set) { - test_set->per_channel_multiplier_fixedpoint.resize(test_set->rows); - test_set->per_channel_multiplier_exponent.resize(test_set->rows); - for (int i = 0; i < test_set->rows; i++) { - // multipliers typically range in [2^30 ; 2^31 - 1]. - // Values in [0, 2^30 - 1] are normally unused, but harmless. - // Thus a good way to randomize multipliers is to subtract from them - // a random value smaller than 2^30 but still significant compared to it. - std::int32_t nudged_multiplier = test_set->spec.multiplier_fixedpoint - - (global_random_engine()() % (1 << 26)); - int nudged_exponent = - test_set->spec.multiplier_exponent - 1 + (global_random_engine()() % 4); - test_set->per_channel_multiplier_fixedpoint[i] = nudged_multiplier; - test_set->per_channel_multiplier_exponent[i] = nudged_exponent; - } - test_set->spec.multiplier_fixedpoint_perchannel = - test_set->per_channel_multiplier_fixedpoint.data(); - test_set->spec.multiplier_exponent_perchannel = - test_set->per_channel_multiplier_exponent.data(); - test_set->spec.multiplier_fixedpoint = 0; - test_set->spec.multiplier_exponent = 0; -} - -template < - typename TestSetType, - bool IsApplicable = - std::is_same::value && - !std::is_same::value> -struct MakeSpecMultiplierFieldsImpl {}; - -template -struct MakeSpecMultiplierFieldsImpl { - static void Run(TestSetType* test_set) { - double multiplier; - ComputeReasonableMultiplier(test_set->lhs.matrix, - test_set->rhs.matrix, &multiplier); - QuantizeMultiplier(multiplier, &test_set->spec.multiplier_fixedpoint, - &test_set->spec.multiplier_exponent); - if (!test_set->benchmark) { - test_set->perchannel = global_random_engine()() & 1; - } - if (test_set->perchannel) { - SwitchMultiplierToPerChannel(test_set); - } - } -}; - -template -struct MakeSpecMultiplierFieldsImpl { - static void Run(TestSetType* test_set) { - test_set->spec.multiplier_fixedpoint = 0; - test_set->spec.multiplier_exponent = 0; - } -}; - -template -void MakeSpecClampFields(Spec* spec) { - using AccumScalar = typename Spec::AccumScalar; - using DstScalar = typename Spec::DstScalar; - - if (std::is_same::value) { - // Returning raw accumulators, clamping is not supported. - spec->clamp_min = std::numeric_limits::lowest(); - spec->clamp_max = std::numeric_limits::max(); - return; - } - - if (getenv("BENCHMARK_ONLY_MATMUL")) { - if (std::is_floating_point::value) { - spec->clamp_min = -std::numeric_limits::infinity(); - spec->clamp_max = std::numeric_limits::infinity(); - } else { - spec->clamp_min = std::numeric_limits::lowest(); - spec->clamp_max = std::numeric_limits::max(); - } - return; - } - - spec->clamp_min = std::numeric_limits::lowest() + 1; - spec->clamp_max = std::numeric_limits::max() - 1; -} - -template -void TestSet::MakeZeroPoints() { - RUY_CHECK_EQ(life_stage, LifeStage::kInitial); - if (!benchmark && !use_specified_zero_points) { - MakeRandomScalar(RandomRange::kReasonableSrcZeroPoint, &lhs_zero_point); - MakeRandomScalar(RandomRange::kReasonableSrcZeroPoint, &rhs_zero_point); - // If destination is std::int32_t, no dst_zero_point is necessary. - if (std::is_same::value) { - dst_zero_point = 0; - } else { - MakeRandomScalar(RandomRange::kReasonableDstZeroPoint, &dst_zero_point); - } - } - life_stage = LifeStage::kHasZeroPoints; -} - -template -void TestSet::MakeLhsRhs() { - RUY_CHECK_EQ(life_stage, LifeStage::kHasZeroPoints); - MakeRandom(rows, depth, lhs_order, lhs_zero_point, layout_style, - RandomRange::kOffCenterAvoidMinValue, &lhs); - MakeRandom(depth, cols, rhs_order, rhs_zero_point, layout_style, - RandomRange::kGeneral, &rhs); - life_stage = LifeStage::kHasLhsRhs; -} - -template -void TestSet::MakeSpec() { - RUY_CHECK_EQ(life_stage, LifeStage::kHasLhsRhs); - - if (!getenv("BENCHMARK_ONLY_MATMUL") && - (benchmark || (global_random_engine()() & 1))) { - MakeRandomVector(RandomRange::kBias, rows, &bias_data); - spec.bias = bias_data.data(); - } - if (lhs.matrix.zero_point == std::numeric_limits::lowest() && - rhs.matrix.zero_point == std::numeric_limits::lowest()) { - lhs.matrix.zero_point += 1; - } - MakeSpecMultiplierFieldsImpl::Run(this); - MakeSpecClampFields(&spec); - life_stage = LifeStage::kHasSpec; -} - -inline int GetIntEnvVarOrZero(const char* name) { - const char* val = getenv(name); - if (!val) { - return 0; - } - return std::stoi(val); -} - -inline float GetFloatEnvVarOrZero(const char* name) { - const char* val = getenv(name); - if (!val) { - return 0; - } - return std::stof(val); -} - -inline int GetHexIntEnvVarOrZero(const char* name) { - const char* val = getenv(name); - if (!val) { - return 0; - } - return std::stoi(val, nullptr, 16); -} - -inline bool GetBoolEnvVarOrFalse(const char* name) { - return static_cast(GetIntEnvVarOrZero(name)); -} - -template -void TestSet::MakeOtherParams() { - RUY_CHECK_EQ(life_stage, LifeStage::kHasSpec); - if (max_num_threads == 0) { - max_num_threads = GetIntEnvVarOrZero("THREADS"); - } - life_stage = LifeStage::kHasOtherParams; -} - -inline std::vector PathsBitfieldAsVector(Path paths_bitfield) { - std::vector result; - std::uint32_t remaining_paths = static_cast(paths_bitfield); - std::uint32_t test_bit = 1; - while (remaining_paths) { - if (remaining_paths & test_bit) { - result.push_back(static_cast(test_bit)); - } - remaining_paths &= ~test_bit; - test_bit <<= 1; - } - return result; -} - -inline std::vector EnumerateTuningsForPath(Path path, bool benchmark) { - if (benchmark) { - return {Tuning::kAuto}; - } -#if RUY_PLATFORM(ARM) - if (path == Path::kNeon || path == Path::kNeonDotprod) { - return {Tuning::kInOrder, Tuning::kOutOfOrder, Tuning::kAuto}; - } -#endif - return {Tuning::kAuto}; -} - -template -void TestSet::MakePrepackedMatrices() { - RUY_CHECK_EQ(life_stage, LifeStage::kHasResultPaths); - - // Prepacked matrices are Path-dependent, so create them for each test result. - for (auto& result : results) { - // If this result uses an external path, then skip this entirely. - if (result->path == Path::kNone) { - continue; - } - // Pre-packing doesn't make sense for Path::kReference. - // TODO(silvasean): Make Path::kReference an ExternalPath? - if (result->path == Path::kReference) { - continue; - } - - // Determine whether we should create/use prepacked matrices. - if (benchmark) { - // For benchmarking, do as requested. - result->use_prepacked_lhs = benchmark_prepack_lhs; - result->use_prepacked_rhs = benchmark_prepack_rhs; - } else { - // When testing, randomly pre-pack sometimes. But don't do it too often. - result->use_prepacked_lhs = (global_random_engine()() & 7) == 0; - result->use_prepacked_rhs = (global_random_engine()() & 7) == 0; - } - - // Create the pre-packed matrices. - PrepackedMatrix* prepacked_lhs_ptr = - result->use_prepacked_lhs ? &result->prepacked_lhs : nullptr; - PrepackedMatrix* prepacked_rhs_ptr = - result->use_prepacked_rhs ? &result->prepacked_rhs : nullptr; - auto alloc_fn = [&result](std::size_t num_bytes) { - return result->allocator.AllocateBytes(num_bytes); - }; - // Use a dst with a null data pointer to check that the pre-packing - // invocation doesn't write into it. - Matrix null_data_dst = result->storage_matrix.matrix; - null_data_dst.data = nullptr; - GlobalContext().SetRuntimeEnabledPaths(result->path); - PrePackForMul(lhs.matrix, rhs.matrix, spec, &GlobalContext(), - &null_data_dst, prepacked_lhs_ptr, - prepacked_rhs_ptr, alloc_fn); - RUY_CHECK_EQ(GlobalContext().last_taken_path, result->path); - } - - life_stage = LifeStage::kHasPrepackedMatrices; -} - -template -void TestSet::MakeResultPaths() { - RUY_CHECK_EQ(life_stage, LifeStage::kHasOtherParams); - - Path paths_bitfield = static_cast(GetHexIntEnvVarOrZero("PATHS")); - - if (paths_bitfield == Path::kNone) { - // Use a dummy Context just to perform the resolution of specific runtime - // enabled paths. - Context context; - paths_bitfield = context.GetRuntimeEnabledPaths(); - } - - // Trim bits that don't correspond to a compiled path, - // to allow specifying e.g. ffff to mean 'all paths' regardless of whether all - // those bits exist as actual paths. - paths_bitfield = paths_bitfield & kAllPaths; - RUY_CHECK_NE(paths_bitfield, Path::kNone); - paths = PathsBitfieldAsVector(paths_bitfield); - -#ifdef RUY_TEST_EXTERNAL_PATHS - - using TestSetType = TestSet; - - if (!GetBoolEnvVarOrFalse("NOEXT")) { - if (SupportsGemmlowp::kValue) { -#ifdef GEMMLOWP_SSE4 - const bool gemmlowp_supported = !spec.multiplier_fixedpoint_perchannel; -#else - const bool gemmlowp_supported = true; -#endif - if (gemmlowp_supported) { - external_paths.push_back(ExternalPath::kGemmlowp); - } - } - if (UsesSingleScalarType::kValue && - std::is_floating_point::value) { - external_paths.push_back(ExternalPath::kEigen); - if (layout_style == LayoutStyle::kPackedLinear) { - external_paths.push_back(ExternalPath::kEigenTensor); - } -// We link against a generic BLAS target that only maps to OpenBLAS on specific -// architectures. -#if RUY_PLATFORM(ARM_32) || RUY_PLATFORM(ARM_64) - // OpenBLAS multi-threading is disabled, so avoid mixing single-threaded - // and multi-threaded benchmark results. - if (max_num_threads == 1 && !getenv("NO_OPENBLAS")) { - external_paths.push_back(ExternalPath::kOpenBlas); - } -#endif - } - } - -#endif // RUY_TEST_EXTERNAL_PATHS - - for (Path path : paths) { - for (Tuning tuning : EnumerateTuningsForPath(path, benchmark)) { - results.emplace_back(new TestResultType); - TestResultType& result = *results.back(); - result.path = path; - result.tuning = tuning; - MakeRandom(rows, cols, dst_order, dst_zero_point, layout_style, - RandomRange::kGeneral, &result.storage_matrix); - } - } - - for (ExternalPath external_path : external_paths) { - results.emplace_back(new TestResultType); - TestResultType& result = *results.back(); - result.external_path = external_path; - MakeRandom(rows, cols, dst_order, dst_zero_point, layout_style, - RandomRange::kGeneral, &result.storage_matrix); - } - - life_stage = LifeStage::kHasResultPaths; -} - -template -void TestSet::EvalResult( - TestResult* result) { - RUY_CHECK(result->path != Path::kNone || - result->external_path != ExternalPath::kNone); - if (result->path != Path::kNone) { - EvalRuy(result); - } else { -#ifdef RUY_TEST_EXTERNAL_PATHS - using TestSetType = TestSet; - EvalExternalPath(this, result); -#endif - } - const std::string& pathname = PathName(*result); - if (std::find(CoveredPaths()->begin(), CoveredPaths()->end(), pathname) == - CoveredPaths()->end()) { - CoveredPaths()->push_back(pathname); - } -} - -using f32 = float; -using f64 = double; -using u8 = std::uint8_t; -using i8 = std::int8_t; -using u16 = std::uint16_t; -using i16 = std::int16_t; -using u32 = std::uint32_t; -using i32 = std::int32_t; -using u64 = std::uint64_t; -using i64 = std::int64_t; - -template -const char* TypeName() { - return nullptr; -} - -#define RUY_TYPENAME(TYPE) \ - template <> \ - const char* TypeName() { \ - return #TYPE; \ - } - -RUY_TYPENAME(f32) -RUY_TYPENAME(f64) -RUY_TYPENAME(u8) -RUY_TYPENAME(i8) -RUY_TYPENAME(u16) -RUY_TYPENAME(i16) -RUY_TYPENAME(u32) -RUY_TYPENAME(i32) -RUY_TYPENAME(u64) -RUY_TYPENAME(i64) - -#undef RUY_TYPENAME - -template -const char* SymmetryName(const Matrix& matrix) { - if (matrix.zero_point == SymmetricZeroPoint()) { - return "symm"; - } else { - return "asymm"; - } -} - -template -int StorageSize(const Matrix& matrix) { - return sizeof(Scalar) * FlatSize(matrix.layout); -} - -// Helper that replicates a buffer and gives out pointers to the replicas. -// This is useful when one wants to traverse data so that it is cold in cache. -// By having a sufficiently large value of num_repeats, one can ensure that the -// working set covered by the replicas is greater than the cache size. -template -class RepeatedBuffer { - public: - RepeatedBuffer() = default; - void Init(const T* elems, std::size_t num_elems, int num_repeats) { - buffers_.clear(); - allocator_.FreeAll(); - for (int i = 0; i < num_repeats; i++) { - T* p; - allocator_.Allocate(num_elems, &p); - memcpy(p, elems, num_elems * sizeof(T)); - buffers_.push_back(p); - } - } - T* Next() { - T* ret = buffers_[current_]; - current_ = (current_ + 1) % buffers_.size(); - return ret; - } - - private: - Allocator allocator_; - std::vector buffers_; - int current_ = 0; -}; - -template -void TestSet::Benchmark( - TestResult* result) { - using DstScalar = typename SpecType::DstScalar; - - const bool cold = getenv("RUY_BENCHMARK_COLD"); - LhsScalar* orig_lhs_data = lhs.matrix.data.get(); - RhsScalar* orig_rhs_data = rhs.matrix.data.get(); - DstScalar* orig_dst_data = result->storage_matrix.matrix.data.get(); - void* orig_prepacked_lhs_data = result->prepacked_lhs.data; - void* orig_prepacked_rhs_data = result->prepacked_rhs.data; - - int num_matmul_sets = 0; - - RepeatedBuffer cold_lhs; - RepeatedBuffer cold_rhs; - RepeatedBuffer cold_dst; - RepeatedBuffer cold_prepacked_lhs; - RepeatedBuffer cold_prepacked_rhs; - - if (cold) { - const int kWorkingSetSize = 100 << 20; - const int each_matmul_set_size = StorageSize(lhs.matrix) + - StorageSize(rhs.matrix) + - StorageSize(result->storage_matrix.matrix); - num_matmul_sets = - (kWorkingSetSize + each_matmul_set_size - 1) / each_matmul_set_size; - - cold_lhs.Init(lhs.matrix.data.get(), FlatSize(lhs.matrix.layout), - num_matmul_sets); - cold_rhs.Init(rhs.matrix.data.get(), FlatSize(rhs.matrix.layout), - num_matmul_sets); - cold_dst.Init(result->storage_matrix.matrix.data.get(), - FlatSize(result->storage_matrix.matrix.layout), - num_matmul_sets); - if (benchmark_prepack_lhs) { - cold_prepacked_lhs.Init(static_cast(result->prepacked_lhs.data), - result->prepacked_lhs.data_size, num_matmul_sets); - } - if (benchmark_prepack_rhs) { - cold_prepacked_rhs.Init(static_cast(result->prepacked_rhs.data), - result->prepacked_rhs.data_size, num_matmul_sets); - } - } - const bool record_pmu = GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU"); - int repeats = GetIntEnvVarOrZero("RUY_BENCHMARK_REPEATS"); - if (!repeats) { - repeats = 4; - } - float benchmark_min_secs = GetFloatEnvVarOrZero("RUY_BENCHMARK_MIN_SECS"); - if (!benchmark_min_secs) { - benchmark_min_secs = 0.5; - } -#ifdef RUY_PROFILER - { - const char* lhstype = TypeName(); - const char* lhssymm = SymmetryName(lhs.matrix); - const char* rhstype = TypeName(); - const char* rhssymm = SymmetryName(rhs.matrix); - - printf("Profiling path=%s shape=(%dx%dx%d) lhs=(%s,%s) rhs=(%s,%s)\n", - PathName(*result).c_str(), rows, depth, cols, lhstype, lhssymm, - rhstype, rhssymm); - ruy::profiler::ScopeProfile profile; -#endif - - float latency = std::numeric_limits::infinity(); - float l1_refill_rate = std::numeric_limits::infinity(); - float l2_refill_rate = std::numeric_limits::infinity(); - float l3_refill_rate = std::numeric_limits::infinity(); - float l1tlb_refill_rate = std::numeric_limits::infinity(); - float l2tlb_refill_rate = std::numeric_limits::infinity(); - float mispred_rate = std::numeric_limits::infinity(); - float frontend_stall_rate = std::numeric_limits::infinity(); - float backend_stall_rate = std::numeric_limits::infinity(); - - for (int repeat = 0; repeat < repeats; repeat++) { - auto& pmu_events = GlobalPmuEvents(); - if (record_pmu) { - pmu_events.StartRecording(); - } - TimePoint time_start = Now(); - TimePoint t = time_start; - int iters = 0; - int iters_at_a_time = 1; - while (ToFloatSeconds(t - time_start) < benchmark_min_secs) { - for (int i = 0; i < iters_at_a_time; i++) { - if (cold) { - lhs.matrix.data = cold_lhs.Next(); - rhs.matrix.data = cold_rhs.Next(); - result->storage_matrix.matrix.data = cold_dst.Next(); - if (benchmark_prepack_lhs) { - result->prepacked_lhs.data = cold_prepacked_lhs.Next(); - } - if (benchmark_prepack_rhs) { - result->prepacked_rhs.data = cold_prepacked_rhs.Next(); - } - } - EvalResult(result); - iters++; - } - iters_at_a_time *= 2; - t = Now(); - } - latency = std::min( - latency, static_cast(ToFloatSeconds(t - time_start) / iters)); - if (record_pmu) { - pmu_events.StopRecording(); - const float normalization_factor = - 1.0f / (static_cast(iters) * rows * cols * depth); - l1_refill_rate = std::min( - l1_refill_rate, pmu_events.L1RefillCount() * normalization_factor); - l2_refill_rate = std::min( - l2_refill_rate, pmu_events.L2RefillCount() * normalization_factor); - l3_refill_rate = std::min( - l3_refill_rate, pmu_events.L3RefillCount() * normalization_factor); - l1tlb_refill_rate = - std::min(l1tlb_refill_rate, - pmu_events.L1TLBRefillCount() * normalization_factor); - l2tlb_refill_rate = - std::min(l2tlb_refill_rate, - pmu_events.L2TLBRefillCount() * normalization_factor); - mispred_rate = - std::min(mispred_rate, pmu_events.BranchMispredictionCount() * - normalization_factor); - frontend_stall_rate = - std::min(frontend_stall_rate, - pmu_events.FrontendStallCount() * normalization_factor); - backend_stall_rate = - std::min(backend_stall_rate, - pmu_events.BackendStallCount() * normalization_factor); - } - } - result->latency = latency; - if (record_pmu) { - result->l1_refill_rate = l1_refill_rate; - result->l2_refill_rate = l2_refill_rate; - result->l3_refill_rate = l3_refill_rate; - result->l1tlb_refill_rate = l1tlb_refill_rate; - result->l2tlb_refill_rate = l2tlb_refill_rate; - result->mispred_rate = mispred_rate; - result->frontend_stall_rate = frontend_stall_rate; - result->backend_stall_rate = backend_stall_rate; - } - -#ifdef RUY_PROFILER - } - fflush(stdout); -#endif - - if (cold) { - lhs.matrix.data = orig_lhs_data; - rhs.matrix.data = orig_rhs_data; - memcpy(orig_dst_data, result->storage_matrix.matrix.data.get(), - StorageSize(result->storage_matrix.matrix)); - result->storage_matrix.matrix.data = orig_dst_data; - result->prepacked_lhs.data = orig_prepacked_lhs_data; - result->prepacked_rhs.data = orig_prepacked_rhs_data; - } -} - -template -void TestSet::Eval() { - RUY_CHECK_EQ(life_stage, LifeStage::kHasPrepackedMatrices); - for (auto& result : results) { - if (benchmark) { - Benchmark(result.get()); - } else { - EvalResult(result.get()); - } - } - life_stage = LifeStage::kEvaluated; -} - -template -std::string DumpRegion(const Matrix& matrix, int center_row, - int center_col) { - static constexpr int kRadius = 20; - int first_row = std::max(0, center_row - kRadius); - int last_row = std::min(matrix.layout.rows - 1, center_row + kRadius); - int first_col = std::max(0, center_col - kRadius); - int last_col = std::min(matrix.layout.cols - 1, center_col + kRadius); - std::ostringstream stream; - for (int row = first_row; row <= last_row; row++) { - for (int col = first_col; col <= last_col; col++) { - stream << static_cast(Element(matrix, row, col)) << " "; - } - stream << "\n"; - } - return stream.str(); -} - -template -void TestSet::VerifyTestResults() const { - const int depth = lhs.matrix.layout.cols; - for (int i = 0; i < results.size() - 1; i++) { - if (!Agree(*results[i], *results[i + 1], depth)) { - std::string paths_in_agreement; - paths_in_agreement.append(PathName(*results[0])); - for (int j = 1; j <= i; j++) { - paths_in_agreement.append(", "); - paths_in_agreement.append(PathName(*results[j])); - } - ErrorAnalysis error_analysis; - AnalyzeTestError(*this, i + 1, &error_analysis); - std::cerr << "Error: path (" << PathName(*results[i + 1]) - << ") disagrees with the other paths (" << paths_in_agreement - << "), which agree with each other." << std::endl; - std::cerr << "Shape: rows = " << rows << ", cols = " << cols - << ", depth = " << depth << std::endl; - std::cerr << "Stats of the good result matrix: " - << StatsAsString(error_analysis.stats_good) << std::endl; - std::cerr << "Stats of the bad result matrix: " - << StatsAsString(error_analysis.stats_bad) << std::endl; - if (error_analysis.error_rows.size() < rows) { - std::cerr << "Rows containing errors: " - << Join(error_analysis.error_rows) << std::endl; - } else { - std::cerr << "Errors found in ALL rows." << std::endl; - } - if (error_analysis.error_cols.size() < cols) { - std::cerr << "Cols containing errors: " - << Join(error_analysis.error_cols) << std::endl; - } else { - std::cerr << "Errors found in ALL cols." << std::endl; - } - std::cerr << "The first error occurs at row " - << error_analysis.row_of_first_error << ", col " - << error_analysis.col_of_first_error << std::endl; - std::cerr << "Good value: " << error_analysis.first_error_good_value - << std::endl; - std::cerr << "Bad value : " << error_analysis.first_error_bad_value - << std::endl; - std::cerr << "Region of Good result matrix around first error:\n\n" - << DumpRegion(results[0]->storage_matrix.matrix, - error_analysis.row_of_first_error, - error_analysis.col_of_first_error) - << std::endl; - std::cerr << "Region of Bad result matrix around first error:\n\n" - << DumpRegion(results[i + 1]->storage_matrix.matrix, - error_analysis.row_of_first_error, - error_analysis.col_of_first_error) - << std::endl; - RUY_CHECK(false); - } - } -} - -template -void TestSet::Verify() { - RUY_CHECK_EQ(life_stage, LifeStage::kEvaluated); - if (expected_outcome == ExpectedOutcome::kSuccess) { - VerifyTestResults(); - } - life_stage = LifeStage::kFinal; -} - -template -void TestRCC(int rows, int depth, int cols, ExpectedOutcome expected_outcome) { - TestSetType test_set; - test_set.rows = rows; - test_set.depth = depth; - test_set.cols = cols; - test_set.lhs_order = Order::kRowMajor; - test_set.rhs_order = Order::kColMajor; - test_set.dst_order = Order::kColMajor; - test_set.layout_style = LayoutStyle::kPackedLinear; - test_set.expected_outcome = expected_outcome; - test_set.Run(); -} - -template -void TestRCC(int rows, int depth, int cols) { - TestRCC(rows, depth, cols, ExpectedOutcome::kSuccess); -} - -template -void TestNonRCC(int rows, int depth, int cols, - ExpectedOutcome expected_outcome) { - TestSetType test_set; - test_set.rows = rows; - test_set.depth = depth; - test_set.cols = cols; - test_set.lhs_order = Order::kColMajor; - test_set.rhs_order = Order::kColMajor; - test_set.dst_order = Order::kColMajor; - test_set.layout_style = LayoutStyle::kPackedLinear; - test_set.expected_outcome = expected_outcome; - test_set.Run(); -} - -template -void TestLinearAllOrders(int rows, int depth, int cols, - ExpectedOutcome expected_outcome) { - const std::vector orders{Order::kColMajor, Order::kRowMajor}; - - for (Order lhs_order : orders) { - for (Order rhs_order : orders) { - for (Order dst_order : orders) { - TestSetType test_set; - test_set.rows = rows; - test_set.depth = depth; - test_set.cols = cols; - test_set.lhs_order = lhs_order; - test_set.rhs_order = rhs_order; - test_set.dst_order = dst_order; - test_set.layout_style = LayoutStyle::kLinear; - test_set.expected_outcome = expected_outcome; - test_set.Run(); - } - } - } -} - -template -void TestLinearAllOrders(int rows, int depth, int cols) { - TestLinearAllOrders(rows, depth, cols, - ExpectedOutcome::kSuccess); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TEST_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/test_fast.cc b/tensorflow/lite/experimental/ruy/ruy/test_fast.cc deleted file mode 100644 index 6b7026530ac..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/test_fast.cc +++ /dev/null @@ -1,110 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This test contains cheap test cases, completes in a few seconds. - -#include - -#include "tensorflow/lite/experimental/ruy/ruy/test.h" - -namespace ruy { - -using LhsScalar = RUY_TEST_LHSSCALAR; -using RhsScalar = RUY_TEST_RHSSCALAR; -using AccumScalar = RUY_TEST_ACCUMSCALAR; -using DstScalar = RUY_TEST_DSTSCALAR; - -using TestSetType = - TestSet>; - -TEST(RuyTest, TestSquareMuls) { - const std::vector sizes{ - // small sizes - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 10, - // multiplies of 16 - 16, - 32, - 48, - 64, - // pot-minus-1 sizes - 15, - 31, - 63, - // pot-plus-1 sizes - 17, - 33, - 65, - }; - - for (int size : sizes) { - TestRCC(size, size, size); - TestLinearAllOrders(size, size, size); - } -} - -TEST(RuyTest, TestMiscMuls) { - const int shapes[][3] = { - {2, 3, 4}, {7, 6, 5}, {12, 23, 6}, {19, 3, 11}, {3, 10, 17}, - {30, 21, 43}, {7, 57, 9}, {49, 69, 71}, {38, 111, 29}, {87, 98, 76}, - {16, 96, 16}, {16, 88, 16}, {16, 84, 16}, {16, 92, 16}, {16, 82, 16}, - {16, 81, 16}, {16, 95, 16}, {3, 128, 5}}; - for (const auto& shape : shapes) { - TestLinearAllOrders(shape[0], shape[1], shape[2]); - } -} - -TEST(RuyTest, TestDeepMuls) { - // TODO(b/137649322): clarify what's the max allowed matrix size. - TestRCC(1, 32767, 1); - TestLinearAllOrders(5, 5001, 4); - TestLinearAllOrders(9, 1025, 10); -} - -TEST(RuyTest, TestShallowMuls) { - TestLinearAllOrders(101, 1, 103); - TestLinearAllOrders(71, 2, 53); - TestLinearAllOrders(51, 3, 73); - TestLinearAllOrders(51, 4, 43); -} - -TEST(RuyTest, TestNarrowMuls) { - for (int width : {1, 2, 3, 4, 5, 8}) { - TestLinearAllOrders(width, 12, 13); - TestLinearAllOrders(15, 19, width); - TestLinearAllOrders(width, 123, 137); - TestLinearAllOrders(158, 119, width); - } -} - -TEST(RuyTest, TestGEMV) { - for (int size = 1; size < 1024; size *= 2) { - for (int depth = 1; depth < 500; depth += 47) { - TestLinearAllOrders(size, depth, 1); - } - } - TestLinearAllOrders(5, 5001, 1); - TestLinearAllOrders(8193, 17, 1); -} - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/test_slow.cc b/tensorflow/lite/experimental/ruy/ruy/test_slow.cc deleted file mode 100644 index 7e7292cd503..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/test_slow.cc +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This test contains more expensive test cases. - -#include "tensorflow/lite/experimental/ruy/ruy/test.h" - -namespace ruy { - -using LhsScalar = RUY_TEST_LHSSCALAR; -using RhsScalar = RUY_TEST_RHSSCALAR; -using AccumScalar = RUY_TEST_ACCUMSCALAR; -using DstScalar = RUY_TEST_DSTSCALAR; - -using TestSetType = - TestSet>; - -TEST(RuyTest, TestBigNarrowMuls) { - for (int width : {1, 2, 3, 4, 5, 8}) { - TestRCC(width, 401, 601); - TestRCC(587, 443, width); - } - TestRCC(7, 45984, - 5); // Large enough to trigger row-sum overflows. - TestRCC(512, 256, 16); -} - -TEST(RuyTest, TestBigShallowMuls) { - TestLinearAllOrders(501, 1, 321); - TestLinearAllOrders(301, 5, 403); - TestLinearAllOrders(256, 32, 512); -} - -TEST(RuyTest, TestBigMuls) { - TestRCC(225, 303, 199); - TestLinearAllOrders(256, 192, 128); -} - -TEST(RuyTest, TestBigPowerOfTwoDepthWithAvoidAliasing) { - // Important to test some power-of-two depths: that's when the - // RUY_AVOID_ALIASING optimization kicks in and makes packed matrices - // strided, exposing bugs in kernels mixing up size and stride. - // Moreover, it's important that the test matrices be sufficiently wide - // that they will result in multiple blocks, exposing bugs in the - // computation of the base address of each block. - TestLinearAllOrders(70, 1024, 80); - TestLinearAllOrders(60, 2048, 70); - TestLinearAllOrders(40, 4096, 50); -} - -TEST(RuyTest, TestGEMV) { - for (int size = 1025; size <= 1409; size += 384) { - for (int depth = 350; depth < 500; depth += 47) { - TestLinearAllOrders(size, depth, 1); - } - } -} - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/test_special_specs.cc b/tensorflow/lite/experimental/ruy/ruy/test_special_specs.cc deleted file mode 100644 index 6f5a88c833a..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/test_special_specs.cc +++ /dev/null @@ -1,163 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This test covers non-basic specs. - -#include "tensorflow/lite/experimental/ruy/ruy/test.h" - -namespace ruy { - -template -struct LoopStructureSpec : BasicSpec { - static constexpr LoopStructure kLoopStructure = tLoopStructure; -}; - -template -struct ZeroPointSupportSpec : BasicSpec { - static constexpr ZeroPointSupport kZeroPointSupport = tZeroPointSupport; -}; - -template -struct RCCSpec : BasicSpec { - static constexpr LayoutSupport kLayoutSupport = LayoutSupport::kRCC; -}; - -template -struct StandardCppKernelLayoutSpec : BasicSpec { - using StandardCppKernelLhsLayout = LhsKernelLayout; - using StandardCppKernelRhsLayout = RhsKernelLayout; - static int local_data_cache_size() { return 1; } - static int shared_data_cache_size() { return 1; } -}; - -using LhsScalar = RUY_TEST_LHSSCALAR; -using RhsScalar = RUY_TEST_RHSSCALAR; -using AccumScalar = RUY_TEST_ACCUMSCALAR; -using DstScalar = RUY_TEST_DSTSCALAR; - -template -void TestLoopStructure() { - using SpecType = LoopStructureSpec; - using TestSetType = TestSet; - for (int size = 1; size < 10; size++) { - TestLinearAllOrders(size, size, size); - } - TestLinearAllOrders(3, 5, 78); - TestLinearAllOrders(19, 91, 7); - TestLinearAllOrders(71, 26, 44); - TestLinearAllOrders(81, 93, 72); -} - -TEST(TestSpecialSpecs, LoopStructure) { - static_assert(BasicSpec::kLoopStructure == - LoopStructure::kAuto, - ""); - static_assert(BasicSpec::kLoopStructure == LoopStructure::kAuto, - ""); - TestLoopStructure(); - TestLoopStructure(); -} - -template -void TestZeroPointSupport(LhsScalar lhs_zero_point, RhsScalar rhs_zero_point, - DstScalar dst_zero_point, - ExpectedOutcome expected_outcome) { - using SpecType = - ZeroPointSupportSpec; - using TestSetType = TestSet; - TestSetType test_set; - test_set.rows = 11; - test_set.depth = 12; - test_set.cols = 13; - test_set.lhs_order = Order::kRowMajor; - test_set.rhs_order = Order::kColMajor; - test_set.dst_order = Order::kColMajor; - test_set.layout_style = LayoutStyle::kPackedLinear; - test_set.expected_outcome = expected_outcome; - test_set.lhs_zero_point = lhs_zero_point; - test_set.rhs_zero_point = rhs_zero_point; - test_set.dst_zero_point = dst_zero_point; - test_set.use_specified_zero_points = true; - test_set.Run(); -} - -TEST(TestSpecialSpecs, ZeroPointSupport) { - // Sanity check - RUY_CHECK_EQ(SymmetricZeroPoint(), 128); - RUY_CHECK_EQ(SymmetricZeroPoint(), 0); - - if (std::is_floating_point::value) { - return; - } - - TestZeroPointSupport( - SymmetricZeroPoint(), SymmetricZeroPoint(), - SymmetricZeroPoint(), ExpectedOutcome::kSuccess); - TestZeroPointSupport( - SymmetricZeroPoint() - 1, SymmetricZeroPoint(), - SymmetricZeroPoint(), ExpectedOutcome::kSuccess); - TestZeroPointSupport( - SymmetricZeroPoint(), SymmetricZeroPoint(), - SymmetricZeroPoint(), ExpectedOutcome::kSuccess); - TestZeroPointSupport( - SymmetricZeroPoint() + 1, SymmetricZeroPoint(), - SymmetricZeroPoint(), ExpectedOutcome::kDeath); - TestZeroPointSupport( - SymmetricZeroPoint(), SymmetricZeroPoint() + 1, - SymmetricZeroPoint(), ExpectedOutcome::kDeath); - TestZeroPointSupport( - SymmetricZeroPoint(), SymmetricZeroPoint(), - SymmetricZeroPoint() - 1, ExpectedOutcome::kDeath); -} - -TEST(TestSpecialSpecs, RCC) { - using RCCSpec = RCCSpec; - using RCCTestSet = TestSet; - TestRCC(81, 93, 72); - TestNonRCC(81, 93, 72, ExpectedOutcome::kDeath); -} - -template -void TestStandardCppKernelLayout() { - using SpecType = - StandardCppKernelLayoutSpec; - using TestSetType = TestSet; - for (int size = 1; size < 10; size++) { - TestLinearAllOrders(size, size, size); - } - TestLinearAllOrders(87, 34, 56); - TestLinearAllOrders(123, 234, 78); -} - -TEST(TestSpecialSpecs, StandardCppKernelLayoutTrivial1x1) { - TestStandardCppKernelLayout, - FixedKernelLayout>(); -} - -TEST(TestSpecialSpecs, StandardCppKernelLayoutSquare4x4) { - TestStandardCppKernelLayout, - FixedKernelLayout>(); -} - -TEST(TestSpecialSpecs, StandardCppKernelLayoutRectangular4x8) { - TestStandardCppKernelLayout, - FixedKernelLayout>(); -} - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/thread_pool.cc b/tensorflow/lite/experimental/ruy/ruy/thread_pool.cc deleted file mode 100644 index eb86a1fbf38..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/thread_pool.cc +++ /dev/null @@ -1,200 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/thread_pool.h" - -#include -#include // NOLINT(build/c++11) -#include // NOLINT(build/c++11) -#include -#include -#include -#include // NOLINT(build/c++11) -#include // NOLINT(build/c++11) - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/wait.h" - -namespace ruy { - -// A worker thread. -class Thread { - public: - enum class State { - Startup, // The initial state before the thread main loop runs. - Ready, // Is not working, has not yet received new work to do. - HasWork, // Has work to do. - ExitAsSoonAsPossible // Should exit at earliest convenience. - }; - - explicit Thread(BlockingCounter* counter_to_decrement_when_ready) - : task_(nullptr), - state_(State::Startup), - counter_to_decrement_when_ready_(counter_to_decrement_when_ready) { - thread_.reset(new std::thread(ThreadFunc, this)); - } - - ~Thread() { - ChangeState(State::ExitAsSoonAsPossible); - thread_->join(); - } - - // Changes State; may be called from either the worker thread - // or the master thread; however, not all state transitions are legal, - // which is guarded by assertions. - // - // The Task argument is to be used only with new_state==HasWork. - // It specifies the Task being handed to this Thread. - void ChangeState(State new_state, Task* task = nullptr) { - state_mutex_.lock(); - State old_state = state_.load(std::memory_order_relaxed); - RUY_DCHECK_NE(old_state, new_state); - switch (old_state) { - case State::Startup: - RUY_DCHECK_EQ(new_state, State::Ready); - break; - case State::Ready: - RUY_DCHECK(new_state == State::HasWork || - new_state == State::ExitAsSoonAsPossible); - break; - case State::HasWork: - RUY_DCHECK(new_state == State::Ready || - new_state == State::ExitAsSoonAsPossible); - break; - default: - abort(); - } - switch (new_state) { - case State::Ready: - if (task_) { - // Doing work is part of reverting to 'ready' state. - task_->Run(); - task_ = nullptr; - } - break; - case State::HasWork: - RUY_DCHECK(!task_); - task_ = task; - break; - default: - break; - } - state_.store(new_state, std::memory_order_relaxed); - state_cond_.notify_all(); - state_mutex_.unlock(); - if (new_state == State::Ready) { - counter_to_decrement_when_ready_->DecrementCount(); - } - } - - static void ThreadFunc(Thread* arg) { arg->ThreadFuncImpl(); } - - // Called by the master thead to give this thread work to do. - void StartWork(Task* task) { ChangeState(State::HasWork, task); } - - private: - // Thread entry point. - void ThreadFuncImpl() { - ChangeState(State::Ready); - - // Thread main loop - while (true) { - // In the 'Ready' state, we have nothing to do but to wait until - // we switch to another state. - const auto& condition = [this]() { - return state_.load(std::memory_order_acquire) != State::Ready; - }; - Wait(condition, &state_cond_, &state_mutex_); - - // Act on new state. - switch (state_.load(std::memory_order_acquire)) { - case State::HasWork: - // Got work to do! So do it, and then revert to 'Ready' state. - ChangeState(State::Ready); - break; - case State::ExitAsSoonAsPossible: - return; - default: - abort(); - } - } - } - - // The underlying thread. - std::unique_ptr thread_; - - // The task to be worked on. - Task* task_; - - // The condition variable and mutex guarding state changes. - std::condition_variable state_cond_; - std::mutex state_mutex_; - - // The state enum tells if we're currently working, waiting for work, etc. - // Its concurrent accesses by the thread and main threads are guarded by - // state_mutex_, and can thus use memory_order_relaxed. This still needs - // to be a std::atomic because we use WaitForVariableChange. - std::atomic state_; - - // pointer to the master's thread BlockingCounter object, to notify the - // master thread of when this thread switches to the 'Ready' state. - BlockingCounter* const counter_to_decrement_when_ready_; -}; - -void ThreadPool::ExecuteImpl(int task_count, int stride, Task* tasks) { - RUY_DCHECK_GE(task_count, 1); - - // Case of 1 thread: just run the single task on the current thread. - if (task_count == 1) { - (tasks + 0)->Run(); - return; - } - - // Task #0 will be run on the current thread. - CreateThreads(task_count - 1); - counter_to_decrement_when_ready_.Reset(task_count - 1); - for (int i = 1; i < task_count; i++) { - auto task_address = reinterpret_cast(tasks) + i * stride; - threads_[i - 1]->StartWork(reinterpret_cast(task_address)); - } - - // Execute task #0 immediately on the current thread. - (tasks + 0)->Run(); - - // Wait for the threads submitted above to finish. - counter_to_decrement_when_ready_.Wait(); -} - -// Ensures that the pool has at least the given count of threads. -// If any new thread has to be created, this function waits for it to -// be ready. -void ThreadPool::CreateThreads(int threads_count) { - if (threads_.size() >= threads_count) { - return; - } - counter_to_decrement_when_ready_.Reset(threads_count - threads_.size()); - while (threads_.size() < threads_count) { - threads_.push_back(new Thread(&counter_to_decrement_when_ready_)); - } - counter_to_decrement_when_ready_.Wait(); -} - -ThreadPool::~ThreadPool() { - for (auto w : threads_) { - delete w; - } -} - -} // end namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/thread_pool.h b/tensorflow/lite/experimental/ruy/ruy/thread_pool.h deleted file mode 100644 index 5504bd80614..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/thread_pool.h +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file is a fork of gemmlowp's multi_thread_gemm.h, under Apache 2.0 -// license. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_THREAD_POOL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_THREAD_POOL_H_ - -#include - -#include "tensorflow/lite/experimental/ruy/ruy/blocking_counter.h" - -namespace ruy { - -// A workload for a thread. -struct Task { - virtual ~Task() {} - virtual void Run() = 0; -}; - -class Thread; - -// A simple pool of threads, that only allows the very -// specific parallelization pattern that we use here: -// One thread, which we call the 'main thread', calls Execute, distributing -// a Task each to N threads, being N-1 'worker threads' and the main thread -// itself. After the main thread has completed its own Task, it waits for -// the worker threads to have all completed. That is the only synchronization -// performed by this ThreadPool. -// -// In particular, there is a naive 1:1 mapping of Tasks to threads. -// This ThreadPool considers it outside of its own scope to try to work -// with fewer threads than there are Tasks. The idea is that such N:M mappings -// of tasks to threads can be implemented as a higher-level feature on top of -// the present low-level 1:1 threadpool. For example, a user might have a -// Task subclass referencing a shared atomic counter indexing into a vector of -// finer-granularity subtasks. Different threads would then concurrently -// increment this atomic counter, getting each their own subtasks to work on. -// That approach is the one used in ruy's multi-thread matrix multiplication -// implementation --- see ruy's TrMulTask. -class ThreadPool { - public: - ThreadPool() {} - - ~ThreadPool(); - - // Executes task_count tasks on task_count threads. - // Grows the threadpool as needed to have at least (task_count-1) threads. - // The 0-th task is run on the thread on which Execute is called: that - // is by definition what we call the "main thread". Synchronization of all - // threads is performed before this function returns. - // - // As explained in the class comment, there is a 1:1 mapping of tasks to - // threads. If you need something smarter than that, for instance if you - // want to run an unbounded number of tasks on a bounded number of threads, - // then you need something higher-level than this ThreadPool, that can - // be layered on top of it by appropriately subclassing Tasks. - // - // TaskType must be a subclass of ruy::Task. That is implicitly guarded by - // the static_cast in this inline implementation. - template - void Execute(int task_count, TaskType* tasks) { - ExecuteImpl(task_count, sizeof(TaskType), static_cast(tasks)); - } - - private: - // Ensures that the pool has at least the given count of threads. - // If any new thread has to be created, this function waits for it to - // be ready. - void CreateThreads(int threads_count); - - // Non-templatized implementation of the public Execute method. - // See the inline implementation of Execute for how this is used. - void ExecuteImpl(int task_count, int stride, Task* tasks); - - // copy construction disallowed - ThreadPool(const ThreadPool&) = delete; - - // The threads in this pool. They are owned by the pool: - // the pool creates threads and destroys them in its destructor. - std::vector threads_; - - // The BlockingCounter used to wait for the threads. - BlockingCounter counter_to_decrement_when_ready_; -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_THREAD_POOL_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/time.h b/tensorflow/lite/experimental/ruy/ruy/time.h deleted file mode 100644 index 9dba75eb4c5..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/time.h +++ /dev/null @@ -1,81 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TIME_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TIME_H_ - -#include // NOLINT(build/c++11) -#include // IWYU pragma: keep -#include // NOLINT(build/c++11) - -#ifdef __linux__ -#include -// IWYU pragma: no_include - -#include -#endif - -namespace ruy { - -using InternalDefaultClock = std::chrono::steady_clock; - -using TimePoint = InternalDefaultClock::time_point; -using Duration = InternalDefaultClock::duration; - -template -Duration DurationFromSeconds(RepresentationType representation) { - return std::chrono::duration_cast( - std::chrono::duration(representation)); -} - -template -Duration DurationFromMilliseconds(RepresentationType representation) { - return std::chrono::duration_cast( - std::chrono::duration(representation)); -} - -template -Duration DurationFromNanoseconds(RepresentationType representation) { - return std::chrono::duration_cast( - std::chrono::duration(representation)); -} - -inline float ToFloatSeconds(const Duration& duration) { - return std::chrono::duration_cast>(duration) - .count(); -} - -inline std::int64_t ToInt64Nanoseconds(const Duration& duration) { - return std::chrono::duration_cast< - std::chrono::duration>(duration) - .count(); -} - -inline TimePoint Now() { return InternalDefaultClock::now(); } - -inline TimePoint CoarseNow() { -#ifdef __linux__ - timespec t; - clock_gettime(CLOCK_MONOTONIC_COARSE, &t); - return TimePoint( - DurationFromNanoseconds(1000000000LL * t.tv_sec + t.tv_nsec)); -#else - return Now(); -#endif -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TIME_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/trace.cc b/tensorflow/lite/experimental/ruy/ruy/trace.cc deleted file mode 100644 index 806f6ec2cf2..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/trace.cc +++ /dev/null @@ -1,325 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/trace.h" - -#include -#include // IWYU pragma: keep -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" -#include "tensorflow/lite/experimental/ruy/ruy/time.h" - -namespace ruy { - -#ifdef RUY_TRACE - -enum class TraceEvent : std::uint8_t { - kNone, - kThreadStart, - kThreadLoopStart, - kThreadEnd, - kBlockReserved, - kBlockPackedLhs, - kBlockPackedRhs, - kBlockFinished -}; - -struct TraceEntry { - TimePoint time_point; - TraceEvent event; - // ruy-internal thread id i.e. contiguous index into array of threads, - // with 0 designating the main thread. - std::uint16_t thread_id = 0; - // Additional parameters whose meaning depends on the 'event' type. - std::uint32_t params[1]; -}; - -struct Trace { - BlockMap block_map; - // During recording, to avoid having to use locks or atomics, we let - // each thread append to its own specific vector. - std::vector> thread_specific_entries; - // Global vector of entries into which we coalesce thread_specific_entries - // after recording is finished, when dumping a trace. See - // AggregateThreadSpecificEntries. - std::vector entries; - TimePoint time_start; - TimePoint time_execute; - TimePoint time_end; -}; - -namespace { - -// Coalesce Trace::thread_specific_entries into Trace::entries. -void AggregateThreadSpecificEntries(Trace* trace) { - RUY_CHECK(trace->entries.empty()); - for (auto& thread_specific_entries_vector : trace->thread_specific_entries) { - for (const TraceEntry& entry : thread_specific_entries_vector) { - trace->entries.push_back(entry); - } - thread_specific_entries_vector.clear(); - } -} - -// Sort Trace::entries by ascending time. In case of equal timepoints, -// sort by some semi-arbitrary ordering of event types. -void Sort(Trace* trace) { - std::sort(std::begin(trace->entries), std::end(trace->entries), - [](const TraceEntry& a, const TraceEntry& b) -> bool { - return a.time_point < b.time_point || - (a.time_point == b.time_point && - static_cast(a.event) < static_cast(b.event)); - }); -} - -// Dump a trace. Assumes that AggregateThreadSpecificEntries and Sort have -// already been called on it. -// -// On some architectures long long ints are not same as std::int64_t, and -// time is printed as %lld, so static_casts are necessary. -void Dump(const Trace& trace) { - const char* trace_filename = getenv("RUY_TRACE_FILE"); - FILE* trace_file = trace_filename ? fopen(trace_filename, "w") : stderr; - if (!trace_file) { - fprintf(stderr, "Failed to open %s for write, errno=%d\n", trace_filename, - errno); - RUY_CHECK(false); - } - fprintf(trace_file, "thread_count:%d\n", trace.block_map.thread_count); - fprintf(trace_file, "rows:%d\n", trace.block_map.dims[Side::kLhs]); - fprintf(trace_file, "cols:%d\n", trace.block_map.dims[Side::kRhs]); - fprintf(trace_file, "Execute: %lld\n", - static_cast( - ToInt64Nanoseconds(trace.time_execute - trace.time_start))); - for (const TraceEntry& entry : trace.entries) { - long long int time = static_cast( - ToInt64Nanoseconds(entry.time_point - trace.time_start)); - switch (entry.event) { - case TraceEvent::kThreadStart: - fprintf(trace_file, "ThreadStart: %lld, %d\n", time, entry.thread_id); - break; - case TraceEvent::kThreadLoopStart: - fprintf(trace_file, "ThreadLoopStart: %lld, %d\n", time, - entry.thread_id); - break; - case TraceEvent::kThreadEnd: - fprintf(trace_file, "ThreadEnd: %lld, %d\n", time, entry.thread_id); - break; - case TraceEvent::kBlockReserved: { - std::uint32_t block_id = entry.params[0]; - SidePair block; - GetBlockByIndex(trace.block_map, block_id, &block); - SidePair start, end; - GetBlockMatrixCoords(trace.block_map, block, &start, &end); - fprintf(trace_file, - "BlockReserved: %lld, %d, %d, %d, %d, %d, %d, %d, %d\n", time, - entry.thread_id, block_id, block[Side::kLhs], block[Side::kRhs], - start[Side::kLhs], start[Side::kRhs], end[Side::kLhs], - end[Side::kRhs]); - break; - } - case TraceEvent::kBlockPackedLhs: { - std::uint32_t block = entry.params[0]; - int start, end; - GetBlockMatrixCoords(Side::kLhs, trace.block_map, block, &start, &end); - fprintf(trace_file, "BlockPackedLhs: %lld, %d, %d, %d, %d\n", time, - entry.thread_id, block, start, end); - break; - } - case TraceEvent::kBlockPackedRhs: { - std::uint32_t block = entry.params[0]; - int start, end; - GetBlockMatrixCoords(Side::kRhs, trace.block_map, block, &start, &end); - fprintf(trace_file, "BlockPackedRhs: %lld, %d, %d, %d, %d\n", time, - entry.thread_id, block, start, end); - break; - } - case TraceEvent::kBlockFinished: { - std::uint32_t block_id = entry.params[0]; - SidePair block; - GetBlockByIndex(trace.block_map, block_id, &block); - fprintf(trace_file, "BlockFinished: %lld, %d, %d, %d, %d\n", time, - entry.thread_id, block_id, block[Side::kLhs], - block[Side::kRhs]); - break; - } - default: - RUY_CHECK(false); - } - } - fprintf(trace_file, "End: %lld\n", - static_cast( - ToInt64Nanoseconds(trace.time_end - trace.time_start))); - if (trace_filename) { - fclose(trace_file); - } -} - -} // anonymous namespace - -// Get a Trace object to record to, or null of tracing is not enabled. -Trace* NewTraceOrNull(TracingContext* tracing, int rows, int depth, int cols) { - if (!tracing->initialized) { - tracing->initialized = true; - tracing->enabled = getenv("RUY_TRACE"); - if (!tracing->enabled) { - return nullptr; - } - if (getenv("RUY_TRACE_FILTER_ROWS")) { - tracing->filter_shape_rows = std::stoi(getenv("RUY_TRACE_FILTER_ROWS")); - } - if (getenv("RUY_TRACE_FILTER_DEPTH")) { - tracing->filter_shape_depth = std::stoi(getenv("RUY_TRACE_FILTER_DEPTH")); - } - if (getenv("RUY_TRACE_FILTER_COLS")) { - tracing->filter_shape_cols = std::stoi(getenv("RUY_TRACE_FILTER_COLS")); - } - } - if (!tracing->enabled) { - return nullptr; - } - if (tracing->filter_shape_rows && rows != tracing->filter_shape_rows) { - return nullptr; - } - if (tracing->filter_shape_depth && depth != tracing->filter_shape_depth) { - return nullptr; - } - if (tracing->filter_shape_cols && cols != tracing->filter_shape_cols) { - return nullptr; - } - // Delete any existing trace. - delete tracing->trace; - // Create a new one. - tracing->trace = new Trace; - return tracing->trace; -} - -// The trace recorded on a context is finalized and dumped by -// this TracingContext destructor. -// -// The idea of dumping on context destructor is that typically one wants to -// run many matrix multiplications, e.g. to hit a steady state in terms of -// performance characteristics, but only trace the last repetition of the -// workload, when that steady state was attained. -TracingContext::~TracingContext() { - if (trace) { - AggregateThreadSpecificEntries(trace); - Sort(trace); - Dump(*trace); - } - delete trace; -} - -void TraceRecordStart(Trace* trace) { - if (trace) { - trace->time_start = Now(); - } -} - -void TraceRecordExecute(const BlockMap& block_map, Trace* trace) { - if (trace) { - trace->time_execute = Now(); - trace->block_map = block_map; - trace->thread_specific_entries.resize(block_map.thread_count); - for (int thread = 0; thread < block_map.thread_count; thread++) { - trace->thread_specific_entries[thread].clear(); - // Reserve some large size to avoid frequent heap allocations - // affecting the recorded timings. - trace->thread_specific_entries[thread].reserve(16384); - } - } -} - -void TraceRecordEnd(Trace* trace) { - if (trace) { - trace->time_end = Now(); - } -} - -void TraceRecordThreadStart(std::uint32_t thread_id, Trace* trace) { - if (trace) { - TraceEntry entry; - entry.event = TraceEvent::kThreadStart; - entry.time_point = Now(); - entry.thread_id = thread_id; - trace->thread_specific_entries[thread_id].push_back(entry); - } -} - -void TraceRecordThreadLoopStart(std::uint32_t thread_id, Trace* trace) { - if (trace) { - TraceEntry entry; - entry.event = TraceEvent::kThreadLoopStart; - entry.time_point = Now(); - entry.thread_id = thread_id; - trace->thread_specific_entries[thread_id].push_back(entry); - } -} - -void TraceRecordBlockReserved(std::uint32_t thread_id, std::uint32_t block_id, - Trace* trace) { - if (trace) { - TraceEntry entry; - entry.event = TraceEvent::kBlockReserved; - entry.time_point = Now(); - entry.thread_id = thread_id; - entry.params[0] = block_id; - trace->thread_specific_entries[thread_id].push_back(entry); - } -} - -void TraceRecordBlockPacked(std::uint32_t thread_id, Side side, int block, - Trace* trace) { - if (trace) { - TraceEntry entry; - entry.event = side == Side::kLhs ? TraceEvent::kBlockPackedLhs - : TraceEvent::kBlockPackedRhs; - entry.time_point = Now(); - entry.thread_id = thread_id; - entry.params[0] = block; - trace->thread_specific_entries[thread_id].push_back(entry); - } -} - -void TraceRecordBlockFinished(std::uint32_t thread_id, std::uint32_t block_id, - Trace* trace) { - if (trace) { - TraceEntry entry; - entry.event = TraceEvent::kBlockFinished; - entry.time_point = Now(); - entry.thread_id = thread_id; - entry.params[0] = block_id; - trace->thread_specific_entries[thread_id].push_back(entry); - } -} - -void TraceRecordThreadEnd(std::uint32_t thread_id, Trace* trace) { - if (trace) { - TraceEntry entry; - entry.event = TraceEvent::kThreadEnd; - entry.time_point = Now(); - entry.thread_id = thread_id; - trace->thread_specific_entries[thread_id].push_back(entry); - } -} - -#endif - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/trace.h b/tensorflow/lite/experimental/ruy/ruy/trace.h deleted file mode 100644 index 6680438c124..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/trace.h +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRACE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRACE_H_ - -#include - -#include "tensorflow/lite/experimental/ruy/ruy/block_map.h" -#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" - -namespace ruy { - -struct Trace; - -#ifdef RUY_TRACE - -struct TracingContext { - bool initialized = false; - bool enabled = false; - int filter_shape_rows = 0; - int filter_shape_cols = 0; - int filter_shape_depth = 0; - Trace* trace = nullptr; - ~TracingContext(); -}; - -Trace* NewTraceOrNull(TracingContext* context, int rows, int depth, int cols); -void TraceRecordThreadStart(std::uint32_t thread_id, Trace* trace); -void TraceRecordThreadLoopStart(std::uint32_t thread_id, Trace* trace); -void TraceRecordBlockReserved(std::uint32_t thread_id, std::uint32_t block_id, - Trace* trace); -void TraceRecordBlockPacked(std::uint32_t thread_id, Side side, int block, - Trace* trace); -void TraceRecordBlockFinished(std::uint32_t thread_id, std::uint32_t block_id, - Trace* trace); -void TraceRecordThreadEnd(std::uint32_t thread_id, Trace* trace); -void TraceRecordStart(Trace* trace); -void TraceRecordExecute(const BlockMap& block_map, Trace* trace); -void TraceRecordEnd(Trace* trace); - -#else - -struct TracingContext {}; - -inline Trace* NewTraceOrNull(TracingContext*, int, int, int) { return nullptr; } -inline void TraceRecordThreadStart(std::uint32_t, Trace*) {} -inline void TraceRecordThreadLoopStart(std::uint32_t, Trace*) {} -inline void TraceRecordBlockReserved(std::uint32_t, std::uint32_t, Trace*) {} -inline void TraceRecordBlockPacked(std::uint32_t, Side, int, Trace*) {} -inline void TraceRecordBlockFinished(std::uint32_t, std::uint32_t, Trace*) {} -inline void TraceRecordThreadEnd(std::uint32_t, Trace*) {} -inline void TraceRecordStart(Trace*) {} -inline void TraceRecordExecute(const BlockMap&, Trace*) {} -inline void TraceRecordEnd(Trace*) {} - -#endif - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRACE_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/trmul.cc b/tensorflow/lite/experimental/ruy/ruy/trmul.cc deleted file mode 100644 index c3e15a9d628..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/trmul.cc +++ /dev/null @@ -1,401 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/trmul.h" - -#include -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/allocator.h" -#include "tensorflow/lite/experimental/ruy/ruy/block_map.h" -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" -#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" -#include "tensorflow/lite/experimental/ruy/ruy/spec.h" -#include "tensorflow/lite/experimental/ruy/ruy/thread_pool.h" -#include "tensorflow/lite/experimental/ruy/ruy/trace.h" -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -namespace ruy { - -namespace { - -enum class PackingStatus : std::uint8_t { kNotStarted, kInProgress, kFinished }; - -struct TrMulTask final : Task { - TrMulTask(TrMulParams* params_, const BlockMap& block_map_, - std::atomic* atomic_block_id_, int thread_id_, - bool need_atomics_, - SidePair*> packing_status_, - TuningResolver* tuning_resolver_, Allocator* local_allocator_, - Trace* trace_) - : params(params_), - block_map(block_map_), - atomic_block_id(atomic_block_id_), - thread_id(thread_id_), - need_atomics(need_atomics_), - packing_status(packing_status_), - tuning_resolver(tuning_resolver_), - local_allocator(local_allocator_), - trace(trace_), - local_packed{nullptr, nullptr} {} - - void Run() override { - TraceRecordThreadStart(thread_id, trace); - - for (Side side : {Side::kLhs, Side::kRhs}) { - if (!params->is_prepacked[side]) { - const int size = NumBlocksPerSide(side, block_map); - local_allocator->Allocate(size, &local_packed[side]); - memset(local_packed[side], 0, size * sizeof(bool)); - } - } - - const int num_blocks = NumBlocks(block_map); - - const Tuning tuning = tuning_resolver->Resolve(); - - TraceRecordThreadLoopStart(thread_id, trace); - - SidePair block; - SidePair start; - SidePair end; - - // Each thread starts by initially reserving the block whose id - // is the thread id. - int block_id = thread_id; - TraceRecordBlockReserved(thread_id, block_id, trace); - - while (block_id < num_blocks) { - // Reserve the next block to handle. In order to hide the latency - // (typically comparable to an access to the level of data cache that - // is shared among CPU cores, e.g. 60 cycles on an ARM CPU as of 2019) - // of this atomic operation, we structure this code so as to avoid - // immediately depending on the `next_n` result. - const int next_block_id = - atomic_block_id->fetch_add(1, std::memory_order_relaxed); - TraceRecordBlockReserved(thread_id, next_block_id, trace); - // Get coordinates of the current block to handle, in "block space". - GetBlockByIndex(block_map, block_id, &block); - // Get coordinates of the current block to handle, in matrix space. - GetBlockMatrixCoords(block_map, block, &start, &end); - // Maybe pack the current LHS/RHS block, if not already packed. - EnsurePacked(block, start, end, tuning); - // Actually do matrix multiplication work - params->RunKernel(tuning, start, end); - TraceRecordBlockFinished(thread_id, block_id, trace); - // Move on to the next block as obtained by the atomic increment - // at the start of this while loop iteration. - block_id = next_block_id; - } - - local_allocator->FreeAll(); - - TraceRecordThreadEnd(thread_id, trace); - } - - private: - // Tries to pack a block, without blocking. - // If the block was already packed, returns true. - // If the block was not started packing, packs it and returns true. - // If the block was being packed by another thread, returns false. - bool TryPack(Side side, int block, int start, int end, Tuning tuning) { - if (params->is_prepacked[side]) { - return true; - } - if (!local_packed[side][block]) { - if (need_atomics) { - // Explanation of this compare_exchange_strong operation: - // This atomically performs all of the following: - // 1. Read `status` with "acquire" memory order. - // * That this read uses "acquire" is because both memory orders - // specified have "acquire" as their read-component. - // 2. Compare (bitwise) with `exchanged_status`. - // 3. If equal, stores the value kInProgress to `status` with "release" - // memory order, and returns true, so we take this 'if' branch. - // * That this store uses "release" is because of the _rel part in - // memory_order_acq_rel passed as the first memory order argument. - // 4. If not equal, stores the loaded value of `status` to - // `exchanged_status` with "relaxed" semantics, and returns false, - // so we take the 'else' branch. - // * That this store uses "relaxed" is because the second memory - // order argument, memory_order_acquire, implies no particular - // store semantics. "relaxed" is acceptable here because this - // stores to a local stack variable. - // - // Rationale for compare_exchange_strong as opposed to - // compare_exchange_weak: - // The spurious-failure case with compare_exchange_weak will actually - // happen a lot here, because the atomic 'status' bytes are stored - // contiguously in arrays and neighboring values will be accessed - // by multiple threads concurrently. On a typical ARM CPU, an exclusives - // reservation granule is 64 bytes, so a lot of false-sharing may - // happen. Using compare_exchange_weak would thus result in often having - // TryPack return 'false' when it could instead have done the packing - // work and returned 'true'. Heuristically, that is not a good thing. - // Moreover, this changes the TryPack contract, loosening it and making - // it harder for the caller to reason about. Finally, the overhead of - // atomic operations is mitigated by the enclosing check on - // local_packed, so maybe the overhead of compare_exchange_strong isn't - // such a problem. But we don't really know for sure, that would be - // interesting to experiment more with. - PackingStatus exchanged_status = PackingStatus::kNotStarted; - std::atomic& status = packing_status[side][block]; - if (status.compare_exchange_strong( - exchanged_status, PackingStatus::kInProgress, - std::memory_order_acq_rel, std::memory_order_acquire)) { - // In this branch, the status was kNotStarted and we just atomically - // changed it to kInProgress as we are about to handle the packing - // ourselves. - params->RunPack(side, tuning, start, end); - TraceRecordBlockPacked(thread_id, side, block, trace); - status.store(PackingStatus::kFinished, std::memory_order_release); - } else if (exchanged_status == PackingStatus::kInProgress) { - // Another thread is currently packing this block. - return false; - } - RUY_DCHECK(status.load(std::memory_order_acquire) == - PackingStatus::kFinished); - } else { - // Single-threaded case: no need for expensive atomics, local_packed - // is the truth already. - params->RunPack(side, tuning, start, end); - TraceRecordBlockPacked(thread_id, side, block, trace); - } - local_packed[side][block] = true; - } - return true; - } - - // Ensures that both the LHS and RHS blocks required by the specified block - // are packed. In the event that they are already being packed on another - // threads, this function may perform the packing of some other block while - // waiting for that other thread to finish packing the requested block. - void EnsurePacked(const SidePair& block, const SidePair& start, - const SidePair& end, Tuning tuning) { -#if RUY_OPT_ENABLED(RUY_OPT_PACK_AHEAD) - SidePair next_runahead_block{block[Side::kLhs] + 1, - block[Side::kRhs] + 1}; - Side next_runahead_side = Side::kLhs; -#endif - while (true) { - bool both_sides_packed = true; - for (Side side : {Side::kLhs, Side::kRhs}) { - both_sides_packed &= - TryPack(side, block[side], start[side], end[side], tuning); - } - if (both_sides_packed) { - break; - } -#if RUY_OPT_ENABLED(RUY_OPT_PACK_AHEAD) - const Side runahead_side = next_runahead_side; - const int runahead_block = next_runahead_block[runahead_side]; - next_runahead_side = - next_runahead_side == Side::kLhs ? Side::kRhs : Side::kLhs; - if (runahead_block >= NumBlocksPerSide(runahead_side, block_map)) { - continue; - } - int runahead_block_start, runahead_block_end; - GetBlockMatrixCoords(runahead_side, block_map, runahead_block, - &runahead_block_start, &runahead_block_end); - TryPack(runahead_side, runahead_block, runahead_block_start, - runahead_block_end, tuning); - next_runahead_block[runahead_side] = runahead_block + 1; -#endif - } - } - - TrMulParams* params; - const BlockMap& block_map; - std::atomic* atomic_block_id; - int thread_id; - bool need_atomics; - SidePair*> packing_status; - TuningResolver* tuning_resolver; - Allocator* local_allocator; - Trace* trace; - - // Local indicators of packedness to avoid the overhead of atomic ops. - SidePair local_packed; -}; - -void AllocatePMatrix(Allocator* allocator, PMatrix* packed) { - packed->data = allocator->AllocateBytes(DataSize(*packed)); - packed->sums = allocator->AllocateBytes(SumsSize(*packed)); -} - -int GetThreadCount(Context* context, int rows, int cols, int depth) { -#if RUY_PLATFORM(EMSCRIPTEN) - // b/139927184, std::thread constructor raises exception - return 1; -#endif - // Empirically determined rule for reasonable number of - // threads to use. This is proportional to the number of arithmetic ops - // in this Mul (product of the 3 sizes). - static constexpr int kDivisorLog2 = 15; - const int guess_log2 = std::max( - 0, ceil_log2(rows) + ceil_log2(cols) + ceil_log2(depth) - kDivisorLog2); - return std::min(1 << guess_log2, context->max_num_threads); -} - -LoopStructure GetLoopStructure(int tentative_thread_count, int rows, int cols, - int depth, int lhs_scalar_size, - int rhs_scalar_size, int local_data_cache_size, - int shared_data_cache_size) { - if (tentative_thread_count == 1) { - const BlockMapTraversalOrder traversal_order = - GetTraversalOrder(rows, cols, depth, lhs_scalar_size, rhs_scalar_size, - local_data_cache_size, shared_data_cache_size); - // If we are in the GEMV case or the block_map would be using linear - // traversal anyway, use the simple loop. - if ((cols == 1) || traversal_order == BlockMapTraversalOrder::kLinear) { - return LoopStructure::kSimple; - } - } - return LoopStructure::kGeneral; -} - -} // namespace - -void TrMul(TrMulParams* params, Context* context) { - profiler::ScopeLabel label( - "TrMul (Path=0x%x, max_num_threads=%d, is_prepacked=(%d,%d))", - static_cast(params->path), context->max_num_threads, - params->is_prepacked[Side::kLhs], params->is_prepacked[Side::kRhs]); - - PMatrix& packed_lhs = params->packed[Side::kLhs]; - PMatrix& packed_rhs = params->packed[Side::kRhs]; - DMatrix& lhs = params->src[Side::kLhs]; - DMatrix& rhs = params->src[Side::kRhs]; - - const int rows = lhs.layout.cols; - const int cols = rhs.layout.cols; - const int depth = lhs.layout.rows; - - const int tentative_thread_count = GetThreadCount(context, rows, cols, depth); - const auto loop_structure = GetLoopStructure( - tentative_thread_count, rows, cols, depth, lhs.data_type.size, - rhs.data_type.size, params->local_data_cache_size, - params->shared_data_cache_size); - Allocator* allocator = context->GetMainAllocator(); - - // Allocate packed matrices - for (Side side : {Side::kLhs, Side::kRhs}) { - if (!params->is_prepacked[side]) { - AllocatePMatrix(allocator, ¶ms->packed[side]); - } - } - - // Case of running this TrMul as a simple loop. - // This is a good place to start reading this function: all the rest - // of this function is just an optimized, but functionally equivalent, - // version of that. - if (loop_structure == LoopStructure::kSimple) { - profiler::ScopeLabel label_simple("TrMulImpl, simple loop"); - Tuning tuning = context->GetMainThreadTuning(); - - const SidePair origin{0, 0}; - const SidePair rounded_dims{packed_lhs.layout.cols, - packed_rhs.layout.cols}; - for (Side side : {Side::kLhs, Side::kRhs}) { - if (!params->is_prepacked[side]) { - params->RunPack(side, tuning, origin[side], rounded_dims[side]); - } - } - params->RunKernel(tuning, origin, rounded_dims); - - allocator->FreeAll(); - return; - } - - profiler::ScopeLabel label_general("TrMulImpl, general case"); - - auto* trace = NewTraceOrNull(&context->tracing, rows, depth, cols); - TraceRecordStart(trace); - - // Initialize block map. - BlockMap block_map; - MakeBlockMap(packed_lhs.layout.cols, packed_rhs.layout.cols, depth, - packed_lhs.layout.kernel.cols, packed_rhs.layout.kernel.cols, - packed_lhs.data_type.size, packed_rhs.data_type.size, - tentative_thread_count, params->path, - params->local_data_cache_size, params->shared_data_cache_size, - &block_map); - - // Initialize per-thread state. - const int thread_count = block_map.thread_count; - const bool need_atomics = thread_count > 1; - context->EnsureNPerThreadStates(thread_count); - for (auto& per_thread_state : context->per_thread_states) { - per_thread_state->tuning_resolver.SetTuning(context->explicit_tuning); - } - - // In the need_atomics case, allocate and initialize atomic values tracking - // the packing status of blocks. - SidePair*> packing_status{nullptr, nullptr}; - if (need_atomics) { - for (Side side : {Side::kLhs, Side::kRhs}) { - if (!params->is_prepacked[side]) { - const int size = NumBlocksPerSide(side, block_map); - allocator->Allocate(size, &packing_status[side]); - for (int i = 0; i < size; i++) { - packing_status[side][i].store(PackingStatus::kNotStarted, - std::memory_order_relaxed); - } - } - } - } - - // Create the atomic block id, allocate it using Allocator so that - // we get the alignment ensuring that it sits alone in its exclusives - // reservation granule. - std::atomic* atomic_block_id; - allocator->Allocate(1, &atomic_block_id); - - // Create task objects. - TrMulTask* tasks; - allocator->Allocate(thread_count, &tasks); - - atomic_block_id->store(thread_count); - - for (int i = 0; i < thread_count; i++) { - new (tasks + i) TrMulTask(params, block_map, atomic_block_id, i, - need_atomics, packing_status, - &context->per_thread_states[i]->tuning_resolver, - &context->per_thread_states[i]->allocator, trace); - } - - // Do the computation. - TraceRecordExecute(block_map, trace); - context->workers_pool.Execute(thread_count, tasks); - - // Finish up. - for (int i = 0; i < thread_count; i++) { - tasks[i].~TrMulTask(); - } - - allocator->FreeAll(); - TraceRecordEnd(trace); -} - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/trmul.h b/tensorflow/lite/experimental/ruy/ruy/trmul.h deleted file mode 100644 index 9786b7f6180..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/trmul.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// As a matrix multiplication library, Ruy offers a Mul entry point, performing -// matrix multiplication. For implementation purposes, it is much nicer to -// be dealing with the transpose-and-multiply operation, doing -// Destination = Transpose(LHS) * RHS -// Indeed, the latter is performing dot-products between the *columns* of LHS -// and the columns of RHS, whereas a plain matrix multiplication is performing -// dot-products between the *rows* of LHS and the columns of RHS. -// That is why TrMul is nicer to implement, allowing for a more symmetric -// treatment of LHS and RHS. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_H_ - -#include "tensorflow/lite/experimental/ruy/ruy/context.h" -#include "tensorflow/lite/experimental/ruy/ruy/trmul_params.h" - -namespace ruy { - -void TrMul(TrMulParams* params, Context* context); - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/trmul_params.h b/tensorflow/lite/experimental/ruy/ruy/trmul_params.h deleted file mode 100644 index c694f16b938..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/trmul_params.h +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_PARAMS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_PARAMS_H_ - -#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -namespace ruy { - -using RunKernelFn = void(Tuning, const SidePair&, void*, - const SidePair&, const SidePair&, DMatrix*); - -using RunPackFn = void(Tuning, const DMatrix&, PMatrix*, int, int); - -// Type-erased data needed for implementing TrMul. -struct TrMulParams { - TrMulParams() : run_pack{nullptr, nullptr}, is_prepacked{false, false} {} - // Helper functions for invoking the function pointers. - void RunPack(Side side, Tuning tuning, int start, int end) { - run_pack[side](tuning, src[side], &packed[side], start, end); - } - void RunKernel(Tuning tuning, const SidePair& start, - const SidePair& end) { - run_kernel(tuning, packed, spec, start, end, &dst); - } - - // path id, can be useful info for some fine-tuning, e.g. to guess reasonable - // cache sizes when not runtime-detectable. - Path path; - - // See Spec::local_data_cache_size(). - int local_data_cache_size = 0; - // See Spec::shared_data_cache_size(). - int shared_data_cache_size = 0; - - // Function pointers to type-erased entry points for kernels and packers. - SidePair run_pack; - RunKernelFn* run_kernel = nullptr; - - // Matrices and packed matrices. - SidePair src; - DMatrix dst; - SidePair packed; - SidePair is_prepacked; - - // Type-erased Spec. - void* spec = nullptr; -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_PARAMS_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/tune.cc b/tensorflow/lite/experimental/ruy/ruy/tune.cc deleted file mode 100644 index 63fa0338d6d..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/tune.cc +++ /dev/null @@ -1,161 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -#include -#include - -namespace ruy { - -#ifdef RUY_IMPLEMENT_TUNING - -namespace { - -void PoorlyOrderedKernel(int iters) { - asm volatile( - "mov w0, %w[iters]\n" - "1:\n" - "subs w0, w0, #1\n" - "mul v0.4s, v0.4s, v0.4s\n" - "mul v0.4s, v0.4s, v0.4s\n" - "mul v0.4s, v0.4s, v0.4s\n" - "mul v0.4s, v0.4s, v0.4s\n" - "mul v1.4s, v1.4s, v1.4s\n" - "mul v1.4s, v1.4s, v1.4s\n" - "mul v1.4s, v1.4s, v1.4s\n" - "mul v1.4s, v1.4s, v1.4s\n" - "mul v2.4s, v2.4s, v2.4s\n" - "mul v2.4s, v2.4s, v2.4s\n" - "mul v2.4s, v2.4s, v2.4s\n" - "mul v2.4s, v2.4s, v2.4s\n" - "mul v3.4s, v3.4s, v3.4s\n" - "mul v3.4s, v3.4s, v3.4s\n" - "mul v3.4s, v3.4s, v3.4s\n" - "mul v3.4s, v3.4s, v3.4s\n" - "bne 1b\n" ::[iters] "r"(iters) - : "cc", "x0", "v0", "v1", "v2", "v3"); -} - -void NicelyOrderedKernel(int iters) { - asm volatile( - "mov w0, %w[iters]\n" - "1:\n" - "subs w0, w0, #1\n" - "mul v0.4s, v0.4s, v0.4s\n" - "mul v1.4s, v1.4s, v1.4s\n" - "mul v2.4s, v2.4s, v2.4s\n" - "mul v3.4s, v3.4s, v3.4s\n" - "mul v0.4s, v0.4s, v0.4s\n" - "mul v1.4s, v1.4s, v1.4s\n" - "mul v2.4s, v2.4s, v2.4s\n" - "mul v3.4s, v3.4s, v3.4s\n" - "mul v0.4s, v0.4s, v0.4s\n" - "mul v1.4s, v1.4s, v1.4s\n" - "mul v2.4s, v2.4s, v2.4s\n" - "mul v3.4s, v3.4s, v3.4s\n" - "mul v0.4s, v0.4s, v0.4s\n" - "mul v1.4s, v1.4s, v1.4s\n" - "mul v2.4s, v2.4s, v2.4s\n" - "mul v3.4s, v3.4s, v3.4s\n" - "bne 1b\n" ::[iters] "r"(iters) - : "cc", "x0", "v0", "v1", "v2", "v3"); -} - -} // namespace - -float TuningResolver::EvalRatio() { - // With the current settings, 400 iterations and 4 repeats, this test has - // a latency of roughly 80 microseconds on a Cortex-A53 at 1.4 GHz. - static constexpr int kLoopIters = 400; - static constexpr int kRepeats = 4; - - Duration timing_poorly_ordered = Duration::max(); - Duration timing_nicely_ordered = Duration::max(); - - for (int r = 0; r < kRepeats; r++) { - TimePoint t0 = Now(); - PoorlyOrderedKernel(kLoopIters); - TimePoint t1 = Now(); - NicelyOrderedKernel(kLoopIters); - TimePoint t2 = Now(); - timing_poorly_ordered = std::min(timing_poorly_ordered, t1 - t0); - timing_nicely_ordered = std::min(timing_nicely_ordered, t2 - t1); - } - - return ToFloatSeconds(timing_nicely_ordered) / - ToFloatSeconds(timing_poorly_ordered); -} - -float TuningResolver::ThresholdRatio() { - // Empirically (see :tune_tool) determined threshold to distinguish in-order - // Cortex-A53/A55 cores from out-of-order Cortex-A57/A73/A75/A76 cores. Based - // on these experimental results, which were obtained with much lower - // (kLoopIters=1000, kRepeats=1) so as to make them resilient to noise, we - // have: - // - // CPU core type | in/out of order | observed ratio - // --------------+-----------------+----------------------------------------- - // Cortex-A53 | in-order | 0.32 -- 0.329 - // Cortex-A55 | in-order | 0.319 -- 0.325 - // Cortex-A55r1 | in-order | 0.319 -- 0.325 - // Cortex-A57 | out-of-order | 0.99 -- 1.01 - // Cortex-A73 | out-of-order | 0.922 -- 0.927 - // Cortex-A75 | out-of-order | 0.921 -- 0.93 - // Cortex-A76 | out-of-order | 1 - // Kryo (pixel1) | out-of-order | 0.73 -- 0.76 - // - // Thus the allowable range for the threshold is [0.35 .. 0.70]. - // We pick a value closer to the upper bound because really any out-of-order - // CPU should by definition produce a ratio close to 1. - return 0.65f; -} - -Tuning TuningResolver::ResolveNow() { - const bool is_probably_inorder = EvalRatio() < ThresholdRatio(); - return is_probably_inorder ? Tuning::kInOrder : Tuning::kOutOfOrder; -} - -#else // not defined RUY_IMPLEMENT_TUNING - -float TuningResolver::EvalRatio() { return 0; } -float TuningResolver::ThresholdRatio() { return 0; } - -Tuning TuningResolver::ResolveNow() { return Tuning::kOutOfOrder; } - -#endif - -TuningResolver::TuningResolver() - : expiry_duration_(DurationFromMilliseconds(250)) {} - -Tuning TuningResolver::Resolve() { -#ifdef RUY_IMPLEMENT_TUNING - if (unresolved_tuning_ != Tuning::kAuto) { - return unresolved_tuning_; - } - TimePoint new_timepoint = CoarseNow(); - if (last_resolved_tuning_ != Tuning::kAuto && - (new_timepoint - last_resolved_timepoint_) < expiry_duration_) { - return last_resolved_tuning_; - } - last_resolved_timepoint_ = new_timepoint; - last_resolved_tuning_ = ResolveNow(); - return last_resolved_tuning_; -#else - return Tuning::kOutOfOrder; -#endif -} - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/tune.h b/tensorflow/lite/experimental/ruy/ruy/tune.h deleted file mode 100644 index 3471604e37a..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/tune.h +++ /dev/null @@ -1,163 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Library doing minimal CPU detection to decide what to tune asm code for. -// -// # Tuning vs Path -// -// Tunings are merely local variations of optimized code paths, that are -// drop-in replacements for each other --- the input and output data layouts -// are identical. By contrast, what ruy calls a Path dictates its own -// data layouts. For example, Path::kNeonDotprod will use different -// layouts compared to Path::kNeon; but within each, different tunings -// will share that same layout. -// -// # Tuning is for now only based on 1 bit: OutOfOrder / InOrder -// -// In practice, each of our asm code paths only needs one bit information to -// decide on tuning: whether the CPU is out-of-order or in-order. -// That is because out-of-order CPUs are by definition relatively insensitive -// to small-scale asm details (which is what "tuning" is about); and for each -// asm code path, there tends to be one main in-order CPU architecture that -// we focus our tuning effort on. Examples: -// * For Path::kNeon, the main in-order CPU is Cortex-A53/A55 (pre-dotprod) -// * For Path::kNeonDotprod, the main in-order CPU is Cortex-A55r1 (dotprod) -// -// Because having tuned code paths is a compromise of efficiency gains -// versus implementation effort and code size, we are happy to stop at just this -// single bit of information, OutOfOrder/InOrder, at least in the current CPU -// landscape. This could change in the future. -// -// # Implementation notes and alternatives. -// -// The current implementation uses a nano-benchmark, see tune.cc. -// That is why it's quite expensive, making caching / -// statefulness necessary (see TuningResolver class comment). -// -// An interesting alternative, which was explained to us by Marat Dukhan -// (maratek@) after this was implemented, would be to use the -// getcpu(2) system call on Linux. This returns a -// numeric CPU identifier that could be mapped to a OutOfOrder/InOrder -// classification given additional information about the CPU. Such -// additional information could be obtained by the cpuinfo library, -// https://github.com/pytorch/cpuinfo -// which obtains this information mainly from parsing /proc/cpuinfo. -// Pros: -// * Would remove the need for the relatively expensive nano-benchmark -// (dozens of microseconds, which have to be reevaluated again several -// times per second). -// * Would conceivably be more reliable. -// Cons: -// * Linux-specific. -// * Modest binary size increase (Marat mentioned the cpuinfo lib is 20k). -// * Won't support exactly 100% of devices (nonstandard /proc/cpuinfo etc). -// -// We could also have both: -// * Maybe by trying getcpu first if supported, then falling back to a -// nano-benchmark. -// * Maybe using getcpu in conjunction with the nano-benchmark to cache -// per-CPU-id nano-benchmark results. -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TUNE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TUNE_H_ - -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/time.h" - -// Tuning only implemented on NEON_64 at the moment (see assembly code -// in the nano-benchmark) and not on Apple (some Apple CPUs produce incorrect -// results on in-order-tuned kernels combining ARM and NEON load instructions -// and NEON `ins` instructions). -// -// When tuning is not implemented, we simply always use Tuning::kOutOfOrder. -#if RUY_OPT_ENABLED(RUY_OPT_TUNING) && RUY_PLATFORM(NEON_64) && \ - !RUY_PLATFORM(APPLE) -#define RUY_IMPLEMENT_TUNING -#endif - -namespace ruy { - -enum class Tuning { - // kAuto means please use auto-detection. It's the default in the - // user-visible parts (see Context). It's meant to be resolved to an - // actual tuning at some point by means of TuningResolver. - kAuto, - // Target an out-order CPU. Example: ARM Cortex-A75. - kOutOfOrder, - // Target an in-order CPU. Example: ARM Cortex-A55. - kInOrder -}; - -// Why a TuningResolver class? -// -// Ideally, this Library would offer a single function, -// Tuning GetCurrentCPUTuning(); -// -// However, determining information about the current CPU is not necessarily, -// cheap, so we currently cache that and only invalidate/reevaluate after -// a fixed amount of time. This need to store state is why this library -// has to expose a class, TuningResolver, not just a function. -class TuningResolver { - public: - TuningResolver(); - - // Allows the user to specify an explicit Tuning value, bypassing auto - // detection; or to specify Tuning::kAuto, reverting to auto detection. - void SetTuning(Tuning tuning) { unresolved_tuning_ = tuning; } - - // Get an actual tuning --- that is the function that this class wanted to be. - Tuning Resolve(); - - private: - TuningResolver(const TuningResolver&) = delete; - - // TuningTool is a demo/tool used to tweak the tuning implementation to - // specific devices. It needs to access some finer granularity information - // than just the Tuning returned by Resolve. Nothing else should need - // access to that. - friend class TuneTool; - // Actually runs a nano-benchmark, producing a real number called 'ratio' - // whose meaning is generally opaque / implementation defined. Typically, - // this would be the ratio between the latencies of two different - // pieces of asm code differing only by the ordering of instructions, - // revealing whether the CPU cares about such ordering details. - // An implementation may just return a dummy value if it is not based on - // such nanobenchmarking / ratio evaluation. - float EvalRatio(); - // Empirically determined threshold on ratio values delineating - // out-of-order (ratios closer to 1) from in-order (ratios farther from 1). - // An implementation may just return a dummy value if it is not based on - // such nanobenchmarking / ratio evaluation. - float ThresholdRatio(); - // Perform the tuning resolution now. That may typically use EvalRatio and - // ThresholdRatio, but an implementation may use a different approach instead. - Tuning ResolveNow(); - - // The tuning as specified by the user, before actual resolution happens - // i.e. before querying any specifics of the current CPU. - // The default value kAuto means try to auto-detect. Other values mean - // bypass auto-detect, use explicit value instead. See SetTuning(). - Tuning unresolved_tuning_ = Tuning::kAuto; - // Cached last resolved tuning. - Tuning last_resolved_tuning_ = Tuning::kAuto; - // Timepoint of cached last resolved tuning, for invalidation purposes. - TimePoint last_resolved_timepoint_; - // Cached last resolved tunings that are older than this age are invalid. - const Duration expiry_duration_; -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TUNE_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/tune_test.cc b/tensorflow/lite/experimental/ruy/ruy/tune_test.cc deleted file mode 100644 index 0b00e645195..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/tune_test.cc +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -#include // NOLINT(build/c++11) -#include // NOLINT(build/c++11) - -#include - -namespace ruy { -namespace { - -TEST(TuneTest, TuneTest) { - TuningResolver tuning_resolver; - ASSERT_FALSE(tuning_resolver.Resolve() == Tuning::kAuto); - // 1 second is likely higher than TuningResolver's internal cache expiry, - // exercising the logic invalidating earlier tuning resolutions. - std::this_thread::sleep_for(std::chrono::seconds(1)); - ASSERT_FALSE(tuning_resolver.Resolve() == Tuning::kAuto); - - tuning_resolver.SetTuning(Tuning::kAuto); - -#ifdef RUY_IMPLEMENT_TUNING - for (auto tuning : {Tuning::kOutOfOrder, Tuning::kInOrder}) { - tuning_resolver.SetTuning(tuning); - ASSERT_TRUE(tuning_resolver.Resolve() == tuning); - // See above comment about 1 second. - std::this_thread::sleep_for(std::chrono::seconds(1)); - ASSERT_TRUE(tuning_resolver.Resolve() == tuning); - } -#endif -} - -} // namespace -} // namespace ruy - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/experimental/ruy/ruy/tune_tool.cc b/tensorflow/lite/experimental/ruy/ruy/tune_tool.cc deleted file mode 100644 index 04cfa6d6b89..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/tune_tool.cc +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Self-contained tool used to tune the tune code --- see the -// threshold ratios used in tune.cc. - -#include // NOLINT(build/c++11) -#include -#include // NOLINT(build/c++11) - -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -#ifdef _WIN32 -#define getpid() 0 -#else -#include -#endif - -namespace ruy { - -class TuneTool { - public: - static void Query(float* eval, float* threshold) { - TuningResolver resolver; - *eval = resolver.EvalRatio(); - *threshold = resolver.ThresholdRatio(); - } -}; - -} // namespace ruy - -int main() { - // Infinite loop: the user can hit Ctrl-C - while (true) { - float eval; - float threshold; - ruy::TuneTool::Query(&eval, &threshold); - printf("[%d] eval=%.3f %c threshold=%.3f ==> probably %s...\n", getpid(), - eval, eval < threshold ? '<' : '>', threshold, - eval < threshold ? "in-order" : "out-of-order"); - fflush(stdout); - std::this_thread::sleep_for(std::chrono::seconds(1)); - } -} diff --git a/tensorflow/lite/experimental/ruy/ruy/wait.cc b/tensorflow/lite/experimental/ruy/ruy/wait.cc deleted file mode 100644 index 7d91b6ebce6..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/wait.cc +++ /dev/null @@ -1,69 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/wait.h" - -#include // NOLINT(build/c++11) - -namespace ruy { - -void Wait(const std::function& condition, const Duration& spin_duration, - std::condition_variable* condvar, std::mutex* mutex) { - // First, trivial case where the `condition` is already true; - if (condition()) { - return; - } - - // Then try busy-waiting. - const TimePoint wait_start = Now(); - while (Now() - wait_start < spin_duration) { - if (condition()) { - return; - } - } - - // Finally, do real passive waiting. - std::unique_lock lock(*mutex); - condvar->wait(lock, condition); -} - -void Wait(const std::function& condition, - std::condition_variable* condvar, std::mutex* mutex) { - // This value was empirically derived with some microbenchmark, we don't have - // high confidence in it. - // - // TODO(b/135595069): make this value configurable at runtime. - // I almost wanted to file another bug to ask for experimenting in a more - // principled way to tune this value better, but this would have to be tuned - // on real end-to-end applications and we'd expect different applications to - // require different tunings. So the more important point is the need for - // this to be controllable by the application. - // - // That this value means that we may be sleeping substantially longer - // than a scheduler timeslice's duration is not necessarily surprising. The - // idea is to pick up quickly new work after having finished the previous - // workload. When it's new work within the same GEMM as the previous work, the - // time interval that we might be busy-waiting is very small, so for that - // purpose it would be more than enough to sleep for 1 ms. - // That is all what we would observe on a GEMM benchmark. However, in a real - // application, after having finished a GEMM, we might do unrelated work for - // a little while, then start on a new GEMM. In that case the wait interval - // may be a little longer. There may also not be another GEMM for a long time, - // in which case we'll end up passively waiting below. - const Duration spin_duration = DurationFromMilliseconds(2); - Wait(condition, spin_duration, condvar, mutex); -} - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/wait.h b/tensorflow/lite/experimental/ruy/ruy/wait.h deleted file mode 100644 index a3cd26282af..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/wait.h +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_WAIT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_WAIT_H_ - -#include // NOLINT(build/c++11) -#include -#include // NOLINT(build/c++11) - -#include "tensorflow/lite/experimental/ruy/ruy/time.h" - -namespace ruy { - -// Waits until some evaluation of `condition` has returned true. -// -// There is no guarantee that calling `condition` again after this function -// has returned would still return true. The only -// contract is that at some point during the execution of that function, -// `condition` has returned true. -// -// First does some spin-waiting for the specified `spin_duration`, -// then falls back to passive waiting for the given condvar, guarded -// by the given mutex. At this point it will try to acquire the mutex lock, -// around the waiting on the condition variable. -// Therefore, this function expects that the calling thread hasn't already -// locked the mutex before calling it. -// This function will always release the mutex lock before returning. -// -// The idea of doing some initial spin-waiting is to help get -// better and more consistent multithreading benefits for small GEMM sizes. -// Spin-waiting help ensuring that if we need to wake up soon after having -// started waiting, then we can wake up quickly (as opposed to, say, -// having to wait to be scheduled again by the OS). On the other hand, -// we must still eventually revert to passive waiting for longer waits -// (e.g. worker threads having finished a GEMM and waiting until the next GEMM) -// so as to avoid permanently spinning. -// -// In situations where other threads might have more useful things to do with -// these CPU cores than our spin-waiting, it may be best to reduce the value -// of `spin_duration`. Setting it to zero disables the spin-waiting entirely. -// -// There is a risk that the std::function used here might use a heap allocation -// to store its context. The expected usage pattern is that these functions' -// contexts will consist of a single pointer value (typically capturing only -// [this]), and that in this case the std::function implementation will use -// inline storage, avoiding a heap allocation. However, we can't effectively -// guard that assumption, and that's not a big concern anyway because the -// latency of a small heap allocation is probably low compared to the intrinsic -// latency of what this Wait function does. -void Wait(const std::function& condition, const Duration& spin_duration, - std::condition_variable* condvar, std::mutex* mutex); - -// Convenience overload using a default `spin_duration`. -// TODO(benoitjacob): let this be controlled from the ruy API. -void Wait(const std::function& condition, - std::condition_variable* condvar, std::mutex* mutex); - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_WAIT_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/wait_test.cc b/tensorflow/lite/experimental/ruy/ruy/wait_test.cc deleted file mode 100644 index b1b7558583d..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/wait_test.cc +++ /dev/null @@ -1,117 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/wait.h" - -#include -#include // NOLINT(build/c++11) -#include // NOLINT(build/c++11) -#include // NOLINT(build/c++11) - -#include -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" - -namespace ruy { -namespace { - -// Thread taking a `value` atomic counter and incrementing it until it equals -// `end_value`, then notifying the condition variable as long as -// `value == end_value`. If `end_value` is increased, it will then resume -// incrementing `value`, etc. Terminates if `end_value == -1`. -class ThreadCountingUpToValue { - public: - ThreadCountingUpToValue(const std::atomic& end_value, - std::atomic* value, - std::condition_variable* condvar, std::mutex* mutex) - : end_value_(end_value), - value_(value), - condvar_(condvar), - mutex_(mutex) {} - void operator()() { - // end_value_==-1 is how the master thread will tell us it's OK to terminate - while (end_value_.load() != -1) { - // wait until end_value is set to a higher value - while (value_->load() == end_value_.load()) { - } - // increment value as long as it's lower than end_value - while (value_->fetch_add(1) < end_value_.load() - 1) { - } - // when value has reached end_value, notify the master thread. - while (value_->load() == end_value_.load()) { - std::lock_guard lock(*mutex_); - condvar_->notify_all(); - } - } - } - - private: - const std::atomic& end_value_; - std::atomic* value_; - std::condition_variable* condvar_; - std::mutex* mutex_; -}; - -void WaitTest(const Duration& spin_duration, const Duration& delay) { -#if RUY_PLATFORM(EMSCRIPTEN) - // b/139927184, std::thread constructor raises exception - return; -#endif - std::condition_variable condvar; - std::mutex mutex; - std::atomic value(0); - std::atomic end_value(0); - ThreadCountingUpToValue thread_callable(end_value, &value, &condvar, &mutex); - std::thread thread(thread_callable); - std::this_thread::sleep_for(delay); - for (int i = 1; i < 10; i++) { - end_value.store(1000 * i); - const auto& condition = [&value, &end_value]() { - return value.load() == end_value.load(); - }; - ruy::Wait(condition, spin_duration, &condvar, &mutex); - EXPECT_EQ(value.load(), end_value.load()); - } - end_value.store(-1); - thread.join(); -} - -TEST(WaitTest, WaitTestNoSpin) { - WaitTest(DurationFromSeconds(0), DurationFromSeconds(0)); -} - -TEST(WaitTest, WaitTestSpinOneMicrosecond) { - WaitTest(DurationFromSeconds(1e-6), DurationFromSeconds(0)); -} - -TEST(WaitTest, WaitTestSpinOneMillisecond) { - WaitTest(DurationFromSeconds(1e-3), DurationFromSeconds(0)); -} - -TEST(WaitTest, WaitTestSpinOneSecond) { - WaitTest(DurationFromSeconds(1), DurationFromSeconds(0)); -} - -// Testcase to consistently reproduce the hang in b/139062384. -TEST(WaitTest, WaitTestNoSpinWithDelayBug139062384) { - WaitTest(DurationFromSeconds(0), DurationFromSeconds(1)); -} - -} // namespace -} // namespace ruy - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java index cfaabda4731..b2b7a339a75 100644 --- a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java +++ b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java @@ -16,6 +16,7 @@ limitations under the License. package org.tensorflow.lite.support.image; import android.graphics.Bitmap; +import android.graphics.Color; import java.util.Arrays; import org.tensorflow.lite.DataType; import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; @@ -30,32 +31,31 @@ class ImageConversions { /** * Converts an Image in a TensorBuffer to a Bitmap, whose memory is already allocated. * - * Notice: We only support ARGB_8888 at this point. + *

Notice: We only support ARGB_8888 at this point. * * @param buffer The TensorBuffer object representing the image. It should be an UInt8 buffer with - * 3 dimensions: width, height, channel. Size of each dimension should be positive and the size of - * channels should be 3 (representing R, G, B). + * 3 dimensions: width, height, channel. Size of each dimension should be positive and the + * size of channels should be 3 (representing R, G, B). An optional 4th dimension "batch" is + * acceptable, and dimensions look like: batch, width, height, channel. In this case, size of + * batches should be 1. * @param bitmap The destination of the conversion. Needs to be created in advance, needs to be - * mutable, and needs to have the same width and height with the buffer. + * mutable, and needs to have the same width and height with the buffer. * @throws IllegalArgumentException 1) if the {@code buffer} is not uint8 (e.g. a float buffer), - * or has an invalid shape. 2) if the {@code bitmap} is not mutable. 3) if the {@code bitmap} has - * different height or width with the buffer. + * or has an invalid shape. 2) if the {@code bitmap} is not mutable. 3) if the {@code bitmap} + * has different height or width with the buffer. */ static void convertTensorBufferToBitmap(TensorBuffer buffer, Bitmap bitmap) { if (buffer.getDataType() != DataType.UINT8) { // We will add support to FLOAT format conversion in the future, as it may need other configs. - throw new UnsupportedOperationException(String.format( - "Converting TensorBuffer of type %s to Bitmap is not supported yet.", - buffer.getDataType())); + throw new UnsupportedOperationException( + String.format( + "Converting TensorBuffer of type %s to ARGB_8888 Bitmap is not supported yet.", + buffer.getDataType())); } int[] shape = buffer.getShape(); - if (shape.length != 3 || shape[0] <= 0 || shape[1] <= 0 || shape[2] != 3) { - throw new IllegalArgumentException(String.format( - "Buffer shape %s is not valid. 3D TensorBuffer with shape [w, h, 3] is required", - Arrays.toString(shape))); - } - int h = shape[0]; - int w = shape[1]; + TensorImage.checkImageTensorShape(shape); + int h = shape[shape.length - 3]; + int w = shape[shape.length - 2]; if (bitmap.getWidth() != w || bitmap.getHeight() != h) { throw new IllegalArgumentException(String.format( "Given bitmap has different width or height %s with the expected ones %s.", @@ -69,10 +69,10 @@ class ImageConversions { int[] intValues = new int[w * h]; int[] rgbValues = buffer.getIntArray(); for (int i = 0, j = 0; i < intValues.length; i++) { - byte r = (byte) rgbValues[j++]; - byte g = (byte) rgbValues[j++]; - byte b = (byte) rgbValues[j++]; - intValues[i] = ((r << 16) | (g << 8) | b); + int r = rgbValues[j++]; + int g = rgbValues[j++]; + int b = rgbValues[j++]; + intValues[i] = Color.rgb(r, g, b); } bitmap.setPixels(intValues, 0, w, 0, 0, w, h); } diff --git a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java index d72b1b8e02b..2d57749b7c7 100644 --- a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java +++ b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java @@ -130,12 +130,10 @@ public class TensorImage { * will be applied. * * @param pixels The RGB pixels representing the image. - * @param shape The shape of the image, should have 3 dims and the last dim should be 3. + * @param shape The shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3). */ public void load(@NonNull float[] pixels, @NonNull int[] shape) { - SupportPreconditions.checkArgument( - shape.length == 3 && shape[2] == 3, - "Only supports image shape in (h, w, c), and channels representing R, G, B in order."); + checkImageTensorShape(shape); TensorBuffer buffer = TensorBuffer.createDynamic(getDataType()); buffer.loadArray(pixels, shape); load(buffer); @@ -148,12 +146,10 @@ public class TensorImage { * into [0, 255]. * * @param pixels The RGB pixels representing the image. - * @param shape The shape of the image, should have 3 dims and the last dim should be 3. + * @param shape The shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3). */ public void load(@NonNull int[] pixels, @NonNull int[] shape) { - SupportPreconditions.checkArgument( - shape.length == 3 && shape[2] == 3, - "Only supports image shape in (h, w, c), and channels representing R, G, B in order."); + checkImageTensorShape(shape); TensorBuffer buffer = TensorBuffer.createDynamic(getDataType()); buffer.loadArray(pixels, shape); load(buffer); @@ -162,9 +158,10 @@ public class TensorImage { /** * Loads a TensorBuffer containing pixel values. The color layout should be RGB. * - * @param buffer The TensorBuffer to load. + * @param buffer The TensorBuffer to load. Its shape should be either (h, w, 3) or (1, h, w, 3). */ public void load(TensorBuffer buffer) { + checkImageTensorShape(buffer.getShape()); container.set(buffer); } @@ -222,6 +219,17 @@ public class TensorImage { return container.getDataType(); } + // Requires tensor shape [h, w, 3] or [1, h, w, 3]. + static void checkImageTensorShape(int[] shape) { + SupportPreconditions.checkArgument( + (shape.length == 3 || (shape.length == 4 && shape[0] == 1)) + && shape[shape.length - 3] > 0 + && shape[shape.length - 2] > 0 + && shape[shape.length - 1] == 3, + "Only supports image shape in (h, w, c) or (1, h, w, c), and channels representing R, G, B" + + " in order."); + } + // Handles RGB image data storage strategy of TensorBuffer. private static class ImageContainer { @@ -274,8 +282,8 @@ public class TensorImage { // Create a new bitmap and reallocate memory for it. if (bitmapImage == null || bitmapImage.getAllocationByteCount() < requiredAllocation) { int[] shape = bufferImage.getShape(); - int h = shape[0]; - int w = shape[1]; + int h = shape[shape.length - 3]; + int w = shape[shape.length - 2]; bitmapImage = Bitmap.createBitmap(w, h, Config.ARGB_8888); } ImageConversions.convertTensorBufferToBitmap(bufferImage, bitmapImage); diff --git a/tensorflow/lite/experimental/support/metadata/metadata.py b/tensorflow/lite/experimental/support/metadata/metadata.py index 1b5380352b8..25ca57bb4cc 100644 --- a/tensorflow/lite/experimental/support/metadata/metadata.py +++ b/tensorflow/lite/experimental/support/metadata/metadata.py @@ -97,8 +97,9 @@ class MetadataPopulator(object): Raises: IOError: File not found. + ValueError: the model does not have the expected flatbuffer identifer. """ - _assert_exist(model_file) + _assert_model_file_identifier(model_file) self._model_file = model_file self._metadata_buf = None self._associated_files = set() @@ -115,6 +116,7 @@ class MetadataPopulator(object): Raises: IOError: File not found. + ValueError: the model does not have the expected flatbuffer identifer. """ return cls(model_file) @@ -129,6 +131,9 @@ class MetadataPopulator(object): Returns: A MetadataPopulator(_MetadataPopulatorWithBuffer) object. + + Raises: + ValueError: the model does not have the expected flatbuffer identifer. """ return _MetadataPopulatorWithBuffer(model_buf) @@ -211,12 +216,13 @@ class MetadataPopulator(object): metadata_buf: metadata buffer (in bytearray) to be populated. Raises: - ValueError: - The metadata to be populated is empty. + ValueError: The metadata to be populated is empty. + ValueError: The metadata does not have the expected flatbuffer identifer. """ if not metadata_buf: raise ValueError("The metadata to be populated is empty.") + _assert_metadata_buffer_identifier(metadata_buf) self._metadata_buf = metadata_buf def load_metadata_file(self, metadata_file): @@ -226,8 +232,8 @@ class MetadataPopulator(object): metadata_file: path to the metadata file to be populated. Raises: - IOError: - File not found. + IOError: File not found. + ValueError: The metadata does not have the expected flatbuffer identifer. """ _assert_exist(metadata_file) with open(metadata_file, "rb") as f: @@ -391,6 +397,7 @@ class _MetadataPopulatorWithBuffer(MetadataPopulator): Raises: ValueError: model_buf is empty. + ValueError: model_buf does not have the expected flatbuffer identifer. """ if not model_buf: raise ValueError("model_buf cannot be empty.") @@ -423,6 +430,8 @@ class MetadataDisplayer(object): metadata_file: valid path to the metadata file. associated_file_list: list of associate files in the model file. """ + _assert_model_file_identifier(model_file) + _assert_metadata_file_identifier(metadata_file) self._model_file = model_file self._metadata_file = metadata_file self._associated_file_list = associated_file_list @@ -553,3 +562,32 @@ def _assert_exist(filename): """Checks if a file exists.""" if not os.path.exists(filename): raise IOError("File, '{0}', does not exist.".format(filename)) + + +def _assert_model_file_identifier(model_file): + """Checks if a model file has the expected TFLite schema identifier.""" + _assert_exist(model_file) + with open(model_file, "rb") as f: + model_buf = f.read() + + if not _schema_fb.Model.ModelBufferHasIdentifier(model_buf, 0): + raise ValueError( + "The model provided does not have the expected identifier, and " + "may not be a valid TFLite model.") + + +def _assert_metadata_file_identifier(metadata_file): + """Checks if a metadata file has the expected Metadata schema identifier.""" + _assert_exist(metadata_file) + with open(metadata_file, "rb") as f: + metadata_buf = f.read() + _assert_metadata_buffer_identifier(metadata_buf) + + +def _assert_metadata_buffer_identifier(metadata_buf): + """Checks if a metadata buffer has the expected Metadata schema identifier.""" + if not _metadata_fb.ModelMetadata.ModelMetadataBufferHasIdentifier( + metadata_buf, 0): + raise ValueError( + "The metadata buffer does not have the expected identifier, and may not" + " be a valid TFLite Metadata.") diff --git a/tensorflow/lite/experimental/support/metadata/metadata_test.py b/tensorflow/lite/experimental/support/metadata/metadata_test.py index 30f6a73e070..81b3eef62f9 100644 --- a/tensorflow/lite/experimental/support/metadata/metadata_test.py +++ b/tensorflow/lite/experimental/support/metadata/metadata_test.py @@ -102,6 +102,39 @@ class MetadataTest(test_util.TensorFlowTestCase): f.write(b.Output()) return metadata_file + def _create_model_buffer_with_wrong_identifier(self): + wrong_identifier = b"widn" + model = _schema_fb.ModelT() + model_builder = flatbuffers.Builder(0) + model_builder.Finish(model.Pack(model_builder), wrong_identifier) + return model_builder.Output() + + def _create_metadata_buffer_with_wrong_identifier(self): + # Creates a metadata with wrong identifier + wrong_identifier = b"widn" + metadata = _metadata_fb.ModelMetadataT() + metadata_builder = flatbuffers.Builder(0) + metadata_builder.Finish(metadata.Pack(metadata_builder), wrong_identifier) + return metadata_builder.Output() + + def _populate_metadata_with_identifier(self, model_buf, metadata_buf, + identifier): + # For testing purposes only. MetadataPopulator cannot populate metadata with + # wrong identifiers. + model = _schema_fb.ModelT.InitFromObj( + _schema_fb.Model.GetRootAsModel(model_buf, 0)) + buffer_field = _schema_fb.BufferT() + buffer_field.data = metadata_buf + model.buffers = [buffer_field] + # Creates a new metadata field. + metadata_field = _schema_fb.MetadataT() + metadata_field.name = _metadata.MetadataPopulator.METADATA_FIELD_NAME + metadata_field.buffer = len(model.buffers) - 1 + model.metadata = [metadata_field] + b = flatbuffers.Builder(0) + b.Finish(model.Pack(b), identifier) + return b.Output() + class MetadataPopulatorTest(MetadataTest): @@ -126,6 +159,14 @@ class MetadataPopulatorTest(MetadataTest): _metadata.MetadataPopulator.with_model_buffer(self._invalid_model_buf) self.assertEqual("model_buf cannot be empty.", str(error.exception)) + def testToModelBufferWithWrongIdentifier(self): + model_buf = self._create_model_buffer_with_wrong_identifier() + with self.assertRaises(ValueError) as error: + _metadata.MetadataPopulator.with_model_buffer(model_buf) + self.assertEqual( + "The model provided does not have the expected identifier, and " + "may not be a valid TFLite model.", str(error.exception)) + def testSinglePopulateAssociatedFile(self): populator = _metadata.MetadataPopulator.with_model_buffer( self._empty_model_buf) @@ -228,6 +269,15 @@ class MetadataPopulatorTest(MetadataTest): "not been loaded into the populator.").format( os.path.basename(self._file2)), str(error.exception)) + def testPopulateMetadataBufferWithWrongIdentifier(self): + metadata_buf = self._create_metadata_buffer_with_wrong_identifier() + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + with self.assertRaises(ValueError) as error: + populator.load_metadata_buffer(metadata_buf) + self.assertEqual( + "The metadata buffer does not have the expected identifier, and may not" + " be a valid TFLite Metadata.", str(error.exception)) + def _assert_golden_metadata(self, model_file): with open(model_file, "rb") as f: model_buf_from_file = f.read() @@ -332,6 +382,34 @@ class MetadataDisplayerTest(MetadataTest): populator.populate() return model_file + def test_load_model_buffer_metadataBufferWithWrongIdentifier_throwsException( + self): + model_buf = self._create_model_buffer_with_wrong_identifier() + metadata_buf = self._create_metadata_buffer_with_wrong_identifier() + model_buf = self._populate_metadata_with_identifier( + model_buf, metadata_buf, + _metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER) + with self.assertRaises(ValueError) as error: + _metadata.MetadataDisplayer.with_model_buffer(model_buf) + self.assertEqual( + "The metadata buffer does not have the expected identifier, and may not" + " be a valid TFLite Metadata.", str(error.exception)) + + def test_load_model_buffer_modelBufferWithWrongIdentifier_throwsException( + self): + model_buf = self._create_model_buffer_with_wrong_identifier() + metadata_file = self._create_metadata_file() + wrong_identifier = b"widn" + with open(metadata_file, "rb") as f: + metadata_buf = bytearray(f.read()) + model_buf = self._populate_metadata_with_identifier(model_buf, metadata_buf, + wrong_identifier) + with self.assertRaises(ValueError) as error: + _metadata.MetadataDisplayer.with_model_buffer(model_buf) + self.assertEqual( + "The model provided does not have the expected identifier, and " + "may not be a valid TFLite model.", str(error.exception)) + def test_load_model_file_invalidModelFile_throwsException(self): with self.assertRaises(IOError) as error: _metadata.MetadataDisplayer.with_model_file(self._invalid_file) diff --git a/tensorflow/lite/experimental/swift/Sources/CoreMLDelegate.swift b/tensorflow/lite/experimental/swift/Sources/CoreMLDelegate.swift new file mode 100644 index 00000000000..21e0276578c --- /dev/null +++ b/tensorflow/lite/experimental/swift/Sources/CoreMLDelegate.swift @@ -0,0 +1,50 @@ +// Copyright 2020 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import TensorFlowLiteC + +/// A delegate that uses the `Core ML` framework for performing TensorFlow Lite graph operations. +/// +/// - Important: This is an experimental interface that is subject to change. +public final class CoreMLDelegate: Delegate { + /// The configuration options for the `CoreMLDelegate`. + public let options: Options + + // Conformance to the `Delegate` protocol. + public private(set) var cDelegate: CDelegate + + /// Creates a new instance configured with the given `options`. + /// + /// - Parameters: + /// - options: Configurations for the delegate. The default is a new instance of + /// `CoreMLDelegate.Options` with the default configuration values. + public init(options: Options = Options()) { + self.options = options + var delegateOptions = TfLiteCoreMlDelegateOptions() + cDelegate = TfLiteCoreMlDelegateCreate(&delegateOptions) + } + + deinit { + TfLiteCoreMlDelegateDelete(cDelegate) + } +} + +extension CoreMLDelegate { + /// Options for configuring the `CoreMLDelegate`. + // TODO(b/143931022): Add preferred device support. + public struct Options: Equatable, Hashable { + /// Creates a new instance with the default values. + public init() {} + } +} diff --git a/tensorflow/lite/experimental/swift/Sources/Tensor.swift b/tensorflow/lite/experimental/swift/Sources/Tensor.swift index 457c0eb2dac..5b1a78183f8 100644 --- a/tensorflow/lite/experimental/swift/Sources/Tensor.swift +++ b/tensorflow/lite/experimental/swift/Sources/Tensor.swift @@ -73,6 +73,8 @@ extension Tensor { case float16 /// A 32-bit single precision floating point. case float32 + /// A 64-bit double precision floating point. + case float64 /// Creates a new instance from the given `TfLiteType` or `nil` if the data type is unsupported /// or could not be determined because there was an error. @@ -94,6 +96,8 @@ extension Tensor { self = .float16 case kTfLiteFloat32: self = .float32 + case kTfLiteFloat64: + self = .float64 case kTfLiteNoType: fallthrough default: diff --git a/tensorflow/lite/experimental/writer/enum_mapping.h b/tensorflow/lite/experimental/writer/enum_mapping.h index 77f7b26cbc2..b78d610c4c5 100644 --- a/tensorflow/lite/experimental/writer/enum_mapping.h +++ b/tensorflow/lite/experimental/writer/enum_mapping.h @@ -64,6 +64,8 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) { return TensorType_FLOAT32; case kTfLiteFloat16: return TensorType_FLOAT16; + case kTfLiteFloat64: + return TensorType_FLOAT64; case kTfLiteInt32: return TensorType_INT32; case kTfLiteUInt8: diff --git a/tensorflow/lite/g3doc/_book.yaml b/tensorflow/lite/g3doc/_book.yaml index a3980aee451..1a769615eef 100644 --- a/tensorflow/lite/g3doc/_book.yaml +++ b/tensorflow/lite/g3doc/_book.yaml @@ -127,6 +127,9 @@ upper_tabs: - title: "Hexagon delegate" path: /lite/performance/hexagon_delegate status: experimental + - title: "Core ML delegate" + path: /lite/performance/coreml_delegate + status: experimental - heading: "Optimize a model" - title: "Overview" diff --git a/tensorflow/lite/g3doc/convert/quantization.md b/tensorflow/lite/g3doc/convert/quantization.md index 9dfc7a2c20c..099921cf6b3 100644 --- a/tensorflow/lite/g3doc/convert/quantization.md +++ b/tensorflow/lite/g3doc/convert/quantization.md @@ -56,16 +56,23 @@ Convert the graph: converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(saved_model_dir) converter.inference_type = tf.lite.constants.QUANTIZED_UINT8 input_arrays = converter.get_input_arrays() -converter.quantized_input_stats = {input_arrays[0] : (0., 1.)} # mean, std_dev +converter.quantized_input_stats = {input_arrays[0] : (0., 1.)} # mean_value, std_dev tflite_model = converter.convert() ``` -For fully integer models, the inputs are uint8. The `mean` and `std_dev values` -specify how those uint8 values map to the float input values used while training -the model. +For fully integer models, the inputs are uint8. When the `inference_type` is set +to `QUANTIZED_UINT8` as above, the real_input_value is standardised using the +[standard-score](https://en.wikipedia.org/wiki/Standard_score) as follows: + +real_input_value = (quantized_input_value - mean_value) / std_dev_value + +The `mean_value` and `std_dev values` specify how those uint8 values map to the +float input values used while training the model. For more details, please see +the +[TFLiteConverter](https://www.tensorflow.org/api_docs/python/tf/compat/v1/lite/TFLiteConverter) `mean` is the integer value from 0 to 255 that maps to floating point 0.0f. -`std_dev` is 255 / (float_max - float_min) +`std_dev` is 255 / (float_max - float_min). For most users, we recommend using post-training quantization. We are working on new tools for post-training and during training quantization that we hope will diff --git a/tensorflow/lite/g3doc/models/style_transfer/overview.ipynb b/tensorflow/lite/g3doc/models/style_transfer/overview.ipynb index 0dbbcbb6ccb..fea744dffeb 100644 --- a/tensorflow/lite/g3doc/models/style_transfer/overview.ipynb +++ b/tensorflow/lite/g3doc/models/style_transfer/overview.ipynb @@ -143,19 +143,6 @@ "Import dependencies." ] }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "qZhQt7ObAHsc" - }, - "outputs": [], - "source": [ - "from __future__ import absolute_import, division, print_function, unicode_literals" - ] - }, { "cell_type": "code", "execution_count": 0, @@ -166,12 +153,8 @@ }, "outputs": [], "source": [ - "try:\n", - " # %tensorflow_version only exists in Colab.\n", - " import tensorflow.compat.v2 as tf\n", - "except Exception:\n", - " pass\n", - "tf.enable_v2_behavior()" + "import tensorflow as tf\n", + "print(tf.__version__)" ] }, { @@ -220,7 +203,7 @@ "style_path = tf.keras.utils.get_file('style23.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg')\n", "\n", "style_predict_path = tf.keras.utils.get_file('style_predict.tflite', 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/arbitrary_style_transfer/style_predict_quantized_256.tflite')\n", - "style_transform_path = tf.keras.utils.get_file('style_transform.tflite', 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/arbitrary_style_transfer/style_transfer_quantized_dynamic.tflite')" + "style_transform_path = tf.keras.utils.get_file('style_transform.tflite', 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/arbitrary_style_transfer/style_transfer_quantized_384.tflite')" ] }, { @@ -234,7 +217,7 @@ "\n", "* The content image and the style image must be RGB images with pixel values being float32 numbers between [0..1].\n", "* The style image size must be (1, 256, 256, 3). We central crop the image and resize it.\n", - "* The content image can be any size. However, as we trained the model using square-cropped data, cropping the content image to a square results in better stylized image." + "* The content image must be (1, 384, 384, 3). We central crop the image and resize it." ] }, { @@ -256,37 +239,27 @@ "\n", " return img\n", "\n", - "# Function to pre-process style image input.\n", - "def preprocess_style_image(style_image):\n", + "# Function to pre-process by resizing an central cropping it.\n", + "def preprocess_image(image, target_dim):\n", " # Resize the image so that the shorter dimension becomes 256px.\n", - " target_dim = 256\n", - " shape = tf.cast(tf.shape(style_image)[1:-1], tf.float32)\n", + " shape = tf.cast(tf.shape(image)[1:-1], tf.float32)\n", " short_dim = min(shape)\n", " scale = target_dim / short_dim\n", " new_shape = tf.cast(shape * scale, tf.int32)\n", - " style_image = tf.image.resize(style_image, new_shape)\n", + " image = tf.image.resize(image, new_shape)\n", "\n", " # Central crop the image.\n", - " style_image = tf.image.resize_with_crop_or_pad(style_image, target_dim, target_dim)\n", + " image = tf.image.resize_with_crop_or_pad(image, target_dim, target_dim)\n", "\n", - " return style_image\n", - "\n", - "# Function to pre-process content image input.\n", - "def preprocess_content_image(content_image):\n", - " # Central crop the image.\n", - " shape = tf.shape(content_image)[1:-1]\n", - " short_dim = min(shape)\n", - " content_image = tf.image.resize_with_crop_or_pad(content_image, short_dim, short_dim)\n", - "\n", - " return content_image\n", + " return image\n", "\n", "# Load the input images.\n", "content_image = load_img(content_path)\n", "style_image = load_img(style_path)\n", "\n", "# Preprocess the input images.\n", - "preprocessed_content_image = preprocess_content_image(content_image)\n", - "preprocessed_style_image = preprocess_style_image(style_image)\n", + "preprocessed_content_image = preprocess_image(content_image, 384)\n", + "preprocessed_style_image = preprocess_image(style_image, 256)\n", "\n", "print('Style Image Shape:', preprocessed_style_image.shape)\n", "print('Content Image Shape:', preprocessed_content_image.shape)" @@ -407,8 +380,6 @@ "\n", " # Set model input.\n", " input_details = interpreter.get_input_details()\n", - " interpreter.resize_tensor_input(input_details[0][\"index\"],\n", - " preprocessed_content_image.shape)\n", " interpreter.allocate_tensors()\n", "\n", " # Set model inputs.\n", @@ -454,7 +425,7 @@ "source": [ "# Calculate style bottleneck of the content image.\n", "style_bottleneck_content = run_style_predict(\n", - " preprocess_style_image(content_image)\n", + " preprocess_image(content_image, 256)\n", " )" ] }, @@ -501,7 +472,7 @@ "\u003ctd\u003ePixel 3 (Android 10) \u003c/td\u003e \u003ctd\u003e142ms\u003c/td\u003e\u003ctd\u003e14ms*\u003c/td\u003e\u003c/tr\u003e\n", "\u003ctr\u003e\u003ctd\u003ePixel 4 (Android 10) \u003c/td\u003e \u003ctd\u003e5.2ms\u003c/td\u003e\u003ctd\u003e6.7ms*\u003c/td\u003e\u003c/tr\u003e\n", "\u003ctr\u003e\u003ctd\u003eiPhone XS (iOS 12.4.1) \u003c/td\u003e \u003ctd\u003e\u003c/td\u003e\u003ctd\u003e10.7ms**\u003c/td\u003e\u003c/tr\u003e\n", - "\u003ctr\u003e \u003ctd rowspan = 3\u003e \u003ca href=\"https://storage.googleapis.com/download.tensorflow.org/models/tflite/arbitrary_style_transfer/style_transfer_quantized_dynamic.tflite\"\u003eStyle transform model\u003c/a\u003e \u003c/td\u003e \n", + "\u003ctr\u003e \u003ctd rowspan = 3\u003e \u003ca href=\"https://storage.googleapis.com/download.tensorflow.org/models/tflite/arbitrary_style_transfer/style_transfer_quantized_384.tflite\"\u003eStyle transform model\u003c/a\u003e \u003c/td\u003e \n", "\u003ctd rowspan = 3\u003e0.2s Mb\u003c/td\u003e\n", "\u003ctd\u003ePixel 3 (Android 10) \u003c/td\u003e \u003ctd\u003e\u003c/td\u003e\u003ctd\u003e540ms*\u003c/td\u003e\u003c/tr\u003e\n", "\u003ctr\u003e\u003ctd\u003ePixel 4 (Android 10) \u003c/td\u003e \u003ctd\u003e\u003c/td\u003e\u003ctd\u003e405ms*\u003c/td\u003e\u003c/tr\u003e\n", @@ -516,13 +487,17 @@ "metadata": { "colab": { "collapsed_sections": [], + "last_runtime": { + "build_target": "//learning/brain/python/client:colab_notebook", + "kind": "private" + }, "name": "Artistic Style Transfer with TensorFlow Lite.ipynb", "private_outputs": true, "provenance": [] }, "kernelspec": { - "display_name": "Python 3", - "name": "python3" + "display_name": "Python 2", + "name": "python2" } }, "nbformat": 4, diff --git a/tensorflow/lite/g3doc/performance/coreml_delegate.md b/tensorflow/lite/g3doc/performance/coreml_delegate.md new file mode 100644 index 00000000000..9106bb93f70 --- /dev/null +++ b/tensorflow/lite/g3doc/performance/coreml_delegate.md @@ -0,0 +1,189 @@ +# Tensorflow Lite Core ML Delegate + +TensorFlow Lite Core ML Delegate enables running TensorFlow Lite models on +[Core ML framework](https://developer.apple.com/documentation/coreml), +which results in faster model inference on iOS devices. + +Note: This delegate is in experimental (beta) phase. + +Note: Core ML delegate is using Core ML version 2.1. + +**Supported iOS versions and devices:** + +* iOS 12 and later. In the older iOS versions, Core ML delegate will + automatically fallback to CPU. +* When running on iPhone Xs and later, it will use Neural Engine for faster + inference. + +**Supported models** + +The Core ML delegate currently supports float32 models. + +## Trying the Core ML delegate on your own model + +The Core ML delegate is already included in nightly release of TensorFlow lite +CocoaPods. To use Core ML delegate, change your TensorFlow lite pod +(`TensorflowLiteC` for C++ API, and `TensorFlowLiteSwift` for Swift) version to +`0.0.1-nightly` in your `Podfile`. + +``` +target 'YourProjectName' + # pod 'TensorFlowLiteSwift' + pod 'TensorFlowLiteSwift', '~> 0.0.1-nightly' +``` + +Note: After updating `Podfile`, you should run `pod cache clean` and `pod +update` to reflect changes. + +### Swift + +Initialize TensorFlow Lite interpreter with the Core ML delegate. + +```swift +let coreMlDelegate = CoreMLDelegate() +let interpreter = try Interpreter(modelPath: modelPath, + delegates: [coreMlDelegate]) +``` + +### Objective-C++ + +The Core ML delegate uses C++ API for Objective-C++ codes. + +#### Step 1. Include `coreml_delegate.h`. + +```objectivec++ +#include "tensorflow/lite/experimental/delegates/coreml/coreml_delegate.h" +``` + +#### Step 2. Create a delegate and initialize a TensorFlow Lite Interpreter + +After initializing the interpreter, call `interpreter->ModifyGraphWithDelegate` +with initialized Core ML delegate to apply the delegate. + +```objectivec++ +// initializer interpreter with model. +tflite::InterpreterBuilder(*model, resolver)(&interpreter); + +// Add following section to use the Core ML delegate. +TfLiteCoreMlDelegateOptions options = {}; +delegate = TfLiteCoreMlDelegateCreate(&options); +interpreter->ModifyGraphWithDelegate(delegate); + +// ... +``` + +#### Step 3. Dispose the delegate when it is no longer used. + +Add this code to the section where you dispose of the delegate (e.g. `dealloc` +of class). + +```objectivec++ +TfLiteCoreMlDelegateDelete(delegate); +``` + +## Supported ops + +Following ops are supported by the Core ML delegate. + +* Add + * Only certain shapes are broadcastable. In Core ML tensor layout, + following tensor shapes are broadcastable. `[B, C, H, W]`, `[B, C, 1, + 1]`, `[B, 1, H, W]`, `[B, 1, 1, 1]`. +* AveragePool2D +* Concat +* Conv2D + * Weights and bias should be constant. +* DepthwiseConv2D + * Weights and bias should be constant. +* Hardswish +* Logistic (aka Sigmoid) +* MaxPool2D +* Mul + * Only certain shapes are broadcastable. In Core ML tensor layout, + following tensor shapes are broadcastable. `[B, C, H, W]`, `[B, C, 1, + 1]`, `[B, 1, H, W]`, `[B, 1, 1, 1]`. +* Relu +* ReluN1To1 +* Relu6 +* Reshape +* ResizeBilinear +* SoftMax +* Tanh + +## Feedback + +For issues, please create a +[GitHub](https://github.com/tensorflow/tensorflow/issues/new?template=50-other-issues.md) +issue with all the necessary details to reproduce. + +## FAQ + +* Does CoreML delegate support fallback to CPU if a graph contains unsupported + ops? + * Yes +* Does CoreML delegate work on iOS Simulator? + * Yes. The library includes x86 and x86_64 targets so it can run on + a simulator, but you will not see performance boost over CPU. +* Does TensorFlow Lite and CoreML delegate support MacOS? + * TensorFlow Lite is only tested on iOS but not MacOS. +* Is custom TF Lite ops supported? + * No, CoreML delegate does not support custom ops and they will fallback to + CPU. + +## APIs + +### Core ML delegate Swift API + +```swift +/// A delegate that uses the `Core ML` framework for performing +/// TensorFlow Lite graph operations. +/// +/// - Important: This is an experimental interface that is subject to change. +public final class CoreMLDelegate: Delegate { + /// The configuration options for the `CoreMLDelegate`. + public let options: Options + + // Conformance to the `Delegate` protocol. + public private(set) var cDelegate: CDelegate + + * /// Creates a new instance configured with the given `options`. + /// + /// - Parameters: + /// - options: Configurations for the delegate. The default is a new instance of + /// `CoreMLDelegate.Options` with the default configuration values. + public init(options: Options = Options()) { + self.options = options + var delegateOptions = TfLiteCoreMlDelegateOptions() + cDelegate = TfLiteCoreMlDelegateCreate(&delegateOptions) + } + + deinit { + TfLiteCoreMlDelegateDelete(cDelegate) + } +} + +extension CoreMLDelegate { + /// Options for configuring the `CoreMLDelegate`. + public struct Options: Equatable, Hashable { + /// Creates a new instance with the default values. + public init() {} + } +} +``` + +### Core ML delegate C++ API + +```c++ +typedef struct { + // We have dummy for now as we can't have empty struct in C. + char dummy; +} TfLiteCoreMlDelegateOptions; + +// Return a delegate that uses CoreML for ops execution. +// Must outlive the interpreter. +TfLiteDelegate* TfLiteCoreMlDelegateCreate( + const TfLiteCoreMlDelegateOptions* options); + +// Do any needed cleanup and delete 'delegate'. +void TfLiteCoreMlDelegateDelete(TfLiteDelegate* delegate); +``` diff --git a/tensorflow/lite/g3doc/performance/delegates.md b/tensorflow/lite/g3doc/performance/delegates.md index 4f383b52e1f..4e6d7d09c73 100644 --- a/tensorflow/lite/g3doc/performance/delegates.md +++ b/tensorflow/lite/g3doc/performance/delegates.md @@ -1,17 +1,18 @@ # TensorFlow Lite delegates Note: Delegate API is still experimental and is subject to change. - ## What is a TensorFlow Lite delegate? -A TensorFlow Lite delegate is a way to delegate part or all of graph execution to another executor. - +A TensorFlow Lite delegate is a way to delegate part or all of graph execution +to another executor. ## Why should I use delegates? -Running inference on compute-heavy machine learning models on mobile devices is resource demanding due to the devices' limited processing and power. +Running inference on compute-heavy machine learning models on mobile devices is +resource demanding due to the devices' limited processing and power. -Instead of relying on the CPU, some devices have hardware accelerators, such as GPU or DSP, that allows for better performance and higher energy efficiency. +Instead of relying on the CPU, some devices have hardware accelerators, such as +GPU or DSP, that allows for better performance and higher energy efficiency. ## Using the built-in delegates @@ -33,6 +34,12 @@ TensorFlow Lite provides the following delegates for hardware acceleration: can be used on devices older version of Android OS that does not fully support NNAPI. See [TensorFlow Lite Hexagon delegate](hexagon_delegate.md) for more detail. +* **Core ML delegate for newer iPhones and iPads** - For newer iPhones and + iPads where Neural Engine is available, you can use Core ML delegate to + accelerate inference for 32-bit float based models. Neural Engine is + available Apple mobile devices with A12 SoC or higher. For an overview of + the Core ML delegate and step-by-step instructions, see + [TensorFlow Lite Core ML delegate](coreml_delegate.md). ## How do delegates work? @@ -40,16 +47,25 @@ Let's say we have a simple model graph such as the following: ![Original graph](../images/performance/tflite_delegate_graph_1.png "Original Graph") -If a delegate was provided for specific operations, then TensorFlow Lite will split the graph into multiple subgraphs where each subgraph will be handled by a delegate. +If a delegate was provided for specific operations, then TensorFlow Lite will +split the graph into multiple subgraphs where each subgraph will be handled by a +delegate. -Let's assume that there is a delegate "MyDelegate," which has a faster implementation for Conv2D and Mean operations. The resulting main graph will be updated to look like below. +Let's assume that there is a delegate "MyDelegate," which has a faster +implementation for Conv2D and Mean operations. The resulting main graph will be +updated to look like below. ![Graph with delegate](../images/performance/tflite_delegate_graph_2.png "Graph with delegate") -Each subgraph that is handled by a delegate will be replaced with a node that evaluates the subgraph on its invoked call. - -Depending on the model, the final graph can end up with one node, which means that all of the graphs were delegated or multiple nodes handled the subgraphs. In general, you don't want to have multiple subgraphs handled by the delegate, since each time you switch from delegate to the main graph, there is an overhead for passing the results from the subgraph to the main graph. It's not always safe to share memory. +Each subgraph that is handled by a delegate will be replaced with a node that +evaluates the subgraph on its invoked call. +Depending on the model, the final graph can end up with one node, which means +that all of the graphs were delegated or multiple nodes handled the subgraphs. +In general, you don't want to have multiple subgraphs handled by the delegate, +since each time you switch from delegate to the main graph, there is an overhead +for passing the results from the subgraph to the main graph. It's not always +safe to share memory. ## How to add a delegate @@ -57,12 +73,15 @@ _Note that the API used below is experimental and is subject to change._ Based on the previous section, to add a delegate, we need to do the following: +1. Define a kernel node that is responsible for evaluating the delegate + subgraph +1. Create an instance of + [TfLiteDelegate](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/c/common.h#L611), + which is responsible for registering the kernel node and claiming the nodes + that the delegate can execute - -1. Define a kernel node that is responsible for evaluating the delegate subgraph -1. Create an instance of [TfLiteDelegate](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/c/common.h#L611), which is responsible for registering the kernel node and claiming the nodes that the delegate can execute - -To see it in code, let's define a delegate and call it "MyDelegate," which can execute Conv2D and Mean operations faster. +To see it in code, let's define a delegate and call it "MyDelegate," which can +execute Conv2D and Mean operations faster. ``` // This is where the execution of the operations or whole graph happens. diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java index 6aeb06355b4..efcdc0e4c65 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -314,6 +314,32 @@ public final class Interpreter implements AutoCloseable { wrapper.run(inputs, outputs); } + /** + * Expicitly updates allocations for all tensors, if necessary. + * + *

This will propagate shapes and memory allocations for all dependent tensors using the input + * tensor shape(s) as given. + * + *

Note: This call is *purely optional*. Tensor allocation will occur automatically during + * execution if any input tensors have been resized. This call is most useful in determining the + * shapes for any output tensors before executing the graph, e.g., + *

{@code
+   * interpreter.resizeInput(0, new int[]{1, 4, 4, 3}));
+   * interpreter.allocateTensors();
+   * FloatBuffer input = FloatBuffer.allocate(interpreter.getInputTensor(0),numElements());
+   * // Populate inputs...
+   * FloatBuffer output = FloatBuffer.allocate(interpreter.getOutputTensor(0).numElements());
+   * interpreter.run(input, output)
+   * // Process outputs...
+   * }
+ * + * @throws IllegalStateException if the graph's tensors could not be successfully allocated. + */ + public void allocateTensors() { + checkNotClosed(); + wrapper.allocateTensors(); + } + /** * Resizes idx-th input of the native model to the given dims. * @@ -373,6 +399,13 @@ public final class Interpreter implements AutoCloseable { /** * Gets the Tensor associated with the provdied output index. * + *

Note: Output tensor details (e.g., shape) may not be fully populated until after inference + * is executed. If you need updated details *before* running inference (e.g., after resizing an + * input tensor, which may invalidate output tensor shapes), use {@link #allocateTensors()} to + * explicitly trigger allocation and shape propagation. Note that, for graphs with output shapes + * that are dependent on input *values*, the output shape may not be fully determined until + * running inference. + * * @throws IllegalArgumentException if {@code outputIndex} is negtive or is not smaller than the * number of model outputs. */ diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index ca21ec5c7ea..73fe506f131 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -175,6 +175,8 @@ final class NativeInterpreterWrapper implements AutoCloseable { /** Resizes dimensions of a specific input. */ void resizeInput(int idx, int[] dims) { if (resizeInput(interpreterHandle, errorHandle, idx, dims)) { + // Tensor allocation is deferred until either an explicit `allocateTensors()` call or + // `invoke()` avoiding redundant allocations if multiple tensors are simultaneosly resized. isMemoryAllocated = false; if (inputTensors[idx] != null) { inputTensors[idx].refreshShape(); @@ -185,6 +187,23 @@ final class NativeInterpreterWrapper implements AutoCloseable { private static native boolean resizeInput( long interpreterHandle, long errorHandle, int inputIdx, int[] dims); + /** Triggers explicit allocation of tensors. */ + void allocateTensors() { + if (isMemoryAllocated) { + return; + } + + isMemoryAllocated = true; + allocateTensors(interpreterHandle, errorHandle); + for (int i = 0; i < outputTensors.length; ++i) { + if (outputTensors[i] != null) { + outputTensors[i].refreshShape(); + } + } + } + + private static native long allocateTensors(long interpreterHandle, long errorHandle); + void setUseNNAPI(boolean useNNAPI) { useNNAPI(interpreterHandle, useNNAPI); } @@ -385,8 +404,6 @@ final class NativeInterpreterWrapper implements AutoCloseable { // List of owned delegates that must be closed when the interpreter is closed. private final List ownedDelegates = new ArrayList<>(); - private static native long allocateTensors(long interpreterHandle, long errorHandle); - private static native boolean hasUnresolvedFlexOp(long interpreterHandle); private static native int getInputTensorIndex(long interpreterHandle, int inputIdx); diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java index 8b18e1764ce..b38f1ad771d 100644 --- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -222,6 +222,32 @@ public final class InterpreterTest { } } + @Test + public void testAllocateTensors() { + try (Interpreter interpreter = new Interpreter(MODEL_BUFFER)) { + // Redundant allocateTensors() should have no effect. + interpreter.allocateTensors(); + + // allocateTensors() should propagate resizes. + int[] inputDims = {1}; + assertThat(interpreter.getOutputTensor(0).shape()).isNotEqualTo(inputDims); + interpreter.resizeInput(0, inputDims); + assertThat(interpreter.getOutputTensor(0).shape()).isNotEqualTo(inputDims); + interpreter.allocateTensors(); + assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(inputDims); + + // Additional redundant calls should have no effect. + interpreter.allocateTensors(); + assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(inputDims); + + // Execution should succeed as expected. + ByteBuffer input = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder()); + ByteBuffer output = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder()); + interpreter.run(input, output); + assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(inputDims); + } + } + @Test public void testUnknownDims() { try (Interpreter interpreter = new Interpreter(UNKNOWN_DIMS_MODEL_PATH_BUFFER)) { diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 28eefb2895f..a4d188f34da 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -274,7 +274,7 @@ cc_library( # For now this unconditionally depends on both ruy and gemmlowp. # See the comment inside class CpuBackendContext on the # gemmlowp_context_ and ruy_context_ members. - "//tensorflow/lite/experimental/ruy/ruy:context", + "@ruy//ruy:context", "@gemmlowp", "//tensorflow/lite:external_cpu_backend_context", ], @@ -295,8 +295,8 @@ cc_library( # We only need to depend on gemmlowp when tflite_with_ruy # is false, but putting these dependencies in a select() seems to # defeat copybara's rewriting rules. - "//tensorflow/lite/experimental/ruy/ruy:context", - "//tensorflow/lite/experimental/ruy/ruy:thread_pool", + "@ruy//ruy:context", + "@ruy//ruy:thread_pool", "@gemmlowp", ], ) @@ -334,9 +334,9 @@ cc_library( ":cpu_backend_threadpool", # Depend on ruy regardless of `tflite_with_ruy`. See the comment in # cpu_backend_gemm.h about why ruy is the generic path. - "//tensorflow/lite/experimental/ruy/ruy", - "//tensorflow/lite/experimental/ruy/ruy:path", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", + "@ruy//ruy", + "@ruy//ruy:path", + "@ruy//ruy/profiler:instrumentation", # We only need to depend on gemmlowp and Eigen when tflite_with_ruy # is false, but putting these dependencies in a select() seems to # defeat copybara's rewriting rules. @@ -355,7 +355,7 @@ cc_test( "@com_google_googletest//:gtest", # ruy's reference path provides the reference implementation # that this test compares against. - "//tensorflow/lite/experimental/ruy/ruy", + "@ruy//ruy", ], ) @@ -596,11 +596,11 @@ cc_library( "//tensorflow/lite:context", "//tensorflow/lite/c:common", "//tensorflow/lite/experimental/kernels:hashtable_op_kernels", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", "//tensorflow/lite/kernels/internal:kernel_utils", "//tensorflow/lite/kernels/internal:tensor", "//third_party/fft2d:fft2d_headers", "@fft2d", + "@ruy//ruy/profiler:instrumentation", ], ) @@ -613,13 +613,13 @@ cc_library( ":cpu_backend_context", ":op_macros", "//tensorflow/lite/c:common", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", "//tensorflow/lite/kernels/internal:common", "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/kernels/internal:kernel_utils", "//tensorflow/lite/kernels/internal:quantization_util", "//tensorflow/lite/kernels/internal:tensor", "//tensorflow/lite/kernels/internal:tensor_utils", + "@ruy//ruy/profiler:instrumentation", ], ) diff --git a/tensorflow/lite/kernels/basic_rnn.cc b/tensorflow/lite/kernels/basic_rnn.cc index f21b8a910dd..920e8cd223a 100644 --- a/tensorflow/lite/kernels/basic_rnn.cc +++ b/tensorflow/lite/kernels/basic_rnn.cc @@ -26,6 +26,15 @@ namespace ops { namespace builtin { namespace rnn { +namespace { + +struct OpData { + int scratch_tensor_index; + bool compute_row_sums = false; +}; + +} // namespace + constexpr int kInputTensor = 0; constexpr int kWeightsTensor = 1; constexpr int kRecurrentWeightsTensor = 2; @@ -36,13 +45,14 @@ constexpr int kHiddenStateTensor = 4; constexpr int kOutputTensor = 0; void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* scratch_tensor_index = new int; - context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index); - return scratch_tensor_index; + auto* op_data = new OpData(); + context->AddTensors(context, /*tensors_to_add=*/6, + &op_data->scratch_tensor_index); + return op_data; } void Free(TfLiteContext* context, void* buffer) { - delete reinterpret_cast(buffer); + delete reinterpret_cast(buffer); } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { @@ -89,10 +99,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Allocate temporary tensors to store quantized values of input and // hidden_state tensors. if (is_hybrid) { - int* scratch_tensor_index = reinterpret_cast(node->user_data); + auto* op_data = reinterpret_cast(node->user_data); + op_data->compute_row_sums = true; TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(3); - node->temporaries->data[0] = *scratch_tensor_index; + node->temporaries = TfLiteIntArrayCreate(6); + node->temporaries->data[0] = op_data->scratch_tensor_index; TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); input_quantized->type = input_weights->type; input_quantized->allocation_type = kTfLiteArenaRw; @@ -101,7 +112,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, input_quantized_size)); } - node->temporaries->data[1] = *scratch_tensor_index + 1; + node->temporaries->data[1] = op_data->scratch_tensor_index + 1; TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, /*index=*/1); hidden_state_quantized->type = input_weights->type; @@ -114,7 +125,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context->ResizeTensor(context, hidden_state_quantized, hidden_state_quantized_size)); } - node->temporaries->data[2] = *scratch_tensor_index + 2; + node->temporaries->data[2] = op_data->scratch_tensor_index + 2; TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); scaling_factors->type = kTfLiteFloat32; scaling_factors->allocation_type = kTfLiteArenaRw; @@ -125,8 +136,43 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, scaling_factors_size)); } + node->temporaries->data[3] = op_data->scratch_tensor_index + 3; + TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/3); + accum_scratch->type = kTfLiteInt32; + accum_scratch->allocation_type = kTfLiteArenaRw; + int accum_scratch_dims[2] = {num_units, batch_size}; + if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2, + accum_scratch_dims)) { + TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2); + accum_scratch_size->data[0] = accum_scratch_dims[0]; + accum_scratch_size->data[1] = accum_scratch_dims[1]; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch, + accum_scratch_size)); + } + node->temporaries->data[4] = op_data->scratch_tensor_index + 4; + TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4); + zero_points->type = kTfLiteInt32; + zero_points->allocation_type = kTfLiteArenaRw; + int zero_points_dims[1] = {batch_size}; + if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) { + TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1); + zero_points_size->data[0] = batch_size; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points, + zero_points_size)); + } + node->temporaries->data[5] = op_data->scratch_tensor_index + 5; + TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5); + row_sums->type = kTfLiteInt32; + row_sums->allocation_type = kTfLiteArenaRwPersistent; + int row_sums_dims[2] = {2, num_units}; + if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) { + TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2); + row_sums_size->data[0] = row_sums_dims[0]; + row_sums_size->data[1] = row_sums_dims[1]; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, row_sums, row_sums_size)); + } } - return kTfLiteOk; } @@ -165,7 +211,9 @@ TfLiteStatus EvalHybrid(const TfLiteTensor* input, TfLiteTensor* input_scratch, TfLiteTensor* hidden_state_scratch, TfLiteTensor* scaling_factors, - TfLiteTensor* hidden_state, TfLiteTensor* output) { + TfLiteTensor* hidden_state, TfLiteTensor* output, + TfLiteTensor* zero_points, TfLiteTensor* accum_scratch, + TfLiteTensor* row_sums, bool* compute_row_sums) { const int batch_size = input->dims->data[0]; const int num_units = input_weights->dims->data[0]; const int input_size = input->dims->data[1]; @@ -190,26 +238,34 @@ TfLiteStatus EvalHybrid(const TfLiteTensor* input, int8_t* quantized_hidden_state_ptr = GetTensorData(hidden_state_scratch); float* scaling_factors_ptr = GetTensorData(scaling_factors); - + int32_t* accum_scratch_ptr = GetTensorData(accum_scratch); + int32_t* zero_points_ptr = nullptr; + int32_t* row_sums_ptr = nullptr; + if (params->asymmetric_quantize_inputs) { + zero_points_ptr = GetTensorData(zero_points); + row_sums_ptr = GetTensorData(row_sums); + } kernel_utils::RnnBatchStep( input_ptr_batch, input_weights_ptr, input_weights_scale, recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, num_units, batch_size, output_batch_leading_dim, params->activation, quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr, - hidden_state_ptr_batch, output_ptr_batch); + hidden_state_ptr_batch, output_ptr_batch, + params->asymmetric_quantize_inputs, zero_points_ptr, accum_scratch_ptr, + row_sums_ptr, compute_row_sums); return kTfLiteOk; } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); - + auto* op_data = reinterpret_cast(node->user_data); const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor); const TfLiteTensor* recurrent_weights = GetInput(context, node, kRecurrentWeightsTensor); const TfLiteTensor* bias = GetInput(context, node, kBiasTensor); TfLiteTensor* hidden_state = - GetVariableInput(context, node, kHiddenStateTensor); + &context->tensors[node->inputs->data[kHiddenStateTensor]]; TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // We already checked that weight types are consistent, so branch on one. @@ -223,9 +279,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* input_quantized = GetTemporary(context, node, 0); TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1); TfLiteTensor* scaling_factors = GetTemporary(context, node, 2); + TfLiteTensor* accum_scratch = GetTemporary(context, node, 3); + TfLiteTensor* zero_points = GetTemporary(context, node, 4); + TfLiteTensor* row_sums = GetTemporary(context, node, 5); return EvalHybrid(input, input_weights, recurrent_weights, bias, params, input_quantized, hidden_state_quantized, - scaling_factors, hidden_state, output); + scaling_factors, hidden_state, output, zero_points, + accum_scratch, row_sums, &op_data->compute_row_sums); } default: context->ReportError(context, "Type %d not currently supported.", diff --git a/tensorflow/lite/kernels/basic_rnn_test.cc b/tensorflow/lite/kernels/basic_rnn_test.cc index b9c251ce044..f7cbaa5a814 100644 --- a/tensorflow/lite/kernels/basic_rnn_test.cc +++ b/tensorflow/lite/kernels/basic_rnn_test.cc @@ -175,7 +175,8 @@ class RNNOpModel : public SingleOpModel { public: RNNOpModel(int batches, int units, int size, const TensorType& weights = TensorType_FLOAT32, - const TensorType& recurrent_weights = TensorType_FLOAT32) + const TensorType& recurrent_weights = TensorType_FLOAT32, + bool asymmetric_quantize_inputs = false) : batches_(batches), units_(units), input_size_(size) { input_ = AddInput(TensorType_FLOAT32); weights_ = AddInput(weights); @@ -183,9 +184,10 @@ class RNNOpModel : public SingleOpModel { bias_ = AddInput(TensorType_FLOAT32); hidden_state_ = AddInput(TensorType_FLOAT32, true); output_ = AddOutput(TensorType_FLOAT32); - SetBuiltinOp( - BuiltinOperator_RNN, BuiltinOptions_RNNOptions, - CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union()); + SetBuiltinOp(BuiltinOperator_RNN, BuiltinOptions_RNNOptions, + CreateRNNOptions(builder_, ActivationFunctionType_RELU, + asymmetric_quantize_inputs) + .Union()); BuildInterpreter({{batches_, input_size_}, // input tensor {units_, input_size_}, // weights tensor {units_, units_}, // recurrent weights tensor @@ -233,8 +235,10 @@ class RNNOpModel : public SingleOpModel { // The hybrid model has quantized weights and recurrent_weights. class HybridRNNOpModel : public RNNOpModel { public: - HybridRNNOpModel(int batches, int units, int size, TensorType tensor_type) - : RNNOpModel(batches, units, size, tensor_type, tensor_type) { + HybridRNNOpModel(int batches, int units, int size, TensorType tensor_type, + bool asymmetric_quantize_inputs) + : RNNOpModel(batches, units, size, tensor_type, tensor_type, + asymmetric_quantize_inputs) { tensor_type_ = tensor_type; } @@ -282,8 +286,10 @@ TEST(RnnOpTest, BlackBoxTest) { } } -TEST(HybridRnnOpTest, BlackBoxTestUint8) { - HybridRNNOpModel rnn(2, 16, 8, TensorType_UINT8); +class HybridRnnOpTest : public ::testing::TestWithParam {}; + +TEST_P(HybridRnnOpTest, BlackBoxTestUint8) { + HybridRNNOpModel rnn(2, 16, 8, TensorType_UINT8, GetParam()); rnn.SetWeights(rnn_weights); rnn.SetBias(rnn_bias); rnn.SetRecurrentWeights(rnn_recurrent_weights); @@ -310,8 +316,8 @@ TEST(HybridRnnOpTest, BlackBoxTestUint8) { } } -TEST(HybridRnnOpTest, BlackBoxTestInt8) { - HybridRNNOpModel rnn(2, 16, 8, TensorType_INT8); +TEST_P(HybridRnnOpTest, BlackBoxTestInt8) { + HybridRNNOpModel rnn(2, 16, 8, TensorType_INT8, GetParam()); rnn.SetWeights(rnn_weights); rnn.SetBias(rnn_bias); rnn.SetRecurrentWeights(rnn_recurrent_weights); @@ -338,5 +344,8 @@ TEST(HybridRnnOpTest, BlackBoxTestInt8) { } } +INSTANTIATE_TEST_SUITE_P(HybridRnnOpTest, HybridRnnOpTest, + ::testing::ValuesIn({false, true})); + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc index 33c43aacbc7..3a780eed0a0 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc @@ -139,18 +139,28 @@ enum TemporaryTensor { kProductScalingFactors = 8, kRecoveredCellWeights = 9, kAccumScratchBuffer = 10, - kAuxInputQuantized = 11, // Optional, quantized tensor for auxiliary input. - kNumTemporaryTensors + kZeroPoints = 11, + kFwRowSums = 12, + kBwRowSums = 13, + kAuxInputQuantized = 14, // Optional, quantized tensor for auxiliary input. + kNumTemporaryTensors = 15 +}; + +struct OpData { + int scratch_tensor_index; + bool compute_fw_row_sums = false; + bool compute_bw_row_sums = false; }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* scratch_tensor_index = new int; - context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index); - return scratch_tensor_index; + auto* op_data = new OpData(); + context->AddTensors(context, kNumTemporaryTensors, + &op_data->scratch_tensor_index); + return op_data; } void Free(TfLiteContext* context, void* buffer) { - delete reinterpret_cast(buffer); + delete reinterpret_cast(buffer); } // Check that input tensor dimensions matches with each other. @@ -385,7 +395,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, // Resize the output and scratch tensors based on the sizes of the input // tensors. Also check that the size of the input tensors match each other. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - int* scratch_tensor_index = reinterpret_cast(node->user_data); + auto* op_data = reinterpret_cast(node->user_data); const auto* params = reinterpret_cast( node->builtin_data); @@ -522,7 +532,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { node->temporaries = TfLiteIntArrayCreate(2); // the two scratch buffers. } // Create a scratch buffer tensor. - node->temporaries->data[kFwScratchBuffer] = *scratch_tensor_index; + node->temporaries->data[kFwScratchBuffer] = op_data->scratch_tensor_index; TfLiteTensor* fw_scratch_buffer = GetTemporary(context, node, kFwScratchBuffer); fw_scratch_buffer->type = input->type; @@ -581,7 +591,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Create a scratch buffer tensor. node->temporaries->data[kBwScratchBuffer] = - *(scratch_tensor_index) + kBwScratchBuffer; + op_data->scratch_tensor_index + kBwScratchBuffer; TfLiteTensor* bw_scratch_buffer = GetTemporary(context, node, kBwScratchBuffer); bw_scratch_buffer->type = input->type; @@ -606,10 +616,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer, bw_scratch_buffer_size)); if (is_hybrid_op) { + // Compute the row sums for cached zero_point offset calculation. + op_data->compute_fw_row_sums = true; + op_data->compute_bw_row_sums = true; // Allocate temporary tensors to store quantized values of input, aux_input // (if present), activation_state and cell_state tensors. node->temporaries->data[kInputQuantized] = - *scratch_tensor_index + kInputQuantized; + op_data->scratch_tensor_index + kInputQuantized; TfLiteTensor* input_quantized = GetTemporary(context, node, kInputQuantized); input_quantized->type = fw_input_to_output_weights->type; @@ -621,7 +634,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } node->temporaries->data[kFwActivationStateQuantized] = - *scratch_tensor_index + kFwActivationStateQuantized; + op_data->scratch_tensor_index + kFwActivationStateQuantized; TfLiteTensor* fw_activation_state_quantized = GetTemporary(context, node, kFwActivationStateQuantized); fw_activation_state_quantized->type = fw_input_to_output_weights->type; @@ -635,7 +648,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { fw_activation_state_quantized_size)); } node->temporaries->data[kBwActivationStateQuantized] = - *scratch_tensor_index + kBwActivationStateQuantized; + op_data->scratch_tensor_index + kBwActivationStateQuantized; TfLiteTensor* bw_activation_state_quantized = GetTemporary(context, node, kBwActivationStateQuantized); bw_activation_state_quantized->type = fw_input_to_output_weights->type; @@ -649,7 +662,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { bw_activation_state_quantized_size)); } node->temporaries->data[kFwCellStateQuantized] = - *scratch_tensor_index + kFwCellStateQuantized; + op_data->scratch_tensor_index + kFwCellStateQuantized; TfLiteTensor* fw_cell_state_quantized = GetTemporary(context, node, kFwCellStateQuantized); fw_cell_state_quantized->type = fw_input_to_output_weights->type; @@ -663,7 +676,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { fw_cell_state_quantized_size)); } node->temporaries->data[kBwCellStateQuantized] = - *scratch_tensor_index + kBwCellStateQuantized; + op_data->scratch_tensor_index + kBwCellStateQuantized; TfLiteTensor* bw_cell_state_quantized = GetTemporary(context, node, kBwCellStateQuantized); bw_cell_state_quantized->type = fw_input_to_output_weights->type; @@ -683,7 +696,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // different matrices (which requires multiplying the scaling factors with // the scaling factor of the matrix). node->temporaries->data[kScalingFactors] = - *scratch_tensor_index + kScalingFactors; + op_data->scratch_tensor_index + kScalingFactors; TfLiteTensor* scaling_factors = GetTemporary(context, node, kScalingFactors); scaling_factors->type = kTfLiteFloat32; @@ -696,7 +709,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { scaling_factors_size)); } node->temporaries->data[kProductScalingFactors] = - *scratch_tensor_index + kProductScalingFactors; + op_data->scratch_tensor_index + kProductScalingFactors; TfLiteTensor* prod_scaling_factors = GetTemporary(context, node, kProductScalingFactors); prod_scaling_factors->type = kTfLiteFloat32; @@ -713,7 +726,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Allocate a temporary tensor to store the recovered cell weights. Since // this is used for diagonal matrices, only need to store n_cell values. node->temporaries->data[kRecoveredCellWeights] = - *scratch_tensor_index + kRecoveredCellWeights; + op_data->scratch_tensor_index + kRecoveredCellWeights; TfLiteTensor* recovered_cell_weights = GetTemporary(context, node, kRecoveredCellWeights); recovered_cell_weights->type = kTfLiteFloat32; @@ -730,7 +743,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Allocate a temporary tensor to store the accumulated int32 values. node->temporaries->data[kAccumScratchBuffer] = - *scratch_tensor_index + kAccumScratchBuffer; + op_data->scratch_tensor_index + kAccumScratchBuffer; TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratchBuffer); accum_scratch->type = kTfLiteInt32; @@ -750,11 +763,72 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context, context->ResizeTensor(context, accum_scratch, accum_size)); } + // Allocate temporary tensors for storing zero-points. + node->temporaries->data[kZeroPoints] = + op_data->scratch_tensor_index + kZeroPoints; + TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints); + zero_points->type = kTfLiteFloat32; + zero_points->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, scaling_dims)) { + TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1); + zero_points_size->data[0] = n_batch; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points, + zero_points_size)); + } + + // Allocate temporary tensors for caching row sums for hybrid zero-point + // calculations. + int fw_row_sums_rows = fw_use_cifg ? 6 : 8; + if (has_aux_input) { + fw_row_sums_rows += fw_use_cifg ? 3 : 4; + } + const TfLiteTensor* fw_projection_weights = + GetOptionalInputTensor(context, node, kFwProjectionWeightsTensor); + if (fw_projection_weights != nullptr) { + fw_row_sums_rows += ceil(n_fw_output / n_fw_cell); + } + node->temporaries->data[kFwRowSums] = + op_data->scratch_tensor_index + kFwRowSums; + TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums); + fw_row_sums->type = kTfLiteInt32; + fw_row_sums->allocation_type = kTfLiteArenaRwPersistent; + int fw_row_sums_dims[2] = {fw_row_sums_rows, n_fw_cell}; + if (!TfLiteIntArrayEqualsArray(fw_row_sums->dims, 2, fw_row_sums_dims)) { + TfLiteIntArray* fw_hybrid_scratch_size = TfLiteIntArrayCreate(2); + fw_hybrid_scratch_size->data[0] = fw_row_sums_dims[0]; + fw_hybrid_scratch_size->data[1] = fw_row_sums_dims[1]; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_row_sums, + fw_hybrid_scratch_size)); + } + + int bw_row_sums_rows = bw_use_cifg ? 6 : 8; + if (has_aux_input) { + bw_row_sums_rows += bw_use_cifg ? 3 : 4; + } + const TfLiteTensor* bw_projection_weights = + GetOptionalInputTensor(context, node, kBwProjectionWeightsTensor); + if (bw_projection_weights != nullptr) { + bw_row_sums_rows += ceil(n_bw_output / n_bw_cell); + } + node->temporaries->data[kBwRowSums] = + op_data->scratch_tensor_index + kBwRowSums; + TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums); + bw_row_sums->type = kTfLiteInt32; + bw_row_sums->allocation_type = kTfLiteArenaRwPersistent; + int bw_row_sums_dims[2] = {bw_row_sums_rows, n_bw_cell}; + if (!TfLiteIntArrayEqualsArray(bw_row_sums->dims, 2, bw_row_sums_dims)) { + TfLiteIntArray* bw_row_sums_size = TfLiteIntArrayCreate(2); + bw_row_sums_size->data[0] = bw_row_sums_dims[0]; + bw_row_sums_size->data[1] = bw_row_sums_dims[1]; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_row_sums, + bw_row_sums_size)); + } + // Only allocate a temporary tensor for quantized auxiliary input if we are // actually going to use it. if (has_aux_input) { node->temporaries->data[kAuxInputQuantized] = - *scratch_tensor_index + kAuxInputQuantized; + op_data->scratch_tensor_index + kAuxInputQuantized; TfLiteTensor* aux_input_quantized = GetTemporary(context, node, kAuxInputQuantized); aux_input_quantized->type = fw_input_to_output_weights->type; @@ -775,7 +849,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const auto* params = reinterpret_cast( node->builtin_data); - + auto* op_data = reinterpret_cast(node->user_data); // Input tensor. const TfLiteTensor* input = GetInput(context, node, kInputTensor); @@ -909,7 +983,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Populate a TfLiteLSTMParams struct for the evaluation functions. TfLiteLSTMParams lstm_params = {params->activation, params->cell_clip, - params->proj_clip, kTfLiteLSTMFullKernel}; + params->proj_clip, kTfLiteLSTMFullKernel, + params->asymmetric_quantize_inputs}; const int bw_output_offset = params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0; @@ -1003,7 +1078,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { : nullptr; TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratchBuffer); - + TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints); + TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums); + TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums); + const int fw_row_sums_size = fw_row_sums->dims->data[0]; + const int bw_row_sums_size = bw_row_sums->dims->data[0]; TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid( input, fw_input_to_input_weights, fw_input_to_forget_weights, fw_input_to_cell_weights, fw_input_to_output_weights, @@ -1025,6 +1104,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { recovered_cell_weights, input_quantized, aux_input_quantized, fw_activation_state_quantized, fw_cell_state_quantized, fw_activation_state, fw_cell_state, accum_scratch, fw_output, + zero_points, fw_row_sums, fw_row_sums_size, + &op_data->compute_fw_row_sums, CpuBackendContext::GetFromContext(context)); TF_LITE_ENSURE_OK(context, fw_pass_status); @@ -1049,6 +1130,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { recovered_cell_weights, input_quantized, aux_input_quantized, bw_activation_state_quantized, bw_cell_state_quantized, bw_activation_state, bw_cell_state, accum_scratch, actual_bw_output, + zero_points, bw_row_sums, bw_row_sums_size, + &op_data->compute_bw_row_sums, CpuBackendContext::GetFromContext(context)); TF_LITE_ENSURE_OK(context, bw_pass_status); return kTfLiteOk; diff --git a/tensorflow/lite/kernels/bidirectional_sequence_lstm_test.cc b/tensorflow/lite/kernels/bidirectional_sequence_lstm_test.cc index 12b33c9661d..c468c4c09fb 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_lstm_test.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_lstm_test.cc @@ -40,7 +40,8 @@ class BidirectionalLSTMOpModel : public SingleOpModel { bool use_projection_bias, bool merge_outputs, bool use_aux_input, float cell_clip, float proj_clip, bool quantize_weights, bool time_major, - const std::vector>& input_shapes) + const std::vector>& input_shapes, + bool asymmetric_quantize_inputs = false) : n_batch_(n_batch), n_input_(n_input), n_fw_cell_(n_cell), @@ -207,12 +208,13 @@ class BidirectionalLSTMOpModel : public SingleOpModel { bw_aux_input_to_output_weights_ = AddNullInput(); } - SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, - BuiltinOptions_BidirectionalSequenceLSTMOptions, - CreateBidirectionalSequenceLSTMOptions( - builder_, ActivationFunctionType_TANH, cell_clip, - proj_clip, merge_outputs, time_major) - .Union()); + SetBuiltinOp( + BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, + BuiltinOptions_BidirectionalSequenceLSTMOptions, + CreateBidirectionalSequenceLSTMOptions( + builder_, ActivationFunctionType_TANH, cell_clip, proj_clip, + merge_outputs, time_major, asymmetric_quantize_inputs) + .Union()); BuildInterpreter(input_shapes); } @@ -424,11 +426,14 @@ class BidirectionalLSTMOpModel : public SingleOpModel { bool quantize_weights_; }; -// Declare LSTMOpTest as a parameterized test, where the parameter is a boolean -// indicating whether to use quantization or not. -class LSTMOpTest : public ::testing::TestWithParam {}; +// Declare LSTMOpTest as a parameterized test. +class LSTMOpTest + : public ::testing::TestWithParam<::testing::tuple> {}; -INSTANTIATE_TEST_SUITE_P(QuantizationOrNot, LSTMOpTest, ::testing::Bool()); +INSTANTIATE_TEST_SUITE_P(QuantizationOrNot, LSTMOpTest, + ::testing::Combine( + /*quantize_weights*/ ::testing::Bool(), + /*asymmetric_quantize_inputs*/ ::testing::Bool())); TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { const int n_batch = 1; @@ -437,7 +442,9 @@ TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { const int n_cell = 4; const int n_output = 4; const int sequence_length = 3; - const bool quantize_weights = GetParam(); + auto params = GetParam(); + const bool quantize_weights = std::get<0>(params); + const bool asymmetric_quantize_inputs = std::get<1>(params); BidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, @@ -509,7 +516,8 @@ TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { {0}, // aux_bw_input_to_forget tensor {0}, // aux_bw_input_to_cell tensor {0}, // aux_bw_input_to_output tensor - }); + }, + asymmetric_quantize_inputs); lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, -0.34550029, 0.04266912, -0.15680569, @@ -600,7 +608,9 @@ TEST_P(LSTMOpTest, BlackBoxTestMergedOutput) { const int n_cell = 4; const int n_output = 4; const int sequence_length = 3; - const bool quantize_weights = GetParam(); + auto params = GetParam(); + const bool quantize_weights = std::get<0>(params); + const bool asymmetric_quantize_inputs = std::get<1>(params); BidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, @@ -672,7 +682,8 @@ TEST_P(LSTMOpTest, BlackBoxTestMergedOutput) { {0}, // aux_bw_input_to_forget tensor {0}, // aux_bw_input_to_cell tensor {0}, // aux_bw_input_to_output tensor - }); + }, + asymmetric_quantize_inputs); lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, -0.34550029, 0.04266912, -0.15680569, @@ -2631,7 +2642,9 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInputZeroAuxWeight) { const int n_cell = 4; const int n_output = 4; const int sequence_length = 3; - const bool quantize_weights = GetParam(); + auto params = GetParam(); + const bool quantize_weights = std::get<0>(params); + const bool asymmetric_quantize_inputs = std::get<1>(params); BidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, @@ -2703,7 +2716,8 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInputZeroAuxWeight) { {n_cell, n_input}, // aux_bw_input_to_forget tensor {n_cell, n_input}, // aux_bw_input_to_cell tensor {n_cell, n_input}, // aux_bw_input_to_output tensor - }); + }, + asymmetric_quantize_inputs); lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, -0.34550029, 0.04266912, -0.15680569, @@ -2802,7 +2816,9 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInput) { const int n_cell = 4; const int n_output = 4; const int sequence_length = 3; - const bool quantize_weights = GetParam(); + auto params = GetParam(); + const bool quantize_weights = std::get<0>(params); + const bool asymmetric_quantize_inputs = std::get<1>(params); BidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, @@ -2874,7 +2890,8 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInput) { {n_cell, n_input}, // aux_bw_input_to_forget tensor {n_cell, n_input}, // aux_bw_input_to_cell tensor {n_cell, n_input}, // aux_bw_input_to_output tensor - }); + }, + asymmetric_quantize_inputs); lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, -0.34550029, 0.04266912, -0.15680569, diff --git a/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc index db456d539b9..58a2ef9c1ea 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc @@ -27,6 +27,16 @@ namespace ops { namespace builtin { namespace bidirectional_sequence_rnn { +namespace { + +struct OpData { + int scratch_tensor_index; + bool fw_compute_row_sums = false; + bool bw_compute_row_sums = false; +}; + +} // namespace + // LINT.IfChange constexpr int kInputTensor = 0; @@ -58,18 +68,23 @@ enum TemporaryTensor { kFwHiddenStateQuantized = 1, kBwHiddenStateQuantized = 2, kScalingFactors = 3, - kAuxInputQuantized = 4, - kNumTemporaryTensors = 5 + kAccumScratch = 4, + kZeroPoints = 5, + kFwRowSums = 6, + kBwRowSums = 7, + kAuxInputQuantized = 8, + kNumTemporaryTensors = 9 }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* scratch_tensor_index = new int; - context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index); - return scratch_tensor_index; + auto* op_data = new OpData(); + context->AddTensors(context, kNumTemporaryTensors, + &op_data->scratch_tensor_index); + return op_data; } void Free(TfLiteContext* context, void* buffer) { - delete reinterpret_cast(buffer); + delete reinterpret_cast(buffer); } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { @@ -157,8 +172,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } if (IsHybridOp(input, fw_input_weights)) { - int* scratch_tensor_index = reinterpret_cast(node->user_data); - + OpData* op_data = reinterpret_cast(node->user_data); + op_data->fw_compute_row_sums = true; + op_data->bw_compute_row_sums = true; TfLiteIntArrayFree(node->temporaries); if (has_aux_input) { node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors); @@ -168,7 +184,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } node->temporaries->data[kInputQuantized] = - *scratch_tensor_index + kInputQuantized; + op_data->scratch_tensor_index + kInputQuantized; TfLiteTensor* input_quantized = GetTemporary(context, node, kInputQuantized); input_quantized->type = fw_input_weights->type; @@ -180,7 +196,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } node->temporaries->data[kFwHiddenStateQuantized] = - *scratch_tensor_index + kFwHiddenStateQuantized; + op_data->scratch_tensor_index + kFwHiddenStateQuantized; TfLiteTensor* fw_hidden_state_quantized = GetTemporary(context, node, kFwHiddenStateQuantized); fw_hidden_state_quantized->type = fw_input_weights->type; @@ -195,7 +211,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } node->temporaries->data[kBwHiddenStateQuantized] = - *scratch_tensor_index + kBwHiddenStateQuantized; + op_data->scratch_tensor_index + kBwHiddenStateQuantized; TfLiteTensor* bw_hidden_state_quantized = GetTemporary(context, node, kBwHiddenStateQuantized); bw_hidden_state_quantized->type = fw_input_weights->type; @@ -211,7 +227,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Allocate temporary tensors to store scaling factors of quantization. node->temporaries->data[kScalingFactors] = - *scratch_tensor_index + kScalingFactors; + op_data->scratch_tensor_index + kScalingFactors; TfLiteTensor* scaling_factors = GetTemporary(context, node, kScalingFactors); scaling_factors->type = kTfLiteFloat32; @@ -223,10 +239,66 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, scaling_factors_size)); } - + node->temporaries->data[kAccumScratch] = + op_data->scratch_tensor_index + kAccumScratch; + TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch); + accum_scratch->type = kTfLiteInt32; + accum_scratch->allocation_type = kTfLiteArenaRw; + int accum_scratch_dims[2] = {std::max(fw_num_units, bw_num_units), + batch_size}; + if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2, + accum_scratch_dims)) { + TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2); + accum_scratch_size->data[0] = accum_scratch_dims[0]; + accum_scratch_size->data[1] = accum_scratch_dims[1]; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch, + accum_scratch_size)); + } + node->temporaries->data[kZeroPoints] = + op_data->scratch_tensor_index + kZeroPoints; + TfLiteTensor* zero_points = + GetTemporary(context, node, /*index=*/kZeroPoints); + zero_points->type = kTfLiteInt32; + zero_points->allocation_type = kTfLiteArenaRw; + int zero_points_dims[1] = {batch_size}; + if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) { + TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1); + zero_points_size->data[0] = batch_size; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points, + zero_points_size)); + } + const int num_row_sums = has_aux_input ? 3 : 2; + node->temporaries->data[kFwRowSums] = + op_data->scratch_tensor_index + kFwRowSums; + TfLiteTensor* fw_row_sums = + GetTemporary(context, node, /*index=*/kFwRowSums); + fw_row_sums->type = kTfLiteInt32; + fw_row_sums->allocation_type = kTfLiteArenaRwPersistent; + int fw_row_sums_dims[2] = {num_row_sums, fw_num_units}; + if (!TfLiteIntArrayEqualsArray(fw_row_sums->dims, 2, fw_row_sums_dims)) { + TfLiteIntArray* fw_row_sums_size = TfLiteIntArrayCreate(2); + fw_row_sums_size->data[0] = fw_row_sums_dims[0]; + fw_row_sums_size->data[1] = fw_row_sums_dims[1]; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_row_sums, + fw_row_sums_size)); + } + node->temporaries->data[kBwRowSums] = + op_data->scratch_tensor_index + kBwRowSums; + TfLiteTensor* bw_row_sums = GetTemporary(context, node, + /*index=*/kBwRowSums); + bw_row_sums->type = kTfLiteInt32; + bw_row_sums->allocation_type = kTfLiteArenaRwPersistent; + int bw_row_sums_dims[2] = {num_row_sums, bw_num_units}; + if (!TfLiteIntArrayEqualsArray(bw_row_sums->dims, 2, bw_row_sums_dims)) { + TfLiteIntArray* bw_row_sums_size = TfLiteIntArrayCreate(2); + bw_row_sums_size->data[0] = bw_row_sums_dims[0]; + bw_row_sums_size->data[1] = bw_row_sums_dims[1]; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_row_sums, + bw_row_sums_size)); + } if (has_aux_input) { node->temporaries->data[kAuxInputQuantized] = - *scratch_tensor_index + kAuxInputQuantized; + op_data->scratch_tensor_index + kAuxInputQuantized; TfLiteTensor* aux_input_quantized = GetTemporary(context, node, kAuxInputQuantized); aux_input_quantized->type = fw_input_weights->type; @@ -418,7 +490,10 @@ TfLiteStatus EvalHybrid( TfLiteTensor* aux_input_quantized, TfLiteTensor* fw_hidden_state_quantized, TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output, TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state, - TfLiteTensor* bw_output) { + TfLiteTensor* bw_output, TfLiteTensor* zero_points, + TfLiteTensor* accum_scratch, TfLiteTensor* fw_row_sums, + TfLiteTensor* bw_row_sums, bool* fw_compute_row_sums, + bool* bw_compute_row_sums) { const bool time_major = params->time_major; const int batch_size = (time_major) ? input->dims->data[1] : input->dims->data[0]; @@ -464,11 +539,20 @@ TfLiteStatus EvalHybrid( int8_t* bw_quantized_hidden_state_ptr = GetTensorData(bw_hidden_state_quantized); float* scaling_factors_ptr = GetTensorData(scaling_factors); - + int32_t* accum_scratch_ptr = GetTensorData(accum_scratch); + int32_t* zero_points_ptr = nullptr; + int32_t* fw_row_sums_ptr = nullptr; + int32_t* bw_row_sums_ptr = nullptr; + if (params->asymmetric_quantize_inputs) { + zero_points_ptr = GetTensorData(zero_points); + fw_row_sums_ptr = GetTensorData(fw_row_sums); + bw_row_sums_ptr = GetTensorData(bw_row_sums); + } const int fw_output_step = params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units; const int bw_output_step = params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units; + if (time_major) { for (int t = 0; t < max_time; t++) { // Forward cell. @@ -491,7 +575,9 @@ TfLiteStatus EvalHybrid( fw_num_units, batch_size, fw_output_step, params->activation, quantized_input_ptr, aux_quantized_input_ptr, fw_quantized_hidden_state_ptr, scaling_factors_ptr, - fw_hidden_state_ptr_batch, output_ptr_batch); + fw_hidden_state_ptr_batch, output_ptr_batch, + params->asymmetric_quantize_inputs, zero_points_ptr, + accum_scratch_ptr, fw_row_sums_ptr, fw_compute_row_sums); } // Backward cell. float* bw_hidden_state_ptr_batch = GetTensorData(bw_hidden_state); @@ -516,7 +602,9 @@ TfLiteStatus EvalHybrid( bw_num_units, batch_size, bw_output_step, params->activation, quantized_input_ptr, aux_quantized_input_ptr, bw_quantized_hidden_state_ptr, scaling_factors_ptr, - bw_hidden_state_ptr_batch, output_ptr_batch); + bw_hidden_state_ptr_batch, output_ptr_batch, + params->asymmetric_quantize_inputs, zero_points_ptr, + accum_scratch_ptr, bw_row_sums_ptr, bw_compute_row_sums); } } } else { @@ -545,7 +633,9 @@ TfLiteStatus EvalHybrid( fw_num_units, /*batch_size=*/1, fw_output_step, params->activation, quantized_input_ptr, aux_quantized_input_ptr, fw_quantized_hidden_state_ptr, scaling_factors_ptr, - fw_hidden_state_ptr_batch, output_ptr_batch); + fw_hidden_state_ptr_batch, output_ptr_batch, + params->asymmetric_quantize_inputs, zero_points_ptr, + accum_scratch_ptr, fw_row_sums_ptr, fw_compute_row_sums); } // Backward cell. float* bw_hidden_state_ptr_batch = @@ -574,7 +664,9 @@ TfLiteStatus EvalHybrid( bw_num_units, /*batch_size=*/1, bw_output_step, params->activation, quantized_input_ptr, aux_quantized_input_ptr, bw_quantized_hidden_state_ptr, scaling_factors_ptr, - bw_hidden_state_ptr_batch, output_ptr_batch); + bw_hidden_state_ptr_batch, output_ptr_batch, + params->asymmetric_quantize_inputs, zero_points_ptr, + accum_scratch_ptr, bw_row_sums_ptr, bw_compute_row_sums); } } } @@ -656,17 +748,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTemporary(context, node, kBwHiddenStateQuantized); TfLiteTensor* scaling_factors = GetTemporary(context, node, kScalingFactors); + TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints); + TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch); + TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums); + TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums); TfLiteTensor* aux_input_quantized = use_aux_input ? GetTemporary(context, node, kAuxInputQuantized) : nullptr; - - return EvalHybrid(input, bw_input, fw_input_weights, fw_recurrent_weights, - fw_bias, bw_input_weights, bw_recurrent_weights, - bw_bias, real_aux_input, fw_aux_input_weights, - bw_aux_input_weights, params, scaling_factors, - input_quantized, aux_input_quantized, - fw_hidden_state_quantized, fw_hidden_state, fw_output, - bw_hidden_state_quantized, bw_hidden_state, bw_output); + auto* op_data = reinterpret_cast(node->user_data); + return EvalHybrid( + input, bw_input, fw_input_weights, fw_recurrent_weights, fw_bias, + bw_input_weights, bw_recurrent_weights, bw_bias, real_aux_input, + fw_aux_input_weights, bw_aux_input_weights, params, scaling_factors, + input_quantized, aux_input_quantized, fw_hidden_state_quantized, + fw_hidden_state, fw_output, bw_hidden_state_quantized, + bw_hidden_state, bw_output, zero_points, accum_scratch, fw_row_sums, + bw_row_sums, &op_data->fw_compute_row_sums, + &op_data->bw_compute_row_sums); } default: context->ReportError(context, "Type not currently supported."); diff --git a/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc index 34441e2b300..4a7cc9a016d 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc @@ -662,20 +662,24 @@ class BidirectionalRNNOpModel : public SingleOpModel { BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units, int bw_units, int input_size, int aux_input_size, AuxInputMode aux_input_mode, bool time_major, - bool merge_outputs) + bool merge_outputs, bool quantize_weights = false, + bool asymmetric_quantize_weights = false) : batches_(batches), sequence_len_(sequence_len), fw_units_(fw_units), bw_units_(bw_units), input_size_(input_size), - aux_input_size_(aux_input_size) { + aux_input_size_(aux_input_size), + quantize_weights_(quantize_weights) { + const TensorType tensor_type = + quantize_weights ? TensorType_UINT8 : TensorType_FLOAT32; input_ = AddInput(TensorType_FLOAT32); - fw_weights_ = AddInput(TensorType_FLOAT32); - fw_recurrent_weights_ = AddInput(TensorType_FLOAT32); + fw_weights_ = AddInput(tensor_type); + fw_recurrent_weights_ = AddInput(tensor_type); fw_bias_ = AddInput(TensorType_FLOAT32); fw_hidden_state_ = AddInput(TensorType_FLOAT32, true); - bw_weights_ = AddInput(TensorType_FLOAT32); - bw_recurrent_weights_ = AddInput(TensorType_FLOAT32); + bw_weights_ = AddInput(tensor_type); + bw_recurrent_weights_ = AddInput(tensor_type); bw_bias_ = AddInput(TensorType_FLOAT32); bw_hidden_state_ = AddInput(TensorType_FLOAT32, true); @@ -697,8 +701,8 @@ class BidirectionalRNNOpModel : public SingleOpModel { } if (aux_input_mode == AuxInputMode::kCrossLinking) { - aux_fw_weights_ = AddInput(TensorType_FLOAT32); - aux_bw_weights_ = AddInput(TensorType_FLOAT32); + aux_fw_weights_ = AddInput(tensor_type); + aux_bw_weights_ = AddInput(tensor_type); aux_fw_weights_shape = {fw_units, aux_input_size_}; aux_bw_weights_shape = {bw_units, aux_input_size_}; @@ -712,12 +716,12 @@ class BidirectionalRNNOpModel : public SingleOpModel { bw_output_ = AddOutput(TensorType_FLOAT32); } - SetBuiltinOp( - BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, - BuiltinOptions_BidirectionalSequenceRNNOptions, - CreateBidirectionalSequenceRNNOptions( - builder_, time_major, ActivationFunctionType_RELU, merge_outputs) - .Union()); + SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, + BuiltinOptions_BidirectionalSequenceRNNOptions, + CreateBidirectionalSequenceRNNOptions( + builder_, time_major, ActivationFunctionType_RELU, + merge_outputs, asymmetric_quantize_weights) + .Union()); BuildInterpreter({ input_shape, // input @@ -744,19 +748,35 @@ class BidirectionalRNNOpModel : public SingleOpModel { } void SetFwWeights(const std::vector& f) { - PopulateTensor(fw_weights_, f); + if (quantize_weights_) { + SymmetricQuantizeAndPopulate(fw_weights_, f); + } else { + PopulateTensor(fw_weights_, f); + } } void SetBwWeights(const std::vector& f) { - PopulateTensor(bw_weights_, f); + if (quantize_weights_) { + SymmetricQuantizeAndPopulate(bw_weights_, f); + } else { + PopulateTensor(bw_weights_, f); + } } void SetFwRecurrentWeights(const std::vector& f) { - PopulateTensor(fw_recurrent_weights_, f); + if (quantize_weights_) { + SymmetricQuantizeAndPopulate(fw_recurrent_weights_, f); + } else { + PopulateTensor(fw_recurrent_weights_, f); + } } void SetBwRecurrentWeights(const std::vector& f) { - PopulateTensor(bw_recurrent_weights_, f); + if (quantize_weights_) { + SymmetricQuantizeAndPopulate(bw_recurrent_weights_, f); + } else { + PopulateTensor(bw_recurrent_weights_, f); + } } void SetInput(std::initializer_list data) { @@ -772,11 +792,19 @@ class BidirectionalRNNOpModel : public SingleOpModel { } void SetAuxFwWeights(const std::vector& f) { - PopulateTensor(aux_fw_weights_, f); + if (quantize_weights_) { + SymmetricQuantizeAndPopulate(aux_fw_weights_, f); + } else { + PopulateTensor(aux_fw_weights_, f); + } } void SetAuxBwWeights(const std::vector& f) { - PopulateTensor(aux_bw_weights_, f); + if (quantize_weights_) { + SymmetricQuantizeAndPopulate(aux_bw_weights_, f); + } else { + PopulateTensor(aux_bw_weights_, f); + } } std::vector GetFwOutput() { return ExtractVector(fw_output_); } @@ -811,17 +839,31 @@ class BidirectionalRNNOpModel : public SingleOpModel { int bw_units_; int input_size_; int aux_input_size_; + bool quantize_weights_; }; +// Declare LSTMOpTest as a parameterized test. +class BidirectionalRNNOpTest + : public ::testing::TestWithParam<::testing::tuple> {}; + +INSTANTIATE_TEST_SUITE_P(QuantizationOrNot, BidirectionalRNNOpTest, + ::testing::Combine( + /*quantize_weights*/ ::testing::Bool(), + /*asymmetric_quantize_inputs*/ ::testing::Bool())); + // TODO(mirkov): add another test which directly compares to TF once TOCO // supports the conversion from dynamic_rnn with BasicRNNCell. -TEST(BidirectionalRNNOpTest, BlackBoxTest) { +TEST_P(BidirectionalRNNOpTest, BlackBoxTest) { + auto params = GetParam(); + const bool quantize_weights = std::get<0>(params); + const bool asymmetric_quantize_inputs = std::get<1>(params); BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*fw_units=*/16, /*bw_units=*/16, /*input_size=*/8, /*aux_input_size=*/0, /*aux_input_mode=*/AuxInputMode::kNoAuxInput, /*time_major=*/false, - /*merge_outputs=*/false); + /*merge_outputs=*/false, quantize_weights, + asymmetric_quantize_inputs); rnn.SetFwWeights(weights); rnn.SetBwWeights(weights); rnn.SetFwBias(biases); @@ -843,7 +885,9 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) { std::vector fw_expected; fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); - EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected))); + EXPECT_THAT(rnn.GetFwOutput(), + ElementsAreArray(ArrayFloatNear( + fw_expected, quantize_weights ? 1.42e-2 : 1e-5))); float* golden_bw_start = rnn_golden_bw_output; float* golden_bw_end = @@ -851,17 +895,23 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) { std::vector bw_expected; bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end); bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end); - EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected))); + EXPECT_THAT(rnn.GetBwOutput(), + ElementsAreArray(ArrayFloatNear( + bw_expected, quantize_weights ? 1.42e-2 : 1e-5))); } // Same as BlackBox test, but input is reshuffled to time_major format. -TEST(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) { +TEST_P(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) { + auto params = GetParam(); + const bool quantize_weights = std::get<0>(params); + const bool asymmetric_quantize_inputs = std::get<1>(params); BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*fw_units=*/16, /*bw_units=*/16, /*input_size=*/8, /*aux_input_size=*/0, /*aux_input_mode=*/AuxInputMode::kNoAuxInput, /*time_major=*/true, - /*merge_outputs=*/false); + /*merge_outputs=*/false, quantize_weights, + asymmetric_quantize_inputs); rnn.SetFwWeights(weights); rnn.SetBwWeights(weights); rnn.SetFwBias(biases); @@ -889,17 +939,26 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) { fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); } - EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected))); + constexpr float kHybridTolerance = 3.57e-1; + constexpr float kFloatTolerance = 1e-5; + EXPECT_THAT( + rnn.GetFwOutput(), + ElementsAreArray(ArrayFloatNear( + fw_expected, quantize_weights ? kHybridTolerance : kFloatTolerance))); } // Same as BlackBox test, yet with merged outputs. -TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) { +TEST_P(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) { + auto params = GetParam(); + const bool quantize_weights = std::get<0>(params); + const bool asymmetric_quantize_inputs = std::get<1>(params); BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*fw_units=*/16, /*bw_units=*/16, /*input_size=*/8, /*aux_input_size=*/0, /*aux_input_mode=*/AuxInputMode::kNoAuxInput, /*time_major=*/false, - /*merge_outputs=*/true); + /*merge_outputs=*/true, quantize_weights, + asymmetric_quantize_inputs); rnn.SetFwWeights(weights); rnn.SetBwWeights(weights); rnn.SetFwBias(biases); @@ -929,7 +988,8 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) { } } EXPECT_THAT(rnn.GetFwOutput(), - ElementsAreArray(ArrayFloatNear(merged_expected))); + ElementsAreArray(ArrayFloatNear( + merged_expected, quantize_weights ? 1.42e-2 : 1e-5))); } // Same as BlackBox test, but input is reshuffled to time_major format. diff --git a/tensorflow/lite/kernels/conv_test.cc b/tensorflow/lite/kernels/conv_test.cc index 1f609685dd9..10e014d0e21 100644 --- a/tensorflow/lite/kernels/conv_test.cc +++ b/tensorflow/lite/kernels/conv_test.cc @@ -1344,11 +1344,6 @@ class PerChannelQuantizedConvolutionOpModel : public BaseConvolutionOpModel { }; TEST_P(ConvolutionOpTest, SimplePerTensorTest) { - // TODO(b/138722124): Enable these tests on NNAPI. - if (SingleOpModel::GetForceUseNnapi()) { - return; - } - PerChannelQuantizedConvolutionOpModel m( GetRegistration(), {TensorType_INT8, {1, 2, 3, 2}, -63.5, 64, 0.5, -1}, {TensorType_INT8, diff --git a/tensorflow/lite/kernels/cpu_backend_context.cc b/tensorflow/lite/kernels/cpu_backend_context.cc index dfeea5d0a64..51284214ee4 100644 --- a/tensorflow/lite/kernels/cpu_backend_context.cc +++ b/tensorflow/lite/kernels/cpu_backend_context.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "public/gemmlowp.h" -#include "tensorflow/lite/experimental/ruy/ruy/context.h" +#include "ruy/context.h" // from @ruy #include "tensorflow/lite/kernels/op_macros.h" namespace { diff --git a/tensorflow/lite/kernels/cpu_backend_context.h b/tensorflow/lite/kernels/cpu_backend_context.h index eafae75fc47..46abcd5e90f 100644 --- a/tensorflow/lite/kernels/cpu_backend_context.h +++ b/tensorflow/lite/kernels/cpu_backend_context.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "public/gemmlowp.h" -#include "tensorflow/lite/experimental/ruy/ruy/context.h" +#include "ruy/context.h" // from @ruy #include "tensorflow/lite/external_cpu_backend_context.h" namespace tflite { diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h b/tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h index 6fde100a4bf..f85a1715af2 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h @@ -35,7 +35,7 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" #include "tensorflow/lite/kernels/cpu_backend_threadpool.h" diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h b/tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h index 253c035688f..ad9bbb75ae5 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h @@ -22,7 +22,7 @@ limitations under the License. #include #include "public/gemmlowp.h" -#include "tensorflow/lite/experimental/ruy/ruy/ruy.h" +#include "ruy/ruy.h" // from @ruy #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h" diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h b/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h index c02dce2b773..d038c03ac04 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_ #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_ -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/ruy.h" +#include "ruy/path.h" // from @ruy +#include "ruy/ruy.h" // from @ruy #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_test.cc b/tensorflow/lite/kernels/cpu_backend_gemm_test.cc index d26df809c97..75181a979eb 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_test.cc +++ b/tensorflow/lite/kernels/cpu_backend_gemm_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/ruy/ruy.h" +#include "ruy/ruy.h" // from @ruy #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" diff --git a/tensorflow/lite/kernels/cpu_backend_threadpool.h b/tensorflow/lite/kernels/cpu_backend_threadpool.h index b924826a07c..ff03d372d5e 100644 --- a/tensorflow/lite/kernels/cpu_backend_threadpool.h +++ b/tensorflow/lite/kernels/cpu_backend_threadpool.h @@ -20,8 +20,8 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/compatibility.h" #ifdef TFLITE_WITH_RUY -#include "tensorflow/lite/experimental/ruy/ruy/context.h" -#include "tensorflow/lite/experimental/ruy/ruy/thread_pool.h" +#include "ruy/context.h" // from @ruy +#include "ruy/thread_pool.h" // from @ruy #else #include "public/gemmlowp.h" #endif diff --git a/tensorflow/lite/kernels/depthwise_conv_test.cc b/tensorflow/lite/kernels/depthwise_conv_test.cc index 344d156545d..5d85eac4aa9 100644 --- a/tensorflow/lite/kernels/depthwise_conv_test.cc +++ b/tensorflow/lite/kernels/depthwise_conv_test.cc @@ -1621,10 +1621,6 @@ class PerChannelQuantizedDepthwiseConvolutionOpTest : public SingleOpTest { }; TEST_P(PerChannelQuantizedDepthwiseConvolutionOpTest, SimplePerTensorTest) { - // TODO(b/138722124): Enable these tests on NNAPI. - if (SingleOpModel::GetForceUseNnapi()) { - return; - } PerChannelQuantizedDepthwiseConvolutionOpModel m( GetRegistration(), {TensorType_INT8, {1, 2, 3, 2}, -63.5, 64, 0.5, -1}, {TensorType_INT8, diff --git a/tensorflow/lite/kernels/div.cc b/tensorflow/lite/kernels/div.cc index 21480884e94..731fb3c2fe2 100644 --- a/tensorflow/lite/kernels/div.cc +++ b/tensorflow/lite/kernels/div.cc @@ -115,13 +115,13 @@ void EvalDiv(TfLiteContext* context, TfLiteNode* node, TfLiteDivParams* params, if (output->type == kTfLiteInt32) { if (kernel_type == kReference) { if (data->requires_broadcast) { - TF_LITE_DIV(reference_ops, BroadcastDiv4DSlow, int32_t); + TF_LITE_DIV(reference_ops, BroadcastDivSlow, int32_t); } else { TF_LITE_DIV(reference_ops, Div, int32_t); } } else { if (data->requires_broadcast) { - TF_LITE_DIV(optimized_ops, BroadcastDiv4DSlow, int32_t); + TF_LITE_DIV(optimized_ops, BroadcastDivSlow, int32_t); } else { TF_LITE_DIV(optimized_ops, Div, int32_t); } @@ -129,13 +129,13 @@ void EvalDiv(TfLiteContext* context, TfLiteNode* node, TfLiteDivParams* params, } else if (output->type == kTfLiteFloat32) { if (kernel_type == kReference) { if (data->requires_broadcast) { - TF_LITE_DIV(reference_ops, BroadcastDiv4DSlow, float); + TF_LITE_DIV(reference_ops, BroadcastDivSlow, float); } else { TF_LITE_DIV(reference_ops, Div, float); } } else { if (data->requires_broadcast) { - TF_LITE_DIV(optimized_ops, BroadcastDiv4DSlow, float); + TF_LITE_DIV(optimized_ops, BroadcastDivSlow, float); } else { TF_LITE_DIV(optimized_ops, Div, float); } @@ -168,13 +168,13 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, GetTensorData(output)) if (kernel_type == kReference) { if (need_broadcast) { - TF_LITE_DIV(reference_ops, BroadcastDiv4DSlow, uint8_t); + TF_LITE_DIV(reference_ops, BroadcastDivSlow, uint8_t); } else { TF_LITE_DIV(reference_ops, Div, uint8_t); } } else { if (need_broadcast) { - TF_LITE_DIV(optimized_ops, BroadcastDiv4DSlow, uint8_t); + TF_LITE_DIV(optimized_ops, BroadcastDivSlow, uint8_t); } else { TF_LITE_DIV(optimized_ops, Div, uint8_t); } diff --git a/tensorflow/lite/kernels/div_test.cc b/tensorflow/lite/kernels/div_test.cc index c3ba35f61c7..e72565f84a0 100644 --- a/tensorflow/lite/kernels/div_test.cc +++ b/tensorflow/lite/kernels/div_test.cc @@ -119,17 +119,35 @@ TEST(FloatDivOpTest, VariousInputShapes) { TEST(FloatDivOpTest, WithBroadcast) { std::vector> test_shapes = { - {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + {8}, {2, 4}, {2, 1, 4}, {1, 2, 2, 2}}; for (int i = 0; i < test_shapes.size(); ++i) { FloatDivOpModel m({TensorType_FLOAT32, test_shapes[i]}, {TensorType_FLOAT32, {}}, // always a scalar {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); - m.PopulateTensor(m.input1(), {-0.2, 0.2, 0.07, 0.08, 0.11, -0.123}); + m.PopulateTensor(m.input1(), + {-0.2, 0.2, 0.07, 0.08, 0.11, -0.123, -0.32, 0.54}); m.PopulateTensor(m.input2(), {0.1}); m.Invoke(); - EXPECT_THAT( - m.GetOutput(), - ElementsAreArray(ArrayFloatNear({-2.0, 2.0, 0.7, 0.8, 1.1, -1.23}))) + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {-2.0, 2.0, 0.7, 0.8, 1.1, -1.23, -3.2, 5.4}))) + << "With shape number " << i; + } +} + +TEST(FloatDivOpTest, WithBroadcast5D) { + std::vector> test_shapes = {{1, 2, 1, 2, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + FloatDivOpModel m({TensorType_FLOAT32, test_shapes[i]}, + {TensorType_FLOAT32, {}}, // always a scalar + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), + {-0.2, 0.2, 0.07, 0.08, 0.11, -0.123, -0.32, 0.54}); + m.PopulateTensor(m.input2(), {0.1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {-2.0, 2.0, 0.7, 0.8, 1.1, -1.23, -3.2, 5.4}))) << "With shape number " << i; } } @@ -171,15 +189,16 @@ TEST(IntegerDivOpTest, VariousInputShapes) { TEST(IntegerDivOpTest, WithBroadcast) { std::vector> test_shapes = { - {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + {8}, {2, 4}, {2, 1, 4}, {1, 4, 1, 2}, {1, 2, 1, 2, 2}}; for (int i = 0; i < test_shapes.size(); ++i) { IntegerDivOpModel m({TensorType_INT32, test_shapes[i]}, {TensorType_INT32, {}}, // always a scalar {TensorType_INT32, {}}, ActivationFunctionType_NONE); - m.PopulateTensor(m.input1(), {-20, 21, 7, 8, 11, -123}); + m.PopulateTensor(m.input1(), {-20, 21, 7, 8, 11, -123, -42, -48}); m.PopulateTensor(m.input2(), {3}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({-6, 7, 2, 2, 3, -41})) + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({-6, 7, 2, 2, 3, -41, -14, -16})) << "With shape number " << i; } } @@ -262,19 +281,19 @@ template void QuantizedWithBroadcast() { const float kQuantizedTolerance = GetTolerance(-3.0, 3.0); const std::vector> test_shapes = { - {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + {8}, {2, 4}, {2, 1, 4}, {1, 4, 1, 2}, {1, 2, 1, 2, 2}}; for (int i = 0; i < test_shapes.size(); ++i) { QuantizedDivOpModel m( {tensor_type, test_shapes[i], -3.0, 3.0}, {tensor_type, {}, -3.0, 3.0}, {tensor_type, {}, -3.0, 3.0}, ActivationFunctionType_NONE); - m.QuantizeAndPopulate(m.input1(), - {-2.0, 0.2, 0.7, 0.8, -0.5, 1.1}); + m.QuantizeAndPopulate( + m.input1(), {-2.0, 0.2, 0.7, 0.8, -0.5, 1.1, -1.3, 1.2}); m.QuantizeAndPopulate(m.input2(), {0.7}); m.Invoke(); - EXPECT_THAT( - m.GetDequantizedOutput(), - ElementsAreArray(ArrayFloatNear( - {-2.857, 0.286, 1.0, 1.143, -0.714, 1.571}, kQuantizedTolerance))) + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + {-2.857, 0.286, 1.0, 1.143, -0.714, 1.571, -1.857, 1.714}, + kQuantizedTolerance))) << "With shape number " << i; } } diff --git a/tensorflow/lite/kernels/fully_connected.cc b/tensorflow/lite/kernels/fully_connected.cc index fc6f1991fd3..5faf13303d8 100644 --- a/tensorflow/lite/kernels/fully_connected.cc +++ b/tensorflow/lite/kernels/fully_connected.cc @@ -71,6 +71,7 @@ struct OpData { int32_t output_activation_max; // The index of the temporary tensor where the quantized inputs are cached. int scratch_tensor_index; + bool compute_row_sums = false; }; constexpr int kInputTensor = 0; @@ -131,7 +132,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { // Instead, we allocate a new object to carry information from Prepare() to // Eval(). auto* op_data = new OpData(); - context->AddTensors(context, /*tensors_to_add=*/3, + context->AddTensors(context, /*tensors_to_add=*/5, &op_data->scratch_tensor_index); return op_data; } @@ -144,7 +145,6 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); OpData* data = reinterpret_cast(node->user_data); - // Check we have all the inputs and outputs we need. TF_LITE_ENSURE(context, node->inputs->size == 2 || node->inputs->size == 3); // Shuffled formats need a workspace to store the shuffled input activations. @@ -208,7 +208,8 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) { if (input->type == kTfLiteFloat32 && (filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8)) { TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(3); + data->compute_row_sums = true; + node->temporaries = TfLiteIntArrayCreate(5); node->temporaries->data[0] = data->scratch_tensor_index; TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); @@ -245,6 +246,28 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK( context, context->ResizeTensor(context, accum_scratch, accum_size)); } + + node->temporaries->data[3] = data->scratch_tensor_index + 3; + TfLiteTensor* input_offsets = GetTemporary(context, node, /*index=*/3); + input_offsets->type = kTfLiteInt32; + input_offsets->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(input_offsets->dims, 1, scaling_dims)) { + TfLiteIntArray* input_offsets_size = TfLiteIntArrayCreate(1); + input_offsets_size->data[0] = batch_size; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_offsets, + input_offsets_size)); + } + node->temporaries->data[4] = data->scratch_tensor_index + 4; + TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/4); + row_sums->type = kTfLiteInt32; + row_sums->allocation_type = kTfLiteArenaRwPersistent; + int row_sums_dims[1] = {num_units}; + if (!TfLiteIntArrayEqualsArray(row_sums->dims, 1, row_sums_dims)) { + TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(1); + row_sums_size->data[0] = row_sums_dims[0]; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, row_sums, row_sums_size)); + } } // Resize output. @@ -332,7 +355,9 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, TfLiteFullyConnectedParams* params, OpData* data, const TfLiteTensor* input, const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* input_quantized, - TfLiteTensor* scaling_factors, TfLiteTensor* output) { + TfLiteTensor* scaling_factors, + TfLiteTensor* accum_scratch, TfLiteTensor* row_sums, + TfLiteTensor* input_offsets, TfLiteTensor* output) { int total_input_size = 1; for (int i = 0; i < input->dims->size; i++) { total_input_size *= input->dims->data[i]; @@ -363,32 +388,39 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, // Quantize input from float to uint8 + quantization params (scaling factor). float unused_min, unused_max; float* scaling_factors_ptr = GetTensorData(scaling_factors); + int32_t* input_offset_ptr = nullptr; + int32_t* row_sums_ptr = nullptr; + if (params->asymmetric_quantize_inputs) { + input_offset_ptr = GetTensorData(input_offsets); + row_sums_ptr = GetTensorData(row_sums); + } int8_t* quant_data = GetTensorData(input_quantized); const int8_t* filter_data = GetTensorData(filter); - + const float* input_ptr = GetTensorData(input); // Quantize each batch independently. for (int b = 0; b < batch_size; ++b) { const int offset = b * input_size; - tensor_utils::SymmetricQuantizeFloats( - GetTensorData(input) + offset, input_size, quant_data + offset, - &unused_min, &unused_max, &scaling_factors_ptr[b]); + if (params->asymmetric_quantize_inputs) { + tensor_utils::AsymmetricQuantizeFloats( + input_ptr + offset, input_size, quant_data + offset, + &scaling_factors_ptr[b], &input_offset_ptr[b]); + } else { + tensor_utils::SymmetricQuantizeFloats( + input_ptr + offset, input_size, quant_data + offset, &unused_min, + &unused_max, &scaling_factors_ptr[b]); + } // Incorporate scaling of the filter. scaling_factors_ptr[b] *= filter->params.scale; } // Compute output += weight * quantized_input -#ifdef TFLITE_WITH_RUY_GEMV - TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/2); int32_t* scratch = GetTensorData(accum_scratch); tensor_utils::MatrixBatchVectorMultiplyAccumulate( filter_data, num_units, input_size, quant_data, scaling_factors_ptr, - batch_size, scratch, GetTensorData(output), + batch_size, GetTensorData(output), /*per_channel_scale=*/nullptr, + input_offset_ptr, scratch, row_sums_ptr, &data->compute_row_sums, CpuBackendContext::GetFromContext(context)); -#else - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - filter_data, num_units, input_size, quant_data, scaling_factors_ptr, - batch_size, GetTensorData(output)); -#endif + // Apply activation function to floats. tensor_utils::ApplyActivationToVector( GetTensorData(output), batch_size * num_units, params->activation, @@ -461,8 +493,12 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, if (input->type == kTfLiteFloat32) { TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1); + TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/2); + TfLiteTensor* input_offsets = GetTemporary(context, node, /*index=*/3); + TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/4); return EvalHybrid(context, node, params, data, input, filter, bias, - input_quantized, scaling_factors, output); + input_quantized, scaling_factors, accum_scratch, row_sums, + input_offsets, output); } else { FullyConnectedParams op_params; op_params.input_offset = input_offset; @@ -590,7 +626,6 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, FullyConnectedParams op_params; op_params.float_activation_min = output_activation_min; op_params.float_activation_max = output_activation_max; - reference_ops::FullyConnected( op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(filter), GetTensorData(filter), diff --git a/tensorflow/lite/kernels/fully_connected_test.cc b/tensorflow/lite/kernels/fully_connected_test.cc index 1f671cae0fc..fbc02dd741d 100644 --- a/tensorflow/lite/kernels/fully_connected_test.cc +++ b/tensorflow/lite/kernels/fully_connected_test.cc @@ -286,7 +286,8 @@ class HybridFullyConnectedOpModel : public SingleOpModel { public: HybridFullyConnectedOpModel(int units, int batches, const TensorData& input, const TensorData& weights, - const TensorData& output = {TensorType_FLOAT32}) + const TensorData& output = {TensorType_FLOAT32}, + bool asymmetric_inputs = false) : batches_(batches), units_(units) { int total_input_size = 1; for (size_t i = 0; i < input.shape.size(); ++i) { @@ -302,10 +303,13 @@ class HybridFullyConnectedOpModel : public SingleOpModel { output_ = AddOutput(output); - SetBuiltinOp( - BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions, - CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU) - .Union()); + auto options = CreateFullyConnectedOptions( + builder_, ActivationFunctionType_RELU, + tflite::FullyConnectedOptionsWeightsFormat_DEFAULT, + false, asymmetric_inputs) + .Union(); + SetBuiltinOp(BuiltinOperator_FULLY_CONNECTED, + BuiltinOptions_FullyConnectedOptions, options); resolver_ = absl::make_unique( BuiltinOperator_FULLY_CONNECTED, ops::builtin::Register_FULLY_CONNECTED_PIE()); @@ -867,6 +871,66 @@ TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedInt8) { /*max_abs_error=*/1.3f))); } +TEST(HybridAsymmetricInputFullyConnectedOpTest, SimpleTestQuantizedUint8) { + HybridFullyConnectedOpModel m( + /*units=*/3, /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {2, 10}}, + /*weights=*/ + {TensorType_UINT8, {3, 10}, 0, 0, 10.0 / 127.0, 0}, {TensorType_FLOAT32}, + /*asymmetric_quantize_input*/ true); // Hybrid asymmetric + + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + { + 24, 25, 26, // + 58, 59, 60, // + }, + /*max_abs_error=*/0.64f))); +} + +TEST(HybridAsymmetricInputFullyConnectedOpTest, SimpleTestQuantizedInt8) { + HybridFullyConnectedOpModel m( + /*units=*/3, /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {2, 10}}, + /*weights=*/{TensorType_INT8, {3, 10}, 0, 0, 10.0 / 127.0, 0}, + {TensorType_FLOAT32}, + /*asymmetric_quantize_input*/ true); + + m.SetSignedWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + { + 24, 25, 26, // + 58, 59, 60, // + }, + /*max_abs_error=*/1.3f))); +} + TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput) { // Note that it is not required that the first dimension be the number of // batches. All we care is that the input can be evenly distributed in diff --git a/tensorflow/lite/kernels/gather_nd.cc b/tensorflow/lite/kernels/gather_nd.cc index 7332d6dfd47..4ca0864b94f 100644 --- a/tensorflow/lite/kernels/gather_nd.cc +++ b/tensorflow/lite/kernels/gather_nd.cc @@ -41,6 +41,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt8: case kTfLiteInt64: case kTfLiteInt32: + case kTfLiteString: break; default: context->ReportError( @@ -103,6 +104,15 @@ TfLiteStatus GatherNd(const TfLiteTensor* params, const TfLiteTensor* indices, return kTfLiteOk; } +template +TfLiteStatus GatherNdString(const TfLiteTensor* params, + const TfLiteTensor* indices, TfLiteTensor* output) { + reference_ops::GatherNdString( + GetTensorShape(params), params, GetTensorShape(indices), + GetTensorData(indices), GetTensorShape(output), output); + return kTfLiteOk; +} + template TfLiteStatus EvalGatherNd(TfLiteContext* context, const TfLiteTensor* params, const TfLiteTensor* indices, TfLiteTensor* output) { @@ -117,6 +127,8 @@ TfLiteStatus EvalGatherNd(TfLiteContext* context, const TfLiteTensor* params, return GatherNd(params, indices, output); case kTfLiteInt64: return GatherNd(params, indices, output); + case kTfLiteString: + return GatherNdString(params, indices, output); default: context->ReportError(context, "Params type '%s' are not supported by gather_nd.", diff --git a/tensorflow/lite/kernels/gather_nd_test.cc b/tensorflow/lite/kernels/gather_nd_test.cc index f90f7b64735..7e2714dac5e 100644 --- a/tensorflow/lite/kernels/gather_nd_test.cc +++ b/tensorflow/lite/kernels/gather_nd_test.cc @@ -313,5 +313,28 @@ TEST(GatherNdOpTest, Int64Int64) { ElementsAreArray({-2LL, 2LL, 2LL, 3LL, 3LL, -3LL})); } +TEST(GatherNdOpTest, StringInt32) { + GatherNdOpModel m({TensorType_STRING, {3, 2, 3}}, {TensorType_INT32, {2, 2}}); + m.SetInput({"A", "B", "C", "D", "E", "F", // + "G", "H", "I", "J", "K", "L", // + "M", "N", "O", "P", "Q", "R"}); + m.SetPositions({0, 1, 1, 0}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({"D", "E", "F", "G", "H", "I"})); +} + +TEST(GatherNdOpTest, StringInt64) { + GatherNdOpModel m({TensorType_STRING, {3, 2, 3}}, {TensorType_INT64, {2, 2}}); + m.SetInput({"A", "B", "C", "D", "E", "F", // + "G", "H", "I", "J", "K", "L", // + "M", "N", "O", "P", "Q", "R"}); + m.SetPositions({0LL, 1LL, 1LL, 0LL}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({"D", "E", "F", "G", "H", "I"})); +} } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 952073ef02a..373fffd8c24 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -249,7 +249,7 @@ cc_library( ":transpose_utils", "//third_party/eigen3", "@gemmlowp//:fixedpoint", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", + "@ruy//ruy/profiler:instrumentation", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:cpu_backend_context", "//tensorflow/lite/kernels:cpu_backend_threadpool", @@ -301,7 +301,7 @@ cc_library( "//tensorflow/lite/kernels:cpu_backend_context", "//tensorflow/lite/kernels:cpu_backend_threadpool", "//tensorflow/lite/kernels:cpu_backend_gemm", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", + "@ruy//ruy/profiler:instrumentation", ] + select({ ":haswell": tflite_deps_intel, ":ios_x86_64": tflite_deps_intel, @@ -477,7 +477,7 @@ cc_library( "//third_party/eigen3", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:op_macros", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", + "@ruy//ruy/profiler:instrumentation", "//tensorflow/lite/tools/optimize/sparsity:format_converter", ] + select({ ":haswell": tflite_deps_intel, @@ -542,7 +542,7 @@ cc_library( "@gemmlowp", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:op_macros", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", + "@ruy//ruy/profiler:instrumentation", "//tensorflow/lite/tools/optimize/sparsity:format_converter", ] + select({ ":haswell": tflite_deps_intel, @@ -626,10 +626,10 @@ cc_library( ":cpu_check", ":portable_tensor_utils", "//tensorflow/lite/c:common", - "//tensorflow/lite/experimental/ruy/ruy", - "//tensorflow/lite/experimental/ruy/ruy:detect_arm", "//tensorflow/lite/kernels:cpu_backend_context", "//tensorflow/lite/kernels:cpu_backend_gemm", + "@ruy//ruy", + "@ruy//ruy:detect_arm", ], ) @@ -822,10 +822,10 @@ cc_test( ":reference_base", ":test_util", ":types", - "//tensorflow/lite/experimental/ruy/ruy:context", "//tensorflow/lite/kernels:cpu_backend_context", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", + "@ruy//ruy:context", ], ) diff --git a/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc b/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc index 4f8ceb33595..1f2f6d57b9a 100644 --- a/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/ruy/context.h" +#include "ruy/context.h" // from @ruy #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h" diff --git a/tensorflow/lite/kernels/internal/kernel_utils.cc b/tensorflow/lite/kernels/internal/kernel_utils.cc index 21c058c394b..f34cee02f4d 100644 --- a/tensorflow/lite/kernels/internal/kernel_utils.cc +++ b/tensorflow/lite/kernels/internal/kernel_utils.cc @@ -123,7 +123,9 @@ void RnnBatchStep( int num_units, int batch_size, int output_batch_leading_dim, TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch, int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors, - float* hidden_state_ptr_batch, float* output_ptr_batch) { + float* hidden_state_ptr_batch, float* output_ptr_batch, + bool asymmetric_quantize_inputs, int32_t* zero_points, + int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums) { RnnBatchStep(input_ptr_batch, input_weights_ptr, input_weights_scale, /*aux_input_ptr_batch=*/nullptr, /*aux_input_weights_ptr=*/nullptr, @@ -133,7 +135,29 @@ void RnnBatchStep( output_batch_leading_dim, activation, quantized_input_ptr_batch, /*aux_quantized_input_ptr_batch=*/nullptr, quantized_hidden_state_ptr_batch, scaling_factors, - hidden_state_ptr_batch, output_ptr_batch); + hidden_state_ptr_batch, output_ptr_batch, + asymmetric_quantize_inputs, zero_points, accum_scratch, row_sums, + compute_row_sums); +} + +void ComputeMatrixSums(int32_t* input_row_sums, int32_t* aux_input_row_sums, + int32_t* recurrent_row_sums, int32_t* row_sums, + const float* aux_input_ptr_batch, int num_units, + int input_size, int aux_input_size, + const int8_t* input_weights_ptr, + const int8_t* aux_input_weights_ptr, + const int8_t* recurrent_weights_ptr) { + memset(input_row_sums, 0, sizeof(int32_t) * num_units); + tensor_utils::ReductionSumVector(input_weights_ptr, input_row_sums, num_units, + input_size); + if (aux_input_ptr_batch) { + memset(aux_input_row_sums, 0, sizeof(int32_t) * num_units); + tensor_utils::ReductionSumVector(aux_input_weights_ptr, aux_input_row_sums, + num_units, aux_input_size); + } + memset(recurrent_row_sums, 0, sizeof(int32_t) * num_units); + tensor_utils::ReductionSumVector(recurrent_weights_ptr, recurrent_row_sums, + num_units, num_units); } void RnnBatchStep( @@ -146,9 +170,31 @@ void RnnBatchStep( TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch, int8_t* aux_quantized_input_ptr_batch, int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors, - float* hidden_state_ptr_batch, float* output_ptr_batch) { + float* hidden_state_ptr_batch, float* output_ptr_batch, + bool asymmetric_quantize_inputs, int32_t* zero_points, + int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums) { // Since the output batch rows may not be contiguous (output_batch_leading_dim // != n_output), we unroll the batched operations where this is the case. + + int32_t* input_row_sums = nullptr; + int32_t* aux_input_row_sums = nullptr; + int32_t* recurrent_row_sums = nullptr; + if (asymmetric_quantize_inputs) { + input_row_sums = row_sums; + aux_input_row_sums = row_sums; + if (aux_input_ptr_batch) { + aux_input_row_sums += num_units; + } + recurrent_row_sums = aux_input_row_sums + num_units; + if (*compute_row_sums) { + ComputeMatrixSums(input_row_sums, aux_input_row_sums, recurrent_row_sums, + row_sums, aux_input_ptr_batch, num_units, input_size, + aux_input_size, input_weights_ptr, + aux_input_weights_ptr, recurrent_weights_ptr); + *compute_row_sums = false; + } + } + if (output_batch_leading_dim == num_units) { // Output = bias tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size, @@ -163,17 +209,25 @@ void RnnBatchStep( // whichever is faster. for (int b = 0; b < batch_size; ++b) { const int offset = b * input_size; - tensor_utils::SymmetricQuantizeFloats( - input_ptr_batch + offset, input_size, - quantized_input_ptr_batch + offset, &unused_min, &unused_max, - &scaling_factors[b]); + if (asymmetric_quantize_inputs) { + tensor_utils::AsymmetricQuantizeFloats( + input_ptr_batch + offset, input_size, + quantized_input_ptr_batch + offset, &scaling_factors[b], + &zero_points[b]); + } else { + tensor_utils::SymmetricQuantizeFloats( + input_ptr_batch + offset, input_size, + quantized_input_ptr_batch + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } scaling_factors[b] *= input_weights_scale; } - // Output += input * input_weights tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_weights_ptr, num_units, input_size, quantized_input_ptr_batch, - scaling_factors, batch_size, output_ptr_batch); + scaling_factors, batch_size, output_ptr_batch, + /*per_channel_scale=*/nullptr, zero_points, accum_scratch, + input_row_sums, compute_row_sums, /*context=*/nullptr); } if (aux_input_ptr_batch && @@ -182,10 +236,17 @@ void RnnBatchStep( float unused_min, unused_max; for (int b = 0; b < batch_size; ++b) { const int offset = b * aux_input_size; - tensor_utils::SymmetricQuantizeFloats( - aux_input_ptr_batch + offset, aux_input_size, - aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max, - &scaling_factors[b]); + if (asymmetric_quantize_inputs) { + tensor_utils::AsymmetricQuantizeFloats( + aux_input_ptr_batch + offset, aux_input_size, + aux_quantized_input_ptr_batch + offset, &scaling_factors[b], + &zero_points[b]); + } else { + tensor_utils::SymmetricQuantizeFloats( + aux_input_ptr_batch + offset, aux_input_size, + aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } scaling_factors[b] *= aux_input_weights_scale; } @@ -193,7 +254,9 @@ void RnnBatchStep( tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_weights_ptr, num_units, aux_input_size, aux_quantized_input_ptr_batch, scaling_factors, batch_size, - output_ptr_batch); + output_ptr_batch, /*per_channel_scale=*/nullptr, zero_points, + accum_scratch, aux_input_row_sums, compute_row_sums, + /*context=*/nullptr); } // Save quantization and matmul computation for all zero input. @@ -203,10 +266,17 @@ void RnnBatchStep( float unused_min, unused_max; for (int b = 0; b < batch_size; ++b) { const int offset = b * num_units; - tensor_utils::SymmetricQuantizeFloats( - hidden_state_ptr_batch + offset, num_units, - quantized_hidden_state_ptr_batch + offset, &unused_min, &unused_max, - &scaling_factors[b]); + if (asymmetric_quantize_inputs) { + tensor_utils::AsymmetricQuantizeFloats( + hidden_state_ptr_batch + offset, num_units, + quantized_hidden_state_ptr_batch + offset, &scaling_factors[b], + &zero_points[b]); + } else { + tensor_utils::SymmetricQuantizeFloats( + hidden_state_ptr_batch + offset, num_units, + quantized_hidden_state_ptr_batch + offset, &unused_min, + &unused_max, &scaling_factors[b]); + } scaling_factors[b] *= recurrent_weights_scale; } @@ -214,7 +284,9 @@ void RnnBatchStep( tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_weights_ptr, num_units, num_units, quantized_hidden_state_ptr_batch, scaling_factors, batch_size, - output_ptr_batch); + output_ptr_batch, /*per_channel_scale=*/nullptr, zero_points, + accum_scratch, recurrent_row_sums, compute_row_sums, + /*context=*/nullptr); } // Output = activation(Output) and update hidden_state @@ -238,10 +310,17 @@ void RnnBatchStep( // whichever is faster. for (int b = 0; b < batch_size; ++b) { const int offset = b * input_size; - tensor_utils::SymmetricQuantizeFloats( - input_ptr_batch + offset, input_size, - quantized_input_ptr_batch + offset, &unused_min, &unused_max, - &scaling_factors[b]); + if (asymmetric_quantize_inputs) { + tensor_utils::AsymmetricQuantizeFloats( + input_ptr_batch + offset, input_size, + quantized_input_ptr_batch + offset, &scaling_factors[b], + &zero_points[b]); + } else { + tensor_utils::SymmetricQuantizeFloats( + input_ptr_batch + offset, input_size, + quantized_input_ptr_batch + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } scaling_factors[b] *= input_weights_scale; } @@ -250,7 +329,9 @@ void RnnBatchStep( tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_weights_ptr, num_units, input_size, quantized_input_ptr_batch + k * input_size, &scaling_factors[k], - /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim); + /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim, + /*per_channel_scale=*/nullptr, zero_points + k, accum_scratch, + input_row_sums, compute_row_sums, /*context=*/nullptr); } } @@ -260,10 +341,17 @@ void RnnBatchStep( float unused_min, unused_max; for (int b = 0; b < batch_size; ++b) { const int offset = b * aux_input_size; - tensor_utils::SymmetricQuantizeFloats( - aux_input_ptr_batch + offset, aux_input_size, - aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max, - &scaling_factors[b]); + if (asymmetric_quantize_inputs) { + tensor_utils::AsymmetricQuantizeFloats( + aux_input_ptr_batch + offset, aux_input_size, + aux_quantized_input_ptr_batch + offset, &scaling_factors[b], + &zero_points[b]); + } else { + tensor_utils::SymmetricQuantizeFloats( + aux_input_ptr_batch + offset, aux_input_size, + aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } scaling_factors[b] *= aux_input_weights_scale; } @@ -273,7 +361,9 @@ void RnnBatchStep( aux_input_weights_ptr, num_units, aux_input_size, aux_quantized_input_ptr_batch + k * aux_input_size, &scaling_factors[k], - /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim); + /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim, + /*per_channel_scale=*/nullptr, zero_points + k, accum_scratch, + aux_input_row_sums, compute_row_sums, /*context=*/nullptr); } } @@ -284,10 +374,17 @@ void RnnBatchStep( float unused_min, unused_max; for (int b = 0; b < batch_size; ++b) { const int offset = b * num_units; - tensor_utils::SymmetricQuantizeFloats( - hidden_state_ptr_batch + offset, num_units, - quantized_hidden_state_ptr_batch + offset, &unused_min, &unused_max, - &scaling_factors[b]); + if (asymmetric_quantize_inputs) { + tensor_utils::AsymmetricQuantizeFloats( + hidden_state_ptr_batch + offset, num_units, + quantized_hidden_state_ptr_batch + offset, &scaling_factors[b], + &zero_points[b]); + } else { + tensor_utils::SymmetricQuantizeFloats( + hidden_state_ptr_batch + offset, num_units, + quantized_hidden_state_ptr_batch + offset, &unused_min, + &unused_max, &scaling_factors[b]); + } scaling_factors[b] *= recurrent_weights_scale; } @@ -296,8 +393,10 @@ void RnnBatchStep( tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_weights_ptr, num_units, num_units, quantized_hidden_state_ptr_batch + k * num_units, - &scaling_factors[k], - /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim); + &scaling_factors[k], /*n_batch=*/1, + output_ptr_batch + k * output_batch_leading_dim, + /*per_channel_scale=*/nullptr, zero_points + k, accum_scratch, + recurrent_row_sums, compute_row_sums, /*context=*/nullptr); } } diff --git a/tensorflow/lite/kernels/internal/kernel_utils.h b/tensorflow/lite/kernels/internal/kernel_utils.h index ebb91678fec..2f551570e17 100644 --- a/tensorflow/lite/kernels/internal/kernel_utils.h +++ b/tensorflow/lite/kernels/internal/kernel_utils.h @@ -70,7 +70,9 @@ void RnnBatchStep( int num_units, int batch_size, int output_batch_leading_dim, TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch, int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors, - float* hidden_state_ptr_batch, float* output_ptr_batch); + float* hidden_state_ptr_batch, float* output_ptr_batch, + bool asymmetric_quantize_inputs, int32_t* zero_points, + int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums); void RnnBatchStep( const float* input_ptr_batch, const int8_t* input_weights_ptr, @@ -82,7 +84,9 @@ void RnnBatchStep( TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch, int8_t* aux_quantized_input_ptr_batch, int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors, - float* hidden_state_ptr_batch, float* output_ptr_batch); + float* hidden_state_ptr_batch, float* output_ptr_batch, + bool asymmetric_quantize_inputs, int32_t* zero_points, + int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums); } // namespace kernel_utils } // namespace tflite diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h index 2768344696d..916edd561ff 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_3X3_FILTER_COMMON_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_3X3_FILTER_COMMON_H_ -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" #include "tensorflow/lite/kernels/internal/types.h" diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h index af763377763..a8f41d5a108 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/types.h" diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h index 1b86d91fb42..3f93a491862 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h" #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h index 293fd4248f2..73acbcf707b 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h" #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" diff --git a/tensorflow/lite/kernels/internal/optimized/im2col_utils.h b/tensorflow/lite/kernels/internal/optimized/im2col_utils.h index e3a9b9acdc6..42aa4825771 100644 --- a/tensorflow/lite/kernels/internal/optimized/im2col_utils.h +++ b/tensorflow/lite/kernels/internal/optimized/im2col_utils.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_IM2COL_UTILS_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_IM2COL_UTILS_H_ -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/add.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/add.h index 8db98cf1bdc..a9dae4feac5 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/add.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/add.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_ADD_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_ADD_H_ -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h" diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h index 6c1abaeff82..61f848c888e 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_CONV_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_CONV_H_ -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h index d44cfabe3c3..ffc7ea84340 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_DEPTHWISE_CONV_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_DEPTHWISE_CONV_H_ -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_threadpool.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_3x3_filter.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_3x3_filter.h index 97039e2e462..0cb1a23e556 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_3x3_filter.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_3x3_filter.h @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h" #include "tensorflow/lite/kernels/internal/types.h" diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid.h index 153a2252f39..37e9261b04a 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_DEPTHWISE_CONV_HYBRID_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_DEPTHWISE_CONV_HYBRID_H_ -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_threadpool.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid_3x3_filter.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid_3x3_filter.h index fa96ce94a6e..51f3d2559db 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid_3x3_filter.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid_3x3_filter.h @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h" #include "tensorflow/lite/kernels/internal/types.h" diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h index fdd3135097b..8de99c1a564 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_FULLY_CONNECTED_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_FULLY_CONNECTED_H_ -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm.h" #include "tensorflow/lite/kernels/cpu_backend_threadpool.h" diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h index 952415593a5..18aeef4c8b5 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_MUL_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_MUL_H_ -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/mul.h" diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/pooling.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/pooling.h index fb4642e7f0d..060845f4a10 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/pooling.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/pooling.h @@ -28,7 +28,7 @@ limitations under the License. #include #include "fixedpoint/fixedpoint.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/cppmath.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h" diff --git a/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h index f87738c34ff..bc8b9b2d3ac 100644 --- a/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h @@ -3429,9 +3429,9 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims, tflite::ArithmeticParams op_params; SetActivationParams(output_activation_min, output_activation_max, &op_params); - BroadcastDiv4DSlow(op_params, DimsToShape(input1_dims), input1_data, - DimsToShape(input2_dims), input2_data, - DimsToShape(output_dims), output_data); + BroadcastDivSlow(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, + DimsToShape(output_dims), output_data); } template diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc index 86e2f9fa96a..dc2204e3a60 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -23,8 +23,8 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/ruy/detect_arm.h" -#include "tensorflow/lite/experimental/ruy/ruy/ruy.h" +#include "ruy/detect_arm.h" // from @ruy +#include "ruy/ruy.h" // from @ruy #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" @@ -1310,6 +1310,13 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl( const int postamble_half_start = m_cols & ~(kWeightsPerNeonLane - 1); const int postamble_start = m_cols & ~((kWeightsPerNeonLane >> 1) - 1); + int32_t* row_sums_ptr = row_sums; + if (row_sums == nullptr) { + row_sums_ptr = static_cast(malloc(sizeof(int32_t) * m_rows)); + memset(row_sums_ptr, 0, sizeof(int32_t) * m_rows); + NeonReductionSumVector(matrix, row_sums_ptr, m_rows, m_cols); + } + for (int batch = 0; batch < n_batch; ++batch) { const float batch_scaling_factor = scaling_factors[batch]; const int batch_input_offset = input_offset[batch]; @@ -1327,10 +1334,6 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl( // Initialize the dot product sum for the row to 0. int32x4_t dotprod_32x4 = vmovq_n_s32(0); - int32x4_t row_sum_32x4; - if (row_sums == nullptr) { - row_sum_32x4 = vmovq_n_s32(0); - } // Prefetch the row to cache. __builtin_prefetch(row_ptr, 0 /* prefetch for read */, 3 /* temporal locality */); @@ -1358,10 +1361,6 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl( prod_16x8 = vmlal_s8(prod_16x8, vget_high_s8(s1_8x16), vget_high_s8(s2_8x16)); dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8); - if (row_sums == nullptr) { - const int16x8_t row_sum_16x8 = vpaddlq_s8(s2_8x16); - row_sum_32x4 = vpadalq_s16(row_sum_32x4, row_sum_16x8); - } } // for col // Half iteration dealing only 8 elements @@ -1375,29 +1374,24 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl( const int8x8_t s2_8x8 = vld1_s8((const int8_t*)(row_ptr + col)); const int16x8_t prod_16x8 = vmull_s8(s1_8x8, s2_8x8); dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8); - if (row_sums == nullptr) { - const int16x8_t row_sum_16x8 = vmovl_s8(s2_8x8); - row_sum_32x4 = vpadalq_s16(row_sum_32x4, row_sum_16x8); - } col += (kWeightsPerNeonLane >> 1); } int32_t dotprod = AccumulateNeonLane(dotprod_32x4); - int32_t row_sum = row_sums == nullptr ? AccumulateNeonLane(row_sum_32x4) - : row_sums[row]; // Postamble loop. for (; col < m_cols; ++col) { dotprod += row_ptr[col] * aligned_vec[col]; - if (row_sums == nullptr) { - row_sum += row_ptr[col]; - } } // for col - dotprod -= row_sum * batch_input_offset; + dotprod -= row_sums_ptr[row] * batch_input_offset; *result += dotprod * scale; ++result; } // for row } // for batch + + if (row_sums == nullptr) { + free(row_sums_ptr); + } if (unaligned) { free(aligned_row_free); } @@ -1410,6 +1404,20 @@ void NeonMatrixBatchVectorMultiplyAccumulate( int n_batch, float* __restrict__ result, const float* per_channel_scale, const int32_t* input_offset, int32_t* scratch, int32_t* row_sums, bool* compute_row_sums, CpuBackendContext* context) { + if (input_offset == nullptr) { +#ifdef TFLITE_WITH_RUY_GEMV + if (context) { + NeonMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors, + scaling_factors, n_batch, scratch, + result, context); + return; + } +#endif + NeonMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors, + scaling_factors, n_batch, result); + return; + } + if (compute_row_sums == nullptr || *compute_row_sums) { memset(row_sums, 0, sizeof(int32_t) * m_rows); NeonReductionSumVector(matrix, row_sums, m_rows, m_cols); @@ -1419,7 +1427,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate( } #ifdef TFLITE_WITH_RUY_GEMV - if (m_rows % 4 == 0) { + if (context != nullptr && m_rows % 4 == 0) { const int32_t* bias = static_cast(nullptr); NeonCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows, 0, scratch, context); @@ -1463,9 +1471,9 @@ void NeonMatrixBatchVectorMultiplyAccumulate( for (; i < total_size; i++) { const float batch_scaling_factor = scaling_factors[i / m_rows]; const int32_t zero_point = input_offset[i / m_rows]; - int32_t x = *(scratch_ptr++); - x -= row_sums[i % m_rows] * zero_point; - *result += x * batch_scaling_factor; + int32_t dotprod = *(scratch_ptr++); + dotprod -= row_sums[i % m_rows] * zero_point; + *result += dotprod * batch_scaling_factor; ++result; } return; diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index d98c51d1a2f..ce9073773a5 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -38,8 +38,8 @@ limitations under the License. #include "third_party/eigen3/Eigen/Core" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "fixedpoint/fixedpoint.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" @@ -2803,29 +2803,30 @@ inline void BroadcastMulDispatch( // reference_ops.h. Once an optimized version is implemented and NdArrayDesc // is no longer referenced in this file, move NdArrayDesc from types.h to // reference_ops.h. -template -void BroadcastDiv4DSlow(const ArithmeticParams& params, - const RuntimeShape& unextended_input1_shape, - const T* input1_data, - const RuntimeShape& unextended_input2_shape, - const T* input2_data, - const RuntimeShape& unextended_output_shape, - T* output_data) { - ruy::profiler::ScopeLabel label("BroadcastDiv4DSlow"); +template +void BroadcastDivSlow(const ArithmeticParams& params, + const RuntimeShape& unextended_input1_shape, + const T* input1_data, + const RuntimeShape& unextended_input2_shape, + const T* input2_data, + const RuntimeShape& unextended_output_shape, + T* output_data) { + ruy::profiler::ScopeLabel label("BroadcastDivSlow"); T output_activation_min; T output_activation_max; GetActivationParams(params, &output_activation_min, &output_activation_max); - TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); - const RuntimeShape output_shape = - RuntimeShape::ExtendedShape(4, unextended_output_shape); + TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N); + TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N); - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; + NdArrayDesc desc1; + NdArrayDesc desc2; + NdArrayDesc output_desc; NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, unextended_input2_shape, &desc1, &desc2); + CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape), + &output_desc); // In Tensorflow, the dimensions are canonically named (batch_number, row, // col, channel), with extents (batches, height, width, depth), with the @@ -2838,41 +2839,38 @@ void BroadcastDiv4DSlow(const ArithmeticParams& params, // We name our variables by their Tensorflow convention, but generate C code // nesting loops such that the innermost loop has the smallest stride for the // best cache behavior. - for (int b = 0; b < output_shape.Dims(0); ++b) { - for (int y = 0; y < output_shape.Dims(1); ++y) { - for (int x = 0; x < output_shape.Dims(2); ++x) { - for (int c = 0; c < output_shape.Dims(3); ++c) { - output_data[Offset(output_shape, b, y, x, c)] = - ActivationFunctionWithMinMax( - input1_data[SubscriptToIndex(desc1, b, y, x, c)] / - input2_data[SubscriptToIndex(desc2, b, y, x, c)], - output_activation_min, output_activation_max); - } - } - } - } + auto div_func = [&](int indexes[N]) { + output_data[SubscriptToIndex(output_desc, indexes)] = + ActivationFunctionWithMinMax( + input1_data[SubscriptToIndex(desc1, indexes)] / + input2_data[SubscriptToIndex(desc2, indexes)], + output_activation_min, output_activation_max); + }; + NDOpsHelper(output_desc, div_func); } // TODO: BroadcastDiv is intentionally duplicated from reference_ops.h. // For more details see the comment above the generic version of -// BroadcastDiv4DSlow. -inline void BroadcastDiv4DSlow(const ArithmeticParams& params, - const RuntimeShape& unextended_input1_shape, - const uint8* input1_data, - const RuntimeShape& unextended_input2_shape, - const uint8* input2_data, - const RuntimeShape& unextended_output_shape, - uint8* output_data) { - TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); - const RuntimeShape output_shape = - RuntimeShape::ExtendedShape(4, unextended_output_shape); +// BroadcastDivSlow. +template +inline void BroadcastDivSlow(const ArithmeticParams& params, + const RuntimeShape& unextended_input1_shape, + const uint8* input1_data, + const RuntimeShape& unextended_input2_shape, + const uint8* input2_data, + const RuntimeShape& unextended_output_shape, + uint8* output_data) { + TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N); + TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N); - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; + NdArrayDesc desc1; + NdArrayDesc desc2; + NdArrayDesc output_desc; NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, unextended_input2_shape, &desc1, &desc2); + CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape), + &output_desc); TFLITE_DCHECK_GT(params.input1_offset, -256); TFLITE_DCHECK_LT(params.input1_offset, 256); @@ -2881,39 +2879,31 @@ inline void BroadcastDiv4DSlow(const ArithmeticParams& params, TFLITE_DCHECK_GT(params.output_offset, -256); TFLITE_DCHECK_LT(params.output_offset, 256); - for (int b = 0; b < output_shape.Dims(0); ++b) { - for (int y = 0; y < output_shape.Dims(1); ++y) { - for (int x = 0; x < output_shape.Dims(2); ++x) { - for (int c = 0; c < output_shape.Dims(3); ++c) { - const int32 input1_val = - params.input1_offset + - input1_data[SubscriptToIndex(desc1, b, y, x, c)]; - const int32 input2_val = - params.input2_offset + - input2_data[SubscriptToIndex(desc2, b, y, x, c)]; - TFLITE_DCHECK_NE(input2_val, 0); - int recip_shift; - const int32 input2_inv = - (input2_val > 0) ? GetReciprocal(input2_val, 31, &recip_shift) - : -GetReciprocal(-input2_val, 31, &recip_shift); - const int headroom = CountLeadingSignBits(input1_val); - const int32 unscaled_quotient = - MultiplyByQuantizedMultiplierGreaterThanOne(input1_val, - input2_inv, headroom); - const int total_shift = params.output_shift - recip_shift - headroom; - const int32 unclamped_result = - params.output_offset + - MultiplyByQuantizedMultiplierSmallerThanOneExp( - unscaled_quotient, params.output_multiplier, total_shift); - const int32 clamped_output = std::min( - params.quantized_activation_max, - std::max(params.quantized_activation_min, unclamped_result)); - output_data[Offset(output_shape, b, y, x, c)] = - static_cast(clamped_output); - } - } - } - } + auto div_func = [&](int indexes[N]) { + const int32 input1_val = + params.input1_offset + input1_data[SubscriptToIndex(desc1, indexes)]; + const int32 input2_val = + params.input2_offset + input2_data[SubscriptToIndex(desc2, indexes)]; + TFLITE_DCHECK_NE(input2_val, 0); + int recip_shift; + const int32 input2_inv = + (input2_val > 0) ? GetReciprocal(input2_val, 31, &recip_shift) + : -GetReciprocal(-input2_val, 31, &recip_shift); + const int headroom = CountLeadingSignBits(input1_val); + const int32 unscaled_quotient = MultiplyByQuantizedMultiplierGreaterThanOne( + input1_val, input2_inv, headroom); + const int total_shift = params.output_shift - recip_shift - headroom; + const int32 unclamped_result = + params.output_offset + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + unscaled_quotient, params.output_multiplier, total_shift); + const int32 clamped_output = + std::min(params.quantized_activation_max, + std::max(params.quantized_activation_min, unclamped_result)); + output_data[SubscriptToIndex(output_desc, indexes)] = + static_cast(clamped_output); + }; + NDOpsHelper(output_desc, div_func); } // TODO(aselle): This is not actually optimized yet. diff --git a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.cc index fe970dd8b39..7fb69e7b4f4 100644 --- a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.cc @@ -167,6 +167,11 @@ void SseMatrixBatchVectorMultiplyAccumulate( const float* __restrict__ scaling_factors, int n_batch, float* __restrict__ result, const float* __restrict__ per_channel_scale, const int32_t* __restrict__ input_offset) { + if (input_offset == nullptr) { + SseMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors, + scaling_factors, n_batch, result); + return; + } static constexpr std::intptr_t kBlockSize = 16; for (std::intptr_t batch = 0; batch < n_batch; ++batch) { const float batch_scaling_factor = scaling_factors[batch]; diff --git a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h index fa6f2c7a8db..1d0d2273e93 100644 --- a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h @@ -59,9 +59,10 @@ void MatrixBatchVectorMultiplyAccumulate( int n_batch, float* __restrict__ result, const float* per_channel_scale, const int32_t* input_offset, int32_t* scratch, int32_t* row_sums, bool* compute_row_sums, CpuBackendContext* context) { - NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols, - vectors, scaling_factors, n_batch, result, per_channel_scale, - input_offset, scratch, row_sums, compute_row_sums, context); + PortableMatrixBatchVectorMultiplyAccumulate( + matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result, + per_channel_scale, input_offset, scratch, row_sums, compute_row_sums, + context); } void MatrixBatchVectorMultiplyAccumulate( diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h b/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h index 20571110005..a815c3f5252 100644 --- a/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h @@ -16,7 +16,7 @@ limitations under the License. #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MUL_H_ #include "fixedpoint/fixedpoint.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/common.h" namespace tflite { diff --git a/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h index 5e0cf7224ce..2148be45590 100644 --- a/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h @@ -613,9 +613,9 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims, tflite::ArithmeticParams op_params; SetActivationParams(output_activation_min, output_activation_max, &op_params); - BroadcastDiv4DSlow(op_params, DimsToShape(input1_dims), input1_data, - DimsToShape(input2_dims), input2_data, - DimsToShape(output_dims), output_data); + BroadcastDivSlow(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, + DimsToShape(output_dims), output_data); } template diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc index 9c58415d6dc..19c74973aeb 100644 --- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -196,6 +196,11 @@ void PortableMatrixBatchVectorMultiplyAccumulate( int n_batch, float* __restrict__ result, const float* per_channel_scale, const int32_t* input_offset, int32_t* scratch, int32_t* row_sums, bool* compute_row_sums, CpuBackendContext* context) { + if (input_offset == nullptr) { + PortableMatrixBatchVectorMultiplyAccumulate( + matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result); + return; + } if (!compute_row_sums || *compute_row_sums) { memset(row_sums, 0, sizeof(int32_t) * m_rows); PortableReductionSumVector(matrix, row_sums, m_rows, m_cols); diff --git a/tensorflow/lite/kernels/internal/reference/reduce.h b/tensorflow/lite/kernels/internal/reference/reduce.h index 46448b2a646..17dfd8557ae 100644 --- a/tensorflow/lite/kernels/internal/reference/reduce.h +++ b/tensorflow/lite/kernels/internal/reference/reduce.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REDUCE_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REDUCE_H_ -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/cppmath.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index 56443bb2139..a872bc4d56a 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -28,8 +28,8 @@ limitations under the License. #include "third_party/eigen3/Eigen/Core" #include "fixedpoint/fixedpoint.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/reference/add.h" @@ -477,28 +477,29 @@ inline void Mul(const ArithmeticParams& params, // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then // generate max(D1, D2) nested for loops. -template -void BroadcastDiv4DSlow(const ArithmeticParams& params, - const RuntimeShape& unextended_input1_shape, - const T* input1_data, - const RuntimeShape& unextended_input2_shape, - const T* input2_data, - const RuntimeShape& unextended_output_shape, - T* output_data) { +template +void BroadcastDivSlow(const ArithmeticParams& params, + const RuntimeShape& unextended_input1_shape, + const T* input1_data, + const RuntimeShape& unextended_input2_shape, + const T* input2_data, + const RuntimeShape& unextended_output_shape, + T* output_data) { T output_activation_min; T output_activation_max; GetActivationParams(params, &output_activation_min, &output_activation_max); - TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); - const RuntimeShape output_shape = - RuntimeShape::ExtendedShape(4, unextended_output_shape); + TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N); + TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N); - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; + NdArrayDesc desc1; + NdArrayDesc desc2; + NdArrayDesc output_desc; NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, unextended_input2_shape, &desc1, &desc2); + CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape), + &output_desc); // In Tensorflow, the dimensions are canonically named (batch_number, row, // col, channel), with extents (batches, height, width, depth), with the @@ -507,23 +508,15 @@ void BroadcastDiv4DSlow(const ArithmeticParams& params, // // In generated C code, we store arrays with the dimensions reversed. The // first dimension has smallest stride. - // - // We name our variables by their Tensorflow convention, but generate C code - // nesting loops such that the innermost loop has the smallest stride for - // the best cache behavior. - for (int b = 0; b < output_shape.Dims(0); ++b) { - for (int y = 0; y < output_shape.Dims(1); ++y) { - for (int x = 0; x < output_shape.Dims(2); ++x) { - for (int c = 0; c < output_shape.Dims(3); ++c) { - output_data[Offset(output_shape, b, y, x, c)] = - ActivationFunctionWithMinMax( - input1_data[SubscriptToIndex(desc1, b, y, x, c)] / - input2_data[SubscriptToIndex(desc2, b, y, x, c)], - output_activation_min, output_activation_max); - } - } - } - } + + auto div_func = [&](int indexes[N]) { + output_data[SubscriptToIndex(output_desc, indexes)] = + ActivationFunctionWithMinMax( + input1_data[SubscriptToIndex(desc1, indexes)] / + input2_data[SubscriptToIndex(desc2, indexes)], + output_activation_min, output_activation_max); + }; + NDOpsHelper(output_desc, div_func); } template @@ -592,23 +585,25 @@ inline void Div(const ArithmeticParams& params, DivElementwise(flat_size, params, input1_data, input2_data, output_data); } -inline void BroadcastDiv4DSlow(const ArithmeticParams& params, - const RuntimeShape& unextended_input1_shape, - const uint8* input1_data, - const RuntimeShape& unextended_input2_shape, - const uint8* input2_data, - const RuntimeShape& unextended_output_shape, - uint8* output_data) { - TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); - const RuntimeShape output_shape = - RuntimeShape::ExtendedShape(4, unextended_output_shape); +template +inline void BroadcastDivSlow(const ArithmeticParams& params, + const RuntimeShape& unextended_input1_shape, + const uint8* input1_data, + const RuntimeShape& unextended_input2_shape, + const uint8* input2_data, + const RuntimeShape& unextended_output_shape, + uint8* output_data) { + TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N); + TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N); - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; + NdArrayDesc desc1; + NdArrayDesc desc2; + NdArrayDesc output_desc; NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, unextended_input2_shape, &desc1, &desc2); + CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape), + &output_desc); TFLITE_DCHECK_GT(params.input1_offset, -256); TFLITE_DCHECK_LT(params.input1_offset, 256); @@ -617,39 +612,31 @@ inline void BroadcastDiv4DSlow(const ArithmeticParams& params, TFLITE_DCHECK_GT(params.output_offset, -256); TFLITE_DCHECK_LT(params.output_offset, 256); - for (int b = 0; b < output_shape.Dims(0); ++b) { - for (int y = 0; y < output_shape.Dims(1); ++y) { - for (int x = 0; x < output_shape.Dims(2); ++x) { - for (int c = 0; c < output_shape.Dims(3); ++c) { - const int32 input1_val = - params.input1_offset + - input1_data[SubscriptToIndex(desc1, b, y, x, c)]; - const int32 input2_val = - params.input2_offset + - input2_data[SubscriptToIndex(desc2, b, y, x, c)]; - TFLITE_DCHECK_NE(input2_val, 0); - int recip_shift; - const int32 input2_inv = - (input2_val > 0) ? GetReciprocal(input2_val, 31, &recip_shift) - : -GetReciprocal(-input2_val, 31, &recip_shift); - const int headroom = CountLeadingSignBits(input1_val); - const int32 unscaled_quotient = - MultiplyByQuantizedMultiplierGreaterThanOne(input1_val, - input2_inv, headroom); - const int total_shift = params.output_shift - recip_shift - headroom; - const int32 unclamped_result = - params.output_offset + - MultiplyByQuantizedMultiplierSmallerThanOneExp( - unscaled_quotient, params.output_multiplier, total_shift); - const int32 clamped_output = std::min( - params.quantized_activation_max, - std::max(params.quantized_activation_min, unclamped_result)); - output_data[Offset(output_shape, b, y, x, c)] = - static_cast(clamped_output); - } - } - } - } + auto div_func = [&](int indexes[N]) { + const int32 input1_val = + params.input1_offset + input1_data[SubscriptToIndex(desc1, indexes)]; + const int32 input2_val = + params.input2_offset + input2_data[SubscriptToIndex(desc2, indexes)]; + TFLITE_DCHECK_NE(input2_val, 0); + int recip_shift; + const int32 input2_inv = + (input2_val > 0) ? GetReciprocal(input2_val, 31, &recip_shift) + : -GetReciprocal(-input2_val, 31, &recip_shift); + const int headroom = CountLeadingSignBits(input1_val); + const int32 unscaled_quotient = MultiplyByQuantizedMultiplierGreaterThanOne( + input1_val, input2_inv, headroom); + const int total_shift = params.output_shift - recip_shift - headroom; + const int32 unclamped_result = + params.output_offset + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + unscaled_quotient, params.output_multiplier, total_shift); + const int32 clamped_output = + std::min(params.quantized_activation_max, + std::max(params.quantized_activation_min, unclamped_result)); + output_data[SubscriptToIndex(output_desc, indexes)] = + static_cast(clamped_output); + }; + NDOpsHelper(output_desc, div_func); } inline void Sub16(const ArithmeticParams& params, @@ -1561,6 +1548,40 @@ inline void Gather(const tflite::GatherParams& op_params, } } +// Common subroutine for both `GatherNd` and `GatherNdString`. +struct GatherNdHelperResult { + int n_slices; + int slice_size; + int indices_nd; + std::vector dims_to_count; +}; + +// Returns common values being used on both `GatherNd` and `GatherNdString`. +inline GatherNdHelperResult GatherNdHelper(const RuntimeShape& params_shape, + const RuntimeShape& indices_shape) { + GatherNdHelperResult ret; + ret.n_slices = 1; + ret.slice_size = 1; + const int indices_dims = indices_shape.DimensionsCount(); + ret.indices_nd = indices_shape.Dims(indices_dims - 1); + const int params_dims = params_shape.DimensionsCount(); + for (int i = 0; i < indices_dims - 1; ++i) { + ret.n_slices *= indices_shape.Dims(i); + } + for (int i = ret.indices_nd; i < params_dims; ++i) { + ret.slice_size *= params_shape.Dims(i); + } + + int remain_flat_size = params_shape.FlatSize(); + ret.dims_to_count = std::vector(ret.indices_nd, 0); + for (int i = 0; i < ret.indices_nd; ++i) { + ret.dims_to_count[i] = remain_flat_size / params_shape.Dims(i); + remain_flat_size = ret.dims_to_count[i]; + } + + return ret; +} + template inline void GatherNd(const RuntimeShape& params_shape, const ParamsT* params_data, @@ -1569,35 +1590,40 @@ inline void GatherNd(const RuntimeShape& params_shape, const RuntimeShape& output_shape, ParamsT* output_data) { ruy::profiler::ScopeLabel label("GatherNd"); - int n_slices = 1; - int slice_size = 1; - const int indices_dims = indices_shape.DimensionsCount(); - const int indices_nd = indices_shape.Dims(indices_dims - 1); - const int params_dims = params_shape.DimensionsCount(); - for (int i = 0; i < indices_dims - 1; ++i) { - n_slices *= indices_shape.Dims(i); - } - for (int i = indices_nd; i < params_dims; ++i) { - slice_size *= params_shape.Dims(i); - } - - int remain_flat_size = params_shape.FlatSize(); - std::vector dims_to_count(indices_nd, 0); - for (int i = 0; i < indices_nd; ++i) { - dims_to_count[i] = remain_flat_size / params_shape.Dims(i); - remain_flat_size = dims_to_count[i]; - } - - for (int i = 0; i < n_slices; ++i) { + const GatherNdHelperResult res = GatherNdHelper(params_shape, indices_shape); + for (int i = 0; i < res.n_slices; ++i) { int from_pos = 0; - for (int j = 0; j < indices_nd; ++j) { - from_pos += indices_data[i * indices_nd + j] * dims_to_count[j]; + for (int j = 0; j < res.indices_nd; ++j) { + from_pos += indices_data[i * res.indices_nd + j] * res.dims_to_count[j]; } - std::memcpy(output_data + i * slice_size, params_data + from_pos, - sizeof(ParamsT) * slice_size); + std::memcpy(output_data + i * res.slice_size, params_data + from_pos, + sizeof(ParamsT) * res.slice_size); } } +template +inline void GatherNdString(const RuntimeShape& params_shape, + const TfLiteTensor* params_data, + const RuntimeShape& indices_shape, + const IndicesT* indices_data, + const RuntimeShape& output_shape, + TfLiteTensor* output_data) { + ruy::profiler::ScopeLabel label("GatherNdString"); + + const GatherNdHelperResult res = GatherNdHelper(params_shape, indices_shape); + DynamicBuffer buffer; + for (int i = 0; i < res.n_slices; ++i) { + int from_pos = 0; + for (int j = 0; j < res.indices_nd; ++j) { + from_pos += indices_data[i * res.indices_nd + j] * res.dims_to_count[j]; + } + for (int j = 0; j < res.slice_size; ++j) { + buffer.AddString(GetString(params_data, from_pos + j)); + } + } + buffer.WriteToTensor(output_data, /*new_shape=*/nullptr); +} + template inline void ScatterNd(const RuntimeShape& indices_shape, const IndicesT* indices_data, diff --git a/tensorflow/lite/kernels/internal/reference/requantize.h b/tensorflow/lite/kernels/internal/reference/requantize.h index 8233be9ebae..32e32ed0d5b 100644 --- a/tensorflow/lite/kernels/internal/reference/requantize.h +++ b/tensorflow/lite/kernels/internal/reference/requantize.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REQUANTIZE_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REQUANTIZE_H_ -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/types.h" diff --git a/tensorflow/lite/kernels/internal/reference/sub.h b/tensorflow/lite/kernels/internal/reference/sub.h index a9ed3a675fd..48d03de02ee 100644 --- a/tensorflow/lite/kernels/internal/reference/sub.h +++ b/tensorflow/lite/kernels/internal/reference/sub.h @@ -16,7 +16,7 @@ limitations under the License. #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SUB_H_ #include "fixedpoint/fixedpoint.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/common.h" namespace tflite { diff --git a/tensorflow/lite/kernels/internal/reference/svdf.h b/tensorflow/lite/kernels/internal/reference/svdf.h index 10c2e2cd849..18e4e079293 100644 --- a/tensorflow/lite/kernels/internal/reference/svdf.h +++ b/tensorflow/lite/kernels/internal/reference/svdf.h @@ -223,7 +223,8 @@ inline void EvalHybridSVDF( const TfLiteTensor* weights_feature, const TfLiteTensor* weights_time, const TfLiteTensor* bias, const TfLiteSVDFParams* params, TfLiteTensor* scratch, TfLiteTensor* scaling_factors, - TfLiteTensor* input_quantized, TfLiteTensor* state, TfLiteTensor* output) { + TfLiteTensor* input_quantized, TfLiteTensor* state, TfLiteTensor* output, + TfLiteTensor* zero_points, TfLiteTensor* row_sums, bool* compute_row_sums) { const int rank = params->rank; const int batch_size = input->dims->data[0]; const int input_size = input->dims->data[1]; @@ -244,6 +245,13 @@ inline void EvalHybridSVDF( float* output_ptr = GetTensorData(output); + int32_t* zero_points_ptr = nullptr; + int32_t* row_sums_ptr = nullptr; + if (params->asymmetric_quantize_inputs && row_sums != nullptr) { + zero_points_ptr = GetTensorData(zero_points); + row_sums_ptr = GetTensorData(row_sums); + } + // Initialize the weights scale. const float weights_feature_scale = weights_feature->params.scale; @@ -258,21 +266,30 @@ inline void EvalHybridSVDF( if (!tensor_utils::IsZeroVector(input_ptr, batch_size * input_size)) { // Quantize input from float to int8. - float unused_min, unused_max; for (int b = 0; b < batch_size; ++b) { const int offset = b * input_size; - tensor_utils::SymmetricQuantizeFloats( - input_ptr + offset, input_size, quantized_input_ptr + offset, - &unused_min, &unused_max, &scaling_factors_ptr[b]); + if (params->asymmetric_quantize_inputs) { + tensor_utils::AsymmetricQuantizeFloats( + input_ptr + offset, input_size, quantized_input_ptr + offset, + &scaling_factors_ptr[b], &zero_points_ptr[b]); + } else { + // Quantize input from float to int8. + float unused_min, unused_max; + tensor_utils::SymmetricQuantizeFloats( + input_ptr + offset, input_size, quantized_input_ptr + offset, + &unused_min, &unused_max, &scaling_factors_ptr[b]); + } scaling_factors_ptr[b] *= weights_feature_scale; } // Compute conv1d(inputs, weights_feature). tensor_utils::MatrixBatchVectorMultiplyAccumulate( weights_feature_ptr, num_filters, input_size, quantized_input_ptr, - scaling_factors_ptr, batch_size, scratch_ptr); + scaling_factors_ptr, batch_size, scratch_ptr, + /*per_channel_scale=*/nullptr, zero_points_ptr, + reinterpret_cast(scratch_ptr), row_sums_ptr, compute_row_sums, + /*context=*/nullptr); } - // Copy the latest activation from scratch into activation_state: // The last, i.e. (memory_size-1)th entry for each batch, and filter. for (int i = 0; i < batch_size * num_filters; ++i) { diff --git a/tensorflow/lite/kernels/internal/strided_slice_logic.h b/tensorflow/lite/kernels/internal/strided_slice_logic.h index 12dd33d3296..d9b5acbbbb4 100644 --- a/tensorflow/lite/kernels/internal/strided_slice_logic.h +++ b/tensorflow/lite/kernels/internal/strided_slice_logic.h @@ -76,6 +76,10 @@ inline int StartForAxis(const tflite::StridedSliceParams& params, const auto begin_mask = params.begin_mask; const auto* start_indices = params.start_indices; const auto* strides = params.strides; + const int axis_size = input_shape.Dims(axis); + if (axis_size == 0) { + return 0; + } // Begin with the specified index. int start = start_indices[axis]; @@ -93,7 +97,6 @@ inline int StartForAxis(const tflite::StridedSliceParams& params, } // Handle negative indices - int axis_size = input_shape.Dims(axis); if (start < 0) { start += axis_size; } @@ -116,6 +119,10 @@ inline int StopForAxis(const tflite::StridedSliceParams& params, const auto shrink_axis_mask = params.shrink_axis_mask; const auto* stop_indices = params.stop_indices; const auto* strides = params.strides; + const int axis_size = input_shape.Dims(axis); + if (axis_size == 0) { + return 0; + } // Begin with the specified index const bool shrink_axis = shrink_axis_mask & (1 << axis); @@ -142,7 +149,6 @@ inline int StopForAxis(const tflite::StridedSliceParams& params, } // Handle negative indices - const int axis_size = input_shape.Dims(axis); if (stop < 0) { stop += axis_size; } diff --git a/tensorflow/lite/kernels/lstm.cc b/tensorflow/lite/kernels/lstm.cc index bbda9257651..4eafc215b6f 100644 --- a/tensorflow/lite/kernels/lstm.cc +++ b/tensorflow/lite/kernels/lstm.cc @@ -55,6 +55,7 @@ struct OpData { // These fields are only used by full kernel. int scratch_tensor_index; lstm_eval::IntegerLstmParameter integer_lstm_param; + bool compute_row_sums; }; namespace full { @@ -727,7 +728,7 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8( void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* op_data = new OpData(); op_data->kernel_type = kTfLiteLSTMFullKernel; - context->AddTensors(context, /*tensors_to_add=*/8, + context->AddTensors(context, /*tensors_to_add=*/10, &op_data->scratch_tensor_index); return op_data; } @@ -1236,7 +1237,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArrayFree(node->temporaries); if (is_hybrid_op) { - node->temporaries = TfLiteIntArrayCreate(8); + node->temporaries = TfLiteIntArrayCreate(10); } else if (is_integer) { if (is_8x8_16) { node->temporaries = TfLiteIntArrayCreate(6); @@ -1273,6 +1274,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } if (is_hybrid_op) { + op_data->compute_row_sums = true; // Allocate temporary tensors to store quantized values of input, // activation_state and cell_state tensors. node->temporaries->data[1] = op_data->scratch_tensor_index + 1; @@ -1370,6 +1372,41 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK( context, context->ResizeTensor(context, accum_scratch, accum_size)); } + + node->temporaries->data[8] = op_data->scratch_tensor_index + 8; + TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/8); + zero_points->type = kTfLiteFloat32; + zero_points->allocation_type = kTfLiteArenaRw; + int zero_points_dims[1] = {n_batch}; + if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) { + TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1); + zero_points_size->data[0] = n_batch; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points, + zero_points_size)); + } + + node->temporaries->data[9] = op_data->scratch_tensor_index + 9; + const TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + const bool use_cifg = (input_to_input_weights == nullptr); + int row_sums_rows = use_cifg ? 6 : 8; + const TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + if (projection_weights != nullptr) { + row_sums_rows += ceil(n_output / n_cell); + } + + TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/9); + row_sums->type = kTfLiteInt32; + row_sums->allocation_type = kTfLiteArenaRwPersistent; + const int row_sums_dims[2] = {row_sums_rows, n_cell}; + if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) { + TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2); + row_sums_size->data[0] = row_sums_dims[0]; + row_sums_size->data[1] = row_sums_dims[1]; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, row_sums, row_sums_size)); + } } if (is_integer) { @@ -1556,6 +1593,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTemporary(context, node, /*index=*/6); TfLiteTensor* output_scratch_buffer = GetTemporary(context, node, /*index=*/7); + TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/8); + TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/9); + const int row_sums_size = row_sums->dims->data[0]; return lstm_eval::EvalHybrid( input, input_to_input_weights, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, @@ -1577,7 +1617,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { input_quantized, /*aux_input_quantized=*/nullptr, activation_state_quantized, cell_state_quantized, activation_state, cell_state, - output_scratch_buffer, output, + output_scratch_buffer, output, zero_points, row_sums, row_sums_size, + &op_data->compute_row_sums, CpuBackendContext::GetFromContext(context)); } else { const int num_intermediate_tensors = node->intermediates->size; diff --git a/tensorflow/lite/kernels/lstm_eval.cc b/tensorflow/lite/kernels/lstm_eval.cc index 9cc146ae8bd..9895c9183ec 100644 --- a/tensorflow/lite/kernels/lstm_eval.cc +++ b/tensorflow/lite/kernels/lstm_eval.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/internal/kernel_utils.h" @@ -33,24 +33,93 @@ namespace builtin { namespace lstm_eval { namespace { -inline float GetTensorScale(const TfLiteTensor* tensor) { - return tensor == nullptr ? 1.0f : tensor->params.scale; +void ComputeRowSums( + int32_t* input_to_input_row_sums, int32_t* input_to_forget_row_sums, + int32_t* input_to_cell_row_sums, int32_t* input_to_output_row_sums, + int32_t* aux_input_to_input_row_sums, int32_t* aux_input_to_forget_row_sums, + int32_t* aux_input_to_cell_row_sums, int32_t* aux_input_to_output_row_sums, + int32_t* recurrent_to_input_row_sums, int32_t* recurrent_to_forget_row_sums, + int32_t* recurrent_to_cell_row_sums, int32_t* recurrent_to_output_row_sums, + int32_t* projection_weights_row_sums, int32_t* row_sums, int n_cell, + int n_input, int n_aux_input, int n_output, + const int8_t* input_to_input_weights_ptr, + const int8_t* input_to_forget_weights_ptr, + const int8_t* input_to_cell_weights_ptr, + const int8_t* input_to_output_weights_ptr, + const int8_t* aux_input_to_input_weights_ptr, + const int8_t* aux_input_to_forget_weights_ptr, + const int8_t* aux_input_to_cell_weights_ptr, + const int8_t* aux_input_to_output_weights_ptr, + const int8_t* recurrent_to_input_weights_ptr, + const int8_t* recurrent_to_forget_weights_ptr, + const int8_t* recurrent_to_cell_weights_ptr, + const int8_t* recurrent_to_output_weights_ptr, + const int8_t* projection_weights_ptr, bool use_cifg, + const float* aux_input_ptr) { + // Compute the row sums for dequantization + if (!use_cifg) { + memset(input_to_input_row_sums, 0, sizeof(int32_t) * n_cell); + tensor_utils::ReductionSumVector(input_to_input_weights_ptr, + input_to_input_row_sums, n_cell, n_input); + } + memset(input_to_forget_row_sums, 0, sizeof(int32_t) * n_cell); + tensor_utils::ReductionSumVector(input_to_forget_weights_ptr, + input_to_forget_row_sums, n_cell, n_input); + memset(input_to_cell_row_sums, 0, sizeof(int32_t) * n_cell); + tensor_utils::ReductionSumVector(input_to_cell_weights_ptr, + input_to_cell_row_sums, n_cell, n_input); + memset(input_to_output_row_sums, 0, sizeof(int32_t) * n_cell); + tensor_utils::ReductionSumVector(input_to_output_weights_ptr, + input_to_output_row_sums, n_cell, n_input); + + if (aux_input_ptr) { + if (!use_cifg) { + memset(aux_input_to_input_row_sums, 0, sizeof(int32_t) * n_cell); + tensor_utils::ReductionSumVector(aux_input_to_input_weights_ptr, + aux_input_to_input_row_sums, n_cell, + n_aux_input); + } + memset(aux_input_to_forget_row_sums, 0, sizeof(int32_t) * n_cell); + tensor_utils::ReductionSumVector(aux_input_to_forget_weights_ptr, + aux_input_to_forget_row_sums, n_cell, + n_aux_input); + memset(aux_input_to_cell_row_sums, 0, sizeof(int32_t) * n_cell); + tensor_utils::ReductionSumVector(aux_input_to_cell_weights_ptr, + aux_input_to_cell_row_sums, n_cell, + n_aux_input); + memset(aux_input_to_output_row_sums, 0, sizeof(int32_t) * n_cell); + tensor_utils::ReductionSumVector(aux_input_to_output_weights_ptr, + aux_input_to_output_row_sums, n_cell, + n_aux_input); + } + if (!use_cifg) { + memset(recurrent_to_input_row_sums, 0, sizeof(int32_t) * n_cell); + tensor_utils::ReductionSumVector(recurrent_to_input_weights_ptr, + recurrent_to_input_row_sums, n_cell, + n_output); + } + memset(recurrent_to_forget_row_sums, 0, sizeof(int32_t) * n_cell); + tensor_utils::ReductionSumVector(recurrent_to_forget_weights_ptr, + recurrent_to_forget_row_sums, n_cell, + n_output); + memset(recurrent_to_cell_row_sums, 0, sizeof(int32_t) * n_cell); + tensor_utils::ReductionSumVector(recurrent_to_cell_weights_ptr, + recurrent_to_cell_row_sums, n_cell, + n_output); + memset(recurrent_to_output_row_sums, 0, sizeof(int32_t) * n_cell); + tensor_utils::ReductionSumVector(recurrent_to_output_weights_ptr, + recurrent_to_output_row_sums, n_cell, + n_output); + + if (projection_weights_ptr != nullptr) { + memset(projection_weights_row_sums, 0, sizeof(int32_t) * n_output); + tensor_utils::ReductionSumVector( + projection_weights_ptr, projection_weights_row_sums, n_output, n_cell); + } } -inline void MatrixBatchVectorMultiplyAccumulate( - const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, - const int8_t* __restrict__ vectors, const float* scaling_factors, - int n_batch, int32_t* scratch, float* __restrict__ result, - CpuBackendContext* context) { -// TODO(b/148289189) Remove when Ruy GEMV is the default. -#ifdef TFLITE_WITH_RUY_GEMV - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, scratch, - result, context); -#else - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result); -#endif +inline float GetTensorScale(const TfLiteTensor* tensor) { + return tensor == nullptr ? 1.0f : tensor->params.scale; } // Performs an LSTM batch inference step for input specified by input_ptr. @@ -473,6 +542,8 @@ inline void LstmStepHybrid( int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr, float* output_state_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr, float* output_ptr, + int32_t* zero_points, int32_t* row_sums, int row_sums_size, + bool* compute_row_sums, bool asymmetric_quantize_inputs, CpuBackendContext* context) { ruy::profiler::ScopeLabel label("LstmStepHybrid"); // Since we have already checked that weights are all there or none, we @@ -503,53 +574,131 @@ inline void LstmStepHybrid( output_gate_scratch); } - // For each batch and cell: compute input_weight * input. - // Skip if input is all zeros. + int32_t* input_to_input_row_sums = nullptr; + int32_t* input_to_forget_row_sums = nullptr; + int32_t* input_to_cell_row_sums = nullptr; + int32_t* input_to_output_row_sums = nullptr; + int32_t* aux_input_to_input_row_sums = nullptr; + int32_t* aux_input_to_forget_row_sums = nullptr; + int32_t* aux_input_to_cell_row_sums = nullptr; + int32_t* aux_input_to_output_row_sums = nullptr; + int32_t* recurrent_to_input_row_sums = nullptr; + int32_t* recurrent_to_forget_row_sums = nullptr; + int32_t* recurrent_to_cell_row_sums = nullptr; + int32_t* recurrent_to_output_row_sums = nullptr; + int32_t* projection_weights_row_sums = nullptr; + + if (asymmetric_quantize_inputs) { + int num_row_sums = use_cifg ? 6 : 8; + if (aux_input_ptr != nullptr) { + num_row_sums += use_cifg ? 3 : 4; + } + if (projection_weights_ptr != nullptr) { + num_row_sums += ceil(n_output / n_cell); + } + TF_LITE_ASSERT(row_sums_size == num_row_sums); + input_to_input_row_sums = row_sums; + input_to_forget_row_sums = + use_cifg ? input_to_input_row_sums : input_to_input_row_sums + n_cell; + input_to_cell_row_sums = input_to_forget_row_sums + n_cell; + input_to_output_row_sums = input_to_cell_row_sums + n_cell; + if (aux_input_ptr != nullptr) { + aux_input_to_input_row_sums = input_to_output_row_sums + n_cell; + aux_input_to_forget_row_sums = use_cifg + ? aux_input_to_input_row_sums + : aux_input_to_input_row_sums + n_cell; + aux_input_to_cell_row_sums = aux_input_to_forget_row_sums + n_cell; + aux_input_to_output_row_sums = aux_input_to_cell_row_sums + n_cell; + } + recurrent_to_input_row_sums = aux_input_ptr + ? aux_input_to_output_row_sums + n_cell + : input_to_output_row_sums + n_cell; + recurrent_to_forget_row_sums = use_cifg + ? recurrent_to_input_row_sums + : recurrent_to_input_row_sums + n_cell; + recurrent_to_cell_row_sums = recurrent_to_forget_row_sums + n_cell; + recurrent_to_output_row_sums = recurrent_to_cell_row_sums + n_cell; + if (projection_weights_ptr != nullptr) { + projection_weights_row_sums = recurrent_to_output_row_sums + n_cell; + } + if (*compute_row_sums) { + ComputeRowSums( + input_to_input_row_sums, input_to_forget_row_sums, + input_to_cell_row_sums, input_to_output_row_sums, + aux_input_to_input_row_sums, aux_input_to_forget_row_sums, + aux_input_to_cell_row_sums, aux_input_to_output_row_sums, + recurrent_to_input_row_sums, recurrent_to_forget_row_sums, + recurrent_to_cell_row_sums, recurrent_to_output_row_sums, + projection_weights_row_sums, row_sums, n_cell, n_input, n_aux_input, + n_output, input_to_input_weights_ptr, input_to_forget_weights_ptr, + input_to_cell_weights_ptr, input_to_output_weights_ptr, + aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr, + aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr, + recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr, + recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr, + projection_weights_ptr, use_cifg, aux_input_ptr); + *compute_row_sums = false; + } + } + if (!tensor_utils::IsZeroVector(input_ptr, n_batch * n_input)) { for (int b = 0; b < n_batch; ++b) { const int offset = b * n_input; - float unused_min, unused_max; - tensor_utils::SymmetricQuantizeFloats( - input_ptr + offset, n_input, quantized_input_ptr + offset, - &unused_min, &unused_max, &scaling_factors[b]); + if (asymmetric_quantize_inputs) { + tensor_utils::AsymmetricQuantizeFloats( + input_ptr + offset, n_input, quantized_input_ptr + offset, + &scaling_factors[b], &zero_points[b]); + } else { + float unused_min, unused_max; + tensor_utils::SymmetricQuantizeFloats( + input_ptr + offset, n_input, quantized_input_ptr + offset, + &unused_min, &unused_max, &scaling_factors[b]); + } } if (!use_cifg) { for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * input_to_input_weights_scale; } - MatrixBatchVectorMultiplyAccumulate( + tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_input_weights_ptr, n_cell, n_input, quantized_input_ptr, - product_scaling_factors, n_batch, accum_scratch_ptr, - input_gate_scratch, context); + product_scaling_factors, n_batch, input_gate_scratch, + /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + input_to_input_row_sums, compute_row_sums, context); } for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * input_to_forget_weights_scale; } - MatrixBatchVectorMultiplyAccumulate( + + tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr, - product_scaling_factors, n_batch, accum_scratch_ptr, - forget_gate_scratch, context); + product_scaling_factors, n_batch, forget_gate_scratch, + /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + input_to_forget_row_sums, compute_row_sums, context); for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * input_to_cell_weights_scale; } - MatrixBatchVectorMultiplyAccumulate( + + tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr, - product_scaling_factors, n_batch, accum_scratch_ptr, cell_scratch, - context); + product_scaling_factors, n_batch, cell_scratch, + /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + input_to_cell_row_sums, compute_row_sums, context); for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * input_to_output_weights_scale; } - MatrixBatchVectorMultiplyAccumulate( + + tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr, - product_scaling_factors, n_batch, accum_scratch_ptr, - output_gate_scratch, context); + product_scaling_factors, n_batch, output_gate_scratch, + /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + input_to_output_row_sums, compute_row_sums, context); } // For each batch and cell: compute aux_input_weight * aux_input. @@ -558,59 +707,84 @@ inline void LstmStepHybrid( !tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input)) { for (int b = 0; b < n_batch; ++b) { const int offset = b * n_aux_input; - float unused_min, unused_max; - tensor_utils::SymmetricQuantizeFloats( - aux_input_ptr + offset, n_aux_input, quantized_aux_input_ptr + offset, - &unused_min, &unused_max, &scaling_factors[b]); + if (asymmetric_quantize_inputs) { + tensor_utils::AsymmetricQuantizeFloats( + aux_input_ptr + offset, n_aux_input, + quantized_aux_input_ptr + offset, &scaling_factors[b], + &zero_points[b]); + } else { + float unused_min, unused_max; + tensor_utils::SymmetricQuantizeFloats( + aux_input_ptr + offset, n_aux_input, + quantized_aux_input_ptr + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } } + if (!use_cifg) { for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * aux_input_to_input_weights_scale; } - MatrixBatchVectorMultiplyAccumulate( + tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_input_weights_ptr, n_cell, n_aux_input, quantized_aux_input_ptr, product_scaling_factors, n_batch, - accum_scratch_ptr, input_gate_scratch, context); + input_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, + accum_scratch_ptr, aux_input_to_input_row_sums, compute_row_sums, + context); } for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * aux_input_to_forget_weights_scale; } - MatrixBatchVectorMultiplyAccumulate( + tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_forget_weights_ptr, n_cell, n_aux_input, quantized_aux_input_ptr, product_scaling_factors, n_batch, - accum_scratch_ptr, forget_gate_scratch, context); + forget_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, + accum_scratch_ptr, aux_input_to_forget_row_sums, compute_row_sums, + context); + row_sums += n_cell; for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * aux_input_to_cell_weights_scale; } - MatrixBatchVectorMultiplyAccumulate( + tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_cell_weights_ptr, n_cell, n_aux_input, - quantized_aux_input_ptr, product_scaling_factors, n_batch, - accum_scratch_ptr, cell_scratch, context); + quantized_aux_input_ptr, product_scaling_factors, n_batch, cell_scratch, + /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + aux_input_to_cell_row_sums, compute_row_sums, context); for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * aux_input_to_output_weights_scale; } - MatrixBatchVectorMultiplyAccumulate( + + tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_output_weights_ptr, n_cell, n_aux_input, quantized_aux_input_ptr, product_scaling_factors, n_batch, - accum_scratch_ptr, output_gate_scratch, context); + output_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, + accum_scratch_ptr, aux_input_to_output_row_sums, compute_row_sums, + context); } if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) { // Save quantization and matmul computation for all zero input. for (int b = 0; b < n_batch; ++b) { const int offset = b * n_output; - float unused_min, unused_max; - tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output, - quantized_output_state_ptr + offset, - &unused_min, &unused_max, - &scaling_factors[b]); + if (asymmetric_quantize_inputs) { + tensor_utils::AsymmetricQuantizeFloats( + output_state_ptr + offset, n_output, + quantized_output_state_ptr + offset, &scaling_factors[b], + &zero_points[b]); + } else { + float unused_min, unused_max; + tensor_utils::SymmetricQuantizeFloats( + output_state_ptr + offset, n_output, + quantized_output_state_ptr + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } } // For each batch and cell: compute recurrent_weight * output_state. if (!use_cifg) { @@ -618,38 +792,46 @@ inline void LstmStepHybrid( product_scaling_factors[b] = scaling_factors[b] * recurrent_to_input_weights_scale; } - MatrixBatchVectorMultiplyAccumulate( + tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_input_weights_ptr, n_cell, n_output, quantized_output_state_ptr, product_scaling_factors, n_batch, - accum_scratch_ptr, input_gate_scratch, context); + input_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, + accum_scratch_ptr, recurrent_to_input_row_sums, compute_row_sums, + context); } for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * recurrent_to_forget_weights_scale; } - MatrixBatchVectorMultiplyAccumulate( + tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_forget_weights_ptr, n_cell, n_output, quantized_output_state_ptr, product_scaling_factors, n_batch, - accum_scratch_ptr, forget_gate_scratch, context); + forget_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, + accum_scratch_ptr, recurrent_to_forget_row_sums, compute_row_sums, + context); for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * recurrent_to_cell_weights_scale; } - MatrixBatchVectorMultiplyAccumulate( + tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_cell_weights_ptr, n_cell, n_output, quantized_output_state_ptr, product_scaling_factors, n_batch, - accum_scratch_ptr, cell_scratch, context); + cell_scratch, /*per_channel_scale=*/nullptr, zero_points, + accum_scratch_ptr, recurrent_to_cell_row_sums, compute_row_sums, + context); for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * recurrent_to_output_weights_scale; } - MatrixBatchVectorMultiplyAccumulate( + tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_output_weights_ptr, n_cell, n_output, quantized_output_state_ptr, product_scaling_factors, n_batch, - accum_scratch_ptr, output_gate_scratch, context); + output_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, + accum_scratch_ptr, recurrent_to_output_row_sums, compute_row_sums, + context); } // For each batch and cell: update input gate. @@ -770,22 +952,32 @@ inline void LstmStepHybrid( // Save quantization and matmul computation for all zero input. for (int b = 0; b < n_batch; ++b) { const int offset = b * n_cell; - float unused_min, unused_max; - tensor_utils::SymmetricQuantizeFloats( - output_gate_scratch + offset, n_cell, - quantized_cell_state_ptr + offset, &unused_min, &unused_max, - &scaling_factors[b]); + if (asymmetric_quantize_inputs) { + tensor_utils::AsymmetricQuantizeFloats( + output_gate_scratch + offset, n_cell, + quantized_cell_state_ptr + offset, &scaling_factors[b], + &zero_points[b]); + } else { + float unused_min, unused_max; + tensor_utils::SymmetricQuantizeFloats( + output_gate_scratch + offset, n_cell, + quantized_cell_state_ptr + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } } for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * projection_weights_scale; } for (int b = 0; b < n_batch; b++) { - MatrixBatchVectorMultiplyAccumulate( + tensor_utils::MatrixBatchVectorMultiplyAccumulate( projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr + b * n_cell, &product_scaling_factors[b], - /*n_batch=*/1, accum_scratch_ptr, - output_ptr + b * output_batch_leading_dim, context); + /*n_batch=*/1, output_ptr + b * output_batch_leading_dim, + /*per_channel_scale=*/nullptr, + asymmetric_quantize_inputs ? &zero_points[b] : nullptr, + accum_scratch_ptr, projection_weights_row_sums, compute_row_sums, + context); } } if (params->proj_clip > 0.0) { @@ -1615,7 +1807,8 @@ TfLiteStatus EvalHybrid( TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer, - TfLiteTensor* output, CpuBackendContext* context) { + TfLiteTensor* output, TfLiteTensor* zero_points, TfLiteTensor* row_sums, + int row_sums_size, bool* compute_row_sums, CpuBackendContext* context) { TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3); const int n_input = input->dims->data[input->dims->size - 1]; int max_time, n_batch; @@ -1654,6 +1847,14 @@ TfLiteStatus EvalHybrid( const int output_batch_leading_dim = output->dims->data[output->dims->size - 1]; + + int32_t* zero_points_ptr = nullptr; + int32_t* row_sums_ptr = nullptr; + if (params->asymmetric_quantize_inputs) { + zero_points_ptr = GetTensorData(zero_points); + row_sums_ptr = GetTensorData(row_sums); + } + if (time_major) { // Feed the sequence into the LSTM step-by-step. const int input_step = n_batch * n_input; @@ -1721,7 +1922,9 @@ TfLiteStatus EvalHybrid( GetTensorData(output_state_quantized), GetTensorData(cell_state_quantized), GetTensorData(output_state), GetTensorData(cell_state), - GetTensorData(output_scratch_buffer), output_ptr, context); + GetTensorData(output_scratch_buffer), output_ptr, + zero_points_ptr, row_sums_ptr, row_sums_size, compute_row_sums, + params->asymmetric_quantize_inputs, context); } } else { for (int b = 0; b < n_batch; b++) { @@ -1806,7 +2009,8 @@ TfLiteStatus EvalHybrid( GetTensorData(output_state_quantized), GetTensorData(cell_state_quantized), output_state_ptr, cell_state_ptr, GetTensorData(output_scratch_buffer), - output_ptr, context); + output_ptr, zero_points_ptr, row_sums_ptr, row_sums_size, + compute_row_sums, params->asymmetric_quantize_inputs, context); } } } diff --git a/tensorflow/lite/kernels/lstm_eval.h b/tensorflow/lite/kernels/lstm_eval.h index ca3f96391aa..877cfd70a89 100644 --- a/tensorflow/lite/kernels/lstm_eval.h +++ b/tensorflow/lite/kernels/lstm_eval.h @@ -156,7 +156,8 @@ TfLiteStatus EvalHybrid( TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer, - TfLiteTensor* output, CpuBackendContext* context); + TfLiteTensor* output, TfLiteTensor* zero_points, TfLiteTensor* row_sums, + int row_sums_size, bool* compute_row_sums, CpuBackendContext* context); TfLiteStatus EvalInteger8x8_16( const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, diff --git a/tensorflow/lite/kernels/lstm_test.cc b/tensorflow/lite/kernels/lstm_test.cc index f426ffae0e0..2bd31eae8db 100644 --- a/tensorflow/lite/kernels/lstm_test.cc +++ b/tensorflow/lite/kernels/lstm_test.cc @@ -38,7 +38,8 @@ class LSTMOpModel : public SingleOpModel { bool use_peephole, bool use_projection_weights, bool use_projection_bias, float cell_clip, float proj_clip, const std::vector>& input_shapes, - const TensorType weight_type, bool is_layer_norm) + const TensorType weight_type, bool is_layer_norm, + bool asymmetric_quantize_inputs = false) : n_batch_(n_batch), n_input_(n_input), n_cell_(n_cell), @@ -129,10 +130,12 @@ class LSTMOpModel : public SingleOpModel { output_ = AddOutput(TensorType_FLOAT32); - SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, - CreateLSTMOptions(builder_, ActivationFunctionType_TANH, - cell_clip, proj_clip) - .Union()); + SetBuiltinOp( + BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, + CreateLSTMOptions(builder_, ActivationFunctionType_TANH, cell_clip, + proj_clip, ::tflite::LSTMKernelType_FULL, + asymmetric_quantize_inputs) + .Union()); // Do not apply delegate yet since tensor values are not known (and more // specifically scales in quantized tensors are not known). @@ -315,7 +318,7 @@ class LSTMOpModel : public SingleOpModel { const TensorType weight_type_; }; -class BaseLstmTest : public ::testing::Test { +class BaseLstmTest : public ::testing::TestWithParam { protected: // Weights of the LSTM model. Some are optional. std::vector input_to_input_weights_; @@ -565,8 +568,11 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest, VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } -TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, +TEST_P(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTestUint8) { + if (SingleOpModel::GetForceUseNnapi() && GetParam()) { + return; + } const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -604,14 +610,20 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, {0}, // projection_bias tensor }, /*weight_type=*/TensorType_UINT8, - /*is_layer_norm=*/false); + /*is_layer_norm=*/false, GetParam()); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.0157651); } -TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, +class NoCifgNoPeepholeNoProjectionNoClippingLstmInt8Test + : public NoCifgNoPeepholeNoProjectionNoClippingLstmTest {}; + +TEST_P(NoCifgNoPeepholeNoProjectionNoClippingLstmInt8Test, HybridLstmBlackBoxTestInt8) { + if (SingleOpModel::GetForceUseNnapi() && GetParam()) { + return; + } const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -649,7 +661,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, {0}, // projection_bias tensor }, /*weight_type=*/TensorType_INT8, - /*is_layer_norm=*/false); + /*is_layer_norm=*/false, GetParam()); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.0157651); @@ -745,8 +757,11 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } -TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, +TEST_P(CifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTestUint8) { + if (SingleOpModel::GetForceUseNnapi() && GetParam()) { + return; + } const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -784,13 +799,18 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, {0}, // projection_bias tensor }, /*weight_type=*/TensorType_UINT8, - /*is_layer_norm=*/false); + /*is_layer_norm=*/false, GetParam()); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573); } +class CifgNoPeepholeNoProjectionNoClippingLstmInt8Test + : public CifgNoPeepholeNoProjectionNoClippingLstmTest {}; -TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, +TEST_P(CifgNoPeepholeNoProjectionNoClippingLstmInt8Test, HybridLstmBlackBoxTestInt8) { + if (SingleOpModel::GetForceUseNnapi() && GetParam()) { + return; + } const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -828,7 +848,7 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, {0}, // projection_bias tensor }, /*weight_type=*/TensorType_INT8, - /*is_layer_norm=*/false); + /*is_layer_norm=*/false, GetParam()); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573); } @@ -1474,50 +1494,11 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, LstmBlackBoxTest) { VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } -TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTestInt8) { - const int n_batch = 2; - const int n_input = 5; - const int n_cell = 20; - const int n_output = 16; - - LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, - /*use_cifg=*/false, /*use_peephole=*/true, - /*use_projection_weights=*/true, - /*use_projection_bias=*/false, - /*cell_clip=*/0.0, /*proj_clip=*/0.0, - { - {n_batch, n_input}, // input tensor - - {n_cell, n_input}, // input_to_input_weight tensor - {n_cell, n_input}, // input_to_forget_weight tensor - {n_cell, n_input}, // input_to_cell_weight tensor - {n_cell, n_input}, // input_to_output_weight tensor - - {n_cell, n_output}, // recurrent_to_input_weight tensor - {n_cell, n_output}, // recurrent_to_forget_weight tensor - {n_cell, n_output}, // recurrent_to_cell_weight tensor - {n_cell, n_output}, // recurrent_to_output_weight tensor - - {n_cell}, // cell_to_input_weight tensor - {n_cell}, // cell_to_forget_weight tensor - {n_cell}, // cell_to_output_weight tensor - - {n_cell}, // input_gate_bias tensor - {n_cell}, // forget_gate_bias tensor - {n_cell}, // cell_bias tensor - {n_cell}, // output_gate_bias tensor - - {n_output, n_cell}, // projection_weight tensor - {0}, // projection_bias tensor - }, - /*weight_type=*/TensorType_INT8, - /*is_layer_norm=*/false); - - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467); -} - -TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, +TEST_P(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTestUint8) { + if (SingleOpModel::GetForceUseNnapi() && GetParam()) { + return; + } const int n_batch = 2; const int n_input = 5; const int n_cell = 20; @@ -1554,11 +1535,60 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, {0}, // projection_bias tensor }, /*weight_type=*/TensorType_UINT8, - /*is_layer_norm=*/false); + /*is_layer_norm=*/false, GetParam()); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467); } +class NoCifgPeepholeProjectionNoClippingLstmInt8Test + : public NoCifgPeepholeProjectionNoClippingLstmTest {}; + +TEST_P(NoCifgPeepholeProjectionNoClippingLstmInt8Test, + HybridLstmBlackBoxTestInt8) { + if (SingleOpModel::GetForceUseNnapi() && GetParam()) { + return; + } + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 20; + const int n_output = 16; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + }, + /*weight_type=*/TensorType_INT8, + /*is_layer_norm=*/false, GetParam()); + + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.0015); +} + class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest : public BaseLstmTest { void SetUp() override { @@ -1693,8 +1723,11 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm); } -TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, +TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, HybridLayerNormLstmBlackBoxTestUint8) { + if (SingleOpModel::GetForceUseNnapi() && GetParam()) { + return; + } const int n_batch = 2; const int n_input = 5; const int n_cell = 4; @@ -1741,7 +1774,7 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, {n_cell}, // output_layer_norm_coefficient tensor }, /*weight_type=*/TensorType_UINT8, - /*is_layer_norm=*/true); + /*is_layer_norm=*/true, GetParam()); lstm_golden_output_ = {{ // Batch0: 3 (input_sequence_size) * 3 (n_output) @@ -1760,8 +1793,14 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, /*tolerance=*/0.0010907); } -TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, +class NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test + : public NoCifgPeepholeProjectionNoClippingLayerNormLstmTest {}; + +TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test, HybridLayerNormLstmBlackBoxTestInt8) { + if (SingleOpModel::GetForceUseNnapi() && GetParam()) { + return; + } const int n_batch = 2; const int n_input = 5; const int n_cell = 4; @@ -1808,22 +1847,24 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, {n_cell}, // output_layer_norm_coefficient tensor }, /*weight_type=*/TensorType_INT8, - /*is_layer_norm=*/true); + /*is_layer_norm=*/true, GetParam()); + // Goldens are calculated from weight_type=TensorType_FLOAT32. lstm_golden_output_ = {{ // Batch0: 3 (input_sequence_size) * 3 (n_output) - 0.0244576, 0.127847, -0.00181765, // seq 0 - 0.0137518, 0.140892, 0.0402234, // seq 1 - -0.0048839, 0.155096, 0.0840309, // seq 2 + 0.0244077, 0.128027, -0.00170918, // seq 0 + 0.0137642, 0.140751, 0.0395835, // seq 1 + -0.00459233, 0.155278, 0.0837378, // seq 2 }, { // Batch1: 3 (input_sequence_size) * 3 (n_output) - -0.00728636, 0.0843957, 0.0634786, // seq 0 - -0.00448382, 0.139278, 0.0737372, // seq 1 - 0.00734616, 0.161793, 0.0560238, // seq 2 + -0.00692428, 0.0848741, 0.063445, // seq 0 + -0.00403911, 0.139963, 0.072681, // seq 1 + 0.00752708, 0.161903, 0.0561371, // seq 2 }}; - VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm); + VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm, + /*tolerance=*/1.06e-3); } class CifgPeepholeProjectionNoClippingLayerNormLstmTest : public BaseLstmTest { @@ -1940,8 +1981,11 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest, VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm); } -TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest, +TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmTest, HybridLayerNormLstmBlackBoxTestUint8) { + if (SingleOpModel::GetForceUseNnapi() && GetParam()) { + return; + } const int n_batch = 2; const int n_input = 5; const int n_cell = 4; @@ -1988,7 +2032,7 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest, {n_cell}, // output_layer_norm_coefficient tensor }, /*weight_type=*/TensorType_UINT8, - /*is_layer_norm=*/true); + /*is_layer_norm=*/true, GetParam()); // Verify the final output. lstm_golden_output_ = { @@ -2009,7 +2053,10 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest, /*tolerance=*/0.000902065); } -TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest, +class CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test + : public CifgPeepholeProjectionNoClippingLayerNormLstmTest {}; + +TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test, HybridLayerNormLstmBlackBoxTestInt8) { const int n_batch = 2; const int n_input = 5; @@ -2057,24 +2104,24 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest, {n_cell}, // output_layer_norm_coefficient tensor }, /*weight_type=*/TensorType_INT8, - /*is_layer_norm=*/true); + /*is_layer_norm=*/true, GetParam()); - // Verify the final output. - lstm_golden_output_ = { - { - // Batch0: 3 (input_sequence_size) * 3 (n_output) - 0.0212250091, 0.140474007, 0.0115012666, // seq 0 - 0.0130806509, 0.152660668, 0.0347516984, // seq 1 - -0.0124010444, 0.166042402, 0.0898982584, // seq 2 - }, - { - // Batch1: 3 (input_sequence_size) * 3 (n_output) - -0.0228835996, 0.0917588323, 0.0778886303, // seq 0 - -0.0275101066, 0.148769245, 0.0938384682, // seq 1 - -0.0103605557, 0.172605693, 0.0728750974, // seq 2 - }}; + // Goldens are results using FLOAT32 inference. + lstm_golden_output_ = {{ + // Batch0: 3 (input_sequence_size) * 3 (n_output) + 0.0212971, 0.140816, 0.0112733, // seq 0 + 0.0132302, 0.152308, 0.0346313, // seq 1 + -0.0123688, 0.16579, 0.0893078, // seq 2 + }, + { + // Batch1: 3 (input_sequence_size) * 3 (n_output) + -0.0226351, 0.0916948, 0.0769176, // seq 0 + -0.0269967, 0.149708, 0.0941492, // seq 1 + -0.0103429, 0.173016, 0.0720509, // seq 2 + }}; - VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm); + VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm, + /*tolerance=*/1e-3); } class LSTMIntegerOpModel : public SingleOpModel { @@ -3311,5 +3358,22 @@ TEST(LSTMOpModel, InvalidTypeTest) { ""); } #endif + +#define QUANTIZE_PARAMETER_TEST(test) \ + INSTANTIATE_TEST_SUITE_P(test, test, ::testing::Bool()) + +QUANTIZE_PARAMETER_TEST(NoCifgNoPeepholeNoProjectionNoClippingLstmTest); +QUANTIZE_PARAMETER_TEST(NoCifgNoPeepholeNoProjectionNoClippingLstmInt8Test); +QUANTIZE_PARAMETER_TEST(CifgNoPeepholeNoProjectionNoClippingLstmTest); +QUANTIZE_PARAMETER_TEST(CifgNoPeepholeNoProjectionNoClippingLstmInt8Test); +QUANTIZE_PARAMETER_TEST(NoCifgPeepholeProjectionNoClippingLstmTest); +QUANTIZE_PARAMETER_TEST(NoCifgPeepholeProjectionNoClippingLstmInt8Test); +QUANTIZE_PARAMETER_TEST(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest); +QUANTIZE_PARAMETER_TEST( + NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test); +QUANTIZE_PARAMETER_TEST(CifgPeepholeProjectionNoClippingLayerNormLstmTest); +QUANTIZE_PARAMETER_TEST(CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test); +#undef QUANTIZE_PARAMETER_TEST + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index 41cc3aa4675..8c1f6b4a9e7 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -33,199 +33,201 @@ namespace builtin { BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_ABS, Register_ABS()); AddBuiltin(BuiltinOperator_HARD_SWISH, Register_HARD_SWISH()); - AddBuiltin(BuiltinOperator_RELU, Register_RELU(), /* min_version */ 1, - /* max_version */ 2); + AddBuiltin(BuiltinOperator_RELU, Register_RELU(), /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_RELU_N1_TO_1, Register_RELU_N1_TO_1()); - AddBuiltin(BuiltinOperator_RELU6, Register_RELU6(), /* min_version */ 1, - /* max_version */ 2); - AddBuiltin(BuiltinOperator_TANH, Register_TANH(), /* min_version */ 1, - /* max_version */ 2); + AddBuiltin(BuiltinOperator_RELU6, Register_RELU6(), /* min_version = */ 1, + /* max_version = */ 2); + AddBuiltin(BuiltinOperator_TANH, Register_TANH(), /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_LOGISTIC, Register_LOGISTIC(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, Register_AVERAGE_POOL_2D(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_MAX_POOL_2D, Register_MAX_POOL_2D(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_L2_POOL_2D, Register_L2_POOL_2D()); AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D(), - /* min_version */ 1, - /* max_version */ 3); + /* min_version = */ 1, + /* max_version = */ 3); AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D(), - /* min_version */ 1, - /* max_version */ 3); + /* min_version = */ 1, + /* max_version = */ 3); AddBuiltin(BuiltinOperator_SVDF, Register_SVDF(), - /* min_version */ 1, - /* max_version */ 3); + /* min_version = */ 1, + /* max_version = */ 3); AddBuiltin(BuiltinOperator_RNN, Register_RNN(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, Register_BIDIRECTIONAL_SEQUENCE_RNN(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, Register_UNIDIRECTIONAL_SEQUENCE_RNN(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, Register_EMBEDDING_LOOKUP(), - /* min_version */ 1, - /* max_version */ 3); + /* min_version = */ 1, + /* max_version = */ 3); AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, Register_EMBEDDING_LOOKUP_SPARSE()); AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED(), - /* min_version */ 1, - /* max_version */ 6); + /* min_version = */ 1, + /* max_version = */ 6); AddBuiltin(BuiltinOperator_LSH_PROJECTION, Register_LSH_PROJECTION()); AddBuiltin(BuiltinOperator_HASHTABLE_LOOKUP, Register_HASHTABLE_LOOKUP()); AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_CONCATENATION, Register_CONCATENATION(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_ADD, Register_ADD(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_SPACE_TO_BATCH_ND, Register_SPACE_TO_BATCH_ND(), - /* min_version */ 1, - /* max_version */ 3); + /* min_version = */ 1, + /* max_version = */ 3); AddBuiltin(BuiltinOperator_BATCH_TO_SPACE_ND, Register_BATCH_TO_SPACE_ND(), - /* min_version */ 1, - /* max_version */ 3); - AddBuiltin(BuiltinOperator_MUL, Register_MUL(), /* min_version */ 1, - /* max_version */ 3); + /* min_version = */ 1, + /* max_version = */ 3); + AddBuiltin(BuiltinOperator_MUL, Register_MUL(), /* min_version = */ 1, + /* max_version = */ 3); AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, Register_LOCAL_RESPONSE_NORMALIZATION()); - AddBuiltin(BuiltinOperator_LSTM, Register_LSTM(), /* min_version */ 1, - /* max_version */ 3); + AddBuiltin(BuiltinOperator_LSTM, Register_LSTM(), /* min_version = */ 1, + /* max_version = */ 3); AddBuiltin(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, - Register_BIDIRECTIONAL_SEQUENCE_LSTM(), /* min_version */ 1, - /* max_version */ 3); + Register_BIDIRECTIONAL_SEQUENCE_LSTM(), /* min_version = */ 1, + /* max_version = */ 3); AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, - Register_UNIDIRECTIONAL_SEQUENCE_LSTM(), /* min_version */ 1, - /* max_version */ 2); - AddBuiltin(BuiltinOperator_PAD, Register_PAD(), /* min_version */ 1, - /* max_version */ 2); - AddBuiltin(BuiltinOperator_PADV2, Register_PADV2(), /* min_version */ 1, - /* max_version */ 2); + Register_UNIDIRECTIONAL_SEQUENCE_LSTM(), /* min_version = */ 1, + /* max_version = */ 2); + AddBuiltin(BuiltinOperator_PAD, Register_PAD(), /* min_version = */ 1, + /* max_version = */ 2); + AddBuiltin(BuiltinOperator_PADV2, Register_PADV2(), /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_RESHAPE, Register_RESHAPE()); AddBuiltin(BuiltinOperator_RESIZE_BILINEAR, Register_RESIZE_BILINEAR(), - /* min_version */ 1, - /* max_version */ 3); + /* min_version = */ 1, + /* max_version = */ 3); AddBuiltin(BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, Register_RESIZE_NEAREST_NEIGHBOR(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM()); AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_DEPTH_TO_SPACE, Register_DEPTH_TO_SPACE()); AddBuiltin(BuiltinOperator_GATHER, Register_GATHER(), - /* min_version */ 1, - /* max_version */ 3); + /* min_version = */ 1, + /* max_version = */ 3); AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE(), - /* min_version */ 1, - /* max_version */ 4); + /* min_version = */ 1, + /* max_version = */ 4); AddBuiltin(BuiltinOperator_MEAN, Register_MEAN(), + /* min_version = */ 1, + /* max_version = */ 2); + AddBuiltin(BuiltinOperator_DIV, Register_DIV(), /* min_version */ 1, /* max_version */ 2); - AddBuiltin(BuiltinOperator_DIV, Register_DIV()); AddBuiltin(BuiltinOperator_SUB, Register_SUB(), - /* min_version */ 1, - /* max_version */ 3); - AddBuiltin(BuiltinOperator_SPLIT, Register_SPLIT(), /* min_version */ 1, - /* max_version */ 3); + /* min_version = */ 1, + /* max_version = */ 3); + AddBuiltin(BuiltinOperator_SPLIT, Register_SPLIT(), /* min_version = */ 1, + /* max_version = */ 3); AddBuiltin(BuiltinOperator_SPLIT_V, Register_SPLIT_V(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE()); AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE(), - /* min_version */ 1, - /* max_version */ 4); + /* min_version = */ 1, + /* max_version = */ 4); AddBuiltin(BuiltinOperator_EXP, Register_EXP()); AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_LOG, Register_LOG()); AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_CAST, Register_CAST()); AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(), - /* min_version */ 1, - /* max_version */ 4); + /* min_version = */ 1, + /* max_version = */ 4); AddBuiltin(BuiltinOperator_PRELU, Register_PRELU()); AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM(), - /* min_version */ 1, - /* max_version */ 4); + /* min_version = */ 1, + /* max_version = */ 4); AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM(), - /* min_version */ 1, - /* max_version */ 4); + /* min_version = */ 1, + /* max_version = */ 4); AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_ARG_MIN, Register_ARG_MIN(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_GREATER, Register_GREATER(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_GREATER_EQUAL, Register_GREATER_EQUAL(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_LESS, Register_LESS(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_LESS_EQUAL, Register_LESS_EQUAL(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_FLOOR, Register_FLOOR()); AddBuiltin(BuiltinOperator_CEIL, Register_CEIL()); AddBuiltin(BuiltinOperator_ROUND, Register_ROUND()); AddBuiltin(BuiltinOperator_NEG, Register_NEG()); AddBuiltin(BuiltinOperator_SELECT, Register_SELECT(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_SELECT_V2, Register_SELECT_V2()); AddBuiltin(BuiltinOperator_SLICE, Register_SLICE(), - /* min_version */ 1, - /* max_version */ 3); + /* min_version = */ 1, + /* max_version = */ 3); AddBuiltin(BuiltinOperator_SIN, Register_SIN()); AddBuiltin(BuiltinOperator_COS, Register_COS()); AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_TILE, Register_TILE(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_SUM, Register_SUM(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_REDUCE_PROD, Register_REDUCE_PROD()); AddBuiltin(BuiltinOperator_REDUCE_MAX, Register_REDUCE_MAX(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_REDUCE_MIN, Register_REDUCE_MIN(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_REDUCE_ANY, Register_REDUCE_ANY()); AddBuiltin(BuiltinOperator_EXPAND_DIMS, Register_EXPAND_DIMS()); AddBuiltin(BuiltinOperator_SPARSE_TO_DENSE, Register_SPARSE_TO_DENSE(), - /* min_version */ 1, - /* max_version */ 3); + /* min_version = */ 1, + /* max_version = */ 3); AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_SQRT, Register_SQRT()); AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT()); AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE()); @@ -233,41 +235,43 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_POW, Register_POW()); AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT(), 1, 2); AddBuiltin(BuiltinOperator_PACK, Register_PACK(), - /* min_version */ 1, - /* max_version */ 3); + /* min_version = */ 1, + /* max_version = */ 3); AddBuiltin(BuiltinOperator_ONE_HOT, Register_ONE_HOT()); AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR()); AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND()); AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT()); AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK(), - /* min_version */ 1, - /* max_version */ 4); + /* min_version = */ 1, + /* max_version = */ 4); AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_SQUARE, Register_SQUARE()); AddBuiltin(BuiltinOperator_ZEROS_LIKE, Register_ZEROS_LIKE()); AddBuiltin(BuiltinOperator_FLOOR_MOD, Register_FLOOR_MOD()); AddBuiltin(BuiltinOperator_RANGE, Register_RANGE()); AddBuiltin(BuiltinOperator_LEAKY_RELU, Register_LEAKY_RELU(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_SQUARED_DIFFERENCE, Register_SQUARED_DIFFERENCE()); AddBuiltin(BuiltinOperator_FILL, Register_FILL()); AddBuiltin(BuiltinOperator_MIRROR_PAD, Register_MIRROR_PAD()); AddBuiltin(BuiltinOperator_UNIQUE, Register_UNIQUE()); AddBuiltin(BuiltinOperator_REVERSE_V2, Register_REVERSE_V2(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_ADD_N, Register_ADD_N()); - AddBuiltin(BuiltinOperator_GATHER_ND, Register_GATHER_ND()); + AddBuiltin(BuiltinOperator_GATHER_ND, Register_GATHER_ND(), + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_WHERE, Register_WHERE()); AddBuiltin(BuiltinOperator_ELU, Register_ELU()); AddBuiltin(BuiltinOperator_REVERSE_SEQUENCE, Register_REVERSE_SEQUENCE()); AddBuiltin(BuiltinOperator_MATRIX_DIAG, Register_MATRIX_DIAG()); AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE(), - /* min_version */ 1, - /* max_version */ 2); + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_MATRIX_SET_DIAG, Register_MATRIX_SET_DIAG()); AddBuiltin(BuiltinOperator_IF, tflite::ops::builtin::Register_IF()); AddBuiltin(BuiltinOperator_WHILE, tflite::ops::builtin::Register_WHILE()); diff --git a/tensorflow/lite/kernels/rfft2d.cc b/tensorflow/lite/kernels/rfft2d.cc index c0554c5e39b..fa201153daf 100644 --- a/tensorflow/lite/kernels/rfft2d.cc +++ b/tensorflow/lite/kernels/rfft2d.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "third_party/fft2d/fft2d.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/op_macros.h" diff --git a/tensorflow/lite/kernels/strided_slice_test.cc b/tensorflow/lite/kernels/strided_slice_test.cc index c687d0761fc..e97eab5b7c4 100644 --- a/tensorflow/lite/kernels/strided_slice_test.cc +++ b/tensorflow/lite/kernels/strided_slice_test.cc @@ -97,6 +97,15 @@ TYPED_TEST(StridedSliceOpTest, UnsupportedArgs) { } #endif +TYPED_TEST(StridedSliceOpTest, In1DEmpty) { + StridedSliceOpModel m({0}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetBegin({1}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0})); +} + TYPED_TEST(StridedSliceOpTest, In1D) { StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); diff --git a/tensorflow/lite/kernels/svdf.cc b/tensorflow/lite/kernels/svdf.cc index bcbd06e8a67..82b7b7e4ee5 100644 --- a/tensorflow/lite/kernels/svdf.cc +++ b/tensorflow/lite/kernels/svdf.cc @@ -43,6 +43,7 @@ struct OpData { int effective_scale_1_b; int32 effective_scale_2_a; int effective_scale_2_b; + bool compute_row_sums = false; }; } // namespace @@ -61,8 +62,8 @@ constexpr int kOutputTensor = 0; void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* op_data = new OpData(); op_data->float_weights_time_initialized = false; - // Note: only needs 4 scratch tensors when is_hybrid_op, only 1 otherwise. - context->AddTensors(context, /*tensors_to_add=*/4, + // Note: only needs 6 scratch tensors when is_hybrid_op, only 1 otherwise. + context->AddTensors(context, /*tensors_to_add=*/6, &op_data->scratch_tensor_index); return op_data; } @@ -130,7 +131,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Resize scratch. TfLiteIntArrayFree(node->temporaries); if (is_hybrid_op) { - node->temporaries = TfLiteIntArrayCreate(4); + node->temporaries = TfLiteIntArrayCreate(6); } else if (is_full_integer) { node->temporaries = TfLiteIntArrayCreate(2); } else { @@ -156,6 +157,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { scratch_size_array)); if (is_hybrid_op) { + op_data->compute_row_sums = true; // Tell interpreter to allocate temporary tensors to store quantized values // of input tensors. node->temporaries->data[1] = scratch_tensor_index + 1; @@ -195,6 +197,30 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context->ResizeTensor(context, float_weights_time, float_weights_time_size)); } + + node->temporaries->data[4] = scratch_tensor_index + 4; + TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4); + zero_points->type = kTfLiteFloat32; + zero_points->allocation_type = kTfLiteArenaRw; + int zero_points_dims[1] = {batch_size}; + if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) { + TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1); + zero_points_size->data[0] = zero_points_dims[0]; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points, + zero_points_size)); + } + + node->temporaries->data[5] = scratch_tensor_index + 5; + TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5); + row_sums->type = kTfLiteFloat32; + row_sums->allocation_type = kTfLiteArenaRwPersistent; + int row_sums_dims[1] = {num_filters}; + if (!TfLiteIntArrayEqualsArray(row_sums->dims, 1, row_sums_dims)) { + TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(1); + row_sums_size->data[0] = row_sums_dims[0]; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, row_sums, row_sums_size)); + } } if (is_full_integer) { // Allocated one extra tensor. @@ -267,7 +293,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTemporary(context, node, /*index=*/2); TfLiteTensor* float_weights_time = GetTemporary(context, node, /*index=*/3); - + TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4); + TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5); // Dequantize weights time. // TODO(alanchiao): this dequantization initialization only needs to // happen once per model and should theoretically be placed in either @@ -285,10 +312,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } op_data->float_weights_time_initialized = true; } - reference_ops::EvalHybridSVDF(context, node, input, weights_feature, - float_weights_time, bias, params, scratch, - scaling_factors, input_quantized, - activation_state, output); + + reference_ops::EvalHybridSVDF( + context, node, input, weights_feature, float_weights_time, bias, + params, scratch, scaling_factors, input_quantized, activation_state, + output, zero_points, row_sums, &op_data->compute_row_sums); return kTfLiteOk; } else { auto* input_params = reinterpret_cast( diff --git a/tensorflow/lite/kernels/svdf_test.cc b/tensorflow/lite/kernels/svdf_test.cc index 1f5cfb040e7..68963b784f4 100644 --- a/tensorflow/lite/kernels/svdf_test.cc +++ b/tensorflow/lite/kernels/svdf_test.cc @@ -131,7 +131,8 @@ class BaseSVDFOpModel : public SingleOpModel { BaseSVDFOpModel(int batches, int units, int input_size, int memory_size, int rank, TensorType weights_feature_type = TensorType_FLOAT32, - TensorType weights_time_type = TensorType_FLOAT32) + TensorType weights_time_type = TensorType_FLOAT32, + bool asymmetric_quantize_inputs = false) : batches_(batches), units_(units), input_size_(input_size), @@ -146,9 +147,10 @@ class BaseSVDFOpModel : public SingleOpModel { TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}}, /*is_variable=*/true); output_ = AddOutput(TensorType_FLOAT32); - SetBuiltinOp( - BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions, - CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union()); + SetBuiltinOp(BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions, + CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE, + asymmetric_quantize_inputs) + .Union()); BuildInterpreter({ {batches_, input_size_}, // input tensor {units_ * rank, input_size_}, // weights_feature tensor @@ -203,9 +205,10 @@ class SVDFOpModel : public BaseSVDFOpModel { class HybridSVDFOpModel : public BaseSVDFOpModel { public: HybridSVDFOpModel(int batches, int units, int input_size, int memory_size, - int rank, TensorType tensor_type) + int rank, TensorType tensor_type, + bool asymmetric_quantize_inputs) : BaseSVDFOpModel(batches, units, input_size, memory_size, rank, - tensor_type, tensor_type) { + tensor_type, tensor_type, asymmetric_quantize_inputs) { tensor_type_ = tensor_type; } @@ -229,7 +232,7 @@ class HybridSVDFOpModel : public BaseSVDFOpModel { TensorType tensor_type_; }; -class SVDFOpTest : public ::testing::Test { +class SVDFOpTest : public ::testing::TestWithParam { protected: void VerifyGoldens(float golden_input[], float golden_output[], int golden_size, BaseSVDFOpModel* svdf, @@ -262,6 +265,9 @@ class SVDFOpTest : public ::testing::Test { } }; +INSTANTIATE_TEST_SUITE_P(SVDFOpTest, SVDFOpTest, + ::testing::ValuesIn({false, true})); + TEST_F(SVDFOpTest, BlackBoxTestRank1) { SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, /*memory_size=*/10, /*rank=*/1); @@ -325,9 +331,10 @@ TEST_F(SVDFOpTest, BlackBoxTestRank2) { &svdf); } -TEST_F(SVDFOpTest, BlackBoxTestHybridRank1Uint8) { +TEST_P(SVDFOpTest, BlackBoxTestHybridRank1Uint8) { HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, - /*memory_size=*/10, /*rank=*/1, TensorType_UINT8); + /*memory_size=*/10, /*rank=*/1, TensorType_UINT8, + GetParam()); svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, 0.22197971, 0.12416199, 0.27901134, 0.27557442, 0.3905206, -0.36137494, -0.06634006, -0.10640851}); @@ -347,12 +354,13 @@ TEST_F(SVDFOpTest, BlackBoxTestHybridRank1Uint8) { VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input), &svdf, - /*tolerance=*/0.002945); + /*tolerance=*/0.004285); } -TEST_F(SVDFOpTest, BlackBoxTestHybridRank2Uint8) { +TEST_P(SVDFOpTest, BlackBoxTestHybridRank2Uint8) { HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, - /*memory_size=*/10, /*rank=*/2, TensorType_UINT8); + /*memory_size=*/10, /*rank=*/2, TensorType_UINT8, + GetParam()); svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, 0.12416199, 0.15785322, 0.27901134, 0.3905206, 0.21931258, -0.36137494, -0.10640851, 0.31053296, @@ -387,12 +395,13 @@ TEST_F(SVDFOpTest, BlackBoxTestHybridRank2Uint8) { VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input), &svdf, - /*tolerance=*/0.00625109); + /*tolerance=*/0.007175); } -TEST_F(SVDFOpTest, BlackBoxTestHybridRank1Int8) { +TEST_P(SVDFOpTest, BlackBoxTestHybridRank1Int8) { HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, - /*memory_size=*/10, /*rank=*/1, TensorType_INT8); + /*memory_size=*/10, /*rank=*/1, TensorType_INT8, + GetParam()); svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, 0.22197971, 0.12416199, 0.27901134, 0.27557442, 0.3905206, -0.36137494, -0.06634006, -0.10640851}); @@ -412,12 +421,13 @@ TEST_F(SVDFOpTest, BlackBoxTestHybridRank1Int8) { VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input), &svdf, - /*tolerance=*/0.002945); + /*tolerance=*/0.004285); } -TEST_F(SVDFOpTest, BlackBoxTestHybridRank2Int8) { +TEST_P(SVDFOpTest, BlackBoxTestHybridRank2Int8) { HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, - /*memory_size=*/10, /*rank=*/2, TensorType_INT8); + /*memory_size=*/10, /*rank=*/2, TensorType_INT8, + GetParam()); svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, 0.12416199, 0.15785322, 0.27901134, 0.3905206, 0.21931258, -0.36137494, -0.10640851, 0.31053296, @@ -452,7 +462,7 @@ TEST_F(SVDFOpTest, BlackBoxTestHybridRank2Int8) { VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input), &svdf, - /*tolerance=*/0.00625109); + /*tolerance=*/0.007175); } // Test case for full integer quantization of SVDF. diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h index 0a37690a689..7b504e42371 100644 --- a/tensorflow/lite/kernels/test_util.h +++ b/tensorflow/lite/kernels/test_util.h @@ -790,6 +790,7 @@ template TensorType GetTensorType() { if (std::is_same::value) return TensorType_FLOAT32; if (std::is_same::value) return TensorType_FLOAT16; + if (std::is_same::value) return TensorType_FLOAT64; if (std::is_same::value) return TensorType_INT8; if (std::is_same::value) return TensorType_INT16; if (std::is_same::value) return TensorType_INT32; diff --git a/tensorflow/lite/kernels/transpose_conv_test.cc b/tensorflow/lite/kernels/transpose_conv_test.cc index 9a1a950fe0f..1851c01bb59 100644 --- a/tensorflow/lite/kernels/transpose_conv_test.cc +++ b/tensorflow/lite/kernels/transpose_conv_test.cc @@ -335,11 +335,6 @@ class PerChannelQuantizedTransposeConvOpModel }; TEST_P(TransposeConvOpTest, SimpleTestQuantizedPerChannelSingleChannel) { - // TODO(b/138722124): Enable these tests on NNAPI. - if (SingleOpModel::GetForceUseNnapi()) { - return; - } - const std::initializer_list filter_data = {1, 2, 3, 4, 5, 6, 7, 8, 9}; PerChannelQuantizedTransposeConvOpModel model( GetRegistration(), {1, 4, 4, 1}, @@ -363,11 +358,6 @@ TEST_P(TransposeConvOpTest, SimpleTestQuantizedPerChannelSingleChannel) { // Test data copied from the float multi-channel test above. TEST_P(TransposeConvOpTest, TestQuantizedPerChannelMultiChannel) { - // TODO(b/138722124): Enable these tests on NNAPI. - if (SingleOpModel::GetForceUseNnapi()) { - return; - } - const std::initializer_list filter_data = { 1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, 8, 10, 12, 14, 16, 18}; PerChannelQuantizedTransposeConvOpModel model( diff --git a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc index b49974da2e0..73b0535fc46 100644 --- a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc @@ -33,6 +33,7 @@ struct OpData { bool is_layer_norm_lstm; // The scratch tensor index. int scratch_tensor_index; + bool compute_row_sums = false; }; // Input Tensors of size {max_time, n_batch, n_input} @@ -92,7 +93,9 @@ enum TemporaryTensor { kProductScalingFactors = 5, kRecoveredCellWeights = 6, kAccumScratch = 7, - kNumTemporaryTensors + kZeroPoints = 8, + kRowSums = 9, + kNumTemporaryTensors = 10 }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -408,6 +411,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { scratch_buffer_size)); if (IsHybridOp(input, input_to_output_weights)) { + op_data->compute_row_sums = true; // Allocate temporary tensors to store quantized values of input, // activation_state and cell_state tensors. node->temporaries->data[kInputQuantized] = @@ -515,6 +519,34 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK( context, context->ResizeTensor(context, accum_scratch, accum_size)); } + node->temporaries->data[kZeroPoints] = scratch_tensor_index + kZeroPoints; + TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints); + zero_points->type = kTfLiteFloat32; + zero_points->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, scaling_dims)) { + TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1); + zero_points_size->data[0] = n_batch; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points, + zero_points_size)); + } + node->temporaries->data[kRowSums] = scratch_tensor_index + kRowSums; + TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums); + row_sums->type = kTfLiteInt32; + row_sums->allocation_type = kTfLiteArenaRwPersistent; + int row_sums_rows = use_cifg ? 6 : 8; + const TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + if (projection_weights != nullptr) { + row_sums_rows += ceil(n_output / n_cell); + } + int row_sums_dims[2] = {row_sums_rows, n_cell}; + if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) { + TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2); + row_sums_size->data[0] = row_sums_dims[0]; + row_sums_size->data[1] = row_sums_dims[1]; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, row_sums, row_sums_size)); + } } return kTfLiteOk; } @@ -600,6 +632,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { lstm_params.activation = params->activation; lstm_params.cell_clip = params->cell_clip; lstm_params.proj_clip = params->proj_clip; + lstm_params.asymmetric_quantize_inputs = params->asymmetric_quantize_inputs; switch (input_to_output_weights->type) { case kTfLiteFloat32: { @@ -623,6 +656,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } case kTfLiteUInt8: case kTfLiteInt8: { + OpData* op_data = reinterpret_cast(node->user_data); TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); TfLiteTensor* activation_state_quantized = GetTemporary(context, node, /*index=*/2); @@ -635,6 +669,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTemporary(context, node, /*index=*/6); TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/kAccumScratch); + TfLiteTensor* zero_points = + GetTemporary(context, node, /*index=*/kZeroPoints); + TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/kRowSums); + const int row_sums_size = row_sums->dims->data[0]; return lstm_eval::EvalHybrid( input, input_to_input_weights, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, @@ -654,7 +692,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { prod_scaling_factors, recovered_cell_weights, input_quantized, /*aux_input_quantized=*/nullptr, activation_state_quantized, cell_state_quantized, activation_state, cell_state, accum_scratch, - output, CpuBackendContext::GetFromContext(context)); + output, zero_points, row_sums, row_sums_size, + &op_data->compute_row_sums, + CpuBackendContext::GetFromContext(context)); } default: context->ReportError(context, "Type %d is not currently supported.", diff --git a/tensorflow/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/lite/kernels/unidirectional_sequence_lstm_test.cc index e89949e279e..4ea018c0cab 100644 --- a/tensorflow/lite/kernels/unidirectional_sequence_lstm_test.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_lstm_test.cc @@ -38,7 +38,8 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { float proj_clip, const std::vector>& input_shapes, const TensorType& weights_type = TensorType_FLOAT32, - bool is_layer_norm = false) + bool is_layer_norm = false, + bool asymmetric_quantize_inputs = false) : n_batch_(n_batch), n_input_(n_input), n_cell_(n_cell), @@ -131,7 +132,7 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { BuiltinOptions_UnidirectionalSequenceLSTMOptions, CreateUnidirectionalSequenceLSTMOptions( builder_, ActivationFunctionType_TANH, cell_clip, - proj_clip, time_major) + proj_clip, time_major, asymmetric_quantize_inputs) .Union()); BuildInterpreter(input_shapes); } @@ -292,11 +293,12 @@ class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel { bool time_major, bool use_cifg, bool use_peephole, bool use_projection_weights, bool use_projection_bias, float cell_clip, float proj_clip, const std::vector>& input_shapes, - TensorType tensor_type) + TensorType tensor_type, bool asymmetric_quantize_inputs) : UnidirectionalLSTMOpModel( n_batch, n_input, n_cell, n_output, sequence_length, time_major, use_cifg, use_peephole, use_projection_weights, use_projection_bias, - cell_clip, proj_clip, input_shapes, tensor_type) { + cell_clip, proj_clip, input_shapes, tensor_type, false, + asymmetric_quantize_inputs) { tensor_type_ = tensor_type; } @@ -360,7 +362,7 @@ class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel { TensorType tensor_type_; }; -class BaseUnidirectionalLstmTest : public ::testing::Test { +class BaseUnidirectionalLstmTest : public ::testing::TestWithParam { protected: // Weights of the LSTM model. Some are optional. std::vector input_to_input_weights_; @@ -626,7 +628,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest, /*time_major=*/false); } -TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest, +TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest, HybridLstmBlackBoxTestUint8) { const int n_batch = 1; const int n_input = 2; @@ -668,7 +670,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest, {n_batch, n_output}, // activation_state tensor {n_batch, n_cell}, // cell_state tensor }, - TensorType_UINT8); + TensorType_UINT8, GetParam()); lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_); @@ -689,7 +691,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest, /*tolerance=*/0.0157651); } -TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest, +TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest, HybridLstmBlackBoxTestInt8) { const int n_batch = 1; const int n_input = 2; @@ -731,7 +733,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest, {n_batch, n_output}, // activation_state tensor {n_batch, n_cell}, // cell_state tensor }, - TensorType_INT8); + TensorType_INT8, GetParam()); lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_); @@ -862,7 +864,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest, VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } -TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest, +TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest, HybridLstmBlackBoxTestUint8) { const int n_batch = 1; const int n_input = 2; @@ -880,11 +882,10 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest, { {sequence_length, n_batch, n_input}, // input tensor - {0, 0}, // input_to_input_weight tensor - {n_cell, n_input}, // input_to_forget_weight tensor - {n_cell, n_input}, // input_to_cell_weight tensor - {n_cell, n_input}, // input_to_output_weight tensor - + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor {0, 0}, // recurrent_to_input_weight tensor {n_cell, n_output}, // recurrent_to_forget_weight tensor {n_cell, n_output}, // recurrent_to_cell_weight tensor @@ -905,7 +906,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest, {n_batch, n_output}, // activation_state tensor {n_batch, n_cell}, // cell_state tensor }, - TensorType_UINT8); + TensorType_UINT8, GetParam()); lstm.SetInputToCellWeights(input_to_cell_weights_); lstm.SetInputToForgetWeights(input_to_forget_weights_); @@ -925,7 +926,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest, VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573); } -TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest, +TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest, HybridLstmBlackBoxTestInt8) { const int n_batch = 1; const int n_input = 2; @@ -968,7 +969,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest, {n_batch, n_output}, // activation_state tensor {n_batch, n_cell}, // cell_state tensor }, - TensorType_INT8); + TensorType_INT8, GetParam()); lstm.SetInputToCellWeights(input_to_cell_weights_); lstm.SetInputToForgetWeights(input_to_forget_weights_); @@ -1655,14 +1656,16 @@ TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest, VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } -TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest, +TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest, HybridLstmBlackBoxTestUint8) { const int n_batch = 2; const int n_input = 5; const int n_cell = 20; const int n_output = 16; const int sequence_length = 4; - + if (GetParam()) { + return; + } HybridUnidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, /*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/true, @@ -1697,7 +1700,7 @@ TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest, {n_batch, n_output}, // activation_state tensor {n_batch, n_cell}, // cell_state tensor }, - TensorType_UINT8); + TensorType_UINT8, GetParam()); lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_); @@ -1723,8 +1726,11 @@ TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest, VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467); } -TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest, +TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest, HybridLstmBlackBoxTestInt8) { + if (GetParam()) { + return; + } const int n_batch = 2; const int n_input = 5; const int n_cell = 20; @@ -1765,7 +1771,7 @@ TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest, {n_batch, n_output}, // activation_state tensor {n_batch, n_cell}, // cell_state tensor }, - TensorType_INT8); + TensorType_INT8, GetParam()); lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_); @@ -2737,5 +2743,14 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest, VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } +#define QUANTIZE_PARAMETER_TEST(test) \ + INSTANTIATE_TEST_SUITE_P(test, test, ::testing::ValuesIn({false, true})); + +QUANTIZE_PARAMETER_TEST( + CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest); +QUANTIZE_PARAMETER_TEST( + NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest); +QUANTIZE_PARAMETER_TEST(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest); +#undef QUANTIZE_PARAMETER_TEST } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc index 47c778185d4..7ed67c1614d 100644 --- a/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc @@ -26,6 +26,15 @@ namespace ops { namespace builtin { namespace unidirectional_sequence_rnn { +namespace { + +struct OpData { + int scratch_tensor_index; + bool compute_row_sums = false; +}; + +} // namespace + // Input tensors. constexpr int kInputTensor = 0; constexpr int kWeightsTensor = 1; @@ -37,13 +46,14 @@ constexpr int kHiddenStateTensor = 4; constexpr int kOutputTensor = 0; void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* scratch_tensor_index = new int; - context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index); - return scratch_tensor_index; + auto* op_data = new OpData(); + context->AddTensors(context, /*tensors_to_add=*/6, + &op_data->scratch_tensor_index); + return op_data; } void Free(TfLiteContext* context, void* buffer) { - delete reinterpret_cast(buffer); + delete reinterpret_cast(buffer); } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { @@ -96,10 +106,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Allocate temporary tensors to store quantized values of input and // hidden_state tensors. if (is_hybrid) { - int* scratch_tensor_index = reinterpret_cast(node->user_data); + auto* op_data = reinterpret_cast(node->user_data); + op_data->compute_row_sums = true; TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(3); - node->temporaries->data[0] = *scratch_tensor_index; + node->temporaries = TfLiteIntArrayCreate(6); + node->temporaries->data[0] = op_data->scratch_tensor_index; TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); input_quantized->type = input_weights->type; input_quantized->allocation_type = kTfLiteArenaRw; @@ -108,7 +119,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, input_quantized_size)); } - node->temporaries->data[1] = *scratch_tensor_index + 1; + node->temporaries->data[1] = op_data->scratch_tensor_index + 1; TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, /*index=*/1); hidden_state_quantized->type = input_weights->type; @@ -121,7 +132,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context->ResizeTensor(context, hidden_state_quantized, hidden_state_quantized_size)); } - node->temporaries->data[2] = *scratch_tensor_index + 2; + node->temporaries->data[2] = op_data->scratch_tensor_index + 2; TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); scaling_factors->type = kTfLiteFloat32; scaling_factors->allocation_type = kTfLiteArenaRw; @@ -132,6 +143,42 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, scaling_factors_size)); } + node->temporaries->data[3] = op_data->scratch_tensor_index + 3; + TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/3); + accum_scratch->type = kTfLiteInt32; + accum_scratch->allocation_type = kTfLiteArenaRw; + int accum_scratch_dims[2] = {num_units, batch_size}; + if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2, + accum_scratch_dims)) { + TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2); + accum_scratch_size->data[0] = accum_scratch_dims[0]; + accum_scratch_size->data[1] = accum_scratch_dims[1]; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch, + accum_scratch_size)); + } + node->temporaries->data[4] = op_data->scratch_tensor_index + 4; + TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4); + zero_points->type = kTfLiteInt32; + zero_points->allocation_type = kTfLiteArenaRw; + int zero_points_dims[1] = {batch_size}; + if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) { + TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1); + zero_points_size->data[0] = batch_size; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points, + zero_points_size)); + } + node->temporaries->data[5] = op_data->scratch_tensor_index + 5; + TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5); + row_sums->type = kTfLiteInt32; + row_sums->allocation_type = kTfLiteArenaRwPersistent; + int row_sums_dims[2] = {2, num_units}; + if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) { + TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2); + row_sums_size->data[0] = row_sums_dims[0]; + row_sums_size->data[1] = row_sums_dims[1]; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, row_sums, row_sums_size)); + } } return kTfLiteOk; } @@ -202,7 +249,9 @@ TfLiteStatus EvalHybrid( const TfLiteTensor* recurrent_weights, const TfLiteTensor* bias, const TfLiteSequenceRNNParams* params, TfLiteTensor* input_scratch, TfLiteTensor* hidden_state_scratch, TfLiteTensor* scaling_factors, - TfLiteTensor* hidden_state, TfLiteTensor* output) { + TfLiteTensor* hidden_state, TfLiteTensor* output, TfLiteTensor* zero_points, + TfLiteTensor* accum_scratch, TfLiteTensor* row_sums, + bool* compute_row_sums) { const bool time_major = params->time_major; const int batch_size = (time_major) ? input->dims->data[1] : input->dims->data[0]; @@ -227,6 +276,14 @@ TfLiteStatus EvalHybrid( float input_weights_scale = input_weights->params.scale; float recurrent_weights_scale = recurrent_weights->params.scale; float* scaling_factors_ptr = GetTensorData(scaling_factors); + int32_t* accum_scratch_ptr = GetTensorData(accum_scratch); + int32_t* zero_points_ptr = nullptr; + int32_t* row_sums_ptr = nullptr; + + if (params->asymmetric_quantize_inputs) { + zero_points_ptr = GetTensorData(zero_points); + row_sums_ptr = GetTensorData(row_sums); + } if (time_major) { // Initialize the pointer to hidden state. @@ -244,7 +301,9 @@ TfLiteStatus EvalHybrid( recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, num_units, batch_size, num_units, params->activation, quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr, - hidden_state_ptr_batch, output_ptr_batch); + hidden_state_ptr_batch, output_ptr_batch, + params->asymmetric_quantize_inputs, zero_points_ptr, + accum_scratch_ptr, row_sums_ptr, compute_row_sums); } } else { // For each batch @@ -259,13 +318,14 @@ TfLiteStatus EvalHybrid( s * input_size; float* output_ptr_batch = GetTensorData(output) + b * num_units * max_time + s * num_units; - kernel_utils::RnnBatchStep( input_ptr_batch, input_weights_ptr, input_weights_scale, recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, num_units, /*batch_size=*/1, num_units, params->activation, quantized_input_ptr, quantized_hidden_state_ptr, - scaling_factors_ptr, hidden_state_ptr_batch, output_ptr_batch); + scaling_factors_ptr, hidden_state_ptr_batch, output_ptr_batch, + params->asymmetric_quantize_inputs, zero_points_ptr, + accum_scratch_ptr, row_sums_ptr, compute_row_sums); } } } @@ -274,7 +334,6 @@ TfLiteStatus EvalHybrid( TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); - const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor); const TfLiteTensor* recurrent_weights = @@ -292,12 +351,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteUInt8: case kTfLiteInt8: { // TODO(mirkov): implement eval with quantized inputs as well. + auto* op_data = reinterpret_cast(node->user_data); TfLiteTensor* input_quantized = GetTemporary(context, node, 0); TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1); TfLiteTensor* scaling_factors = GetTemporary(context, node, 2); + TfLiteTensor* accum_scratch = GetTemporary(context, node, 3); + TfLiteTensor* zero_points = GetTemporary(context, node, 4); + TfLiteTensor* row_sums = GetTemporary(context, node, 5); return EvalHybrid(input, input_weights, recurrent_weights, bias, params, input_quantized, hidden_state_quantized, - scaling_factors, hidden_state, output); + scaling_factors, hidden_state, output, zero_points, + accum_scratch, row_sums, &op_data->compute_row_sums); } default: context->ReportError(context, "Type %d not currently supported.", diff --git a/tensorflow/lite/kernels/unidirectional_sequence_rnn_test.cc b/tensorflow/lite/kernels/unidirectional_sequence_rnn_test.cc index 7e520ee9739..8b6f102acdb 100644 --- a/tensorflow/lite/kernels/unidirectional_sequence_rnn_test.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_rnn_test.cc @@ -174,7 +174,8 @@ class UnidirectionalRNNOpModel : public SingleOpModel { UnidirectionalRNNOpModel( int batches, int sequence_len, int units, int size, bool time_major, const TensorType& weights = TensorType_FLOAT32, - const TensorType& recurrent_weights = TensorType_FLOAT32) + const TensorType& recurrent_weights = TensorType_FLOAT32, + bool asymmetric_quantize_inputs = false) : batches_(batches), sequence_len_(sequence_len), units_(units), @@ -188,7 +189,8 @@ class UnidirectionalRNNOpModel : public SingleOpModel { SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, BuiltinOptions_SequenceRNNOptions, CreateSequenceRNNOptions(builder_, time_major, - ActivationFunctionType_RELU) + ActivationFunctionType_RELU, + asymmetric_quantize_inputs) .Union()); if (time_major) { BuildInterpreter({{sequence_len_, batches_, input_size_}, @@ -249,9 +251,11 @@ class HybridUnidirectionalRNNOpModel : public UnidirectionalRNNOpModel { public: HybridUnidirectionalRNNOpModel(int batches, int sequence_len, int units, int size, bool time_major, - TensorType tensor_type) + TensorType tensor_type, + bool asymmetric_quantize_inputs) : UnidirectionalRNNOpModel(batches, sequence_len, units, size, time_major, - tensor_type, tensor_type) { + tensor_type, tensor_type, + asymmetric_quantize_inputs) { tensor_type_ = tensor_type; } @@ -297,10 +301,14 @@ TEST(UnidirectionalRNNOpTest, BlackBoxTest) { EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); } -TEST(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestUint8) { +class HybridUnidirectionalRNNOpModelOpTest + : public ::testing::TestWithParam {}; + +TEST_P(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestUint8) { HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*units=*/16, /*size=*/8, - /*time_major=*/false, TensorType_UINT8); + /*time_major=*/false, TensorType_UINT8, + GetParam()); rnn.SetWeights(rnn_weights); rnn.SetBias(rnn_bias); rnn.SetRecurrentWeights(rnn_recurrent_weights); @@ -323,10 +331,11 @@ TEST(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestUint8) { expected, /*max_abs_error=*/0.013))); } -TEST(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestInt8) { +TEST_P(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestInt8) { HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*units=*/16, /*size=*/8, - /*time_major=*/false, TensorType_INT8); + /*time_major=*/false, TensorType_INT8, + GetParam()); rnn.SetWeights(rnn_weights); rnn.SetBias(rnn_bias); rnn.SetRecurrentWeights(rnn_recurrent_weights); @@ -378,10 +387,11 @@ TEST(UnidirectionalRNNOpTest, TimeMajorBlackBoxTest) { EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); } -TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestUint8) { +TEST_P(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestUint8) { HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*units=*/16, /*size=*/8, - /*time_major=*/true, TensorType_UINT8); + /*time_major=*/true, TensorType_UINT8, + GetParam()); rnn.SetWeights(rnn_weights); rnn.SetBias(rnn_bias); rnn.SetRecurrentWeights(rnn_recurrent_weights); @@ -408,10 +418,11 @@ TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestUint8) { expected, /*max_abs_error=*/0.013))); } -TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestInt8) { +TEST_P(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestInt8) { HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*units=*/16, /*size=*/8, - /*time_major=*/true, TensorType_INT8); + /*time_major=*/true, TensorType_INT8, + GetParam()); rnn.SetWeights(rnn_weights); rnn.SetBias(rnn_bias); rnn.SetRecurrentWeights(rnn_recurrent_weights); @@ -438,5 +449,9 @@ TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestInt8) { expected, /*max_abs_error=*/0.013))); } +INSTANTIATE_TEST_SUITE_P(HybridUnidirectionalRNNOpModelOpTest, + HybridUnidirectionalRNNOpModelOpTest, + ::testing::ValuesIn({true, false})); + } // namespace } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/BUILD b/tensorflow/lite/micro/benchmarks/BUILD similarity index 85% rename from tensorflow/lite/micro/kernels/xtensa_hifimini/BUILD rename to tensorflow/lite/micro/benchmarks/BUILD index a289a27aa7a..695ea92a2d9 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/BUILD +++ b/tensorflow/lite/micro/benchmarks/BUILD @@ -12,7 +12,6 @@ cc_binary( deps = [ "//tensorflow/lite/c:common", "//tensorflow/lite/micro/kernels:micro_ops", - "//tensorflow/lite/micro/kernels:micro_utils", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -25,7 +24,6 @@ cc_binary( deps = [ "//tensorflow/lite/c:common", "//tensorflow/lite/micro/kernels:micro_ops", - "//tensorflow/lite/micro/kernels:micro_utils", "//tensorflow/lite/micro/testing:micro_test", ], ) diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/conv_benchmark.cc b/tensorflow/lite/micro/benchmarks/conv_benchmark.cc similarity index 100% rename from tensorflow/lite/micro/kernels/xtensa_hifimini/conv_benchmark.cc rename to tensorflow/lite/micro/benchmarks/conv_benchmark.cc diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/depthwise_conv_benchmark.cc b/tensorflow/lite/micro/benchmarks/depthwise_conv_benchmark.cc similarity index 100% rename from tensorflow/lite/micro/kernels/xtensa_hifimini/depthwise_conv_benchmark.cc rename to tensorflow/lite/micro/benchmarks/depthwise_conv_benchmark.cc diff --git a/tensorflow/lite/micro/examples/magic_wand/arduino/Makefile.inc b/tensorflow/lite/micro/examples/magic_wand/arduino/Makefile.inc new file mode 100644 index 00000000000..4ec1b387c5c --- /dev/null +++ b/tensorflow/lite/micro/examples/magic_wand/arduino/Makefile.inc @@ -0,0 +1,7 @@ +ifeq ($(TARGET),$(filter $(TARGET),arduino)) + +magic_wand_SRCS += \ + tensorflow/lite/micro/examples/magic_wand/sparkfun_edge/accelerometer_handler.cc \ + tensorflow/lite/micro/examples/magic_wand/sparkfun_edge/output_handler.cc + +endif diff --git a/tensorflow/lite/micro/examples/magic_wand/arduino/accelerometer_handler.cc b/tensorflow/lite/micro/examples/magic_wand/arduino/accelerometer_handler.cc index 148a6f29c5c..866b8d6fd79 100644 --- a/tensorflow/lite/micro/examples/magic_wand/arduino/accelerometer_handler.cc +++ b/tensorflow/lite/micro/examples/magic_wand/arduino/accelerometer_handler.cc @@ -13,6 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#if defined(ARDUINO) && !defined(ARDUINO_ARDUINO_NANO33BLE) +#define ARDUINO_EXCLUDE_CODE +#endif // defined(ARDUINO) && !defined(ARDUINO_ARDUINO_NANO33BLE) + +#ifndef ARDUINO_EXCLUDE_CODE + #include "tensorflow/lite/micro/examples/magic_wand/accelerometer_handler.h" #include @@ -131,3 +137,5 @@ bool ReadAccelerometer(tflite::ErrorReporter* error_reporter, float* input, return true; } + +#endif // ARDUINO_EXCLUDE_CODE diff --git a/tensorflow/lite/micro/examples/magic_wand/arduino/output_handler.cc b/tensorflow/lite/micro/examples/magic_wand/arduino/output_handler.cc index ae2f570ea42..a01869e0058 100644 --- a/tensorflow/lite/micro/examples/magic_wand/arduino/output_handler.cc +++ b/tensorflow/lite/micro/examples/magic_wand/arduino/output_handler.cc @@ -13,6 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#if defined(ARDUINO) && !defined(ARDUINO_ARDUINO_NANO33BLE) +#define ARDUINO_EXCLUDE_CODE +#endif // defined(ARDUINO) && !defined(ARDUINO_ARDUINO_NANO33BLE) + +#ifndef ARDUINO_EXCLUDE_CODE + #include "tensorflow/lite/micro/examples/magic_wand/output_handler.h" #include "Arduino.h" @@ -47,3 +53,5 @@ void HandleOutput(tflite::ErrorReporter* error_reporter, int kind) { "*\n\r *\n\r *\n\r * * * * * * * *\n\r"); } } + +#endif // ARDUINO_EXCLUDE_CODE diff --git a/tensorflow/lite/micro/examples/magic_wand/sparkfun_edge/accelerometer_handler.cc b/tensorflow/lite/micro/examples/magic_wand/sparkfun_edge/accelerometer_handler.cc index ff527c78d46..0b35b69c298 100644 --- a/tensorflow/lite/micro/examples/magic_wand/sparkfun_edge/accelerometer_handler.cc +++ b/tensorflow/lite/micro/examples/magic_wand/sparkfun_edge/accelerometer_handler.cc @@ -13,6 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#if defined(ARDUINO) && !defined(ARDUINO_SFE_EDGE) +#define ARDUINO_EXCLUDE_CODE +#endif // defined(ARDUINO) && !defined(ARDUINO_SFE_EDGE) + +#ifndef ARDUINO_EXCLUDE_CODE + #include "tensorflow/lite/micro/examples/magic_wand/accelerometer_handler.h" // These are headers from Ambiq's Apollo3 SDK. @@ -23,8 +29,8 @@ limitations under the License. #include "am_util.h" // NOLINT #include "lis2dh12_platform_apollo3.h" -lis2dh12_platform_apollo3_if_t dev_if = {0}; // accelerometer device interface -lis2dh12_ctx_t dev_ctx = {0}; // accelerometer device control +lis2dh12_platform_apollo3_if_t dev_if; // accelerometer device interface +lis2dh12_ctx_t dev_ctx; // accelerometer device control // A union representing either int16_t[3] or uint8_t[6], // storing the most recent data @@ -40,7 +46,8 @@ int initAccelerometer(void) { uint32_t retVal32 = 0; static uint8_t whoamI = 0; - am_hal_iom_config_t i2cConfig = {0}; + am_hal_iom_config_t i2cConfig; + memset((void*)(&i2cConfig), 0x00, sizeof(am_hal_iom_config_t)); i2cConfig.eInterfaceMode = AM_HAL_IOM_I2C_MODE; i2cConfig.ui32ClockFreq = AM_HAL_IOM_100KHZ; @@ -133,12 +140,12 @@ TfLiteStatus SetupAccelerometer(tflite::ErrorReporter* error_reporter) { if (lis2dh12_fifo_mode_set(&dev_ctx, LIS2DH12_BYPASS_MODE)) { TF_LITE_REPORT_ERROR(error_reporter, "Failed to clear FIFO buffer."); - return 0; + return kTfLiteError; } if (lis2dh12_fifo_mode_set(&dev_ctx, LIS2DH12_DYNAMIC_STREAM_MODE)) { TF_LITE_REPORT_ERROR(error_reporter, "Failed to set streaming mode."); - return 0; + return kTfLiteError; } TF_LITE_REPORT_ERROR(error_reporter, "Magic starts!"); @@ -208,3 +215,5 @@ bool ReadAccelerometer(tflite::ErrorReporter* error_reporter, float* input, } return true; } + +#endif // ARDUINO_EXCLUDE_CODE diff --git a/tensorflow/lite/micro/examples/magic_wand/sparkfun_edge/output_handler.cc b/tensorflow/lite/micro/examples/magic_wand/sparkfun_edge/output_handler.cc index 80d798548e6..4c3cb42631b 100644 --- a/tensorflow/lite/micro/examples/magic_wand/sparkfun_edge/output_handler.cc +++ b/tensorflow/lite/micro/examples/magic_wand/sparkfun_edge/output_handler.cc @@ -13,6 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#if defined(ARDUINO) && !defined(ARDUINO_SFE_EDGE) +#define ARDUINO_EXCLUDE_CODE +#endif // defined(ARDUINO) && !defined(ARDUINO_SFE_EDGE) + +#ifndef ARDUINO_EXCLUDE_CODE + #include "tensorflow/lite/micro/examples/magic_wand/output_handler.h" #include "am_bsp.h" // NOLINT @@ -63,3 +69,5 @@ void HandleOutput(tflite::ErrorReporter* error_reporter, int kind) { am_devices_led_on(am_bsp_psLEDs, AM_BSP_LED_GREEN); } } + +#endif // ARDUINO_EXCLUDE_CODE diff --git a/tensorflow/lite/micro/examples/magic_wand/train/train_magic_wand_model.ipynb b/tensorflow/lite/micro/examples/magic_wand/train/train_magic_wand_model.ipynb index 65f439f9090..0f33efb6c94 100644 --- a/tensorflow/lite/micro/examples/magic_wand/train/train_magic_wand_model.ipynb +++ b/tensorflow/lite/micro/examples/magic_wand/train/train_magic_wand_model.ipynb @@ -1,25 +1,10 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "Train a gesture recognition model for microcontroller use", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "GPU" - }, "cells": [ { "cell_type": "markdown", "metadata": { - "id": "1BtkMGSYQOTQ", - "colab_type": "text" + "colab_type": "text", + "id": "1BtkMGSYQOTQ" }, "source": [ "# Train a gesture recognition model for microcontroller use" @@ -28,39 +13,39 @@ { "cell_type": "markdown", "metadata": { - "id": "BaFfr7DHRmGF", - "colab_type": "text" + "colab_type": "text", + "id": "BaFfr7DHRmGF" }, "source": [ "This notebook demonstrates how to train a 20kb gesture recognition model for [TensorFlow Lite for Microcontrollers](https://tensorflow.org/lite/microcontrollers/overview). It will produce the same model used in the [magic_wand](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/micro/examples/magic_wand) example application.\n", "\n", "The model is designed to be used with [Google Colaboratory](https://colab.research.google.com).\n", "\n", - "\n", - " \n", - " \n", - "
\n", - " Run in Google Colab\n", - " \n", - " View source on GitHub\n", - "
\n" + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/examples/magic_wand/train/train_magic_wand_model.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/examples/magic_wand/train/train_magic_wand_model.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", + " \u003c/td\u003e\n", + "\u003c/table\u003e\n" ] }, { "cell_type": "markdown", "metadata": { - "id": "xXgS6rxyT7Qk", - "colab_type": "text" + "colab_type": "text", + "id": "xXgS6rxyT7Qk" }, "source": [ - "Training is much faster using GPU acceleration. Before you proceed, ensure you are using a GPU runtime by going to **Runtime -> Change runtime type** and selecting **GPU**. Training will take around 5 minutes on a GPU runtime." + "Training is much faster using GPU acceleration. Before you proceed, ensure you are using a GPU runtime by going to **Runtime -\u003e Change runtime type** and selecting **GPU**. Training will take around 5 minutes on a GPU runtime." ] }, { "cell_type": "markdown", "metadata": { - "id": "LG6ErX5FRIaV", - "colab_type": "text" + "colab_type": "text", + "id": "LG6ErX5FRIaV" }, "source": [ "## Configure dependencies\n", @@ -68,24 +53,11 @@ "Run the following cell to ensure the correct version of TensorFlow is used." ] }, - { - "cell_type": "code", - "metadata": { - "id": "h3sE3keZZnMX", - "colab_type": "code", - "colab": {} - }, - "source": [ - "%tensorflow_version 2.x\n" - ], - "execution_count": 0, - "outputs": [] - }, { "cell_type": "markdown", "metadata": { - "id": "STNft9TrfoVh", - "colab_type": "text" + "colab_type": "text", + "id": "STNft9TrfoVh" }, "source": [ "We'll also clone the TensorFlow repository, which contains the training scripts, and copy them into our workspace." @@ -93,25 +65,25 @@ }, { "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "ygkWw73dRNda", + "colab": {}, "colab_type": "code", - "colab": {} + "id": "ygkWw73dRNda" }, + "outputs": [], "source": [ "# Clone the repository from GitHub\n", "!git clone --depth 1 -q https://github.com/tensorflow/tensorflow\n", "# Copy the training scripts into our workspace\n", "!cp -r tensorflow/tensorflow/lite/micro/examples/magic_wand/train train" - ], - "execution_count": 0, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "pXI7R4RehFdU", - "colab_type": "text" + "colab_type": "text", + "id": "pXI7R4RehFdU" }, "source": [ "## Prepare the data\n", @@ -121,25 +93,25 @@ }, { "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "W2Sg2AKzVr2L", + "colab": {}, "colab_type": "code", - "colab": {} + "id": "W2Sg2AKzVr2L" }, + "outputs": [], "source": [ "# Download the data we will use to train the model\n", "!wget http://download.tensorflow.org/models/tflite/magic_wand/data.tar.gz\n", "# Extract the data into the train directory\n", - "!tar xvzf data.tar.gz -C train 1>/dev/null" - ], - "execution_count": 0, - "outputs": [] + "!tar xvzf data.tar.gz -C train 1\u003e/dev/null" + ] }, { "cell_type": "markdown", "metadata": { - "id": "DNjukI1Sgl2C", - "colab_type": "text" + "colab_type": "text", + "id": "DNjukI1Sgl2C" }, "source": [ "We'll then run the scripts that split the data into training, validation, and test sets." @@ -147,11 +119,13 @@ }, { "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "XBqSVpi6Vxss", + "colab": {}, "colab_type": "code", - "colab": {} + "id": "XBqSVpi6Vxss" }, + "outputs": [], "source": [ "# The scripts must be run from within the train directory\n", "%cd train\n", @@ -159,15 +133,13 @@ "!python data_prepare.py\n", "# Split the data by person\n", "!python data_split_person.py" - ], - "execution_count": 0, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "5-cmVbFvhTvy", - "colab_type": "text" + "colab_type": "text", + "id": "5-cmVbFvhTvy" }, "source": [ "## Load TensorBoard\n", @@ -177,24 +149,24 @@ }, { "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "CCx6SN9NWRPw", + "colab": {}, "colab_type": "code", - "colab": {} + "id": "CCx6SN9NWRPw" }, + "outputs": [], "source": [ "# Load TensorBoard\n", "%load_ext tensorboard\n", "%tensorboard --logdir logs/scalars" - ], - "execution_count": 0, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "ERC2Cr4PhaOl", - "colab_type": "text" + "colab_type": "text", + "id": "ERC2Cr4PhaOl" }, "source": [ "## Begin training\n", @@ -204,22 +176,22 @@ }, { "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "DXmQZgbuWQFO", + "colab": {}, "colab_type": "code", - "colab": {} + "id": "DXmQZgbuWQFO" }, + "outputs": [], "source": [ "!python train.py --model CNN --person true" - ], - "execution_count": 0, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "4gXbVzcXhvGD", - "colab_type": "text" + "colab_type": "text", + "id": "4gXbVzcXhvGD" }, "source": [ "## Create a C source file\n", @@ -231,21 +203,36 @@ }, { "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "8wgei4OGe3Nz", + "colab": {}, "colab_type": "code", - "colab": {} + "id": "8wgei4OGe3Nz" }, + "outputs": [], "source": [ "# Install xxd if it is not available\n", "!apt-get -qq install xxd\n", "# Save the file as a C source file\n", - "!xxd -i model.tflite > /content/model.cc\n", + "!xxd -i model.tflite \u003e /content/model.cc\n", "# Print the source file\n", "!cat /content/model.cc" - ], - "execution_count": 0, - "outputs": [] + ] } - ] -} \ No newline at end of file + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "Train a gesture recognition model for microcontroller use", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/lite/micro/examples/micro_speech/arduino/Makefile.inc b/tensorflow/lite/micro/examples/micro_speech/arduino/Makefile.inc new file mode 100644 index 00000000000..d24e3c092ee --- /dev/null +++ b/tensorflow/lite/micro/examples/micro_speech/arduino/Makefile.inc @@ -0,0 +1,7 @@ +ifeq ($(TARGET),$(filter $(TARGET),arduino)) + +MICRO_SPEECH_SRCS += \ + tensorflow/lite/micro/examples/micro_speech/sparkfun_edge/audio_provider.cc \ + tensorflow/lite/micro/examples/micro_speech/sparkfun_edge/command_responder.cc + +endif diff --git a/tensorflow/lite/micro/examples/micro_speech/arduino/audio_provider.cc b/tensorflow/lite/micro/examples/micro_speech/arduino/audio_provider.cc index c783aea034e..efbe5011187 100644 --- a/tensorflow/lite/micro/examples/micro_speech/arduino/audio_provider.cc +++ b/tensorflow/lite/micro/examples/micro_speech/arduino/audio_provider.cc @@ -28,6 +28,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#if defined(ARDUINO) && !defined(ARDUINO_ARDUINO_NANO33BLE) +#define ARDUINO_EXCLUDE_CODE +#endif // defined(ARDUINO) && !defined(ARDUINO_ARDUINO_NANO33BLE) + +#ifndef ARDUINO_EXCLUDE_CODE + #include "tensorflow/lite/micro/examples/micro_speech/audio_provider.h" #include "PDM.h" @@ -116,3 +122,5 @@ TfLiteStatus GetAudioSamples(tflite::ErrorReporter* error_reporter, } int32_t LatestAudioTimestamp() { return g_latest_audio_timestamp; } + +#endif // ARDUINO_EXCLUDE_CODE diff --git a/tensorflow/lite/micro/examples/micro_speech/arduino/command_responder.cc b/tensorflow/lite/micro/examples/micro_speech/arduino/command_responder.cc index 9b67e78c772..467742406de 100644 --- a/tensorflow/lite/micro/examples/micro_speech/arduino/command_responder.cc +++ b/tensorflow/lite/micro/examples/micro_speech/arduino/command_responder.cc @@ -13,6 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#if defined(ARDUINO) && !defined(ARDUINO_ARDUINO_NANO33BLE) +#define ARDUINO_EXCLUDE_CODE +#endif // defined(ARDUINO) && !defined(ARDUINO_ARDUINO_NANO33BLE) + +#ifndef ARDUINO_EXCLUDE_CODE + #include "tensorflow/lite/micro/examples/micro_speech/command_responder.h" #include "Arduino.h" @@ -83,3 +89,5 @@ void RespondToCommand(tflite::ErrorReporter* error_reporter, digitalWrite(LED_BUILTIN, LOW); } } + +#endif // ARDUINO_EXCLUDE_CODE diff --git a/tensorflow/lite/micro/examples/micro_speech/sparkfun_edge/audio_provider.cc b/tensorflow/lite/micro/examples/micro_speech/sparkfun_edge/audio_provider.cc index 179705e5647..82aa99e5a55 100644 --- a/tensorflow/lite/micro/examples/micro_speech/sparkfun_edge/audio_provider.cc +++ b/tensorflow/lite/micro/examples/micro_speech/sparkfun_edge/audio_provider.cc @@ -13,6 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#if defined(ARDUINO) && !defined(ARDUINO_SFE_EDGE) +#define ARDUINO_EXCLUDE_CODE +#endif // defined(ARDUINO) && !defined(ARDUINO_SFE_EDGE) + +#ifndef ARDUINO_EXCLUDE_CODE + #include "tensorflow/lite/micro/examples/micro_speech/audio_provider.h" #include @@ -367,3 +373,5 @@ TfLiteStatus GetAudioSamples(tflite::ErrorReporter* error_reporter, } int32_t LatestAudioTimestamp() { return g_latest_audio_timestamp; } + +#endif // ARDUINO_EXCLUDE_CODE diff --git a/tensorflow/lite/micro/examples/micro_speech/sparkfun_edge/command_responder.cc b/tensorflow/lite/micro/examples/micro_speech/sparkfun_edge/command_responder.cc index c3333f42ef8..bdb14829fe5 100644 --- a/tensorflow/lite/micro/examples/micro_speech/sparkfun_edge/command_responder.cc +++ b/tensorflow/lite/micro/examples/micro_speech/sparkfun_edge/command_responder.cc @@ -13,6 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#if defined(ARDUINO) && !defined(ARDUINO_SFE_EDGE) +#define ARDUINO_EXCLUDE_CODE +#endif // defined(ARDUINO) && !defined(ARDUINO_SFE_EDGE) + +#ifndef ARDUINO_EXCLUDE_CODE + #include "tensorflow/lite/micro/examples/micro_speech/command_responder.h" #include "am_bsp.h" // NOLINT @@ -53,3 +59,5 @@ void RespondToCommand(tflite::ErrorReporter* error_reporter, } } } + +#endif // ARDUINO_EXCLUDE_CODE diff --git a/tensorflow/lite/micro/examples/person_detection/Makefile.inc b/tensorflow/lite/micro/examples/person_detection/Makefile.inc index ca95f736cd4..a295bb83f71 100644 --- a/tensorflow/lite/micro/examples/person_detection/Makefile.inc +++ b/tensorflow/lite/micro/examples/person_detection/Makefile.inc @@ -1,4 +1,5 @@ $(eval $(call add_third_party_download,$(PERSON_MODEL_URL),$(PERSON_MODEL_MD5),person_model_grayscale,)) +$(eval $(call add_third_party_download,$(RUY_URL),$(RUY_MD5),ruy,)) person_detection_MODEL_SRCS := \ tensorflow/lite/micro/examples/person_detection/model_settings.cc \ diff --git a/tensorflow/lite/micro/examples/person_detection/arduino/HM01B0_platform.h b/tensorflow/lite/micro/examples/person_detection/arduino/HM01B0_platform.h new file mode 100644 index 00000000000..50835f9f1cb --- /dev/null +++ b/tensorflow/lite/micro/examples/person_detection/arduino/HM01B0_platform.h @@ -0,0 +1,25 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_ARDUINO_HM01B0_PLATFORM_H_ +#define TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_ARDUINO_HM01B0_PLATFORM_H_ + +#if defined(ARDUINO) && defined(ARDUINO_SFE_EDGE) +#include "hm01b0_platform_edge.h" +#define HM01B0_PIN_TRIG 0 // unused +#define HM01B0_PIN_INT 0 // unused +#endif // defined(ARDUINO) && defined(ARDUINO_SFE_EDGE) + +#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_ARDUINO_HM01B0_PLATFORM_H_ \ No newline at end of file diff --git a/tensorflow/lite/micro/examples/person_detection/arduino/Makefile.inc b/tensorflow/lite/micro/examples/person_detection/arduino/Makefile.inc new file mode 100644 index 00000000000..3181b36a268 --- /dev/null +++ b/tensorflow/lite/micro/examples/person_detection/arduino/Makefile.inc @@ -0,0 +1,18 @@ +ifeq ($(TARGET),$(filter $(TARGET),arduino)) + +person_detection_SRCS += \ + tensorflow/lite/micro/examples/person_detection/sparkfun_edge/image_provider.cc \ + tensorflow/lite/micro/examples/person_detection/sparkfun_edge/detection_responder.cc \ + tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0_debug.c \ + tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0_optimized.c \ + tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0.c + +person_detection_HDRS += \ + tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0_debug.h \ + tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0_optimized.h \ + tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0_RAW8_QVGA_8bits_lsb_5fps.h \ + tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0_Walking1s_01.h \ + tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0.h \ + tensorflow/lite/micro/examples/person_detection/arduino/HM01B0_platform.h + +endif diff --git a/tensorflow/lite/micro/examples/person_detection/arduino/detection_responder.cc b/tensorflow/lite/micro/examples/person_detection/arduino/detection_responder.cc index 790f1753f76..622972092cb 100644 --- a/tensorflow/lite/micro/examples/person_detection/arduino/detection_responder.cc +++ b/tensorflow/lite/micro/examples/person_detection/arduino/detection_responder.cc @@ -13,6 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#if defined(ARDUINO) && !defined(ARDUINO_ARDUINO_NANO33BLE) +#define ARDUINO_EXCLUDE_CODE +#endif // defined(ARDUINO) && !defined(ARDUINO_ARDUINO_NANO33BLE) + +#ifndef ARDUINO_EXCLUDE_CODE + #include "tensorflow/lite/micro/examples/person_detection/detection_responder.h" #include "Arduino.h" @@ -54,3 +60,5 @@ void RespondToDetection(tflite::ErrorReporter* error_reporter, TF_LITE_REPORT_ERROR(error_reporter, "Person score: %d No person score: %d", person_score, no_person_score); } + +#endif // ARDUINO_EXCLUDE_CODE diff --git a/tensorflow/lite/micro/examples/person_detection/arduino/image_provider.cc b/tensorflow/lite/micro/examples/person_detection/arduino/image_provider.cc index f652490e1ce..cd79e6b2d44 100644 --- a/tensorflow/lite/micro/examples/person_detection/arduino/image_provider.cc +++ b/tensorflow/lite/micro/examples/person_detection/arduino/image_provider.cc @@ -35,6 +35,12 @@ limitations under the License. * "#define LOAD_SD_LIBRARY" and "#define LOAD_SDFAT_LIBRARY". */ +#if defined(ARDUINO) && !defined(ARDUINO_ARDUINO_NANO33BLE) +#define ARDUINO_EXCLUDE_CODE +#endif // defined(ARDUINO) && !defined(ARDUINO_ARDUINO_NANO33BLE) + +#ifndef ARDUINO_EXCLUDE_CODE + // Required by Arducam library #include #include @@ -261,3 +267,5 @@ TfLiteStatus GetImage(tflite::ErrorReporter* error_reporter, int image_width, return kTfLiteOk; } + +#endif // ARDUINO_EXCLUDE_CODE diff --git a/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0.c b/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0.c index 1cb9d45ea2e..8e457ec4ca8 100644 --- a/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0.c +++ b/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0.c @@ -13,6 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#if defined(ARDUINO) && !defined(ARDUINO_SFE_EDGE) +#define ARDUINO_EXCLUDE_CODE +#endif // defined(ARDUINO) && !defined(ARDUINO_SFE_EDGE) + +#ifndef ARDUINO_EXCLUDE_CODE + #include "HM01B0.h" #include "HM01B0_Walking1s_01.h" @@ -756,3 +762,5 @@ uint32_t hm01b0_single_frame_capture(hm01b0_cfg_t* psCfg) { HM01B0_REG_MODE_SELECT_STREAMING_NFRAMES, 1); hm01b0_write_reg(psCfg, HM01B0_REG_GRP_PARAM_HOLD, 0x01, 1); } + +#endif // ARDUINO_EXCLUDE_CODE diff --git a/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0.h b/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0.h index 49f01ddca58..e2561da6d10 100644 --- a/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0.h +++ b/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0.h @@ -16,12 +16,23 @@ limitations under the License. #ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_HIMAX_DRIVER_HM01B0_H_ #define TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_HIMAX_DRIVER_HM01B0_H_ +#if defined(ARDUINO) && !defined(ARDUINO_SFE_EDGE) +#define ARDUINO_EXCLUDE_CODE +#endif // defined(ARDUINO) && !defined(ARDUINO_SFE_EDGE) + #ifdef __cplusplus extern "C" { #endif + +#ifndef ARDUINO_EXCLUDE_CODE #include "am_bsp.h" // NOLINT #include "am_mcu_apollo.h" // NOLINT #include "am_util.h" // NOLINT +#endif // ARDUINO_EXCLUDE_CODE + +#if defined(ARDUINO) +#include "tensorflow/lite/micro/examples/person_detection/arduino/HM01B0_platform.h" +#endif // defined(ARDUINO) #define HM01B0_DRV_VERSION (0) #define HM01B0_DRV_SUBVERSION (3) diff --git a/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0_debug.c b/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0_debug.c index bf897850ec3..3a64b701a04 100644 --- a/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0_debug.c +++ b/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0_debug.c @@ -13,6 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#if defined(ARDUINO) && !defined(ARDUINO_SFE_EDGE) +#define ARDUINO_EXCLUDE_CODE +#endif // defined(ARDUINO) && !defined(ARDUINO_SFE_EDGE) + +#ifndef ARDUINO_EXCLUDE_CODE + #include "HM01B0_debug.h" #include "am_util.h" // NOLINT @@ -33,3 +39,4 @@ void hm01b0_framebuffer_dump(uint8_t* frame, uint32_t length) { am_util_delay_ms(1); } +#endif // ARDUINO_EXCLUDE_CODE diff --git a/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0_optimized.c b/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0_optimized.c index ec5a2c1c47b..0547ba82cdb 100644 --- a/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0_optimized.c +++ b/tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0_optimized.c @@ -13,6 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#if defined(ARDUINO) && !defined(ARDUINO_SFE_EDGE) +#define ARDUINO_EXCLUDE_CODE +#endif // defined(ARDUINO) && !defined(ARDUINO_SFE_EDGE) + +#ifndef ARDUINO_EXCLUDE_CODE + #include "HM01B0.h" #include "am_bsp.h" //NOLINT #include "am_mcu_apollo.h" //NOLINT @@ -82,3 +88,5 @@ uint32_t hm01b0_blocking_read_oneframe_scaled( } return HM01B0_ERR_OK; } + +#endif // ARDUINO_EXCLUDE_CODE diff --git a/tensorflow/lite/micro/examples/person_detection/sparkfun_edge/detection_responder.cc b/tensorflow/lite/micro/examples/person_detection/sparkfun_edge/detection_responder.cc index 40625ef56bc..9025232b215 100644 --- a/tensorflow/lite/micro/examples/person_detection/sparkfun_edge/detection_responder.cc +++ b/tensorflow/lite/micro/examples/person_detection/sparkfun_edge/detection_responder.cc @@ -13,6 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#if defined(ARDUINO) && !defined(ARDUINO_SFE_EDGE) +#define ARDUINO_EXCLUDE_CODE +#endif // defined(ARDUINO) && !defined(ARDUINO_SFE_EDGE) + +#ifndef ARDUINO_EXCLUDE_CODE + #include "tensorflow/lite/micro/examples/person_detection/detection_responder.h" #include "am_bsp.h" // NOLINT @@ -47,3 +53,5 @@ void RespondToDetection(tflite::ErrorReporter* error_reporter, TF_LITE_REPORT_ERROR(error_reporter, "Person score: %d No person score: %d", person_score, no_person_score); } + +#endif // ARDUINO_EXCLUDE_CODE diff --git a/tensorflow/lite/micro/examples/person_detection/sparkfun_edge/image_provider.cc b/tensorflow/lite/micro/examples/person_detection/sparkfun_edge/image_provider.cc index 7226e59cd65..22c52651e7c 100644 --- a/tensorflow/lite/micro/examples/person_detection/sparkfun_edge/image_provider.cc +++ b/tensorflow/lite/micro/examples/person_detection/sparkfun_edge/image_provider.cc @@ -13,6 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#if defined(ARDUINO) +#include "tensorflow/lite/micro/examples/person_detection/arduino/HM01B0_platform.h" +#endif // defined(ARDUINO) + +#if defined(ARDUINO) && !defined(ARDUINO_SFE_EDGE) +#define ARDUINO_EXCLUDE_CODE +#endif // defined(ARDUINO) && !defined(ARDUINO_SFE_EDGE) + +#ifndef ARDUINO_EXCLUDE_CODE + #include "tensorflow/lite/micro/examples/person_detection/image_provider.h" #include "tensorflow/lite/micro/examples/person_detection/himax_driver/HM01B0.h" @@ -204,3 +214,5 @@ TfLiteStatus GetImage(tflite::ErrorReporter* error_reporter, int frame_width, return kTfLiteOk; } + +#endif // ARDUINO_EXCLUDE_CODE diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/softmax.cc b/tensorflow/lite/micro/kernels/cmsis-nn/softmax.cc index ce45c1ae9b1..e9b76b3891f 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/softmax.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/softmax.cc @@ -81,10 +81,13 @@ void SoftmaxFloat(const TfLiteTensor* input, TfLiteTensor* output, void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output, const SoftmaxParams& op_data) { + const auto input_shape = GetTensorShape(input); + const auto output_shape = GetTensorShape(output); + if (input->type == kTfLiteUInt8) { - tflite::reference_ops::Softmax( - op_data, GetTensorShape(input), GetTensorData(input), - GetTensorShape(output), GetTensorData(output)); + tflite::reference_ops::Softmax(op_data, input_shape, + GetTensorData(input), output_shape, + GetTensorData(output)); } else { const unsigned int num_dims = NumDimensions(input); @@ -117,7 +120,7 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { } case kTfLiteInt8: case kTfLiteUInt8: { - SoftmaxQuantized(input, output, params, op_data); + SoftmaxQuantized(input, output, op_data); return kTfLiteOk; } default: diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/conv.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/conv.cc index e3bec9ddcb4..58c0c9fdd22 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/conv.cc +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/conv.cc @@ -183,8 +183,6 @@ constexpr int kMaxChannels = 256; // https://www.tensorflow.org/lite/performance/quantization_spec constexpr int kConvQuantizedDimension = 0; -const int kTensorNotAllocated = -1; - struct OpData { TfLitePaddingValues padding; // The scaling factor from input to output (aka the 'real multiplier') can @@ -253,7 +251,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, kOutputTensor); const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); - const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); // TODO(b/132070898): Use statically slotted OpData structures until a // scratch memory API is ready. diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/depthwise_conv.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/depthwise_conv.cc index 03fd3969c97..32eaf72c68d 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/depthwise_conv.cc +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/depthwise_conv.cc @@ -253,11 +253,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); - const TfLiteTensor* bias = - (NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr; // TODO(b/132070898): Use statically slotted OpData structures until a // scratch memory API is ready. diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h b/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h index bacea7c2eb4..f35ffaa741e 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h @@ -16,9 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_LITE_MICRO_KERNELS_XTENSA_HIFIMINI_FIXEDPOINT_UTILS_H_ #define TENSORFLOW_LITE_MICRO_KERNELS_XTENSA_HIFIMINI_FIXEDPOINT_UTILS_H_ -#include #include +#include +#include +#include + +#include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/micro/kernels/xtensa_hifimini/utils.h" namespace tflite { @@ -191,17 +195,17 @@ inline ae_q56s MultiplyByQuantizedMultiplier(ae_p24x2s x_24x2, // // Calculate quantization params for 24bit runtimes. // -inline void QuantizeMultiplier(double double_multiplier, - int32_t* quantized_multiplier, int* shift) { - if (double_multiplier == 0.) { +inline void QuantizeMultiplier(float multiplier, int32_t* quantized_multiplier, + int* shift) { + if (multiplier == 0.0f) { *quantized_multiplier = 0; *shift = 0; return; } // Special cased to 24bit: - const double q = std::frexp(double_multiplier, shift); - auto q_fixed = static_cast(TfLiteRound(q * (1 << 23))); + const float q = std::frexp(multiplier, shift); + auto q_fixed = static_cast(std::round(q * (1 << 23))); TFLITE_CHECK(q_fixed <= (1 << 23)); if (q_fixed == (1 << 23)) { @@ -221,11 +225,11 @@ inline void QuantizeMultiplier(double double_multiplier, // Convert a floating point number to a Q representation for 24 bit integers. // inline int CreateQConstantForInt24(int integer_bits, float f) { - const double min_bounds = static_cast(INT24_MIN); - const double max_bounds = static_cast(INT24_MAX); + const float min_bounds = static_cast(INT24_MIN); + const float max_bounds = static_cast(INT24_MAX); int fractional_bits = 23 - integer_bits; - double raw = std::round(f * static_cast(1 << fractional_bits)); + float raw = std::round(f * static_cast(1 << fractional_bits)); raw = std::max(raw, min_bounds); raw = std::min(raw, max_bounds); return static_cast(raw); diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc index 47a2077fec1..e619d025dc1 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc @@ -26,97 +26,14 @@ limitations under the License. namespace tflite { namespace ops { namespace micro { - -namespace xtensa { -namespace hifimini { - -// Quantized softmax with int8 input and int8/int16 output. -template -inline void Softmax(const SoftmaxParams& params, - const RuntimeShape& input_shape, const int8* input_data, - const RuntimeShape& output_shape, OutputT* output_data) { - const int32_t input_beta_multiplier = params.input_multiplier; - const int32_t input_beta_left_shift = params.input_left_shift; - const int diff_min = params.diff_min; - // The representation chosen for the input to the exp() function is Q5.26. - // We need to leave extra space since values that we skip might be as large as - // -32 before multiplying by input_beta_multiplier, and therefore as large as - // -16 afterwards. Note that exp(-8) is definitely not insignificant to - // accumulation, but exp(-16) definitely is. - static const int kScaledDiffIntegerBits = 5; - static const int kAccumulationIntegerBits = 12; - using FixedPointScaledDiff = - gemmlowp::FixedPoint; - using FixedPointAccum = - gemmlowp::FixedPoint; - using FixedPoint0 = gemmlowp::FixedPoint; - - const int trailing_dim = input_shape.DimensionsCount() - 1; - const int outer_size = - MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); - const int depth = - MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); - - for (int i = 0; i < outer_size; ++i) { - int8 max_in_row = -128; - for (int c = 0; c < depth; ++c) { - max_in_row = std::max(max_in_row, input_data[i * depth + c]); - } - - FixedPointAccum sum_of_exps = FixedPointAccum::Zero(); - for (int c = 0; c < depth; ++c) { - int32_t input_diff = - static_cast(input_data[i * depth + c]) - max_in_row; - if (input_diff >= diff_min) { - const int32_t input_diff_rescaled = - MultiplyByQuantizedMultiplierGreaterThanOne( - input_diff, input_beta_multiplier, input_beta_left_shift); - const FixedPointScaledDiff scaled_diff_f8 = - FixedPointScaledDiff::FromRaw(input_diff_rescaled); - sum_of_exps = sum_of_exps + gemmlowp::Rescale( - exp_on_negative_values(scaled_diff_f8)); - } - } - - int num_bits_over_unit; - FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal( - sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit)); - - for (int c = 0; c < depth; ++c) { - int32_t input_diff = - static_cast(input_data[i * depth + c]) - max_in_row; - if (input_diff >= diff_min) { - const int32_t input_diff_rescaled = - MultiplyByQuantizedMultiplierGreaterThanOne( - input_diff, input_beta_multiplier, input_beta_left_shift); - const FixedPointScaledDiff scaled_diff_f8 = - FixedPointScaledDiff::FromRaw(input_diff_rescaled); - - FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8); - const int32_t unsat_output = gemmlowp::RoundingDivideByPOT( - (shifted_scale * exp_in_0).raw(), - num_bits_over_unit + 31 - (sizeof(OutputT) * 8)); - // TODO(b/148494470): Handle int32 shifts properly: - const int32_t shifted_output = - unsat_output - - (static_cast(std::numeric_limits::max()) + 1); - output_data[i * depth + c] = static_cast(std::max( - std::min(shifted_output, - static_cast(std::numeric_limits::max())), - static_cast(std::numeric_limits::min()))); - } else { - output_data[i * depth + c] = std::numeric_limits::min(); - } - } - } -} - -} // namespace hifimini -} // namespace xtensa - namespace activations { namespace { +// TODO(b/141176180): This code is currently a strict subset of the portable +// implementation (softmax.cc one directory up). When TFLM implements +// registrations for selective types (e.g. compile without float support), this +// can be removed. Otherwise, any HiFi specific optimizations should land here. + // This size will work for both the hotword (1) and ambient music (0): static SoftmaxParams kStaticOpData; @@ -143,7 +60,8 @@ TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context, int input_left_shift; tflite::PreprocessSoftmaxScaling( - params->beta, input->params.scale, kScaledDiffIntegerBits, + static_cast(params->beta), + static_cast(input->params.scale), kScaledDiffIntegerBits, &op_data->input_multiplier, &input_left_shift); op_data->input_left_shift = input_left_shift; op_data->diff_min = @@ -158,12 +76,11 @@ TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context, void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output, const SoftmaxParams& op_params) { if (output->type == kTfLiteInt16) { - xtensa::hifimini::Softmax( + tflite::reference_ops::Softmax( op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(output), GetTensorData(output)); - } else { - xtensa::hifimini::Softmax( + tflite::reference_ops::Softmax( op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(output), GetTensorData(output)); } diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/svdf.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/svdf.cc index 7f1ade86d35..6833b5fbd7d 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/svdf.cc +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/svdf.cc @@ -167,9 +167,10 @@ void EvalIntegerSVDF( const int16_t* vector1_ptr = GetTensorData(weights_time_tensor); const int16_t* vector2_ptr = state_ptr + b * n_memory * n_filter; - int num_iters = n_filter / 2; - const ae_p16x2s* offset_vector1 = (const ae_p16x2s*)(vector1_ptr - 2); - const ae_p16x2s* offset_vector2 = (const ae_p16x2s*)(vector2_ptr - 2); + const ae_p16x2s* offset_vector1 = + reinterpret_cast(vector1_ptr - 2); + const ae_p16x2s* offset_vector2 = + reinterpret_cast(vector2_ptr - 2); for (int i = 0; i < n_filter; i++) { *scratch_ptr_batch = 0; @@ -238,7 +239,6 @@ void EvalIntegerSVDF( // Cap min/max and convert to int32 (already aligned to 32bit): x_56 = AE_MAXQ56S(x_56, output_int8_min_56); x_56 = AE_MINQ56S(x_56, output_int8_max_56); - int32_t x_32 = AE_TRUNCA32Q48(x_56); GetTensorData(output_tensor)[i] = static_cast(AE_TRUNCA32Q48(x_56)); } @@ -361,12 +361,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { weights_time->quantization.params); auto* output_params = reinterpret_cast(output->quantization.params); - const double effective_scale_1 = input_params->scale->data[0] * - weights_feature_params->scale->data[0] / - state_params->scale->data[0]; - const double effective_scale_2 = state_params->scale->data[0] * - weight_time_params->scale->data[0] / - output_params->scale->data[0]; + const float effective_scale_1 = input_params->scale->data[0] * + weights_feature_params->scale->data[0] / + state_params->scale->data[0]; + const float effective_scale_2 = state_params->scale->data[0] * + weight_time_params->scale->data[0] / + output_params->scale->data[0]; xtensa::hifimini::QuantizeMultiplier(effective_scale_1, &op_data->effective_scale_1_a, &op_data->effective_scale_1_b); diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/utils.h b/tensorflow/lite/micro/kernels/xtensa_hifimini/utils.h index 47170085b9f..59caf4bbf2f 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/utils.h +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/utils.h @@ -18,6 +18,8 @@ limitations under the License. #include +#include + // INT24 MIN/MAX #define INT24_MIN -8388608 #define INT24_MAX 8388607 @@ -28,9 +30,9 @@ limitations under the License. // the "signed" or upper 8bits are discarded. inline ae_p24x2s AE_CONVERT_INT32_24x2(int32_t v) { if (v > INT24_MIN && v < INT24_MAX) { - return *((ae_p24s*)&v); + return *reinterpret_cast(&v); } else { - return (ae_p24s) * ((ae_p24f*)&v); + return static_cast(*reinterpret_cast(&v)); } } diff --git a/tensorflow/lite/micro/micro_interpreter.cc b/tensorflow/lite/micro/micro_interpreter.cc index 4c6b21c99a0..09aa1dc1d08 100644 --- a/tensorflow/lite/micro/micro_interpreter.cc +++ b/tensorflow/lite/micro/micro_interpreter.cc @@ -257,32 +257,30 @@ TfLiteStatus MicroInterpreter::Invoke() { } TfLiteTensor* MicroInterpreter::input(size_t index) { - const flatbuffers::Vector* inputs = subgraph_->inputs(); - const size_t length = inputs->size(); + const size_t length = inputs_size(); if ((index < 0) || (index >= length)) { TF_LITE_REPORT_ERROR(error_reporter_, "Input index %d out of range (length is %d)", index, length); return nullptr; } - return &(context_.tensors[inputs->Get(index)]); + return &(context_.tensors[inputs().Get(index)]); } TfLiteTensor* MicroInterpreter::output(size_t index) { - const flatbuffers::Vector* outputs = subgraph_->outputs(); - const size_t length = outputs->size(); - if ((index < 0) || (index >= outputs->size())) { + const size_t length = outputs_size(); + if ((index < 0) || (index >= length)) { TF_LITE_REPORT_ERROR(error_reporter_, "Output index %d out of range (length is %d)", index, length); return nullptr; } - return &(context_.tensors[outputs->Get(index)]); + return &(context_.tensors[outputs().Get(index)]); } TfLiteTensor* MicroInterpreter::tensor(size_t index) { const size_t length = tensors_size(); - if ((index < 0) || (index >= tensors_size())) { + if ((index < 0) || (index >= length)) { TF_LITE_REPORT_ERROR(error_reporter_, "Tensor index %d out of range (length is %d)", index, length); diff --git a/tensorflow/lite/micro/micro_optional_debug_tools.cc b/tensorflow/lite/micro/micro_optional_debug_tools.cc index bc69eb55315..70f16c78d79 100644 --- a/tensorflow/lite/micro/micro_optional_debug_tools.cc +++ b/tensorflow/lite/micro/micro_optional_debug_tools.cc @@ -77,6 +77,8 @@ const char* TensorTypeName(TfLiteType type) { return "kTfLiteComplex64"; case kTfLiteFloat16: return "kTfLiteFloat16"; + case kTfLiteFloat64: + return "kTfLiteFloat64"; } return "(invalid)"; } diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index 4e8c2ae5758..e8b44bcbea6 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -48,20 +48,23 @@ INCLUDES := \ -I. \ -I$(MAKEFILE_DIR)/downloads/ \ -I$(MAKEFILE_DIR)/downloads/gemmlowp \ --I$(MAKEFILE_DIR)/downloads/flatbuffers/include +-I$(MAKEFILE_DIR)/downloads/flatbuffers/include \ +-I$(MAKEFILE_DIR)/downloads/ruy # Same list of paths, but now relative to the generated project files. GENERATED_PROJECT_INCLUDES := \ -I. \ -I./third_party/gemmlowp \ --I./third_party/flatbuffers/include +-I./third_party/flatbuffers/include \ +-I./third_party/ruy # Same list of paths, but now in the format the generate_keil_project.py # script expects them. PROJECT_INCLUDES := \ . \ third_party/gemmlowp \ -third_party/flatbuffers/include +third_party/flatbuffers/include \ +third_party/ruy TEST_SCRIPT := tensorflow/lite/micro/testing/test_linux_binary.sh @@ -99,7 +102,7 @@ $(wildcard tensorflow/lite/micro/kernels/*test.cc) \ $(wildcard tensorflow/lite/micro/memory_planner/*test.cc) MICROLITE_BENCHMARK_SRCS := \ -$(wildcard tensorflow/lite/micro/kernels/xtensa-hifimini/*benchmark.cc) +$(wildcard tensorflow/lite/micro/kernels/benchmarks/*.cc) MICROLITE_TEST_HDRS := \ $(wildcard tensorflow/lite/micro/testing/*.h) @@ -132,7 +135,6 @@ tensorflow/lite/core/api/error_reporter.h \ tensorflow/lite/core/api/flatbuffer_conversions.h \ tensorflow/lite/core/api/op_resolver.h \ tensorflow/lite/core/api/tensor_utils.h \ -tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h \ tensorflow/lite/kernels/internal/common.h \ tensorflow/lite/kernels/internal/compatibility.h \ tensorflow/lite/kernels/internal/optimized/neon_check.h \ @@ -192,7 +194,9 @@ third_party/gemmlowp/LICENSE \ third_party/flatbuffers/include/flatbuffers/base.h \ third_party/flatbuffers/include/flatbuffers/stl_emulation.h \ third_party/flatbuffers/include/flatbuffers/flatbuffers.h \ -third_party/flatbuffers/LICENSE.txt +third_party/flatbuffers/LICENSE.txt \ +third_party/ruy/ruy/profiler/instrumentation.h + MAKE_PROJECT_FILES := \ Makefile \ diff --git a/tensorflow/lite/micro/tools/make/fix_arduino_subfolders.py b/tensorflow/lite/micro/tools/make/fix_arduino_subfolders.py index fce809cd65c..8676794d3c5 100755 --- a/tensorflow/lite/micro/tools/make/fix_arduino_subfolders.py +++ b/tensorflow/lite/micro/tools/make/fix_arduino_subfolders.py @@ -28,7 +28,7 @@ import six def rename_example_subfolder_files(library_dir): """Moves source files in example subfolders to equivalents at root.""" - patterns = ['*.h', '*.cpp'] + patterns = ['*.h', '*.cpp', '*.c'] for pattern in patterns: search_path = os.path.join(library_dir, 'examples/*/*', pattern) for source_file_path in glob.glob(search_path): diff --git a/tensorflow/lite/micro/tools/make/helper_functions.inc b/tensorflow/lite/micro/tools/make/helper_functions.inc index 184a0293ad7..aee04c63256 100644 --- a/tensorflow/lite/micro/tools/make/helper_functions.inc +++ b/tensorflow/lite/micro/tools/make/helper_functions.inc @@ -163,6 +163,14 @@ endef # can invoke to create the standalone project. define generate_arduino_project +$(PRJDIR)$(2)/arduino/examples/%.c: tensorflow/lite/micro/examples/%.c + @mkdir -p $$(dir $$@) + @python tensorflow/lite/micro/tools/make/transform_source.py \ + --platform=arduino \ + --is_example_source \ + --source_path="$$<" \ + --third_party_headers="$(4)" < $$< > $$@ + $(PRJDIR)$(2)/arduino/examples/%.cpp: tensorflow/lite/micro/examples/%.cc @mkdir -p $$(dir $$@) @python tensorflow/lite/micro/tools/make/transform_source.py \ diff --git a/tensorflow/lite/micro/tools/make/third_party_downloads.inc b/tensorflow/lite/micro/tools/make/third_party_downloads.inc index c4ff652a0ff..189d758eb96 100644 --- a/tensorflow/lite/micro/tools/make/third_party_downloads.inc +++ b/tensorflow/lite/micro/tools/make/third_party_downloads.inc @@ -48,6 +48,9 @@ SIFIVE_FE310_LIB_MD5 := "06ee24c4956f8e21670ab3395861fe64" KISSFFT_URL="https://github.com/mborgerding/kissfft/archive/v130.zip" KISSFFT_MD5="438ba1fef5783cc5f5f201395cc477ca" +RUY_URL="https://github.com/google/ruy/archive/91d62808498cea7ccb48aa59181e218b4ad05701.zip" +RUY_MD5="5e653ae8863408ede2a0ca104fea5b1e" + PERSON_MODEL_URL := "https://storage.googleapis.com/download.tensorflow.org/data/tf_lite_micro_person_data_grayscale_2019_11_21.zip" PERSON_MODEL_MD5 := "fe2934bd0788f1dcc7af3f0a954542ab" diff --git a/tensorflow/lite/nnapi/nnapi_handler.cc b/tensorflow/lite/nnapi/nnapi_handler.cc index c26b18d4ee7..dbbe7d3046c 100644 --- a/tensorflow/lite/nnapi/nnapi_handler.cc +++ b/tensorflow/lite/nnapi/nnapi_handler.cc @@ -50,8 +50,41 @@ void NnApiHandler::Reset() { *nnapi_ = *NnApiPassthroughInstance(); } -void NnApiHandler::SetAndroidSdkVersion(int version) { +void NnApiHandler::SetAndroidSdkVersion(int version, + bool set_unsupported_ops_to_null) { nnapi_->android_sdk_version = version; + + if (!set_unsupported_ops_to_null) { + return; + } + + if (version < 29) { + nnapi_->ANeuralNetworks_getDeviceCount = nullptr; + nnapi_->ANeuralNetworks_getDevice = nullptr; + nnapi_->ANeuralNetworksDevice_getName = nullptr; + nnapi_->ANeuralNetworksDevice_getVersion = nullptr; + nnapi_->ANeuralNetworksDevice_getFeatureLevel = nullptr; + nnapi_->ANeuralNetworksDevice_getType = nullptr; + nnapi_->ANeuralNetworksModel_getSupportedOperationsForDevices = nullptr; + nnapi_->ANeuralNetworksCompilation_createForDevices = nullptr; + nnapi_->ANeuralNetworksCompilation_setCaching = nullptr; + nnapi_->ANeuralNetworksExecution_compute = nullptr; + nnapi_->ANeuralNetworksExecution_getOutputOperandRank = nullptr; + nnapi_->ANeuralNetworksExecution_getOutputOperandDimensions = nullptr; + nnapi_->ANeuralNetworksBurst_create = nullptr; + nnapi_->ANeuralNetworksBurst_free = nullptr; + nnapi_->ANeuralNetworksExecution_burstCompute = nullptr; + nnapi_->ANeuralNetworksMemory_createFromAHardwareBuffer = nullptr; + nnapi_->ANeuralNetworksExecution_setMeasureTiming = nullptr; + nnapi_->ANeuralNetworksExecution_getDuration = nullptr; + nnapi_->ANeuralNetworksDevice_getExtensionSupport = nullptr; + nnapi_->ANeuralNetworksModel_getExtensionOperandType = nullptr; + nnapi_->ANeuralNetworksModel_getExtensionOperationType = nullptr; + nnapi_->ANeuralNetworksModel_setOperandExtensionData = nullptr; + } + if (version < 28) { + nnapi_->ANeuralNetworksModel_relaxComputationFloat32toFloat16 = nullptr; + } } void NnApiHandler::SetDeviceName(const std::string& name) { diff --git a/tensorflow/lite/nnapi/nnapi_handler.h b/tensorflow/lite/nnapi/nnapi_handler.h index 1ccdae5a214..00c0b23e3cf 100644 --- a/tensorflow/lite/nnapi/nnapi_handler.h +++ b/tensorflow/lite/nnapi/nnapi_handler.h @@ -252,7 +252,29 @@ class NnApiHandler { nnapi_->ANeuralNetworksModel_getSupportedOperationsForDevices = stub; } - void SetAndroidSdkVersion(int version); + template + void ExecutionStartComputeReturns() { + nnapi_->ANeuralNetworksExecution_startCompute = + [](ANeuralNetworksExecution* execution, ANeuralNetworksEvent** event) { + *event = reinterpret_cast(1); + return Value; + }; + } + + template + void EventWaitReturns() { + nnapi_->ANeuralNetworksEvent_wait = [](ANeuralNetworksEvent* event) { + return Value; + }; + } + + /* + * Sets the SDK Version in the nnapi structure. + * If set_unsupported_ops_to_null is set to true, all the functions not + * available at the given sdk level will be set to null too. + */ + void SetAndroidSdkVersion(int version, + bool set_unsupported_ops_to_null = false); const NnApi* GetNnApi() { return nnapi_; } diff --git a/tensorflow/lite/nnapi/nnapi_handler_test.cc b/tensorflow/lite/nnapi/nnapi_handler_test.cc index aea766ef036..e6fb410bb94 100644 --- a/tensorflow/lite/nnapi/nnapi_handler_test.cc +++ b/tensorflow/lite/nnapi/nnapi_handler_test.cc @@ -85,6 +85,93 @@ TEST_F(NnApiHandlerTest, ShouldSupportPassthroughCalls) { EXPECT_THAT(nnapi->ANeuralNetworks_getDeviceCount(&device_count), Eq(1)); } +TEST_F(NnApiHandlerTest, ShouldSetNnApiMembersToNullAsPerSdkVersion_NNAPI11) { + auto* handler = NnApiHandler::Instance(); + + // Setting non null values for nnapi functions + handler->SetNnapiSupportedDevice("devvice", 1000); + handler->GetSupportedOperationsForDevicesReturns<1>(); + handler->CompilationCreateForDevicesReturns<1>(); + handler->ExecutionComputeReturns<1>(); + handler->MemoryCreateFromFdReturns<1>(); + + handler->SetAndroidSdkVersion(28, /*set_unsupported_ops_to_null=*/true); + + const NnApi* nnapi = NnApiImplementation(); + + using ::testing::IsNull; + + EXPECT_THAT(nnapi->ANeuralNetworks_getDeviceCount, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworks_getDevice, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksDevice_getName, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksDevice_getVersion, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksDevice_getFeatureLevel, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksDevice_getType, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksModel_getSupportedOperationsForDevices, + IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksCompilation_createForDevices, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksCompilation_setCaching, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksExecution_compute, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksExecution_getOutputOperandRank, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksExecution_getOutputOperandDimensions, + IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksBurst_create, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksBurst_free, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksExecution_burstCompute, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksMemory_createFromAHardwareBuffer, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksExecution_setMeasureTiming, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksExecution_getDuration, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksDevice_getExtensionSupport, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksModel_getExtensionOperandType, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksModel_getExtensionOperationType, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksModel_setOperandExtensionData, IsNull()); +} + +TEST_F(NnApiHandlerTest, ShouldSetNnApiMembersToNullAsPerSdkVersion_NNAPI10) { + auto* handler = NnApiHandler::Instance(); + + // Setting non null values for nnapi functions + handler->SetNnapiSupportedDevice("devvice", 1000); + handler->GetSupportedOperationsForDevicesReturns<1>(); + handler->CompilationCreateForDevicesReturns<1>(); + handler->ExecutionComputeReturns<1>(); + handler->MemoryCreateFromFdReturns<1>(); + + handler->SetAndroidSdkVersion(27, /*set_unsupported_ops_to_null=*/true); + + const NnApi* nnapi = NnApiImplementation(); + + using ::testing::IsNull; + + EXPECT_THAT(nnapi->ANeuralNetworks_getDeviceCount, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworks_getDevice, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksDevice_getName, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksDevice_getVersion, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksDevice_getFeatureLevel, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksDevice_getType, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksModel_getSupportedOperationsForDevices, + IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksCompilation_createForDevices, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksCompilation_setCaching, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksExecution_compute, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksExecution_getOutputOperandRank, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksExecution_getOutputOperandDimensions, + IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksBurst_create, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksBurst_free, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksExecution_burstCompute, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksMemory_createFromAHardwareBuffer, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksExecution_setMeasureTiming, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksExecution_getDuration, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksDevice_getExtensionSupport, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksModel_getExtensionOperandType, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksModel_getExtensionOperationType, IsNull()); + EXPECT_THAT(nnapi->ANeuralNetworksModel_setOperandExtensionData, IsNull()); + + EXPECT_THAT(nnapi->ANeuralNetworksModel_relaxComputationFloat32toFloat16, + IsNull()); +} + void ExpectEquals(const NnApi& left, const NnApi& right) { #define EXPECT_NNAPI_MEMBER_EQ(name) EXPECT_EQ(left.name, right.name) diff --git a/tensorflow/lite/optional_debug_tools.cc b/tensorflow/lite/optional_debug_tools.cc index 4e9b7d4e0a4..c5ccdb98390 100644 --- a/tensorflow/lite/optional_debug_tools.cc +++ b/tensorflow/lite/optional_debug_tools.cc @@ -59,6 +59,8 @@ const char* TensorTypeName(TfLiteType type) { return "kTfLiteComplex64"; case kTfLiteFloat16: return "kTfLiteFloat16"; + case kTfLiteFloat64: + return "kTfLiteFloat64"; } return "(invalid)"; } diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index 1744defea94..89b0a91f665 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -257,7 +257,11 @@ def build_toco_convert_protos(input_tensors, target_ops=None, allow_nonexistent_arrays=False, debug_info=None, - conversion_summary_dir=None): + conversion_summary_dir=None, + saved_model_dir=None, + saved_model_version=0, + saved_model_tags=None, + saved_model_exported_names=None): """Builds protocol buffers describing a conversion of a model using TOCO. Typically this is to convert from TensorFlow GraphDef to TFLite, in which @@ -323,6 +327,18 @@ def build_toco_convert_protos(input_tensors, debug_info: `GraphDebugInfo` proto containing the stack traces for the original nodes referred by the converted graph. conversion_summary_dir: A string, the path to the generated conversion logs. + saved_model_dir: Filepath of the saved model to be converted. This value + will be non-empty only when the saved model import path will be used. + Otherwises, the graph def-based conversion will be processed. + saved_model_version: SavedModel file format version of The saved model file + to be converted. This value will be set only when the SavedModel import + path will be used. + saved_model_tags: Set of string saved model tags, formatted in the + comma-separated value. This value will be set only when the SavedModel + import path will be used. + saved_model_exported_names: Names to be exported (default: export all) when + the saved model import path is on. This value will be set only when the + SavedModel import path will be used. Returns: model_flags, toco_flags, debug_info: three protocol buffers describing the @@ -397,6 +413,14 @@ def build_toco_convert_protos(input_tensors, model.allow_nonexistent_arrays = allow_nonexistent_arrays + if saved_model_dir: + model.saved_model_dir = saved_model_dir + model.saved_model_version = saved_model_version + if saved_model_tags: + model.saved_model_tags.extend(saved_model_tags) + if saved_model_exported_names: + model.saved_model_exported_names.extend(saved_model_exported_names) + return model, toco, debug_info diff --git a/tensorflow/lite/python/interpreter_wrapper/numpy.cc b/tensorflow/lite/python/interpreter_wrapper/numpy.cc index b8c6555c285..00e5064e620 100644 --- a/tensorflow/lite/python/interpreter_wrapper/numpy.cc +++ b/tensorflow/lite/python/interpreter_wrapper/numpy.cc @@ -38,6 +38,8 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { return NPY_FLOAT32; case kTfLiteFloat16: return NPY_FLOAT16; + case kTfLiteFloat64: + return NPY_FLOAT64; case kTfLiteInt32: return NPY_INT32; case kTfLiteInt16: diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index ba9e6e0bd39..97d3f2a1ec6 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -74,6 +74,7 @@ from tensorflow.python.lib.io import file_io as _file_io from tensorflow.python.saved_model import signature_constants as _signature_constants from tensorflow.python.saved_model import tag_constants as _tag_constants from tensorflow.python.saved_model.load import load as _load +from tensorflow.python.saved_model.loader_impl import parse_saved_model_with_debug_info as _parse_saved_model_with_debug_info from tensorflow.python.util import deprecation as _deprecation from tensorflow.python.util.tf_export import tf_export as _tf_export @@ -285,6 +286,10 @@ class TFLiteConverterBase(object): # The 'GraphDebugInfo' contains the stack traces of all the original nodes # in the `GraphDef` to the converter. self._debug_info = None + self._saved_model_dir = None + self._saved_model_tags = None + self._saved_model_version = None + self._saved_model_exported_names = [] def _grappler_config(self, optimizers=None): """Creates a tf.compat.v1.ConfigProto for configuring Grappler. @@ -346,8 +351,46 @@ class TFLiteConverterBase(object): "target_ops": self.target_spec.supported_ops, "enable_mlir_converter": self.experimental_new_converter, } + + if self._saved_model_dir: + args.update({ + "saved_model_dir": self._saved_model_dir, + "saved_model_version": self._saved_model_version, + "saved_model_tags": self._saved_model_tags, + "saved_model_exported_names": self._saved_model_exported_names, + }) + return args + def _contains_function_with_implements_attr(self, saved_model_proto): + meta_graph = saved_model_proto.meta_graphs[0] + for function in meta_graph.graph_def.library.function: + if function.attr.get("_implements", None) or function.attr.get( + "api_implements", None): + return True + return False + + def _parse_saved_model_args(self): + """Parses SavedModel arguments from the given Keras/RNN SavedModel.""" + if self._saved_model_dir: + try: + saved_model_proto, _ = ( + _parse_saved_model_with_debug_info(self._saved_model_dir)) + except OSError: + # If it fails to read the given saved model, it will fall back to the + # frozen graph def path. + self._saved_model_dir = None + return + if not self._contains_function_with_implements_attr(saved_model_proto): + self._saved_model_dir = None + else: + self._saved_model_exported_names = [] + self._saved_model_version = saved_model_proto.saved_model_schema_version + if self._saved_model_version not in [1, 2]: + raise ValueError( + "SavedModel file format({0}) is not supported".format( + self._saved_model_version)) + @_tf_export("lite.TFLiteConverter", v1=[]) class TFLiteConverterV2(TFLiteConverterBase): @@ -387,7 +430,11 @@ class TFLiteConverterV2(TFLiteConverterBase): ``` """ - def __init__(self, funcs, trackable_obj=None): + def __init__(self, + funcs, + trackable_obj=None, + saved_model_dir=None, + saved_model_tags=None): """Constructor for TFLiteConverter. Args: @@ -398,10 +445,19 @@ class TFLiteConverterV2(TFLiteConverterBase): get garbage collected since functions have a weak reference to Variables. This is only required when the tf.AutoTrackable object is not maintained by the user (e.g. `from_saved_model`). + saved_model_dir: Directory of the SavedModel. This argument can be null + when it creates via the from_keras_model and from_concrete_function + methods. + saved_model_tags: Set of tags identifying the MetaGraphDef within the + SavedModel to analyze. All tags in the tag set must be present. (default + set(SERVING)). This argument will be available when the saved model dir + argument is set. """ super(TFLiteConverterV2, self).__init__() self._funcs = funcs self._trackable_obj = trackable_obj + self._saved_model_dir = saved_model_dir + self._saved_model_tags = saved_model_tags @classmethod def from_concrete_functions(cls, funcs): @@ -463,6 +519,9 @@ class TFLiteConverterV2(TFLiteConverterBase): # Ensures any graphs created in Eager mode are able to run. This is required # in order to create a tf.estimator.Exporter that exports a TFLite model. + if tags is None: + tags = set([_tag_constants.SERVING]) + with context.eager_mode(): saved_model = _load(saved_model_dir, tags) if not signature_keys: @@ -475,7 +534,7 @@ class TFLiteConverterV2(TFLiteConverterBase): "'{}'.".format(key, ",".join(saved_model.signatures))) funcs.append(saved_model.signatures[key]) - return cls(funcs, saved_model) + return cls(funcs, saved_model, saved_model_dir, tags) @classmethod def from_keras_model(cls, model): @@ -521,6 +580,9 @@ class TFLiteConverterV2(TFLiteConverterBase): "ConcreteFunction. Converting multiple functions is " "under development.") + # Parses SavedModel argument. + self._parse_saved_model_args() + # graph_def is used here to preserve the node bug information frozen_func, graph_def = ( _convert_to_constants.convert_variables_to_constants_v2_as_graph( @@ -693,6 +755,7 @@ class TFLiteConverter(TFLiteConverterBase): the dataset to evaluate different optimizations. experimental_new_converter: Experimental flag, subject to change. Enables MLIR-based conversion instead of TOCO conversion. + Example usage: ```python @@ -725,7 +788,9 @@ class TFLiteConverter(TFLiteConverterBase): output_tensors, input_arrays_with_shape=None, output_arrays=None, - experimental_debug_info_func=None): + experimental_debug_info_func=None, + saved_model_dir=None, + saved_model_tags=None): """Constructor for TFLiteConverter. Args: @@ -743,6 +808,13 @@ class TFLiteConverter(TFLiteConverterBase): `output_tensors` are None. (default None) experimental_debug_info_func: An experimental function to retrieve the graph debug info for a set of nodes from the `graph_def`. + saved_model_dir: Directory of the SavedModel. This argument can be null + when it creates via the from_keras_model and from_concrete_function + methods. + saved_model_tags: Set of tags identifying the MetaGraphDef within the + SavedModel to analyze. All tags in the tag set must be present. (default + set(SERVING)). This argument will be available when the saved model dir + argument is set. Raises: ValueError: Invalid arguments. @@ -766,6 +838,8 @@ class TFLiteConverter(TFLiteConverterBase): self.conversion_summary_dir = None self._debug_info_func = experimental_debug_info_func self._custom_opdefs = None + self._saved_model_dir = saved_model_dir + self._saved_model_tags = saved_model_tags # Attributes are used by models that cannot be loaded into TensorFlow. if not self._has_valid_tensors(): @@ -928,7 +1002,9 @@ class TFLiteConverter(TFLiteConverterBase): graph_def=result[0], input_tensors=result[1], output_tensors=result[2], - experimental_debug_info_func=_build_debug_info_func(result[3])) + experimental_debug_info_func=_build_debug_info_func(result[3]), + saved_model_dir=saved_model_dir, + saved_model_tags=tag_set) @classmethod def from_keras_model_file(cls, @@ -1059,6 +1135,9 @@ class TFLiteConverter(TFLiteConverterBase): Input shape is not specified. None value for dimension in input_tensor. """ + # Parses SavedModel argument. + self._parse_saved_model_args() + quant_mode = QuantizationMode(self.optimizations, self.target_spec, self.representative_dataset, self._graph_def) diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.cc b/tensorflow/lite/python/optimize/calibration_wrapper.cc index cadd5538f5b..5a7a3ae2aa5 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper.cc +++ b/tensorflow/lite/python/optimize/calibration_wrapper.cc @@ -68,6 +68,8 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) { return TensorType_FLOAT32; case kTfLiteFloat16: return TensorType_FLOAT16; + case kTfLiteFloat64: + return TensorType_FLOAT64; case kTfLiteInt32: return TensorType_INT32; case kTfLiteUInt8: diff --git a/tensorflow/lite/python/util.py b/tensorflow/lite/python/util.py index c22970a2ab4..faf9ba611ed 100644 --- a/tensorflow/lite/python/util.py +++ b/tensorflow/lite/python/util.py @@ -43,6 +43,7 @@ from tensorflow.python.training.saver import export_meta_graph as _export_meta_g _MAP_TF_TO_TFLITE_TYPES = { dtypes.float32: _types_pb2.FLOAT, dtypes.float16: _types_pb2.FLOAT16, + dtypes.float64: _types_pb2.FLOAT64, dtypes.int32: _types_pb2.INT32, dtypes.int64: _types_pb2.INT64, dtypes.string: _types_pb2.STRING, diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index 24cd73eef7a..32ccbe8cbee 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -40,6 +40,7 @@ enum TensorType : byte { INT16 = 7, COMPLEX64 = 8, INT8 = 9, + FLOAT64 = 10, } // Custom quantization parameters for experimenting with new quantization @@ -519,17 +520,22 @@ table LSHProjectionOptions { table SVDFOptions { rank:int; fused_activation_function:ActivationFunctionType; + // For weights-only quantization, use asymmetric quantization for non + // constant inputs at evaluation time. + asymmetric_quantize_inputs:bool; } // An implementation of TensorFlow RNNCell. table RNNOptions { fused_activation_function:ActivationFunctionType; + asymmetric_quantize_inputs:bool; } // An implementation of TensorFlow dynamic_rnn with RNNCell. table SequenceRNNOptions { time_major:bool; fused_activation_function:ActivationFunctionType; + asymmetric_quantize_inputs:bool; } // An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell. @@ -537,6 +543,7 @@ table BidirectionalSequenceRNNOptions { time_major:bool; fused_activation_function:ActivationFunctionType; merge_outputs: bool; + asymmetric_quantize_inputs:bool; } enum FullyConnectedOptionsWeightsFormat: byte { @@ -556,6 +563,11 @@ table FullyConnectedOptions { // If set to true, then the number of dimension is preserved. Furthermore, // all but the last dimension of the input and output shapes will be equal. keep_num_dims: bool; + + // Parameters for FullyConnected version 7 or above. + // If set to true, then weights-only op will use asymmetric quantization for + // inputs. + asymmetric_quantize_inputs: bool; } table SoftmaxOptions { @@ -604,6 +616,9 @@ table LSTMOptions { // Parameters for LSTM version 2 or above. // Basic kernel is only supported in version 2 or above. kernel_type: LSTMKernelType = FULL; + + // Parameters for LSTM version 4 or above. + asymmetric_quantize_inputs: bool; } // An implementation of TensorFlow dynamic_rnn with LSTMCell. @@ -614,6 +629,9 @@ table UnidirectionalSequenceLSTMOptions { // If true then first dimension is sequence, otherwise batch. time_major:bool; + + // Parameter for Unidirectional Sequence LSTM version 4. + asymmetric_quantize_inputs:bool; } table BidirectionalSequenceLSTMOptions { @@ -630,6 +648,9 @@ table BidirectionalSequenceLSTMOptions { // Version 1 implementations assumed time_major to be true, so this default // value should never change. time_major: bool = true; + + // Parameters for version 3 or above. + asymmetric_quantize_inputs:bool; } table ResizeBilinearOptions { diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index 609eac198fb..34b36bf7354 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -378,11 +378,12 @@ enum TensorType { TensorType_INT16 = 7, TensorType_COMPLEX64 = 8, TensorType_INT8 = 9, + TensorType_FLOAT64 = 10, TensorType_MIN = TensorType_FLOAT32, - TensorType_MAX = TensorType_INT8 + TensorType_MAX = TensorType_FLOAT64 }; -inline const TensorType (&EnumValuesTensorType())[10] { +inline const TensorType (&EnumValuesTensorType())[11] { static const TensorType values[] = { TensorType_FLOAT32, TensorType_FLOAT16, @@ -393,13 +394,14 @@ inline const TensorType (&EnumValuesTensorType())[10] { TensorType_BOOL, TensorType_INT16, TensorType_COMPLEX64, - TensorType_INT8 + TensorType_INT8, + TensorType_FLOAT64 }; return values; } inline const char * const *EnumNamesTensorType() { - static const char * const names[11] = { + static const char * const names[12] = { "FLOAT32", "FLOAT16", "INT32", @@ -410,13 +412,14 @@ inline const char * const *EnumNamesTensorType() { "INT16", "COMPLEX64", "INT8", + "FLOAT64", nullptr }; return names; } inline const char *EnumNameTensorType(TensorType e) { - if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_INT8)) return ""; + if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_FLOAT64)) return ""; const size_t index = static_cast(e); return EnumNamesTensorType()[index]; } @@ -4216,9 +4219,11 @@ struct SVDFOptionsT : public flatbuffers::NativeTable { typedef SVDFOptions TableType; int32_t rank; tflite::ActivationFunctionType fused_activation_function; + bool asymmetric_quantize_inputs; SVDFOptionsT() : rank(0), - fused_activation_function(tflite::ActivationFunctionType_NONE) { + fused_activation_function(tflite::ActivationFunctionType_NONE), + asymmetric_quantize_inputs(false) { } }; @@ -4226,7 +4231,8 @@ struct SVDFOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef SVDFOptionsT NativeTableType; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_RANK = 4, - VT_FUSED_ACTIVATION_FUNCTION = 6 + VT_FUSED_ACTIVATION_FUNCTION = 6, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 8 }; int32_t rank() const { return GetField(VT_RANK, 0); @@ -4234,10 +4240,14 @@ struct SVDFOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { tflite::ActivationFunctionType fused_activation_function() const { return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } + bool asymmetric_quantize_inputs() const { + return GetField(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_RANK) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + VerifyField(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) && verifier.EndTable(); } SVDFOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -4254,6 +4264,9 @@ struct SVDFOptionsBuilder { void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { fbb_.AddElement(SVDFOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) { + fbb_.AddElement(SVDFOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast(asymmetric_quantize_inputs), 0); + } explicit SVDFOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -4269,9 +4282,11 @@ struct SVDFOptionsBuilder { inline flatbuffers::Offset CreateSVDFOptions( flatbuffers::FlatBufferBuilder &_fbb, int32_t rank = 0, - tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) { + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, + bool asymmetric_quantize_inputs = false) { SVDFOptionsBuilder builder_(_fbb); builder_.add_rank(rank); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } @@ -4281,22 +4296,29 @@ flatbuffers::Offset CreateSVDFOptions(flatbuffers::FlatBufferBuilde struct RNNOptionsT : public flatbuffers::NativeTable { typedef RNNOptions TableType; tflite::ActivationFunctionType fused_activation_function; + bool asymmetric_quantize_inputs; RNNOptionsT() - : fused_activation_function(tflite::ActivationFunctionType_NONE) { + : fused_activation_function(tflite::ActivationFunctionType_NONE), + asymmetric_quantize_inputs(false) { } }; struct RNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef RNNOptionsT NativeTableType; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { - VT_FUSED_ACTIVATION_FUNCTION = 4 + VT_FUSED_ACTIVATION_FUNCTION = 4, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 6 }; tflite::ActivationFunctionType fused_activation_function() const { return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } + bool asymmetric_quantize_inputs() const { + return GetField(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + VerifyField(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) && verifier.EndTable(); } RNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -4310,6 +4332,9 @@ struct RNNOptionsBuilder { void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { fbb_.AddElement(RNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) { + fbb_.AddElement(RNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast(asymmetric_quantize_inputs), 0); + } explicit RNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -4324,8 +4349,10 @@ struct RNNOptionsBuilder { inline flatbuffers::Offset CreateRNNOptions( flatbuffers::FlatBufferBuilder &_fbb, - tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) { + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, + bool asymmetric_quantize_inputs = false) { RNNOptionsBuilder builder_(_fbb); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } @@ -4336,9 +4363,11 @@ struct SequenceRNNOptionsT : public flatbuffers::NativeTable { typedef SequenceRNNOptions TableType; bool time_major; tflite::ActivationFunctionType fused_activation_function; + bool asymmetric_quantize_inputs; SequenceRNNOptionsT() : time_major(false), - fused_activation_function(tflite::ActivationFunctionType_NONE) { + fused_activation_function(tflite::ActivationFunctionType_NONE), + asymmetric_quantize_inputs(false) { } }; @@ -4346,7 +4375,8 @@ struct SequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef SequenceRNNOptionsT NativeTableType; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_TIME_MAJOR = 4, - VT_FUSED_ACTIVATION_FUNCTION = 6 + VT_FUSED_ACTIVATION_FUNCTION = 6, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 8 }; bool time_major() const { return GetField(VT_TIME_MAJOR, 0) != 0; @@ -4354,10 +4384,14 @@ struct SequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { tflite::ActivationFunctionType fused_activation_function() const { return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } + bool asymmetric_quantize_inputs() const { + return GetField(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_TIME_MAJOR) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + VerifyField(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) && verifier.EndTable(); } SequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -4374,6 +4408,9 @@ struct SequenceRNNOptionsBuilder { void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { fbb_.AddElement(SequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) { + fbb_.AddElement(SequenceRNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast(asymmetric_quantize_inputs), 0); + } explicit SequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -4389,8 +4426,10 @@ struct SequenceRNNOptionsBuilder { inline flatbuffers::Offset CreateSequenceRNNOptions( flatbuffers::FlatBufferBuilder &_fbb, bool time_major = false, - tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) { + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, + bool asymmetric_quantize_inputs = false) { SequenceRNNOptionsBuilder builder_(_fbb); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); builder_.add_fused_activation_function(fused_activation_function); builder_.add_time_major(time_major); return builder_.Finish(); @@ -4403,10 +4442,12 @@ struct BidirectionalSequenceRNNOptionsT : public flatbuffers::NativeTable { bool time_major; tflite::ActivationFunctionType fused_activation_function; bool merge_outputs; + bool asymmetric_quantize_inputs; BidirectionalSequenceRNNOptionsT() : time_major(false), fused_activation_function(tflite::ActivationFunctionType_NONE), - merge_outputs(false) { + merge_outputs(false), + asymmetric_quantize_inputs(false) { } }; @@ -4415,7 +4456,8 @@ struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuf enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_TIME_MAJOR = 4, VT_FUSED_ACTIVATION_FUNCTION = 6, - VT_MERGE_OUTPUTS = 8 + VT_MERGE_OUTPUTS = 8, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 10 }; bool time_major() const { return GetField(VT_TIME_MAJOR, 0) != 0; @@ -4426,11 +4468,15 @@ struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuf bool merge_outputs() const { return GetField(VT_MERGE_OUTPUTS, 0) != 0; } + bool asymmetric_quantize_inputs() const { + return GetField(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_TIME_MAJOR) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField(verifier, VT_MERGE_OUTPUTS) && + VerifyField(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) && verifier.EndTable(); } BidirectionalSequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -4450,6 +4496,9 @@ struct BidirectionalSequenceRNNOptionsBuilder { void add_merge_outputs(bool merge_outputs) { fbb_.AddElement(BidirectionalSequenceRNNOptions::VT_MERGE_OUTPUTS, static_cast(merge_outputs), 0); } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) { + fbb_.AddElement(BidirectionalSequenceRNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast(asymmetric_quantize_inputs), 0); + } explicit BidirectionalSequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -4466,8 +4515,10 @@ inline flatbuffers::Offset CreateBidirectionalS flatbuffers::FlatBufferBuilder &_fbb, bool time_major = false, tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, - bool merge_outputs = false) { + bool merge_outputs = false, + bool asymmetric_quantize_inputs = false) { BidirectionalSequenceRNNOptionsBuilder builder_(_fbb); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); builder_.add_merge_outputs(merge_outputs); builder_.add_fused_activation_function(fused_activation_function); builder_.add_time_major(time_major); @@ -4481,10 +4532,12 @@ struct FullyConnectedOptionsT : public flatbuffers::NativeTable { tflite::ActivationFunctionType fused_activation_function; tflite::FullyConnectedOptionsWeightsFormat weights_format; bool keep_num_dims; + bool asymmetric_quantize_inputs; FullyConnectedOptionsT() : fused_activation_function(tflite::ActivationFunctionType_NONE), weights_format(tflite::FullyConnectedOptionsWeightsFormat_DEFAULT), - keep_num_dims(false) { + keep_num_dims(false), + asymmetric_quantize_inputs(false) { } }; @@ -4493,7 +4546,8 @@ struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tabl enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_FUSED_ACTIVATION_FUNCTION = 4, VT_WEIGHTS_FORMAT = 6, - VT_KEEP_NUM_DIMS = 8 + VT_KEEP_NUM_DIMS = 8, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 10 }; tflite::ActivationFunctionType fused_activation_function() const { return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); @@ -4504,11 +4558,15 @@ struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tabl bool keep_num_dims() const { return GetField(VT_KEEP_NUM_DIMS, 0) != 0; } + bool asymmetric_quantize_inputs() const { + return GetField(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField(verifier, VT_WEIGHTS_FORMAT) && VerifyField(verifier, VT_KEEP_NUM_DIMS) && + VerifyField(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) && verifier.EndTable(); } FullyConnectedOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -4528,6 +4586,9 @@ struct FullyConnectedOptionsBuilder { void add_keep_num_dims(bool keep_num_dims) { fbb_.AddElement(FullyConnectedOptions::VT_KEEP_NUM_DIMS, static_cast(keep_num_dims), 0); } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) { + fbb_.AddElement(FullyConnectedOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast(asymmetric_quantize_inputs), 0); + } explicit FullyConnectedOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -4544,8 +4605,10 @@ inline flatbuffers::Offset CreateFullyConnectedOptions( flatbuffers::FlatBufferBuilder &_fbb, tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, tflite::FullyConnectedOptionsWeightsFormat weights_format = tflite::FullyConnectedOptionsWeightsFormat_DEFAULT, - bool keep_num_dims = false) { + bool keep_num_dims = false, + bool asymmetric_quantize_inputs = false) { FullyConnectedOptionsBuilder builder_(_fbb); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); builder_.add_keep_num_dims(keep_num_dims); builder_.add_weights_format(weights_format); builder_.add_fused_activation_function(fused_activation_function); @@ -4932,11 +4995,13 @@ struct LSTMOptionsT : public flatbuffers::NativeTable { float cell_clip; float proj_clip; tflite::LSTMKernelType kernel_type; + bool asymmetric_quantize_inputs; LSTMOptionsT() : fused_activation_function(tflite::ActivationFunctionType_NONE), cell_clip(0.0f), proj_clip(0.0f), - kernel_type(tflite::LSTMKernelType_FULL) { + kernel_type(tflite::LSTMKernelType_FULL), + asymmetric_quantize_inputs(false) { } }; @@ -4946,7 +5011,8 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_FUSED_ACTIVATION_FUNCTION = 4, VT_CELL_CLIP = 6, VT_PROJ_CLIP = 8, - VT_KERNEL_TYPE = 10 + VT_KERNEL_TYPE = 10, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 12 }; tflite::ActivationFunctionType fused_activation_function() const { return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); @@ -4960,12 +5026,16 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { tflite::LSTMKernelType kernel_type() const { return static_cast(GetField(VT_KERNEL_TYPE, 0)); } + bool asymmetric_quantize_inputs() const { + return GetField(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField(verifier, VT_CELL_CLIP) && VerifyField(verifier, VT_PROJ_CLIP) && VerifyField(verifier, VT_KERNEL_TYPE) && + VerifyField(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) && verifier.EndTable(); } LSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -4988,6 +5058,9 @@ struct LSTMOptionsBuilder { void add_kernel_type(tflite::LSTMKernelType kernel_type) { fbb_.AddElement(LSTMOptions::VT_KERNEL_TYPE, static_cast(kernel_type), 0); } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) { + fbb_.AddElement(LSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast(asymmetric_quantize_inputs), 0); + } explicit LSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -5005,10 +5078,12 @@ inline flatbuffers::Offset CreateLSTMOptions( tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, float cell_clip = 0.0f, float proj_clip = 0.0f, - tflite::LSTMKernelType kernel_type = tflite::LSTMKernelType_FULL) { + tflite::LSTMKernelType kernel_type = tflite::LSTMKernelType_FULL, + bool asymmetric_quantize_inputs = false) { LSTMOptionsBuilder builder_(_fbb); builder_.add_proj_clip(proj_clip); builder_.add_cell_clip(cell_clip); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); builder_.add_kernel_type(kernel_type); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); @@ -5022,11 +5097,13 @@ struct UnidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable { float cell_clip; float proj_clip; bool time_major; + bool asymmetric_quantize_inputs; UnidirectionalSequenceLSTMOptionsT() : fused_activation_function(tflite::ActivationFunctionType_NONE), cell_clip(0.0f), proj_clip(0.0f), - time_major(false) { + time_major(false), + asymmetric_quantize_inputs(false) { } }; @@ -5036,7 +5113,8 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb VT_FUSED_ACTIVATION_FUNCTION = 4, VT_CELL_CLIP = 6, VT_PROJ_CLIP = 8, - VT_TIME_MAJOR = 10 + VT_TIME_MAJOR = 10, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 12 }; tflite::ActivationFunctionType fused_activation_function() const { return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); @@ -5050,12 +5128,16 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb bool time_major() const { return GetField(VT_TIME_MAJOR, 0) != 0; } + bool asymmetric_quantize_inputs() const { + return GetField(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField(verifier, VT_CELL_CLIP) && VerifyField(verifier, VT_PROJ_CLIP) && VerifyField(verifier, VT_TIME_MAJOR) && + VerifyField(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) && verifier.EndTable(); } UnidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -5078,6 +5160,9 @@ struct UnidirectionalSequenceLSTMOptionsBuilder { void add_time_major(bool time_major) { fbb_.AddElement(UnidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast(time_major), 0); } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) { + fbb_.AddElement(UnidirectionalSequenceLSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast(asymmetric_quantize_inputs), 0); + } explicit UnidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -5095,10 +5180,12 @@ inline flatbuffers::Offset CreateUnidirection tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, float cell_clip = 0.0f, float proj_clip = 0.0f, - bool time_major = false) { + bool time_major = false, + bool asymmetric_quantize_inputs = false) { UnidirectionalSequenceLSTMOptionsBuilder builder_(_fbb); builder_.add_proj_clip(proj_clip); builder_.add_cell_clip(cell_clip); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); builder_.add_time_major(time_major); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); @@ -5113,12 +5200,14 @@ struct BidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable { float proj_clip; bool merge_outputs; bool time_major; + bool asymmetric_quantize_inputs; BidirectionalSequenceLSTMOptionsT() : fused_activation_function(tflite::ActivationFunctionType_NONE), cell_clip(0.0f), proj_clip(0.0f), merge_outputs(false), - time_major(true) { + time_major(true), + asymmetric_quantize_inputs(false) { } }; @@ -5129,7 +5218,8 @@ struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbu VT_CELL_CLIP = 6, VT_PROJ_CLIP = 8, VT_MERGE_OUTPUTS = 10, - VT_TIME_MAJOR = 12 + VT_TIME_MAJOR = 12, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 14 }; tflite::ActivationFunctionType fused_activation_function() const { return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); @@ -5146,6 +5236,9 @@ struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbu bool time_major() const { return GetField(VT_TIME_MAJOR, 1) != 0; } + bool asymmetric_quantize_inputs() const { + return GetField(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && @@ -5153,6 +5246,7 @@ struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbu VerifyField(verifier, VT_PROJ_CLIP) && VerifyField(verifier, VT_MERGE_OUTPUTS) && VerifyField(verifier, VT_TIME_MAJOR) && + VerifyField(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) && verifier.EndTable(); } BidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -5178,6 +5272,9 @@ struct BidirectionalSequenceLSTMOptionsBuilder { void add_time_major(bool time_major) { fbb_.AddElement(BidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast(time_major), 1); } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) { + fbb_.AddElement(BidirectionalSequenceLSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast(asymmetric_quantize_inputs), 0); + } explicit BidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -5196,10 +5293,12 @@ inline flatbuffers::Offset CreateBidirectional float cell_clip = 0.0f, float proj_clip = 0.0f, bool merge_outputs = false, - bool time_major = true) { + bool time_major = true, + bool asymmetric_quantize_inputs = false) { BidirectionalSequenceLSTMOptionsBuilder builder_(_fbb); builder_.add_proj_clip(proj_clip); builder_.add_cell_clip(cell_clip); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); builder_.add_time_major(time_major); builder_.add_merge_outputs(merge_outputs); builder_.add_fused_activation_function(fused_activation_function); @@ -11034,6 +11133,7 @@ inline void SVDFOptions::UnPackTo(SVDFOptionsT *_o, const flatbuffers::resolver_ (void)_resolver; { auto _e = rank(); _o->rank = _e; } { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; } } inline flatbuffers::Offset SVDFOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11046,10 +11146,12 @@ inline flatbuffers::Offset CreateSVDFOptions(flatbuffers::FlatBuffe struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SVDFOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _rank = _o->rank; auto _fused_activation_function = _o->fused_activation_function; + auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs; return tflite::CreateSVDFOptions( _fbb, _rank, - _fused_activation_function); + _fused_activation_function, + _asymmetric_quantize_inputs); } inline RNNOptionsT *RNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -11062,6 +11164,7 @@ inline void RNNOptions::UnPackTo(RNNOptionsT *_o, const flatbuffers::resolver_fu (void)_o; (void)_resolver; { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; } } inline flatbuffers::Offset RNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11073,9 +11176,11 @@ inline flatbuffers::Offset CreateRNNOptions(flatbuffers::FlatBufferB (void)_o; struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const RNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _fused_activation_function = _o->fused_activation_function; + auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs; return tflite::CreateRNNOptions( _fbb, - _fused_activation_function); + _fused_activation_function, + _asymmetric_quantize_inputs); } inline SequenceRNNOptionsT *SequenceRNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -11089,6 +11194,7 @@ inline void SequenceRNNOptions::UnPackTo(SequenceRNNOptionsT *_o, const flatbuff (void)_resolver; { auto _e = time_major(); _o->time_major = _e; } { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; } } inline flatbuffers::Offset SequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11101,10 +11207,12 @@ inline flatbuffers::Offset CreateSequenceRNNOptions(flatbuff struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SequenceRNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _time_major = _o->time_major; auto _fused_activation_function = _o->fused_activation_function; + auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs; return tflite::CreateSequenceRNNOptions( _fbb, _time_major, - _fused_activation_function); + _fused_activation_function, + _asymmetric_quantize_inputs); } inline BidirectionalSequenceRNNOptionsT *BidirectionalSequenceRNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -11119,6 +11227,7 @@ inline void BidirectionalSequenceRNNOptions::UnPackTo(BidirectionalSequenceRNNOp { auto _e = time_major(); _o->time_major = _e; } { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } { auto _e = merge_outputs(); _o->merge_outputs = _e; } + { auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; } } inline flatbuffers::Offset BidirectionalSequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11132,11 +11241,13 @@ inline flatbuffers::Offset CreateBidirectionalS auto _time_major = _o->time_major; auto _fused_activation_function = _o->fused_activation_function; auto _merge_outputs = _o->merge_outputs; + auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs; return tflite::CreateBidirectionalSequenceRNNOptions( _fbb, _time_major, _fused_activation_function, - _merge_outputs); + _merge_outputs, + _asymmetric_quantize_inputs); } inline FullyConnectedOptionsT *FullyConnectedOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -11151,6 +11262,7 @@ inline void FullyConnectedOptions::UnPackTo(FullyConnectedOptionsT *_o, const fl { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } { auto _e = weights_format(); _o->weights_format = _e; } { auto _e = keep_num_dims(); _o->keep_num_dims = _e; } + { auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; } } inline flatbuffers::Offset FullyConnectedOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11164,11 +11276,13 @@ inline flatbuffers::Offset CreateFullyConnectedOptions(fl auto _fused_activation_function = _o->fused_activation_function; auto _weights_format = _o->weights_format; auto _keep_num_dims = _o->keep_num_dims; + auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs; return tflite::CreateFullyConnectedOptions( _fbb, _fused_activation_function, _weights_format, - _keep_num_dims); + _keep_num_dims, + _asymmetric_quantize_inputs); } inline SoftmaxOptionsT *SoftmaxOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -11352,6 +11466,7 @@ inline void LSTMOptions::UnPackTo(LSTMOptionsT *_o, const flatbuffers::resolver_ { auto _e = cell_clip(); _o->cell_clip = _e; } { auto _e = proj_clip(); _o->proj_clip = _e; } { auto _e = kernel_type(); _o->kernel_type = _e; } + { auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; } } inline flatbuffers::Offset LSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11366,12 +11481,14 @@ inline flatbuffers::Offset CreateLSTMOptions(flatbuffers::FlatBuffe auto _cell_clip = _o->cell_clip; auto _proj_clip = _o->proj_clip; auto _kernel_type = _o->kernel_type; + auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs; return tflite::CreateLSTMOptions( _fbb, _fused_activation_function, _cell_clip, _proj_clip, - _kernel_type); + _kernel_type, + _asymmetric_quantize_inputs); } inline UnidirectionalSequenceLSTMOptionsT *UnidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -11387,6 +11504,7 @@ inline void UnidirectionalSequenceLSTMOptions::UnPackTo(UnidirectionalSequenceLS { auto _e = cell_clip(); _o->cell_clip = _e; } { auto _e = proj_clip(); _o->proj_clip = _e; } { auto _e = time_major(); _o->time_major = _e; } + { auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; } } inline flatbuffers::Offset UnidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11401,12 +11519,14 @@ inline flatbuffers::Offset CreateUnidirection auto _cell_clip = _o->cell_clip; auto _proj_clip = _o->proj_clip; auto _time_major = _o->time_major; + auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs; return tflite::CreateUnidirectionalSequenceLSTMOptions( _fbb, _fused_activation_function, _cell_clip, _proj_clip, - _time_major); + _time_major, + _asymmetric_quantize_inputs); } inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -11423,6 +11543,7 @@ inline void BidirectionalSequenceLSTMOptions::UnPackTo(BidirectionalSequenceLSTM { auto _e = proj_clip(); _o->proj_clip = _e; } { auto _e = merge_outputs(); _o->merge_outputs = _e; } { auto _e = time_major(); _o->time_major = _e; } + { auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; } } inline flatbuffers::Offset BidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11438,13 +11559,15 @@ inline flatbuffers::Offset CreateBidirectional auto _proj_clip = _o->proj_clip; auto _merge_outputs = _o->merge_outputs; auto _time_major = _o->time_major; + auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs; return tflite::CreateBidirectionalSequenceLSTMOptions( _fbb, _fused_activation_function, _cell_clip, _proj_clip, _merge_outputs, - _time_major); + _time_major, + _asymmetric_quantize_inputs); } inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { diff --git a/tensorflow/lite/testing/op_tests/binary_op.py b/tensorflow/lite/testing/op_tests/binary_op.py index 118c95dc777..9d0c85e35aa 100644 --- a/tensorflow/lite/testing/op_tests/binary_op.py +++ b/tensorflow/lite/testing/op_tests/binary_op.py @@ -114,6 +114,18 @@ def make_binary_op_tests(options, }, ] + # float64 types are supported via flex only. + if options.run_with_flex and options.use_experimental_converter: + test_parameters = test_parameters + [ + { + "dtype": [tf.float64], + "input_shape_1": [[7]], + "input_shape_2": [[7]], + "activation": [False], + "fully_quantize": [False], + }, + ] + # test_parameters include fully_quantize option only when # allow_fully_quantize is True. if not allow_fully_quantize: @@ -184,7 +196,18 @@ def make_add_tests(options): @register_make_test_function() def make_div_tests(options): - make_binary_op_tests(options, tf.compat.v1.div) + """Make zip tests for div op with 5D case.""" + test_parameters = [ + { + "dtype": [tf.float32], + "input_shape_1": [[1, 3, 3, 3, 3]], + "input_shape_2": [[3]], + "activation": [False], + "fully_quantize": [False], + }, + ] + make_binary_op_tests( + options, tf.compat.v1.div, test_parameters=test_parameters) @register_make_test_function() diff --git a/tensorflow/lite/testing/op_tests/gather_nd.py b/tensorflow/lite/testing/op_tests/gather_nd.py index 1137488469e..13f317c25a4 100644 --- a/tensorflow/lite/testing/op_tests/gather_nd.py +++ b/tensorflow/lite/testing/op_tests/gather_nd.py @@ -29,19 +29,19 @@ def make_gather_nd_tests(options): test_parameters = [ { - "params_dtype": [tf.float32, tf.int32, tf.int64], + "params_dtype": [tf.float32, tf.int32, tf.int64, tf.string], "params_shape": [[5, 1]], "indices_dtype": [tf.int32, tf.int64], "indices_shape": [[1, 1]], }, { - "params_dtype": [tf.float32, tf.int32, tf.int64], + "params_dtype": [tf.float32, tf.int32, tf.int64, tf.string], "params_shape": [[5, 5]], "indices_dtype": [tf.int32, tf.int64], "indices_shape": [[2, 1], [2, 2]], }, { - "params_dtype": [tf.float32, tf.int32, tf.int64], + "params_dtype": [tf.float32, tf.int32, tf.int64, tf.string], "params_shape": [[5, 5, 10]], "indices_dtype": [tf.int32, tf.int64], "indices_shape": [[3, 1], [2, 2], [2, 3], [2, 1, 3]], diff --git a/tensorflow/lite/testing/zip_test_utils.py b/tensorflow/lite/testing/zip_test_utils.py index be837a8a669..c5de72ef9f7 100644 --- a/tensorflow/lite/testing/zip_test_utils.py +++ b/tensorflow/lite/testing/zip_test_utils.py @@ -75,6 +75,7 @@ RANDOM_SEED = 342 TF_TYPE_INFO = { tf.float32: (np.float32, "FLOAT"), tf.float16: (np.float16, "FLOAT"), + tf.float64: (np.double, "FLOAT64"), tf.int32: (np.int32, "INT32"), tf.uint8: (np.uint8, "QUANTIZED_UINT8"), tf.int16: (np.int16, "QUANTIZED_INT16"), @@ -108,7 +109,7 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100): if dtype in TF_TYPE_INFO: dtype = TF_TYPE_INFO[dtype][0] - if dtype in (tf.float32, tf.float16): + if dtype in (tf.float32, tf.float16, tf.float64): value = (max_value - min_value) * np.random.random_sample(shape) + min_value elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16): value = np.random.randint(min_value, max_value + 1, shape) @@ -128,7 +129,7 @@ def create_scalar_data(dtype, min_value=-100, max_value=100): if dtype in TF_TYPE_INFO: dtype = TF_TYPE_INFO[dtype][0] - if dtype in (tf.float32, tf.float16): + if dtype in (tf.float32, tf.float16, tf.float64): value = (max_value - min_value) * np.random.random() + min_value elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16): value = np.random.randint(min_value, max_value + 1) diff --git a/tensorflow/lite/toco/model.h b/tensorflow/lite/toco/model.h index 9c669c2760f..11a400318d1 100644 --- a/tensorflow/lite/toco/model.h +++ b/tensorflow/lite/toco/model.h @@ -234,6 +234,7 @@ enum class ArrayDataType : uint8 { kString, kComplex64, kFloat16, + kFloat64, }; // Compile-time logic to map ArrayDataType to the corresponding C++ scalar type diff --git a/tensorflow/lite/toco/model_flags.proto b/tensorflow/lite/toco/model_flags.proto index dfc425073f5..7fd42e4afd8 100644 --- a/tensorflow/lite/toco/model_flags.proto +++ b/tensorflow/lite/toco/model_flags.proto @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto2"; -import "tensorflow/lite/toco/types.proto"; package toco; +import "tensorflow/lite/toco/types.proto"; + message InputArrayShape { repeated int32 dims = 2; } @@ -130,7 +131,7 @@ message ArraysExtraInfo { // optional int32 input_dims = 11 [ default = 4]; // repeated int32 input_shape = 13; // -// Next ID to USE: 20. +// Next ID to USE: 24. message ModelFlags { // Information about the input arrays, i.e. the arrays from which input // activations will be read. @@ -181,4 +182,22 @@ message ModelFlags { // When set to false, toco will not change the input ranges and the output // ranges of concat operator to the overlap of all input ranges. optional bool change_concat_input_ranges = 19 [default = true]; + + // Filepath of the saved model to be converted. This value will be non-empty + // only when the saved model import path will be used. Otherwise, the graph + // def-based conversion will be processed. + optional string saved_model_dir = 20; + + // SavedModel file format version of The saved model file to be converted. + // This value will be set only when the SavedModel import path will be used. + optional int32 saved_model_version = 21; + + // Set of string saved model tags, formatted in the comma-separated value. + // This value will be set only when the SavedModel import path will be used. + repeated string saved_model_tags = 22; + + // Names to be exported (default: export all) when the saved model import path + // is on. This value will be set only when the SavedModel import path will be + // used. + repeated string saved_model_exported_names = 23; } diff --git a/tensorflow/lite/toco/python/BUILD b/tensorflow/lite/toco/python/BUILD index b8a00b90a06..236913c9678 100644 --- a/tensorflow/lite/toco/python/BUILD +++ b/tensorflow/lite/toco/python/BUILD @@ -49,6 +49,7 @@ cc_library( "//tensorflow/lite/toco:tooling_util", "//tensorflow/core:protos_all_cc", "//tensorflow/compiler/mlir/lite/python:graphdef_to_tfl_flatbuffer", + "//tensorflow/compiler/mlir/lite/python:saved_model_to_tfl_flatbuffer", ] + select({ # This is required when running `tflite_convert` from `bazel`. # It requires to link with TensorFlow Ops to get the op definitions. diff --git a/tensorflow/lite/toco/python/toco_python_api.cc b/tensorflow/lite/toco/python/toco_python_api.cc index 31de4cfc726..667754e956f 100644 --- a/tensorflow/lite/toco/python/toco_python_api.cc +++ b/tensorflow/lite/toco/python/toco_python_api.cc @@ -21,6 +21,7 @@ limitations under the License. #include "google/protobuf/text_format.h" #include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h" +#include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h" #include "tensorflow/lite/toco/import_tensorflow.h" @@ -144,13 +145,6 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, } } - tensorflow::GraphDef graph_def; - if (!graph_def.ParseFromString(input_contents_txt)) { - PyErr_SetString(PyExc_ValueError, - "Failed to convert GraphDef to Python String."); - return nullptr; - } - auto& dump_options = *GraphVizDumpOptions::singleton(); if (toco_flags.has_dump_graphviz_dir()) { dump_options.dump_graphviz = toco_flags.dump_graphviz_dir(); @@ -165,13 +159,25 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, // Convert model. if (enable_mlir_converter) { - status = tensorflow::ConvertGraphDefToTFLiteFlatBuffer( - model_flags, toco_flags, debug_info, graph_def, - &output_file_contents_txt); - if (!toco_flags.conversion_summary_dir().empty()) { - PopulateConversionLogHelper(model_flags, &toco_flags, input_contents_txt, - output_file_contents_txt, - status.error_message(), &dump_options); + if (!model_flags.saved_model_dir().empty()) { + status = tensorflow::ConvertSavedModelToTFLiteFlatBuffer( + model_flags, toco_flags, &output_file_contents_txt); + } else { + tensorflow::GraphDef graph_def; + if (!graph_def.ParseFromString(input_contents_txt)) { + PyErr_SetString(PyExc_ValueError, + "Failed to convert GraphDef to Python String."); + return nullptr; + } + + status = tensorflow::ConvertGraphDefToTFLiteFlatBuffer( + model_flags, toco_flags, debug_info, graph_def, + &output_file_contents_txt); + if (!toco_flags.conversion_summary_dir().empty()) { + PopulateConversionLogHelper( + model_flags, &toco_flags, input_contents_txt, + output_file_contents_txt, status.error_message(), &dump_options); + } } } else { status = Convert(input_contents_txt, toco_flags, model_flags, diff --git a/tensorflow/lite/toco/tflite/op_version.cc b/tensorflow/lite/toco/tflite/op_version.cc index bbec4f91646..1c699410a3e 100644 --- a/tensorflow/lite/toco/tflite/op_version.cc +++ b/tensorflow/lite/toco/tflite/op_version.cc @@ -73,7 +73,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) { {{OperatorType::kGather, 1}, "1.6.0"}, {{OperatorType::kGather, 2}, "1.14.0"}, {{OperatorType::kGather, 3}, "1.15.0"}, - {{OperatorType::kGatherNd, 1}, "1.14.0"}, + {{OperatorType::kGatherNd, 2}, kPendingReleaseOpVersion}, {{OperatorType::kSvdf, 1}, "1.5.0"}, {{OperatorType::kSvdf, 2}, "1.14.0"}, {{OperatorType::kSvdf, 3}, "2.2.0"}, diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc index ce7b95377aa..4c041e82b3b 100644 --- a/tensorflow/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -51,7 +51,8 @@ namespace tflite { {ArrayDataType::kInt64, ::tflite::TensorType_INT64}, {ArrayDataType::kString, ::tflite::TensorType_STRING}, {ArrayDataType::kComplex64, ::tflite::TensorType_COMPLEX64}, - {ArrayDataType::kFloat16, ::tflite::TensorType_FLOAT16}}; + {ArrayDataType::kFloat16, ::tflite::TensorType_FLOAT16}, + {ArrayDataType::kFloat64, ::tflite::TensorType_FLOAT64}}; auto it = tensor_type_map.find(type); if (it != tensor_type_map.end()) { @@ -302,6 +303,23 @@ class Div : public BuiltinOperatorfused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const OperatorSignature& op_signature) const override { + const string& input1_name = op_signature.op->inputs[0]; + const string& input2_name = op_signature.op->inputs[1]; + const Array& input1_array = op_signature.model->GetArray(input1_name); + const Array& input2_array = op_signature.model->GetArray(input2_name); + ::tflite::OpSignature op_sig = + GetVersioningOpSig(builtin_op(), op_signature); + if (input1_array.has_shape() && input2_array.has_shape()) { + op_sig.options.broadcast.num_dims = + std::max(input1_array.shape().dimensions_count(), + input2_array.shape().dimensions_count()); + op_sig.options.broadcast.need_broadcast = + (input1_array.shape() != input2_array.shape()); + } + return ::tflite::GetBuiltinOperatorVersion(op_sig); + } }; class BatchToSpaceND diff --git a/tensorflow/lite/toco/tflite/operator_test.cc b/tensorflow/lite/toco/tflite/operator_test.cc index a81900cb038..dd0c2946795 100644 --- a/tensorflow/lite/toco/tflite/operator_test.cc +++ b/tensorflow/lite/toco/tflite/operator_test.cc @@ -1040,9 +1040,10 @@ TEST_F(OperatorTest, VersioningMulTest) { SimpleMulVersioningTest(ArrayDataType::kInt8, 2.0f, 3); } -void SimpleSubVersioningTest(ArrayDataType data_type, Shape shape1, - Shape shape2, int version) { - SubOperator op; +template +void SimpleTwoInputsVersioningTest(ArrayDataType data_type, Shape shape1, + Shape shape2, int version) { + OpType op; op.inputs = {"input1", "input2"}; op.outputs = {"output"}; auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/); @@ -1064,16 +1065,33 @@ void SimpleSubVersioningTest(ArrayDataType data_type, Shape shape1, } TEST_F(OperatorTest, VersioningSubTest) { - SimpleSubVersioningTest(ArrayDataType::kUint8, {1, 2, 2, 2}, {1, 2, 2, 2}, 1); - SimpleSubVersioningTest(ArrayDataType::kInt8, {1, 2, 2, 2}, {1, 2, 2, 2}, 2); - SimpleSubVersioningTest(ArrayDataType::kUint8, {1, 2, 2}, {1, 2, 2}, 1); - SimpleSubVersioningTest(ArrayDataType::kInt8, {1, 2, 2}, {1, 2, 2}, 2); - SimpleSubVersioningTest(ArrayDataType::kUint8, {1, 2, 2, 2}, {1, 2, 2, 1}, 1); - SimpleSubVersioningTest(ArrayDataType::kInt8, {1, 2, 2, 2}, {1, 2, 2, 1}, 2); - SimpleSubVersioningTest(ArrayDataType::kUint8, {1, 2, 2, 2, 2}, - {1, 2, 2, 2, 1}, 3); - SimpleSubVersioningTest(ArrayDataType::kInt8, {1, 2, 2, 2, 2}, - {1, 2, 2, 2, 1}, 3); + SimpleTwoInputsVersioningTest(ArrayDataType::kUint8, + {1, 2, 2, 2}, {1, 2, 2, 2}, 1); + SimpleTwoInputsVersioningTest(ArrayDataType::kInt8, {1, 2, 2, 2}, + {1, 2, 2, 2}, 2); + SimpleTwoInputsVersioningTest(ArrayDataType::kUint8, {1, 2, 2}, + {1, 2, 2}, 1); + SimpleTwoInputsVersioningTest(ArrayDataType::kInt8, {1, 2, 2}, + {1, 2, 2}, 2); + SimpleTwoInputsVersioningTest(ArrayDataType::kUint8, + {1, 2, 2, 2}, {1, 2, 2, 1}, 1); + SimpleTwoInputsVersioningTest(ArrayDataType::kInt8, {1, 2, 2, 2}, + {1, 2, 2, 1}, 2); + SimpleTwoInputsVersioningTest( + ArrayDataType::kUint8, {1, 2, 2, 2, 2}, {1, 2, 2, 2, 1}, 3); + SimpleTwoInputsVersioningTest( + ArrayDataType::kInt8, {1, 2, 2, 2, 2}, {1, 2, 2, 2, 1}, 3); +} + +TEST_F(OperatorTest, VersioningDivTest) { + SimpleTwoInputsVersioningTest(ArrayDataType::kUint8, + {1, 2, 2, 2}, {1, 2, 2, 2}, 1); + SimpleTwoInputsVersioningTest(ArrayDataType::kInt8, {1, 2, 2}, + {1, 2, 2}, 1); + SimpleTwoInputsVersioningTest(ArrayDataType::kUint8, + {1, 2, 2, 2}, {1, 2, 2, 1}, 1); + SimpleTwoInputsVersioningTest( + ArrayDataType::kInt8, {1, 2, 2, 2, 2}, {1, 2, 2, 2, 1}, 2); } TEST_F(OperatorTest, VersioningPadTest) { SimpleVersioningTest(); } diff --git a/tensorflow/lite/toco/tooling_util.cc b/tensorflow/lite/toco/tooling_util.cc index 55b98972da6..d0a3b146bda 100644 --- a/tensorflow/lite/toco/tooling_util.cc +++ b/tensorflow/lite/toco/tooling_util.cc @@ -1766,6 +1766,8 @@ int ElementSize(ArrayDataType data_type) { return 8; case ArrayDataType::kComplex64: return 8; + case ArrayDataType::kFloat64: + return 8; // Usually not critical limitation because strings are only input and/or // output. @@ -2307,6 +2309,10 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) { return ArrayDataType::kString; case COMPLEX64: return ArrayDataType::kComplex64; + case FLOAT16: + return ArrayDataType::kFloat16; + case FLOAT64: + return ArrayDataType::kFloat64; default: return ArrayDataType::kNone; } diff --git a/tensorflow/lite/toco/types.proto b/tensorflow/lite/toco/types.proto index 2c655517431..029a159321e 100644 --- a/tensorflow/lite/toco/types.proto +++ b/tensorflow/lite/toco/types.proto @@ -49,4 +49,7 @@ enum IODataType { // Half precision float, not quantized. FLOAT16 = 10; + + // Double precision float, not quantized. + FLOAT64 = 11; } diff --git a/tensorflow/lite/tools/benchmark/BUILD b/tensorflow/lite/tools/benchmark/BUILD index 1dd7e928c20..d10c1acb95d 100644 --- a/tensorflow/lite/tools/benchmark/BUILD +++ b/tensorflow/lite/tools/benchmark/BUILD @@ -154,7 +154,7 @@ cc_library( "@com_google_absl//absl/strings", "//tensorflow/lite:framework", "//tensorflow/lite:string_util", - "//tensorflow/lite/experimental/ruy/ruy/profiler", + "@ruy//ruy/profiler", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/profiling:platform_profiler", "//tensorflow/lite/profiling:profiler", diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc index 617976991e1..a451eab5448 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc @@ -28,7 +28,7 @@ limitations under the License. #include "absl/base/attributes.h" #include "absl/strings/numbers.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/profiler.h" +#include "ruy/profiler/profiler.h" // from @ruy #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/op_resolver.h" @@ -497,6 +497,10 @@ BenchmarkTfLiteModel::CreateRandomTensorData(const TfLiteTensor& t, #endif // TFLITE_ENABLE_FP16_CPU_BENCHMARKS break; } + case kTfLiteFloat64: { + return CreateInputTensorData( + num_elements, std::uniform_real_distribution(-0.5, 0.5)); + } case kTfLiteInt64: { int low = has_value_range ? low_range : 0; int high = has_value_range ? high_range : 99; diff --git a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h index 10280df05b3..39ec547198e 100644 --- a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h +++ b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h @@ -236,6 +236,7 @@ typedef enum { kTfLiteComplex64 = 8, kTfLiteInt8 = 9, kTfLiteFloat16 = 10, + kTfLiteFloat64 = 11, } TfLiteType; // Return the name of a given type, for error reporting purposes. diff --git a/tensorflow/lite/tools/make/Makefile b/tensorflow/lite/tools/make/Makefile index 9043d494235..7a77cf2b3f5 100644 --- a/tensorflow/lite/tools/make/Makefile +++ b/tensorflow/lite/tools/make/Makefile @@ -36,6 +36,7 @@ INCLUDES := \ -I$(MAKEFILE_DIR)/downloads/eigen \ -I$(MAKEFILE_DIR)/downloads/absl \ -I$(MAKEFILE_DIR)/downloads/gemmlowp \ +-I$(MAKEFILE_DIR)/downloads/ruy \ -I$(MAKEFILE_DIR)/downloads/neon_2_sse \ -I$(MAKEFILE_DIR)/downloads/farmhash/src \ -I$(MAKEFILE_DIR)/downloads/flatbuffers/include \ @@ -119,7 +120,7 @@ $(wildcard tensorflow/lite/c/*.c) \ $(wildcard tensorflow/lite/core/*.cc) \ $(wildcard tensorflow/lite/core/api/*.cc) \ $(wildcard tensorflow/lite/experimental/resource/*.cc) \ -$(wildcard tensorflow/lite/experimental/ruy/ruy/*.cc) +$(wildcard tensorflow/lite/tools/make/downloads/ruy/ruy/*.cc) ifneq ($(BUILD_TYPE),micro) CORE_CC_ALL_SRCS += \ $(wildcard tensorflow/lite/kernels/*.cc) \ @@ -151,6 +152,11 @@ $(wildcard tensorflow/lite/*/*/*/example*.cc) \ $(wildcard tensorflow/lite/*/*/*/test*.cc) \ $(wildcard tensorflow/lite/*/*/*/*test.cc) \ $(wildcard tensorflow/lite/*/*/*/*tool.cc) \ +$(wildcard tensorflow/lite/*/*/*/*/*/benchmark.cc) \ +$(wildcard tensorflow/lite/*/*/*/*/*/example*.cc) \ +$(wildcard tensorflow/lite/*/*/*/*/*/test*.cc) \ +$(wildcard tensorflow/lite/*/*/*/*/*/*test.cc) \ +$(wildcard tensorflow/lite/*/*/*/*/*/*tool.cc) \ $(wildcard tensorflow/lite/kernels/*test_main.cc) \ $(wildcard tensorflow/lite/kernels/*test_util*.cc) \ tensorflow/lite/tflite_with_xnnpack.cc \ diff --git a/tensorflow/lite/tools/make/download_dependencies.sh b/tensorflow/lite/tools/make/download_dependencies.sh index 2156feafef0..314f4fe6177 100755 --- a/tensorflow/lite/tools/make/download_dependencies.sh +++ b/tensorflow/lite/tools/make/download_dependencies.sh @@ -37,6 +37,8 @@ EIGEN_URL="$(grep -o 'https.*gitlab.com/libeigen/eigen/-/archive/.*tar\.gz' "${B EIGEN_SHA="$(eval echo $(grep '# SHARED_EIGEN_SHA' "${BZL_FILE_PATH}" | grep -o '\".*\"'))" GEMMLOWP_URL="$(grep -o 'https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GEMMLOWP_SHA="$(eval echo $(grep '# SHARED_GEMMLOWP_SHA' "${BZL_FILE_PATH}" | grep -o '\".*\"'))" +RUY_URL="https://github.com/google/ruy/archive/91d62808498cea7ccb48aa59181e218b4ad05701.zip" +RUY_SHA="ac6d71df496a20043252f451d82a01636bb8bba9c3d6b5dc9fadadaffa392751" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" GOOGLETEST_SHA="58a6f4277ca2bc8565222b3bbd58a177609e9c488e8a72649359ba51450db7d8" ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)" @@ -105,6 +107,7 @@ download_and_extract() { download_and_extract "${EIGEN_URL}" "${DOWNLOADS_DIR}/eigen" "${EIGEN_SHA}" download_and_extract "${GEMMLOWP_URL}" "${DOWNLOADS_DIR}/gemmlowp" "${GEMMLOWP_SHA}" +download_and_extract "${RUY_URL}" "${DOWNLOADS_DIR}/ruy" "${RUY_SHA}" download_and_extract "${GOOGLETEST_URL}" "${DOWNLOADS_DIR}/googletest" "${GOOGLETEST_SHA}" download_and_extract "${ABSL_URL}" "${DOWNLOADS_DIR}/absl" "${ABSL_SHA}" download_and_extract "${NEON_2_SSE_URL}" "${DOWNLOADS_DIR}/neon_2_sse" diff --git a/tensorflow/lite/tools/verifier.cc b/tensorflow/lite/tools/verifier.cc index b0025015743..1d0b813a2c4 100644 --- a/tensorflow/lite/tools/verifier.cc +++ b/tensorflow/lite/tools/verifier.cc @@ -350,6 +350,9 @@ bool VerifyNumericTensorBuffer(const Tensor& tensor, const Buffer& buffer, case TensorType_FLOAT16: bytes_required *= sizeof(uint16_t); break; + case TensorType_FLOAT64: + bytes_required *= sizeof(double); + break; case TensorType_INT32: bytes_required *= sizeof(int32_t); break; diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index ec4d2d708bb..7b68e9d698c 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -373,6 +373,20 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { } return 1; + case BuiltinOperator_GATHER_ND: + if (!op_sig.input_types.empty() && + op_sig.input_types.at(0) == TensorType_STRING) { + return 2; + } + return 1; + + case BuiltinOperator_DIV: + if (op_sig.options.broadcast.need_broadcast && + op_sig.options.broadcast.num_dims > 4) { + return 2; + } + return 1; + case BuiltinOperator_AVERAGE_POOL_2D: case BuiltinOperator_ADD: case BuiltinOperator_CONCATENATION: @@ -515,6 +529,7 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op, } break; case BuiltinOperator_SUB: + case BuiltinOperator_DIV: case BuiltinOperator_MAXIMUM: case BuiltinOperator_MINIMUM: { op_sig.options.broadcast.need_broadcast = diff --git a/tensorflow/lite/tools/versioning/op_version_test.cc b/tensorflow/lite/tools/versioning/op_version_test.cc index ecb01bfe954..ae4efce2544 100644 --- a/tensorflow/lite/tools/versioning/op_version_test.cc +++ b/tensorflow/lite/tools/versioning/op_version_test.cc @@ -525,4 +525,33 @@ TEST(OpVersionTest, VersioningTransposeTest) { }; EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); } +TEST(OpVersionTest, VersioningGatherNdOperatorTest) { + OpSignature fake_op_sig = { + .op = BuiltinOperator_GATHER_ND, + .input_types = + std::vector{TensorType_INT32, TensorType_INT32}, + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + fake_op_sig = { + .op = BuiltinOperator_GATHER_ND, + .input_types = + std::vector{TensorType_STRING, TensorType_INT32}, + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); +} +TEST(OpVersionTest, VersioningDivTest) { + OpSignature fake_op_sig = { + .op = BuiltinOperator_DIV, + .input_types = std::vector{TensorType_UINT8}, + }; + + fake_op_sig.options.broadcast.need_broadcast = true; + fake_op_sig.options.broadcast.num_dims = 5; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + fake_op_sig.options.broadcast.need_broadcast = false; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + fake_op_sig.options.broadcast.num_dims = 4; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); +} } // namespace tflite diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc index fda074fabdf..702f9fc7f85 100644 --- a/tensorflow/lite/tools/versioning/runtime_version.cc +++ b/tensorflow/lite/tools/versioning/runtime_version.cc @@ -67,6 +67,7 @@ void UpdateMinimumRuntimeVersionForModel(uint8_t* model_buffer_pointer) { {{BuiltinOperator_SUB, 1}, "1.6.0"}, {{BuiltinOperator_SUB, 2}, "1.14.0"}, {{BuiltinOperator_DIV, 1}, "1.6.0"}, + {{BuiltinOperator_DIV, 2}, kPendingReleaseOpVersion}, {{BuiltinOperator_BATCH_TO_SPACE_ND, 1}, "1.6.0"}, {{BuiltinOperator_BATCH_TO_SPACE_ND, 2}, "1.14.0"}, {{BuiltinOperator_CAST, 1}, "1.5.0"}, diff --git a/tensorflow/lite/type_to_tflitetype.h b/tensorflow/lite/type_to_tflitetype.h index 28efb96f89d..84cd54b5718 100644 --- a/tensorflow/lite/type_to_tflitetype.h +++ b/tensorflow/lite/type_to_tflitetype.h @@ -74,5 +74,9 @@ template <> constexpr TfLiteType typeToTfLiteType() { return kTfLiteFloat16; } +template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteFloat64; +} } // namespace tflite #endif // TENSORFLOW_LITE_TYPE_TO_TFLITETYPE_H_ diff --git a/tensorflow/lite/util.cc b/tensorflow/lite/util.cc index 335c6773039..c91e50b1845 100644 --- a/tensorflow/lite/util.cc +++ b/tensorflow/lite/util.cc @@ -103,6 +103,9 @@ TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type, case kTfLiteFloat16: *bytes = sizeof(TfLiteFloat16); break; + case kTfLiteFloat64: + *bytes = sizeof(double); + break; default: if (context) { context->ReportError( diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index fdb1b66943b..2ebfdbee26c 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -32,6 +32,7 @@ tensorflow/third_party/clang_toolchain/download_clang.bzl tensorflow/third_party/codegen.BUILD tensorflow/third_party/com_google_absl.BUILD tensorflow/third_party/common.bzl +tensorflow/third_party/coremltools.BUILD tensorflow/third_party/cub.BUILD tensorflow/third_party/curl.BUILD tensorflow/third_party/cython.BUILD diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index de3d86afa48..b48a2050205 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -403,11 +403,11 @@ tf_python_pybind_extension( deps = [ ":cost_analyzer_headers", ":pybind11_status", - "//tensorflow/core:core_cpu_headers_lib", "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:gpu_id", "//tensorflow/core:lib_headers_for_pybind", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime:core_cpu_headers_lib", + "//tensorflow/core/common_runtime/gpu:gpu_id", "@pybind11", ], ) @@ -566,8 +566,8 @@ cc_library( visibility = tf_external_workspace_visible(visibility), deps = [ "//tensorflow/c:tf_status_headers", - "//tensorflow/core:core_cpu_headers_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime:core_cpu_headers_lib", "//third_party/python_runtime:headers", "@pybind11", ], @@ -646,7 +646,7 @@ tf_python_pybind_extension( "//third_party/python_runtime:headers", "//tensorflow/core:protos_all_cc", "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:core_cpu_headers_lib", + "//tensorflow/core/common_runtime:core_cpu_headers_lib", "//tensorflow/core:lib_headers_for_pybind", "@com_google_absl//absl/types:optional", ] + if_static( @@ -717,10 +717,10 @@ tf_python_pybind_extension( ":pybind11_lib", ":pybind11_proto", ":pybind11_status", - "//tensorflow/core:core_cpu_headers_lib", "//tensorflow/core:framework_headers_lib", "//tensorflow/core:lib_headers_for_pybind", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime:core_cpu_headers_lib", "//third_party/python_runtime:headers", "@com_google_absl//absl/strings", "@pybind11", @@ -968,8 +968,8 @@ cc_library( deps = [ ":numpy_lib", "//tensorflow/c:tf_status_headers", - "//tensorflow/core:core_cpu_headers_lib", "//tensorflow/core:framework_internal_headers_lib", + "//tensorflow/core/common_runtime:core_cpu_headers_lib", "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", ], @@ -5758,9 +5758,9 @@ tf_python_pybind_extension( deps = [ ":pybind11_proto", ":pybind11_status", - "//tensorflow/core:core_cpu_headers_lib", "//tensorflow/core:framework_internal_headers_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime:core_cpu_headers_lib", "//third_party/python_runtime:headers", "@pybind11", ], @@ -5926,7 +5926,7 @@ filegroup( "//tensorflow/c:tf_status_helper", # tfe "//tensorflow/compiler/jit:flags", #tfe "//tensorflow/compiler/mlir/python:mlir", # mlir - "//tensorflow/core:core_cpu_base_no_ops", # tf_session + "//tensorflow/core/common_runtime:core_cpu_base_no_ops", # tf_session "//tensorflow/core:core_cpu_impl", # device_lib "//tensorflow/core/data/service:server_lib", # server_lib "//tensorflow/core:framework_internal_impl", # op_def_registry @@ -7468,9 +7468,9 @@ tf_python_pybind_extension( deps = [ ":pybind11_status", "@pybind11", - "//tensorflow/core:core_cpu_headers_lib", + "//tensorflow/core/common_runtime:core_cpu_headers_lib", "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:gpu_id", + "//tensorflow/core/common_runtime/gpu:gpu_id", "//tensorflow/core:protos_all_cc", ] + if_not_windows(["//tensorflow/core/grappler/costs:graph_properties"]), # b/148556093, ) @@ -7543,11 +7543,11 @@ tf_python_pybind_extension( module_name = "_pywrap_tf_cluster", deps = [ ":pybind11_status", - "//tensorflow/core:core_cpu_headers_lib", "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:gpu_id", "//tensorflow/core:lib_headers_for_pybind", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime:core_cpu_headers_lib", + "//tensorflow/core/common_runtime/gpu:gpu_id", "@com_google_absl//absl/types:span", "@pybind11", ], @@ -7604,11 +7604,11 @@ tf_python_pybind_extension( module_name = "_pywrap_tf_optimizer", deps = [ ":pybind11_status", - "//tensorflow/core:core_cpu_headers_lib", "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:gpu_id", "//tensorflow/core:lib_headers_for_pybind", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime:core_cpu_headers_lib", + "//tensorflow/core/common_runtime/gpu:gpu_id", "@pybind11", ], ) @@ -7994,7 +7994,7 @@ tf_python_pybind_extension( "@pybind11", "//third_party/python_runtime:headers", "//tensorflow/compiler/jit:flags_headers_only", - "//tensorflow/core:core_cpu_headers_lib", + "//tensorflow/core/common_runtime:core_cpu_headers_lib", "//tensorflow/core:framework_headers_lib", "//tensorflow/core:lib_headers_for_pybind", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/python/autograph/converters/BUILD b/tensorflow/python/autograph/converters/BUILD index 3bd51ce2eb4..8e29fa1961e 100644 --- a/tensorflow/python/autograph/converters/BUILD +++ b/tensorflow/python/autograph/converters/BUILD @@ -28,7 +28,7 @@ py_library( "control_flow.py", "control_flow_deprecated_py2.py", "directives.py", - "function_scopes.py", + "functions.py", "list_comprehensions.py", "lists.py", "logical_expressions.py", @@ -154,8 +154,8 @@ py_test( ) py_test( - name = "function_scopes_test", - srcs = ["function_scopes_test.py"], + name = "functions_test", + srcs = ["functions_test.py"], python_version = "PY3", deps = [ ":converters", diff --git a/tensorflow/python/autograph/converters/asserts_test.py b/tensorflow/python/autograph/converters/asserts_test.py index 0302964c6b1..fd31cd15a0e 100644 --- a/tensorflow/python/autograph/converters/asserts_test.py +++ b/tensorflow/python/autograph/converters/asserts_test.py @@ -19,7 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.autograph.converters import asserts -from tensorflow.python.autograph.converters import function_scopes +from tensorflow.python.autograph.converters import functions from tensorflow.python.autograph.core import converter_testing from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl @@ -36,7 +36,7 @@ class AssertsTest(converter_testing.TestCase): return a with ops.Graph().as_default(): - with self.converted(test_fn, (function_scopes, asserts), {}) as result: + with self.converted(test_fn, (functions, asserts), {}) as result: op = result.test_fn(constant_op.constant(False)) with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'testmsg'): diff --git a/tensorflow/python/autograph/converters/break_statements.py b/tensorflow/python/autograph/converters/break_statements.py index dc5511824c1..60e65e9a1db 100644 --- a/tensorflow/python/autograph/converters/break_statements.py +++ b/tensorflow/python/autograph/converters/break_statements.py @@ -80,28 +80,40 @@ class BreakTransformer(converter.Base): # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) - if break_used: - # Python's else clause only triggers if the loop exited cleanly (e.g. - # break did not trigger). - guarded_orelse = self._guard_if_present(node.orelse, break_var) - + if not break_used: template = """ - var_name = False - while ag__.and_(lambda: test, lambda: ag__.not_(var_name)): + while test: body - else: - orelse + orelse """ node = templates.replace( - template, - var_name=break_var, - test=node.test, - body=node.body, - orelse=guarded_orelse) + template, test=node.test, body=node.body, orelse=node.orelse) - new_while_node = node[1] + new_while_node = node[0] anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES) + return node + + # Python's else clause only triggers if the loop exited cleanly (e.g. + # break did not trigger). + guarded_orelse = self._guard_if_present(node.orelse, break_var) + + template = """ + var_name = False + while ag__.and_(lambda: test, lambda: ag__.not_(var_name)): + body + orelse + """ + node = templates.replace( + template, + var_name=break_var, + test=node.test, + body=node.body, + orelse=guarded_orelse) + + new_while_node = node[1] + anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES) + return node def visit_For(self, node): @@ -115,37 +127,54 @@ class BreakTransformer(converter.Base): # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) - if break_used: - # Python's else clause only triggers if the loop exited cleanly (e.g. - # break did not trigger). - guarded_orelse = self._guard_if_present(node.orelse, break_var) - extra_test = templates.replace_as_expression( - 'ag__.not_(var_name)', var_name=break_var) - - # The extra test is hidden in the AST, which will confuse the static - # analysis. To mitigate that, we insert a no-op statement that ensures - # the control variable is marked as used. - # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) + if not break_used: template = """ - var_name = False for target in iter_: - (var_name,) body - else: - orelse + orelse """ node = templates.replace( template, - var_name=break_var, iter_=node.iter, target=node.target, body=node.body, - orelse=guarded_orelse) + orelse=node.orelse) - new_for_node = node[1] - anno.setanno(new_for_node, anno.Basic.EXTRA_LOOP_TEST, extra_test) + new_for_node = node[0] + anno.copyanno(original_node, new_for_node, anno.Basic.EXTRA_LOOP_TEST) anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES) + return node + + # Python's else clause only triggers if the loop exited cleanly (e.g. + # break did not trigger). + guarded_orelse = self._guard_if_present(node.orelse, break_var) + extra_test = templates.replace_as_expression( + 'ag__.not_(var_name)', var_name=break_var) + + # The extra test is hidden in the AST, which will confuse the static + # analysis. To mitigate that, we insert a no-op statement that ensures + # the control variable is marked as used. + # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) + template = """ + var_name = False + for target in iter_: + (var_name,) + body + orelse + """ + node = templates.replace( + template, + var_name=break_var, + iter_=node.iter, + target=node.target, + body=node.body, + orelse=guarded_orelse) + + new_for_node = node[1] + anno.setanno(new_for_node, anno.Basic.EXTRA_LOOP_TEST, extra_test) + anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES) + return node diff --git a/tensorflow/python/autograph/converters/call_trees.py b/tensorflow/python/autograph/converters/call_trees.py index 810f19b692b..54804fcef3d 100644 --- a/tensorflow/python/autograph/converters/call_trees.py +++ b/tensorflow/python/autograph/converters/call_trees.py @@ -96,47 +96,31 @@ class CallTreeTransformer(converter.Base): """Transforms the call tree by renaming transformed symbols.""" def visit_Lambda(self, node): - if anno.hasanno(node, 'function_context_name'): + if not anno.hasanno(node, 'function_context_name'): # Lambda functions created during the conversion process have no # context manager. - self.state[_Function].enter() - self.state[_Function].context_name = anno.getanno( - node, 'function_context_name') - node = self.generic_visit(node) - self.state[_Function].exit() - else: - node = self.generic_visit(node) - return node + return self.generic_visit(node) + with self.state[_Function] as fn_scope: + fn_scope.context_name = anno.getanno(node, 'function_context_name') + return self.generic_visit(node) def visit_FunctionDef(self, node): - self.state[_Function].enter() - # Note: if the conversion process ever creates helper functions, this - # assumption will no longer hold. - assert anno.hasanno(node, 'function_context_name'), ( - 'The function_scopes converter always creates a scope for functions.') - self.state[_Function].context_name = anno.getanno( - node, 'function_context_name') - node.args = self.visit(node.args) - node.body = self.visit_block(node.body) - - if self.state[_Function].level < 2: - # Top-level functions lose their decorator because the conversion is - # always just-in-time and by the time it happens the decorators are - # already set to be applied. - node.decorator_list = [] - else: - # TODO(mdan): Fix the tests so that we can always add this decorator. - # Inner functions are converted already, so we insert a decorator to - # prevent double conversion. Double conversion would work too, but this - # saves the overhead. - node.decorator_list.append( - parser.parse_expression('ag__.autograph_artifact')) - - if node.returns: - node.returns = self.visit(node.returns) - - self.state[_Function].exit() - return node + # Decorators and arg defaults are part of the outer scope. + node.decorator_list = self.visit_block(node.decorator_list) + node.args.defaults = self.visit_block(node.args.defaults) + for i, d in enumerate(node.args.kw_defaults): + if d is not None: + node.args.kw_defaults[i] = self.visit(d) + with self.state[_Function] as fn_scope: + # Note: if the conversion process ever creates helper functions, this + # assumption will no longer hold. + assert anno.hasanno(node, 'function_context_name'), ( + 'The function_scopes converter always creates a scope for functions.') + fn_scope.context_name = anno.getanno(node, 'function_context_name') + node.body = self.visit_block(node.body) + if node.returns: + node.returns = self.visit(node.returns) + return node def visit_With(self, node): # Context manager calls (in node.items) are not converted. diff --git a/tensorflow/python/autograph/converters/call_trees_test.py b/tensorflow/python/autograph/converters/call_trees_test.py index a4346141232..86ca2dc9c24 100644 --- a/tensorflow/python/autograph/converters/call_trees_test.py +++ b/tensorflow/python/autograph/converters/call_trees_test.py @@ -22,7 +22,7 @@ from __future__ import print_function import imp from tensorflow.python.autograph.converters import call_trees -from tensorflow.python.autograph.converters import function_scopes +from tensorflow.python.autograph.converters import functions from tensorflow.python.autograph.core import converter_testing from tensorflow.python.platform import test @@ -34,7 +34,7 @@ class CallTreesTest(converter_testing.TestCase): def test_fn(f): return f() + 20 - with self.converted(test_fn, (function_scopes, call_trees), {}) as result: + with self.converted(test_fn, (functions, call_trees), {}) as result: self.assertEqual(result.test_fn(lambda: 1), 21) self.assertListEqual(self.dynamic_calls, [((), None)]) @@ -43,7 +43,7 @@ class CallTreesTest(converter_testing.TestCase): def test_fn(f, g): return f(g() + 20) + 4000 - with self.converted(test_fn, (function_scopes, call_trees), {}) as result: + with self.converted(test_fn, (functions, call_trees), {}) as result: self.assertEqual(result.test_fn(lambda x: x + 300, lambda: 1), 4321) self.assertListEqual(self.dynamic_calls, [ ((), None), @@ -55,7 +55,7 @@ class CallTreesTest(converter_testing.TestCase): def test_fn(f, g): return f(g()) + 300 - with self.converted(test_fn, (function_scopes, call_trees), {}) as result: + with self.converted(test_fn, (functions, call_trees), {}) as result: self.assertEqual(result.test_fn(lambda x: x + 20, lambda: 1), 321) self.assertListEqual(self.dynamic_calls, [ ((), None), @@ -70,7 +70,7 @@ class CallTreesTest(converter_testing.TestCase): def test_fn(): return get_one().__add__(20) - with self.converted(test_fn, (function_scopes, call_trees), + with self.converted(test_fn, (functions, call_trees), {'get_one': get_one}, ()) as result: self.assertEqual(result.test_fn(), 21) @@ -85,7 +85,7 @@ class CallTreesTest(converter_testing.TestCase): def test_fn(f, a): return f(a) + 20 - with self.converted(test_fn, (function_scopes, call_trees), {}) as result: + with self.converted(test_fn, (functions, call_trees), {}) as result: self.assertEqual(result.test_fn(lambda a: a, 1), 21) self.assertListEqual(self.dynamic_calls, [((1,), None)]) @@ -94,7 +94,7 @@ class CallTreesTest(converter_testing.TestCase): def test_fn(f, a, b): return f(a, b) + 300 - with self.converted(test_fn, (function_scopes, call_trees), {}) as result: + with self.converted(test_fn, (functions, call_trees), {}) as result: self.assertEqual(result.test_fn(lambda a, b: a + b, 1, 20), 321) self.assertListEqual(self.dynamic_calls, [((1, 20), None)]) @@ -103,7 +103,7 @@ class CallTreesTest(converter_testing.TestCase): def test_fn(f, a, b): return f(a, c=b) + 300 - with self.converted(test_fn, (function_scopes, call_trees), {}) as result: + with self.converted(test_fn, (functions, call_trees), {}) as result: self.assertEqual(result.test_fn(lambda a, c: a + c, 1, 20), 321) self.assertListEqual(self.dynamic_calls, [((1,), {'c': 20})]) @@ -112,7 +112,7 @@ class CallTreesTest(converter_testing.TestCase): def test_fn(f, a, *args, **kwargs): return f(a, *args, **kwargs) + 5 - with self.converted(test_fn, (function_scopes, call_trees), {}) as result: + with self.converted(test_fn, (functions, call_trees), {}) as result: self.assertEqual( result.test_fn(lambda *args, **kwargs: 7, 1, *[2, 3], **{ 'b': 4, @@ -129,7 +129,7 @@ class CallTreesTest(converter_testing.TestCase): args = [1, 20, 300] return f(*args) + 4000 - with self.converted(test_fn, (function_scopes, call_trees), + with self.converted(test_fn, (functions, call_trees), {'f': f}) as result: self.assertEqual(result.test_fn(), 4321) self.assertListEqual(self.dynamic_calls, [((1, 20, 300), None)]) @@ -145,7 +145,7 @@ class CallTreesTest(converter_testing.TestCase): # args2 = [3] # return f(*args1, 2, *args2, 4) # - # with self.converted(test_fn, (function_scopes, call_trees), + # with self.converted(test_fn, (functions, call_trees), # {'f': f}) as result: # self.assertEqual(result.test_fn(), 1234) # self.assertListEqual(self.dynamic_calls, [((1, 2, 3, 4), None)]) @@ -155,7 +155,7 @@ class CallTreesTest(converter_testing.TestCase): def test_fn(f, a, b, **kwargs): return f(a, b=b, **kwargs) + 5 - with self.converted(test_fn, (function_scopes, call_trees), {}) as result: + with self.converted(test_fn, (functions, call_trees), {}) as result: self.assertEqual( result.test_fn(lambda *args, **kwargs: 7, 1, 2, **{'c': 3}), 12) self.assertListEqual(self.dynamic_calls, [((1,), {'b': 2, 'c': 3})]) @@ -166,7 +166,7 @@ class CallTreesTest(converter_testing.TestCase): # def test_fn(f, a, b, c, kwargs1, kwargs2): # return f(a, b=b, **kwargs1, c=c, **kwargs2) + 5 # - # with self.converted(test_fn, (function_scopes, call_trees), {}) as result: + # with self.converted(test_fn, (functions, call_trees), {}) as result: # self.assertEqual( # result.test_fn(lambda *args, **kwargs: 7, 1, 2, 3, {'d': 4}, # {'e': 5}), 12) @@ -188,7 +188,7 @@ class CallTreesTest(converter_testing.TestCase): def test_fn(f, g, a, *args): return f(lambda x: g(x, *args), a) - with self.converted(test_fn, (function_scopes, call_trees), {}) as result: + with self.converted(test_fn, (functions, call_trees), {}) as result: self.assertEqual(result.test_fn(f, g, 1, *(20, 300)), 4321) def test_debugger_set_trace(self): @@ -201,7 +201,7 @@ class CallTreesTest(converter_testing.TestCase): def test_fn(): return pdb.set_trace() - with self.converted(test_fn, (function_scopes, call_trees), + with self.converted(test_fn, (functions, call_trees), {'pdb': pdb}) as result: result.test_fn() self.assertListEqual(tracking_list, [1]) @@ -217,7 +217,7 @@ class CallTreesTest(converter_testing.TestCase): return self.other_method(a) + 300 tc = TestClass() - with self.converted(TestClass.test_method, (function_scopes, call_trees), + with self.converted(TestClass.test_method, (functions, call_trees), {}) as result: self.assertEqual(321, result.test_method(tc, 1)) self.assertListEqual(self.dynamic_calls, [((1,), None)]) @@ -233,7 +233,7 @@ class CallTreesTest(converter_testing.TestCase): return self.other_method(a) + 300 tc = TestClass() - with self.converted(tc.test_method, (function_scopes, call_trees), + with self.converted(tc.test_method, (functions, call_trees), {}) as result: self.assertEqual(321, result.test_method(tc, 1)) self.assertListEqual(self.dynamic_calls, [((1,), None)]) diff --git a/tensorflow/python/autograph/converters/function_scopes.py b/tensorflow/python/autograph/converters/function_scopes.py deleted file mode 100644 index 100a14e4494..00000000000 --- a/tensorflow/python/autograph/converters/function_scopes.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Wraps the body of a converted function with auxiliary constructs.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import gast - -from tensorflow.python.autograph.core import converter -from tensorflow.python.autograph.pyct import anno -from tensorflow.python.autograph.pyct import templates -from tensorflow.python.autograph.pyct.static_analysis import annos - - -class _Function(object): - - def __init__(self): - self.context_name = None - - -class FunctionBodyTransformer(converter.Base): - """Wraps function bodies around autograph-specific boilerplate.""" - - def visit_Return(self, node): - if node.value is None: - return node - return templates.replace( - 'return function_context_name.mark_return_value(value)', - function_context_name=self.state[_Function].context_name, - value=node.value) - - def _function_scope_options(self): - """Returns the options with which to create function scopes.""" - # Top-level function receive the options that were directly requested. - # All others receive the options corresponding to a recursive conversion. - # Note: this mainly controls the user_requested flag, which is important - # primarily because the FunctionScope context also creates a - # ControlStatusCtx(autograph=ENABLED) when user_requested is True. See - # function_wrappers.py. - if self.state[_Function].level == 2: - return self.ctx.program.options - return self.ctx.program.options.call_options() - - def visit_Lambda(self, node): - self.state[_Function].enter() - node = self.generic_visit(node) - - # Only wrap the top-level function. Theoretically, we can and should wrap - # everything, but that can lead to excessive boilerplate when lambdas are - # nested. - # TODO(mdan): Looks more closely for use cases that actually require this. - if self.state[_Function].level > 2: - self.state[_Function].exit() - return node - - scope = anno.getanno(node, anno.Static.SCOPE) - function_context_name = self.ctx.namer.new_symbol('lscope', - scope.referenced) - self.state[_Function].context_name = function_context_name - anno.setanno(node, 'function_context_name', function_context_name) - - template = """ - ag__.with_function_scope( - lambda function_context: body, function_context_name, options) - """ - node.body = templates.replace_as_expression( - template, - options=self._function_scope_options().to_ast(), - function_context=function_context_name, - function_context_name=gast.Constant(function_context_name, kind=None), - body=node.body) - - self.state[_Function].exit() - return node - - def visit_FunctionDef(self, node): - self.state[_Function].enter() - scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) - - function_context_name = self.ctx.namer.new_symbol('fscope', - scope.referenced) - self.state[_Function].context_name = function_context_name - anno.setanno(node, 'function_context_name', function_context_name) - - node = self.generic_visit(node) - - docstring_node = None - if node.body: - first_statement = node.body[0] - if (isinstance(first_statement, gast.Expr) and - isinstance(first_statement.value, gast.Constant)): - docstring_node = first_statement - node.body = node.body[1:] - - template = """ - with ag__.FunctionScope( - function_name, context_name, options) as function_context: - body - """ - wrapped_body = templates.replace( - template, - function_name=gast.Constant(node.name, kind=None), - context_name=gast.Constant(function_context_name, kind=None), - options=self._function_scope_options().to_ast(), - function_context=function_context_name, - body=node.body) - - if docstring_node is not None: - wrapped_body = [docstring_node] + wrapped_body - - node.body = wrapped_body - - self.state[_Function].exit() - return node - - -def transform(node, ctx): - return FunctionBodyTransformer(ctx).visit(node) diff --git a/tensorflow/python/autograph/converters/functions.py b/tensorflow/python/autograph/converters/functions.py new file mode 100644 index 00000000000..c1003badc1d --- /dev/null +++ b/tensorflow/python/autograph/converters/functions.py @@ -0,0 +1,142 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Converts function definitions and lambdas by adding necessary boilerplate.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.python.autograph.core import converter +from tensorflow.python.autograph.pyct import anno +from tensorflow.python.autograph.pyct import parser +from tensorflow.python.autograph.pyct import templates +from tensorflow.python.autograph.pyct.static_analysis import annos + + +class _Function(object): + + def __init__(self): + self.context_name = None + + +class FunctionTransformer(converter.Base): + """Wraps function bodies around autograph-specific boilerplate.""" + + def visit_Return(self, node): + if node.value is None: + return node + node = self.generic_visit(node) + return templates.replace( + 'return function_context_name.mark_return_value(value)', + function_context_name=self.state[_Function].context_name, + value=node.value) + + def _function_scope_options(self, fn_scope): + """Returns the options with which to create function scopes.""" + # Top-level function receive the options that were directly requested. + # All others receive the options corresponding to a recursive conversion. + # Note: this mainly controls the user_requested flag, which is important + # primarily because the FunctionScope context also creates a + # ControlStatusCtx(autograph=ENABLED) when user_requested is True. See + # function_wrappers.py. + if fn_scope.level == 2: + return self.ctx.program.options + return self.ctx.program.options.call_options() + + def visit_Lambda(self, node): + with self.state[_Function] as fn_scope: + node = self.generic_visit(node) + + # TODO(mdan): Fix the tests so that we can always add this decorator. + if fn_scope.level > 2: + return templates.replace_as_expression( + 'ag__.autograph_artifact(l)', l=node) + + scope = anno.getanno(node, anno.Static.SCOPE) + function_context_name = self.ctx.namer.new_symbol('lscope', + scope.referenced) + fn_scope.context_name = function_context_name + anno.setanno(node, 'function_context_name', function_context_name) + + template = """ + ag__.with_function_scope( + lambda function_context: body, function_context_name, options) + """ + node.body = templates.replace_as_expression( + template, + options=self._function_scope_options(fn_scope).to_ast(), + function_context=function_context_name, + function_context_name=gast.Constant(function_context_name, kind=None), + body=node.body) + + return node + + def visit_FunctionDef(self, node): + with self.state[_Function] as fn_scope: + scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) + + function_context_name = self.ctx.namer.new_symbol('fscope', + scope.referenced) + fn_scope.context_name = function_context_name + anno.setanno(node, 'function_context_name', function_context_name) + + node = self.generic_visit(node) + + if fn_scope.level <= 2: + # Top-level functions lose their decorator because the conversion is + # always just-in-time and by the time it happens the decorators are + # already set to be applied. + node.decorator_list = [] + else: + # TODO(mdan): Fix the tests so that we can always add this decorator. + # Inner functions are converted already, so we insert a decorator to + # prevent double conversion. Double conversion would work too, but this + # saves the overhead. + node.decorator_list.append( + parser.parse_expression('ag__.autograph_artifact')) + + docstring_node = None + if node.body: + first_statement = node.body[0] + if (isinstance(first_statement, gast.Expr) and + isinstance(first_statement.value, gast.Constant)): + docstring_node = first_statement + node.body = node.body[1:] + + template = """ + with ag__.FunctionScope( + function_name, context_name, options) as function_context: + body + """ + wrapped_body = templates.replace( + template, + function_name=gast.Constant(node.name, kind=None), + context_name=gast.Constant(function_context_name, kind=None), + options=self._function_scope_options(fn_scope).to_ast(), + function_context=function_context_name, + body=node.body) + + if docstring_node is not None: + wrapped_body = [docstring_node] + wrapped_body + + node.body = wrapped_body + + return node + + +def transform(node, ctx): + return FunctionTransformer(ctx).visit(node) diff --git a/tensorflow/python/autograph/converters/function_scopes_test.py b/tensorflow/python/autograph/converters/functions_test.py similarity index 84% rename from tensorflow/python/autograph/converters/function_scopes_test.py rename to tensorflow/python/autograph/converters/functions_test.py index 9c8939a6132..aad455e67d7 100644 --- a/tensorflow/python/autograph/converters/function_scopes_test.py +++ b/tensorflow/python/autograph/converters/functions_test.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for function_scopes module.""" +"""Tests for functions module.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.autograph.converters import function_scopes +from tensorflow.python.autograph.converters import functions from tensorflow.python.autograph.core import ag_ctx from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import converter_testing @@ -28,7 +28,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.platform import test -class FunctionBodyTransformerTest(converter_testing.TestCase): +class FunctionTransformer(converter_testing.TestCase): @test_util.run_deprecated_v1 def test_basic(self): @@ -39,7 +39,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase): l += a return l - with self.converted(test_fn, function_scopes, {}) as result: + with self.converted(test_fn, functions, {}) as result: result_op = result.test_fn(constant_op.constant(1)) self.assertIn('test_fn/', result_op.op.name) self.assertEqual('Docstring.', result.test_fn.__doc__) @@ -56,7 +56,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase): """ return tf.constant(1) - with self.converted(test_fn, function_scopes, {}, + with self.converted(test_fn, functions, {}, (constant_op.constant,)) as result: result_op = result.test_fn() self.assertIn('test_fn/', result_op.op.name) @@ -74,7 +74,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase): l += 1 return l, inner_fn(l) - with self.converted(test_fn, function_scopes, {}, + with self.converted(test_fn, functions, {}, (ops.name_scope,)) as result: first, second = result.test_fn(constant_op.constant(1)) self.assertIn('test_fn/', first.op.name) @@ -100,7 +100,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase): 'ag_ctx': ag_ctx, 'converter': converter } - with self.converted(test_fn, function_scopes, ns) as result: + with self.converted(test_fn, functions, ns) as result: result.test_fn() @test_util.run_deprecated_v1 @@ -118,7 +118,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase): ns = {'TestClass': TestClass} node, ctx = self.prepare(TestClass, ns) - node = function_scopes.transform(node, ctx) + node = functions.transform(node, ctx) with self.compiled(node, {}, (ops.name_scope,)) as result: first, second = result.TestClass().test_fn(constant_op.constant(1)) @@ -126,6 +126,15 @@ class FunctionBodyTransformerTest(converter_testing.TestCase): self.assertNotIn('inner_fn', first.op.name) self.assertIn('test_fn/inner_fn/', second.op.inputs[0].name) + def test_lambda_in_return_value(self): + + def test_fn(): + return lambda x: x + 1 + + with self.converted(test_fn, functions, {}) as result: + result_l = result.test_fn() + self.assertTrue(result_l.fake_autograph_artifact) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/autograph/converters/loop_integration_test.py b/tensorflow/python/autograph/converters/loop_integration_test.py new file mode 100644 index 00000000000..351eb7b92cf --- /dev/null +++ b/tensorflow/python/autograph/converters/loop_integration_test.py @@ -0,0 +1,95 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Integration Tests for loop.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.autograph.converters import break_statements +from tensorflow.python.autograph.converters import continue_statements +from tensorflow.python.autograph.converters import control_flow +from tensorflow.python.autograph.core import converter_testing +from tensorflow.python.framework import constant_op +from tensorflow.python.platform import test + + +class LoopIntegrationTest(converter_testing.TestCase): + + def assertTransformedEquivalent(self, test_fn, *inputs): + with self.converted(test_fn, + [break_statements, continue_statements, control_flow], + {}, (constant_op.constant,)) as result: + self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) + + def test_while_loop_with_else(self): + + def test_fn(x): + while x > 2: + x /= 2 + else: + x += 1 + return x + + self.assertTransformedEquivalent(test_fn, 4) + self.assertTransformedEquivalent(test_fn, 2) + + def test_while_loop_with_else_and_break(self): + + def test_fn(cond1): + x = 8 + while x > 2: + x /= 2 + if cond1: + break + else: + x += 1 + return x + + self.assertTransformedEquivalent(test_fn, True) + self.assertTransformedEquivalent(test_fn, False) + + def test_for_loop_with_else(self): + + def test_fn(l): + res = 0 + for x in l: + res += x + else: + res += 1 + return res + + self.assertTransformedEquivalent(test_fn, []) + self.assertTransformedEquivalent(test_fn, [1, 2]) + + def test_for_loop_with_else_and_break(self): + + def test_fn(flag): + l = [1, 2, 3] + res = 0 + for x in l: + res += x + if flag: + break + else: + res += 1 + return res + + self.assertTransformedEquivalent(test_fn, True) + self.assertTransformedEquivalent(test_fn, False) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py index 4b170159b8b..8afcbdfb6bd 100644 --- a/tensorflow/python/autograph/core/converter_testing.py +++ b/tensorflow/python/autograph/core/converter_testing.py @@ -96,6 +96,10 @@ class TestCase(test.TestCase): kwargs = {} return f(*args, **kwargs) + def fake_autograph_artifact(f): + setattr(f, 'fake_autograph_artifact', True) + return f + try: result, source, source_map = loader.load_ast( node, include_source_map=True) @@ -111,6 +115,7 @@ class TestCase(test.TestCase): fake_ag.Feature = converter.Feature fake_ag.utils = utils fake_ag.FunctionScope = function_wrappers.FunctionScope + fake_ag.autograph_artifact = fake_autograph_artifact result.ag__ = fake_ag result.ag_source_map__ = source_map for k, v in namespace.items(): diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py index 7134c2c0b69..d4706879b0a 100644 --- a/tensorflow/python/autograph/impl/conversion.py +++ b/tensorflow/python/autograph/impl/conversion.py @@ -40,7 +40,7 @@ from tensorflow.python.autograph.converters import conditional_expressions from tensorflow.python.autograph.converters import continue_statements from tensorflow.python.autograph.converters import control_flow from tensorflow.python.autograph.converters import directives -from tensorflow.python.autograph.converters import function_scopes +from tensorflow.python.autograph.converters import functions from tensorflow.python.autograph.converters import lists from tensorflow.python.autograph.converters import logical_expressions from tensorflow.python.autograph.converters import return_statements @@ -616,7 +616,7 @@ def node_to_graph(node, context): unsupported_features_checker.verify(node) node = converter.standard_analysis(node, context, is_initial=True) - node = converter.apply_(node, context, function_scopes) + node = converter.apply_(node, context, functions) node = converter.apply_(node, context, arg_defaults) node = converter.apply_(node, context, directives) node = converter.apply_(node, context, break_statements) diff --git a/tensorflow/python/autograph/pyct/BUILD b/tensorflow/python/autograph/pyct/BUILD index 7881b17f88b..735d504f18f 100644 --- a/tensorflow/python/autograph/pyct/BUILD +++ b/tensorflow/python/autograph/pyct/BUILD @@ -24,6 +24,7 @@ py_library( "__init__.py", "anno.py", "ast_util.py", + "cache.py", "cfg.py", "error_utils.py", "errors.py", @@ -76,6 +77,21 @@ py_test( ], ) +py_test( + name = "cache_test", + srcs = ["cache_test.py"], + python_version = "PY3", + srcs_version = "PY2AND3", + tags = [ + "no_oss_py2", + ], + deps = [ + ":pyct", + "//tensorflow/python:client_testlib", + "@gast_archive//:gast", + ], +) + py_test( name = "cfg_test", srcs = ["cfg_test.py"], diff --git a/tensorflow/python/autograph/pyct/cache.py b/tensorflow/python/autograph/pyct/cache.py new file mode 100644 index 00000000000..d9af6e6156a --- /dev/null +++ b/tensorflow/python/autograph/pyct/cache.py @@ -0,0 +1,97 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Caching utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import inspect +import weakref + + +# TODO(mdan): Add a garbage collection hook for cleaning up modules. +class _TransformedFnCache(object): + """Generic hierarchical cache for transformed functions. + + The keys are soft references (i.e. they are discarded when the key is + destroyed) created from the source function by `_get_key`. The subkeys are + strong references and can be any value. Typically they identify different + kinds of transformation. + """ + + __slots__ = ('_cache',) + + def __init__(self): + self._cache = weakref.WeakKeyDictionary() + + def _get_key(self, entity): + raise NotImplementedError('subclasses must override') + + def has(self, entity, subkey): + key = self._get_key(entity) + parent = self._cache.get(key, None) + if parent is None: + return False + return subkey in parent + + def __getitem__(self, entity): + key = self._get_key(entity) + parent = self._cache.get(key, None) + if parent is None: + # The bucket is initialized to support this usage: + # cache[key][subkey] = value + self._cache[key] = parent = {} + return parent + + def __len__(self): + return len(self._cache) + + +class CodeObjectCache(_TransformedFnCache): + """A function cache based on code objects. + + Code objects are good proxies for the source code of a function. + + This cache efficiently handles functions that share code objects, such as + functions defined in a loop, bound methods, etc. + + The cache falls back to the function object, if it doesn't have a code object. + """ + + def _get_key(self, entity): + if hasattr(entity, '__code__'): + return entity.__code__ + else: + return entity + + +class UnboundInstanceCache(_TransformedFnCache): + """A function cache based on unbound function objects. + + Using the function for the cache key allows efficient handling of object + methods. + + Unlike the _CodeObjectCache, this discriminates between different functions + even if they have the same code. This is needed for decorators that may + masquerade as another function. + """ + + def _get_key(self, entity): + if inspect.ismethod(entity): + return entity.__func__ + return entity + + diff --git a/tensorflow/python/autograph/pyct/cache_test.py b/tensorflow/python/autograph/pyct/cache_test.py new file mode 100644 index 00000000000..6c40954be56 --- /dev/null +++ b/tensorflow/python/autograph/pyct/cache_test.py @@ -0,0 +1,79 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for cache module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.autograph.pyct import cache +from tensorflow.python.platform import test + + +class CacheTest(test.TestCase): + + def test_code_object_cache(self): + + def factory(x): + def test_fn(): + return x + 1 + return test_fn + + c = cache.CodeObjectCache() + + f1 = factory(1) + dummy = object() + + c[f1][1] = dummy + + self.assertTrue(c.has(f1, 1)) + self.assertFalse(c.has(f1, 2)) + self.assertIs(c[f1][1], dummy) + self.assertEqual(len(c), 1) + + f2 = factory(2) + + self.assertTrue(c.has(f2, 1)) + self.assertIs(c[f2][1], dummy) + self.assertEqual(len(c), 1) + + def test_unbound_instance_cache(self): + + class TestClass(object): + + def method(self): + pass + + c = cache.UnboundInstanceCache() + + o1 = TestClass() + dummy = object() + + c[o1.method][1] = dummy + + self.assertTrue(c.has(o1.method, 1)) + self.assertFalse(c.has(o1.method, 2)) + self.assertIs(c[o1.method][1], dummy) + self.assertEqual(len(c), 1) + + o2 = TestClass() + + self.assertTrue(c.has(o2.method, 1)) + self.assertIs(c[o2.method][1], dummy) + self.assertEqual(len(c), 1) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/autograph/pyct/common_transformers/anf.py b/tensorflow/python/autograph/pyct/common_transformers/anf.py index 009ae2b4417..15ceefcbdd4 100644 --- a/tensorflow/python/autograph/pyct/common_transformers/anf.py +++ b/tensorflow/python/autograph/pyct/common_transformers/anf.py @@ -36,11 +36,11 @@ from tensorflow.python.autograph.pyct import templates from tensorflow.python.autograph.pyct import transformer +# TODO(mdan): Replace with naming.Namer. class DummyGensym(object): """A dumb gensym that suffixes a stem by sequential numbers from 1000.""" - def __init__(self, ctx): - del ctx + def __init__(self): # A proper implementation needs to account for: # * ctx.info.namespace # * all the symbols defined in the AST @@ -105,14 +105,12 @@ class AnfTransformer(transformer.Base): # processing the `body` and the `orelse` need to be kept together with them, # and not accidentally lifted out of the `if`. - def __init__(self, ctx, config, gensym_source=None): + def __init__(self, ctx, config): """Creates an ANF transformer. Args: ctx: transformer.Context config: Configuration - gensym_source: An optional object with the same interface as `DummyGensym` - for generating unique names """ super(AnfTransformer, self).__init__(ctx) if config is None: @@ -137,10 +135,7 @@ class AnfTransformer(transformer.Base): (ASTEdgePattern(ANY, ANY, gast.expr), REPLACE)] else: self._overrides = config - if gensym_source is None: - self._gensym = DummyGensym(ctx) - else: - self._gensym = gensym_source(ctx) + self._gensym = DummyGensym() self._pending_statements = [] def _consume_pending_statements(self): @@ -529,7 +524,7 @@ def _is_trivial(node): return False -def transform(node, ctx, config=None, gensym_source=None): +def transform(node, ctx, config=None): """Converts the given node to A-normal form (ANF). The general idea of A-normal form: https://en.wikipedia.org/wiki/A-normal_form @@ -605,7 +600,5 @@ def transform(node, ctx, config=None, gensym_source=None): argument provide? config: Optional ANF configuration. If omitted, ANF replaces all expression expect literal constants. - gensym_source: An optional object with the same interface as `DummyGensym` - for generating unique names. """ - return AnfTransformer(ctx, config, gensym_source=gensym_source).visit(node) + return AnfTransformer(ctx, config).visit(node) diff --git a/tensorflow/python/autograph/pyct/common_transformers/anf_test.py b/tensorflow/python/autograph/pyct/common_transformers/anf_test.py index 80715f115be..ced2ee3a975 100644 --- a/tensorflow/python/autograph/pyct/common_transformers/anf_test.py +++ b/tensorflow/python/autograph/pyct/common_transformers/anf_test.py @@ -29,30 +29,8 @@ from tensorflow.python.autograph.pyct.common_transformers import anf from tensorflow.python.platform import test -class DummyGensym(object): - """A dumb gensym that suffixes a stem by sequential numbers from 1000.""" - - def __init__(self, ctx): - del ctx - # A proper implementation needs to account for: - # * ctx.info.namespace - # * all the symbols defined in the AST - # * the symbols generated so far - self._idx = 0 - - def new_name(self, stem='tmp'): - self._idx += 1 - return stem + '_' + str(1000 + self._idx) - - -# These two test functions have to be top-level, not nested, for compatibility -# with some unknown version of Python 2.7 preceding 2.7.15. Why? Because -# `exec` and nested function definitions _incompatibly_ change the -# representation of local variables, such that `exec` inside a nested function -# definition is a syntax error in that version. The tuple form of `exec` fixes -# this problem, but apparently that was introduced in some unknown version of -# Python that's more recent than at least one version that we wish to be -# compatible with. +# TODO(mdan): These two functions no longer need to be at the top level. +# TODO(mdan): Don't use exec. def exec_test_function(): # The point is to test A-normal form conversion of exec # pylint: disable=exec-used @@ -88,9 +66,7 @@ class AnfTestBase(test.TestCase): # statements. exp_node, _ = parser.parse_entity(expected_fn, future_features=()) node, _ = parser.parse_entity(test_fn, future_features=()) - node = anf.transform( - node, self._simple_context(), - config=config, gensym_source=DummyGensym) + node = anf.transform(node, self._simple_context(), config=config) exp_name = exp_node.name # Ignoring the function names in the result because they can't be # the same (because both functions have to exist in the same scope @@ -98,8 +74,7 @@ class AnfTestBase(test.TestCase): node.name = exp_name self.assert_same_ast(exp_node, node) # Check that ANF is idempotent - node_repeated = anf.transform( - node, self._simple_context(), gensym_source=DummyGensym) + node_repeated = anf.transform(node, self._simple_context()) self.assert_same_ast(node_repeated, node) @@ -466,9 +441,7 @@ class AnfNonTransformationTest(AnfTransformerTest): orig_source = parser.unparse(node, indentation=' ') orig_str = textwrap.dedent(orig_source).strip() config = [(anf.ANY, anf.LEAVE)] # Configuration to transform nothing - node = anf.transform( - node, self._simple_context(), - config=config, gensym_source=DummyGensym) + node = anf.transform(node, self._simple_context(), config=config) new_source = parser.unparse(node, indentation=' ') new_str = textwrap.dedent(new_source).strip() self.assertEqual(orig_str, new_str) diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index f69618245f3..bcd27fb6318 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -862,7 +862,7 @@ class BaseSession(SessionInterface): * A `tf.Tensor`. The corresponding fetched value will be a numpy ndarray containing the value of that tensor. - * A `tf.SparseTensor`. + * A `tf.sparse.SparseTensor`. The corresponding fetched value will be a `tf.compat.v1.SparseTensorValue` containing the value of that sparse tensor. @@ -907,7 +907,7 @@ class BaseSession(SessionInterface): `tf.compat.v1.placeholder`, the shape of the value will be checked for compatibility with the placeholder. * If the key is a - `tf.SparseTensor`, + `tf.sparse.SparseTensor`, the value should be a `tf.compat.v1.SparseTensorValue`. * If the key is a nested tuple of `Tensor`s or `SparseTensor`s, the value diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index e1423f312b5..0b65b446e2b 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -31,7 +31,7 @@ from tensorflow.python.util.tf_export import tf_export # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 4, 1) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 4, 6) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None diff --git a/tensorflow/python/compiler/tensorrt/BUILD b/tensorflow/python/compiler/tensorrt/BUILD index 87a89b66d9e..1e4c215994f 100644 --- a/tensorflow/python/compiler/tensorrt/BUILD +++ b/tensorflow/python/compiler/tensorrt/BUILD @@ -200,6 +200,5 @@ cuda_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", "//tensorflow/python/estimator", - "//tensorflow/python/keras", ], ) diff --git a/tensorflow/python/compiler/tensorrt/test/quantization_mnist_test.py b/tensorflow/python/compiler/tensorrt/test/quantization_mnist_test.py index 5c761aae5b1..92e44aa68a8 100644 --- a/tensorflow/python/compiler/tensorrt/test/quantization_mnist_test.py +++ b/tensorflow/python/compiler/tensorrt/test/quantization_mnist_test.py @@ -18,10 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function + +import tensorflow_datasets as tfds + from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import get_linked_tensorrt_version from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled from tensorflow.core.protobuf import config_pb2 -from tensorflow.python import keras from tensorflow.python.compiler.tensorrt import trt_convert from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator.estimator import Estimator @@ -33,10 +35,10 @@ from tensorflow.python.framework import graph_util from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.keras.datasets import mnist from tensorflow.python.layers import layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics from tensorflow.python.ops import nn @@ -81,12 +83,12 @@ class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase): 'kernel', shape=[num_inputs, num_outputs], dtype=dtypes.float32, - initializer=keras.initializers.glorot_uniform()) + initializer=init_ops.GlorotUniform()) bias = variable_scope.get_variable( 'bias', shape=[num_outputs], dtype=dtypes.float32, - initializer=keras.initializers.zeros()) + initializer=init_ops.Zeros()) x = math_ops.matmul(x, kernel) x = _Quantize(x, quantization_range) x = nn.bias_add(x, bias) @@ -179,19 +181,15 @@ class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase): Returns: The Estimator evaluation result. """ - # Get dataset - train_data, test_data = mnist.load_data() - - def _PreprocessFn(x, y): + def _PreprocessFn(entry): + x, y = entry['image'], entry['label'] x = math_ops.cast(x, dtypes.float32) - x = array_ops.expand_dims(x, axis=2) x = 2.0 * (x / 255.0) - 1.0 y = math_ops.cast(y, dtypes.int32) return x, y def _EvalInputFn(): - mnist_x, mnist_y = test_data - dataset = dataset_ops.Dataset.from_tensor_slices((mnist_x, mnist_y)) + dataset = tfds.load('mnist', split='test') dataset = dataset.map( map_func=_PreprocessFn, num_parallel_calls=8).batch(batch_size=batch_size) @@ -201,9 +199,8 @@ class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase): return features, labels def _TrainInputFn(): - mnist_x, mnist_y = train_data - dataset = dataset_ops.Dataset.from_tensor_slices((mnist_x, mnist_y)) - dataset = dataset.shuffle(2 * len(mnist_x)) + dataset = tfds.load('mnist', split='train') + dataset = dataset.shuffle(60000) dataset = dataset.map( map_func=_PreprocessFn, num_parallel_calls=8).batch(batch_size=batch_size) diff --git a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py index 3245a100265..773061d57a7 100644 --- a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py @@ -522,6 +522,25 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): logging.info("Writing graph to %s/%s", temp_dir, graph_name) graph_io.write_graph(gdef, temp_dir, graph_name) + # Remove the graph sequence number prefix from the name only if the name has + # a prefix TRTEngineOp_n_. When expecting_prefix is true, assert such a + # prefix exists. + def _RemoveGraphSequenceNumberImpl(self, name, expecting_prefix): + match = re.search(r"TRTEngineOp_\d+_", name) + has_prefix = match and name.startswith(match.group(0)) + assert (not expecting_prefix) or has_prefix + if has_prefix: + parts = name.split("_", maxsplit=2) + assert len(parts) == 3 + return parts[0] + "_" + parts[2] + return name + + def _RemoveGraphSequenceNumber(self, name): + return self._RemoveGraphSequenceNumberImpl(name, True) + + def _MayRemoveGraphSequenceNumber(self, name): + return self._RemoveGraphSequenceNumberImpl(name, False) + def _VerifyConnections(self, expected_engines, original_gdef, converted_gdef): old_to_new_node_map = { self._ToString(node.name): self._ToString(node.name) @@ -579,11 +598,14 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): # Compute the actual mapping from each node to its input nodes. actual_input_map = {} for node in converted_gdef.node: - name_str = self._ToString(node.name) + name_str = node.name + if node.op == "TRTEngineOp": + name_str = self._RemoveGraphSequenceNumber(name_str) actual_input_map[name_str] = set() input_set = actual_input_map[name_str] for inp in node.input: (prefix, node_name) = _InputName(inp) + node_name = self._MayRemoveGraphSequenceNumber(node_name) input_set.add(prefix + node_name) self.assertEqual( @@ -628,7 +650,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): self.assertIn(function_name, functions) if not IsQuantizationWithCalibration and not is_dynamic_engine: self.assertTrue(len(node.attr["serialized_segment"].s), node.name) - self.assertIn(node.name, expected_engines) + self.assertIn( + self._RemoveGraphSequenceNumber(node.name), expected_engines) self.assertEqual( self._ToBytes(run_params.precision_mode), node.attr["precision_mode"].s, node.name) @@ -662,7 +685,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): node.name for node in gdef_to_verify.node if node.op == "TRTEngineOp" ] for func in gdef_to_verify.library.function: - if not re.search(r"TRTEngineOp_\d+_native_segment", func.signature.name): + if not re.search(r"TRTEngineOp_\d+_\d+_native_segment", + func.signature.name): for node in func.node_def: all_op_names.append(node.name) if node.op == "TRTEngineOp": @@ -670,9 +694,12 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): # Remove the function name prefix. def _Canonicalize(names): return set(self._ToString(name.split("/")[-1]) for name in names) + # Remove the graph sequence number prefix from all the names. + def _RemoveGraphSequenceNumber(names): + return set(self._RemoveGraphSequenceNumber(name) for name in names) all_op_names = _Canonicalize(all_op_names) - trt_op_names = _Canonicalize(trt_op_names) + trt_op_names = _RemoveGraphSequenceNumber(_Canonicalize(trt_op_names)) if isinstance(expected_engines, dict): # For simplicity we don't verify the connections inside the engine in diff --git a/tensorflow/python/compiler/tensorrt/trt_convert_test.py b/tensorflow/python/compiler/tensorrt/trt_convert_test.py index fbe312fc4d6..df21e93f836 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert_test.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import gc import os +import re import tempfile from absl.testing import parameterized @@ -310,6 +311,24 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): converter.save(output_saved_model_dir=output_saved_model_dir) return output_graph_def + # Remove the graph sequence number prefix from the name only if the name has + # a prefix TRTEngineOp_n_. + def _MayRemoveGraphSequenceNumber(self, name): + prefix = re.search(r"TRTEngineOp_\d+_", name) + if prefix and name.startswith(prefix.group(0)): + parts = name.split("_", maxsplit=2) + assert len(parts) == 3 + return parts[0] + "_" + parts[2] + return name + + # Return the unique TRTEngineOp in the given graph def. + def _GetUniqueTRTEngineOp(self, graph_def): + trt_engine_nodes = [ + node for node in graph_def.node if node.op == "TRTEngineOp" + ] + assert len(trt_engine_nodes) == 1 + return trt_engine_nodes[0] + def _TestTrtGraphConverter(self, device, output_saved_model_dir=None, @@ -330,7 +349,10 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): graph_defs_to_verify.append(saved_model_graph_def) for graph_def in graph_defs_to_verify: - node_name_to_op = {node.name: node.op for node in graph_def.node} + node_name_to_op = { + self._MayRemoveGraphSequenceNumber(node.name): node.op + for node in graph_def.node + } self.assertEqual( { "input1": "Placeholder", @@ -434,13 +456,13 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): trt_op_names = [] for node in graph_def.node: if node.op == "TRTEngineOp": - trt_op_names.append(node.name) + trt_op_names.append(self._MayRemoveGraphSequenceNumber(node.name)) if check_fn: check_fn(node) for func in graph_def.library.function: for node in func.node_def: if node.op == "TRTEngineOp": - trt_op_names.append(node.name) + trt_op_names.append(self._MayRemoveGraphSequenceNumber(node.name)) if check_fn: check_fn(node) self.assertEqual(1, len(trt_op_names)) @@ -473,11 +495,15 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): # Verify the converted GraphDef and ConcreteFunction. self._CheckTrtOps(converter._converted_func) # pylint: disable=protected-access + trt_engine_name = self._GetUniqueTRTEngineOp( + converter._converted_graph_def).name + # Save the converted model without any TRT engine cache. output_saved_model_dir = self.mkdtemp() converter.save(output_saved_model_dir) unexpected_asset_file = os.path.join( - output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0") + output_saved_model_dir, + "assets/trt-serialized-engine." + trt_engine_name) self.assertFalse(os.path.exists(unexpected_asset_file)) # Run the converted function to populate the engine cache. @@ -490,7 +516,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): output_saved_model_dir = self.mkdtemp() converter.save(output_saved_model_dir) expected_asset_file = os.path.join( - output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0") + output_saved_model_dir, + "assets/trt-serialized-engine." + trt_engine_name) self.assertTrue(os.path.exists(expected_asset_file)) self.assertTrue(os.path.getsize(expected_asset_file)) @@ -566,6 +593,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): converter.convert(calibration_input_fn=_CalibrationInputFn) + trt_engine_name = self._GetUniqueTRTEngineOp( + converter._converted_graph_def).name + def _CheckFn(node): self.assertTrue(len(node.attr["calibration_data"].s), node.name) @@ -583,7 +613,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): output_saved_model_dir = self.mkdtemp() converter.save(output_saved_model_dir) expected_asset_file = os.path.join( - output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0") + output_saved_model_dir, + "assets/trt-serialized-engine." + trt_engine_name) self.assertTrue(os.path.exists(expected_asset_file)) self.assertTrue(os.path.getsize(expected_asset_file)) @@ -635,6 +666,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): converter = self._CreateConverterV2(input_saved_model_dir) converter.convert() + trt_engine_name = self._GetUniqueTRTEngineOp( + converter._converted_graph_def).name + def _InputFn(): yield np_input1, np_input2 @@ -645,7 +679,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): def _DestroyCache(): with ops.device("GPU:0"): handle = gen_trt_ops.create_trt_resource_handle( - resource_name="TRTEngineOp_0") + resource_name=trt_engine_name) gen_resource_variable_ops.destroy_resource_op( handle, ignore_lookup_error=False) diff --git a/tensorflow/python/data/experimental/ops/batching.py b/tensorflow/python/data/experimental/ops/batching.py index 398ec98a7cb..4d2bfbb8dc9 100644 --- a/tensorflow/python/data/experimental/ops/batching.py +++ b/tensorflow/python/data/experimental/ops/batching.py @@ -100,15 +100,15 @@ def dense_to_ragged_batch(batch_size, @tf_export("data.experimental.dense_to_sparse_batch") def dense_to_sparse_batch(batch_size, row_shape): - """A transformation that batches ragged elements into `tf.SparseTensor`s. + """A transformation that batches ragged elements into `tf.sparse.SparseTensor`s. Like `Dataset.padded_batch()`, this transformation combines multiple consecutive elements of the dataset, which might have different shapes, into a single element. The resulting element has three components (`indices`, `values`, and `dense_shape`), which - comprise a `tf.SparseTensor` that represents the same data. The + comprise a `tf.sparse.SparseTensor` that represents the same data. The `row_shape` represents the dense shape of each row in the - resulting `tf.SparseTensor`, to which the effective batch size is + resulting `tf.sparse.SparseTensor`, to which the effective batch size is prepended. For example: ```python @@ -133,9 +133,9 @@ def dense_to_sparse_batch(batch_size, row_shape): consecutive elements of this dataset to combine in a single batch. row_shape: A `tf.TensorShape` or `tf.int64` vector tensor-like object representing the equivalent dense shape of a row in the resulting - `tf.SparseTensor`. Each element of this dataset must have the same rank as - `row_shape`, and must have size less than or equal to `row_shape` in each - dimension. + `tf.sparse.SparseTensor`. Each element of this dataset must have the same + rank as `row_shape`, and must have size less than or equal to `row_shape` + in each dimension. Returns: A `Dataset` transformation function, which can be passed to @@ -295,7 +295,7 @@ def unbatch(): class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset): - """A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s.""" + """A `Dataset` that batches ragged dense elements into `tf.sparse.SparseTensor`s.""" def __init__(self, input_dataset, batch_size, row_shape): """See `Dataset.dense_to_sparse_batch()` for more details.""" diff --git a/tensorflow/python/data/experimental/ops/grouping.py b/tensorflow/python/data/experimental/ops/grouping.py index e48ffbc2d46..e3bfb7950ff 100644 --- a/tensorflow/python/data/experimental/ops/grouping.py +++ b/tensorflow/python/data/experimental/ops/grouping.py @@ -161,7 +161,7 @@ def bucket_by_sequence_length(element_length_func, bucket), and caller must ensure that the source `Dataset` does not contain any elements with length longer than `max(bucket_boundaries)`. no_padding: `bool`, indicates whether to pad the batch features (features - need to be either of type `tf.SparseTensor` or of same shape). + need to be either of type `tf.sparse.SparseTensor` or of same shape). drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing whether the last batch should be dropped in the case it has fewer than `batch_size` elements; the default behavior is not to drop the smaller diff --git a/tensorflow/python/data/kernel_tests/from_sparse_tensor_slices_test.py b/tensorflow/python/data/kernel_tests/from_sparse_tensor_slices_test.py index 147d31366bb..d7a2c158de9 100644 --- a/tensorflow/python/data/kernel_tests/from_sparse_tensor_slices_test.py +++ b/tensorflow/python/data/kernel_tests/from_sparse_tensor_slices_test.py @@ -37,7 +37,7 @@ class FromSparseTensorSlicesTest(test_base.DatasetTestBase, @combinations.generate( combinations.combine(tf_api_version=1, mode=["graph"])) def testFromSparseTensorSlices(self): - """Test a dataset based on slices of a `tf.SparseTensor`.""" + """Test a dataset based on slices of a `tf.sparse.SparseTensor`.""" st = array_ops.sparse_placeholder(dtypes.float64) iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_sparse_tensor_slices(st)) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index d2c247678a2..b23df3672c9 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -158,7 +158,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): Elements can be nested structures of tuples, named tuples, and dictionaries. Element components can be of any type representable by `tf.TypeSpec`, - including `tf.Tensor`, `tf.data.Dataset`, `tf.SparseTensor`, + including `tf.Tensor`, `tf.data.Dataset`, `tf.sparse.SparseTensor`, `tf.RaggedTensor`, and `tf.TensorArray`. >>> a = 1 # Integer element @@ -1456,7 +1456,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): array([[ 10, 100], [ 11, 12]], dtype=int32))] See also `tf.data.experimental.dense_to_sparse_batch`, which combines - elements that may have different shapes into a `tf.SparseTensor`. + elements that may have different shapes into a `tf.sparse.SparseTensor`. Args: batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of @@ -1589,6 +1589,23 @@ name=None)) >>> list(d.as_numpy_iterator()) [b'HELLO', b'WORLD'] + 3) Use `tf.numpy_function`, which also allows you to write arbitrary + Python code. Note that `tf.py_function` accepts `tf.Tensor` whereas + `tf.numpy_function` accepts numpy arrays and returns only numpy arrays. + For example: + + >>> d = tf.data.Dataset.from_tensor_slices(['hello', 'world']) + >>> def upper_case_fn(t: np.ndarray): + ... return t.decode('utf-8').upper() + >>> d = d.map(lambda x: tf.numpy_function(func=upper_case_fn, + ... inp=[x], Tout=tf.string)) + >>> list(d.as_numpy_iterator()) + [b'HELLO', b'WORLD'] + + Note that the use of `tf.numpy_function` and `tf.py_function` + in general precludes the possibility of executing user-defined + transformations in parallel (because of Python GIL). + Performance can often be improved by setting `num_parallel_calls` so that `map` will use multiple threads to process elements. If deterministic order isn't required, it can also improve performance to set @@ -2086,15 +2103,40 @@ class DatasetV1(DatasetV2): raise NotImplementedError("Dataset._as_variant_tensor") @deprecation.deprecated( - None, "Use `for ... in dataset:` to iterate over a dataset. If using " - "`tf.estimator`, return the `Dataset` object directly from your input " - "function. As a last resort, you can use " - "`tf.compat.v1.data.make_one_shot_iterator(dataset)`.") + None, "This is a deprecated API that should only be used in TF 1 graph " + "mode and legacy TF 2 graph mode available through `tf.compat.v1`. In " + "all other situations -- namely, eager mode and inside `tf.function` -- " + "you can consume dataset elements using `for elem in dataset: ...` or " + "by explicitly creating iterator via `iterator = iter(dataset)` and " + "fetching its elements via `values = next(iterator)`. Furthermore, " + "this API is not available in TF 2. During the transition from TF 1 " + "to TF 2 you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)` " + "to create a TF 1 graph mode style iterator for a dataset created " + "through TF 2 APIs. Note that this should be a transient state of your " + "code base as there are in general no guarantees about the " + "interoperability of TF 1 and TF 2 code.") def make_one_shot_iterator(self): """Creates an `Iterator` for enumerating the elements of this dataset. Note: The returned iterator will be initialized automatically. - A "one-shot" iterator does not currently support re-initialization. + A "one-shot" iterator does not currently support re-initialization. For + that see `make_initializable_iterator`. + + Example: + + ```python + # Building graph ... + dataset = ... + next_value = dataset.make_one_shot_iterator().get_next() + + # ... from within a session ... + try: + while True: + value = sess.run(next_value) + ... + except tf.errors.OutOfRangeError: + pass + ``` Returns: An `Iterator` over the elements of this dataset. @@ -2153,10 +2195,19 @@ class DatasetV1(DatasetV2): get_legacy_output_classes(self)) @deprecation.deprecated( - None, "Use `for ... in dataset:` to iterate over a dataset. If using " - "`tf.estimator`, return the `Dataset` object directly from your input " - "function. As a last resort, you can use " - "`tf.compat.v1.data.make_initializable_iterator(dataset)`.") + None, "This is a deprecated API that should only be used in TF 1 graph " + "mode and legacy TF 2 graph mode available through `tf.compat.v1`. " + "In all other situations -- namely, eager mode and inside `tf.function` " + "-- you can consume dataset elements using `for elem in dataset: ...` " + "or by explicitly creating iterator via `iterator = iter(dataset)` " + "and fetching its elements via `values = next(iterator)`. " + "Furthermore, this API is not available in TF 2. During the transition " + "from TF 1 to TF 2 you can use " + "`tf.compat.v1.data.make_initializable_iterator(dataset)` to create a TF " + "1 graph mode style iterator for a dataset created through TF 2 APIs. " + "Note that this should be a transient state of your code base as there " + "are in general no guarantees about the interoperability of TF 1 and TF " + "2 code.") def make_initializable_iterator(self, shared_name=None): """Creates an `Iterator` for enumerating the elements of this dataset. @@ -2164,10 +2215,19 @@ class DatasetV1(DatasetV2): and you must run the `iterator.initializer` operation before using it: ```python + # Building graph ... dataset = ... iterator = dataset.make_initializable_iterator() - # ... + next_value = iterator.get_next() # This is a Tensor. + + # ... from within a session ... sess.run(iterator.initializer) + try: + while True: + value = sess.run(next_value) + ... + except tf.errors.OutOfRangeError: + pass ``` Args: @@ -2181,7 +2241,6 @@ class DatasetV1(DatasetV2): Raises: RuntimeError: If eager execution is enabled. """ - return self._make_initializable_iterator(shared_name) def _make_initializable_iterator(self, shared_name=None): # pylint: disable=missing-docstring @@ -2266,10 +2325,10 @@ class DatasetV1(DatasetV2): @staticmethod @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.") def from_sparse_tensor_slices(sparse_tensor): - """Splits each rank-N `tf.SparseTensor` in this dataset row-wise. + """Splits each rank-N `tf.sparse.SparseTensor` in this dataset row-wise. Args: - sparse_tensor: A `tf.SparseTensor`. + sparse_tensor: A `tf.sparse.SparseTensor`. Returns: Dataset: A `Dataset` of rank-(N-1) sparse tensors. @@ -2874,14 +2933,14 @@ class TensorSliceDataset(DatasetSource): class SparseTensorSliceDataset(DatasetSource): - """A `Dataset` that splits a rank-N `tf.SparseTensor` into its rows.""" + """A `Dataset` that splits a rank-N `tf.sparse.SparseTensor` into its rows.""" def __init__(self, sparse_tensor): """See `Dataset.from_sparse_tensor_slices()` for details.""" if not isinstance(sparse_tensor, sparse_tensor_lib.SparseTensor): raise TypeError( - "`sparse_tensor` must be a `tf.SparseTensor` object. Was {}.".format( - sparse_tensor)) + "`sparse_tensor` must be a `tf.sparse.SparseTensor` object." + "Was {}.".format(sparse_tensor)) self._sparse_tensor = sparse_tensor indices_shape = self._sparse_tensor.indices.get_shape() diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index 09187705c16..5bd2824839f 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -448,7 +448,7 @@ class Iterator(trackable.Trackable): def output_classes(self): """Returns the class of each component of an element of this iterator. - The expected values are `tf.Tensor` and `tf.SparseTensor`. + The expected values are `tf.Tensor` and `tf.sparse.SparseTensor`. Returns: A nested structure of Python `type` objects corresponding to each @@ -677,7 +677,7 @@ class OwnedIterator(trackable.Trackable, composite_tensor.CompositeTensor): def output_classes(self): """Returns the class of each component of an element of this iterator. - The expected values are `tf.Tensor` and `tf.SparseTensor`. + The expected values are `tf.Tensor` and `tf.sparse.SparseTensor`. Returns: A nested structure of Python `type` objects corresponding to each diff --git a/tensorflow/core/data/service/python/BUILD b/tensorflow/python/data/service/BUILD similarity index 100% rename from tensorflow/core/data/service/python/BUILD rename to tensorflow/python/data/service/BUILD diff --git a/tensorflow/python/data/service/__init__.py b/tensorflow/python/data/service/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorflow/core/data/service/python/server_lib.py b/tensorflow/python/data/service/server_lib.py similarity index 98% rename from tensorflow/core/data/service/python/server_lib.py rename to tensorflow/python/data/service/server_lib.py index d3636123e0f..45b1924b812 100644 --- a/tensorflow/core/data/service/python/server_lib.py +++ b/tensorflow/python/data/service/server_lib.py @@ -20,7 +20,7 @@ from __future__ import print_function # pylint: disable=invalid-import-order,g-bad-import-order, unused-import from tensorflow.python import pywrap_tensorflow -from tensorflow.core.data.service.python import _pywrap_server_lib +from tensorflow.python.data.service import _pywrap_server_lib class MasterServer(object): diff --git a/tensorflow/core/data/service/python/server_lib_test.py b/tensorflow/python/data/service/server_lib_test.py similarity index 95% rename from tensorflow/core/data/service/python/server_lib_test.py rename to tensorflow/python/data/service/server_lib_test.py index 6e9d6b9c043..b18262bf52b 100644 --- a/tensorflow/core/data/service/python/server_lib_test.py +++ b/tensorflow/python/data/service/server_lib_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.core.data.service.python import server_lib +from tensorflow.python.data.service import server_lib from tensorflow.python.platform import test diff --git a/tensorflow/core/data/service/python/server_lib_wrapper.cc b/tensorflow/python/data/service/server_lib_wrapper.cc similarity index 100% rename from tensorflow/core/data/service/python/server_lib_wrapper.cc rename to tensorflow/python/data/service/server_lib_wrapper.cc diff --git a/tensorflow/python/data/util/sparse.py b/tensorflow/python/data/util/sparse.py index d7e516e24f9..fc1f9dcbf90 100644 --- a/tensorflow/python/data/util/sparse.py +++ b/tensorflow/python/data/util/sparse.py @@ -47,7 +47,7 @@ def as_dense_shapes(shapes, classes): Returns: a structure matching the nested structure of `shapes`, containing `tensor_shape.unknown_shape()` at positions where `classes` contains - `tf.SparseTensor` and matching contents of `shapes` otherwise + `tf.sparse.SparseTensor` and matching contents of `shapes` otherwise """ ret = nest.pack_sequence_as(shapes, [ tensor_shape.unknown_shape() if c is sparse_tensor.SparseTensor else shape @@ -65,8 +65,8 @@ def as_dense_types(types, classes): Returns: a structure matching the nested structure of `types`, containing - `dtypes.variant` at positions where `classes` contains `tf.SparseTensor` and - matching contents of `types` otherwise + `dtypes.variant` at positions where `classes` contains + `tf.sparse.SparseTensor` and matching contents of `types` otherwise """ ret = nest.pack_sequence_as(types, [ dtypes.variant if c is sparse_tensor.SparseTensor else ty @@ -106,8 +106,8 @@ def get_classes(tensors): Returns: a structure matching the nested structure of `tensors`, containing - `tf.SparseTensor` at positions where `tensors` contains a sparse tensor and - `tf.Tensor` otherwise + `tf.sparse.SparseTensor` at positions where `tensors` contains a sparse + tensor and `tf.Tensor` otherwise. """ return nest.pack_sequence_as(tensors, [ sparse_tensor.SparseTensor diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index b0f982e3d5c..d0a70f18294 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -255,22 +255,17 @@ py_library( ) py_library( - name = "mirrored_strategy", - srcs = ["mirrored_strategy.py"], + name = "mirrored_run", + srcs = ["mirrored_run.py"], deps = [ - ":cross_device_ops", ":device_util", ":distribute_lib", - ":input_lib", - ":multi_worker_util", - ":numpy_dataset", ":reduce_util", ":shared_variable_creator", ":values", "//tensorflow/python:array_ops", "//tensorflow/python:config", "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", "//tensorflow/python:device", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", @@ -284,9 +279,33 @@ py_library( "//tensorflow/python:variable_scope", "//tensorflow/python/autograph/core", "//tensorflow/python/autograph/impl", - "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", + ], +) + +py_library( + name = "mirrored_strategy", + srcs = ["mirrored_strategy.py"], + deps = [ + ":cross_device_ops", + ":device_util", + ":distribute_lib", + ":input_lib", + ":mirrored_run", + ":multi_worker_util", + ":numpy_dataset", + ":reduce_util", + ":values", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:device", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:util", + "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", + "//tensorflow/python/eager:context", "//tensorflow/python/eager:tape", ], ) @@ -297,7 +316,7 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ ":input_lib", - ":mirrored_strategy", + ":mirrored_run", ":numpy_dataset", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", diff --git a/tensorflow/python/distribute/cluster_resolver/BUILD b/tensorflow/python/distribute/cluster_resolver/BUILD index 1a9d0202837..8577f1978b9 100644 --- a/tensorflow/python/distribute/cluster_resolver/BUILD +++ b/tensorflow/python/distribute/cluster_resolver/BUILD @@ -152,6 +152,7 @@ tf_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:training_server_lib", "//tensorflow/python/tpu/client", + "@absl_py//absl/testing:flagsaver", ], ) diff --git a/tensorflow/python/distribute/mirrored_run.py b/tensorflow/python/distribute/mirrored_run.py new file mode 100644 index 00000000000..2cd139c387f --- /dev/null +++ b/tensorflow/python/distribute/mirrored_run.py @@ -0,0 +1,454 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Class MirroredStrategy implementing tf.distribute.Strategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import functools +import threading +import weakref + +from tensorflow.python import pywrap_tfe +from tensorflow.python.autograph.core import ag_ctx as autograph_ctx +from tensorflow.python.autograph.impl import api as autograph +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import shared_variable_creator +from tensorflow.python.distribute import values +from tensorflow.python.eager import context +from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import device as tf_device +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import summary_ops_v2 +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import coordinator + + +def call_for_each_replica(strategy, fn, args=None, kwargs=None): + """Call `fn` on each worker devices(replica). + + It's highly recommended to wrap the call to this function inside a + `tf.function`, otherwise the performance is poor. + + Args: + strategy: `tf.distribute.Strategy`. + fn: function to call on each worker devices. + args: positional arguments to `fn`. + kwargs: keyword arguments to `fn`. + + Returns: + Wrapped returned value of `fn` from all replicas. + """ + if args is None: + args = () + if kwargs is None: + kwargs = {} + + if isinstance(fn, def_function.Function): + if strategy not in _cfer_fn_cache: + _cfer_fn_cache[strategy] = weakref.WeakKeyDictionary() + wrapped = _cfer_fn_cache[strategy].get(fn) + if wrapped is None: + # We need to wrap fn such that it triggers _call_for_each_replica inside + # the tf.function. We use _clone() instead of @tf.function wrapped + # call_for_each_replica() because we would like to retain the arguments to + # the @tf.function decorator of fn. + wrapped = fn._clone( # pylint: disable=protected-access + python_function=functools.partial(call_for_each_replica, strategy, + fn.python_function)) + _cfer_fn_cache[strategy][fn] = wrapped + return wrapped(args, kwargs) + + if context.executing_eagerly(): + logging.log_first_n( + logging.WARN, "Using %s eagerly has significant " + "overhead currently. We will be working on improving " + "this in the future, but for now please wrap " + "`call_for_each_replica` or `experimental_run` or " + "`experimental_run_v2` inside a tf.function to get " + "the best performance." % strategy.__class__.__name__, 5) + else: + # When a tf.function is wrapped to trigger _call_for_each_replica (see + # the other branch above), AutoGraph stops conversion at + # _call_for_each_replica itself (TF library functions are whitelisted). + # This makes sure that the Python function that originally passed to + # the tf.function is still converted. + fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) + + return _call_for_each_replica(strategy, fn, args, kwargs) + + +# Per strategy cache for call_for_each_replica def_function.Function objects. +_cfer_fn_cache = weakref.WeakKeyDictionary() + + +@contextlib.contextmanager +def _enter_graph(g, eager, creator_stack=None): + """Context manager for selecting a graph and maybe eager mode.""" + if eager: + with g.as_default(), context.eager_mode(): + if creator_stack is not None: + g._variable_creator_stack = creator_stack # pylint: disable=protected-access + yield + else: + with g.as_default(): + if creator_stack is not None: + g._variable_creator_stack = creator_stack # pylint: disable=protected-access + yield + + +def _cpu_device(device): + cpu_device = tf_device.DeviceSpec.from_string(device) + cpu_device = cpu_device.replace(device_type="CPU", device_index=0) + return cpu_device.to_string() + + +class _RequestedStop(Exception): # pylint: disable=g-bad-exception-name + pass + + +def _call_for_each_replica(distribution, fn, args, kwargs): + """Run `fn` in separate threads, once per replica/worker device. + + Args: + distribution: the DistributionStrategy object. + fn: function to run (will be run once per replica, each in its own thread). + args: positional arguments for `fn` + kwargs: keyword arguments for `fn`. + + Returns: + Merged return value of `fn` across all replicas. + + Raises: + RuntimeError: If fn() calls get_replica_context().merge_call() a different + number of times from the available devices. + """ + # TODO(josh11b): Add this option once we add synchronization to variable + # creation. Until then, this is pretty unsafe to use. + run_concurrently = False + if not context.executing_eagerly(): + # Needed for per-thread device, etc. contexts in graph mode. + ops.get_default_graph().switch_to_thread_local() + + coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,)) + + shared_variable_store = {} + devices = distribution.extended.worker_devices + + # TODO(isaprykin): Create these threads once instead of during every call. + threads = [] + for index in range(len(devices)): + variable_creator_fn = shared_variable_creator.make_fn( + shared_variable_store, index) + t = _MirroredReplicaThread( + distribution, coord, index, devices, variable_creator_fn, fn, + values.select_replica(index, args), + values.select_replica(index, kwargs)) + threads.append(t) + + for t in threads: + t.start() + + # When `fn` starts `should_run` event is set on _MirroredReplicaThread + # (`MRT`) threads. The execution waits until + # `MRT.has_paused` is set, which indicates that either `fn` is + # complete or a `get_replica_context().merge_call()` is called. If `fn` is + # complete, then `MRT.done` is set to True. Otherwise, arguments + # of `get_replica_context().merge_call` from all paused threads are grouped + # and the `merge_fn` is performed. Results of the + # `get_replica_context().merge_call` are then set to `MRT.merge_result`. + # Each such `get_replica_context().merge_call` call returns the + # `MRT.merge_result` for that thread when `MRT.should_run` event + # is reset again. Execution of `fn` resumes. + + try: + with coord.stop_on_exception(): + all_done = False + while not all_done and not coord.should_stop(): + done = [] + if run_concurrently: + for t in threads: + t.should_run.set() + for t in threads: + t.has_paused.wait() + t.has_paused.clear() + if coord.should_stop(): + return None + done.append(t.done) + else: + for t in threads: + t.should_run.set() + t.has_paused.wait() + t.has_paused.clear() + if coord.should_stop(): + return None + done.append(t.done) + if coord.should_stop(): + return None + all_done = all(done) + if not all_done: + if any(done): + raise RuntimeError("Some replicas made a different number of " + "replica_context().merge_call() calls.") + # get_replica_context().merge_call() case + merge_args = values.regroup(tuple(t.merge_args for t in threads)) + merge_kwargs = values.regroup(tuple(t.merge_kwargs for t in threads)) + # We capture the name_scope of the MRT when we call merge_fn + # to ensure that if we have opened a name scope in the MRT, + # it will be respected when executing the merge function. We only + # capture the name_scope from the first MRT and assume it is + # the same for all other MRTs. + mtt_captured_name_scope = threads[0].captured_name_scope + mtt_captured_var_scope = threads[0].captured_var_scope + # Capture and merge the control dependencies from all the threads. + mtt_captured_control_deps = set() + for t in threads: + mtt_captured_control_deps.update(t.captured_control_deps) + with ops.name_scope(mtt_captured_name_scope),\ + ops.control_dependencies(mtt_captured_control_deps), \ + variable_scope.variable_scope(mtt_captured_var_scope): + merge_result = threads[0].merge_fn(distribution, *merge_args, + **merge_kwargs) + for r, t in enumerate(threads): + t.merge_result = values.select_replica(r, merge_result) + finally: + for t in threads: + t.should_run.set() + coord.join(threads) + + return values.regroup(tuple(t.main_result for t in threads)) + + +class _MirroredReplicaThread(threading.Thread): + """A thread that runs() a function on a device.""" + + def __init__(self, dist, coord, replica_id, devices, variable_creator_fn, + fn, args, kwargs): + super(_MirroredReplicaThread, self).__init__() + self.coord = coord + self.distribution = dist + self.devices = devices + self.replica_id = replica_id + self.variable_creator_fn = variable_creator_fn + # State needed to run and return the results of `fn`. + self.main_fn = fn + self.main_args = args + self.main_kwargs = kwargs + self.main_result = None + self.done = False + # State needed to run the next merge_call() (if any) requested via + # ReplicaContext. + self.merge_fn = None + self.merge_args = None + self.merge_kwargs = None + self.merge_result = None + self.captured_name_scope = None + self.captured_var_scope = None + # We use a thread.Event for the main thread to signal when this + # thread should start running (`should_run`), and another for + # this thread to transfer control back to the main thread + # (`has_paused`, either when it gets to a + # `get_replica_context().merge_call` or when `fn` returns). In + # either case the event starts cleared, is signaled by calling + # set(). The receiving thread waits for the signal by calling + # wait() and then immediately clearing the event using clear(). + self.should_run = threading.Event() + self.has_paused = threading.Event() + # These fields have to do with inheriting various contexts from the + # parent thread: + context.ensure_initialized() + ctx = context.context() + self.in_eager = ctx.executing_eagerly() + self.record_thread_local_summary_state() + self.record_thread_local_eager_context_state() + self.context_device_policy = ( + pywrap_tfe.TFE_ContextGetDevicePlacementPolicy( + ctx._context_handle)) # pylint: disable=protected-access + self.graph = ops.get_default_graph() + with ops.init_scope(): + self._init_in_eager = context.executing_eagerly() + self._init_graph = ops.get_default_graph() + self._variable_creator_stack = self.graph._variable_creator_stack[:] # pylint: disable=protected-access + self._var_scope = variable_scope.get_variable_scope() + # Adding a "/" at end lets us re-enter this scope later. + self._name_scope = self.graph.get_name_scope() + if self._name_scope: + self._name_scope += "/" + if self.replica_id > 0: + if not self._name_scope: + self._name_scope = "" + self._name_scope += "replica_%d/" % self.replica_id + + def run(self): + self.should_run.wait() + self.should_run.clear() + try: + if self.coord.should_stop(): + return + self.restore_thread_local_summary_state() + self.restore_thread_local_eager_context_state() + # TODO(josh11b): Use current logical device instead of 0 here. + with self.coord.stop_on_exception(), \ + _enter_graph(self._init_graph, self._init_in_eager), \ + _enter_graph(self.graph, self.in_eager, + self._variable_creator_stack), \ + context.device_policy(self.context_device_policy), \ + _MirroredReplicaContext(self.distribution, constant_op.constant( + self.replica_id, dtypes.int32)), \ + ops.device(self.devices[self.replica_id]), \ + ops.name_scope(self._name_scope), \ + variable_scope.variable_scope( + self._var_scope, reuse=self.replica_id > 0), \ + variable_scope.variable_creator_scope(self.variable_creator_fn): + self.main_result = self.main_fn(*self.main_args, **self.main_kwargs) + self.done = True + finally: + self.has_paused.set() + + def record_thread_local_summary_state(self): + """Record the thread local summary state in self.""" + # TODO(slebedev): is this still relevant? the referenced bug is closed. + summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access + self._summary_step = summary_state.step + self._summary_writer = summary_state.writer + self._summary_recording = summary_state.is_recording + self._summary_recording_distribution_strategy = ( + summary_state.is_recording_distribution_strategy) + + def restore_thread_local_summary_state(self): + """Restore thread local summary state from self.""" + # TODO(slebedev): is this still relevant? the referenced bug is closed. + summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access + summary_state.step = self._summary_step + summary_state.writer = self._summary_writer + summary_state.is_recording = self._summary_recording + summary_state.is_recording_distribution_strategy = ( + self._summary_recording_distribution_strategy) + + def record_thread_local_eager_context_state(self): + ctx = context.context() + eager_context_state = ctx._thread_local_data # pylint: disable=protected-access + self._eager_context_op_callbacks = eager_context_state.op_callbacks + # TODO(b/125892694): record other fields in EagerContext. + + def restore_thread_local_eager_context_state(self): + ctx = context.context() + eager_context_state = ctx._thread_local_data # pylint: disable=protected-access + eager_context_state.op_callbacks = self._eager_context_op_callbacks + # TODO(b/125892694): record other fields in EagerContext. + + +class _MirroredReplicaContext(distribute_lib.ReplicaContext): + """ReplicaContext for synchronized replica.""" + + def _merge_call(self, fn, args, kwargs): + """`merge_call()` implementation for synchronized replica. + + This pauses the current replica thread and passes `fn` and its arguments to + the main thread. The main thread will wait until all replicas pause, then + invoke `fn` with grouped arugments. The current replica thread will continue + after `fn` completes. + + See `_call_for_each_replica` for the logic in the main thread. + + Args: + fn: a function that is called in cross replica context with grouped + arguments from each replica. `fn` should returns grouped values. + args: positional arguments to `fn`. + kwargs: keyward arguments to `fn`. + + Returns: + Return value of `fn` for the current replica. + + Raises: + RuntimeError: when merge_call happens in a different graph, e.g. in a + different tf.function, which is not supported now. + _RequestedStop: when stop is requested. + + """ + t = threading.current_thread() + assert isinstance(t, _MirroredReplicaThread) + t.merge_fn = fn + t.merge_args = args + t.merge_kwargs = kwargs + t.captured_name_scope = t.graph.get_name_scope() + # Adding a "/" at end lets us re-enter this scope later. + if t.captured_name_scope: + t.captured_name_scope += "/" + + t.captured_var_scope = variable_scope.get_variable_scope() + t.captured_control_deps = t.graph._current_control_dependencies() # pylint: disable=protected-access + + # It is problematic if `merge_call` is called under a different graph other + # than the one that `_call_for_each_replica` is called under, there are + # 3 cases this can happen: + # + # 1. The `fn` passed to `_call_for_each_replica` is decorated with + # `tf.function` and there is a `merge_call` in `fn`. Since + # MirroredStrategy traces a separate function per thread (per device), + # and each trace takes a shared lock, the lock is never released by the + # first thread and subsequent replica threads cannot proceed to trace + # their own functions. This issue is addressed by always converting + # `_call_for_each_replica(tf.function(f))` to + # ``tf.function(_call_for_each_replica(f))`.` in + # `MirroredStrategy._call_for_each_replica`. + # + # 2. The `fn` passed to `_call_for_each_replica` contains a nested + # `tf.function`, and there is a `merge_call` in the nested `tf.function`. + # In this case each thread can successfully trace its own function, but + # since the `merge_fn` passed to `merge_call` is executed in the main + # thread (where `_call_for_each_replica` is executed), it can't access + # the tensors that come from different graphs. + # + # 3. The `fn` passed to `_call_for_each_replica` contains a control-flow + # statement, and there is a `merge_call` inside the control-flow body, + # `fn` or `_call_for_each_replica` is decorated with `tf.function`. + # Control flow statement creates a separate graph for its body, similar + # to #2, `merge_fn` executed in the main thread can't access the + # tensors that come from different graphs. + # + # We raise an error for #2 and #3. + if ops.get_default_graph() != t.graph: + raise RuntimeError( + "`merge_call` called while defining a new graph or a tf.function." + " This can often happen if the function `fn` passed to" + " `strategy.run()` contains a nested `@tf.function`, and the nested " + "`@tf.function` contains a synchronization point, such as aggregating" + " gradients (e.g, optimizer.apply_gradients), or if the function `fn`" + " uses a control flow statement which contains a synchronization" + " point in the body. Such behaviors are not yet supported. Instead," + " please avoid nested `tf.function`s or control flow statements that" + " may potentially cross a synchronization boundary, for example," + " wrap the `fn` passed to `strategy.run` or the entire `strategy.run`" + " inside a `tf.function` or move the control flow out of `fn`") + + t.has_paused.set() + t.should_run.wait() + t.should_run.clear() + if t.coord.should_stop(): + raise _RequestedStop() + return t.merge_result + + @property + def devices(self): + distribute_lib.require_replica_context(self) + replica_id = tensor_util.constant_value(self._replica_id_in_sync_group) + return [self._strategy.extended.worker_devices_by_replica[replica_id]] diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index 698bf2c2ce6..de66128ce37 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -18,191 +18,33 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib import copy -import functools -import threading -import weakref -from tensorflow.python import pywrap_tfe -from tensorflow.python.autograph.core import ag_ctx as autograph_ctx -from tensorflow.python.autograph.impl import api as autograph from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import input_lib +from tensorflow.python.distribute import mirrored_run from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import reduce_util -from tensorflow.python.distribute import shared_variable_creator from tensorflow.python.distribute import values from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver from tensorflow.python.eager import context -from tensorflow.python.eager import def_function from tensorflow.python.eager import tape from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as tf_device -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import summary_ops_v2 -from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import coordinator from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export - # TODO(josh11b): Replace asserts in this file with if ...: raise ... -@contextlib.contextmanager -def _enter_graph(g, eager, creator_stack=None): - """Context manager for selecting a graph and maybe eager mode.""" - if eager: - with g.as_default(), context.eager_mode(): - if creator_stack is not None: - g._variable_creator_stack = creator_stack # pylint: disable=protected-access - yield - else: - with g.as_default(): - if creator_stack is not None: - g._variable_creator_stack = creator_stack # pylint: disable=protected-access - yield - - -def _cpu_device(device): - cpu_device = tf_device.DeviceSpec.from_string(device) - cpu_device = cpu_device.replace(device_type="CPU", device_index=0) - return cpu_device.to_string() - - -class _RequestedStop(Exception): # pylint: disable=g-bad-exception-name - pass - - -# _call_for_each_replica is not a member of MirroredStrategy so that it is -# not allowed to use anything specific to MirroredStrategy and thus -# can be shared with other distribution strategies. - - -# TODO(yuefengz): maybe create a common class for those who need to call this -# _call_for_each_replica. -def _call_for_each_replica(distribution, devices, fn, args, kwargs): - """Run `fn` in separate threads, once per replica/worker device. - - Args: - distribution: the DistributionStrategy object. - devices: the devices to run `fn` on (logical device 0 for each replica). - fn: function to run (will be run once per replica, each in its own thread). - args: positional arguments for `fn` - kwargs: keyword arguments for `fn`. - - Returns: - Merged return value of `fn` across all replicas. - - Raises: - RuntimeError: If fn() calls get_replica_context().merge_call() a different - number of times from the available devices. - """ - # TODO(josh11b): Add this option once we add synchronization to variable - # creation. Until then, this is pretty unsafe to use. - run_concurrently = False - if not context.executing_eagerly(): - # Needed for per-thread device, etc. contexts in graph mode. - ops.get_default_graph().switch_to_thread_local() - - coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,)) - - shared_variable_store = {} - - # TODO(isaprykin): Create these threads once instead of during every call. - threads = [] - for index in range(len(devices)): - variable_creator_fn = shared_variable_creator.make_fn( - shared_variable_store, index) - t = _MirroredReplicaThread( - distribution, coord, index, devices, variable_creator_fn, fn, - values.select_replica(index, args), - values.select_replica(index, kwargs)) - threads.append(t) - - for t in threads: - t.start() - - # When `fn` starts `should_run` event is set on _MirroredReplicaThread - # (`MRT`) threads. The execution waits until - # `MRT.has_paused` is set, which indicates that either `fn` is - # complete or a `get_replica_context().merge_call()` is called. If `fn` is - # complete, then `MRT.done` is set to True. Otherwise, arguments - # of `get_replica_context().merge_call` from all paused threads are grouped - # and the `merge_fn` is performed. Results of the - # `get_replica_context().merge_call` are then set to `MRT.merge_result`. - # Each such `get_replica_context().merge_call` call returns the - # `MRT.merge_result` for that thread when `MRT.should_run` event - # is reset again. Execution of `fn` resumes. - - try: - with coord.stop_on_exception(): - all_done = False - while not all_done and not coord.should_stop(): - done = [] - if run_concurrently: - for t in threads: - t.should_run.set() - for t in threads: - t.has_paused.wait() - t.has_paused.clear() - if coord.should_stop(): - return None - done.append(t.done) - else: - for t in threads: - t.should_run.set() - t.has_paused.wait() - t.has_paused.clear() - if coord.should_stop(): - return None - done.append(t.done) - if coord.should_stop(): - return None - all_done = all(done) - if not all_done: - if any(done): - raise RuntimeError("Some replicas made a different number of " - "replica_context().merge_call() calls.") - # get_replica_context().merge_call() case - merge_args = values.regroup(tuple(t.merge_args for t in threads)) - merge_kwargs = values.regroup(tuple(t.merge_kwargs for t in threads)) - # We capture the name_scope of the MRT when we call merge_fn - # to ensure that if we have opened a name scope in the MRT, - # it will be respected when executing the merge function. We only - # capture the name_scope from the first MRT and assume it is - # the same for all other MRTs. - mtt_captured_name_scope = threads[0].captured_name_scope - mtt_captured_var_scope = threads[0].captured_var_scope - # Capture and merge the control dependencies from all the threads. - mtt_captured_control_deps = set() - for t in threads: - mtt_captured_control_deps.update(t.captured_control_deps) - with ops.name_scope(mtt_captured_name_scope),\ - ops.control_dependencies(mtt_captured_control_deps), \ - variable_scope.variable_scope(mtt_captured_var_scope): - merge_result = threads[0].merge_fn(distribution, *merge_args, - **merge_kwargs) - for r, t in enumerate(threads): - t.merge_result = values.select_replica(r, merge_result) - finally: - for t in threads: - t.should_run.set() - coord.join(threads) - - return values.regroup(tuple(t.main_result for t in threads)) - - def _is_device_list_single_worker(devices): """Checks whether the devices list is for single or multi-worker. @@ -469,7 +311,6 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): "any local devices.") self._cross_device_ops = cross_device_ops self._initialize_strategy(devices) - self._cfer_fn_cache = weakref.WeakKeyDictionary() # TODO(b/128995245): Enable last partial batch support in graph mode. if ops.executing_eagerly_outside_functions(): @@ -739,35 +580,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): return self._get_cross_device_ops().broadcast(tensor, destinations) def _call_for_each_replica(self, fn, args, kwargs): - if isinstance(fn, def_function.Function): - wrapped = self._cfer_fn_cache.get(fn) - if wrapped is None: - # We need to wrap fn such that it triggers _call_for_each_replica inside - # the tf.function. - wrapped = fn._clone( # pylint: disable=protected-access - python_function=functools.partial(self._call_for_each_replica, - fn.python_function)) - self._cfer_fn_cache[fn] = wrapped - return wrapped(args, kwargs) - - if context.executing_eagerly(): - logging.log_first_n( - logging.WARN, "Using %s eagerly has significant " - "overhead currently. We will be working on improving " - "this in the future, but for now please wrap " - "`call_for_each_replica` or `experimental_run` or " - "`run` inside a tf.function to get the best performance." % - self._container_strategy().__class__.__name__, 5) - else: - # When a tf.function is wrapped to trigger _call_for_each_replica (see - # the other branch above), AutoGraph stops conversion at - # _call_for_each_replica itself (TF library functions are whitelisted). - # This makes sure that the Python function that originally passed to - # the tf.function is still converted. - fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) - - return _call_for_each_replica(self._container_strategy(), self._devices, - fn, args, kwargs) + return mirrored_run.call_for_each_replica(self._container_strategy(), fn, + args, kwargs) def _configure(self, session_config=None, @@ -912,203 +726,3 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): def _in_multi_worker_mode(self): """Whether this strategy indicates working in multi-worker settings.""" return False - - -class _MirroredReplicaThread(threading.Thread): - """A thread that runs() a function on a device.""" - - def __init__(self, dist, coord, replica_id, devices, variable_creator_fn, - fn, args, kwargs): - super(_MirroredReplicaThread, self).__init__() - self.coord = coord - self.distribution = dist - self.devices = devices - self.replica_id = replica_id - self.variable_creator_fn = variable_creator_fn - # State needed to run and return the results of `fn`. - self.main_fn = fn - self.main_args = args - self.main_kwargs = kwargs - self.main_result = None - self.done = False - # State needed to run the next merge_call() (if any) requested via - # ReplicaContext. - self.merge_fn = None - self.merge_args = None - self.merge_kwargs = None - self.merge_result = None - self.captured_name_scope = None - self.captured_var_scope = None - # We use a thread.Event for the main thread to signal when this - # thread should start running (`should_run`), and another for - # this thread to transfer control back to the main thread - # (`has_paused`, either when it gets to a - # `get_replica_context().merge_call` or when `fn` returns). In - # either case the event starts cleared, is signaled by calling - # set(). The receiving thread waits for the signal by calling - # wait() and then immediately clearing the event using clear(). - self.should_run = threading.Event() - self.has_paused = threading.Event() - # These fields have to do with inheriting various contexts from the - # parent thread: - context.ensure_initialized() - ctx = context.context() - self.in_eager = ctx.executing_eagerly() - self.record_thread_local_summary_state() - self.record_thread_local_eager_context_state() - self.context_device_policy = ( - pywrap_tfe.TFE_ContextGetDevicePlacementPolicy( - ctx._context_handle)) # pylint: disable=protected-access - self.graph = ops.get_default_graph() - with ops.init_scope(): - self._init_in_eager = context.executing_eagerly() - self._init_graph = ops.get_default_graph() - self._variable_creator_stack = self.graph._variable_creator_stack[:] # pylint: disable=protected-access - self._var_scope = variable_scope.get_variable_scope() - # Adding a "/" at end lets us re-enter this scope later. - self._name_scope = self.graph.get_name_scope() - if self._name_scope: - self._name_scope += "/" - if self.replica_id > 0: - if not self._name_scope: - self._name_scope = "" - self._name_scope += "replica_%d/" % self.replica_id - - def run(self): - self.should_run.wait() - self.should_run.clear() - try: - if self.coord.should_stop(): - return - self.restore_thread_local_summary_state() - self.restore_thread_local_eager_context_state() - # TODO(josh11b): Use current logical device instead of 0 here. - with self.coord.stop_on_exception(), \ - _enter_graph(self._init_graph, self._init_in_eager), \ - _enter_graph(self.graph, self.in_eager, - self._variable_creator_stack), \ - context.device_policy(self.context_device_policy), \ - MirroredReplicaContext(self.distribution, constant_op.constant( - self.replica_id, dtypes.int32)), \ - ops.device(self.devices[self.replica_id]), \ - ops.name_scope(self._name_scope), \ - variable_scope.variable_scope( - self._var_scope, reuse=self.replica_id > 0), \ - variable_scope.variable_creator_scope(self.variable_creator_fn): - self.main_result = self.main_fn(*self.main_args, **self.main_kwargs) - self.done = True - finally: - self.has_paused.set() - - def record_thread_local_summary_state(self): - """Record the thread local summary state in self.""" - # TODO(slebedev): is this still relevant? the referenced bug is closed. - summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access - self._summary_step = summary_state.step - self._summary_writer = summary_state.writer - self._summary_recording = summary_state.is_recording - self._summary_recording_distribution_strategy = ( - summary_state.is_recording_distribution_strategy) - - def restore_thread_local_summary_state(self): - """Restore thread local summary state from self.""" - # TODO(slebedev): is this still relevant? the referenced bug is closed. - summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access - summary_state.step = self._summary_step - summary_state.writer = self._summary_writer - summary_state.is_recording = self._summary_recording - summary_state.is_recording_distribution_strategy = ( - self._summary_recording_distribution_strategy) - - def record_thread_local_eager_context_state(self): - ctx = context.context() - eager_context_state = ctx._thread_local_data # pylint: disable=protected-access - self._eager_context_op_callbacks = eager_context_state.op_callbacks - # TODO(b/125892694): record other fields in EagerContext. - - def restore_thread_local_eager_context_state(self): - ctx = context.context() - eager_context_state = ctx._thread_local_data # pylint: disable=protected-access - eager_context_state.op_callbacks = self._eager_context_op_callbacks - # TODO(b/125892694): record other fields in EagerContext. - - -class MirroredReplicaContext(distribute_lib.ReplicaContext): - """ReplicaContext used in MirroredStrategy.extended.call_for_each_replica(). - - Opened in `_MirroredReplicaThread`, to allow the user to invoke - `MirroredStrategy`'s specific implementation of `merge_call()`, - which works by delegating the function and its arguments to - the main thread (the one that invoked - `MirroredStrategy.extended.call_for_each_replica()`). - """ - - def _merge_call(self, fn, args, kwargs): - """Delegate to the main thread to actually perform merge_call().""" - t = threading.current_thread() # a _MirroredReplicaThread - t.merge_fn = fn - t.merge_args = args - t.merge_kwargs = kwargs - t.captured_name_scope = t.graph.get_name_scope() - # Adding a "/" at end lets us re-enter this scope later. - if t.captured_name_scope: - t.captured_name_scope += "/" - - t.captured_var_scope = variable_scope.get_variable_scope() - t.captured_control_deps = t.graph._current_control_dependencies() # pylint: disable=protected-access - - # It is problematic if `merge_call` is called under a different graph other - # than the one that `_call_for_each_replica` is called under, there are - # 3 cases this can happen: - # - # 1. The `fn` passed to `_call_for_each_replica` is decorated with - # `tf.function` and there is a `merge_call` in `fn`. Since - # MirroredStrategy traces a separate function per thread (per device), - # and each trace takes a shared lock, the lock is never released by the - # first thread and subsequent replica threads cannot proceed to trace - # their own functions. This issue is addressed by always converting - # `_call_for_each_replica(tf.function(f))` to - # ``tf.function(_call_for_each_replica(f))`.` in - # `MirroredStrategy._call_for_each_replica`. - # - # 2. The `fn` passed to `_call_for_each_replica` contains a nested - # `tf.function`, and there is a `merge_call` in the nested `tf.function`. - # In this case each thread can successfully trace its own function, but - # since the `merge_fn` passed to `merge_call` is executed in the main - # thread (where `_call_for_each_replica` is executed), it can't access - # the tensors that come from different graphs. - # - # 3. The `fn` passed to `_call_for_each_replica` contains a control-flow - # statement, and there is a `merge_call` inside the control-flow body, - # `fn` or `_call_for_each_replica` is decorated with `tf.function`. - # Control flow statement creates a separate graph for its body, similar - # to #2, `merge_fn` executed in the main thread can't access the - # tensors that come from different graphs. - # - # We raise an error for #2 and #3. - if ops.get_default_graph() != t.graph: - raise RuntimeError( - "`merge_call` called while defining a new graph or a tf.function." - " This can often happen if the function `fn` passed to" - " `strategy.run()` contains a nested `@tf.function`, and the nested " - "`@tf.function` contains a synchronization point, such as aggregating" - " gradients (e.g, optimizer.apply_gradients), or if the function `fn`" - " uses a control flow statement which contains a synchronization" - " point in the body. Such behaviors are not yet supported. Instead," - " please avoid nested `tf.function`s or control flow statements that" - " may potentially cross a synchronization boundary, for example," - " wrap the `fn` passed to `strategy.run` or the entire `strategy.run`" - " inside a `tf.function` or move the control flow out of `fn`") - - t.has_paused.set() - t.should_run.wait() - t.should_run.clear() - if t.coord.should_stop(): - raise _RequestedStop() - return t.merge_result - - @property - def devices(self): - distribute_lib.require_replica_context(self) - replica_id = tensor_util.constant_value(self._replica_id_in_sync_group) - return [self._strategy.extended.worker_devices_by_replica[replica_id]] diff --git a/tensorflow/python/distribute/parameter_server_strategy.py b/tensorflow/python/distribute/parameter_server_strategy.py index 7099e4a6390..e1f0f41b393 100644 --- a/tensorflow/python/distribute/parameter_server_strategy.py +++ b/tensorflow/python/distribute/parameter_server_strategy.py @@ -25,7 +25,7 @@ from tensorflow.python.distribute import cross_device_ops as cross_device_ops_li from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import input_lib -from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.distribute import mirrored_run from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import values @@ -456,9 +456,8 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): return var_creator(**kwargs) def _call_for_each_replica(self, fn, args, kwargs): - # pylint: disable=protected-access - return mirrored_strategy._call_for_each_replica( - self._container_strategy(), self._compute_devices, fn, args, kwargs) + return mirrored_run.call_for_each_replica(self._container_strategy(), fn, + args, kwargs) def _verify_destinations_not_different_worker(self, destinations): if not self._cluster_spec: diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index 3c095469927..6e51b84a1d1 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import atexit import collections import contextlib import copy @@ -327,6 +328,11 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): self._logical_device_stack = [0] + if context.executing_eagerly(): + # In async remote eager, we want to sync the exectors before exiting the + # program. + atexit.register(context.async_wait) + # TODO(bfontain): Remove once a proper dataset API exists for prefetching # a dataset to multiple devices exists. # If value is true, this forces prefetch of data to the host's memeory rather diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py index c3c3e0d5286..96089a67b00 100644 --- a/tensorflow/python/distribute/values_test.py +++ b/tensorflow/python/distribute/values_test.py @@ -1841,7 +1841,6 @@ class AggregatingVariableTest(test.TestCase, parameterized.TestCase): self.assertEqual(self.evaluate(aggregating._v.read_value()), 3.) def testAssignAdd(self, distribution): - self.skipTest("b/151250566") with distribution.scope(): v = variable_scope.variable( 1, aggregation=variables_lib.VariableAggregation.MEAN) diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index 1d4b5a59d01..31e821068d3 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -241,7 +241,7 @@ class MicroBenchmarks(test.Benchmark): def _benchmark_add(self, a, b): def func(): - return memoryview(math_ops.add(a, b)) + return memoryview(math_ops.add_v2(a, b)) with ops.device("GPU:0" if context.num_gpus() else "CPU:0"): for _ in range(1000): diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py index 998350f695d..f2b8ad8be7d 100644 --- a/tensorflow/python/eager/def_function_test.py +++ b/tensorflow/python/eager/def_function_test.py @@ -26,7 +26,6 @@ from absl.testing import parameterized from six.moves import range from tensorflow.python.autograph.core import converter -from tensorflow.python.eager import backprop from tensorflow.python.eager import def_function from tensorflow.python.eager import lift_to_graph from tensorflow.python.framework import constant_op @@ -35,8 +34,6 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util -from tensorflow.python.keras.engine import training -from tensorflow.python.keras.layers import core from tensorflow.python.module import module from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -46,26 +43,6 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test -from tensorflow.python.training import adam - - -class _ModelWithOptimizer(training.Model): - - def __init__(self): - super(_ModelWithOptimizer, self).__init__() - self.dense = core.Dense(1) - self.optimizer = adam.AdamOptimizer(0.01) - - @def_function.function( - input_signature=(tensor_spec.TensorSpec([None, 2], dtypes.float32), - tensor_spec.TensorSpec([None], dtypes.float32))) - def call(self, x, y): - with backprop.GradientTape() as tape: - loss = math_ops.reduce_mean((self.dense(x) - y) ** 2.) - trainable_variables = self.trainable_variables - gradients = tape.gradient(loss, trainable_variables) - self.optimizer.apply_gradients(zip(gradients, trainable_variables)) - return {'loss': loss} class _HasDecoratedMethod(object): @@ -74,6 +51,7 @@ class _HasDecoratedMethod(object): def f(self, x): return x * 3. + class DefFunctionTest(test.TestCase, parameterized.TestCase): def testNoVariables(self): @@ -311,12 +289,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): input_signature=[tensor_spec.TensorSpec((), dtypes.int32)]) self.assertEqual(3, wrapped(constant_op.constant(1)).numpy()) - def test_optimizer(self): - x = constant_op.constant([[3., 4.]]) - y = constant_op.constant([2.]) - model = _ModelWithOptimizer() - model(x, y) - def test_concrete_function_from_signature(self): @def_function.function( diff --git a/tensorflow/python/eager/forwardprop_test.py b/tensorflow/python/eager/forwardprop_test.py index 1dc580549ce..aad179ffb6b 100644 --- a/tensorflow/python/eager/forwardprop_test.py +++ b/tensorflow/python/eager/forwardprop_test.py @@ -35,9 +35,6 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.keras.layers import convolutional -from tensorflow.python.keras.layers import core -from tensorflow.python.keras.layers import normalization_v2 from tensorflow.python.module import module from tensorflow.python.ops import array_ops from tensorflow.python.ops import custom_gradient @@ -384,96 +381,6 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase): def testElementwiseNNOps(self, value, op_fn): _test_gradients(self, op_fn, [constant_op.constant(value)], order=3) - @parameterized.named_parameters( - [("Dense", [[0.1]], functools.partial(core.Dense, 5)), - ("Conv2D", - np.reshape(np.arange(start=-1., stop=1., step=2. / (1 * 2 * 4 * 4)), - [1, 2, 4, 4]), - functools.partial(convolutional.Conv2D, 2, 2), 1e-3)]) - def testKerasLayers(self, value, op_fn, atol=1e-6): - layer = op_fn() - input_value = constant_op.constant(value, dtype=dtypes.float32) - layer.build(input_value.shape) - # Make sure the test is deterministic by avoiding random variable - # initialization. - for v in layer.trainable_variables: - v.assign(array_ops.reshape( - math_ops.range( - -1., 1., 2. / array_ops.size(v, out_type=dtypes.float32), - dtype=dtypes.float32), - v.shape)) - _test_gradients( - self, layer, [input_value], atol=atol, - # These are linear, so second-order is pretty boring. - order=2) - - @parameterized.named_parameters( - [("NonFused", [[0.1], [0.2], [-0.3]], - functools.partial(normalization_v2.BatchNormalization, fused=False)), - ("Fused", [[[[0.1, 2.]]], [[[0.2, -3.]]], [[[-0.3, 4.]]]], - functools.partial(normalization_v2.BatchNormalization, fused=True))]) - def testBatchNorm(self, value, op_fn): - for training in [True, False]: - layer = op_fn() - input_value = constant_op.constant(value, dtype=dtypes.float32) - layer.build(input_value.shape) - _test_gradients( - self, functools.partial(layer, training=training), [input_value], - order=2, atol=1e-3) - - @parameterized.named_parameters( - [("NonFused", [[0.1], [0.2], [-0.3]], - functools.partial(normalization_v2.BatchNormalization, fused=False)), - ("Fused", [[[[0.1, 2.]]], [[[0.2, -3.]]], [[[-0.3, 4.]]]], - functools.partial(normalization_v2.BatchNormalization, fused=True))]) - def testBatchNormLayerParamGrads(self, value, op_fn): - for training in [True, False]: - layer = op_fn() - with backprop.GradientTape() as tape: - input_value = constant_op.constant(value, dtype=dtypes.float32) - tape.watch(input_value) - output = layer(input_value, training=training) - jac_back = tape.jacobian( - output, [input_value] + layer.trainable_variables) - jac_forward = _jacfwd( - lambda *args: layer(args[0], training=training), # pylint:disable=cell-var-from-loop - [input_value] + layer.trainable_variables) - for backward, forward in zip(jac_back, jac_forward): - forward = array_ops.reshape(forward, array_ops.shape(backward)) - self.assertAllClose(backward, forward) - - @parameterized.named_parameters( - [("NCHW", "NCHW", [4, 3, 2, 2], 3, False), - ("NHWC", "NHWC", [4, 2, 2, 3], 3, False), - ("NCHWForward", "NCHW", [2, 2, 1, 1], 2, True), - ("NHWCForward", "NHWC", [2, 1, 1, 2], 2, True),]) - def testFusedBatchNormGradsTraining(self, data_format, x_shape, channels, - test_back_over_forward): - increment = 3. / math_ops.reduce_prod( - constant_op.constant(x_shape, dtype=dtypes.float32)) - x = array_ops.reshape(math_ops.range(-2., 1., increment), x_shape) - scale = constant_op.constant([1., 1.1, 0.9])[:channels] - offset = constant_op.constant([-0.5, -0.6, -0.7])[:channels] - epsilon = 0.001 - - def _bn_fused(x_arg, scale_arg, offset_arg): - return nn_impl.fused_batch_norm(x_arg, scale_arg, offset_arg, - epsilon=epsilon, is_training=True, - data_format=data_format)[0] - _test_gradients(self, _bn_fused, [x, scale, offset], order=2, atol=1e-3) - if test_back_over_forward: - # Note that this uses a loop over parameters, and so is quite slow. Thus - # it's skipped for the larger test cases. - gradfwd_x = _gradfwd(_bn_fused, 0) - _test_gradients(self, gradfwd_x, [x, scale, offset], order=1, - atol=1e-3) - gradfwd_scale = _gradfwd(_bn_fused, 1) - _test_gradients(self, gradfwd_scale, [x, scale, offset], order=1, - atol=1e-3) - gradfwd_offset = _gradfwd(_bn_fused, 2) - _test_gradients(self, gradfwd_offset, [x, scale, offset], order=1, - atol=1e-3) - def testFusedBatchNormGradsInference(self): if test.is_built_with_rocm(): @@ -499,55 +406,6 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase): _test_gradients(self, _bn_fused, [x, scale, offset], order=2, atol=1e-2) - @parameterized.named_parameters( - [("Function", def_function.function), - ("NoFunction", lambda f: f)]) - def testVariablesHVP(self, decorator): - - if test.is_built_with_rocm(): - # TODO(rocm) - # This test was recently added and has never passed on the - # ROCm platform. Remove this skip once the test is passing again - self.skipTest("NoFunction decorator test fails on the ROCm platform") - - class _Model(module.Module): - - def __init__(self): - self._first_dense = core.Dense(18) - self._conv = convolutional.Conv2D(2, 2) - self._norm = normalization_v2.BatchNormalization() - self._second_dense = core.Dense(1) - - def __call__(self, x): - x = self._first_dense(x) - x = nn_ops.relu(x) - x = self._norm(x) - x = nn_ops.relu(self._conv(array_ops.reshape(x, [-1, 2, 3, 3]))) - return self._second_dense(x) - - model = _Model() - def _loss(): - input_value = constant_op.constant([[-0.5, 1.], [0.5, -1.]]) - target = constant_op.constant([[-1.], [2.]]) - return math_ops.reduce_sum((model(input_value) - target) ** 2.) - - @decorator - def _compute_hvps(): - with backprop.GradientTape() as tape: - loss = _loss() - vector = tape.gradient(loss, model.trainable_variables) - variable_input_fn = lambda unused_variables: _loss() - forward_over_back_hvp, = _hvp( - variable_input_fn, [model.trainable_variables], [vector]) - with backprop.GradientTape(persistent=True) as tape: - tape.watch(model.trainable_variables) - loss = _loss() - first_grads = tape.gradient(loss, model.trainable_variables) - back_over_back_hvp = tape.gradient( - first_grads, model.trainable_variables, output_gradients=vector) - return forward_over_back_hvp, back_over_back_hvp - self.assertAllClose(*_compute_hvps(), rtol=1e-5, atol=1e-5) - def testPushPopAccumulatorState(self): # Note that this example is somewhat contrived. push_forwardprop_state is # probably only useful in practice for building functions that compute jvps @@ -1069,29 +927,6 @@ class HessianTests(test.TestCase, parameterized.TestCase): use_pfor=True, dtype=[dtypes.float32]) self.assertAllClose(hess_value, hessian_pfor) - @parameterized.named_parameters( - [("PFor", True), - ("MapFn", False)]) - def testHessianOfVariables(self, use_pfor): - model = core.Dense(1) - model.build([None, 2]) - - def _loss(*unused_args): - input_value = constant_op.constant([[-0.5, 1.], [0.5, -1.]]) - target = constant_op.constant([[-1.], [2.]]) - return math_ops.reduce_sum((model(input_value) - target) ** 2.) - - kernel_hess, bias_hess = _forward_over_back_hessian( - _loss, [model.kernel, model.bias], use_pfor=use_pfor, - dtype=[dtypes.float32, dtypes.float32]) - # 3 total parameters, the whole hessian is the 3x3 concatenation - self.assertEqual([3, 2, 1], kernel_hess.shape) - self.assertEqual([3, 1], bias_hess.shape) - full_hessian = array_ops.concat( - [array_ops.reshape(kernel_hess, [3, 2]), bias_hess], axis=1) - # The full Hessian should be symmetric. - self.assertAllClose(full_hessian, array_ops.transpose(full_hessian)) - if __name__ == "__main__": # TODO(allenl): Also test with 1.x-style graph mode. diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index c16060422b8..d7dd53d038a 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -2494,9 +2494,6 @@ class Function(object): args, kwargs = None, None with self._lock: graph_function, args, kwargs = self._maybe_define_function(args, kwargs) - if self.input_signature: - args = self.input_signature - kwargs = {} seen_names = set() captured = object_identity.ObjectIdentitySet( graph_function.graph.internal_captures) diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index d98c83665b9..edfb6b0a347 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -30,9 +30,7 @@ import numpy from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 -from tensorflow.python import keras from tensorflow.python.autograph.core import ag_ctx -from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import backprop from tensorflow.python.eager import cancellation from tensorflow.python.eager import context @@ -52,9 +50,6 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util -from tensorflow.python.keras.engine import training as keras_training -from tensorflow.python.keras.layers import core -from tensorflow.python.keras.optimizer_v2 import adam from tensorflow.python.layers import convolutional from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -68,7 +63,6 @@ from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops from tensorflow.python.ops import list_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope @@ -94,28 +88,6 @@ def total_function_cache(defined): # pylint: enable=protected-access -class MiniModel(keras_training.Model): - """Minimal model for mnist. - - Useful for testing and debugging on slow TPU simulators. - """ - - def __init__(self): - super(MiniModel, self).__init__(name='') - self.fc = keras.layers.Dense(1, name='fc', kernel_initializer='ones', - bias_initializer='ones') - - def call(self, inputs, training=True): - return self.fc(inputs) - - -class DefunnedMiniModel(MiniModel): - - @function.defun - def call(self, inputs, training=True): - return super(DefunnedMiniModel, self).call(inputs, training=training) - - def _example_indexed_slices_with_dense_shape(): return indexed_slices.IndexedSlices( constant_op.constant([1, 2]), constant_op.constant([0, 1]), @@ -439,26 +411,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase): self.assertTrue(unknown_dim[0]) self.assertLen(total_function_cache(func), 3) - def testFunctionRelaxationLosesInnerDimWithKerasLayer(self): - layer = keras.layers.Dense(1) - fn = def_function.function(experimental_relax_shapes=True)(layer) - - with self.captureWritesToStream(sys.stderr) as printed: - fn(array_ops.ones((3, 2))) - self.assertNotIn('ValueError', printed.contents()) - with self.captureWritesToStream(sys.stderr) as printed: - # Use batch size 2 to trigger a second cache miss on the shape. - fn(array_ops.ones((2, 2))) - self.assertNotIn('ValueError', printed.contents()) - - # Shape relaxation passes TensorShape([None, None]), which causes layer - # matmul to fail, due to incompatible dims. What would have been a graph - # build time error (layer would complain about the inner dim being 4). - with self.captureWritesToStream(sys.stderr) as printed: - with self.assertRaisesRegexp(errors.InvalidArgumentError, - r'Matrix size-incompatible'): - fn(array_ops.ones((3, 4))) - def testNestedShapeFunctionRelaxation(self): got_shape = [None] @@ -1513,24 +1465,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase): has_device.f() self.assertIn('CPU', has_device.v.device) - @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) - def testDefunKerasModelCall(self): - model = MiniModel() - model.call = function.defun(model.call) - - x = array_ops.ones([1, 2]) - y = model(x) - - if not context.executing_eagerly(): - self.evaluate(variables.global_variables_initializer()) - - self.assertAllEqual([[3.0]], self.evaluate(y)) - - # Break the reference cycle between the MiniModel and the defun: - # `MiniModel` --(through its `call` method)--> `Function` - # `Function` --(instancemethod on `MiniModel`)--> `MiniModel` - del model.call - @test_util.run_in_graph_and_eager_modes def testDeviceAnnotationsRespected(self): @@ -2712,54 +2646,18 @@ class FunctionTest(test.TestCase, parameterized.TestCase): self.assertLen(total_function_cache(defined), 3 if ops.Tensor._USE_EQUALITY else 5) - def testDecoratedMethod(self): - m = DefunnedMiniModel() - instance_call_one = m.call(array_ops.ones([1, 2]), training=True) - instance_call_two = m.call( - inputs=array_ops.ones([1, 2]), training=True) - class_call = DefunnedMiniModel.call(m, array_ops.ones([1, 2]), - training=True) - self.assertAllEqual(instance_call_one, instance_call_two) - self.assertAllEqual(instance_call_one, class_call) - - def testDecoratedMethodUniqueFunctionPerInstance(self): - m = DefunnedMiniModel() - n = DefunnedMiniModel() - - class_method_one = DefunnedMiniModel.call - class_method_two = DefunnedMiniModel.call - - m_method_one = m.call - m_method_two = m.call - - n_method_one = n.call - n_method_two = n.call - - self.assertEqual(class_method_one, class_method_two) - self.assertEqual(m_method_one, m_method_two) - self.assertEqual(n_method_one, n_method_two) - self.assertNotEqual(m.call, n.call) - def testDecoratedMethodInspect(self): + + class DefunnedMiniModel(object): + + @function.defun + def call(self, inputs, training=True): + pass + m = DefunnedMiniModel() fullargspec = tf_inspect.getfullargspec(m.call) self.assertIn('training', fullargspec.args) - def testDecoratedMethodGetConcreteFunction(self): - m = DefunnedMiniModel() - instance_call_one = m.call.get_concrete_function( - array_ops.ones([1, 2]), training=False) - instance_call_two = m.call.get_concrete_function( - inputs=array_ops.ones([1, 2]), training=False) - self.assertAllEqual(instance_call_one(array_ops.ones([1, 2])), - instance_call_two(array_ops.ones([1, 2]))) - - # Also make sure get_concrete_function works on the class method - DefunnedMiniModel.call.get_concrete_function( - m, array_ops.ones([1, 2]), training=False) - DefunnedMiniModel.call.get_concrete_function( - m, inputs=array_ops.ones([1, 2]), training=True) - def testFunctionModifiesInputList(self): # Tests on `list` methods that do in place modification, except `list.sort` # since it cannot even be "defunned" in the first place @@ -2915,21 +2813,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase): modify_same_flat(nested_input) - def testDecoratedMethodVariableCleanup(self): - m = DefunnedMiniModel() - m(array_ops.ones([1, 2])) - variable_refs = list({v.ref() for v in m.variables}) - self.assertLen(variable_refs, 2) - del m - - # Verifying if the variables are only referenced from variable_refs. - # We expect the reference counter to be 1, but `sys.getrefcount` reports - # one higher reference counter because a temporary is created when we call - # sys.getrefcount(). Hence check if the number returned is 2. - # https://docs.python.org/3/library/sys.html#sys.getrefcount - self.assertEqual(sys.getrefcount(variable_refs[0].deref()), 2) - self.assertEqual(sys.getrefcount(variable_refs[1].deref()), 2) - def testExecutorType(self): @function.defun def add_five(x): @@ -3592,56 +3475,6 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase): self.assertEqual((v,), tape.watched_variables()) - def testStandardTrainingLoopInFunction(self): - layer = core.Dense(2) - dataset = ( - dataset_ops.DatasetV2.from_tensors( - (array_ops.ones([784]), array_ops.ones([], dtypes.int32))) - .map(lambda x, y: (x, y)) - .repeat(10) - .batch(32)) - optimizer = adam.Adam() - - @def_function.function - def train(): - for x, y in dataset: - with backprop.GradientTape() as tape: - out = layer(x) - loss = math_ops.reduce_mean( - nn_ops.sparse_softmax_cross_entropy_with_logits( - logits=out, labels=y)) - layer_variables = layer.trainable_variables - gradients = tape.gradient(loss, layer_variables) - optimizer.apply_gradients(zip(gradients, layer_variables)) - - train() - - def testEarlyStoppingTrainingLoopInFunction(self): - layer = core.Dense(2) - dataset = ( - dataset_ops.DatasetV2.from_tensors( - (array_ops.ones([784]), array_ops.ones([], dtypes.int32))) - .map(lambda x, y: (x, y)) - .repeat(10) - .batch(32)) - optimizer = adam.Adam() - - @def_function.function - def train(): - for x, y in dataset: - with backprop.GradientTape() as tape: - out = layer(x) - loss = math_ops.reduce_mean( - nn_ops.sparse_softmax_cross_entropy_with_logits( - logits=out, labels=y)) - layer_variables = layer.trainable_variables - gradients = tape.gradient(loss, layer_variables) - optimizer.apply_gradients(zip(gradients, layer_variables)) - if optimizer.iterations > 3: - break - - train() - def testDeferredCapture(self): value = 1.0 diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index f8e1fb568ac..7dd7eb53fb1 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -292,15 +292,6 @@ TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx, } } - // We always enable implicit mirroring for constants. Without this, code - // written previously under the assumption that - // - // with tf.device('GPU:0'): x = tf.constant(1.0) - // - // will be placed in the GPU will suffer a non-trivial performance regression - // (measured at ~20% for certain benchmarks). - handle->handle->EnableImplicitMirroring(); - return handle.release(); } diff --git a/tensorflow/python/eager/pywrap_tfe_test.py b/tensorflow/python/eager/pywrap_tfe_test.py index c2389025a25..9539b952617 100644 --- a/tensorflow/python/eager/pywrap_tfe_test.py +++ b/tensorflow/python/eager/pywrap_tfe_test.py @@ -18,6 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys +import traceback + import numpy as np from tensorflow.python import pywrap_tfe @@ -28,6 +31,7 @@ from tensorflow.python.eager import def_function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util @@ -332,6 +336,18 @@ class Tests(test.TestCase): # TODO(b/147828820): Converting with tensors should work. # _ = ops.EagerTensor([[t]], device=ctx.device_name, dtype=None) + def testFallbackErrorNotVisibleWhenFallbackMethodRaises(self): + ctx = context.context() + ctx.ensure_initialized() + + try: + math_ops.mat_mul([[1., 1.] * 2], [[1., 1.] * 3]) + except errors.InvalidArgumentError: + etype, value, tb = sys.exc_info() + full_exception_text = " ".join( + traceback.format_exception(etype, value, tb)) + + self.assertNotRegex(full_exception_text, "_FallbackException") if __name__ == "__main__": test.main() diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 93a13f4e3ce..f53825381d9 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -1969,7 +1969,8 @@ class _CategoricalColumn(_FeatureColumn): WARNING: Do not subclass this layer unless you know what you are doing: the API is subject to future changes. - A categorical feature typically handled with a `tf.SparseTensor` of IDs. + A categorical feature typically handled with a `tf.sparse.SparseTensor` of + IDs. """ IdWeightPair = collections.namedtuple( # pylint: disable=invalid-name diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py index 4003b8e1093..f981909aef1 100644 --- a/tensorflow/python/feature_column/feature_column_v2.py +++ b/tensorflow/python/feature_column/feature_column_v2.py @@ -152,7 +152,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops -from tensorflow.python.ops import init_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -543,7 +542,7 @@ class _LinearModelLayer(Layer): name='weights', dtype=dtypes.float32, shape=(first_dim, self._units), - initializer=init_ops.zeros_initializer(), + initializer=initializers.zeros(), trainable=self.trainable) # Create a bias variable. @@ -551,7 +550,7 @@ class _LinearModelLayer(Layer): name='bias_weights', dtype=dtypes.float32, shape=[self._units], - initializer=init_ops.zeros_initializer(), + initializer=initializers.zeros(), trainable=self.trainable, use_resource=True, # TODO(rohanj): Get rid of this hack once we have a mechanism for @@ -962,7 +961,7 @@ def embedding_column(categorical_column, 'Embedding of column_name: {}'.format( categorical_column.name)) if initializer is None: - initializer = init_ops.truncated_normal_initializer( + initializer = initializers.truncated_normal( mean=0.0, stddev=1 / math.sqrt(dimension)) return EmbeddingColumn( @@ -1104,7 +1103,7 @@ def shared_embedding_columns(categorical_columns, if (initializer is not None) and (not callable(initializer)): raise ValueError('initializer must be callable if specified.') if initializer is None: - initializer = init_ops.truncated_normal_initializer( + initializer = initializers.truncated_normal( mean=0.0, stddev=1. / math.sqrt(dimension)) # Sort the columns so the default collection name is deterministic even if the @@ -1287,7 +1286,7 @@ def shared_embedding_columns_v2(categorical_columns, if (initializer is not None) and (not callable(initializer)): raise ValueError('initializer must be callable if specified.') if initializer is None: - initializer = init_ops.truncated_normal_initializer( + initializer = initializers.truncated_normal( mean=0.0, stddev=1. / math.sqrt(dimension)) # Sort the columns so the default collection name is deterministic even if the @@ -2515,7 +2514,8 @@ def _create_dense_column_weighted_sum(column, transformation_cache, class CategoricalColumn(FeatureColumn): """Represents a categorical feature. - A categorical feature typically handled with a `tf.SparseTensor` of IDs. + A categorical feature typically handled with a `tf.sparse.SparseTensor` of + IDs. """ IdWeightPair = collections.namedtuple( # pylint: disable=invalid-name diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py index 8c9aee722a6..fe769850fb0 100644 --- a/tensorflow/python/feature_column/feature_column_v2_test.py +++ b/tensorflow/python/feature_column/feature_column_v2_test.py @@ -41,8 +41,8 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util +from tensorflow.python.keras import initializers from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import partitioned_variables @@ -6662,6 +6662,9 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase): self.assertEqual([categorical_column], embedding_column.parents) config = embedding_column.get_config() + # initializer config contains `dtype` in v1. + initializer_config = initializers.serialize(initializers.truncated_normal( + mean=0.0, stddev=1 / np.sqrt(2))) self.assertEqual( { 'categorical_column': { @@ -6675,24 +6678,15 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase): 'ckpt_to_load_from': None, 'combiner': 'mean', 'dimension': 2, - 'initializer': { - 'class_name': 'TruncatedNormal', - 'config': { - 'dtype': 'float32', - 'stddev': 0.7071067811865475, - 'seed': None, - 'mean': 0.0 - } - }, + 'initializer': initializer_config, 'max_norm': None, 'tensor_name_in_ckpt': None, 'trainable': True, 'use_safe_embedding_lookup': True }, config) - custom_objects = {'TruncatedNormal': init_ops.TruncatedNormal} new_embedding_column = fc.EmbeddingColumn.from_config( - config, custom_objects=custom_objects) + config, custom_objects=None) self.assertEqual(embedding_column.get_config(), new_embedding_column.get_config()) self.assertIsNot(categorical_column, @@ -6700,7 +6694,7 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase): new_embedding_column = fc.EmbeddingColumn.from_config( config, - custom_objects=custom_objects, + custom_objects=None, columns_by_name={ serialization._column_name_with_class_name(categorical_column): categorical_column diff --git a/tensorflow/python/framework/indexed_slices.py b/tensorflow/python/framework/indexed_slices.py index c1b3a1775ec..8e9c6f63dca 100644 --- a/tensorflow/python/framework/indexed_slices.py +++ b/tensorflow/python/framework/indexed_slices.py @@ -80,7 +80,7 @@ class IndexedSlices(tensor_like.TensorLike, composite_tensor.CompositeTensor): (e.g. `tf.gather`). Contrast this representation with - `tf.SparseTensor`, + `tf.sparse.SparseTensor`, which uses multi-dimensional indices and scalar values. """ diff --git a/tensorflow/python/framework/op_callbacks.py b/tensorflow/python/framework/op_callbacks.py index bfd41f0465a..0f2515b6fd1 100644 --- a/tensorflow/python/framework/op_callbacks.py +++ b/tensorflow/python/framework/op_callbacks.py @@ -79,7 +79,7 @@ def add_op_callback(callback_fn): # "MatMul_2". # graph: The graph that the op belongs to (if any). # - In eager execution of an op or FuncGraph, this is `None`. - # - In graph construction, this is the op's containing graph + # - In graph construction, this is the op's enclosing graph # as a `tf.Graph` object. # # Return values: @@ -89,7 +89,7 @@ def add_op_callback(callback_fn): # `outputs` argument. # If the return value is `None`, downstream execution or graph # construction will be unaffected. - # Howevevr, if the return value is a `list` or `tuple` of `Tensor`s, + # However, if the return value is a `list` or `tuple` of `Tensor`s, # - In eager execution, these returned `Tensor`s should be # `EagerTensor`s. Their values will replace the original values of # `outputs` for downstream eager execution. (*Not implemented yet*). diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 600ba7414ce..51d2e272703 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -92,6 +92,13 @@ _api_usage_gauge = monitoring.BoolGauge( "/tensorflow/api/ops_eager_execution", "Whether ops.enable_eager_execution() is called.") +_tensor_equality_api_usage_gauge = monitoring.BoolGauge( + "/tensorflow/api/enable_tensor_equality", + "Whether ops.enable_tensor_equality() is called.") + +_control_flow_api_gauge = monitoring.BoolGauge( + "/tensorflow/api/enable_control_flow_v2", + "Whether enable_control_flow_v2() is called.") # pylint: disable=protected-access _DTYPES_INTERN_TABLE = dtypes._INTERN_TABLE @@ -277,14 +284,17 @@ def enable_tensor_equality(): unhashable. Thus tensors can no longer be directly used in sets or as a key in a dictionary. """ + _tensor_equality_api_usage_gauge.get_cell().set(True) Tensor._USE_EQUALITY = True # pylint: disable=protected-access + @tf_export(v1=["disable_tensor_equality"]) def disable_tensor_equality(): """Compare Tensors by their id and be hashable. This is a legacy behaviour of TensorFlow and is highly discouraged. """ + _tensor_equality_api_usage_gauge.get_cell().set(False) Tensor._USE_EQUALITY = False # pylint: disable=protected-access @@ -340,7 +350,7 @@ class Tensor(tensor_like.TensorLike): shape of a tensor at execution time. A number of specialized tensors are available: see `tf.Variable`, - `tf.constant`, `tf.placeholder`, `tf.SparseTensor`, and + `tf.constant`, `tf.placeholder`, `tf.sparse.SparseTensor`, and `tf.RaggedTensor`. For more on Tensors, see the [guide](https://tensorflow.org/guide/tensor). @@ -876,8 +886,10 @@ class Tensor(tensor_like.TensorLike): __array_priority__ = 100 def __array__(self): - raise NotImplementedError("Cannot convert a symbolic Tensor ({}) to a numpy" - " array.".format(self.name)) + raise NotImplementedError( + "Cannot convert a symbolic Tensor ({}) to a numpy array." + " This error may indicate that you're trying to pass a Tensor to" + " a NumPy call, which is not supported".format(self.name)) def __len__(self): raise TypeError("len is not well defined for symbolic Tensors. ({}) " diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index 5912c26a5a0..857cc7b6638 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -806,18 +806,6 @@ void GenEagerPythonOp::AddEagerFastPathExecute() { // Handle fallback. if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", "); strings::StrAppend(&fallback_params, "ctx=_ctx"); - strings::StrAppend(&result_, " ", "except _core._FallbackException:\n"); - strings::StrAppend(&result_, " try:\n"); - strings::StrAppend( - &result_, " ", "return ", function_name_, kEagerFallbackSuffix, - "(\n", - WordWrap(strings::StrCat(" "), - strings::StrCat(fallback_params, ")"), kRightMargin), - "\n"); - strings::StrAppend(&result_, " except _core._SymbolicException:\n"); - strings::StrAppend(&result_, - " pass # Add nodes to the TensorFlow graph.\n"); - AddDispatch(" "); // Any errors thrown from execute need to be unwrapped from // _NotOkStatusException. @@ -825,6 +813,20 @@ void GenEagerPythonOp::AddEagerFastPathExecute() { "except _core._NotOkStatusException as e:\n"); strings::StrAppend(&result_, " ", "_ops.raise_from_not_ok_status(e, name)\n"); + + strings::StrAppend(&result_, " ", "except _core._FallbackException:\n"); + strings::StrAppend(&result_, " pass\n"); + strings::StrAppend(&result_, " try:\n"); + strings::StrAppend( + &result_, " ", "return ", function_name_, kEagerFallbackSuffix, + "(\n", + WordWrap(strings::StrCat(" "), + strings::StrCat(fallback_params, ")"), kRightMargin), + "\n"); + strings::StrAppend(&result_, " except _core._SymbolicException:\n"); + strings::StrAppend(&result_, + " pass # Add nodes to the TensorFlow graph.\n"); + AddDispatch(" "); } void GenEagerPythonOp::AddEagerInferredAttrs(const string& indentation) { diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py index a175f80a6c3..d085dfdab0d 100644 --- a/tensorflow/python/framework/sparse_tensor.py +++ b/tensorflow/python/framework/sparse_tensor.py @@ -297,18 +297,18 @@ _pywrap_utils.RegisterType("SparseTensorValue", SparseTensorValue) @tf_export("SparseTensorSpec") class SparseTensorSpec(type_spec.BatchableTypeSpec): - """Type specification for a `tf.SparseTensor`.""" + """Type specification for a `tf.sparse.SparseTensor`.""" __slots__ = ["_shape", "_dtype"] value_type = property(lambda self: SparseTensor) def __init__(self, shape=None, dtype=dtypes.float32): - """Constructs a type specification for a `tf.SparseTensor`. + """Constructs a type specification for a `tf.sparse.SparseTensor`. Args: - shape: The dense shape of the `SparseTensor`, or `None` to allow - any dense shape. + shape: The dense shape of the `SparseTensor`, or `None` to allow any dense + shape. dtype: `tf.DType` of values in the `SparseTensor`. """ self._shape = tensor_shape.as_shape(shape) @@ -472,13 +472,14 @@ def convert_to_tensor_or_sparse_tensor(value, dtype=None, name=None): def is_sparse(x): """Check whether `x` is sparse. - Check whether an object is a `tf.SparseTensor` or + Check whether an object is a `tf.sparse.SparseTensor` or `tf.compat.v1.SparseTensorValue`. Args: x: A python object to check. Returns: - `True` iff `x` is a `tf.SparseTensor` or `tf.compat.v1.SparseTensorValue`. + `True` iff `x` is a `tf.sparse.SparseTensor` or + `tf.compat.v1.SparseTensorValue`. """ return isinstance(x, (SparseTensor, SparseTensorValue)) diff --git a/tensorflow/python/framework/sparse_tensor_test.py b/tensorflow/python/framework/sparse_tensor_test.py index 0202a83ef9f..f7ecf00f29b 100644 --- a/tensorflow/python/framework/sparse_tensor_test.py +++ b/tensorflow/python/framework/sparse_tensor_test.py @@ -22,6 +22,7 @@ from absl.testing import parameterized import numpy as np from tensorflow.python.eager import context +from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor @@ -59,6 +60,18 @@ class SparseTensorTest(test_util.TensorFlowTestCase): self.assertAllEqual(sess_run_value.values, value.values) self.assertAllEqual(sess_run_value.dense_shape, value.dense_shape) + def testShape(self): + + @def_function.function + def test_fn(tensor): + tensor = sparse_ops.sparse_transpose(tensor) + self.assertEqual(tensor.shape.rank, 2) + return tensor + + tensor = sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 2]], values=[1., 2], dense_shape=[3, 4]) + test_fn(tensor) + def testIsSparse(self): self.assertFalse(sparse_tensor.is_sparse(3)) self.assertFalse(sparse_tensor.is_sparse("foo")) diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index b2ab779386b..35b5c0a3b1e 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -26,6 +26,8 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import func_graph +from tensorflow.python.framework import indexed_slices +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util @@ -33,6 +35,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_state_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables +from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.platform import test @@ -759,15 +762,39 @@ class TensorUtilTest(test.TestCase): self.assertFalse(tensor_util.ShapeEquals(t, [4])) +@test_util.run_all_in_graph_and_eager_modes class IsTensorTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes def testConstantTensor(self): np_val = np.random.rand(3).astype(np.int32) tf_val = constant_op.constant(np_val) self.assertFalse(tensor_util.is_tensor(np_val)) self.assertTrue(tensor_util.is_tensor(tf_val)) + def testRaggedTensor(self): + rt = ragged_factory_ops.constant([[1, 2], [3]]) + rt_value = self.evaluate(rt) + self.assertTrue(tensor_util.is_tensor(rt)) + self.assertFalse(tensor_util.is_tensor(rt_value)) + + def testSparseTensor(self): + st = sparse_tensor.SparseTensor([[1, 2]], [3], [10, 10]) + st_value = self.evaluate(st) + self.assertTrue(tensor_util.is_tensor(st)) + self.assertFalse(tensor_util.is_tensor(st_value)) + + def testIndexedSlices(self): + x = indexed_slices.IndexedSlices( + constant_op.constant([1, 2, 3]), constant_op.constant([10, 20, 30])) + x_value = indexed_slices.IndexedSlicesValue( + np.array([1, 2, 3]), np.array([10, 20, 30]), np.array([100])) + self.assertTrue(tensor_util.is_tensor(x)) + self.assertFalse(tensor_util.is_tensor(x_value)) + + def testVariable(self): + v = variables.Variable([1, 2, 3]) + self.assertTrue(tensor_util.is_tensor(v)) + class ConstantValueTest(test.TestCase): diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index c20b55ed829..809c527fb84 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== # pylint: disable=g-import-not-at-top +# pylint: disable=g-classes-have-attributes """Callbacks: utilities called at certain points during model training. """ from __future__ import absolute_import @@ -336,8 +337,8 @@ class CallbackList(object): This function should only be called during TRAIN mode. Arguments: - epoch: integer, index of epoch. - logs: dict. Currently no data is passed to this argument for this method + epoch: Integer, index of epoch. + logs: Dict. Currently no data is passed to this argument for this method but that may change in the future. """ logs = logs or {} @@ -357,8 +358,8 @@ class CallbackList(object): This function should only be called during TRAIN mode. Arguments: - epoch: integer, index of epoch. - logs: dict, metric results for this training epoch, and for the + epoch: Integer, index of epoch. + logs: Dict, metric results for this training epoch, and for the validation epoch if validation is performed. Validation result keys are prefixed with `val_`. """ @@ -376,8 +377,8 @@ class CallbackList(object): """Calls the `on_train_batch_begin` methods of its callbacks. Arguments: - batch: integer, index of batch within the current epoch. - logs: dict. Has keys `batch` and `size` representing the current batch + batch: Integer, index of batch within the current epoch. + logs: Dict. Has keys `batch` and `size` representing the current batch number and the size of the batch. """ # TODO(b/150629188): Make ProgBarLogger callback not use batch hooks @@ -389,8 +390,8 @@ class CallbackList(object): """Calls the `on_train_batch_end` methods of its callbacks. Arguments: - batch: integer, index of batch within the current epoch. - logs: dict. Aggregated metric results up until this batch. + batch: Integer, index of batch within the current epoch. + logs: Dict. Aggregated metric results up until this batch. """ if self._should_call_train_batch_hooks: self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs) @@ -399,8 +400,8 @@ class CallbackList(object): """Calls the `on_test_batch_begin` methods of its callbacks. Arguments: - batch: integer, index of batch within the current epoch. - logs: dict. Has keys `batch` and `size` representing the current batch + batch: Integer, index of batch within the current epoch. + logs: Dict. Has keys `batch` and `size` representing the current batch number and the size of the batch. """ if self._should_call_test_batch_hooks: @@ -410,8 +411,8 @@ class CallbackList(object): """Calls the `on_test_batch_end` methods of its callbacks. Arguments: - batch: integer, index of batch within the current epoch. - logs: dict. Aggregated metric results up until this batch. + batch: Integer, index of batch within the current epoch. + logs: Dict. Aggregated metric results up until this batch. """ if self._should_call_test_batch_hooks: self._call_batch_hook(ModeKeys.TEST, 'end', batch, logs=logs) @@ -420,8 +421,8 @@ class CallbackList(object): """Calls the `on_predict_batch_begin` methods of its callbacks. Arguments: - batch: integer, index of batch within the current epoch. - logs: dict. Has keys `batch` and `size` representing the current batch + batch: Integer, index of batch within the current epoch. + logs: Dict. Has keys `batch` and `size` representing the current batch number and the size of the batch. """ if self._should_call_predict_batch_hooks: @@ -431,8 +432,8 @@ class CallbackList(object): """Calls the `on_predict_batch_end` methods of its callbacks. Arguments: - batch: integer, index of batch within the current epoch. - logs: dict. Aggregated metric results up until this batch. + batch: Integer, index of batch within the current epoch. + logs: Dict. Aggregated metric results up until this batch. """ if self._should_call_predict_batch_hooks: self._call_batch_hook(ModeKeys.PREDICT, 'end', batch, logs=logs) @@ -441,7 +442,7 @@ class CallbackList(object): """Calls the `on_train_begin` methods of its callbacks. Arguments: - logs: dict. Currently no data is passed to this argument for this method + logs: Dict. Currently no data is passed to this argument for this method but that may change in the future. """ logs = logs or {} @@ -458,7 +459,7 @@ class CallbackList(object): """Calls the `on_train_end` methods of its callbacks. Arguments: - logs: dict. Currently no data is passed to this argument for this method + logs: Dict. Currently no data is passed to this argument for this method but that may change in the future. """ logs = logs or {} @@ -475,7 +476,7 @@ class CallbackList(object): """Calls the `on_test_begin` methods of its callbacks. Arguments: - logs: dict. Currently no data is passed to this argument for this method + logs: Dict. Currently no data is passed to this argument for this method but that may change in the future. """ logs = logs or {} @@ -492,7 +493,7 @@ class CallbackList(object): """Calls the `on_test_end` methods of its callbacks. Arguments: - logs: dict. Currently no data is passed to this argument for this method + logs: Dict. Currently no data is passed to this argument for this method but that may change in the future. """ logs = logs or {} @@ -509,7 +510,7 @@ class CallbackList(object): """Calls the 'on_predict_begin` methods of its callbacks. Arguments: - logs: dict. Currently no data is passed to this argument for this method + logs: Dict. Currently no data is passed to this argument for this method but that may change in the future. """ logs = logs or {} @@ -526,7 +527,7 @@ class CallbackList(object): """Calls the `on_predict_end` methods of its callbacks. Arguments: - logs: dict. Currently no data is passed to this argument for this method + logs: Dict. Currently no data is passed to this argument for this method but that may change in the future. """ logs = logs or {} @@ -548,9 +549,9 @@ class Callback(object): """Abstract base class used to build new callbacks. Attributes: - params: dict. Training parameters + params: Dict. Training parameters (eg. verbosity, batch size, number of epochs...). - model: instance of `keras.models.Model`. + model: Instance of `keras.models.Model`. Reference of the model being trained. The `logs` dictionary that callback methods @@ -591,8 +592,8 @@ class Callback(object): be called during TRAIN mode. Arguments: - epoch: integer, index of epoch. - logs: dict. Currently no data is passed to this argument for this method + epoch: Integer, index of epoch. + logs: Dict. Currently no data is passed to this argument for this method but that may change in the future. """ @@ -604,8 +605,8 @@ class Callback(object): be called during TRAIN mode. Arguments: - epoch: integer, index of epoch. - logs: dict, metric results for this training epoch, and for the + epoch: Integer, index of epoch. + logs: Dict, metric results for this training epoch, and for the validation epoch if validation is performed. Validation result keys are prefixed with `val_`. """ @@ -618,8 +619,8 @@ class Callback(object): Subclasses should override for any actions to run. Arguments: - batch: integer, index of batch within the current epoch. - logs: dict. Has keys `batch` and `size` representing the current batch + batch: Integer, index of batch within the current epoch. + logs: Dict. Has keys `batch` and `size` representing the current batch number and the size of the batch. """ # For backwards compatibility. @@ -633,8 +634,8 @@ class Callback(object): Subclasses should override for any actions to run. Arguments: - batch: integer, index of batch within the current epoch. - logs: dict. Aggregated metric results up until this batch. + batch: Integer, index of batch within the current epoch. + logs: Dict. Aggregated metric results up until this batch. """ # For backwards compatibility. self.on_batch_end(batch, logs=logs) @@ -650,8 +651,8 @@ class Callback(object): Subclasses should override for any actions to run. Arguments: - batch: integer, index of batch within the current epoch. - logs: dict. Has keys `batch` and `size` representing the current batch + batch: Integer, index of batch within the current epoch. + logs: Dict. Has keys `batch` and `size` representing the current batch number and the size of the batch. """ @@ -666,8 +667,8 @@ class Callback(object): Subclasses should override for any actions to run. Arguments: - batch: integer, index of batch within the current epoch. - logs: dict. Aggregated metric results up until this batch. + batch: Integer, index of batch within the current epoch. + logs: Dict. Aggregated metric results up until this batch. """ @doc_controls.for_subclass_implementers @@ -678,8 +679,8 @@ class Callback(object): Subclasses should override for any actions to run. Arguments: - batch: integer, index of batch within the current epoch. - logs: dict. Has keys `batch` and `size` representing the current batch + batch: Integer, index of batch within the current epoch. + logs: Dict. Has keys `batch` and `size` representing the current batch number and the size of the batch. """ @@ -691,8 +692,8 @@ class Callback(object): Subclasses should override for any actions to run. Arguments: - batch: integer, index of batch within the current epoch. - logs: dict. Aggregated metric results up until this batch. + batch: Integer, index of batch within the current epoch. + logs: Dict. Aggregated metric results up until this batch. """ @doc_controls.for_subclass_implementers @@ -702,7 +703,7 @@ class Callback(object): Subclasses should override for any actions to run. Arguments: - logs: dict. Currently no data is passed to this argument for this method + logs: Dict. Currently no data is passed to this argument for this method but that may change in the future. """ @@ -713,7 +714,7 @@ class Callback(object): Subclasses should override for any actions to run. Arguments: - logs: dict. Currently no data is passed to this argument for this method + logs: Dict. Currently no data is passed to this argument for this method but that may change in the future. """ @@ -724,7 +725,7 @@ class Callback(object): Subclasses should override for any actions to run. Arguments: - logs: dict. Currently no data is passed to this argument for this method + logs: Dict. Currently no data is passed to this argument for this method but that may change in the future. """ @@ -735,7 +736,7 @@ class Callback(object): Subclasses should override for any actions to run. Arguments: - logs: dict. Currently no data is passed to this argument for this method + logs: Dict. Currently no data is passed to this argument for this method but that may change in the future. """ @@ -746,7 +747,7 @@ class Callback(object): Subclasses should override for any actions to run. Arguments: - logs: dict. Currently no data is passed to this argument for this method + logs: Dict. Currently no data is passed to this argument for this method but that may change in the future. """ @@ -757,7 +758,7 @@ class Callback(object): Subclasses should override for any actions to run. Arguments: - logs: dict. Currently no data is passed to this argument for this method + logs: Dict. Currently no data is passed to this argument for this method but that may change in the future. """ @@ -847,7 +848,7 @@ class ProgbarLogger(Callback): """Callback that prints metrics to stdout. Arguments: - count_mode: One of "steps" or "samples". + count_mode: One of `"steps"` or `"samples"`. Whether the progress bar should count samples seen or steps (batches) seen. stateful_metrics: Iterable of string names of metrics that @@ -1411,7 +1412,7 @@ class EarlyStopping(Callback): """Stop training when a monitored metric has stopped improving. Assuming the goal of a training is to minimize the loss. With this, the - metric to be monitored would be 'loss', and mode would be 'min'. A + metric to be monitored would be `'loss'`, and mode would be `'min'`. A `model.fit()` training loop will check at end of every epoch whether the loss is no longer decreasing, considering the `min_delta` and `patience` if applicable. Once it's found no longer decreasing, @@ -1420,6 +1421,30 @@ class EarlyStopping(Callback): The quantity to be monitored needs to be available in `logs` dict. To make it so, pass the loss or metrics at `model.compile()`. + Arguments: + monitor: Quantity to be monitored. + min_delta: Minimum change in the monitored quantity + to qualify as an improvement, i.e. an absolute + change of less than min_delta, will count as no + improvement. + patience: Number of epochs with no improvement + after which training will be stopped. + verbose: verbosity mode. + mode: One of `{"auto", "min", "max"}`. In `min` mode, + training will stop when the quantity + monitored has stopped decreasing; in `"max"` + mode it will stop when the quantity + monitored has stopped increasing; in `"auto"` + mode, the direction is automatically inferred + from the name of the monitored quantity. + baseline: Baseline value for the monitored quantity. + Training will stop if the model doesn't show improvement over the + baseline. + restore_best_weights: Whether to restore model weights from + the epoch with the best value of the monitored quantity. + If False, the model weights obtained at the last step of + training are used. + Example: >>> callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3) @@ -1442,32 +1467,6 @@ class EarlyStopping(Callback): mode='auto', baseline=None, restore_best_weights=False): - """Initialize an EarlyStopping callback. - - Arguments: - monitor: Quantity to be monitored. - min_delta: Minimum change in the monitored quantity - to qualify as an improvement, i.e. an absolute - change of less than min_delta, will count as no - improvement. - patience: Number of epochs with no improvement - after which training will be stopped. - verbose: verbosity mode. - mode: One of `{"auto", "min", "max"}`. In `min` mode, - training will stop when the quantity - monitored has stopped decreasing; in `max` - mode it will stop when the quantity - monitored has stopped increasing; in `auto` - mode, the direction is automatically inferred - from the name of the monitored quantity. - baseline: Baseline value for the monitored quantity. - Training will stop if the model doesn't show improvement over the - baseline. - restore_best_weights: Whether to restore model weights from - the epoch with the best value of the monitored quantity. - If False, the model weights obtained at the last step of - training are used. - """ super(EarlyStopping, self).__init__() self.monitor = monitor @@ -1550,18 +1549,19 @@ class RemoteMonitor(Callback): Events are sent to `root + '/publish/epoch/end/'` by default. Calls are HTTP POST, with a `data` argument which is a JSON-encoded dictionary of event data. - If send_as_json is set to True, the content type of the request will be - application/json. Otherwise the serialized JSON will be sent within a form. + If `send_as_json=True`, the content type of the request will be + `"application/json"`. + Otherwise the serialized JSON will be sent within a form. Arguments: - root: String; root url of the target server. - path: String; path relative to `root` to which the events will be sent. - field: String; JSON field under which the data will be stored. - The field is used only if the payload is sent within a form - (i.e. send_as_json is set to False). - headers: Dictionary; optional custom HTTP headers. - send_as_json: Boolean; whether the request should be - sent as application/json. + root: String; root url of the target server. + path: String; path relative to `root` to which the events will be sent. + field: String; JSON field under which the data will be stored. + The field is used only if the payload is sent within a form + (i.e. send_as_json is set to False). + headers: Dictionary; optional custom HTTP headers. + send_as_json: Boolean; whether the request should be + sent as `"application/json"`. """ def __init__(self, @@ -1608,19 +1608,26 @@ class RemoteMonitor(Callback): class LearningRateScheduler(Callback): """Learning rate scheduler. - At the beginning of every epoch, this callback gets the learning rate - value from `schedule` function provided at `__init__`, with the current epoch, - and applies that learning rate on the optimizer. + At the beginning of every epoch, this callback gets the updated learning rate + value from `schedule` function provided at `__init__`, with the current epoch + and current learning rate, and applies the updated learning rate + on the optimizer. + + Arguments: + schedule: a function that takes an epoch index (integer, indexed from 0) + and current learning rate (float) as inputs and returns a new + learning rate as output (float). + verbose: int. 0: quiet, 1: update messages. Example: - >>> # This function keeps the learning rate at 0.001 for the first ten epochs + >>> # This function keeps the initial learning rate for the first ten epochs >>> # and decreases it exponentially after that. - >>> def scheduler(epoch): + >>> def scheduler(epoch, lr): ... if epoch < 10: - ... return 0.001 + ... return lr ... else: - ... return 0.001 * tf.math.exp(0.1 * (10 - epoch)) + ... return lr * tf.math.exp(-0.1) >>> >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) >>> model.compile(tf.keras.optimizers.SGD(), loss='mse') @@ -1629,21 +1636,13 @@ class LearningRateScheduler(Callback): >>> callback = tf.keras.callbacks.LearningRateScheduler(scheduler) >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), - ... epochs=2, callbacks=[callback], verbose=0) + ... epochs=15, callbacks=[callback], verbose=0) >>> round(model.optimizer.lr.numpy(), 5) - 0.001 + 0.00607 """ def __init__(self, schedule, verbose=0): - """Initialize a `keras.callbacks.LearningRateScheduler` callback. - - Arguments: - schedule: a function that takes an epoch index as input - (integer, indexed from 0) and returns a new - learning rate as output (float). - verbose: int. 0: quiet, 1: update messages. - """ super(LearningRateScheduler, self).__init__() self.schedule = schedule self.verbose = verbose @@ -1688,7 +1687,7 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): If you have installed TensorFlow with pip, you should be able to launch TensorBoard from the command line: - ```sh + ``` tensorboard --logdir=path_to_your_logs ``` @@ -1696,24 +1695,27 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard). Example (Basic): + ```python tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs") model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) # run the tensorboard command to view the visualizations. ``` + Example (Profile): + ```python # profile a single batch, e.g. the 5th batch. - tensorboard_callback = - tf.keras.callbacks.TensorBoard(log_dir='./logs', profile_batch=5) + tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs', + profile_batch=5) model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) - # run the tensorboard command to view the visualizations in profile plugin. + # Now run the tensorboard command to view the visualizations (profile plugin). # profile a range of batches, e.g. from 10 to 20. - tensorboard_callback = - tf.keras.callbacks.TensorBoard(log_dir='./logs', profile_batch='10,20') + tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs', + profile_batch='10,20') model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) - # run the tensorboard command to view the visualizations in profile plugin. + # Now run the tensorboard command to view the visualizations (profile plugin). ``` Arguments: @@ -2142,14 +2144,15 @@ class ReduceLROnPlateau(Callback): Arguments: monitor: quantity to be monitored. - factor: factor by which the learning rate will be reduced. new_lr = lr * - factor + factor: factor by which the learning rate will be reduced. + `new_lr = lr * factor`. patience: number of epochs with no improvement after which learning rate will be reduced. verbose: int. 0: quiet, 1: update messages. - mode: one of {auto, min, max}. In `min` mode, lr will be reduced when the - quantity monitored has stopped decreasing; in `max` mode it will be - reduced when the quantity monitored has stopped increasing; in `auto` + mode: one of `{'auto', 'min', 'max'}`. In `'min'` mode, + the learning rate will be reduced when the + quantity monitored has stopped decreasing; in `'max'` mode it will be + reduced when the quantity monitored has stopped increasing; in `'auto'` mode, the direction is automatically inferred from the name of the monitored quantity. min_delta: threshold for measuring the new optimum, to only focus on @@ -2248,10 +2251,10 @@ class ReduceLROnPlateau(Callback): @keras_export('keras.callbacks.CSVLogger') class CSVLogger(Callback): - """Callback that streams epoch results to a csv file. + """Callback that streams epoch results to a CSV file. Supports all values that can be represented as a string, - including 1D iterables such as np.ndarray. + including 1D iterables such as `np.ndarray`. Example: @@ -2261,10 +2264,10 @@ class CSVLogger(Callback): ``` Arguments: - filename: filename of the csv file, e.g. 'run/log.csv'. - separator: string used to separate elements in the csv file. - append: True: append if file exists (useful for continuing - training). False: overwrite existing file, + filename: Filename of the CSV file, e.g. `'run/log.csv'`. + separator: String used to separate elements in the CSV file. + append: Boolean. True: append if file exists (useful for continuing + training). False: overwrite existing file. """ def __init__(self, filename, separator=',', append=False): @@ -2347,12 +2350,12 @@ class LambdaCallback(Callback): at the appropriate time. Note that the callbacks expects positional arguments, as: - - `on_epoch_begin` and `on_epoch_end` expect two positional arguments: - `epoch`, `logs` - - `on_batch_begin` and `on_batch_end` expect two positional arguments: - `batch`, `logs` - - `on_train_begin` and `on_train_end` expect one positional argument: - `logs` + - `on_epoch_begin` and `on_epoch_end` expect two positional arguments: + `epoch`, `logs` + - `on_batch_begin` and `on_batch_end` expect two positional arguments: + `batch`, `logs` + - `on_train_begin` and `on_train_end` expect one positional argument: + `logs` Arguments: on_epoch_begin: called at the beginning of every epoch. diff --git a/tensorflow/python/keras/datasets/boston_housing.py b/tensorflow/python/keras/datasets/boston_housing.py index f3900cc075a..2c0badfefba 100644 --- a/tensorflow/python/keras/datasets/boston_housing.py +++ b/tensorflow/python/keras/datasets/boston_housing.py @@ -40,7 +40,7 @@ def load_data(path='boston_housing.npz', test_split=0.2, seed=113): Arguments: path: path where to cache the dataset locally - (relative to ~/.keras/datasets). + (relative to `~/.keras/datasets`). test_split: fraction of the data to reserve as test set. seed: Random seed for shuffling the data before computing the test split. @@ -48,10 +48,11 @@ def load_data(path='boston_housing.npz', test_split=0.2, seed=113): Returns: Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. - **x_train, x_test**: numpy arrays with shape (num_samples, 13) containing - either the training samples (for x_train), or test samples (for y_train) + **x_train, x_test**: numpy arrays with shape `(num_samples, 13)` + containing either the training samples (for x_train), + or test samples (for y_train). - **y_train, y_test**: numpy arrays of shape (num_samples, ) containing the + **y_train, y_test**: numpy arrays of shape `(num_samples,)` containing the target scalars. The targets are float scalars typically between 10 and 50 that represent the home prices in k$. """ diff --git a/tensorflow/python/keras/datasets/cifar10.py b/tensorflow/python/keras/datasets/cifar10.py index 60afd2c5b78..7b74feb4726 100644 --- a/tensorflow/python/keras/datasets/cifar10.py +++ b/tensorflow/python/keras/datasets/cifar10.py @@ -40,9 +40,9 @@ def load_data(): Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. **x_train, x_test**: uint8 arrays of RGB image data with shape - (num_samples, 3, 32, 32) if the `tf.keras.backend.image_data_format` is - 'channels_first', or (num_samples, 32, 32, 3) if the data format - is 'channels_last'. + `(num_samples, 3, 32, 32)` if `tf.keras.backend.image_data_format()` is + `'channels_first'`, or `(num_samples, 32, 32, 3)` if the data format + is `'channels_last'`. **y_train, y_test**: uint8 arrays of category labels (integers in range 0-9) each with shape (num_samples, 1). diff --git a/tensorflow/python/keras/datasets/cifar100.py b/tensorflow/python/keras/datasets/cifar100.py index 0c835b40d5d..5596f6ebb9b 100644 --- a/tensorflow/python/keras/datasets/cifar100.py +++ b/tensorflow/python/keras/datasets/cifar100.py @@ -46,9 +46,9 @@ def load_data(label_mode='fine'): Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. **x_train, x_test**: uint8 arrays of RGB image data with shape - (num_samples, 3, 32, 32) if the `tf.keras.backend.image_data_format` is - 'channels_first', or (num_samples, 32, 32, 3) if the data format - is 'channels_last'. + `(num_samples, 3, 32, 32)` if `tf.keras.backend.image_data_format()` is + `'channels_first'`, or `(num_samples, 32, 32, 3)` if the data format + is `'channels_last'`. **y_train, y_test**: uint8 arrays of category labels with shape (num_samples, 1). diff --git a/tensorflow/python/keras/datasets/imdb.py b/tensorflow/python/keras/datasets/imdb.py index d6f7cf6ae3d..61fbf92eaef 100644 --- a/tensorflow/python/keras/datasets/imdb.py +++ b/tensorflow/python/keras/datasets/imdb.py @@ -80,7 +80,7 @@ def load_data(path='imdb.npz', **x_train, x_test**: lists of sequences, which are lists of indexes (integers). If the num_words argument was specific, the maximum - possible index value is num_words-1. If the `maxlen` argument was + possible index value is `num_words - 1`. If the `maxlen` argument was specified, the largest possible sequence length is `maxlen`. **y_train, y_test**: lists of integer labels (1 or 0). diff --git a/tensorflow/python/keras/datasets/mnist.py b/tensorflow/python/keras/datasets/mnist.py index 1d41de197b3..f371ad4ece5 100644 --- a/tensorflow/python/keras/datasets/mnist.py +++ b/tensorflow/python/keras/datasets/mnist.py @@ -31,12 +31,12 @@ def load_data(path='mnist.npz'): This is a dataset of 60,000 28x28 grayscale images of the 10 digits, along with a test set of 10,000 images. More info can be found at the - (MNIST homepage)[http://yann.lecun.com/exdb/mnist/]. + [MNIST homepage](http://yann.lecun.com/exdb/mnist/). Arguments: path: path where to cache the dataset locally - (relative to ~/.keras/datasets). + (relative to `~/.keras/datasets`). Returns: Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. diff --git a/tensorflow/python/keras/datasets/reuters.py b/tensorflow/python/keras/datasets/reuters.py index ceec8d3bfd5..46ac9249637 100644 --- a/tensorflow/python/keras/datasets/reuters.py +++ b/tensorflow/python/keras/datasets/reuters.py @@ -42,11 +42,11 @@ def load_data(path='reuters.npz', """Loads the Reuters newswire classification dataset. This is a dataset of 11,228 newswires from Reuters, labeled over 46 topics. + This was originally generated by parsing and preprocessing the classic Reuters-21578 dataset, but the preprocessing code is no longer packaged - with Keras. - - See this [github discussion](https://github.com/keras-team/keras/issues/12072) + with Keras. See this + [github discussion](https://github.com/keras-team/keras/issues/12072) for more info. Each newswire is encoded as a list of word indexes (integers). @@ -91,7 +91,7 @@ def load_data(path='reuters.npz', **x_train, x_test**: lists of sequences, which are lists of indexes (integers). If the num_words argument was specific, the maximum - possible index value is num_words-1. If the `maxlen` argument was + possible index value is `num_words - 1`. If the `maxlen` argument was specified, the largest possible sequence length is `maxlen`. **y_train, y_test**: lists of integer labels (1 or 0). diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index c0c0d9d04be..a0e1e9edc2f 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -658,8 +658,6 @@ def mark_as_return(outputs, acd): V2_DTYPE_BEHAVIOR = None -# These two functions are not exported because we plan on removing them in the -# future. def enable_v2_dtype_behavior(): """Enable the V2 dtype behavior for Keras layers. diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 425cb5c9127..747e51fc4e2 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -547,6 +547,9 @@ class Network(base_layer.Layer): """ # TODO(fchollet): We could build a dictionary based on layer names # since they are constant, but we have not done that yet. + if index is not None and name is not None: + raise ValueError('Provide only a layer name or a layer index.') + if index is not None: if len(self.layers) <= index: raise ValueError('Was asked to retrieve layer at index ' + str(index) + @@ -554,13 +557,13 @@ class Network(base_layer.Layer): ' layers.') else: return self.layers[index] - else: - if not name: - raise ValueError('Provide either a layer name or layer index.') - for layer in self.layers: - if layer.name == name: - return layer - raise ValueError('No such layer: ' + name) + + if name is not None: + for layer in self.layers: + if layer.name == name: + return layer + raise ValueError('No such layer: ' + name + '.') + raise ValueError('Provide either a layer name or layer index.') @property def trainable_weights(self): diff --git a/tensorflow/python/keras/engine/network_test.py b/tensorflow/python/keras/engine/network_test.py index e227d08f595..ad620713b63 100644 --- a/tensorflow/python/keras/engine/network_test.py +++ b/tensorflow/python/keras/engine/network_test.py @@ -128,6 +128,40 @@ class NetworkConstructionTest(keras_parameterized.TestCase): self.assertEqual(len(layer.get_updates_for(x1)), 2) self.assertEqual(len(layer.get_updates_for(None)), 0) + def test_get_layer(self): + # create a simple network + x = input_layer_lib.Input(shape=(32,)) + dense_a = layers.Dense(4, name='dense_a') + dense_b = layers.Dense(2, name='dense_b') + y = dense_b(dense_a(x)) + network = network_lib.Network(x, y, name='dense_network') + + # test various get_layer by index + self.assertEqual(network.get_layer(index=1), dense_a) + + # test invalid get_layer by index + with self.assertRaisesRegexp( + ValueError, 'Was asked to retrieve layer at index ' + str(3) + + ' but model only has ' + str(len(network.layers)) + ' layers.'): + network.get_layer(index=3) + + # test that only one between name and index is requested + with self.assertRaisesRegexp(ValueError, + 'Provide only a layer name or a layer index'): + network.get_layer(index=1, name='dense_b') + + # test that a name or an index must be provided + with self.assertRaisesRegexp(ValueError, + 'Provide either a layer name or layer index.'): + network.get_layer() + + # test various get_layer by name + self.assertEqual(network.get_layer(name='dense_a'), dense_a) + + # test invalid get_layer by name + with self.assertRaisesRegexp(ValueError, 'No such layer: dense_c.'): + network.get_layer(name='dense_c') + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def testTopologicalAttributes(self): # test layer attributes / methods related to cross-layer connectivity. diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 64b5ff16f21..14c4dad34e5 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -467,7 +467,7 @@ class Model(network.Network, version_utils.ModelVersionSelector): >>> x = np.random.random((2, 3)) >>> y = np.random.randint(0, 2, (2, 2)) - >>> _ = model.fit(x, y, verbose=0) + >>> model.fit(x, y) >>> model.metrics_names ['loss', 'mae'] @@ -478,7 +478,7 @@ class Model(network.Network, version_utils.ModelVersionSelector): >>> model = tf.keras.models.Model( ... inputs=inputs, outputs=[output_1, output_2]) >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"]) - >>> _ = model.fit(x, (y, y), verbose=0) + >>> model.fit(x, (y, y)) >>> model.metrics_names ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae', 'out_1_acc'] @@ -1059,7 +1059,7 @@ class Model(network.Network, version_utils.ModelVersionSelector): return_dict=False): """Returns the loss value & metrics values for the model in test mode. - Computation is done in batches. + Computation is done in batches (see the `batch_size` arg.) Arguments: x: Input data. It could be: - A Numpy array (or array-like), or a list @@ -1077,10 +1077,11 @@ class Model(network.Network, version_utils.ModelVersionSelector): `x` is a dataset, generator or `keras.utils.Sequence` instance, `y` should not be specified (since targets will be obtained from the iterator/dataset). - batch_size: Integer or `None`. Number of samples per gradient update. If - unspecified, `batch_size` will default to 32. Do not specify the - `batch_size` if your data is in the form of a dataset, generators, - or `keras.utils.Sequence` instances (since they generate batches). + batch_size: Integer or `None`. Number of samples per batch of + computation. If unspecified, `batch_size` will default to 32. Do not + specify the `batch_size` if your data is in the form of a dataset, + generators, or `keras.utils.Sequence` instances (since they generate + batches). verbose: 0 or 1. Verbosity mode. 0 = silent, 1 = progress bar. sample_weight: Optional Numpy array of weights for the test samples, used for weighting the loss function. You can either pass a flat (1D) diff --git a/tensorflow/python/keras/engine/training_v1.py b/tensorflow/python/keras/engine/training_v1.py index 710f9bf3497..1edf364b3ff 100644 --- a/tensorflow/python/keras/engine/training_v1.py +++ b/tensorflow/python/keras/engine/training_v1.py @@ -800,7 +800,7 @@ class Model(training_lib.Model): use_multiprocessing=False): """Returns the loss value & metrics values for the model in test mode. - Computation is done in batches. + Computation is done in batches (see the `batch_size` arg.) Arguments: x: Input data. It could be: @@ -820,7 +820,7 @@ class Model(training_lib.Model): `keras.utils.Sequence` instance, `y` should not be specified (since targets will be obtained from the iterator/dataset). batch_size: Integer or `None`. - Number of samples per gradient update. + Number of samples per batch of computation. If unspecified, `batch_size` will default to 32. Do not specify the `batch_size` if your data is in the form of symbolic tensors, dataset, @@ -903,7 +903,7 @@ class Model(training_lib.Model): use_multiprocessing=False): """Generates output predictions for the input samples. - Computation is done in batches. + Computation is done in batches (see the `batch_size` arg.) Arguments: x: Input samples. It could be: @@ -914,7 +914,7 @@ class Model(training_lib.Model): - A `tf.data` dataset. - A generator or `keras.utils.Sequence` instance. batch_size: Integer or `None`. - Number of samples per gradient update. + Number of samples per batch of computation. If unspecified, `batch_size` will default to 32. Do not specify the `batch_size` if your data is in the form of symbolic tensors, dataset, diff --git a/tensorflow/python/keras/integration_test/BUILD b/tensorflow/python/keras/integration_test/BUILD new file mode 100644 index 00000000000..f9b3b168721 --- /dev/null +++ b/tensorflow/python/keras/integration_test/BUILD @@ -0,0 +1,33 @@ +# Description: +# Contains Keras integration tests that verify with other TF high level APIs. + +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +package( + default_visibility = [ + "//tensorflow/tools/pip_package:__pkg__", + ], + licenses = ["notice"], # Apache 2.0 +) + +exports_files(["LICENSE"]) + +tf_py_test( + name = "forwardprop_test", + srcs = ["forwardprop_test.py"], + python_version = "PY3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python:extra_py_tests_deps", + ], +) + +tf_py_test( + name = "function_test", + srcs = ["function_test.py"], + python_version = "PY3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python:extra_py_tests_deps", + ], +) diff --git a/tensorflow/python/keras/integration_test/README.md b/tensorflow/python/keras/integration_test/README.md new file mode 100644 index 00000000000..4d40893f686 --- /dev/null +++ b/tensorflow/python/keras/integration_test/README.md @@ -0,0 +1,12 @@ +# Keras Integration Test + +This package contains integration tests that ensure the correct interaction +between Keras and other Tensorflow high level APIs, like dataset, TF function +and distribution strategy, etc. + +There are a few guidelines for the tests under this package. + +*. Only use the public TF API. +*. Test should focus on the end to end use case between Keras and other TF high + level API. Unit test will be a better place for behavior testing for the + individual APIs. diff --git a/tensorflow/python/keras/integration_test/forwardprop_test.py b/tensorflow/python/keras/integration_test/forwardprop_test.py new file mode 100644 index 00000000000..0418a9db48b --- /dev/null +++ b/tensorflow/python/keras/integration_test/forwardprop_test.py @@ -0,0 +1,294 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +from absl.testing import parameterized +import numpy as np +import tensorflow as tf + + +def _jvp(f, primals, tangents): + """Compute the jacobian of `f` at `primals` multiplied by `tangents`.""" + with tf.autodiff.ForwardAccumulator(primals, tangents) as acc: + primals_out = f(*primals) + return primals_out, acc.jvp( + primals_out, unconnected_gradients=tf.UnconnectedGradients.ZERO) + + +def _jacfwd(f, primals): + """Compute the jacobian of `f` at `primals` using forward-mode autodiff.""" + jac_flat = [] + flat_primals = tf.nest.flatten(primals) + tangent_mask = [tf.zeros_like(primal) for primal in flat_primals] + for primal_index, primal in enumerate(flat_primals): + primal_vector = tf.reshape(primal, [-1]) + primal_vector_length = tf.size(primal_vector) + jac_columns = [] + for element_index in tf.range(primal_vector_length): + mask = tf.one_hot(element_index, primal_vector_length) + tangent_mask[primal_index] = tf.reshape(mask, tf.shape(primal)) + jac_columns.append( + tf.nest.map_structure( + functools.partial(tf.reshape, shape=[-1]), + _jvp(f, primals, tf.nest.pack_sequence_as(primals, + tangent_mask))[1])) + jac_flat.append(tf.stack(jac_columns, axis=1)) + tangent_mask[primal_index] = tf.zeros_like(primal) + return tf.nest.pack_sequence_as(primals, jac_flat) + + +def _grad(f, argnums=0): + """Return a function which computes the gradient of `f`.""" + + def _f(*params): + with tf.GradientTape() as tape: + tape.watch(params) + primals_out = f(*params) + return tape.gradient( + primals_out, + params[argnums], + unconnected_gradients=tf.UnconnectedGradients.ZERO) + + return _f + + +def _hvp(f, primals, tangents): + """Compute a forward-over-back Hessian-vector product.""" + with tf.autodiff.ForwardAccumulator(primals, tangents) as acc: + with tf.GradientTape() as tape: + tape.watch(primals) + f_out = f(*primals) + f_out.shape.assert_is_compatible_with([]) + return acc.jvp(tape.gradient(f_out, primals)) + + +def _vectorize_parameters(f, params, use_pfor, dtype): + """Loop over `params`, providing a one-hot mask to `f` for each.""" + parameter_sizes = [tf.size(param) for param in params] + total_size = tf.math.add_n(parameter_sizes) + + def _wrapper(index): + full_onehot = tf.one_hot(index, total_size) + split_onehot = tf.split(full_onehot, parameter_sizes) + tangents = [ + tf.reshape(v, tf.shape(param)) + for param, v in zip(params, split_onehot) + ] + return f(tangents) + + if use_pfor: + return tf.vectorized_map(_wrapper, tf.range(total_size)) + else: + return tf.map_fn(_wrapper, tf.range(total_size), dtype) + + +def _forward_over_back_hessian(f, params, use_pfor, dtype=None): + """Computes the full Hessian matrix for the scalar-valued f(*params). + + Args: + f: A function taking `params` and returning a scalar. + params: A possibly nested structure of tensors. + use_pfor: If true, uses `tf.vectorized_map` calls instead of looping. + dtype: Required if `use_pfor=False`. A possibly nested structure of dtypes + (e.g. `tf.float32`) matching the structure of `f`'s returns. + + Returns: + A possibly nested structure of matrix slices corresponding to `params`. Each + slice has shape [P, p_s] where `p_s` is the number of parameters (`tf.size`) + in the corresponding element of `params` and `P` is the total number of + parameters (`sum_s(p_s)`). The full matrix can be obtained by concatenating + along the second axis. + """ + return _vectorize_parameters( + functools.partial(_hvp, f, params), + params, use_pfor=use_pfor, dtype=dtype) + + +def _test_gradients(testcase, + f, + primals, + order, + delta=1e-3, + rtol=1e-2, + atol=1e-6): + """Tests forward/backward jacobians of `f`'s [0, `order`)-order gradients.""" + if order < 1: + raise ValueError( + "`order` should be a positive integer, got '{}'.".format(order)) + if order > 1: + _test_gradients( + testcase=testcase, + f=_grad(f), + primals=primals, + order=order - 1, + delta=delta, + rtol=rtol, + atol=atol) + sym_jac_back, num_jac = tf.test.compute_gradient(f, primals, delta=delta) + testcase.assertAllClose(num_jac, sym_jac_back, rtol=rtol, atol=atol) + sym_jac_fwd = _jacfwd(f, primals) + testcase.assertAllClose(num_jac, sym_jac_fwd, rtol=rtol, atol=atol) + # And the symbolic computations should be much closer. + testcase.assertAllClose(sym_jac_back, sym_jac_fwd) + + +class ForwardpropTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters([ + ("Dense", [[0.1]], functools.partial(tf.keras.layers.Dense, 5)), + ("Conv2D", + np.reshape( + np.arange(start=-1., stop=1., step=2. / (1 * 2 * 4 * 4)), + [1, 2, 4, 4]), functools.partial(tf.keras.layers.Conv2D, 2, 2), 1e-3) + ]) + def testKerasLayers(self, value, op_fn, atol=1e-6): + layer = op_fn() + input_value = tf.constant(value, dtype=tf.float32) + layer.build(input_value.shape) + # Make sure the test is deterministic by avoiding random variable + # initialization. + for v in layer.trainable_variables: + v.assign( + tf.reshape( + tf.range( + -1., + 1., + 2. / tf.size(v, out_type=tf.float32), + dtype=tf.float32), v.shape)) + _test_gradients( + self, layer, [input_value], atol=atol, + # These are linear, so second-order is pretty boring. + order=2) + + @parameterized.named_parameters([ + ("NonFused", [[0.1], [0.2], [-0.3]], + functools.partial(tf.keras.layers.BatchNormalization, fused=False)), + ("Fused", [[[[0.1, 2.]]], [[[0.2, -3.]]], [[[-0.3, 4.]]]], + functools.partial(tf.keras.layers.BatchNormalization, fused=True)) + ]) + def testBatchNorm(self, value, op_fn): + for training in [True, False]: + layer = op_fn() + input_value = tf.constant(value, dtype=tf.float32) + layer.build(input_value.shape) + _test_gradients( + self, functools.partial(layer, training=training), [input_value], + order=2, atol=1e-3) + + @parameterized.named_parameters([ + ("NonFused", [[0.1], [0.2], [-0.3]], + functools.partial(tf.keras.layers.BatchNormalization, fused=False)), + ("Fused", [[[[0.1, 2.]]], [[[0.2, -3.]]], [[[-0.3, 4.]]]], + functools.partial(tf.keras.layers.BatchNormalization, fused=True)) + ]) + def testBatchNormLayerParamGrads(self, value, op_fn): + for training in [True, False]: + layer = op_fn() + with tf.GradientTape() as tape: + input_value = tf.constant(value, dtype=tf.float32) + tape.watch(input_value) + output = layer(input_value, training=training) + jac_back = tape.jacobian( + output, [input_value] + layer.trainable_variables) + jac_forward = _jacfwd( + lambda *args: layer(args[0], training=training), # pylint:disable=cell-var-from-loop + [input_value] + layer.trainable_variables) + for backward, forward in zip(jac_back, jac_forward): + forward = tf.reshape(forward, tf.shape(backward)) + self.assertAllClose(backward, forward) + + @parameterized.named_parameters([("Function", tf.function), + ("NoFunction", lambda f: f)]) + def testVariablesHVP(self, decorator): + + if tf.test.is_built_with_rocm(): + # TODO(rocm) + # This test was recently added and has never passed on the + # ROCm platform. Remove this skip once the test is passing again + self.skipTest("NoFunction decorator test fails on the ROCm platform") + + class _Model(tf.Module): + + def __init__(self): + self._first_dense = tf.keras.layers.Dense(18) + self._conv = tf.keras.layers.Conv2D(2, 2) + self._norm = tf.keras.layers.BatchNormalization() + self._second_dense = tf.keras.layers.Dense(1) + + def __call__(self, x): + x = self._first_dense(x) + x = tf.nn.relu(x) + x = self._norm(x) + x = tf.nn.relu(self._conv(tf.reshape(x, [-1, 2, 3, 3]))) + return self._second_dense(x) + + model = _Model() + def _loss(): + input_value = tf.constant([[-0.5, 1.], [0.5, -1.]]) + target = tf.constant([[-1.], [2.]]) + return tf.math.reduce_sum((model(input_value) - target)**2.) + + @decorator + def _compute_hvps(): + with tf.GradientTape() as tape: + loss = _loss() + vector = tape.gradient(loss, model.trainable_variables) + variable_input_fn = lambda unused_variables: _loss() + forward_over_back_hvp, = _hvp( + variable_input_fn, [model.trainable_variables], [vector]) + with tf.GradientTape(persistent=True) as tape: + tape.watch(model.trainable_variables) + loss = _loss() + first_grads = tape.gradient(loss, model.trainable_variables) + back_over_back_hvp = tape.gradient( + first_grads, model.trainable_variables, output_gradients=vector) + return forward_over_back_hvp, back_over_back_hvp + self.assertAllClose(*_compute_hvps(), rtol=1e-5, atol=1e-5) + + +class HessianTests(tf.test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + [("PFor", True), + ("MapFn", False)]) + def testHessianOfVariables(self, use_pfor): + model = tf.keras.layers.Dense(1) + model.build([None, 2]) + + def _loss(*unused_args): + input_value = tf.constant([[-0.5, 1.], [0.5, -1.]]) + target = tf.constant([[-1.], [2.]]) + return tf.math.reduce_sum((model(input_value) - target)**2.) + + kernel_hess, bias_hess = _forward_over_back_hessian( + _loss, [model.kernel, model.bias], + use_pfor=use_pfor, + dtype=[tf.float32, tf.float32]) + # 3 total parameters, the whole hessian is the 3x3 concatenation + self.assertEqual([3, 2, 1], kernel_hess.shape) + self.assertEqual([3, 1], bias_hess.shape) + full_hessian = tf.concat([tf.reshape(kernel_hess, [3, 2]), bias_hess], + axis=1) + # The full Hessian should be symmetric. + self.assertAllClose(full_hessian, tf.transpose(full_hessian)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/python/keras/integration_test/function_test.py b/tensorflow/python/keras/integration_test/function_test.py new file mode 100644 index 00000000000..b0823724927 --- /dev/null +++ b/tensorflow/python/keras/integration_test/function_test.py @@ -0,0 +1,213 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +import tensorflow as tf + + +class MiniModel(tf.keras.Model): + """Minimal model for mnist. + + Useful for testing and debugging on slow TPU simulators. + """ + + def __init__(self): + super(MiniModel, self).__init__(name='') + self.fc = tf.keras.layers.Dense(1, name='fc', kernel_initializer='ones', + bias_initializer='ones') + + def call(self, inputs, training=True): + return self.fc(inputs) + + +class DefunnedMiniModel(MiniModel): + + @tf.function + def call(self, inputs, training=True): + return super(DefunnedMiniModel, self).call(inputs, training=training) + + +class ModelWithOptimizer(tf.keras.Model): + + def __init__(self): + super(ModelWithOptimizer, self).__init__() + self.dense = tf.keras.layers.Dense(1) + self.optimizer = tf.keras.optimizers.Adam(0.01) + + @tf.function( + input_signature=(tf.TensorSpec([None, 2], tf.float32), + tf.TensorSpec([None], tf.float32))) + def call(self, x, y): + with tf.GradientTape() as tape: + loss = tf.math.reduce_mean((self.dense(x) - y) ** 2.) + trainable_variables = self.trainable_variables + gradients = tape.gradient(loss, trainable_variables) + self.optimizer.apply_gradients(zip(gradients, trainable_variables)) + return {'loss': loss} + + +class FunctionTest(tf.test.TestCase): + + def testFunctionRelaxationLosesInnerDimWithKerasLayer(self): + layer = tf.keras.layers.Dense(1) + fn = tf.function(experimental_relax_shapes=True)(layer) + + with self.captureWritesToStream(sys.stderr) as printed: + fn(tf.ones((3, 2))) + self.assertNotIn('ValueError', printed.contents()) + with self.captureWritesToStream(sys.stderr) as printed: + # Use batch size 2 to trigger a second cache miss on the shape. + fn(tf.ones((2, 2))) + self.assertNotIn('ValueError', printed.contents()) + + # Shape relaxation passes TensorShape([None, None]), which causes layer + # matmul to fail, due to incompatible dims. What would have been a graph + # build time error (layer would complain about the inner dim being 4). + with self.captureWritesToStream(sys.stderr) as printed: + with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, + r'Matrix size-incompatible'): + fn(tf.ones((3, 4))) + + def testDefunKerasModelCall(self): + model = MiniModel() + model.call = tf.function(model.call) + + x = tf.ones([1, 2]) + y = model(x) # pylint:disable=not-callable + + self.assertAllEqual([[3.0]], self.evaluate(y)) + + # Break the reference cycle between the MiniModel and the defun: + # `MiniModel` --(through its `call` method)--> `Function` + # `Function` --(instancemethod on `MiniModel`)--> `MiniModel` + del model.call + + def testDecoratedMethod(self): + m = DefunnedMiniModel() + instance_call_one = m.call(tf.ones([1, 2]), training=True) + instance_call_two = m.call( + inputs=tf.ones([1, 2]), training=True) + class_call = DefunnedMiniModel.call(m, tf.ones([1, 2]), training=True) + self.assertAllEqual(instance_call_one, instance_call_two) + self.assertAllEqual(instance_call_one, class_call) + + def testDecoratedMethodUniqueFunctionPerInstance(self): + m = DefunnedMiniModel() + n = DefunnedMiniModel() + + class_method_one = DefunnedMiniModel.call + class_method_two = DefunnedMiniModel.call + + m_method_one = m.call + m_method_two = m.call + + n_method_one = n.call + n_method_two = n.call + + self.assertEqual(class_method_one, class_method_two) + self.assertEqual(m_method_one, m_method_two) + self.assertEqual(n_method_one, n_method_two) + self.assertNotEqual(m.call, n.call) + + def testDecoratedMethodGetConcreteFunction(self): + m = DefunnedMiniModel() + instance_call_one = m.call.get_concrete_function( + tf.ones([1, 2]), training=False) + instance_call_two = m.call.get_concrete_function( + inputs=tf.ones([1, 2]), training=False) + self.assertAllEqual(instance_call_one(tf.ones([1, 2])), + instance_call_two(tf.ones([1, 2]))) + + # Also make sure get_concrete_function works on the class method + DefunnedMiniModel.call.get_concrete_function( + m, tf.ones([1, 2]), training=False) + DefunnedMiniModel.call.get_concrete_function( + m, inputs=tf.ones([1, 2]), training=True) + + def testDecoratedMethodVariableCleanup(self): + m = DefunnedMiniModel() + m(tf.ones([1, 2])) # pylint:disable=not-callable + variable_refs = list({v.ref() for v in m.variables}) + self.assertLen(variable_refs, 2) + del m + + # Verifying if the variables are only referenced from variable_refs. + # We expect the reference counter to be 1, but `sys.getrefcount` reports + # one higher reference counter because a temporary is created when we call + # sys.getrefcount(). Hence check if the number returned is 2. + # https://docs.python.org/3/library/sys.html#sys.getrefcount + self.assertEqual(sys.getrefcount(variable_refs[0].deref()), 2) + self.assertEqual(sys.getrefcount(variable_refs[1].deref()), 2) + + def testStandardTrainingLoopInFunction(self): + layer = tf.keras.layers.Dense(2) + dataset = ( + tf.data.Dataset.from_tensors((tf.ones([784]), tf.ones([], tf.int32))) + .map(lambda x, y: (x, y)) + .repeat(10) + .batch(32)) + optimizer = tf.keras.optimizers.Adam() + + @tf.function + def train(): + for x, y in dataset: + with tf.GradientTape() as tape: + out = layer(x) + loss = tf.reduce_mean( + tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=out, labels=y)) + layer_variables = layer.trainable_variables + gradients = tape.gradient(loss, layer_variables) + optimizer.apply_gradients(zip(gradients, layer_variables)) + + train() + + def testEarlyStoppingTrainingLoopInFunction(self): + layer = tf.keras.layers.Dense(2) + dataset = ( + tf.data.Dataset.from_tensors((tf.ones([784]), tf.ones([], tf.int32))) + .map(lambda x, y: (x, y)) + .repeat(10) + .batch(32)) + optimizer = tf.keras.optimizers.Adam() + + @tf.function + def train(): + for x, y in dataset: + with tf.GradientTape() as tape: + out = layer(x) + loss = tf.math.reduce_mean( + tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=out, labels=y)) + layer_variables = layer.trainable_variables + gradients = tape.gradient(loss, layer_variables) + optimizer.apply_gradients(zip(gradients, layer_variables)) + if optimizer.iterations > 3: + break + + train() + + def test_optimizer(self): + x = tf.constant([[3., 4.]]) + y = tf.constant([2.]) + model = ModelWithOptimizer() + model(x, y) # pylint:disable=not-callable + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/python/keras/layers/local.py b/tensorflow/python/keras/layers/local.py index ec7392e754e..3e9c0f9c0a3 100644 --- a/tensorflow/python/keras/layers/local.py +++ b/tensorflow/python/keras/layers/local.py @@ -782,7 +782,7 @@ def local_conv_sparse_matmul(inputs, kernel, kernel_idxs, kernel_shape, output_shape): """Apply N-D convolution with un-shared weights using a single sparse matmul. - This method outputs `inputs . tf.SparseTensor(indices=kernel_idxs, + This method outputs `inputs . tf.sparse.SparseTensor(indices=kernel_idxs, values=kernel, dense_shape=kernel_shape)`, with `.` standing for matrix-multiply. It also reshapes `inputs` to 2-D and `output` to (N+2)-D. diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup.py b/tensorflow/python/keras/layers/preprocessing/index_lookup.py index bc5fab1604d..6c40d1618bc 100644 --- a/tensorflow/python/keras/layers/preprocessing/index_lookup.py +++ b/tensorflow/python/keras/layers/preprocessing/index_lookup.py @@ -49,7 +49,7 @@ _ACCUMULATOR_COUNTS_NAME = "counts" class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer): """Maps strings (or integers) from a vocabulary to integer indices. - This layer translates a set of arbitray strings or integers into an integer + This layer translates a set of arbitrary strings or integers into an integer output via a table-based lookup, with optional out-of-vocabulary handling. If desired, the user can call this layer's `adapt()` method on a data set, diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index 628ecc332c5..5010cbc4370 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function import collections +import warnings import numpy as np @@ -266,6 +267,12 @@ class RNN(Layer): RNN calculation. However, most TensorFlow data is batch-major, so by default this function accepts input and emits output in batch-major form. + zero_output_for_mask: Boolean (default `False`). + Whether the output should use zeros for the masked timesteps. Note that + this field is only used when `return_sequences` is True and mask is + provided. It can useful if you want to reuse the raw output sequence of + the RNN without interference from the masked timesteps, eg, merging + bidirectional RNNs. Call arguments: inputs: Input tensor. @@ -979,6 +986,15 @@ class RNN(Layer): def _trackable_saved_model_saver(self): return layer_serialization.RNNSavedModelSaver(self) + @property + def weights(self): + if self.stateful: + warnings.warn( + 'The internal states of stateful RNN layers are not included in ' + '`layer.weights`. Please use `layer.states()` if you want to ' + 'retrieve the internal states of the layer.') + return super(RNN, self).weights + @keras_export('keras.layers.AbstractRNNCell') class AbstractRNNCell(Layer): diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py index 9db3435fa50..d97f8d94d50 100644 --- a/tensorflow/python/keras/losses.py +++ b/tensorflow/python/keras/losses.py @@ -48,12 +48,14 @@ class Loss(object): * `call()`: Contains the logic for loss calculation using `y_true`, `y_pred`. Example subclass implementation: + ```python class MeanSquaredError(Loss): + def call(self, y_true, y_pred): - y_pred = ops.convert_to_tensor_v2(y_pred) - y_true = math_ops.cast(y_true, y_pred.dtype) - return K.mean(math_ops.square(y_pred - y_true), axis=-1) + y_pred = tf.convert_to_tensor_v2(y_pred) + y_true = tf.cast(y_true, y_pred.dtype) + return tf.reduce_mean(math_ops.square(y_pred - y_true), axis=-1) ``` When used with `tf.distribute.Strategy`, outside of built-in training loops @@ -259,7 +261,7 @@ class MeanSquaredError(LossFunctionWrapper): `loss = square(y_true - y_pred)` - Usage: + Standalone usage: >>> y_true = [[0., 1.], [0., 0.]] >>> y_pred = [[1., 1.], [1., 0.]] @@ -284,11 +286,10 @@ class MeanSquaredError(LossFunctionWrapper): >>> mse(y_true, y_pred).numpy() array([0.5, 0.5], dtype=float32) - Usage with the `compile` API: + Usage with the `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss=tf.keras.losses.MeanSquaredError()) + model.compile(optimizer='sgd', loss=tf.keras.losses.MeanSquaredError()) ``` """ @@ -319,7 +320,7 @@ class MeanAbsoluteError(LossFunctionWrapper): `loss = abs(y_true - y_pred)` - Usage: + Standalone usage: >>> y_true = [[0., 1.], [0., 0.]] >>> y_pred = [[1., 1.], [1., 0.]] @@ -344,11 +345,10 @@ class MeanAbsoluteError(LossFunctionWrapper): >>> mae(y_true, y_pred).numpy() array([0.5, 0.5], dtype=float32) - Usage with the `compile` API: + Usage with the `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss=tf.keras.losses.MeanAbsoluteError()) + model.compile(optimizer='sgd', loss=tf.keras.losses.MeanAbsoluteError()) ``` """ @@ -379,7 +379,7 @@ class MeanAbsolutePercentageError(LossFunctionWrapper): `loss = 100 * abs(y_true - y_pred) / y_true` - Usage: + Standalone usage: >>> y_true = [[2., 1.], [2., 3.]] >>> y_pred = [[1., 1.], [1., 0.]] @@ -404,11 +404,11 @@ class MeanAbsolutePercentageError(LossFunctionWrapper): >>> mape(y_true, y_pred).numpy() array([25., 75.], dtype=float32) - Usage with the `compile` API: + Usage with the `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss=tf.keras.losses.MeanAbsolutePercentageError()) + model.compile(optimizer='sgd', + loss=tf.keras.losses.MeanAbsolutePercentageError()) ``` """ @@ -440,7 +440,7 @@ class MeanSquaredLogarithmicError(LossFunctionWrapper): `loss = square(log(y_true + 1.) - log(y_pred + 1.))` - Usage: + Standalone usage: >>> y_true = [[0., 1.], [0., 0.]] >>> y_pred = [[1., 1.], [1., 0.]] @@ -465,11 +465,11 @@ class MeanSquaredLogarithmicError(LossFunctionWrapper): >>> msle(y_true, y_pred).numpy() array([0.240, 0.240], dtype=float32) - Usage with the `compile` API: + Usage with the `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss=tf.keras.losses.MeanSquaredLogarithmicError()) + model.compile(optimizer='sgd', + loss=tf.keras.losses.MeanSquaredLogarithmicError()) ``` """ @@ -507,7 +507,7 @@ class BinaryCrossentropy(LossFunctionWrapper): floating-pointing value, and both `y_pred` and `y_true` have the shape `[batch_size]`. - Usage: + Standalone usage: >>> y_true = [[0., 1.], [0., 0.]] >>> y_pred = [[0.6, 0.4], [0.4, 0.6]] @@ -535,8 +535,7 @@ class BinaryCrossentropy(LossFunctionWrapper): Usage with the `tf.keras` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss=tf.keras.losses.BinaryCrossentropy()) + model.compile(optimizer='sgd', loss=tf.keras.losses.BinaryCrossentropy()) ``` """ @@ -589,7 +588,7 @@ class CategoricalCrossentropy(LossFunctionWrapper): example. The shape of both `y_pred` and `y_true` are `[batch_size, num_classes]`. - Usage: + Standalone usage: >>> y_true = [[0, 1, 0], [0, 0, 1]] >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] @@ -614,11 +613,10 @@ class CategoricalCrossentropy(LossFunctionWrapper): >>> cce(y_true, y_pred).numpy() array([0.0513, 2.303], dtype=float32) - Usage with the `compile` API: + Usage with the `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss=tf.keras.losses.CategoricalCrossentropy()) + model.compile(optimizer='sgd', loss=tf.keras.losses.CategoricalCrossentropy()) ``` """ @@ -671,7 +669,7 @@ class SparseCategoricalCrossentropy(LossFunctionWrapper): The shape of `y_true` is `[batch_size]` and the shape of `y_pred` is `[batch_size, num_classes]`. - Usage: + Standalone usage: >>> y_true = [1, 2] >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] @@ -696,11 +694,11 @@ class SparseCategoricalCrossentropy(LossFunctionWrapper): >>> scce(y_true, y_pred).numpy() array([0.0513, 2.303], dtype=float32) - Usage with the `compile` API: + Usage with the `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss=tf.keras.losses.SparseCategoricalCrossentropy()) + model.compile(optimizer='sgd', + loss=tf.keras.losses.SparseCategoricalCrossentropy()) ``` """ @@ -742,7 +740,7 @@ class Hinge(LossFunctionWrapper): `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are provided we will convert them to -1 or 1. - Usage: + Standalone usage: >>> y_true = [[0., 1.], [0., 0.]] >>> y_pred = [[0.6, 0.4], [0.4, 0.6]] @@ -767,11 +765,10 @@ class Hinge(LossFunctionWrapper): >>> h(y_true, y_pred).numpy() array([1.1, 1.5], dtype=float32) - Usage with the `compile` API: + Usage with the `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss=tf.keras.losses.Hinge()) + model.compile(optimizer='sgd', loss=tf.keras.losses.Hinge()) ``` """ @@ -802,7 +799,7 @@ class SquaredHinge(LossFunctionWrapper): `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are provided we will convert them to -1 or 1. - Usage: + Standalone usage: >>> y_true = [[0., 1.], [0., 0.]] >>> y_pred = [[0.6, 0.4], [0.4, 0.6]] @@ -827,11 +824,10 @@ class SquaredHinge(LossFunctionWrapper): >>> h(y_true, y_pred).numpy() array([1.46, 2.26], dtype=float32) - Usage with the `compile` API: + Usage with the `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss=tf.keras.losses.SquaredHinge()) + model.compile(optimizer='sgd', loss=tf.keras.losses.SquaredHinge()) ``` """ @@ -863,7 +859,7 @@ class CategoricalHinge(LossFunctionWrapper): `loss = maximum(neg - pos + 1, 0)` where `neg=maximum((1-y_true)*y_pred) and pos=sum(y_true*y_pred)` - Usage: + Standalone usage: >>> y_true = [[0, 1], [0, 0]] >>> y_pred = [[0.6, 0.4], [0.4, 0.6]] @@ -888,11 +884,10 @@ class CategoricalHinge(LossFunctionWrapper): >>> h(y_true, y_pred).numpy() array([1.2, 1.6], dtype=float32) - Usage with the `compile` API: + Usage with the `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss=tf.keras.losses.CategoricalHinge()) + model.compile(optimizer='sgd', loss=tf.keras.losses.CategoricalHinge()) ``` """ @@ -923,7 +918,7 @@ class Poisson(LossFunctionWrapper): `loss = y_pred - y_true * log(y_pred)` - Usage: + Standalone usage: >>> y_true = [[0., 1.], [0., 0.]] >>> y_pred = [[1., 1.], [0., 0.]] @@ -948,11 +943,10 @@ class Poisson(LossFunctionWrapper): >>> p(y_true, y_pred).numpy() array([0.999, 0.], dtype=float32) - Usage with the `compile` API: + Usage with the `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss=tf.keras.losses.Poisson()) + model.compile(optimizer='sgd', loss=tf.keras.losses.Poisson()) ``` """ @@ -981,7 +975,7 @@ class LogCosh(LossFunctionWrapper): `logcosh = log((exp(x) + exp(-x))/2)`, where x is the error `y_pred - y_true`. - Usage: + Standalone usage: >>> y_true = [[0., 1.], [0., 0.]] >>> y_pred = [[1., 1.], [0., 0.]] @@ -1006,11 +1000,10 @@ class LogCosh(LossFunctionWrapper): >>> l(y_true, y_pred).numpy() array([0.217, 0.], dtype=float32) - Usage with the `compile` API: + Usage with the `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss=tf.keras.losses.LogCosh()) + model.compile(optimizer='sgd', loss=tf.keras.losses.LogCosh()) ``` """ @@ -1040,7 +1033,7 @@ class KLDivergence(LossFunctionWrapper): See: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence - Usage: + Standalone usage: >>> y_true = [[0, 1], [0, 0]] >>> y_pred = [[0.6, 0.4], [0.4, 0.6]] @@ -1065,11 +1058,10 @@ class KLDivergence(LossFunctionWrapper): >>> kl(y_true, y_pred).numpy() array([0.916, -3.08e-06], dtype=float32) - Usage with the `compile` API: + Usage with the `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss=tf.keras.losses.KLDivergence()) + model.compile(optimizer='sgd', loss=tf.keras.losses.KLDivergence()) ``` """ @@ -1106,7 +1098,7 @@ class Huber(LossFunctionWrapper): ``` where d is `delta`. See: https://en.wikipedia.org/wiki/Huber_loss - Usage: + Standalone usage: >>> y_true = [[0, 1], [0, 0]] >>> y_pred = [[0.6, 0.4], [0.4, 0.6]] @@ -1131,11 +1123,10 @@ class Huber(LossFunctionWrapper): >>> h(y_true, y_pred).numpy() array([0.18, 0.13], dtype=float32) - Usage with the `compile` API: + Usage with the `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss=tf.keras.losses.Huber()) + model.compile(optimizer='sgd', loss=tf.keras.losses.Huber()) ``` """ @@ -1177,7 +1168,7 @@ def mean_squared_error(y_true, y_pred): `loss = mean(square(y_true - y_pred), axis=-1)` - Usage: + Standalone usage: >>> y_true = np.random.randint(0, 2, size=(2, 3)) >>> y_pred = np.random.random(size=(2, 3)) @@ -1209,7 +1200,7 @@ def mean_absolute_error(y_true, y_pred): `loss = mean(abs(y_true - y_pred), axis=-1)` - Usage: + Standalone usage: >>> y_true = np.random.randint(0, 2, size=(2, 3)) >>> y_pred = np.random.random(size=(2, 3)) @@ -1241,7 +1232,7 @@ def mean_absolute_percentage_error(y_true, y_pred): `loss = 100 * mean(abs(y_true - y_pred) / y_true, axis=-1)` - Usage: + Standalone usage: >>> y_true = np.random.random(size=(2, 3)) >>> y_true = np.maximum(y_true, 1e-7) # Prevent division by zero @@ -1277,7 +1268,7 @@ def mean_squared_logarithmic_error(y_true, y_pred): `loss = mean(square(log(y_true + 1) - log(y_pred + 1)), axis=-1)` - Usage: + Standalone usage: >>> y_true = np.random.randint(0, 2, size=(2, 3)) >>> y_pred = np.random.random(size=(2, 3)) @@ -1325,7 +1316,7 @@ def squared_hinge(y_true, y_pred): `loss = mean(square(maximum(1 - y_true * y_pred, 0)), axis=-1)` - Usage: + Standalone usage: >>> y_true = np.random.choice([-1, 1], size=(2, 3)) >>> y_pred = np.random.random(size=(2, 3)) @@ -1357,7 +1348,7 @@ def hinge(y_true, y_pred): `loss = mean(maximum(1 - y_true * y_pred, 0), axis=-1)` - Usage: + Standalone usage: >>> y_true = np.random.choice([-1, 1], size=(2, 3)) >>> y_pred = np.random.random(size=(2, 3)) @@ -1389,7 +1380,7 @@ def categorical_hinge(y_true, y_pred): `loss = maximum(neg - pos + 1, 0)` where `neg=maximum((1-y_true)*y_pred) and pos=sum(y_true*y_pred)` - Usage: + Standalone usage: >>> y_true = np.random.randint(0, 3, size=(2,)) >>> y_true = tf.keras.utils.to_categorical(y_true, num_classes=3) @@ -1459,7 +1450,7 @@ def log_cosh(y_true, y_pred): like the mean squared error, but will not be so strongly affected by the occasional wildly incorrect prediction. - Usage: + Standalone usage: >>> y_true = np.random.random(size=(2, 3)) >>> y_pred = np.random.random(size=(2, 3)) @@ -1495,7 +1486,7 @@ def categorical_crossentropy(y_true, label_smoothing=0): """Computes the categorical crossentropy loss. - Usage: + Standalone usage: >>> y_true = [[0, 1, 0], [0, 0, 1]] >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] @@ -1532,7 +1523,7 @@ def categorical_crossentropy(y_true, def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1): """Computes the sparse categorical crossentropy loss. - Usage: + Standalone usage: >>> y_true = [1, 2] >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] @@ -1563,7 +1554,7 @@ def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1): def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0): """Computes the binary crossentropy loss. - Usage: + Standalone usage: >>> y_true = [[0, 1], [0, 0]] >>> y_pred = [[0.6, 0.4], [0.4, 0.6]] @@ -1610,7 +1601,7 @@ def kl_divergence(y_true, y_pred): See: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence - Usage: + Standalone usage: >>> y_true = np.random.randint(0, 2, size=(2, 3)).astype(np.float64) >>> y_pred = np.random.random(size=(2, 3)) @@ -1645,7 +1636,7 @@ def poisson(y_true, y_pred): The Poisson loss is the mean of the elements of the `Tensor` `y_pred - y_true * log(y_pred)`. - Usage: + Standalone usage: >>> y_true = np.random.randint(0, 2, size=(2, 3)) >>> y_pred = np.random.random(size=(2, 3)) @@ -1683,27 +1674,24 @@ def poisson(y_true, y_pred): def cosine_similarity(y_true, y_pred, axis=-1): """Computes the cosine similarity between labels and predictions. - Note that it is a negative quantity between -1 and 0, where 0 indicates - orthogonality and values closer to -1 indicate greater similarity. This makes - it usable as a loss function in a setting where you try to maximize the - proximity between predictions and targets. If either `y_true` or `y_pred` - is a zero vector, cosine similarity will be 0 regardless of the proximity - between predictions and targets. + Note that it is a number between -1 and 1. When it is a negative number + between -1 and 0, 0 indicates orthogonality and values closer to -1 + indicate greater similarity. The values closer to 1 indicate greater + dissimilarity. This makes it usable as a loss function in a setting + where you try to maximize the proximity between predictions and + targets. If either `y_true` or `y_pred` is a zero vector, cosine + similarity will be 0 regardless of the proximity between predictions + and targets. `loss = -sum(l2_norm(y_true) * l2_norm(y_pred))` - Usage: + Standalone usage: - >>> y_true = [[0., 1.], [1., 1.]] - >>> y_pred =[[1., 0.], [1., 1.]] + >>> y_true = [[0., 1.], [1., 1.], [1., 1.]] + >>> y_pred = [[1., 0.], [1., 1.], [-1., -1.]] >>> loss = tf.keras.losses.cosine_similarity(y_true, y_pred, axis=1) - >>> # l2_norm(y_true) = [[0., 1.], [1./1.414], 1./1.414]]] - >>> # l2_norm(y_pred) = [[1., 0.], [1./1.414], 1./1.414]]] - >>> # l2_norm(y_true) . l2_norm(y_pred) = [[0., 0.], [0.5, 0.5]] - >>> # loss = -sum(l2_norm(y_true) . l2_norm(y_pred), axis=1) - >>> # = -[0. + 0., 0.5 + 0.5] >>> loss.numpy() - array([-0., -0.999], dtype=float32) + array([-0., -0.999, 0.999], dtype=float32) Args: y_true: Tensor of true targets. @@ -1731,7 +1719,7 @@ class CosineSimilarity(LossFunctionWrapper): `loss = -sum(l2_norm(y_true) * l2_norm(y_pred))` - Usage: + Standalone usage: >>> y_true = [[0., 1.], [1., 1.]] >>> y_pred = [[1., 0.], [1., 1.]] @@ -1761,11 +1749,10 @@ class CosineSimilarity(LossFunctionWrapper): >>> cosine_loss(y_true, y_pred).numpy() array([-0., -0.999], dtype=float32) - Usage with the `compile` API: + Usage with the `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss=tf.keras.losses.CosineSimilarity(axis=1)) + model.compile(optimizer='sgd', loss=tf.keras.losses.CosineSimilarity(axis=1)) ``` Args: diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index 5cbd59c49cf..4323c52209f 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -83,7 +83,7 @@ class Metric(base_layer.Layer): dtype: (Optional) data type of the metric result. **kwargs: Additional layer keywords arguments. - Usage: + Standalone usage: ```python m = SomeMetric(...) @@ -92,7 +92,7 @@ class Metric(base_layer.Layer): print('Final result: ', m.result().numpy()) ``` - Usage with tf.keras API: + Usage with `compile()` API: ```python model = tf.keras.Sequential() @@ -404,19 +404,18 @@ class Sum(Reduce): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.Sum() - >>> _ = m.update_state([1, 3, 5, 7]) + >>> m.update_state([1, 3, 5, 7]) >>> m.result().numpy() 16.0 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) model.add_metric(tf.keras.metrics.Sum(name='sum_1')(outputs)) - model.compile('sgd', loss='mse') + model.compile(optimizer='sgd', loss='mse') ``` """ @@ -443,23 +442,22 @@ class Mean(Reduce): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.Mean() - >>> _ = m.update_state([1, 3, 5, 7]) + >>> m.update_state([1, 3, 5, 7]) >>> m.result().numpy() 4.0 >>> m.reset_states() - >>> _ = m.update_state([1, 3, 5, 7], sample_weight=[1, 1, 0, 0]) + >>> m.update_state([1, 3, 5, 7], sample_weight=[1, 1, 0, 0]) >>> m.result().numpy() 2.0 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) model.add_metric(tf.keras.metrics.Mean(name='mean_1')(outputs)) - model.compile('sgd', loss='mse') + model.compile(optimizer='sgd', loss='mse') ``` """ @@ -485,10 +483,10 @@ class MeanRelativeError(Mean): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.MeanRelativeError(normalizer=[1, 3, 2, 3]) - >>> _ = m.update_state([1, 3, 2, 3], [2, 4, 6, 8]) + >>> m.update_state([1, 3, 2, 3], [2, 4, 6, 8]) >>> # metric = mean(|y_pred - y_true| / normalizer) >>> # = mean([1, 1, 4, 5] / [1, 3, 2, 3]) = mean([1, 1/3, 2, 5/3]) @@ -496,12 +494,11 @@ class MeanRelativeError(Mean): >>> m.result().numpy() 1.25 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) model.compile( - 'sgd', + optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.MeanRelativeError(normalizer=[1, 3])]) ``` @@ -638,24 +635,25 @@ class Accuracy(MeanMetricWrapper): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.Accuracy() - >>> _ = m.update_state([[1], [2], [3], [4]], [[0], [2], [3], [4]]) + >>> m.update_state([[1], [2], [3], [4]], [[0], [2], [3], [4]]) >>> m.result().numpy() 0.75 >>> m.reset_states() - >>> _ = m.update_state([[1], [2], [3], [4]], [[0], [2], [3], [4]], - ... sample_weight=[1, 1, 0, 0]) + >>> m.update_state([[1], [2], [3], [4]], [[0], [2], [3], [4]], + ... sample_weight=[1, 1, 0, 0]) >>> m.result().numpy() 0.5 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.Accuracy()]) + model.compile(optimizer='sgd', + loss='mse', + metrics=[tf.keras.metrics.Accuracy()]) ``` """ @@ -681,24 +679,25 @@ class BinaryAccuracy(MeanMetricWrapper): threshold: (Optional) Float representing the threshold for deciding whether prediction values are 1 or 0. - Usage: + Standalone usage: >>> m = tf.keras.metrics.BinaryAccuracy() - >>> _ = m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]]) + >>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]]) >>> m.result().numpy() 0.75 >>> m.reset_states() - >>> _ = m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]], - ... sample_weight=[1, 0, 0, 1]) + >>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]], + ... sample_weight=[1, 0, 0, 1]) >>> m.result().numpy() 0.5 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.BinaryAccuracy()]) + model.compile(optimizer='sgd', + loss='mse', + metrics=[tf.keras.metrics.BinaryAccuracy()]) ``` """ @@ -729,27 +728,26 @@ class CategoricalAccuracy(MeanMetricWrapper): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.CategoricalAccuracy() - >>> _ = m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8], - ... [0.05, 0.95, 0]]) + >>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8], + ... [0.05, 0.95, 0]]) >>> m.result().numpy() 0.5 >>> m.reset_states() - >>> _ = m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8], - ... [0.05, 0.95, 0]], - ... sample_weight=[0.7, 0.3]) + >>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8], + ... [0.05, 0.95, 0]], + ... sample_weight=[0.7, 0.3]) >>> m.result().numpy() 0.3 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) model.compile( - 'sgd', + optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.CategoricalAccuracy()]) ``` @@ -783,25 +781,24 @@ class SparseCategoricalAccuracy(MeanMetricWrapper): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.SparseCategoricalAccuracy() - >>> _ = m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]]) + >>> m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]]) >>> m.result().numpy() 0.5 >>> m.reset_states() - >>> _ = m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]], - ... sample_weight=[0.7, 0.3]) + >>> m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]], + ... sample_weight=[0.7, 0.3]) >>> m.result().numpy() 0.3 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) model.compile( - 'sgd', + optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) ``` @@ -822,26 +819,27 @@ class TopKCategoricalAccuracy(MeanMetricWrapper): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.TopKCategoricalAccuracy(k=1) - >>> _ = m.update_state([[0, 0, 1], [0, 1, 0]], - ... [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]) + >>> m.update_state([[0, 0, 1], [0, 1, 0]], + ... [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]) >>> m.result().numpy() 0.5 >>> m.reset_states() - >>> _ = m.update_state([[0, 0, 1], [0, 1, 0]], - ... [[0.1, 0.9, 0.8], [0.05, 0.95, 0]], - ... sample_weight=[0.7, 0.3]) + >>> m.update_state([[0, 0, 1], [0, 1, 0]], + ... [[0.1, 0.9, 0.8], [0.05, 0.95, 0]], + ... sample_weight=[0.7, 0.3]) >>> m.result().numpy() 0.3 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', metrics=[tf.keras.metrics.TopKCategoricalAccuracy()]) + model.compile(optimizer='sgd', + loss='mse', + metrics=[tf.keras.metrics.TopKCategoricalAccuracy()]) ``` """ @@ -860,25 +858,25 @@ class SparseTopKCategoricalAccuracy(MeanMetricWrapper): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1) - >>> _ = m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]) + >>> m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]) >>> m.result().numpy() 0.5 >>> m.reset_states() - >>> _ = m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]], - ... sample_weight=[0.7, 0.3]) + >>> m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]], + ... sample_weight=[0.7, 0.3]) >>> m.result().numpy() 0.3 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) model.compile( - 'sgd', + optimizer='sgd', + loss='mse', metrics=[tf.keras.metrics.SparseTopKCategoricalAccuracy()]) ``` """ @@ -975,23 +973,24 @@ class FalsePositives(_ConfusionMatrixConditionCount): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.FalsePositives() - >>> _ = m.update_state([0, 1, 0, 0], [0, 0, 1, 1]) + >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1]) >>> m.result().numpy() 2.0 >>> m.reset_states() - >>> _ = m.update_state([0, 1, 0, 0], [0, 0, 1, 1], sample_weight=[0, 0, 1, 0]) + >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1], sample_weight=[0, 0, 1, 0]) >>> m.result().numpy() 1.0 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.FalsePositives()]) + model.compile(optimizer='sgd', + loss='mse', + metrics=[tf.keras.metrics.FalsePositives()]) ``` """ @@ -1023,23 +1022,24 @@ class FalseNegatives(_ConfusionMatrixConditionCount): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.FalseNegatives() - >>> _ = m.update_state([0, 1, 1, 1], [0, 1, 0, 0]) + >>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0]) >>> m.result().numpy() 2.0 >>> m.reset_states() - >>> _ = m.update_state([0, 1, 1, 1], [0, 1, 0, 0], sample_weight=[0, 0, 1, 0]) + >>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0], sample_weight=[0, 0, 1, 0]) >>> m.result().numpy() 1.0 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.FalseNegatives()]) + model.compile(optimizer='sgd', + loss='mse', + metrics=[tf.keras.metrics.FalseNegatives()]) ``` """ @@ -1071,23 +1071,24 @@ class TrueNegatives(_ConfusionMatrixConditionCount): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.TrueNegatives() - >>> _ = m.update_state([0, 1, 0, 0], [1, 1, 0, 0]) + >>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0]) >>> m.result().numpy() 2.0 >>> m.reset_states() - >>> _ = m.update_state([0, 1, 0, 0], [1, 1, 0, 0], sample_weight=[0, 0, 1, 0]) + >>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0], sample_weight=[0, 0, 1, 0]) >>> m.result().numpy() 1.0 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.TrueNegatives()]) + model.compile(optimizer='sgd', + loss='mse', + metrics=[tf.keras.metrics.TrueNegatives()]) ``` """ @@ -1119,23 +1120,24 @@ class TruePositives(_ConfusionMatrixConditionCount): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.TruePositives() - >>> _ = m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) + >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) >>> m.result().numpy() 2.0 >>> m.reset_states() - >>> _ = m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0]) + >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0]) >>> m.result().numpy() 1.0 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.TruePositives()]) + model.compile(optimizer='sgd', + loss='mse', + metrics=[tf.keras.metrics.TruePositives()]) ``` """ @@ -1183,35 +1185,36 @@ class Precision(Metric): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.Precision() - >>> _ = m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) + >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) >>> m.result().numpy() 0.6666667 >>> m.reset_states() - >>> _ = m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0]) + >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0]) >>> m.result().numpy() 1.0 >>> # With top_k=2, it will calculate precision over y_true[:2] and y_pred[:2] >>> m = tf.keras.metrics.Precision(top_k=2) - >>> _ = m.update_state([0, 0, 1, 1], [1, 1, 1, 1]) + >>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1]) >>> m.result().numpy() 0.0 >>> # With top_k=4, it will calculate precision over y_true[:4] and y_pred[:4] >>> m = tf.keras.metrics.Precision(top_k=4) - >>> _ = m.update_state([0, 0, 1, 1], [1, 1, 1, 1]) + >>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1]) >>> m.result().numpy() 0.5 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.Precision()]) + model.compile(optimizer='sgd', + loss='mse', + metrics=[tf.keras.metrics.Precision()]) ``` """ @@ -1319,23 +1322,24 @@ class Recall(Metric): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.Recall() - >>> _ = m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) + >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) >>> m.result().numpy() 0.6666667 >>> m.reset_states() - >>> _ = m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0]) + >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0]) >>> m.result().numpy() 1.0 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.Recall()]) + model.compile(optimizer='sgd', + loss='mse', + metrics=[tf.keras.metrics.Recall()]) ``` """ @@ -1529,25 +1533,24 @@ class SensitivityAtSpecificity(SensitivitySpecificityBase): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.SensitivityAtSpecificity(0.5) - >>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) + >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) >>> m.result().numpy() 0.5 >>> m.reset_states() - >>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], - ... sample_weight=[1, 1, 2, 2, 1]) + >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], + ... sample_weight=[1, 1, 2, 2, 1]) >>> m.result().numpy() 0.333333 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) model.compile( - 'sgd', + optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.SensitivityAtSpecificity()]) ``` @@ -1605,25 +1608,24 @@ class SpecificityAtSensitivity(SensitivitySpecificityBase): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.SpecificityAtSensitivity(0.5) - >>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) + >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) >>> m.result().numpy() 0.66666667 >>> m.reset_states() - >>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], - ... sample_weight=[1, 1, 2, 2, 2]) + >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], + ... sample_weight=[1, 1, 2, 2, 2]) >>> m.result().numpy() 0.5 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) model.compile( - 'sgd', + optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.SpecificityAtSensitivity()]) ``` @@ -1673,25 +1675,24 @@ class PrecisionAtRecall(SensitivitySpecificityBase): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.PrecisionAtRecall(0.5) - >>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) + >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) >>> m.result().numpy() 0.5 >>> m.reset_states() - >>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], - ... sample_weight=[2, 2, 2, 1, 1]) + >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], + ... sample_weight=[2, 2, 2, 1, 1]) >>> m.result().numpy() 0.33333333 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) model.compile( - 'sgd', + optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.PrecisionAtRecall(recall=0.8)]) ``` @@ -1744,25 +1745,24 @@ class RecallAtPrecision(SensitivitySpecificityBase): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.RecallAtPrecision(0.8) - >>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) + >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) >>> m.result().numpy() 0.5 >>> m.reset_states() - >>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9], - ... sample_weight=[1, 0, 0, 1]) + >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9], + ... sample_weight=[1, 0, 0, 1]) >>> m.result().numpy() 1.0 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) model.compile( - 'sgd', + optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.RecallAtPrecision(precision=0.8)]) ``` @@ -1861,10 +1861,10 @@ class AUC(Metric): before flattening; therefore `label_weights` should not be used for multi-class data. - Usage: + Standalone usage: >>> m = tf.keras.metrics.AUC(num_thresholds=3) - >>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) + >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) >>> # threshold values are [0 - 1e-7, 0.5, 1 + 1e-7] >>> # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2] >>> # recall = [1, 0.5, 0], fp_rate = [1, 0, 0] @@ -1873,16 +1873,15 @@ class AUC(Metric): 0.75 >>> m.reset_states() - >>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9], - ... sample_weight=[1, 0, 0, 1]) + >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9], + ... sample_weight=[1, 0, 0, 1]) >>> m.result().numpy() 1.0 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.AUC()]) + model.compile(optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.AUC()]) ``` """ @@ -2239,7 +2238,7 @@ class CosineSimilarity(MeanMetricWrapper): axis: (Optional) Defaults to -1. The dimension along which the cosine similarity is computed. - Usage: + Standalone usage: >>> # l2_norm(y_true) = [[0., 1.], [1./1.414], 1./1.414]]] >>> # l2_norm(y_pred) = [[1., 0.], [1./1.414], 1./1.414]]] @@ -2247,22 +2246,21 @@ class CosineSimilarity(MeanMetricWrapper): >>> # result = mean(sum(l2_norm(y_true) . l2_norm(y_pred), axis=1)) >>> # = ((0. + 0.) + (0.5 + 0.5)) / 2 >>> m = tf.keras.metrics.CosineSimilarity(axis=1) - >>> _ = m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]]) + >>> m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]]) >>> m.result().numpy() 0.49999997 >>> m.reset_states() - >>> _ = m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]], - ... sample_weight=[0.3, 0.7]) + >>> m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]], + ... sample_weight=[0.3, 0.7]) >>> m.result().numpy() 0.6999999 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) model.compile( - 'sgd', + optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.CosineSimilarity(axis=1)]) ``` @@ -2281,25 +2279,26 @@ class MeanAbsoluteError(MeanMetricWrapper): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.MeanAbsoluteError() - >>> _ = m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) >>> m.result().numpy() 0.25 >>> m.reset_states() - >>> _ = m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], - ... sample_weight=[1, 0]) + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], + ... sample_weight=[1, 0]) >>> m.result().numpy() 0.5 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) model.compile( - 'sgd', loss='mse', metrics=[tf.keras.metrics.MeanAbsoluteError()]) + optimizer='sgd', + loss='mse', + metrics=[tf.keras.metrics.MeanAbsoluteError()]) ``` """ @@ -2316,25 +2315,24 @@ class MeanAbsolutePercentageError(MeanMetricWrapper): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.MeanAbsolutePercentageError() - >>> _ = m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) >>> m.result().numpy() 250000000.0 >>> m.reset_states() - >>> _ = m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], - ... sample_weight=[1, 0]) + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], + ... sample_weight=[1, 0]) >>> m.result().numpy() 500000000.0 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) model.compile( - 'sgd', + optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.MeanAbsolutePercentageError()]) ``` @@ -2353,25 +2351,26 @@ class MeanSquaredError(MeanMetricWrapper): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.MeanSquaredError() - >>> _ = m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) >>> m.result().numpy() 0.25 >>> m.reset_states() - >>> _ = m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], - ... sample_weight=[1, 0]) + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], + ... sample_weight=[1, 0]) >>> m.result().numpy() 0.5 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) model.compile( - 'sgd', loss='mse', metrics=[tf.keras.metrics.MeanSquaredError()]) + optimizer='sgd', + loss='mse', + metrics=[tf.keras.metrics.MeanSquaredError()]) ``` """ @@ -2388,25 +2387,24 @@ class MeanSquaredLogarithmicError(MeanMetricWrapper): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.MeanSquaredLogarithmicError() - >>> _ = m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) >>> m.result().numpy() 0.12011322 >>> m.reset_states() - >>> _ = m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], - ... sample_weight=[1, 0]) + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], + ... sample_weight=[1, 0]) >>> m.result().numpy() 0.24022643 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) model.compile( - 'sgd', + optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.MeanSquaredLogarithmicError()]) ``` @@ -2428,24 +2426,23 @@ class Hinge(MeanMetricWrapper): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.Hinge() - >>> _ = m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) + >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) >>> m.result().numpy() 1.3 >>> m.reset_states() - >>> _ = m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], - ... sample_weight=[1, 0]) + >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], + ... sample_weight=[1, 0]) >>> m.result().numpy() 1.1 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.Hinge()]) + model.compile(optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.Hinge()]) ``` """ @@ -2464,25 +2461,24 @@ class SquaredHinge(MeanMetricWrapper): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.SquaredHinge() - >>> _ = m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) + >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) >>> m.result().numpy() 1.86 >>> m.reset_states() - >>> _ = m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], - ... sample_weight=[1, 0]) + >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], + ... sample_weight=[1, 0]) >>> m.result().numpy() 1.46 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) model.compile( - 'sgd', + optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.SquaredHinge()]) ``` @@ -2500,25 +2496,24 @@ class CategoricalHinge(MeanMetricWrapper): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.CategoricalHinge() - >>> _ = m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) + >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) >>> m.result().numpy() 1.4000001 >>> m.reset_states() - >>> _ = m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], - ... sample_weight=[1, 0]) + >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], + ... sample_weight=[1, 0]) >>> m.result().numpy() 1.2 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) model.compile( - 'sgd', + optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.CategoricalHinge()]) ``` @@ -2532,25 +2527,24 @@ class CategoricalHinge(MeanMetricWrapper): class RootMeanSquaredError(Mean): """Computes root mean squared error metric between `y_true` and `y_pred`. - Usage: + Standalone usage: >>> m = tf.keras.metrics.RootMeanSquaredError() - >>> _ = m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) >>> m.result().numpy() 0.5 >>> m.reset_states() - >>> _ = m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], - ... sample_weight=[1, 0]) + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], + ... sample_weight=[1, 0]) >>> m.result().numpy() 0.70710677 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) model.compile( - 'sgd', + optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.RootMeanSquaredError()]) ``` @@ -2594,24 +2588,25 @@ class LogCoshError(MeanMetricWrapper): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.LogCoshError() - >>> _ = m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) >>> m.result().numpy() 0.10844523 >>> m.reset_states() - >>> _ = m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], - ... sample_weight=[1, 0]) + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], + ... sample_weight=[1, 0]) >>> m.result().numpy() 0.21689045 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.LogCoshError()]) + model.compile(optimizer='sgd', + loss='mse', + metrics=[tf.keras.metrics.LogCoshError()]) ``` """ @@ -2629,24 +2624,25 @@ class Poisson(MeanMetricWrapper): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.Poisson() - >>> _ = m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) >>> m.result().numpy() 0.49999997 >>> m.reset_states() - >>> _ = m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], - ... sample_weight=[1, 0]) + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], + ... sample_weight=[1, 0]) >>> m.result().numpy() 0.99999994 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.Poisson()]) + model.compile(optimizer='sgd', + loss='mse', + metrics=[tf.keras.metrics.Poisson()]) ``` """ @@ -2664,24 +2660,25 @@ class KLDivergence(MeanMetricWrapper): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.KLDivergence() - >>> _ = m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) + >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) >>> m.result().numpy() 0.45814306 >>> m.reset_states() - >>> _ = m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], - ... sample_weight=[1, 0]) + >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], + ... sample_weight=[1, 0]) >>> m.result().numpy() 0.9162892 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) - model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.KLDivergence()]) + model.compile(optimizer='sgd', + loss='mse', + metrics=[tf.keras.metrics.KLDivergence()]) ``` """ @@ -2711,7 +2708,7 @@ class MeanIoU(Metric): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> # cm = [[1, 1], >>> # [1, 1]] @@ -2719,22 +2716,21 @@ class MeanIoU(Metric): >>> # iou = true_positives / (sum_row + sum_col - true_positives)) >>> # result = (1 / (2 + 2 - 1) + 1 / (2 + 2 - 1)) / 2 = 0.33 >>> m = tf.keras.metrics.MeanIoU(num_classes=2) - >>> _ = m.update_state([0, 0, 1, 1], [0, 1, 0, 1]) + >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1]) >>> m.result().numpy() 0.33333334 >>> m.reset_states() - >>> _ = m.update_state([0, 0, 1, 1], [0, 1, 0, 1], - ... sample_weight=[0.3, 0.3, 0.3, 0.1]) + >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1], + ... sample_weight=[0.3, 0.3, 0.3, 0.1]) >>> m.result().numpy() 0.23809525 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) model.compile( - 'sgd', + optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.MeanIoU(num_classes=2)]) ``` @@ -2836,15 +2832,15 @@ class MeanTensor(Metric): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. - Usage: + Standalone usage: >>> m = tf.keras.metrics.MeanTensor() - >>> _ = m.update_state([0, 1, 2, 3]) - >>> _ = m.update_state([4, 5, 6, 7]) + >>> m.update_state([0, 1, 2, 3]) + >>> m.update_state([4, 5, 6, 7]) >>> m.result().numpy() array([2., 3., 4., 5.], dtype=float32) - >>> _ = m.update_state([12, 10, 8, 6], sample_weight= [0, 0.2, 0.5, 1]) + >>> m.update_state([12, 10, 8, 6], sample_weight= [0, 0.2, 0.5, 1]) >>> m.result().numpy() array([2. , 3.6363635, 4.8 , 5.3333335], dtype=float32) """ @@ -2951,25 +2947,24 @@ class BinaryCrossentropy(MeanMetricWrapper): e.g. `label_smoothing=0.2` means that we will use a value of `0.1` for label `0` and `0.9` for label `1`". - Usage: + Standalone usage: >>> m = tf.keras.metrics.BinaryCrossentropy() - >>> _ = m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) + >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) >>> m.result().numpy() 0.81492424 >>> m.reset_states() - >>> _ = m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], - ... sample_weight=[1, 0]) + >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], + ... sample_weight=[1, 0]) >>> m.result().numpy() 0.9162905 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) model.compile( - 'sgd', + optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.BinaryCrossentropy()]) ``` @@ -3007,7 +3002,7 @@ class CategoricalCrossentropy(MeanMetricWrapper): `label_smoothing=0.2` means that we will use a value of `0.1` for label `0` and `0.9` for label `1`" - Usage: + Standalone usage: >>> # EPSILON = 1e-7, y = y_true, y` = y_pred >>> # y` = clip_ops.clip_by_value(output, EPSILON, 1. - EPSILON) @@ -3017,24 +3012,23 @@ class CategoricalCrossentropy(MeanMetricWrapper): >>> # = [0.051, 2.302] >>> # Reduced xent = (0.051 + 2.302) / 2 >>> m = tf.keras.metrics.CategoricalCrossentropy() - >>> _ = m.update_state([[0, 1, 0], [0, 0, 1]], - ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) + >>> m.update_state([[0, 1, 0], [0, 0, 1]], + ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) >>> m.result().numpy() 1.1769392 >>> m.reset_states() - >>> _ = m.update_state([[0, 1, 0], [0, 0, 1]], - ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]], - ... sample_weight=tf.constant([0.3, 0.7])) + >>> m.update_state([[0, 1, 0], [0, 0, 1]], + ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]], + ... sample_weight=tf.constant([0.3, 0.7])) >>> m.result().numpy() 1.6271976 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) model.compile( - 'sgd', + optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.CategoricalCrossentropy()]) ``` @@ -3076,7 +3070,7 @@ class SparseCategoricalCrossentropy(MeanMetricWrapper): axis: (Optional) Defaults to -1. The dimension along which the metric is computed. - Usage: + Standalone usage: >>> # y_true = one_hot(y_true) = [[0, 1, 0], [0, 0, 1]] >>> # logits = log(y_pred) @@ -3089,24 +3083,23 @@ class SparseCategoricalCrossentropy(MeanMetricWrapper): >>> # xent = [0.0513, 2.3026] >>> # Reduced xent = (0.0513 + 2.3026) / 2 >>> m = tf.keras.metrics.SparseCategoricalCrossentropy() - >>> _ = m.update_state([1, 2], - ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) + >>> m.update_state([1, 2], + ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) >>> m.result().numpy() 1.1769392 >>> m.reset_states() - >>> _ = m.update_state([1, 2], - ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]], - ... sample_weight=tf.constant([0.3, 0.7])) + >>> m.update_state([1, 2], + ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]], + ... sample_weight=tf.constant([0.3, 0.7])) >>> m.result().numpy() 1.6271976 - Usage with tf.keras API: + Usage with `compile()` API: ```python - model = tf.keras.Model(inputs, outputs) model.compile( - 'sgd', + optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.SparseCategoricalCrossentropy()]) ``` @@ -3196,7 +3189,7 @@ def accuracy(y_true, y_pred): def binary_accuracy(y_true, y_pred, threshold=0.5): """Calculates how often predictions matches binary labels. - Usage: + Standalone usage: >>> y_true = [[1], [1], [0], [0]] >>> y_pred = [[1], [1], [0], [0]] >>> m = tf.keras.metrics.binary_accuracy(y_true, y_pred) @@ -3223,7 +3216,7 @@ def binary_accuracy(y_true, y_pred, threshold=0.5): def categorical_accuracy(y_true, y_pred): """Calculates how often predictions matches one-hot labels. - Usage: + Standalone usage: >>> y_true = [[0, 0, 1], [0, 1, 0]] >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] >>> m = tf.keras.metrics.categorical_accuracy(y_true, y_pred) @@ -3251,7 +3244,7 @@ def categorical_accuracy(y_true, y_pred): def sparse_categorical_accuracy(y_true, y_pred): """Calculates how often predictions matches integer labels. - Usage: + Standalone usage: >>> y_true = [2, 1] >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] >>> m = tf.keras.metrics.sparse_categorical_accuracy(y_true, y_pred) @@ -3291,7 +3284,7 @@ def sparse_categorical_accuracy(y_true, y_pred): def top_k_categorical_accuracy(y_true, y_pred, k=5): """Computes how often targets are in the top `K` predictions. - Usage: + Standalone usage: >>> y_true = [[0, 0, 1], [0, 1, 0]] >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] >>> m = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=3) @@ -3316,7 +3309,7 @@ def top_k_categorical_accuracy(y_true, y_pred, k=5): def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5): """Computes how often integer targets are in the top `K` predictions. - Usage: + Standalone usage: >>> y_true = [2, 1] >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] >>> m = tf.keras.metrics.sparse_top_k_categorical_accuracy( @@ -3467,3 +3460,4 @@ def get(identifier): def is_built_in(cls): return cls.__module__ == Metric.__module__ + diff --git a/tensorflow/python/keras/optimizer_v2/adadelta.py b/tensorflow/python/keras/optimizer_v2/adadelta.py index 9d67ed25c66..99bd2f8e8bf 100644 --- a/tensorflow/python/keras/optimizer_v2/adadelta.py +++ b/tensorflow/python/keras/optimizer_v2/adadelta.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - -"""Adadelta for TensorFlow.""" +"""Adadelta optimizer implementation.""" +# pylint: disable=g-classes-have-attributes from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -34,23 +34,9 @@ class Adadelta(optimizer_v2.OptimizerV2): Adadelta optimization is a stochastic gradient descent method that is based on adaptive learning rate per dimension to address two drawbacks: - 1) the continual decay of learning rates throughout training - 2) the need for a manually selected global learning rate - Two accumulation steps are required: - 1) the accumulation of gradients squared, - 2) the accumulation of updates squared. - - Initialization: - - $$E[g^2]_0 := 0 \text{(Initialize gradient 2nd order moment vector)}$$ - $$E[\Delta x^2]_0 := 0 \text{(Initialize 2nd order variable update)}$$ - - $$t := t + 1$$ - $$E[g^2]_t := \rho * E[g^2]_{t-1} + (1 - \rho) * g^2$$ - $$\Delta x_t = -RMS[\Delta x]_{t-1} * g_t / RMS[g]_t$$ - $$E[\Delta x^2]_t := \rho * E[\Delta x^2]_{t-1} + (1 - \rho) * \Delta x_t^2$$ - $$x_t := x_{t-1} + \Delta x_{t}$$ + - The continual decay of learning rates throughout training + - The need for a manually selected global learning rate Adadelta is a more robust extension of Adagrad that adapts learning rates based on a moving window of gradient updates, instead of accumulating all @@ -59,16 +45,22 @@ class Adadelta(optimizer_v2.OptimizerV2): don't have to set an initial learning rate. In this version, initial learning rate can be set, as in most other Keras optimizers. - @compatibility(eager) - When eager execution is enabled, `learning_rate`, `rho`, and `epsilon` can - each be a callable that takes no arguments and returns the actual value to - use. This can be useful for changing these values across different - invocations of optimizer functions. - @end_compatibility + Args: + learning_rate: A `Tensor`, floating point value, or a schedule that is a + `tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate. + To match the exact form in the original paper use 1.0. + rho: A `Tensor` or a floating point value. The decay rate. + epsilon: A `Tensor` or a floating point value. A constant epsilon used + to better conditioning the grad update. + name: Optional name prefix for the operations created when applying + gradients. Defaults to `"Adadelta"`. + **kwargs: Keyword arguments. Allowed to be one of + `"clipnorm"` or `"clipvalue"`. + `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips + gradients by value. - References - See [M. D. Zeiler](http://arxiv.org/abs/1212.5701) - ([pdf](http://arxiv.org/pdf/1212.5701v1.pdf)) + Reference: + - [Zeiler, 2012](http://arxiv.org/abs/1212.5701) """ _HAS_AGGREGATE_GRAD = True @@ -79,23 +71,6 @@ class Adadelta(optimizer_v2.OptimizerV2): epsilon=1e-7, name='Adadelta', **kwargs): - """Construct a new Adadelta optimizer. - - Args: - learning_rate: A `Tensor`, floating point value, or a schedule that is a - `tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate. - To match the exact form in the original paper use 1.0. - rho: A `Tensor` or a floating point value. The decay rate. - epsilon: A `Tensor` or a floating point value. A constant epsilon used - to better conditioning the grad update. - name: Optional name prefix for the operations created when applying - gradients. Defaults to "Adadelta". - **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, - `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip - gradients by value, `decay` is included for backward compatibility to - allow time inverse decay of learning rate. `lr` is included for backward - compatibility, recommended to use `learning_rate` instead. - """ super(Adadelta, self).__init__(name, **kwargs) self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) self._set_hyper('decay', self._initial_decay) diff --git a/tensorflow/python/keras/optimizer_v2/adagrad.py b/tensorflow/python/keras/optimizer_v2/adagrad.py index 4e4ffd8e856..dbed9de92c6 100644 --- a/tensorflow/python/keras/optimizer_v2/adagrad.py +++ b/tensorflow/python/keras/optimizer_v2/adagrad.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - -"""Adagrad for TensorFlow.""" +"""Adagrad optimizer implementation.""" +# pylint: disable=g-classes-have-attributes from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -39,26 +39,22 @@ class Adagrad(optimizer_v2.OptimizerV2): updated during training. The more updates a parameter receives, the smaller the updates. - Initialization: - $$accum_{g_0} := \text{initial_accumulator_value}$$ + Args: + learning_rate: A `Tensor`, floating point value, or a schedule that is a + `tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate. + initial_accumulator_value: A floating point value. + Starting value for the accumulators, must be non-negative. + epsilon: A small floating point value to avoid zero denominator. + name: Optional name prefix for the operations created when applying + gradients. Defaults to `"Adagrad"`. + **kwargs: Keyword arguments. Allowed to be one of + `"clipnorm"` or `"clipvalue"`. + `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips + gradients by value. - Update step: - $$t := t + 1$$ - $$accum_{g_t} := accum_{g_{t-1}} + g^2$$ - $$\theta_t := \theta_{t-1} - lr * g / (\sqrt{accum_{g_t}} + \epsilon)$$ - - @compatibility(eager) - When eager execution is enabled, `learning_rate` can be a callable that - takes no arguments and returns the actual value to use. This can be useful - for changing these values across different invocations of optimizer - functions. - @end_compatibility - - References: - - * [Paper](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf). - * [Introduction] - (https://ppasupat.github.io/a9online/uploads/proximal_notes.pdf). + Reference: + - [Duchi et al., 2011]( + http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf). """ _HAS_AGGREGATE_GRAD = True @@ -69,25 +65,6 @@ class Adagrad(optimizer_v2.OptimizerV2): epsilon=1e-7, name='Adagrad', **kwargs): - """Construct a new Adagrad optimizer. - - Args: - learning_rate: A `Tensor`, floating point value, or a schedule that is a - `tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate. - initial_accumulator_value: A floating point value. - Starting value for the accumulators, must be non-negative. - epsilon: A small floating point value to avoid zero denominator. - name: Optional name prefix for the operations created when applying - gradients. Defaults to "Adagrad". - **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, - `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip - gradients by value, `decay` is included for backward compatibility to - allow time inverse decay of learning rate. `lr` is included for backward - compatibility, recommended to use `learning_rate` instead. - - Raises: - ValueError: If the `initial_accumulator_value` or `epsilon` is invalid. - """ if initial_accumulator_value < 0.0: raise ValueError('initial_accumulator_value must be non-negative: %s' % initial_accumulator_value) @@ -141,7 +118,7 @@ class Adagrad(optimizer_v2.OptimizerV2): An optimizer instance. """ if 'initial_accumulator_value' not in config: - config['initial_accumulator_value'] = 0. + config['initial_accumulator_value'] = 0.1 if 'lr' in config: config['learning_rate'] = config.pop('lr') return cls(**config) diff --git a/tensorflow/python/keras/optimizer_v2/adam.py b/tensorflow/python/keras/optimizer_v2/adam.py index 67152e4b537..df41201e14b 100644 --- a/tensorflow/python/keras/optimizer_v2/adam.py +++ b/tensorflow/python/keras/optimizer_v2/adam.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Adam for TensorFlow.""" +"""Adam optimizer implementation.""" +# pylint: disable=g-classes-have-attributes from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -35,50 +36,58 @@ class Adam(optimizer_v2.OptimizerV2): Adam optimization is a stochastic gradient descent method that is based on adaptive estimation of first-order and second-order moments. - According to the paper - [Adam: A Method for Stochastic Optimization. Kingma et al., - 2014](http://arxiv.org/abs/1412.6980), the method is "*computationally + + According to + [Kingma et al., 2014](http://arxiv.org/abs/1412.6980), + the method is "*computationally efficient, has little memory requirement, invariant to diagonal rescaling of gradients, and is well suited for problems that are large in terms of data/parameters*". - For AMSGrad see [On The Convergence Of Adam And Beyond. - Reddi et al., 5-8](https://openreview.net/pdf?id=ryQu7f-RZ). + Args: + learning_rate: A `Tensor`, floating point value, or a schedule that is a + `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable + that takes no arguments and returns the actual value to use, The + learning rate. Defaults to 0.001. + beta_1: A float value or a constant float tensor, or a callable + that takes no arguments and returns the actual value to use. The + exponential decay rate for the 1st moment estimates. Defaults to 0.9. + beta_2: A float value or a constant float tensor, or a callable + that takes no arguments and returns the actual value to use, The + exponential decay rate for the 2nd moment estimates. Defaults to 0.999. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just before + Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to + 1e-7. + amsgrad: Boolean. Whether to apply AMSGrad variant of this algorithm from + the paper "On the Convergence of Adam and beyond". Defaults to `False`. + name: Optional name for the operations created when applying gradients. + Defaults to `"Adam"`. + **kwargs: Keyword arguments. Allowed to be one of + `"clipnorm"` or `"clipvalue"`. + `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips + gradients by value. - **If amsgrad = False**: + Usage: - initialize $m_0$ as 1st moment vector - initialize $v_0$ as 2nd moment vector + >>> opt = tf.keras.optimizers.Adam(learning_rate=0.1) + >>> var1 = tf.Variable(10.0) + >>> loss = lambda: (var1 ** 2)/2.0 # d(loss)/d(var1) == var1 + >>> step_count = opt.minimize(loss, [var1]).numpy() + >>> # The first step is `-learning_rate*sign(grad)` + >>> var1.numpy() + 9.9 - The update rule for $\theta$ with gradient $g$ uses an optimization - described at the end of section 2 of the paper: + Reference: + - [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) + - [Reddi et al., 2018]( + https://openreview.net/pdf?id=ryQu7f-RZ) for `amsgrad`. - $$lr_t = \mathrm{learning\_rate} * - \sqrt{1 - \beta_2^t} / (1 - \beta_1^t)$$ - $$m_t = \beta_1 * m_{t-1} + (1 - \beta_1) * g$$ - $$v_t = \beta_2 * v_{t-1} + (1 - \beta_2) * g^2$$ - $$\theta_t = \theta_{t-1} - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$ - - **If amsgrad = True**: - - initialize $m_0$ as 1st moment vector - initialize $v_0$ as 2nd moment vector - initialize $\hat{v}_0$ as 2nd moment vector - - The update rule for $\theta$ with gradient $g$ uses an optimization - described at the end of section 2 of the paper: - - $$lr_t = \mathrm{learning\_rate} * - \sqrt{1 - \beta_2^t} / (1 - \beta_1^t)$$ - - $$m_t = \beta_1 * m_{t-1} + (1 - \beta_1) * g$$ - $$v_t = \beta_2 * v_{t-1} + (1 - \beta_2) * g^2$$ - $$\hat{v}_t = \max(\hat{v}_{t-1}, v_t)$$ - $$\theta_t = \theta_{t-1} - lr_t * m_t / (\sqrt{\hat{v}_t} + \epsilon)$$ + Notes: The default value of 1e-7 for epsilon might not be a good default in general. For example, when training an Inception network on ImageNet a - current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the + current good choice is 1.0 or 0.1. Note that since Adam uses the formulation just before Section 2.1 of the Kingma and Ba paper rather than the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon hat" in the paper. @@ -91,16 +100,6 @@ class Adam(optimizer_v2.OptimizerV2): accumulator. This means that the sparse behavior is equivalent to the dense behavior (in contrast to some momentum implementations which ignore momentum unless a variable slice was actually used). - - Usage: - - >>> opt = tf.keras.optimizers.Adam(learning_rate=0.1) - >>> var1 = tf.Variable(10.0) - >>> loss = lambda: (var1 ** 2)/2.0 # d(loss)/d(var1) == var1 - >>> step_count = opt.minimize(loss, [var1]).numpy() - >>> # The first step is `-learning_rate*sign(grad)` - >>> var1.numpy() - 9.9 """ _HAS_AGGREGATE_GRAD = True @@ -113,34 +112,6 @@ class Adam(optimizer_v2.OptimizerV2): amsgrad=False, name='Adam', **kwargs): - """Construct a new Adam optimizer. - - Args: - learning_rate: A `Tensor`, floating point value, or a schedule that is a - `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable - that takes no arguments and returns the actual value to use, The - learning rate. Defaults to 0.001. - beta_1: A float value or a constant float tensor, or a callable - that takes no arguments and returns the actual value to use. The - exponential decay rate for the 1st moment estimates. Defaults to 0.9. - beta_2: A float value or a constant float tensor, or a callable - that takes no arguments and returns the actual value to use, The - exponential decay rate for the 2nd moment estimates. Defaults to 0.999. - epsilon: A small constant for numerical stability. This epsilon is - "epsilon hat" in the Kingma and Ba paper (in the formula just before - Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to - 1e-7. - amsgrad: Boolean. Whether to apply AMSGrad variant of this algorithm from - the paper "On the Convergence of Adam and beyond". Defaults to `False`. - name: Optional name for the operations created when applying gradients. - Defaults to "Adam". - **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, - `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip - gradients by value, `decay` is included for backward compatibility to - allow time inverse decay of learning rate. `lr` is included for backward - compatibility, recommended to use `learning_rate` instead. - """ - super(Adam, self).__init__(name, **kwargs) self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) self._set_hyper('decay', self._initial_decay) @@ -329,7 +300,7 @@ class NonFusedAdam(optimizer_v2.OptimizerV2): The default value of 1e-7 for epsilon might not be a good default in general. For example, when training an Inception network on ImageNet a - current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the + current good choice is 1.0 or 0.1. Note that since Adam uses the formulation just before Section 2.1 of the Kingma and Ba paper rather than the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon hat" in the paper. diff --git a/tensorflow/python/keras/optimizer_v2/adamax.py b/tensorflow/python/keras/optimizer_v2/adamax.py index 9a7e1e28a89..5ac4734c6a2 100644 --- a/tensorflow/python/keras/optimizer_v2/adamax.py +++ b/tensorflow/python/keras/optimizer_v2/adamax.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - -"""Adamax for TensorFlow.""" +"""Adamax optimizer implementation.""" +# pylint: disable=g-classes-have-attributes from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -39,27 +39,27 @@ class Adamax(optimizer_v2.OptimizerV2): Initialization: - ``` - m_0 <- 0 (Initialize initial 1st moment vector) - v_0 <- 0 (Initialize the exponentially weighted infinity norm) - t <- 0 (Initialize timestep) + ```python + m = 0 # Initialize initial 1st moment vector + v = 0 # Initialize the exponentially weighted infinity norm + t = 0 # Initialize timestep ``` - The update rule for `variable` with gradient `g` uses an optimization + The update rule for parameter `w` with gradient `g` is described at the end of section 7.1 of the paper: - ``` - t <- t + 1 - - m_t <- beta1 * m_{t-1} + (1 - beta1) * g - v_t <- max(beta2 * v_{t-1}, abs(g)) - variable <- variable - learning_rate / (1 - beta1^t) * m_t / (v_t + epsilon) + ```python + t += 1 + m = beta1 * m + (1 - beta) * g + v = max(beta2 * v, abs(g)) + current_lr = learning_rate / (1 - beta1 ** t) + w = w - current_lr * m / (v + epsilon) ``` - Similar to AdamOptimizer, the epsilon is added for numerical stability - (especially to get rid of division by zero when v_t = 0). + Similarly to `Adam`, the epsilon is added for numerical stability + (especially to get rid of division by zero when `v_t == 0`). - Contrast to AdamOptimizer, the sparse implementation of this algorithm + In contrast to `Adam`, the sparse implementation of this algorithm (used when the gradient is an IndexedSlices object, typically because of `tf.gather` or an embedding lookup in the forward pass) only updates variable slices and corresponding `m_t`, `v_t` terms when that part of @@ -68,9 +68,23 @@ class Adamax(optimizer_v2.OptimizerV2): implementations which ignore momentum unless a variable slice was actually used). - References - see Section 7 of [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) - ([pdf](http://arxiv.org/pdf/1412.6980.pdf)). + Args: + learning_rate: A `Tensor`, floating point value, or a schedule that is a + `tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate. + beta_1: A float value or a constant float tensor. The exponential decay + rate for the 1st moment estimates. + beta_2: A float value or a constant float tensor. The exponential decay + rate for the exponentially weighted infinity norm. + epsilon: A small constant for numerical stability. + name: Optional name for the operations created when applying gradients. + Defaults to `"Adamax"`. + **kwargs: Keyword arguments. Allowed to be one of + `"clipnorm"` or `"clipvalue"`. + `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips + gradients by value. + + Reference: + - [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) """ _HAS_AGGREGATE_GRAD = True @@ -82,24 +96,6 @@ class Adamax(optimizer_v2.OptimizerV2): epsilon=1e-7, name='Adamax', **kwargs): - """Construct a new Adamax optimizer. - - Args: - learning_rate: A `Tensor`, floating point value, or a schedule that is a - `tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate. - beta_1: A float value or a constant float tensor. The exponential decay - rate for the 1st moment estimates. - beta_2: A float value or a constant float tensor. The exponential decay - rate for the exponentially weighted infinity norm. - epsilon: A small constant for numerical stability. - name: Optional name for the operations created when applying gradients. - Defaults to "Adamax". - **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, - `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip - gradients by value, `decay` is included for backward compatibility to - allow time inverse decay of learning rate. `lr` is included for backward - compatibility, recommended to use `learning_rate` instead. - """ super(Adamax, self).__init__(name, **kwargs) self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) self._set_hyper('decay', self._initial_decay) diff --git a/tensorflow/python/keras/optimizer_v2/ftrl.py b/tensorflow/python/keras/optimizer_v2/ftrl.py index 17484395044..419f0f70125 100644 --- a/tensorflow/python/keras/optimizer_v2/ftrl.py +++ b/tensorflow/python/keras/optimizer_v2/ftrl.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Ftrl-proximal for TensorFlow.""" +"""Ftrl-proximal optimizer implementation.""" +# pylint: disable=g-classes-have-attributes from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -35,26 +36,32 @@ class Ftrl(optimizer_v2.OptimizerV2): above) and shrinkage-type L2 (which is the addition of an L2 penalty to the loss function). - Initialization: - $$t = 0$$ - $$n_{0} = 0$$ - $$\sigma_{0} = 0$$ - $$z_{0} = 0$$ + Args: + learning_rate: A `Tensor`, floating point value, or a schedule that is a + `tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate. + learning_rate_power: A float value, must be less or equal to zero. + Controls how the learning rate decreases during training. Use zero for + a fixed learning rate. + initial_accumulator_value: The starting value for accumulators. + Only zero or positive values are allowed. + l1_regularization_strength: A float value, must be greater than or + equal to zero. + l2_regularization_strength: A float value, must be greater than or + equal to zero. + name: Optional name prefix for the operations created when applying + gradients. Defaults to `"Ftrl"`. + l2_shrinkage_regularization_strength: A float value, must be greater than + or equal to zero. This differs from L2 above in that the L2 above is a + stabilization penalty, whereas this L2 shrinkage is a magnitude penalty. + When input is sparse shrinkage will only happen on the active weights. + **kwargs: Keyword arguments. Allowed to be one of + `"clipnorm"` or `"clipvalue"`. + `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips + gradients by value. - Update ($$i$$ is variable index): - $$t = t + 1$$ - $$n_{t,i} = n_{t-1,i} + g_{t,i}^{2}$$ - $$\sigma_{t,i} = (\sqrt{n_{t,i}} - \sqrt{n_{t-1,i}}) / \alpha$$ - $$z_{t,i} = z_{t-1,i} + g_{t,i} - \sigma_{t,i} * w_{t,i}$$ - $$w_{t,i} = - ((\beta+\sqrt{n+{t}}) / \alpha + \lambda_{2})^{-1} * (z_{i} - - sgn(z_{i}) * \lambda_{1}) if \abs{z_{i}} > \lambda_{i} else 0$$ - - Check the documentation for the l2_shrinkage_regularization_strength - parameter for more details when shrinkage is enabled, where gradient is - replaced with gradient_with_shrinkage. - - References: See - [paper](https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf) + Reference: + - [paper]( + https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf) """ def __init__(self, @@ -66,44 +73,6 @@ class Ftrl(optimizer_v2.OptimizerV2): name='Ftrl', l2_shrinkage_regularization_strength=0.0, **kwargs): - r"""Construct a new FTRL optimizer. - - Args: - learning_rate: A `Tensor`, floating point value, or a schedule that is a - `tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate. - learning_rate_power: A float value, must be less or equal to zero. - Controls how the learning rate decreases during training. Use zero for - a fixed learning rate. - initial_accumulator_value: The starting value for accumulators. - Only zero or positive values are allowed. - l1_regularization_strength: A float value, must be greater than or - equal to zero. - l2_regularization_strength: A float value, must be greater than or - equal to zero. - name: Optional name prefix for the operations created when applying - gradients. Defaults to "Ftrl". - l2_shrinkage_regularization_strength: A float value, must be greater than - or equal to zero. This differs from L2 above in that the L2 above is a - stabilization penalty, whereas this L2 shrinkage is a magnitude penalty. - The FTRL formulation can be written as: - w_{t+1} = argmin_w(\hat{g}_{1:t}w + L1*||w||_1 + L2*||w||_2^2), where - \hat{g} = g + (2*L2_shrinkage*w), and g is the gradient of the loss - function w.r.t. the weights w. - Specifically, in the absence of L1 regularization, it is equivalent to - the following update rule: - w_{t+1} = w_t - lr_t / (1 + 2*L2*lr_t) * g_t - - 2*L2_shrinkage*lr_t / (1 + 2*L2*lr_t) * w_t - where lr_t is the learning rate at t. - When input is sparse shrinkage will only happen on the active weights.\ - **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, - `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip - gradients by value, `decay` is included for backward compatibility to - allow time inverse decay of learning rate. `lr` is included for backward - compatibility, recommended to use `learning_rate` instead. - - Raises: - ValueError: If one of the arguments is invalid. - """ super(Ftrl, self).__init__(name, **kwargs) if initial_accumulator_value < 0.0: diff --git a/tensorflow/python/keras/optimizer_v2/gradient_descent.py b/tensorflow/python/keras/optimizer_v2/gradient_descent.py index 32547b95a52..856cc692431 100644 --- a/tensorflow/python/keras/optimizer_v2/gradient_descent.py +++ b/tensorflow/python/keras/optimizer_v2/gradient_descent.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Momentum for TensorFlow.""" +"""SGD optimizer implementation.""" +# pylint: disable=g-classes-have-attributes from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -27,17 +28,45 @@ from tensorflow.python.util.tf_export import keras_export @keras_export("keras.optimizers.SGD") class SGD(optimizer_v2.OptimizerV2): - r"""Stochastic gradient descent and momentum optimizer. + r"""Gradient descent (with momentum) optimizer. - The update rule for $\theta$ with gradient $g$ when `momentum` is 0.0: - $$\theta_t = \theta_{t-1} - \mathrm{learning\_rate} * g_t$$ + Update rule for parameter `w` with gradient `g` when `momentum` is 0: - The update rule when `momentum` is larger than 0.0: - $$v_t = \mathrm{momentum} * v_{t-1} - \mathrm{learning\_rate} * g_t$$ - $$\theta_t = \theta_{t-1} + v_t$$ - if `nesterov` is False, gradient is evaluated at $\theta_t$. - if `nesterov` is True, gradient is evaluated at $\theta_t + momentum * v_t$, - and the variables always store $\theta + m v$ instead of $theta$ + ```python + w = w - learning_rate * g + ``` + + Update rule when `momentum` is larger than 0: + + ```python + velocity = momentum * velocity - learning_rate * g + w = w * velocity + ``` + + When `nesterov=False`, this rule becomes: + + ```python + velocity = momentum * velocity - learning_rate * g + w = w + momentum * velocity - learning_rate * g + ``` + + Args: + learning_rate: A `Tensor`, floating point value, or a schedule that is a + `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable + that takes no arguments and returns the actual value to use. The + learning rate. Defaults to 0.01. + momentum: float hyperparameter >= 0 that accelerates gradient descent + in the relevant + direction and dampens oscillations. Defaults to 0, i.e., vanilla gradient + descent. + nesterov: boolean. Whether to apply Nesterov momentum. + Defaults to `False`. + name: Optional name prefix for the operations created when applying + gradients. Defaults to `"SGD"`. + **kwargs: Keyword arguments. Allowed to be one of + `"clipnorm"` or `"clipvalue"`. + `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips + gradients by value. Usage: @@ -45,7 +74,7 @@ class SGD(optimizer_v2.OptimizerV2): >>> var = tf.Variable(1.0) >>> loss = lambda: (var ** 2)/2.0 # d(loss)/d(var1) = var1 >>> step_count = opt.minimize(loss, [var]).numpy() - >>> # Step is `-learning_rate*grad` + >>> # Step is `- learning_rate * grad` >>> var.numpy() 0.9 @@ -53,7 +82,7 @@ class SGD(optimizer_v2.OptimizerV2): >>> var = tf.Variable(1.0) >>> val0 = var.value() >>> loss = lambda: (var ** 2)/2.0 # d(loss)/d(var1) = var1 - >>> # First step is `-learning_rate*grad` + >>> # First step is `- learning_rate * grad` >>> step_count = opt.minimize(loss, [var]).numpy() >>> val1 = var.value() >>> (val0 - val1).numpy() @@ -64,13 +93,8 @@ class SGD(optimizer_v2.OptimizerV2): >>> (val1 - val2).numpy() 0.18 - Some of the args below are hyperparameters, where a hyperparameter is - defined as a scalar Tensor, a regular Python value, or a callable (which - will be evaluated when `apply_gradients` is called) returning a scalar - Tensor or a Python value. - - # References - nesterov = True, See [Sutskever et al., 2013]( + Reference: + - For `nesterov=True`, See [Sutskever et al., 2013]( http://jmlr.org/proceedings/papers/v28/sutskever13.pdf). """ @@ -82,25 +106,6 @@ class SGD(optimizer_v2.OptimizerV2): nesterov=False, name="SGD", **kwargs): - """Construct a new Stochastic Gradient Descent or Momentum optimizer. - - Arguments: - learning_rate: A `Tensor`, floating point value, or a schedule that is a - `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable - that takes no arguments and returns the actual value to use. The - learning rate. Defaults to 0.01. - momentum: float hyperparameter >= 0 that accelerates SGD in the relevant - direction and dampens oscillations. Defaults to 0.0, i.e., SGD. - nesterov: boolean. Whether to apply Nesterov momentum. - Defaults to `False`. - name: Optional name prefix for the operations created when applying - gradients. Defaults to 'SGD'. - **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, - `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip - gradients by value, `decay` is included for backward compatibility to - allow time inverse decay of learning rate. `lr` is included for backward - compatibility, recommended to use `learning_rate` instead. - """ super(SGD, self).__init__(name, **kwargs) self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) self._set_hyper("decay", self._initial_decay) diff --git a/tensorflow/python/keras/optimizer_v2/nadam.py b/tensorflow/python/keras/optimizer_v2/nadam.py index f22fbaaae3c..090eabacf1e 100644 --- a/tensorflow/python/keras/optimizer_v2/nadam.py +++ b/tensorflow/python/keras/optimizer_v2/nadam.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Nadam for TensorFlow.""" +"""Nadam optimizer implementation.""" +# pylint: disable=g-classes-have-attributes from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -36,29 +37,22 @@ class Nadam(optimizer_v2.OptimizerV2): Much like Adam is essentially RMSprop with momentum, Nadam is Adam with Nesterov momentum. - Initialization: + Args: + learning_rate: A Tensor or a floating point value. The learning rate. + beta_1: A float value or a constant float tensor. The exponential decay + rate for the 1st moment estimates. + beta_2: A float value or a constant float tensor. The exponential decay + rate for the exponentially weighted infinity norm. + epsilon: A small constant for numerical stability. + name: Optional name for the operations created when applying gradients. + Defaults to `"Nadam"`. + **kwargs: Keyword arguments. Allowed to be one of + `"clipnorm"` or `"clipvalue"`. + `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips + gradients by value. - $$m_0 := 0 \text{(Initialize 1st moment vector)}$$ - $$v_0 := 0 \text{(Initialize 2nd moment vector)}$$ - $$mu_0 := 1$$ - $$t := 0 \text{(Initialize timestep)}$$ - - Computes: - $$t := t + 1$$ - $$\mu_t := \beta_1 * (1 - 0.5 * 0.96^{0.004 * t})$$ - $$g' := g / (1 - \prod_{i=1}^{t}{\mu_i})$$ - $$m_t := \beta_1 * m_{t-1} + (1 - \beta_1) * g$$ - $$m' := m_t / (1 - \prod_{i=1}^{t+1}{\mu_i})$$ - $$v_t := \beta_2 * v_{t-1} + (1 - \beta_2) * g * g$$ - $$v' := v_t / (1 - \beta_2^t)$$ - $$\bar{m} := (1 - \mu_t) * g' + \mu_{t+1} * m'$$ - $$\theta_t := \theta_{t-1} - lr * \bar{m} / (\sqrt{v'} + \epsilon)$$ - - gradient is evaluated at theta(t) + momentum * v(t), and the variables always - store theta + beta_1 * m / sqrt(v) instead of theta. - - References - See [Dozat, T., 2015](http://cs229.stanford.edu/proj2015/054_report.pdf). + Reference: + - [Dozat, 2015](http://cs229.stanford.edu/proj2015/054_report.pdf). """ _HAS_AGGREGATE_GRAD = True @@ -70,24 +64,6 @@ class Nadam(optimizer_v2.OptimizerV2): epsilon=1e-7, name='Nadam', **kwargs): - """Construct a new Nadam optimizer. - - Args: - learning_rate: A Tensor or a floating point value. The learning rate. - beta_1: A float value or a constant float tensor. The exponential decay - rate for the 1st moment estimates. - beta_2: A float value or a constant float tensor. The exponential decay - rate for the exponentially weighted infinity norm. - epsilon: A small constant for numerical stability. - name: Optional name for the operations created when applying gradients. - Defaults to "Nadam". - **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, - `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip - gradients by value, `decay` is included for backward compatibility to - allow time inverse decay of learning rate. `lr` is included for backward - compatibility, recommended to use `learning_rate` instead. - """ - # Backwards compatibility with keras NAdam optimizer. kwargs['decay'] = kwargs.pop('schedule_decay', 0.004) learning_rate = kwargs.get('lr', learning_rate) diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index 37ec1e933ff..7b2e336678e 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Version 2 of class Optimizer.""" # pylint: disable=g-bad-name @@ -79,11 +78,10 @@ def _deduplicate_indexed_slices(values, indices): @six.add_metaclass(abc.ABCMeta) @keras_export("keras.optimizers.Optimizer") class OptimizerV2(trackable.Trackable): - """Updated base class for optimizers. + """Base class for Keras optimizers. - This class defines the API to add Ops to train a model. You never use this - class directly, but instead instantiate one of its subclasses such as - `tf.keras.optimizers.SGD`, `tf.keras.optimizers.Adam`. + You should not use this class directly, but instead instantiate one of its + subclasses such as `tf.keras.optimizers.SGD`, `tf.keras.optimizers.Adam`, etc. ### Usage @@ -101,7 +99,7 @@ class OptimizerV2(trackable.Trackable): opt.minimize(loss, var_list=[var1, var2]) ``` - ### Custom training loop with Keras models + ### Usage in custom training loops In Keras models, sometimes variables are created when the model is first called, instead of construction time. Examples include 1) sequential models @@ -109,6 +107,7 @@ class OptimizerV2(trackable.Trackable): callable in these cases. Example: + ```python opt = tf.keras.optimizers.SGD(learning_rate=0.1) model = tf.keras.Sequential() @@ -120,7 +119,7 @@ class OptimizerV2(trackable.Trackable): opt.minimize(loss_fn, var_list_fn) ``` - ### Processing gradients before applying them. + ### Processing gradients before applying them Calling `minimize()` takes care of both computing the gradients and applying them to the variables. If you want to process the gradients @@ -150,7 +149,7 @@ class OptimizerV2(trackable.Trackable): opt.apply_gradients(zip(processed_grads, var_list)) ``` - ### Use with `tf.distribute.Strategy`. + ### Use with `tf.distribute.Strategy` This optimizer class is `tf.distribute.Strategy` aware, which means it automatically sums gradients across all replicas. To average gradients, @@ -172,7 +171,7 @@ class OptimizerV2(trackable.Trackable): step. As a result, using `tf.math.reduce_mean` will give the wrong answer, resulting in gradients that can be many times too big. - ### Variable Constraint + ### Variable Constraints All Keras optimizers respect variable constraints. If constraint function is passed to any variable, the constraint will be applied to the variable after @@ -195,7 +194,7 @@ class OptimizerV2(trackable.Trackable): This can be useful if you want to log debug a training algorithm, report stats about the slots, etc. - ### Hyper parameters + ### Hyperparameters These are arguments passed to the optimizer subclass constructor (the `__init__` method), and then passed to `self._set_hyper()`. @@ -203,7 +202,7 @@ class OptimizerV2(trackable.Trackable): callables. If they are callable, the callable will be called during `apply_gradients()` to get the value for the hyper parameter. - Hyper parameters can be overwritten through user code: + Hyperparameters can be overwritten through user code: Example: @@ -220,7 +219,8 @@ class OptimizerV2(trackable.Trackable): opt.minimize(loss, var_list=[var1, var2]) ``` - ### Callable learning rate. + ### Callable learning rate + Optimizer accepts a callable learning rate in two ways. The first way is through built-in or customized `tf.keras.optimizers.schedules.LearningRateSchedule`. The schedule will be @@ -250,14 +250,17 @@ class OptimizerV2(trackable.Trackable): >>> opt.minimize(loss, var_list=[var]) >> opt = tf.keras.optimizers.RMSprop(learning_rate=0.1) >>> var1 = tf.Variable(10.0) - >>> loss = lambda: (var1 ** 2)/2.0 # d(loss)/d(var1) = var1 + >>> loss = lambda: (var1 ** 2) / 2.0 # d(loss) / d(var1) = var1 >>> step_count = opt.minimize(loss, [var1]).numpy() >>> var1.numpy() 9.683772 - References - See ([pdf] - http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf). + Reference: + - [Hinton, 2012]( + http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) """ _HAS_AGGREGATE_GRAD = True diff --git a/tensorflow/python/keras/preprocessing/image.py b/tensorflow/python/keras/preprocessing/image.py index 0b7e2967dea..d58a4fcd2b2 100644 --- a/tensorflow/python/keras/preprocessing/image.py +++ b/tensorflow/python/keras/preprocessing/image.py @@ -452,8 +452,8 @@ class ImageDataGenerator(image.ImageDataGenerator): # (std, mean, and principal components if ZCA whitening is applied) datagen.fit(x_train) # fits the model on batches with real-time data augmentation: - model.fit_generator(datagen.flow(x_train, y_train, batch_size=32), - steps_per_epoch=len(x_train) / 32, epochs=epochs) + model.fit(datagen.flow(x_train, y_train, batch_size=32), + steps_per_epoch=len(x_train) / 32, epochs=epochs) # here's a more "manual" example for e in range(epochs): print('Epoch', e) @@ -486,7 +486,7 @@ class ImageDataGenerator(image.ImageDataGenerator): target_size=(150, 150), batch_size=32, class_mode='binary') - model.fit_generator( + model.fit( train_generator, steps_per_epoch=2000, epochs=50, diff --git a/tensorflow/python/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py index 179d9b6020c..00df6b59739 100644 --- a/tensorflow/python/keras/utils/data_utils.py +++ b/tensorflow/python/keras/utils/data_utils.py @@ -195,9 +195,9 @@ def get_file(fname, fname: Name of the file. If an absolute path `/path/to/file.txt` is specified the file will be saved at that location. origin: Original URL of the file. - untar: Deprecated in favor of 'extract'. + untar: Deprecated in favor of `extract` argument. boolean, whether the file should be decompressed - md5_hash: Deprecated in favor of 'file_hash'. + md5_hash: Deprecated in favor of `file_hash` argument. md5 hash of the file for verification file_hash: The expected hash string of the file after download. The sha256 and md5 hash algorithms are both supported. @@ -205,17 +205,16 @@ def get_file(fname, saved. If an absolute path `/path/to/folder` is specified the file will be saved at that location. hash_algorithm: Select the hash algorithm to verify the file. - options are 'md5', 'sha256', and 'auto'. + options are `'md5'`, `'sha256'`, and `'auto'`. The default 'auto' detects the hash algorithm in use. extract: True tries extracting the file as an Archive, like tar or zip. archive_format: Archive format to try for extracting the file. - Options are 'auto', 'tar', 'zip', and None. - 'tar' includes tar, tar.gz, and tar.bz files. - The default 'auto' is ['tar', 'zip']. + Options are `'auto'`, `'tar'`, `'zip'`, and `None`. + `'tar'` includes tar, tar.gz, and tar.bz files. + The default `'auto'` corresponds to `['tar', 'zip']`. None or an empty list will return no matches found. cache_dir: Location to store cached files, when None it - defaults to the [Keras - Directory](/faq/#where-is-the-keras-configuration-filed-stored). + defaults to the default directory `~/.keras/`. Returns: Path to the downloaded file @@ -315,8 +314,8 @@ def _hash_file(fpath, algorithm='sha256', chunk_size=65535): Arguments: fpath: path to the file being validated - algorithm: hash algorithm, one of 'auto', 'sha256', or 'md5'. - The default 'auto' detects the hash algorithm in use. + algorithm: hash algorithm, one of `'auto'`, `'sha256'`, or `'md5'`. + The default `'auto'` detects the hash algorithm in use. chunk_size: Bytes to read at a time, important for large files. Returns: @@ -420,32 +419,32 @@ class Sequence(object): Examples: ```python - from skimage.io import imread - from skimage.transform import resize - import numpy as np - import math + from skimage.io import imread + from skimage.transform import resize + import numpy as np + import math - # Here, `x_set` is list of path to the images - # and `y_set` are the associated classes. + # Here, `x_set` is list of path to the images + # and `y_set` are the associated classes. - class CIFAR10Sequence(Sequence): + class CIFAR10Sequence(Sequence): - def __init__(self, x_set, y_set, batch_size): - self.x, self.y = x_set, y_set - self.batch_size = batch_size + def __init__(self, x_set, y_set, batch_size): + self.x, self.y = x_set, y_set + self.batch_size = batch_size - def __len__(self): - return math.ceil(len(self.x) / self.batch_size) + def __len__(self): + return math.ceil(len(self.x) / self.batch_size) - def __getitem__(self, idx): - batch_x = self.x[idx * self.batch_size:(idx + 1) * - self.batch_size] - batch_y = self.y[idx * self.batch_size:(idx + 1) * - self.batch_size] + def __getitem__(self, idx): + batch_x = self.x[idx * self.batch_size:(idx + 1) * + self.batch_size] + batch_y = self.y[idx * self.batch_size:(idx + 1) * + self.batch_size] - return np.array([ - resize(imread(file_name), (200, 200)) - for file_name in batch_x]), np.array(batch_y) + return np.array([ + resize(imread(file_name), (200, 200)) + for file_name in batch_x]), np.array(batch_y) ``` """ @@ -485,10 +484,10 @@ def iter_sequence_infinite(seq): """Iterates indefinitely over a Sequence. Arguments: - seq: Sequence instance. + seq: `Sequence` instance. Yields: - Batches of data from the Sequence. + Batches of data from the `Sequence`. """ while True: for item in seq: diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py index 897f97c793b..27015cbc8f2 100644 --- a/tensorflow/python/keras/utils/generic_utils.py +++ b/tensorflow/python/keras/utils/generic_utils.py @@ -234,7 +234,7 @@ def get_registered_object(name, custom_objects=None, module_objects=None): @keras_export('keras.utils.serialize_keras_object') def serialize_keras_object(instance): - """Serialize Keras object into JSON.""" + """Serialize a Keras object into a JSON-compatible representation.""" _, instance = tf_decorator.unwrap(instance) if instance is None: return None @@ -327,6 +327,7 @@ def deserialize_keras_object(identifier, module_objects=None, custom_objects=None, printable_module_name='object'): + """Turns the serialized form of a Keras object back into an actual object.""" if identifier is None: return None diff --git a/tensorflow/python/keras/utils/np_utils.py b/tensorflow/python/keras/utils/np_utils.py index f2ca5d213e8..1e8fcf34693 100644 --- a/tensorflow/python/keras/utils/np_utils.py +++ b/tensorflow/python/keras/utils/np_utils.py @@ -27,7 +27,18 @@ def to_categorical(y, num_classes=None, dtype='float32'): E.g. for use with categorical_crossentropy. - Usage Example: + Arguments: + y: class vector to be converted into a matrix + (integers from 0 to num_classes). + num_classes: total number of classes. If `None`, this would be inferred + as the (largest number in `y`) + 1. + dtype: The data type expected by the input. Default: `'float32'`. + + Returns: + A binary matrix representation of the input. The classes axis is placed + last. + + Example: >>> a = tf.keras.utils.to_categorical([0, 1, 2, 3], num_classes=4) >>> a = tf.constant(a, shape=[4, 4]) @@ -51,29 +62,6 @@ def to_categorical(y, num_classes=None, dtype='float32'): >>> print(np.around(loss, 5)) [0. 0. 0. 0.] - Arguments: - y: class vector to be converted into a matrix - (integers from 0 to num_classes). - num_classes: total number of classes. If `None`, this would be inferred - as the (largest number in `y`) + 1. - dtype: The data type expected by the input. Default: `'float32'`. - - Returns: - A binary matrix representation of the input. The classes axis is placed - last. - - Usage example: - - >>> y = [0, 1, 2, 3, 3, 1, 0] - >>> tf.keras.utils.to_categorical(y, 4) - array([[1., 0., 0., 0.], - [0., 1., 0., 0.], - [0., 0., 1., 0.], - [0., 0., 0., 1.], - [0., 0., 0., 1.], - [0., 1., 0., 0.], - [1., 0., 0., 0.]], dtype=float32) - Raises: Value Error: If input contains string value @@ -100,7 +88,7 @@ def normalize(x, axis=-1, order=2): Arguments: x: Numpy array to normalize. axis: axis along which to normalize. - order: Normalization order (e.g. 2 for L2 norm). + order: Normalization order (e.g. `order=2` for L2 norm). Returns: A normalized copy of the array. diff --git a/tensorflow/python/kernel_tests/ctc_loss_op_test.py b/tensorflow/python/kernel_tests/ctc_loss_op_test.py index 19918496fbd..058437cc04e 100644 --- a/tensorflow/python/kernel_tests/ctc_loss_op_test.py +++ b/tensorflow/python/kernel_tests/ctc_loss_op_test.py @@ -31,7 +31,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util -from tensorflow.python.keras import keras_parameterized from tensorflow.python.ops import array_ops from tensorflow.python.ops import ctc_ops from tensorflow.python.ops import gradients_impl @@ -942,8 +941,8 @@ class CTCLossTestV2(test.TestCase): [[1.0, 2.0], [5.0, 8.0], [14.0, 20.0]], out) -@keras_parameterized.run_all_keras_modes -class CTCLossTestV3(keras_parameterized.TestCase): +@test_util.run_all_in_graph_and_eager_modes +class CTCLossTestV3(test.TestCase, parameterized.TestCase): @parameterized.parameters([False, True]) @test_util.run_v2_only @@ -955,6 +954,8 @@ class CTCLossTestV3(keras_parameterized.TestCase): """ if not test.is_gpu_available(): self.skipTest("Need GPU for testing.") + if not context.executing_eagerly(): + self.skipTest("Need eager execution for testing.") random_seed.set_random_seed(5) batch_size = 8 diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index 6629a58577b..6403ca3a0ea 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -288,8 +288,14 @@ struct Converter { Py_DECREF(scalar); if (*error != nullptr) return errors::InvalidArgument(*error); t = ConverterTraits::CreateScalar(ctx, value); + if (t == nullptr) { + return errors::Internal("Cannot create tensor."); + } } else { t = ConverterTraits::CreateTensor(ctx, state->inferred_shape); + if (t == nullptr) { + return errors::Internal("Cannot create tensor."); + } if (t->NumElements() > 0) { T* buf = static_cast(t->Data()); *error = Helper(obj, 0, state, &buf); @@ -674,8 +680,8 @@ typedef Converter BoolConverter; // The two may share underlying storage so changes to one may reflect in the // other. TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) { - tensorflow::Tensor t; - tensorflow::Status status = tensorflow::NdarrayToTensor(obj, &t); + tensorflow::Tensor tensor; + tensorflow::Status status = tensorflow::NdarrayToTensor(obj, &tensor); if (!status.ok()) { PyErr_SetString(PyExc_ValueError, tensorflow::strings::StrCat( @@ -685,8 +691,8 @@ TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) { return nullptr; } - return new TFE_TensorHandle{ - ctx->context->CreateLocalHandle(new TensorInterface(std::move(t)))}; + TensorInterface t(std::move(tensor)); + return new TFE_TensorHandle{ctx->context->CreateLocalHandle(&t)}; } } // namespace @@ -868,10 +874,10 @@ TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj, case DT_INVALID: // Only occurs for empty tensors. { - Tensor t(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype, - TensorShape(state.inferred_shape)); - return new TFE_TensorHandle{ - ctx->context->CreateLocalHandle(new TensorInterface(std::move(t)))}; + Tensor tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype, + TensorShape(state.inferred_shape)); + TensorInterface t(std::move(tensor)); + return new TFE_TensorHandle{ctx->context->CreateLocalHandle(&t)}; } default: diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 15758e22182..9367374717e 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -3107,7 +3107,8 @@ def sparse_placeholder(dtype, shape=None, name=None): print(sess.run(y, feed_dict={ x: (indices, values, shape)})) # Will succeed. - sp = tf.SparseTensor(indices=indices, values=values, dense_shape=shape) + sp = tf.sparse.SparseTensor(indices=indices, values=values, + dense_shape=shape) sp_value = sp.eval(session=sess) print(sess.run(y, feed_dict={x: sp_value})) # Will succeed. ``` @@ -3536,20 +3537,27 @@ def edit_distance(hypothesis, truth, normalize=True, name="edit_distance"): For the following inputs, ```python - # 'hypothesis' is a tensor of shape `(2, 1)` with variable-length values: - hypothesis = tf.SparseTensor( - [[0, 0], - [1,0]], - ["a", "b"], - (2, 1)) + # 'hypothesis' is a tensor of shape `[2, 1]` with variable-length values: + # (0,0) = ["a"] + # (1,0) = ["b"] + hypothesis = tf.sparse.SparseTensor( + [[0, 0, 0], + [1, 0, 0]], + ["a", "b"], + (2, 1, 1)) - # 'truth' is a tensor of shape `(2, 2)` with variable-length values: - truth = tf.SparseTensor( - [[0, 1], - [1, 0], - [1, 1]], - ["a", ["b", "c"], "a"], - (2, 2)) + # 'truth' is a tensor of shape `[2, 2]` with variable-length values: + # (0,0) = [] + # (0,1) = ["a"] + # (1,0) = ["b", "c"] + # (1,1) = ["a"] + truth = tf.sparse.SparseTensor( + [[0, 1, 0], + [1, 0, 0], + [1, 0, 1], + [1, 1, 0]], + ["a", "b", "c", "a"], + (2, 2, 2)) normalize = True diff --git a/tensorflow/python/ops/control_flow_v2_toggles.py b/tensorflow/python/ops/control_flow_v2_toggles.py index 5b9291f7768..15db5287c45 100644 --- a/tensorflow/python/ops/control_flow_v2_toggles.py +++ b/tensorflow/python/ops/control_flow_v2_toggles.py @@ -42,6 +42,8 @@ def enable_control_flow_v2(): # pylint: disable=invalid-name Note: v2 control flow is always enabled inside of tf.function. Calling this function is not required. """ + # pylint: disable=protected-access + ops._control_flow_api_gauge.get_cell().set(True) control_flow_util.ENABLE_CONTROL_FLOW_V2 = True @@ -55,6 +57,8 @@ def disable_control_flow_v2(): # pylint: disable=invalid-name If your code needs tf.disable_control_flow_v2() to be called to work properly please file a bug. """ + # pylint: disable=protected-access + ops._control_flow_api_gauge.get_cell().set(False) control_flow_util.ENABLE_CONTROL_FLOW_V2 = False diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py index d18799c5224..d989bc0be44 100644 --- a/tensorflow/python/ops/ctc_ops.py +++ b/tensorflow/python/ops/ctc_ops.py @@ -1126,7 +1126,7 @@ def dense_labels_to_sparse(dense, length): length: int tensor of shape [batch] The length of each sequence in dense. Returns: - tf.SparseTensor with values only for the valid elements of sequences. + tf.sparse.SparseTensor with values only for the valid elements of sequences. """ flat_values = array_ops.reshape(dense, [-1]) diff --git a/tensorflow/python/ops/map_fn.py b/tensorflow/python/ops/map_fn.py index 7438e584227..632bfbc21e7 100644 --- a/tensorflow/python/ops/map_fn.py +++ b/tensorflow/python/ops/map_fn.py @@ -106,7 +106,7 @@ def map_fn(fn, * A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`) * A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`) - * A `tf.SparseTensorSpec` (to describe a `tf.SparseTensor`) + * A `tf.SparseTensorSpec` (to describe a `tf.sparse.SparseTensor`) * A (possibly nested) tuple, list, or dict containing the above types. #### RaggedTensors @@ -159,11 +159,11 @@ def map_fn(fn, #### SparseTensors - `map_fn` supports `tf.SparseTensor` inputs and outputs. In particular: + `map_fn` supports `tf.sparse.SparseTensor` inputs and outputs. In particular: * If `elems` is a `SparseTensor`, then `fn` will be called with each row of that sparse tensor. In particular, the value passed to `fn` will be a - `tf.SparseTensor` with one fewer dimension than `elems`. + `tf.sparse.SparseTensor` with one fewer dimension than `elems`. * If the result of `map_fn` should be a `SparseTensor`, then use a `tf.SparseTensorSpec` to specify `fn_output_signature`. The individual @@ -171,7 +171,7 @@ def map_fn(fn, `SparseTensor` with one more dimension. >>> # Example: SparseTensor input - >>> st = tf.SparseTensor([[0, 0], [2, 0], [2, 1]], [2, 3, 4], [4, 4]) + >>> st = tf.sparse.SparseTensor([[0, 0], [2, 0], [2, 1]], [2, 3, 4], [4, 4]) >>> tf.map_fn(tf.sparse.reduce_sum, st, fn_output_signature=tf.int32) @@ -191,10 +191,15 @@ def map_fn(fn, *rows* of a `SparseTensor`. If you wish to map a function over the nonzero values, then you should use: - * `tf.SparseTensor(st.indices, fn(st.values), st.dense_shape)` - (if the function is expressible as TensorFlow ops) - * `tf.SparseTensor(st.indices, tf.map_fn(fn, st.values), st.dense_shape)` - (otherwise). + * If the function is expressible as TensorFlow ops, use: + ```python + tf.sparse.SparseTensor(st.indices, fn(st.values), st.dense_shape) + ``` + * Otherwise, use: + ```python + tf.sparse.SparseTensor(st.indices, tf.map_fn(fn, st.values), + st.dense_shape) + ``` #### `map_fn` vs. vectorized operations @@ -276,7 +281,7 @@ def map_fn(fn, * A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`) * A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`) - * A `tf.SparseTensorSpec` (to describe a `tf.SparseTensor`) + * A `tf.SparseTensorSpec` (to describe a `tf.sparse.SparseTensor`) * A (possibly nested) tuple, list, or dict containing the above types. Returns: diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 062b571ff4e..f062047cec2 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -1541,8 +1541,8 @@ def equal(x, y, name=None): Args: - x: A `tf.Tensor` or `tf.SparseTensor` or `tf.IndexedSlices`. - y: A `tf.Tensor` or `tf.SparseTensor` or `tf.IndexedSlices`. + x: A `tf.Tensor` or `tf.sparse.SparseTensor` or `tf.IndexedSlices`. + y: A `tf.Tensor` or `tf.sparse.SparseTensor` or `tf.IndexedSlices`. name: A name for the operation (optional). Returns: @@ -1577,8 +1577,8 @@ def not_equal(x, y, name=None): Args: - x: A `tf.Tensor` or `tf.SparseTensor` or `tf.IndexedSlices`. - y: A `tf.Tensor` or `tf.SparseTensor` or `tf.IndexedSlices`. + x: A `tf.Tensor` or `tf.sparse.SparseTensor` or `tf.IndexedSlices`. + y: A `tf.Tensor` or `tf.sparse.SparseTensor` or `tf.IndexedSlices`. name: A name for the operation (optional). Returns: @@ -3016,12 +3016,12 @@ def matmul(a, **does not support `tf.sparse.SparseTensor`**, it just makes optimizations that assume most values in `a` are zero. See `tf.sparse.sparse_dense_matmul` - for some support for `tf.SparseTensor` multiplication. + for some support for `tf.sparse.SparseTensor` multiplication. b_is_sparse: If `True`, `b` is treated as a sparse matrix. Notice, this **does not support `tf.sparse.SparseTensor`**, it just makes optimizations that assume most values in `a` are zero. See `tf.sparse.sparse_dense_matmul` - for some support for `tf.SparseTensor` multiplication. + for some support for `tf.sparse.SparseTensor` multiplication. name: Name for the operation (optional). Returns: diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py index 35d5e64334e..2a13bc2300f 100644 --- a/tensorflow/python/ops/parallel_for/pfor.py +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -1217,10 +1217,10 @@ class PFor(object): the new dense shape will be (N, max_i(x_i), max_i(y_i), max_i(z_i)). Args: - y: A tf.SparseTensor. + y: A tf.sparse.SparseTensor. Returns: - A tf.SparseTensor that is the converted value corresponding to y. + A tf.sparse.SparseTensor that is the converted value corresponding to y. """ outputs = [ self._convert_helper(t) for t in (y.indices, y.values, y.dense_shape) diff --git a/tensorflow/python/ops/ragged/ragged_tensor.py b/tensorflow/python/ops/ragged/ragged_tensor.py index afb631ed0f2..32e388e480d 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor.py +++ b/tensorflow/python/ops/ragged/ragged_tensor.py @@ -1707,17 +1707,17 @@ class RaggedTensor(composite_tensor.CompositeTensor, tensor_like.TensorLike): @classmethod def from_sparse(cls, st_input, name=None, row_splits_dtype=dtypes.int64): - """Converts a 2D `tf.SparseTensor` to a `RaggedTensor`. + """Converts a 2D `tf.sparse.SparseTensor` to a `RaggedTensor`. Each row of the `output` `RaggedTensor` will contain the explicit values from the same row in `st_input`. `st_input` must be ragged-right. If not it is not ragged-right, then an error will be generated. Example: - - >>> st = tf.SparseTensor(indices=[[0, 0], [0, 1], [0, 2], [1, 0], [3, 0]], - ... values=[1, 2, 3, 4, 5], - ... dense_shape=[4, 3]) + >>> indices = [[0, 0], [0, 1], [0, 2], [1, 0], [3, 0]] + >>> st = tf.sparse.SparseTensor(indices=indices, + ... values=[1, 2, 3, 4, 5], + ... dense_shape=[4, 3]) >>> tf.RaggedTensor.from_sparse(st).to_list() [[1, 2, 3], [4], [], [5]] @@ -1768,7 +1768,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, tensor_like.TensorLike): st_input.values, segment_ids, num_segments, validate=False) def to_sparse(self, name=None): - """Converts this `RaggedTensor` into a `tf.SparseTensor`. + """Converts this `RaggedTensor` into a `tf.sparse.SparseTensor`. Example: diff --git a/tensorflow/python/ops/ragged/row_partition.py b/tensorflow/python/ops/ragged/row_partition.py index 3e0269f692e..133b55a53bf 100644 --- a/tensorflow/python/ops/ragged/row_partition.py +++ b/tensorflow/python/ops/ragged/row_partition.py @@ -1209,7 +1209,7 @@ def _merge_tensors(t1, t2, name, validate): elif t1 is t2: return t1, True else: - err_msg = ("RowPartition.merge_precomuted_encodings: partitons " + err_msg = ("RowPartition.merge_precomuted_encodings: partitions " "have incompatible %s" % name) if not t1.shape.is_compatible_with(t2.shape): raise ValueError(err_msg) diff --git a/tensorflow/python/ops/sets_impl.py b/tensorflow/python/ops/sets_impl.py index 195810d104a..988d437bae8 100644 --- a/tensorflow/python/ops/sets_impl.py +++ b/tensorflow/python/ops/sets_impl.py @@ -156,7 +156,8 @@ def set_intersection(a, b, validate_indices=True): ((1, 1, 0), 5), ((1, 1, 1), 6), ]) - a = tf.SparseTensor(list(a.keys()), list(a.values()), dense_shape=[2,2,2]) + a = tf.sparse.SparseTensor(list(a.keys()), list(a.values()), + dense_shape=[2,2,2]) # b = np.array([[{1}, {}], [{4}, {5, 6, 7, 8}]]) b = collections.OrderedDict([ @@ -167,7 +168,8 @@ def set_intersection(a, b, validate_indices=True): ((1, 1, 2), 7), ((1, 1, 3), 8), ]) - b = tf.SparseTensor(list(b.keys()), list(b.values()), dense_shape=[2, 2, 4]) + b = tf.sparse.SparseTensor(list(b.keys()), list(b.values()), + dense_shape=[2, 2, 4]) # `tf.sets.intersection` is applied to each aligned pair of sets. tf.sets.intersection(a, b) @@ -224,7 +226,8 @@ def set_difference(a, b, aminusb=True, validate_indices=True): ((1, 1, 0), 5), ((1, 1, 1), 6), ]) - a = tf.SparseTensor(list(a.keys()), list(a.values()), dense_shape=[2, 2, 2]) + a = tf.sparse.SparseTensor(list(a.keys()), list(a.values()), + dense_shape=[2, 2, 2]) # np.array([[{1, 3}, {2}], [{4, 5}, {5, 6, 7, 8}]]) b = collections.OrderedDict([ @@ -238,7 +241,8 @@ def set_difference(a, b, aminusb=True, validate_indices=True): ((1, 1, 2), 7), ((1, 1, 3), 8), ]) - b = tf.SparseTensor(list(b.keys()), list(b.values()), dense_shape=[2, 2, 4]) + b = tf.sparse.SparseTensor(list(b.keys()), list(b.values()), + dense_shape=[2, 2, 4]) # `set_difference` is applied to each aligned pair of sets. tf.sets.difference(a, b) @@ -302,7 +306,8 @@ def set_union(a, b, validate_indices=True): ((1, 1, 0), 5), ((1, 1, 1), 6), ]) - a = tf.SparseTensor(list(a.keys()), list(a.values()), dense_shape=[2, 2, 2]) + a = tf.sparse.SparseTensor(list(a.keys()), list(a.values()), + dense_shape=[2, 2, 2]) # [[{1, 3}, {2}], [{4, 5}, {5, 6, 7, 8}]] b = collections.OrderedDict([ @@ -316,7 +321,8 @@ def set_union(a, b, validate_indices=True): ((1, 1, 2), 7), ((1, 1, 3), 8), ]) - b = tf.SparseTensor(list(b.keys()), list(b.values()), dense_shape=[2, 2, 4]) + b = tf.sparse.SparseTensor(list(b.keys()), list(b.values()), + dense_shape=[2, 2, 4]) # `set_union` is applied to each aligned pair of sets. tf.sets.union(a, b) diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index 2a6db8dd432..5096b332364 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -16,7 +16,7 @@ # pylint: disable=g-short-docstring-punctuation """Sparse Tensor Representation. -See also `tf.SparseTensor`. +See also `tf.sparse.SparseTensor`. """ from __future__ import absolute_import @@ -2510,7 +2510,7 @@ def sparse_softmax(sp_input, name=None): values = np.asarray([[[0., np.e], [1., 0.]], [[np.e, 0.], [np.e, np.e]]]) indices = np.vstack(np.where(values)).astype(np.int64).T - result = tf.sparse.softmax(tf.SparseTensor(indices, values, shape)) + result = tf.sparse.softmax(tf.sparse.SparseTensor(indices, values, shape)) # ...returning a 3-D SparseTensor, equivalent to: # [? 1.] [1 ?] # [1. ? ] and [.5 .5] @@ -2644,8 +2644,8 @@ def sparse_transpose(sp_input, perm=None, name=None): """ with ops.name_scope(name, "SparseTranspose", [sp_input]) as name: if perm is None: - if sp_input.shape.is_fully_defined(): - rank = len(sp_input.shape) + if sp_input.shape.rank is not None: + rank = sp_input.shape.rank perm = (rank - 1) - np.arange(0, rank, 1) else: rank = array_ops.rank(sp_input) diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index ee43e6e1d43..6891c0411df 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -503,7 +503,10 @@ PYBIND11_MODULE(_pywrap_tfe, m) { m.def("TFE_ExecutorWaitForAllPendingNodes", [](TFE_Executor& exc) { tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus()); + // NOTE: release Python GIL for pending PyFunc ops to be executed properly. + Py_BEGIN_ALLOW_THREADS; TFE_ExecutorWaitForAllPendingNodes(&exc, status.get()); + Py_END_ALLOW_THREADS; tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); }); m.def("TFE_ExecutorClearError", &TFE_ExecutorClearError); diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD index 7c8a4760081..049594ead90 100644 --- a/tensorflow/python/tools/BUILD +++ b/tensorflow/python/tools/BUILD @@ -26,6 +26,7 @@ py_library( "//tensorflow/tools/compatibility:tf_upgrade_v2", ], deps = [ + ":saved_model_aot_compile", ":saved_model_utils", # The following py_library are needed because # py_binary may not depend on them when --define=no_tensorflow_py_deps=true @@ -135,6 +136,7 @@ py_test( "//tensorflow/python:math_ops", "//tensorflow/python:training", "//tensorflow/python:variables", + "@absl_py//absl/testing:parameterized", ], ) @@ -322,6 +324,10 @@ py_library( srcs = ["saved_model_cli.py"], srcs_version = "PY2AND3", deps = [ + # Note: if you make any changes here, make corresponding changes to the + # deps of the "tools_pip" target in this file. Otherwise release builds + # (built with --define=no_tensorflow_py_deps=true) may end up with a + # broken saved_model_cli. ":saved_model_aot_compile", ":saved_model_utils", "//tensorflow/python", diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py index f1372f612c0..561e998f6c3 100644 --- a/tensorflow/python/tools/freeze_graph.py +++ b/tensorflow/python/tools/freeze_graph.py @@ -357,7 +357,7 @@ def freeze_graph(input_graph, variable_names_blacklist, input_meta_graph_def, input_saved_model_dir, - saved_model_tags.replace(" ", "").split(","), + [tag for tag in saved_model_tags.replace(" ", "").split(",") if tag], checkpoint_version=checkpoint_version) diff --git a/tensorflow/python/tools/freeze_graph_test.py b/tensorflow/python/tools/freeze_graph_test.py index a27058655ad..7c9ecc3b2e2 100644 --- a/tensorflow/python/tools/freeze_graph_test.py +++ b/tensorflow/python/tools/freeze_graph_test.py @@ -21,6 +21,8 @@ from __future__ import print_function import os import re +from absl.testing import parameterized + from tensorflow.core.example import example_pb2 from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import saver_pb2 @@ -46,7 +48,7 @@ from tensorflow.python.tools import freeze_graph from tensorflow.python.training import saver as saver_lib -class FreezeGraphTest(test_util.TensorFlowTestCase): +class FreezeGraphTest(test_util.TensorFlowTestCase, parameterized.TestCase): def _testFreezeGraph(self, saver_write_version): @@ -124,7 +126,7 @@ class FreezeGraphTest(test_util.TensorFlowTestCase): feature_value]) return example.SerializeToString() - def _writeDummySavedModel(self, path, feature_name): + def _writeDummySavedModel(self, path, feature_name, tags): """Writes a classifier with two input features to the given path.""" with ops.Graph().as_default(): examples = array_ops.placeholder(dtypes.string, name="input_node") @@ -151,11 +153,12 @@ class FreezeGraphTest(test_util.TensorFlowTestCase): builder = saved_model_builder.SavedModelBuilder(path) builder.add_meta_graph_and_variables( sess, - [tag_constants.SERVING], + tags, signature_def_map={ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature, - },) + }, + ) builder.save(as_text=True) @test_util.run_v1_only("b/120545219") @@ -218,11 +221,14 @@ class FreezeGraphTest(test_util.TensorFlowTestCase): output = sess.run(output_node) self.assertNear(2.0, output, 0.00001) - def testFreezeSavedModel(self): + @parameterized.named_parameters( + ("empty_tags_set", "", []), + ("default_tags_set", tag_constants.SERVING, [tag_constants.SERVING])) + def testFreezeSavedModel(self, tags_string, tags_list): tmp_dir = self.get_temp_dir() saved_model_dir = os.path.join(tmp_dir, "saved_model_dir") feature_name = "feature" - self._writeDummySavedModel(saved_model_dir, feature_name) + self._writeDummySavedModel(saved_model_dir, feature_name, tags_list) output_graph_filename = os.path.join(tmp_dir, "output_graph.pb") input_saved_model_dir = saved_model_dir @@ -235,7 +241,7 @@ class FreezeGraphTest(test_util.TensorFlowTestCase): input_meta_graph = False checkpoint_path = None input_graph_filename = None - saved_model_tags = tag_constants.SERVING + saved_model_tags = tags_string freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path, input_binary, checkpoint_path, output_node_names, diff --git a/tensorflow/python/tpu/tpu_embedding.py b/tensorflow/python/tpu/tpu_embedding.py index e24188eaf16..9ced7734ece 100644 --- a/tensorflow/python/tpu/tpu_embedding.py +++ b/tensorflow/python/tpu/tpu_embedding.py @@ -1392,16 +1392,20 @@ class TPUEmbedding(object): 'table_ids': [], 'max_sequence_lengths': [], } + int_zeros = array_ops.zeros((0,), dtype=dtypes.int64) + float_zeros = array_ops.zeros((0,), dtype=dtypes.float32) for table_id, table in enumerate(self._table_to_features_dict): features = self._table_to_features_dict[table] for feature in features: enqueue_data = enqueue_datas[feature] - kwargs['sample_splits'].append(enqueue_data.sample_splits) + kwargs['sample_splits'].append( + enqueue_data.sample_splits + if enqueue_data.sample_splits is not None else int_zeros) kwargs['aggregation_weights'].append( - enqueue_data.aggregation_weights if enqueue_data.aggregation_weights - is not None else array_ops.zeros((0,), dtype=dtypes.float32)) + enqueue_data.aggregation_weights + if enqueue_data.aggregation_weights is not None else float_zeros) kwargs['embedding_indices'].append(enqueue_data.embedding_indices) diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index d215fb632b3..afefa502593 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -310,8 +310,9 @@ def flatten(structure, expand_composites=False): Args: structure: an arbitrarily nested structure. Note, numpy arrays are considered atoms and are not flattened. - expand_composites: If true, then composite tensors such as tf.SparseTensor - and tf.RaggedTensor are expanded into their component tensors. + expand_composites: If true, then composite tensors such as + `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their + component tensors. Returns: A Python list, the flattened version of the input. @@ -357,15 +358,16 @@ def assert_same_structure(nest1, nest2, check_types=True, nest1: an arbitrarily nested structure. nest2: an arbitrarily nested structure. check_types: if `True` (default) types of sequences are checked as well, - including the keys of dictionaries. If set to `False`, for example a - list and a tuple of objects will look the same if they have the same - size. Note that namedtuples with identical name and fields are always - considered to have the same shallow structure. Two types will also be - considered the same if they are both list subtypes (which allows "list" - and "_ListWrapper" from trackable dependency tracking to compare - equal). - expand_composites: If true, then composite tensors such as `tf.SparseTensor` - and `tf.RaggedTensor` are expanded into their component tensors. + including the keys of dictionaries. If set to `False`, for example a + list and a tuple of objects will look the same if they have the same + size. Note that namedtuples with identical name and fields are always + considered to have the same shallow structure. Two types will also be + considered the same if they are both list subtypes (which allows "list" + and "_ListWrapper" from trackable dependency tracking to compare + equal). + expand_composites: If true, then composite tensors such as + `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their + component tensors. Raises: ValueError: If the two structures do not have the same number of elements or @@ -534,11 +536,12 @@ def pack_sequence_as(structure, flat_sequence, expand_composites=False): Args: structure: Nested structure, whose structure is given by nested lists, - tuples, and dicts. Note: numpy arrays and strings are considered - scalars. + tuples, and dicts. Note: numpy arrays and strings are considered + scalars. flat_sequence: flat sequence to pack. - expand_composites: If true, then composite tensors such as `tf.SparseTensor` - and `tf.RaggedTensor` are expanded into their component tensors. + expand_composites: If true, then composite tensors such as + `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their + component tensors. Returns: packed: `flat_sequence` converted to have the same recursive structure as @@ -574,9 +577,9 @@ def map_structure(func, *structure, **kwargs): Note that namedtuples with identical name and fields are always considered to have the same shallow structure. * `expand_composites`: If set to `True`, then composite tensors such - as `tf.SparseTensor` and `tf.RaggedTensor` are expanded into their - component tensors. If `False` (the default), then composite tensors - are not expanded. + as `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into + their component tensors. If `False` (the default), then composite + tensors are not expanded. Returns: A new structure with the same arity as `structure`, whose values correspond @@ -762,8 +765,9 @@ def assert_shallow_structure(shallow_tree, `input_tree` have to be the same. Note that even with check_types==True, this function will consider two different namedtuple classes with the same name and _fields attribute to be the same class. - expand_composites: If true, then composite tensors such as tf.SparseTensor - and tf.RaggedTensor are expanded into their component tensors. + expand_composites: If true, then composite tensors such as + `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their + component tensors. Raises: TypeError: If `shallow_tree` is a sequence but `input_tree` is not. TypeError: If the sequence types of `shallow_tree` are different from @@ -911,8 +915,9 @@ def flatten_up_to(shallow_tree, input_tree, check_types=True, Note, numpy arrays are considered scalars. check_types: bool. If True, check that each node in shallow_tree has the same type as the corresponding node in input_tree. - expand_composites: If true, then composite tensors such as tf.SparseTensor - and tf.RaggedTensor are expanded into their component tensors. + expand_composites: If true, then composite tensors such as + `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their + component tensors. Returns: A Python list, the partially flattened version of `input_tree` according to @@ -1015,8 +1020,9 @@ def flatten_with_tuple_paths_up_to(shallow_tree, Note, numpy arrays are considered scalars. check_types: bool. If True, check that each node in shallow_tree has the same type as the corresponding node in input_tree. - expand_composites: If true, then composite tensors such as tf.SparseTensor - and tf.RaggedTensor are expanded into their component tensors. + expand_composites: If true, then composite tensors such as + `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their + component tensors. Returns: A Python list, the partially flattened version of `input_tree` according to @@ -1233,8 +1239,9 @@ def get_traverse_shallow_structure(traverse_fn, structure, shallow structure of the same type, describing which parts of the substructure to traverse. structure: The structure to traverse. - expand_composites: If true, then composite tensors such as tf.SparseTensor - and tf.RaggedTensor are expanded into their component tensors. + expand_composites: If true, then composite tensors such as + `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their + component tensors. Returns: A shallow structure containing python bools, which can be passed to @@ -1313,12 +1320,13 @@ def yield_flat_paths(nest, expand_composites=False): Args: nest: the value to produce a flattened paths list for. - expand_composites: If true, then composite tensors such as tf.SparseTensor - and tf.RaggedTensor are expanded into their component tensors. + expand_composites: If true, then composite tensors such as + `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their + component tensors. Yields: Tuples containing index or key values which form the path to a specific - leaf value in the nested structure. + leaf value in the nested structure. """ is_seq = is_sequence_or_composite if expand_composites else is_sequence for k, _ in _yield_flat_up_to(nest, nest, is_seq): @@ -1338,8 +1346,9 @@ def flatten_with_joined_string_paths(structure, separator="/", structure: the nested structure to flatten. separator: string to separate levels of hierarchy in the results, defaults to '/'. - expand_composites: If true, then composite tensors such as tf.SparseTensor - and tf.RaggedTensor are expanded into their component tensors. + expand_composites: If true, then composite tensors such as + `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their + component tensors. Returns: A list of (string, data element) tuples. @@ -1362,8 +1371,9 @@ def flatten_with_tuple_paths(structure, expand_composites=False): Args: structure: the nested structure to flatten. - expand_composites: If true, then composite tensors such as tf.SparseTensor - and tf.RaggedTensor are expanded into their component tensors. + expand_composites: If true, then composite tensors such as + `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their + component tensors. Returns: A list of `(tuple_path, leaf_element)` tuples. Each `tuple_path` is a tuple diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h index 405927dcc81..23438b43c53 100644 --- a/tensorflow/python/util/util.h +++ b/tensorflow/python/util/util.h @@ -234,7 +234,7 @@ PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types, // nest: an arbitrarily nested structure or a scalar object. Note, numpy // arrays are considered scalars. // expand_composites: If true, then composite tensors (such as -// `tf.SparseTensor` and `tf.RaggedTensor` are flattened into their +// `tf.sparse.SparseTensor` and `tf.RaggedTensor` are flattened into their // component tensors. // // Returns: diff --git a/tensorflow/python/util/util_wrapper.cc b/tensorflow/python/util/util_wrapper.cc index 6df78f5db44..50ea922ef52 100644 --- a/tensorflow/python/util/util_wrapper.cc +++ b/tensorflow/python/util/util_wrapper.cc @@ -244,7 +244,7 @@ PYBIND11_MODULE(_pywrap_utils, m) { Args: nest: an arbitrarily nested structure or a scalar object. Note, numpy arrays are considered scalars. - expand_composites: If true, then composite tensors such as `tf.SparseTensor` + expand_composites: If true, then composite tensors such as `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their component tensors. Returns: diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 29b02069209..5e18203844e 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -2803,5 +2803,8 @@ def filegroup_as_file(name, dep, visibility = []): visibility = visibility, ) +def tf_grpc_dependency(): + return "//tensorflow:grpc" + def tf_grpc_cc_dependency(): return "//tensorflow:grpc++" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.-t-f-lite-converter.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.-t-f-lite-converter.pbtxt index db76bb3f4b3..0c43fc556aa 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.lite.-t-f-lite-converter.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.-t-f-lite-converter.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'graph_def\', \'input_tensors\', \'output_tensors\', \'input_arrays_with_shape\', \'output_arrays\', \'experimental_debug_info_func\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'graph_def\', \'input_tensors\', \'output_tensors\', \'input_arrays_with_shape\', \'output_arrays\', \'experimental_debug_info_func\', \'saved_model_dir\', \'saved_model_tags\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "convert" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.lite.-t-f-lite-converter.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.lite.-t-f-lite-converter.pbtxt index 63a6667c0b2..c575283b74d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.lite.-t-f-lite-converter.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.lite.-t-f-lite-converter.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'funcs\', \'trackable_obj\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'funcs\', \'trackable_obj\', \'saved_model_dir\', \'saved_model_tags\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " } member_method { name: "convert" diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh index 4f7ce00eb53..c49ba608fc0 100755 --- a/tensorflow/tools/ci_build/ci_sanity.sh +++ b/tensorflow/tools/ci_build/ci_sanity.sh @@ -366,6 +366,7 @@ do_external_licenses_check(){ -e "@com_github_googlecloudplatform_google_cloud_cpp//google" \ -e "@com_github_grpc_grpc//src/compiler" \ -e "@platforms//os" \ + -e "@ruy//" \ -v ${MISSING_LICENSES_FILE} > temp.txt mv temp.txt ${MISSING_LICENSES_FILE} @@ -383,6 +384,7 @@ do_external_licenses_check(){ -e "@com_github_googlecloudplatform_google_cloud_cpp//" \ -e "@embedded_jdk//" \ -e "^//$" \ + -e "@ruy//" \ -v ${EXTRA_LICENSES_FILE} > temp.txt mv temp.txt ${EXTRA_LICENSES_FILE} diff --git a/tensorflow/tools/compatibility/ast_edits.py b/tensorflow/tools/compatibility/ast_edits.py index fa1e8def53d..b5d51d2742f 100644 --- a/tensorflow/tools/compatibility/ast_edits.py +++ b/tensorflow/tools/compatibility/ast_edits.py @@ -213,8 +213,8 @@ class APIChangeSpec(object): """ def preprocess(self, root_node): # pylint: disable=unused-argument - """Preprocess a parse tree. Return any produced logs and errors.""" - return [], [] + """Preprocess a parse tree. Return a preprocessed node, logs and errors.""" + return root_node, [], [] def clear_preprocessing(self): """Restore this APIChangeSpec to before it preprocessed a file. @@ -942,7 +942,7 @@ class ASTCodeUpgrader(object): log = ["ERROR: Failed to parse.\n" + traceback.format_exc()] return 0, "", log, [] - preprocess_logs, preprocess_errors = self._api_change_spec.preprocess(t) + t, preprocess_logs, preprocess_errors = self._api_change_spec.preprocess(t) visitor = _PastaEditVisitor(self._api_change_spec) visitor.visit(t) diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py index 3b3feff1b58..d27c75fb44e 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py @@ -54,21 +54,47 @@ class VersionedTFImport(ast_edits.AnalysisResult): "` was directly imported as `tf`.") +compat_v1_import = VersionedTFImport("compat.v1") +compat_v2_import = VersionedTFImport("compat.v2") + + class TFAPIImportAnalysisSpec(ast_edits.APIAnalysisSpec): def __init__(self): self.symbols_to_detect = {} self.imports_to_detect = { ("tensorflow", None): UnaliasedTFImport(), - ("tensorflow.compat.v1", "tf"): VersionedTFImport("compat.v1"), - ("tensorflow.compat.v2", "tf"): VersionedTFImport("compat.v2"), + ("tensorflow.compat.v1", "tf"): compat_v1_import, + ("tensorflow.compat.v2", "tf"): compat_v2_import, } +class CompatV1ImportReplacer(ast.NodeVisitor): + """AST Visitor that replaces `import tensorflow.compat.v1 as tf`. + + Converts `import tensorflow.compat.v1 as tf` to `import tensorflow as tf` + """ + + def visit_Import(self, node): # pylint: disable=invalid-name + """Handle visiting an import node in the AST. + + Args: + node: Current Node + """ + for import_alias in node.names: + # Detect based on full import name and alias + if (import_alias.name == "tensorflow.compat.v1" and + import_alias.asname == "tf"): + import_alias.name = "tensorflow" + self.generic_visit(node) + + class TFAPIChangeSpec(ast_edits.NoUpdateSpec): """List of maps that describe what changed in the API.""" - def __init__(self, import_rename=False): + def __init__(self, import_rename=False, upgrade_compat_v1_import=False): + self.upgrade_compat_v1_import = upgrade_compat_v1_import + # Maps from a function name to a dictionary that describes how to # map from an old argument keyword to the new argument keyword. # If the new argument is None, it will be removed. @@ -1612,10 +1638,21 @@ class TFAPIChangeSpec(ast_edits.NoUpdateSpec): self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS - def preprocess(self, root_node): + def preprocess(self, root_node, after_compat_v1_upgrade=False): visitor = ast_edits.PastaAnalyzeVisitor(TFAPIImportAnalysisSpec()) visitor.visit(root_node) detections = set(visitor.results) + + # Upgrade explicit compat v1 imports if `upgrade_compat_v1_import` is + # enabled. Then preprocess the updated root node. + # We only do this upgrading once, because some forms of the import may + # still cause errors but aren't trivially upgradeable, and we don't want + # to enter an infinite loop. E.g. `from tensorflow.compat import v1, v2`. + if (compat_v1_import in detections and self.upgrade_compat_v1_import and + not after_compat_v1_upgrade): + CompatV1ImportReplacer().visit(root_node) + return self.preprocess(root_node, after_compat_v1_upgrade=True) + # If we have detected the presence of imports of specific TF versions, # We want to modify the update spec to check only module deprecations # and skip all other conversions. @@ -1629,7 +1666,7 @@ class TFAPIChangeSpec(ast_edits.NoUpdateSpec): self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS self.function_transformers = {} self.import_renames = {} - return visitor.log, visitor.warnings_and_errors + return root_node, visitor.log, visitor.warnings_and_errors def clear_preprocessing(self): self.__init__() diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_main.py b/tensorflow/tools/compatibility/tf_upgrade_v2_main.py index 7dcbfe19c39..7c7461c19da 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2_main.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_main.py @@ -101,7 +101,15 @@ Simple usage: parser.add_argument( "--no_import_rename", dest="no_import_rename", - help=("Not to rename import to compact.v2 explicitly."), + help=("Not to rename import to compat.v2 explicitly."), + action="store_true") + parser.add_argument( + "--no_upgrade_compat_v1_import", + dest="no_upgrade_compat_v1_import", + help=("If specified, don't upgrade explicit imports of " + "`tensorflow.compat.v1 as tf` to the v2 apis. Otherwise, " + "explicit imports of the form `tensorflow.compat.v1 as tf` will " + "be upgraded."), action="store_true") parser.add_argument( "--reportfile", @@ -132,10 +140,13 @@ Simple usage: change_spec = tf_upgrade_v2_safety.TFAPIChangeSpec() else: if args.no_import_rename: - change_spec = tf_upgrade_v2.TFAPIChangeSpec(import_rename=False) + change_spec = tf_upgrade_v2.TFAPIChangeSpec( + import_rename=False, + upgrade_compat_v1_import=not args.no_upgrade_compat_v1_import) else: change_spec = tf_upgrade_v2.TFAPIChangeSpec( - import_rename=_IMPORT_RENAME_DEFAULT) + import_rename=_IMPORT_RENAME_DEFAULT, + upgrade_compat_v1_import=not args.no_upgrade_compat_v1_import) upgrade = ast_edits.ASTCodeUpgrader(change_spec) report_text = None diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py index 059ed26889f..47b9899a6b7 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py @@ -117,11 +117,15 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase): visitor.private_map["tf.compat"] = ["v1", "v2"] traverse.traverse(tf.compat.v1, visitor) - def _upgrade(self, old_file_text, import_rename=False): + def _upgrade(self, + old_file_text, + import_rename=False, + upgrade_compat_v1_import=False): in_file = six.StringIO(old_file_text) out_file = six.StringIO() upgrader = ast_edits.ASTCodeUpgrader( - tf_upgrade_v2.TFAPIChangeSpec(import_rename)) + tf_upgrade_v2.TFAPIChangeSpec( + import_rename, upgrade_compat_v1_import=upgrade_compat_v1_import)) count, report, errors = ( upgrader.process_opened_file("test.py", in_file, "test_out.py", out_file)) @@ -2215,6 +2219,30 @@ def _log_prob(self, x): _, _, _, new_text = self._upgrade(text, import_rename=True) self.assertEqual(new_text, expected_text) + import_header = ("import tensorflow.compat.v1 as tf\n" + "import tensorflow.compat.v1 as tf_v1\n" + "import tensorflow.compat.v2 as tf_v2\n") + text = import_header + old_symbol + expected_header = ("import tensorflow.compat.v2 as tf\n" + "import tensorflow.compat.v1 as tf_v1\n" + "import tensorflow.compat.v2 as tf_v2\n") + expected_text = expected_header + new_symbol + _, _, _, new_text = self._upgrade( + text, import_rename=True, upgrade_compat_v1_import=True) + self.assertEqual(new_text, expected_text) + + import_header = ("import tensorflow.compat.v1 as tf\n" + "import tensorflow.compat.v1 as tf_v1\n" + "import tensorflow.compat.v2 as tf_v2\n") + text = import_header + old_symbol + expected_header = ("import tensorflow as tf\n" + "import tensorflow.compat.v1 as tf_v1\n" + "import tensorflow.compat.v2 as tf_v2\n") + expected_text = expected_header + new_symbol + _, _, _, new_text = self._upgrade( + text, import_rename=False, upgrade_compat_v1_import=True) + self.assertEqual(new_text, expected_text) + import_header = "from tensorflow import foo\n" text = import_header + old_symbol expected_text = "from tensorflow.compat.v2 import foo\n" + new_symbol diff --git a/tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile index d07d701b9b7..49b635e537b 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile @@ -25,25 +25,19 @@ FROM ubuntu:${UBUNTU_VERSION} as base RUN apt-get update && apt-get install -y curl -ARG USE_PYTHON_3_NOT_2 -# TODO(angerson) Completely remove Python 2 support -ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} -ARG PYTHON=python${_PY_SUFFIX} -ARG PIP=pip${_PY_SUFFIX} - # See http://bugs.python.org/issue19846 ENV LANG C.UTF-8 RUN apt-get update && apt-get install -y \ - ${PYTHON} \ - ${PYTHON}-pip + python3 + python3-pip -RUN ${PIP} --no-cache-dir install --upgrade \ +RUN python3 -m pip --no-cache-dir install --upgrade \ pip \ setuptools # Some TF tools expect a "python" binary -RUN ln -s $(which ${PYTHON}) /usr/local/bin/python +RUN ln -s $(which python3) /usr/local/bin/python # Options: # tensorflow @@ -59,10 +53,9 @@ RUN ${PIP} install --no-cache-dir ${TF_PACKAGE}${TF_PACKAGE_VERSION:+==${TF_PACK COPY bashrc /etc/bash.bashrc RUN chmod a+rwx /etc/bash.bashrc -RUN ${PIP} install jupyter matplotlib +RUN python3 -m pip install jupyter matplotlib # Pin ipykernel and nbformat; see https://github.com/ipython/ipykernel/issues/422 -RUN if [[ "${USE_PYTHON_3_NOT_2}" == "1" ]]; then ${PIP} install ipykernel==5.1.1 nbformat==4.4.0; fi -RUN ${PIP} install jupyter_http_over_ws +RUN python3 -m pip install jupyter_http_over_ws ipykernel==5.1.1 nbformat==4.4.0 RUN jupyter serverextension enable --py jupyter_http_over_ws RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/ @@ -82,6 +75,6 @@ RUN apt-get autoremove -y && apt-get remove -y wget WORKDIR /tf EXPOSE 8888 -RUN ${PYTHON} -m ipykernel.kernelspec +RUN python3 -m ipykernel.kernelspec CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/tf --ip 0.0.0.0 --no-browser --allow-root"] diff --git a/tensorflow/tools/dockerfiles/dockerfiles/cpu.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/cpu.Dockerfile index 6e8b09de45a..ec02e8cab6d 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/cpu.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/cpu.Dockerfile @@ -25,25 +25,19 @@ FROM ubuntu:${UBUNTU_VERSION} as base RUN apt-get update && apt-get install -y curl -ARG USE_PYTHON_3_NOT_2 -# TODO(angerson) Completely remove Python 2 support -ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} -ARG PYTHON=python${_PY_SUFFIX} -ARG PIP=pip${_PY_SUFFIX} - # See http://bugs.python.org/issue19846 ENV LANG C.UTF-8 RUN apt-get update && apt-get install -y \ - ${PYTHON} \ - ${PYTHON}-pip + python3 + python3-pip -RUN ${PIP} --no-cache-dir install --upgrade \ +RUN python3 -m pip --no-cache-dir install --upgrade \ pip \ setuptools # Some TF tools expect a "python" binary -RUN ln -s $(which ${PYTHON}) /usr/local/bin/python +RUN ln -s $(which python3) /usr/local/bin/python # Options: # tensorflow diff --git a/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu-jupyter.Dockerfile index ea38e538945..b61ad3af8c2 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu-jupyter.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu-jupyter.Dockerfile @@ -54,25 +54,19 @@ ARG CHECKOUT_TF_SRC=0 RUN chmod a+w /etc/passwd /etc/group RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true -ARG USE_PYTHON_3_NOT_2 -# TODO(angerson) Completely remove Python 2 support -ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} -ARG PYTHON=python${_PY_SUFFIX} -ARG PIP=pip${_PY_SUFFIX} - # See http://bugs.python.org/issue19846 ENV LANG C.UTF-8 RUN apt-get update && apt-get install -y \ - ${PYTHON} \ - ${PYTHON}-pip + python3 + python3-pip -RUN ${PIP} --no-cache-dir install --upgrade \ +RUN python3 -m pip --no-cache-dir install --upgrade \ pip \ setuptools # Some TF tools expect a "python" binary -RUN ln -s $(which ${PYTHON}) /usr/local/bin/python +RUN ln -s $(which python3) /usr/local/bin/python RUN apt-get update && apt-get install -y \ build-essential \ @@ -80,7 +74,7 @@ RUN apt-get update && apt-get install -y \ git \ wget \ openjdk-8-jdk \ - ${PYTHON}-dev \ + python3-dev \ virtualenv \ swig @@ -96,7 +90,6 @@ RUN ${PIP} --no-cache-dir install \ pandas \ future \ portpicker \ - && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \ enum34 # Install bazel @@ -111,10 +104,9 @@ RUN mkdir /bazel && \ COPY bashrc /etc/bash.bashrc RUN chmod a+rwx /etc/bash.bashrc -RUN ${PIP} install jupyter matplotlib +RUN python3 -m pip install jupyter matplotlib # Pin ipykernel and nbformat; see https://github.com/ipython/ipykernel/issues/422 -RUN if [[ "${USE_PYTHON_3_NOT_2}" == "1" ]]; then ${PIP} install ipykernel==5.1.1 nbformat==4.4.0; fi -RUN ${PIP} install jupyter_http_over_ws +RUN python3 -m pip install jupyter_http_over_ws ipykernel==5.1.1 nbformat==4.4.0 RUN jupyter serverextension enable --py jupyter_http_over_ws RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/ @@ -134,6 +126,6 @@ RUN apt-get autoremove -y && apt-get remove -y wget WORKDIR /tf EXPOSE 8888 -RUN ${PYTHON} -m ipykernel.kernelspec +RUN python3 -m ipykernel.kernelspec CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/tf --ip 0.0.0.0 --no-browser --allow-root"] diff --git a/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu.Dockerfile index 6c731f24b8d..beed6081721 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu.Dockerfile @@ -54,25 +54,19 @@ ARG CHECKOUT_TF_SRC=0 RUN chmod a+w /etc/passwd /etc/group RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true -ARG USE_PYTHON_3_NOT_2 -# TODO(angerson) Completely remove Python 2 support -ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} -ARG PYTHON=python${_PY_SUFFIX} -ARG PIP=pip${_PY_SUFFIX} - # See http://bugs.python.org/issue19846 ENV LANG C.UTF-8 RUN apt-get update && apt-get install -y \ - ${PYTHON} \ - ${PYTHON}-pip + python3 + python3-pip -RUN ${PIP} --no-cache-dir install --upgrade \ +RUN python3 -m pip --no-cache-dir install --upgrade \ pip \ setuptools # Some TF tools expect a "python" binary -RUN ln -s $(which ${PYTHON}) /usr/local/bin/python +RUN ln -s $(which python3) /usr/local/bin/python RUN apt-get update && apt-get install -y \ build-essential \ @@ -80,7 +74,7 @@ RUN apt-get update && apt-get install -y \ git \ wget \ openjdk-8-jdk \ - ${PYTHON}-dev \ + python3-dev \ virtualenv \ swig @@ -96,7 +90,6 @@ RUN ${PIP} --no-cache-dir install \ pandas \ future \ portpicker \ - && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \ enum34 # Install bazel diff --git a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile index 5644d536751..58f45fb22b8 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile @@ -96,25 +96,19 @@ RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/lib && echo "/usr/local/cuda/lib64/stubs" > /etc/ld.so.conf.d/z-cuda-stubs.conf \ && ldconfig -ARG USE_PYTHON_3_NOT_2 -# TODO(angerson) Completely remove Python 2 support -ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} -ARG PYTHON=python${_PY_SUFFIX} -ARG PIP=pip${_PY_SUFFIX} - # See http://bugs.python.org/issue19846 ENV LANG C.UTF-8 RUN apt-get update && apt-get install -y \ - ${PYTHON} \ - ${PYTHON}-pip + python3 + python3-pip -RUN ${PIP} --no-cache-dir install --upgrade \ +RUN python3 -m pip --no-cache-dir install --upgrade \ pip \ setuptools # Some TF tools expect a "python" binary -RUN ln -s $(which ${PYTHON}) /usr/local/bin/python +RUN ln -s $(which python3) /usr/local/bin/python RUN apt-get update && apt-get install -y \ build-essential \ @@ -122,7 +116,7 @@ RUN apt-get update && apt-get install -y \ git \ wget \ openjdk-8-jdk \ - ${PYTHON}-dev \ + python3-dev \ virtualenv \ swig @@ -138,7 +132,6 @@ RUN ${PIP} --no-cache-dir install \ pandas \ future \ portpicker \ - && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \ enum34 # Install bazel @@ -153,10 +146,9 @@ RUN mkdir /bazel && \ COPY bashrc /etc/bash.bashrc RUN chmod a+rwx /etc/bash.bashrc -RUN ${PIP} install jupyter matplotlib +RUN python3 -m pip install jupyter matplotlib # Pin ipykernel and nbformat; see https://github.com/ipython/ipykernel/issues/422 -RUN if [[ "${USE_PYTHON_3_NOT_2}" == "1" ]]; then ${PIP} install ipykernel==5.1.1 nbformat==4.4.0; fi -RUN ${PIP} install jupyter_http_over_ws +RUN python3 -m pip install jupyter_http_over_ws ipykernel==5.1.1 nbformat==4.4.0 RUN jupyter serverextension enable --py jupyter_http_over_ws RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/ @@ -176,6 +168,6 @@ RUN apt-get autoremove -y && apt-get remove -y wget WORKDIR /tf EXPOSE 8888 -RUN ${PYTHON} -m ipykernel.kernelspec +RUN python3 -m ipykernel.kernelspec CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/tf --ip 0.0.0.0 --no-browser --allow-root"] diff --git a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile index eb75d905462..9dd3a00b729 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile @@ -96,25 +96,19 @@ RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/lib && echo "/usr/local/cuda/lib64/stubs" > /etc/ld.so.conf.d/z-cuda-stubs.conf \ && ldconfig -ARG USE_PYTHON_3_NOT_2 -# TODO(angerson) Completely remove Python 2 support -ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} -ARG PYTHON=python${_PY_SUFFIX} -ARG PIP=pip${_PY_SUFFIX} - # See http://bugs.python.org/issue19846 ENV LANG C.UTF-8 RUN apt-get update && apt-get install -y \ - ${PYTHON} \ - ${PYTHON}-pip + python3 + python3-pip -RUN ${PIP} --no-cache-dir install --upgrade \ +RUN python3 -m pip --no-cache-dir install --upgrade \ pip \ setuptools # Some TF tools expect a "python" binary -RUN ln -s $(which ${PYTHON}) /usr/local/bin/python +RUN ln -s $(which python3) /usr/local/bin/python RUN apt-get update && apt-get install -y \ build-essential \ @@ -122,7 +116,7 @@ RUN apt-get update && apt-get install -y \ git \ wget \ openjdk-8-jdk \ - ${PYTHON}-dev \ + python3-dev \ virtualenv \ swig @@ -138,7 +132,6 @@ RUN ${PIP} --no-cache-dir install \ pandas \ future \ portpicker \ - && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \ enum34 # Install bazel diff --git a/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile index f9ec2b603cf..f644254ff57 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile @@ -74,25 +74,19 @@ RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/lib && echo "/usr/local/cuda/lib64/stubs" > /etc/ld.so.conf.d/z-cuda-stubs.conf \ && ldconfig -ARG USE_PYTHON_3_NOT_2 -# TODO(angerson) Completely remove Python 2 support -ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} -ARG PYTHON=python${_PY_SUFFIX} -ARG PIP=pip${_PY_SUFFIX} - # See http://bugs.python.org/issue19846 ENV LANG C.UTF-8 RUN apt-get update && apt-get install -y \ - ${PYTHON} \ - ${PYTHON}-pip + python3 + python3-pip -RUN ${PIP} --no-cache-dir install --upgrade \ +RUN python3 -m pip --no-cache-dir install --upgrade \ pip \ setuptools # Some TF tools expect a "python" binary -RUN ln -s $(which ${PYTHON}) /usr/local/bin/python +RUN ln -s $(which python3) /usr/local/bin/python # Options: # tensorflow @@ -108,10 +102,9 @@ RUN ${PIP} install --no-cache-dir ${TF_PACKAGE}${TF_PACKAGE_VERSION:+==${TF_PACK COPY bashrc /etc/bash.bashrc RUN chmod a+rwx /etc/bash.bashrc -RUN ${PIP} install jupyter matplotlib +RUN python3 -m pip install jupyter matplotlib # Pin ipykernel and nbformat; see https://github.com/ipython/ipykernel/issues/422 -RUN if [[ "${USE_PYTHON_3_NOT_2}" == "1" ]]; then ${PIP} install ipykernel==5.1.1 nbformat==4.4.0; fi -RUN ${PIP} install jupyter_http_over_ws +RUN python3 -m pip install jupyter_http_over_ws ipykernel==5.1.1 nbformat==4.4.0 RUN jupyter serverextension enable --py jupyter_http_over_ws RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/ @@ -131,6 +124,6 @@ RUN apt-get autoremove -y && apt-get remove -y wget WORKDIR /tf EXPOSE 8888 -RUN ${PYTHON} -m ipykernel.kernelspec +RUN python3 -m ipykernel.kernelspec CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/tf --ip 0.0.0.0 --no-browser --allow-root"] diff --git a/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile index 33fdc44626e..2f20dcd6104 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile @@ -74,25 +74,19 @@ RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/lib && echo "/usr/local/cuda/lib64/stubs" > /etc/ld.so.conf.d/z-cuda-stubs.conf \ && ldconfig -ARG USE_PYTHON_3_NOT_2 -# TODO(angerson) Completely remove Python 2 support -ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} -ARG PYTHON=python${_PY_SUFFIX} -ARG PIP=pip${_PY_SUFFIX} - # See http://bugs.python.org/issue19846 ENV LANG C.UTF-8 RUN apt-get update && apt-get install -y \ - ${PYTHON} \ - ${PYTHON}-pip + python3 + python3-pip -RUN ${PIP} --no-cache-dir install --upgrade \ +RUN python3 -m pip --no-cache-dir install --upgrade \ pip \ setuptools # Some TF tools expect a "python" binary -RUN ln -s $(which ${PYTHON}) /usr/local/bin/python +RUN ln -s $(which python3) /usr/local/bin/python # Options: # tensorflow diff --git a/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod-jupyter.Dockerfile index 129fc78db54..c22f391bc51 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod-jupyter.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod-jupyter.Dockerfile @@ -54,25 +54,19 @@ ARG CHECKOUT_TF_SRC=0 RUN chmod a+w /etc/passwd /etc/group RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true -ARG USE_PYTHON_3_NOT_2 -# TODO(angerson) Completely remove Python 2 support -ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} -ARG PYTHON=python${_PY_SUFFIX} -ARG PIP=pip${_PY_SUFFIX} - # See http://bugs.python.org/issue19846 ENV LANG C.UTF-8 RUN apt-get update && apt-get install -y \ - ${PYTHON} \ - ${PYTHON}-pip + python3 + python3-pip -RUN ${PIP} --no-cache-dir install --upgrade \ +RUN python3 -m pip --no-cache-dir install --upgrade \ pip \ setuptools # Some TF tools expect a "python" binary -RUN ln -s $(which ${PYTHON}) /usr/local/bin/python +RUN ln -s $(which python3) /usr/local/bin/python RUN apt-get update && apt-get install -y \ build-essential \ @@ -80,7 +74,7 @@ RUN apt-get update && apt-get install -y \ git \ wget \ openjdk-8-jdk \ - ${PYTHON}-dev \ + python3-dev \ virtualenv \ swig @@ -96,11 +90,10 @@ RUN ${PIP} --no-cache-dir install \ pandas \ future \ portpicker \ - && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \ enum34 # Install bazel -ARG BAZEL_VERSION=1.2.1 +ARG BAZEL_VERSION=2.0.0 RUN mkdir /bazel && \ wget -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \ wget -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \ @@ -163,10 +156,9 @@ RUN test "${CHECKOUT_HOROVOD_SRC}" -eq 1 && git clone --recursive https://github COPY bashrc /etc/bash.bashrc RUN chmod a+rwx /etc/bash.bashrc -RUN ${PIP} install jupyter matplotlib +RUN python3 -m pip install jupyter matplotlib # Pin ipykernel and nbformat; see https://github.com/ipython/ipykernel/issues/422 -RUN if [[ "${USE_PYTHON_3_NOT_2}" == "1" ]]; then ${PIP} install ipykernel==5.1.1 nbformat==4.4.0; fi -RUN ${PIP} install jupyter_http_over_ws +RUN python3 -m pip install jupyter_http_over_ws ipykernel==5.1.1 nbformat==4.4.0 RUN jupyter serverextension enable --py jupyter_http_over_ws RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/ @@ -186,6 +178,6 @@ RUN apt-get autoremove -y && apt-get remove -y wget WORKDIR /tf EXPOSE 8888 -RUN ${PYTHON} -m ipykernel.kernelspec +RUN python3 -m ipykernel.kernelspec CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/tf --ip 0.0.0.0 --no-browser --allow-root"] diff --git a/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod.Dockerfile index 245b8f6ee68..08525fa4132 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod.Dockerfile @@ -54,25 +54,19 @@ ARG CHECKOUT_TF_SRC=0 RUN chmod a+w /etc/passwd /etc/group RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true -ARG USE_PYTHON_3_NOT_2 -# TODO(angerson) Completely remove Python 2 support -ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} -ARG PYTHON=python${_PY_SUFFIX} -ARG PIP=pip${_PY_SUFFIX} - # See http://bugs.python.org/issue19846 ENV LANG C.UTF-8 RUN apt-get update && apt-get install -y \ - ${PYTHON} \ - ${PYTHON}-pip + python3 + python3-pip -RUN ${PIP} --no-cache-dir install --upgrade \ +RUN python3 -m pip --no-cache-dir install --upgrade \ pip \ setuptools # Some TF tools expect a "python" binary -RUN ln -s $(which ${PYTHON}) /usr/local/bin/python +RUN ln -s $(which python3) /usr/local/bin/python RUN apt-get update && apt-get install -y \ build-essential \ @@ -80,7 +74,7 @@ RUN apt-get update && apt-get install -y \ git \ wget \ openjdk-8-jdk \ - ${PYTHON}-dev \ + python3-dev \ virtualenv \ swig @@ -96,11 +90,10 @@ RUN ${PIP} --no-cache-dir install \ pandas \ future \ portpicker \ - && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \ enum34 # Install bazel -ARG BAZEL_VERSION=1.2.1 +ARG BAZEL_VERSION=2.0.0 RUN mkdir /bazel && \ wget -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \ wget -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \ diff --git a/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod-jupyter.Dockerfile index a00bf8365e0..82973e3e813 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod-jupyter.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod-jupyter.Dockerfile @@ -25,25 +25,19 @@ FROM ubuntu:${UBUNTU_VERSION} as base RUN apt-get update && apt-get install -y curl -ARG USE_PYTHON_3_NOT_2 -# TODO(angerson) Completely remove Python 2 support -ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} -ARG PYTHON=python${_PY_SUFFIX} -ARG PIP=pip${_PY_SUFFIX} - # See http://bugs.python.org/issue19846 ENV LANG C.UTF-8 RUN apt-get update && apt-get install -y \ - ${PYTHON} \ - ${PYTHON}-pip + python3 + python3-pip -RUN ${PIP} --no-cache-dir install --upgrade \ +RUN python3 -m pip --no-cache-dir install --upgrade \ pip \ setuptools # Some TF tools expect a "python" binary -RUN ln -s $(which ${PYTHON}) /usr/local/bin/python +RUN ln -s $(which python3) /usr/local/bin/python # Options: # tensorflow @@ -111,10 +105,9 @@ RUN ${PIP} install --no-cache-dir horovod==${HOROVOD_VERSION} COPY bashrc /etc/bash.bashrc RUN chmod a+rwx /etc/bash.bashrc -RUN ${PIP} install jupyter matplotlib +RUN python3 -m pip install jupyter matplotlib # Pin ipykernel and nbformat; see https://github.com/ipython/ipykernel/issues/422 -RUN if [[ "${USE_PYTHON_3_NOT_2}" == "1" ]]; then ${PIP} install ipykernel==5.1.1 nbformat==4.4.0; fi -RUN ${PIP} install jupyter_http_over_ws +RUN python3 -m pip install jupyter_http_over_ws ipykernel==5.1.1 nbformat==4.4.0 RUN jupyter serverextension enable --py jupyter_http_over_ws RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/ @@ -134,6 +127,6 @@ RUN apt-get autoremove -y && apt-get remove -y wget WORKDIR /tf EXPOSE 8888 -RUN ${PYTHON} -m ipykernel.kernelspec +RUN python3 -m ipykernel.kernelspec CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/tf --ip 0.0.0.0 --no-browser --allow-root"] diff --git a/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod.Dockerfile index 4208ea8ffac..9ca4c15f630 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod.Dockerfile @@ -25,25 +25,19 @@ FROM ubuntu:${UBUNTU_VERSION} as base RUN apt-get update && apt-get install -y curl -ARG USE_PYTHON_3_NOT_2 -# TODO(angerson) Completely remove Python 2 support -ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} -ARG PYTHON=python${_PY_SUFFIX} -ARG PIP=pip${_PY_SUFFIX} - # See http://bugs.python.org/issue19846 ENV LANG C.UTF-8 RUN apt-get update && apt-get install -y \ - ${PYTHON} \ - ${PYTHON}-pip + python3 + python3-pip -RUN ${PIP} --no-cache-dir install --upgrade \ +RUN python3 -m pip --no-cache-dir install --upgrade \ pip \ setuptools # Some TF tools expect a "python" binary -RUN ln -s $(which ${PYTHON}) /usr/local/bin/python +RUN ln -s $(which python3) /usr/local/bin/python # Options: # tensorflow diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le-jupyter.Dockerfile index 4dae9f50c4b..db630ec3c63 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le-jupyter.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le-jupyter.Dockerfile @@ -25,25 +25,19 @@ FROM ubuntu:${UBUNTU_VERSION} as base RUN apt-get update && apt-get install -y curl -ARG USE_PYTHON_3_NOT_2 -# TODO(angerson) Completely remove Python 2 support -ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} -ARG PYTHON=python${_PY_SUFFIX} -ARG PIP=pip${_PY_SUFFIX} - # See http://bugs.python.org/issue19846 ENV LANG C.UTF-8 RUN apt-get update && apt-get install -y \ - ${PYTHON} \ - ${PYTHON}-pip + python3 + python3-pip -RUN ${PIP} --no-cache-dir install --upgrade \ +RUN python3 -m pip --no-cache-dir install --upgrade \ pip \ setuptools # Some TF tools expect a "python" binary -RUN ln -s $(which ${PYTHON}) /usr/local/bin/python +RUN ln -s $(which python3) /usr/local/bin/python # Options: # tensorflow @@ -68,8 +62,8 @@ RUN if [ ${TF_PACKAGE} = tensorflow-gpu ]; then \ elif [ ${TF_PACKAGE} = tf-nightly ]; then \ BASE=https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/lastSuccessfulBuild/; \ fi; \ - MAJOR=`${PYTHON} -c 'import sys; print(sys.version_info[0])'`; \ - MINOR=`${PYTHON} -c 'import sys; print(sys.version_info[1])'`; \ + MAJOR=`python3 -c 'import sys; print(sys.version_info[0])'`; \ + MINOR=`python3 -c 'import sys; print(sys.version_info[1])'`; \ PACKAGE=$(wget -qO- ${BASE}"api/xml?xpath=//fileName&wrapper=artifacts" | grep -o "[^<>]*cp${MAJOR}${MINOR}[^<>]*.whl"); \ wget ${BASE}"artifact/tensorflow_pkg/"${PACKAGE}; \ ${PIP} install ${PACKAGE} @@ -77,10 +71,9 @@ RUN if [ ${TF_PACKAGE} = tensorflow-gpu ]; then \ COPY bashrc /etc/bash.bashrc RUN chmod a+rwx /etc/bash.bashrc -RUN ${PIP} install jupyter matplotlib +RUN python3 -m pip install jupyter matplotlib # Pin ipykernel and nbformat; see https://github.com/ipython/ipykernel/issues/422 -RUN if [[ "${USE_PYTHON_3_NOT_2}" == "1" ]]; then ${PIP} install ipykernel==5.1.1 nbformat==4.4.0; fi -RUN ${PIP} install jupyter_http_over_ws +RUN python3 -m pip install jupyter_http_over_ws ipykernel==5.1.1 nbformat==4.4.0 RUN jupyter serverextension enable --py jupyter_http_over_ws RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/ @@ -100,6 +93,6 @@ RUN apt-get autoremove -y && apt-get remove -y wget WORKDIR /tf EXPOSE 8888 -RUN ${PYTHON} -m ipykernel.kernelspec +RUN python3 -m ipykernel.kernelspec CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/tf --ip 0.0.0.0 --no-browser --allow-root"] diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le.Dockerfile index e6a5184c8e7..63f39312225 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le.Dockerfile @@ -25,25 +25,19 @@ FROM ubuntu:${UBUNTU_VERSION} as base RUN apt-get update && apt-get install -y curl -ARG USE_PYTHON_3_NOT_2 -# TODO(angerson) Completely remove Python 2 support -ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} -ARG PYTHON=python${_PY_SUFFIX} -ARG PIP=pip${_PY_SUFFIX} - # See http://bugs.python.org/issue19846 ENV LANG C.UTF-8 RUN apt-get update && apt-get install -y \ - ${PYTHON} \ - ${PYTHON}-pip + python3 + python3-pip -RUN ${PIP} --no-cache-dir install --upgrade \ +RUN python3 -m pip --no-cache-dir install --upgrade \ pip \ setuptools # Some TF tools expect a "python" binary -RUN ln -s $(which ${PYTHON}) /usr/local/bin/python +RUN ln -s $(which python3) /usr/local/bin/python # Options: # tensorflow @@ -68,8 +62,8 @@ RUN if [ ${TF_PACKAGE} = tensorflow-gpu ]; then \ elif [ ${TF_PACKAGE} = tf-nightly ]; then \ BASE=https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/lastSuccessfulBuild/; \ fi; \ - MAJOR=`${PYTHON} -c 'import sys; print(sys.version_info[0])'`; \ - MINOR=`${PYTHON} -c 'import sys; print(sys.version_info[1])'`; \ + MAJOR=`python3 -c 'import sys; print(sys.version_info[0])'`; \ + MINOR=`python3 -c 'import sys; print(sys.version_info[1])'`; \ PACKAGE=$(wget -qO- ${BASE}"api/xml?xpath=//fileName&wrapper=artifacts" | grep -o "[^<>]*cp${MAJOR}${MINOR}[^<>]*.whl"); \ wget ${BASE}"artifact/tensorflow_pkg/"${PACKAGE}; \ ${PIP} install ${PACKAGE} diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le-jupyter.Dockerfile index 0870cfc83f9..662f28ffc7a 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le-jupyter.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le-jupyter.Dockerfile @@ -54,32 +54,26 @@ ARG CHECKOUT_TF_SRC=0 RUN chmod a+w /etc/passwd /etc/group RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true -ARG USE_PYTHON_3_NOT_2 -# TODO(angerson) Completely remove Python 2 support -ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} -ARG PYTHON=python${_PY_SUFFIX} -ARG PIP=pip${_PY_SUFFIX} - # See http://bugs.python.org/issue19846 ENV LANG C.UTF-8 RUN apt-get update && apt-get install -y \ - ${PYTHON} \ - ${PYTHON}-pip + python3 + python3-pip -RUN ${PIP} --no-cache-dir install --upgrade \ +RUN python3 -m pip --no-cache-dir install --upgrade \ pip \ setuptools # Some TF tools expect a "python" binary -RUN ln -s $(which ${PYTHON}) /usr/local/bin/python +RUN ln -s $(which python3) /usr/local/bin/python RUN apt-get update && apt-get install -y \ build-essential \ curl \ git \ openjdk-8-jdk \ - ${PYTHON}-dev \ + python3-dev \ virtualenv \ swig @@ -94,7 +88,6 @@ RUN ${PIP} --no-cache-dir install \ sklearn \ pandas \ portpicker \ - && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \ enum34 # Build and install bazel @@ -112,10 +105,9 @@ RUN mkdir /bazel && \ COPY bashrc /etc/bash.bashrc RUN chmod a+rwx /etc/bash.bashrc -RUN ${PIP} install jupyter matplotlib +RUN python3 -m pip install jupyter matplotlib # Pin ipykernel and nbformat; see https://github.com/ipython/ipykernel/issues/422 -RUN if [[ "${USE_PYTHON_3_NOT_2}" == "1" ]]; then ${PIP} install ipykernel==5.1.1 nbformat==4.4.0; fi -RUN ${PIP} install jupyter_http_over_ws +RUN python3 -m pip install jupyter_http_over_ws ipykernel==5.1.1 nbformat==4.4.0 RUN jupyter serverextension enable --py jupyter_http_over_ws RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/ @@ -135,6 +127,6 @@ RUN apt-get autoremove -y && apt-get remove -y wget WORKDIR /tf EXPOSE 8888 -RUN ${PYTHON} -m ipykernel.kernelspec +RUN python3 -m ipykernel.kernelspec CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/tf --ip 0.0.0.0 --no-browser --allow-root"] diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le.Dockerfile index 560eca22508..aac72a12640 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le.Dockerfile @@ -54,32 +54,26 @@ ARG CHECKOUT_TF_SRC=0 RUN chmod a+w /etc/passwd /etc/group RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true -ARG USE_PYTHON_3_NOT_2 -# TODO(angerson) Completely remove Python 2 support -ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} -ARG PYTHON=python${_PY_SUFFIX} -ARG PIP=pip${_PY_SUFFIX} - # See http://bugs.python.org/issue19846 ENV LANG C.UTF-8 RUN apt-get update && apt-get install -y \ - ${PYTHON} \ - ${PYTHON}-pip + python3 + python3-pip -RUN ${PIP} --no-cache-dir install --upgrade \ +RUN python3 -m pip --no-cache-dir install --upgrade \ pip \ setuptools # Some TF tools expect a "python" binary -RUN ln -s $(which ${PYTHON}) /usr/local/bin/python +RUN ln -s $(which python3) /usr/local/bin/python RUN apt-get update && apt-get install -y \ build-essential \ curl \ git \ openjdk-8-jdk \ - ${PYTHON}-dev \ + python3-dev \ virtualenv \ swig @@ -94,7 +88,6 @@ RUN ${PIP} --no-cache-dir install \ sklearn \ pandas \ portpicker \ - && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \ enum34 # Build and install bazel diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le-jupyter.Dockerfile index cec115afea9..9c43f1128f0 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le-jupyter.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le-jupyter.Dockerfile @@ -96,32 +96,26 @@ RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/lib && echo "/usr/local/cuda/lib64/stubs" > /etc/ld.so.conf.d/z-cuda-stubs.conf \ && ldconfig -ARG USE_PYTHON_3_NOT_2 -# TODO(angerson) Completely remove Python 2 support -ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} -ARG PYTHON=python${_PY_SUFFIX} -ARG PIP=pip${_PY_SUFFIX} - # See http://bugs.python.org/issue19846 ENV LANG C.UTF-8 RUN apt-get update && apt-get install -y \ - ${PYTHON} \ - ${PYTHON}-pip + python3 + python3-pip -RUN ${PIP} --no-cache-dir install --upgrade \ +RUN python3 -m pip --no-cache-dir install --upgrade \ pip \ setuptools # Some TF tools expect a "python" binary -RUN ln -s $(which ${PYTHON}) /usr/local/bin/python +RUN ln -s $(which python3) /usr/local/bin/python RUN apt-get update && apt-get install -y \ build-essential \ curl \ git \ openjdk-8-jdk \ - ${PYTHON}-dev \ + python3-dev \ virtualenv \ swig @@ -136,7 +130,6 @@ RUN ${PIP} --no-cache-dir install \ sklearn \ pandas \ portpicker \ - && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \ enum34 # Build and install bazel @@ -154,10 +147,9 @@ RUN mkdir /bazel && \ COPY bashrc /etc/bash.bashrc RUN chmod a+rwx /etc/bash.bashrc -RUN ${PIP} install jupyter matplotlib +RUN python3 -m pip install jupyter matplotlib # Pin ipykernel and nbformat; see https://github.com/ipython/ipykernel/issues/422 -RUN if [[ "${USE_PYTHON_3_NOT_2}" == "1" ]]; then ${PIP} install ipykernel==5.1.1 nbformat==4.4.0; fi -RUN ${PIP} install jupyter_http_over_ws +RUN python3 -m pip install jupyter_http_over_ws ipykernel==5.1.1 nbformat==4.4.0 RUN jupyter serverextension enable --py jupyter_http_over_ws RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/ @@ -177,6 +169,6 @@ RUN apt-get autoremove -y && apt-get remove -y wget WORKDIR /tf EXPOSE 8888 -RUN ${PYTHON} -m ipykernel.kernelspec +RUN python3 -m ipykernel.kernelspec CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/tf --ip 0.0.0.0 --no-browser --allow-root"] diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le.Dockerfile index 7b450773091..937e78ce1e9 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le.Dockerfile @@ -96,32 +96,26 @@ RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/lib && echo "/usr/local/cuda/lib64/stubs" > /etc/ld.so.conf.d/z-cuda-stubs.conf \ && ldconfig -ARG USE_PYTHON_3_NOT_2 -# TODO(angerson) Completely remove Python 2 support -ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} -ARG PYTHON=python${_PY_SUFFIX} -ARG PIP=pip${_PY_SUFFIX} - # See http://bugs.python.org/issue19846 ENV LANG C.UTF-8 RUN apt-get update && apt-get install -y \ - ${PYTHON} \ - ${PYTHON}-pip + python3 + python3-pip -RUN ${PIP} --no-cache-dir install --upgrade \ +RUN python3 -m pip --no-cache-dir install --upgrade \ pip \ setuptools # Some TF tools expect a "python" binary -RUN ln -s $(which ${PYTHON}) /usr/local/bin/python +RUN ln -s $(which python3) /usr/local/bin/python RUN apt-get update && apt-get install -y \ build-essential \ curl \ git \ openjdk-8-jdk \ - ${PYTHON}-dev \ + python3-dev \ virtualenv \ swig @@ -136,7 +130,6 @@ RUN ${PIP} --no-cache-dir install \ sklearn \ pandas \ portpicker \ - && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \ enum34 # Build and install bazel diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le-jupyter.Dockerfile index 93ac1fc1bbc..692212cf4c5 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le-jupyter.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le-jupyter.Dockerfile @@ -74,25 +74,19 @@ RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/lib && echo "/usr/local/cuda/lib64/stubs" > /etc/ld.so.conf.d/z-cuda-stubs.conf \ && ldconfig -ARG USE_PYTHON_3_NOT_2 -# TODO(angerson) Completely remove Python 2 support -ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} -ARG PYTHON=python${_PY_SUFFIX} -ARG PIP=pip${_PY_SUFFIX} - # See http://bugs.python.org/issue19846 ENV LANG C.UTF-8 RUN apt-get update && apt-get install -y \ - ${PYTHON} \ - ${PYTHON}-pip + python3 + python3-pip -RUN ${PIP} --no-cache-dir install --upgrade \ +RUN python3 -m pip --no-cache-dir install --upgrade \ pip \ setuptools # Some TF tools expect a "python" binary -RUN ln -s $(which ${PYTHON}) /usr/local/bin/python +RUN ln -s $(which python3) /usr/local/bin/python # Options: # tensorflow @@ -117,8 +111,8 @@ RUN if [ ${TF_PACKAGE} = tensorflow-gpu ]; then \ elif [ ${TF_PACKAGE} = tf-nightly ]; then \ BASE=https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/lastSuccessfulBuild/; \ fi; \ - MAJOR=`${PYTHON} -c 'import sys; print(sys.version_info[0])'`; \ - MINOR=`${PYTHON} -c 'import sys; print(sys.version_info[1])'`; \ + MAJOR=`python3 -c 'import sys; print(sys.version_info[0])'`; \ + MINOR=`python3 -c 'import sys; print(sys.version_info[1])'`; \ PACKAGE=$(wget -qO- ${BASE}"api/xml?xpath=//fileName&wrapper=artifacts" | grep -o "[^<>]*cp${MAJOR}${MINOR}[^<>]*.whl"); \ wget ${BASE}"artifact/tensorflow_pkg/"${PACKAGE}; \ ${PIP} install ${PACKAGE} @@ -126,10 +120,9 @@ RUN if [ ${TF_PACKAGE} = tensorflow-gpu ]; then \ COPY bashrc /etc/bash.bashrc RUN chmod a+rwx /etc/bash.bashrc -RUN ${PIP} install jupyter matplotlib +RUN python3 -m pip install jupyter matplotlib # Pin ipykernel and nbformat; see https://github.com/ipython/ipykernel/issues/422 -RUN if [[ "${USE_PYTHON_3_NOT_2}" == "1" ]]; then ${PIP} install ipykernel==5.1.1 nbformat==4.4.0; fi -RUN ${PIP} install jupyter_http_over_ws +RUN python3 -m pip install jupyter_http_over_ws ipykernel==5.1.1 nbformat==4.4.0 RUN jupyter serverextension enable --py jupyter_http_over_ws RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/ @@ -149,6 +142,6 @@ RUN apt-get autoremove -y && apt-get remove -y wget WORKDIR /tf EXPOSE 8888 -RUN ${PYTHON} -m ipykernel.kernelspec +RUN python3 -m ipykernel.kernelspec CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/tf --ip 0.0.0.0 --no-browser --allow-root"] diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le.Dockerfile index d174615487f..26673df61c7 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le.Dockerfile @@ -74,25 +74,19 @@ RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/lib && echo "/usr/local/cuda/lib64/stubs" > /etc/ld.so.conf.d/z-cuda-stubs.conf \ && ldconfig -ARG USE_PYTHON_3_NOT_2 -# TODO(angerson) Completely remove Python 2 support -ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} -ARG PYTHON=python${_PY_SUFFIX} -ARG PIP=pip${_PY_SUFFIX} - # See http://bugs.python.org/issue19846 ENV LANG C.UTF-8 RUN apt-get update && apt-get install -y \ - ${PYTHON} \ - ${PYTHON}-pip + python3 + python3-pip -RUN ${PIP} --no-cache-dir install --upgrade \ +RUN python3 -m pip --no-cache-dir install --upgrade \ pip \ setuptools # Some TF tools expect a "python" binary -RUN ln -s $(which ${PYTHON}) /usr/local/bin/python +RUN ln -s $(which python3) /usr/local/bin/python # Options: # tensorflow @@ -117,8 +111,8 @@ RUN if [ ${TF_PACKAGE} = tensorflow-gpu ]; then \ elif [ ${TF_PACKAGE} = tf-nightly ]; then \ BASE=https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/lastSuccessfulBuild/; \ fi; \ - MAJOR=`${PYTHON} -c 'import sys; print(sys.version_info[0])'`; \ - MINOR=`${PYTHON} -c 'import sys; print(sys.version_info[1])'`; \ + MAJOR=`python3 -c 'import sys; print(sys.version_info[0])'`; \ + MINOR=`python3 -c 'import sys; print(sys.version_info[1])'`; \ PACKAGE=$(wget -qO- ${BASE}"api/xml?xpath=//fileName&wrapper=artifacts" | grep -o "[^<>]*cp${MAJOR}${MINOR}[^<>]*.whl"); \ wget ${BASE}"artifact/tensorflow_pkg/"${PACKAGE}; \ ${PIP} install ${PACKAGE} diff --git a/tensorflow/tools/dockerfiles/partials/jupyter.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/jupyter.partial.Dockerfile index 3ffc295f09b..bbe58b7b17d 100644 --- a/tensorflow/tools/dockerfiles/partials/jupyter.partial.Dockerfile +++ b/tensorflow/tools/dockerfiles/partials/jupyter.partial.Dockerfile @@ -1,7 +1,6 @@ -RUN ${PIP} install jupyter matplotlib +RUN python3 -m pip install jupyter matplotlib # Pin ipykernel and nbformat; see https://github.com/ipython/ipykernel/issues/422 -RUN if [[ "${USE_PYTHON_3_NOT_2}" == "1" ]]; then ${PIP} install ipykernel==5.1.1 nbformat==4.4.0; fi -RUN ${PIP} install jupyter_http_over_ws +RUN python3 -m pip install jupyter_http_over_ws ipykernel==5.1.1 nbformat==4.4.0 RUN jupyter serverextension enable --py jupyter_http_over_ws RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/ @@ -21,6 +20,6 @@ RUN apt-get autoremove -y && apt-get remove -y wget WORKDIR /tf EXPOSE 8888 -RUN ${PYTHON} -m ipykernel.kernelspec +RUN python3 -m ipykernel.kernelspec CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/tf --ip 0.0.0.0 --no-browser --allow-root"] diff --git a/tensorflow/tools/dockerfiles/partials/tensorflow-ppc64le.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/tensorflow-ppc64le.partial.Dockerfile index fbeb7f994ff..faf7f31d5a7 100644 --- a/tensorflow/tools/dockerfiles/partials/tensorflow-ppc64le.partial.Dockerfile +++ b/tensorflow/tools/dockerfiles/partials/tensorflow-ppc64le.partial.Dockerfile @@ -21,8 +21,8 @@ RUN if [ ${TF_PACKAGE} = tensorflow-gpu ]; then \ elif [ ${TF_PACKAGE} = tf-nightly ]; then \ BASE=https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/lastSuccessfulBuild/; \ fi; \ - MAJOR=`${PYTHON} -c 'import sys; print(sys.version_info[0])'`; \ - MINOR=`${PYTHON} -c 'import sys; print(sys.version_info[1])'`; \ + MAJOR=`python3 -c 'import sys; print(sys.version_info[0])'`; \ + MINOR=`python3 -c 'import sys; print(sys.version_info[1])'`; \ PACKAGE=$(wget -qO- ${BASE}"api/xml?xpath=//fileName&wrapper=artifacts" | grep -o "[^<>]*cp${MAJOR}${MINOR}[^<>]*.whl"); \ wget ${BASE}"artifact/tensorflow_pkg/"${PACKAGE}; \ ${PIP} install ${PACKAGE} diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu/bazel.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu/bazel.partial.Dockerfile index bcd4d882c92..4135d0538f2 100644 --- a/tensorflow/tools/dockerfiles/partials/ubuntu/bazel.partial.Dockerfile +++ b/tensorflow/tools/dockerfiles/partials/ubuntu/bazel.partial.Dockerfile @@ -4,7 +4,7 @@ RUN apt-get update && apt-get install -y \ git \ wget \ openjdk-8-jdk \ - ${PYTHON}-dev \ + python3-dev \ virtualenv \ swig @@ -20,7 +20,6 @@ RUN ${PIP} --no-cache-dir install \ pandas \ future \ portpicker \ - && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \ enum34 # Install bazel diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu/bazelbuild.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu/bazelbuild.partial.Dockerfile index 27d49a18337..180a6745861 100644 --- a/tensorflow/tools/dockerfiles/partials/ubuntu/bazelbuild.partial.Dockerfile +++ b/tensorflow/tools/dockerfiles/partials/ubuntu/bazelbuild.partial.Dockerfile @@ -3,7 +3,7 @@ RUN apt-get update && apt-get install -y \ curl \ git \ openjdk-8-jdk \ - ${PYTHON}-dev \ + python3-dev \ virtualenv \ swig @@ -18,7 +18,6 @@ RUN ${PIP} --no-cache-dir install \ sklearn \ pandas \ portpicker \ - && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \ enum34 # Build and install bazel diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu/python.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu/python.partial.Dockerfile index 804f8102e52..edbb18bb47c 100644 --- a/tensorflow/tools/dockerfiles/partials/ubuntu/python.partial.Dockerfile +++ b/tensorflow/tools/dockerfiles/partials/ubuntu/python.partial.Dockerfile @@ -1,19 +1,13 @@ -ARG USE_PYTHON_3_NOT_2 -# TODO(angerson) Completely remove Python 2 support -ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} -ARG PYTHON=python${_PY_SUFFIX} -ARG PIP=pip${_PY_SUFFIX} - # See http://bugs.python.org/issue19846 ENV LANG C.UTF-8 RUN apt-get update && apt-get install -y \ - ${PYTHON} \ - ${PYTHON}-pip + python3 + python3-pip -RUN ${PIP} --no-cache-dir install --upgrade \ +RUN python3 -m pip --no-cache-dir install --upgrade \ pip \ setuptools # Some TF tools expect a "python" binary -RUN ln -s $(which ${PYTHON}) /usr/local/bin/python +RUN ln -s $(which python3) /usr/local/bin/python diff --git a/tensorflow/tools/dockerfiles/spec.yml b/tensorflow/tools/dockerfiles/spec.yml index b8187b15dfb..436ef41c15a 100644 --- a/tensorflow/tools/dockerfiles/spec.yml +++ b/tensorflow/tools/dockerfiles/spec.yml @@ -25,28 +25,19 @@ header: | # of functionality ("slices") by listing all of the "slice sets" to use when # building. # -# For example, a release that uses {nightly}{py} would create 4 Dockerfiles -# (which could become images or concrete Dockerfiles), because the "nightly" -# and "py" slice sets both have two entries: -# -# - nightly (no -py2 because the Python 2 slice set has add_to_name: "" -# - nightly-py3 -# - nightly-gpu (similar) -# - nightly-gpu-py3 -# # Releases are all treated differently by TensorFlow's CI systems. releases: # Built Nightly and pushed to tensorflow/tensorflow nightly: tag_specs: - - "{nightly}{py}{jupyter}" - - "{_TAG_PREFIX}{ubuntu-devel}{py-devel}" + - "{nightly}{jupyter}" + - "{_TAG_PREFIX}{ubuntu-devel}" # Built per-release and pushed to tensorflow/tensorflow # --arg _TAG_PREFIX= should be set to "1.11" (for example) or "latest". versioned: tag_specs: - - "{_TAG_PREFIX}{ubuntu}{py}{jupyter}" + - "{_TAG_PREFIX}{ubuntu}{jupyter}" # Dockerfiles stored in the TF repo; not pushed anywhere dockerfiles: @@ -62,22 +53,6 @@ releases: slice_sets: - py: - - add_to_name: "" - args: - - USE_PYTHON_3_NOT_2=1 - - add_to_name: "-py3" - args: - - USE_PYTHON_3_NOT_2=1 - - py-devel: - - add_to_name: "" - args: - - USE_PYTHON_3_NOT_2=1 - - add_to_name: "-py3" - args: - - USE_PYTHON_3_NOT_2=1 - jupyter: - add_to_name: "" - add_to_name: "-jupyter" diff --git a/tensorflow/tools/docs/tf_doctest_lib.py b/tensorflow/tools/docs/tf_doctest_lib.py index 2ba368e6fa2..96f2bc7341b 100644 --- a/tensorflow/tools/docs/tf_doctest_lib.py +++ b/tensorflow/tools/docs/tf_doctest_lib.py @@ -146,6 +146,12 @@ class TfDoctestOutputChecker(doctest.OutputChecker, object): A bool, indicating if the check was successful or not. """ + # If the docstring's output is empty and there is some output generated + # after running the snippet, return True. This is because if the user + # doesn't want to display output, respect that over what the doctest wants. + if not want and got: + return True + # Replace python's addresses with ellipsis (`...`) since it can change on # each execution. want = self._ADDRESS_RE.sub('at ...>', want) diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 991a5742579..e152f9b6a22 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -95,7 +95,6 @@ COMMON_PIP_DEPS = [ "//tensorflow:tensorflow_py", "//tensorflow/compiler/tf2xla:xla_compiled_cpu_runtime_hdrs", "//tensorflow/compiler/tf2xla:xla_compiled_cpu_runtime_srcs", - "//tensorflow/core/data/service/python:server_lib", "//tensorflow/core:protos_all_proto_srcs", "//tensorflow/examples/saved_model/integration_tests:mnist_util", "//tensorflow/lite/python/testdata:interpreter_test_data", @@ -107,6 +106,7 @@ COMMON_PIP_DEPS = [ "//tensorflow/python/autograph/pyct/common_transformers:common_transformers", "//tensorflow/python/compiler:compiler", "//tensorflow/python:cond_v2", + "//tensorflow/python/data/service:server_lib", "//tensorflow/python:distributed_framework_test_lib", "//tensorflow/python/distribute:distribute_test_lib_pip", "//tensorflow/python:loss_scale", @@ -171,6 +171,7 @@ filegroup( "//third_party/fft2d:LICENSE", "//third_party/hadoop:LICENSE.txt", "//third_party/icu/data:LICENSE", + "@ruy//:LICENSE", "@arm_neon_2_x86_sse//:LICENSE", "@astunparse_archive//:LICENSE", "@astor_archive//:LICENSE", diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 142652f3e23..ead41e83c37 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -39,6 +39,7 @@ load("//third_party/kissfft:workspace.bzl", kissfft = "repo") load("//third_party/pasta:workspace.bzl", pasta = "repo") load("//third_party/psimd:workspace.bzl", psimd = "repo") load("//third_party/pthreadpool:workspace.bzl", pthreadpool = "repo") +load("//third_party/ruy:workspace.bzl", ruy = "repo") load("//third_party/sobol_data:workspace.bzl", sobol_data = "repo") load("//third_party/vulkan_headers:workspace.bzl", vulkan_headers = "repo") load("//third_party/toolchains/remote_config:configs.bzl", "initialize_rbe_configs") @@ -65,6 +66,7 @@ def initialize_third_party(): pthreadpool() sobol_data() vulkan_headers() + ruy() # Sanitize a dependency so that it works correctly from code that includes # TensorFlow as a submodule. @@ -201,11 +203,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): name = "eigen_archive", build_file = clean_dep("//third_party:eigen.BUILD"), patch_file = clean_dep("//third_party/eigen3:gpu_packet_math.patch"), - sha256 = "2f046557f4093becf51b44c6339873c18e2f1ea55c4b3f3a08b7d15a1d9c6e5b", # SHARED_EIGEN_SHA - strip_prefix = "eigen-4fd5d1477b221fc7daf2b7f1c7e4ee4f04ceaced", + sha256 = "62d1581a740caa74f1bf9db8552abebcd772bf12be035e9422bd59bfb0a2ba8e", # SHARED_EIGEN_SHA + strip_prefix = "eigen-deb93ed1bf359ac99923e3a2b90a2920b1101290", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/4fd5d1477b221fc7daf2b7f1c7e4ee4f04ceaced/eigen-4fd5d1477b221fc7daf2b7f1c7e4ee4f04ceaced.tar.gz", - "https://gitlab.com/libeigen/eigen/-/archive/4fd5d1477b221fc7daf2b7f1c7e4ee4f04ceaced/eigen-4fd5d1477b221fc7daf2b7f1c7e4ee4f04ceaced.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/deb93ed1bf359ac99923e3a2b90a2920b1101290/eigen-deb93ed1bf359ac99923e3a2b90a2920b1101290.tar.gz", + "https://gitlab.com/libeigen/eigen/-/archive/deb93ed1bf359ac99923e3a2b90a2920b1101290/eigen-deb93ed1bf359ac99923e3a2b90a2920b1101290.tar.gz", ], ) @@ -545,12 +547,12 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "curl", build_file = clean_dep("//third_party:curl.BUILD"), - sha256 = "d0393da38ac74ffac67313072d7fe75b1fa1010eb5987f63f349b024a36b7ffb", - strip_prefix = "curl-7.66.0", + sha256 = "01ae0c123dee45b01bbaef94c0bc00ed2aec89cb2ee0fd598e0d302a6b5e0a98", + strip_prefix = "curl-7.69.1", system_build_file = clean_dep("//third_party/systemlibs:curl.BUILD"), urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/curl.haxx.se/download/curl-7.66.0.tar.gz", - "https://curl.haxx.se/download/curl-7.66.0.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/curl.haxx.se/download/curl-7.69.1.tar.gz", + "https://curl.haxx.se/download/curl-7.69.1.tar.gz", ], ) @@ -589,8 +591,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): ) # Check out LLVM and MLIR from llvm-project. - LLVM_COMMIT = "7a4a98a9c4f39d9c395f5ce587dbbcb5450a9655" - LLVM_SHA256 = "d11b4b7e4522e86d9525f1ad1f840f2f871164ab0b0f848e9a1f314af63cf3d7" + LLVM_COMMIT = "71305033d11564fe9c19e8e40200680daae41e89" + LLVM_SHA256 = "b3deb570e34de7850793e17a15e9576fbc6bc0229150a7cf0ff26a2cdb010ad0" LLVM_URLS = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), @@ -622,12 +624,12 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "jsoncpp_git", build_file = clean_dep("//third_party:jsoncpp.BUILD"), - sha256 = "c49deac9e0933bcb7044f08516861a2d560988540b23de2ac1ad443b219afdb6", - strip_prefix = "jsoncpp-1.8.4", + sha256 = "77a402fb577b2e0e5d0bdc1cf9c65278915cdb25171e3452c68b6da8a561f8f0", + strip_prefix = "jsoncpp-1.9.2", system_build_file = clean_dep("//third_party/systemlibs:jsoncpp.BUILD"), urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/open-source-parsers/jsoncpp/archive/1.8.4.tar.gz", - "https://github.com/open-source-parsers/jsoncpp/archive/1.8.4.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/open-source-parsers/jsoncpp/archive/1.9.2.tar.gz", + "https://github.com/open-source-parsers/jsoncpp/archive/1.9.2.tar.gz", ], ) @@ -1045,6 +1047,16 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): "https://github.com/GrahamDumpleton/wrapt/archive/1.11.1.tar.gz", ], ) + tf_http_archive( + name = "coremltools", + sha256 = "0d594a714e8a5fd5bd740ad112ef59155c0482e25fdc8f8efa5758f90abdcf1e", + strip_prefix = "coremltools-3.3", + build_file = clean_dep("//third_party:coremltools.BUILD"), + urls = [ + "http://mirror.tensorflow.org/github.com/apple/coremltools/archive/3.3.zip", + "https://github.com/apple/coremltools/archive/3.3.zip", + ], + ) def tf_bind(): """Bind targets for some external repositories""" diff --git a/third_party/coremltools.BUILD b/third_party/coremltools.BUILD new file mode 100644 index 00000000000..2c50359de5a --- /dev/null +++ b/third_party/coremltools.BUILD @@ -0,0 +1,15 @@ +load("@com_google_protobuf//:protobuf.bzl", "cc_proto_library") + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # BSD + +exports_files(["LICENSE.txt"]) + +cc_proto_library( + name = "mlmodel_cc_proto", + srcs = glob(["mlmodel/format/*.proto"]), + include = "mlmodel/format", + default_runtime = "@com_google_protobuf//:protobuf_lite", + protoc = "@com_google_protobuf//:protoc", +) diff --git a/third_party/cpuinfo/BUILD.bazel b/third_party/cpuinfo/BUILD.bazel index d0d44a4663b..1a8557a89be 100644 --- a/third_party/cpuinfo/BUILD.bazel +++ b/third_party/cpuinfo/BUILD.bazel @@ -98,6 +98,7 @@ cc_library( srcs = select({ ":linux_x86_64": COMMON_SRCS + X86_SRCS + LINUX_SRCS + LINUX_X86_SRCS, ":linux_aarch64": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM64_SRCS, + ":linux_arm": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM32_SRCS, ":macos_x86_64": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS, ":android_armv7": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM32_SRCS + ANDROID_ARM_SRCS, ":android_arm64": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM64_SRCS + ANDROID_ARM_SRCS, @@ -167,6 +168,12 @@ config_setting( values = {"cpu": "aarch64"}, ) +config_setting( + name = "linux_arm", + values = {"cpu": "arm"}, + visibility = ["//visibility:public"], +) + config_setting( name = "macos_x86_64", values = { diff --git a/third_party/curl.BUILD b/third_party/curl.BUILD index f3a7e3f59e7..bd9709e4383 100644 --- a/third_party/curl.BUILD +++ b/third_party/curl.BUILD @@ -163,6 +163,8 @@ cc_library( "lib/quic.h", "lib/rand.c", "lib/rand.h", + "lib/rename.h", + "lib/rename.c", "lib/rtsp.c", "lib/rtsp.h", "lib/security.c", @@ -183,13 +185,13 @@ cc_library( "lib/smb.h", "lib/smtp.h", "lib/sockaddr.h", + "lib/socketpair.h", "lib/socks.c", "lib/socks.h", "lib/speedcheck.c", "lib/speedcheck.h", "lib/splay.c", "lib/splay.h", - "lib/ssh.h", "lib/strcase.c", "lib/strcase.h", "lib/strdup.c", @@ -219,13 +221,13 @@ cc_library( "lib/vauth/vauth.c", "lib/vauth/vauth.h", "lib/version.c", + "lib/vssh/ssh.h", + "lib/vtls/bearssl.h", "lib/vtls/gskit.h", "lib/vtls/gtls.h", "lib/vtls/mbedtls.h", "lib/vtls/nssg.h", "lib/vtls/openssl.h", - "lib/vtls/polarssl.h", - "lib/vtls/polarssl_threadlock.h", "lib/vtls/schannel.h", "lib/vtls/vtls.c", "lib/vtls/vtls.h", diff --git a/third_party/flatbuffers/BUILD.bazel b/third_party/flatbuffers/BUILD.bazel index 3b21a73154a..1ee46f05235 100644 --- a/third_party/flatbuffers/BUILD.bazel +++ b/third_party/flatbuffers/BUILD.bazel @@ -112,6 +112,7 @@ filegroup( "python/flatbuffers/number_types.py", "python/flatbuffers/packer.py", "python/flatbuffers/table.py", + "python/flatbuffers/util.py", ], ) diff --git a/third_party/flatbuffers/workspace.bzl b/third_party/flatbuffers/workspace.bzl index dffc100bc22..d1d19a46134 100644 --- a/third_party/flatbuffers/workspace.bzl +++ b/third_party/flatbuffers/workspace.bzl @@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive") def repo(): third_party_http_archive( name = "flatbuffers", - strip_prefix = "flatbuffers-a4b2884e4ed6116335d534af8f58a84678b74a17", - sha256 = "6ff041dcaf873acbf0a93886e6b4f7704b68af1457e8b675cae88fbefe2de330", + strip_prefix = "flatbuffers-1.12.0", + sha256 = "62f2223fb9181d1d6338451375628975775f7522185266cd5296571ac152bc45", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/https://github.com/google/flatbuffers/archive/a4b2884e4ed6116335d534af8f58a84678b74a17.zip", - "https://github.com/google/flatbuffers/archive/a4b2884e4ed6116335d534af8f58a84678b74a17.zip", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v1.12.0.tar.gz", + "https://github.com/google/flatbuffers/archive/v1.12.0.tar.gz", ], build_file = "//third_party/flatbuffers:BUILD.bazel", system_build_file = "//third_party/flatbuffers:BUILD.system", diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index 203630802e4..8fa64f264dc 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -1035,7 +1035,18 @@ def _create_local_cuda_repository(repository_ctx): cuda_defines["%{host_compiler_prefix}"] = host_compiler_prefix - cuda_defines["%{linker_bin_path}"] = "" + # Bazel sets '-B/usr/bin' flag to workaround build errors on RHEL (see + # https://github.com/bazelbuild/bazel/issues/760). + # However, this stops our custom clang toolchain from picking the provided + # LLD linker, so we're only adding '-B/usr/bin' when using non-downloaded + # toolchain. + # TODO: when bazel stops adding '-B/usr/bin' by default, remove this + # flag from the CROSSTOOL completely (see + # https://github.com/bazelbuild/bazel/issues/5634) + if should_download_clang: + cuda_defines["%{linker_bin_path}"] = "" + else: + cuda_defines["%{linker_bin_path}"] = host_compiler_prefix cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "" cuda_defines["%{unfiltered_compile_flags}"] = "" diff --git a/third_party/jsoncpp.BUILD b/third_party/jsoncpp.BUILD index cf3cba05556..7bc466c664f 100644 --- a/third_party/jsoncpp.BUILD +++ b/third_party/jsoncpp.BUILD @@ -12,11 +12,12 @@ cc_library( "src/lib_json/json_writer.cpp", ], hdrs = [ + "include/json/allocator.h", "include/json/autolink.h", "include/json/config.h", - "include/json/features.h", "include/json/forwards.h", "include/json/json.h", + "include/json/json_features.h", "include/json/reader.h", "include/json/value.h", "include/json/version.h", diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD index f9b7b4bfa52..078bb39eadc 100644 --- a/third_party/mlir/BUILD +++ b/third_party/mlir/BUILD @@ -143,6 +143,13 @@ filegroup( # Affine dialect. ##---------------------------------------------------------------------------## +filegroup( + name = "PassBaseTdFiles", + srcs = [ + "include/mlir/Pass/PassBase.td", + ], +) + filegroup( name = "AffineOpsTdFiles", srcs = [ @@ -250,6 +257,7 @@ cc_library( includes = ["include"], deps = [ ":AVX512", + ":ConversionPassIncGen", ":EDSC", ":IR", ":LLVMAVX512", @@ -300,6 +308,22 @@ gentbl( ], ) +gentbl( + name = "LoopPassIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + "-gen-pass-decls", + "include/mlir/Dialect/LoopOps/Passes.h.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/LoopOps/Passes.td", + td_srcs = [ + ":PassBaseTdFiles", + ], +) + cc_library( name = "LoopOpsTransforms", srcs = glob(["lib/Dialect/LoopOps/Transforms/*.cpp"]), @@ -309,6 +333,7 @@ cc_library( ":Affine", ":IR", ":LoopOps", + ":LoopPassIncGen", ":Pass", ":StandardOps", ":Transforms", @@ -422,6 +447,22 @@ cc_library( ], ) +gentbl( + name = "AffinePassIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + "-gen-pass-decls", + "include/mlir/Dialect/Affine/Passes.h.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Affine/Passes.td", + td_srcs = [ + ":PassBaseTdFiles", + ], +) + cc_library( name = "AffineTransforms", srcs = glob([ @@ -433,6 +474,7 @@ cc_library( includes = ["include"], deps = [ ":Affine", + ":AffinePassIncGen", ":Analysis", ":IR", ":LoopOps", @@ -445,6 +487,22 @@ cc_library( ], ) +gentbl( + name = "ConversionPassIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + "-gen-pass-decls", + "include/mlir/Conversion/Passes.h.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Conversion/Passes.td", + td_srcs = [ + ":PassBaseTdFiles", + ], +) + cc_library( name = "AffineToStandardTransforms", srcs = glob([ @@ -455,6 +513,7 @@ cc_library( includes = ["include"], deps = [ ":Affine", + ":ConversionPassIncGen", ":IR", ":LoopOps", ":Pass", @@ -833,6 +892,22 @@ cc_library( ], ) +gentbl( + name = "LLVMPassIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + "-gen-pass-decls", + "include/mlir/Dialect/LLVMIR/Transforms/Passes.h.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/LLVMIR/Transforms/Passes.td", + td_srcs = [ + ":PassBaseTdFiles", + ], +) + cc_library( name = "LLVMIRTransforms", srcs = glob(["lib/Dialect/LLVMIR/Transforms/*.cpp"]), @@ -841,6 +916,7 @@ cc_library( deps = [ ":IR", ":LLVMDialect", + ":LLVMPassIncGen", ":Pass", ], ) @@ -931,6 +1007,22 @@ cc_library( ], ) +gentbl( + name = "GPUPassIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + "-gen-pass-decls", + "include/mlir/Dialect/GPU/Passes.h.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/GPU/Passes.td", + td_srcs = [ + ":PassBaseTdFiles", + ], +) + cc_library( name = "GPUTransforms", srcs = glob( @@ -949,6 +1041,7 @@ cc_library( deps = [ ":EDSC", ":GPUDialect", + ":GPUPassIncGen", ":IR", ":LoopOps", ":ParallelLoopMapperAttrGen", @@ -1015,6 +1108,7 @@ cc_library( ]), includes = ["include"], deps = [ + ":ConversionPassIncGen", ":GPUCommonTransforms", ":GPUDialect", ":GPUToNVVMGen", @@ -1036,6 +1130,7 @@ cc_library( ], includes = ["include"], deps = [ + ":ConversionPassIncGen", ":GPUCommonTransforms", ":GPUDialect", ":LLVMTransforms", @@ -1054,6 +1149,7 @@ cc_library( hdrs = ["include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"], includes = ["include"], deps = [ + ":ConversionPassIncGen", ":GPUDialect", ":IR", ":LLVMDialect", @@ -1075,6 +1171,7 @@ cc_library( hdrs = ["include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h"], includes = ["include"], deps = [ + ":ConversionPassIncGen", ":GPUDialect", ":IR", ":LLVMDialect", @@ -1120,6 +1217,7 @@ cc_library( "lib/Conversions/GPUToSPIRV", ], deps = [ + ":ConversionPassIncGen", ":GPUDialect", ":GPUToSPIRVIncGen", ":IR", @@ -1528,6 +1626,22 @@ cc_library( ], ) +gentbl( + name = "SPIRVPassIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + "-gen-pass-decls", + "include/mlir/Dialect/SPIRV/Passes.h.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/SPIRV/Passes.td", + td_srcs = [ + ":PassBaseTdFiles", + ], +) + cc_library( name = "SPIRVLowering", srcs = [ @@ -1548,6 +1662,7 @@ cc_library( ":IR", ":Pass", ":SPIRVDialect", + ":SPIRVPassIncGen", ":SPIRVTargetAndABIStructGen", ":StandardOps", ":Support", @@ -1570,6 +1685,7 @@ cc_library( "lib/Conversion/StandardToSPIRV", ], deps = [ + ":ConversionPassIncGen", ":IR", ":Pass", ":SPIRVDialect", @@ -1595,6 +1711,7 @@ cc_library( "lib/Conversion/StandardToStandard", ], deps = [ + ":ConversionPassIncGen", ":IR", ":Pass", ":StandardOps", @@ -1728,6 +1845,22 @@ gentbl( ], ) +gentbl( + name = "TransformsPassIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + "-gen-pass-decls", + "include/mlir/Transforms/Passes.h.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Transforms/Passes.td", + td_srcs = [ + ":PassBaseTdFiles", + ], +) + cc_library( name = "Transforms", srcs = glob([ @@ -1749,6 +1882,7 @@ cc_library( ":StandardOps", ":Support", ":TransformUtils", + ":TransformsPassIncGen", ":VectorOps", "@llvm-project//llvm:support", ], @@ -1776,6 +1910,7 @@ cc_library( deps = [ ":Affine", ":AffineToStandardTransforms", + ":ConversionPassIncGen", ":GPUDialect", ":GPUTransforms", ":IR", @@ -1800,6 +1935,7 @@ cc_library( includes = ["include"], deps = [ ":Affine", + ":ConversionPassIncGen", ":GPUDialect", ":LoopOps", ":LoopsToGPU", @@ -1821,6 +1957,7 @@ cc_library( ], includes = ["include"], deps = [ + ":ConversionPassIncGen", ":IR", ":LLVMDialect", ":LoopOps", @@ -1843,6 +1980,7 @@ cc_library( ], includes = ["include"], deps = [ + ":ConversionPassIncGen", ":IR", ":LLVMDialect", ":Pass", @@ -2290,10 +2428,14 @@ cc_library( ":AVX512", ":AVX512ToLLVM", ":Affine", + ":AffinePassIncGen", ":AffineTransforms", - ":Analysis", + ":CFGTransforms", + ":ConversionPassIncGen", ":FxpMathOps", + ":FxpMathPassIncGen", ":GPUDialect", + ":GPUPassIncGen", ":GPUToCUDATransforms", ":GPUToNVVMTransforms", ":GPUToROCDLTransforms", @@ -2304,27 +2446,36 @@ cc_library( ":LLVMAVX512", ":LLVMDialect", ":LLVMIRTransforms", + ":LLVMPassIncGen", + ":LLVMTransforms", ":LinalgOps", + ":LinalgPassIncGen", ":LinalgToLLVM", ":LinalgToSPIRV", ":LinalgTransforms", ":LoopOps", ":LoopOpsTransforms", + ":LoopPassIncGen", ":LoopsToGPUPass", ":NVVMDialect", ":OpenMPDialect", ":QuantOps", + ":QuantPassIncGen", + ":QuantizerPassIncGen", ":QuantizerTransforms", ":ROCDLDialect", ":SDBM", ":SPIRVDialect", ":SPIRVLowering", + ":SPIRVPassIncGen", ":Shape", ":StandardOps", ":StandardToSPIRVConversions", ":StandardToStandard", ":Transforms", + ":TransformsPassIncGen", ":VectorOps", + ":VectorToLLVM", ], ) @@ -2610,6 +2761,22 @@ gentbl( ], ) +gentbl( + name = "QuantPassIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + "-gen-pass-decls", + "include/mlir/Dialect/Quant/Passes.h.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Quant/Passes.td", + td_srcs = [ + ":PassBaseTdFiles", + ], +) + cc_library( name = "QuantOps", srcs = [ @@ -2636,6 +2803,7 @@ cc_library( ":IR", ":Pass", ":QuantOpsIncGen", + ":QuantPassIncGen", ":SideEffects", ":StandardOps", "@llvm-project//llvm:support", @@ -2681,6 +2849,22 @@ gentbl( ], ) +gentbl( + name = "FxpMathPassIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + "-gen-pass-decls", + "include/mlir/Dialect/FxpMathOps/Passes.h.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/FxpMathOps/Passes.td", + td_srcs = [ + ":PassBaseTdFiles", + ], +) + cc_library( name = "FxpMathOps", srcs = [ @@ -2695,6 +2879,7 @@ cc_library( includes = ["include"], deps = [ ":FxpMathOpsIncGen", + ":FxpMathPassIncGen", ":IR", ":Pass", ":QuantOps", @@ -2842,6 +3027,7 @@ cc_library( ":AffineToStandardTransforms", ":Analysis", ":CFGTransforms", + ":ConversionPassIncGen", ":EDSC", ":IR", ":LLVMDialect", @@ -2869,6 +3055,7 @@ cc_library( ]), includes = ["include"], deps = [ + ":ConversionPassIncGen", ":DialectUtils", ":IR", ":LinalgOps", @@ -2905,6 +3092,22 @@ cc_library( ], ) +gentbl( + name = "LinalgPassIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + "-gen-pass-decls", + "include/mlir/Dialect/Linalg/Passes.h.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Linalg/Passes.td", + td_srcs = [ + ":PassBaseTdFiles", + ], +) + cc_library( name = "LinalgTransforms", srcs = [ @@ -2937,16 +3140,15 @@ cc_library( ":LLVMDialect", ":LLVMTransforms", ":LinalgOps", - ":LinalgOpsIncGen", + ":LinalgPassIncGen", ":LinalgStructuredOpsIncGen", - ":LinalgTransformPatternsIncGen", ":LoopOps", - ":Parser", ":Pass", ":StandardOps", ":Support", ":TransformUtils", ":Transforms", + ":TransformsPassIncGen", ":VectorOps", "@llvm-project//llvm:core", "@llvm-project//llvm:support", @@ -2976,6 +3178,22 @@ cc_library( ], ) +gentbl( + name = "QuantizerPassIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + "-gen-pass-decls", + "include/mlir/Quantizer/Transforms/Passes.h.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Quantizer/Transforms/Passes.td", + td_srcs = [ + ":PassBaseTdFiles", + ], +) + cc_library( name = "QuantizerTransforms", srcs = glob([ @@ -2990,6 +3208,7 @@ cc_library( ":IR", ":Pass", ":QuantOps", + ":QuantizerPassIncGen", ":QuantizerSupportLib", ":Support", "@llvm-project//llvm:support", @@ -3072,6 +3291,7 @@ cc_library( ]), includes = ["include"], deps = [ + ":ConversionPassIncGen", ":DialectUtils", ":EDSC", ":IR", @@ -3099,6 +3319,7 @@ cc_library( includes = ["include"], deps = [ ":Affine", + ":ConversionPassIncGen", ":EDSC", ":IR", ":LLVMDialect", diff --git a/third_party/ngraph/ngraph_tf.BUILD b/third_party/ngraph/ngraph_tf.BUILD index b4b2511e6b1..3ce31feec27 100644 --- a/third_party/ngraph/ngraph_tf.BUILD +++ b/third_party/ngraph/ngraph_tf.BUILD @@ -59,8 +59,8 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/types:variant", "@ngraph//:ngraph_core", - "@org_tensorflow//tensorflow/core:core_cpu_headers_lib", "@org_tensorflow//tensorflow/core:framework_headers_lib", + "@org_tensorflow//tensorflow/core/common_runtime:core_cpu_headers_lib", ], alwayslink = 1, ) diff --git a/third_party/ruy/BUILD b/third_party/ruy/BUILD new file mode 100644 index 00000000000..3ded6314938 --- /dev/null +++ b/third_party/ruy/BUILD @@ -0,0 +1,8 @@ +# Ruy is not BLAS + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files(["LICENSE"]) diff --git a/third_party/ruy/workspace.bzl b/third_party/ruy/workspace.bzl new file mode 100644 index 00000000000..203b89aa7e9 --- /dev/null +++ b/third_party/ruy/workspace.bzl @@ -0,0 +1,15 @@ +"""Loads the ruy library, used by TensorFlow Lite.""" + +load("//third_party:repo.bzl", "third_party_http_archive") + +def repo(): + third_party_http_archive( + name = "ruy", + sha256 = "ac6d71df496a20043252f451d82a01636bb8bba9c3d6b5dc9fadadaffa392751", + strip_prefix = "ruy-91d62808498cea7ccb48aa59181e218b4ad05701", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/ruy/archive/91d62808498cea7ccb48aa59181e218b4ad05701.zip", + "https://github.com/google/ruy/archive/91d62808498cea7ccb48aa59181e218b4ad05701.zip", + ], + build_file = "//third_party/ruy:BUILD", + )