Merged commit includes the following changes:
319411158 by A. Unique TensorFlower<gardener@tensorflow.org>: Integrate LLVM at https://github.com/llvm/llvm-project/commit/d6343e607ac8 -- 319410296 by A. Unique TensorFlower<gardener@tensorflow.org>: [XLA] Implement extra prefetch limit for while uses. Outstanding prefetch limits can prevent prefetches from being scheduled for the duration of while loops. Since using alternate memory for the while loops can be more beneficial, allow specifying additional prefetch limit when the use is a while HLO. -- 319406145 by A. Unique TensorFlower<gardener@tensorflow.org>: [XLA:CPU] Teach dot_op_emitter how to tile&vectorize linalg matmuls And turn them on by default. This is on-par with the existing emitter, sometimes better and unlocks more potential. The strategy classes are duplicated right now, but I expect them to graduate to mlir core soon. I'm planning to remove the custom LLVM IR emitters if this turns out to be stable enough. -- 319402982 by A. Unique TensorFlower<gardener@tensorflow.org>: PR #40327: [ROCm] Enabling optimized FusedBatchNormInferenceMetaKernel for half Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/40327 This PR enables optimized FusedBatchNormInferenceMetaKernel for half on ROCm. Copybara import of the project: -- 5f658e2bc1b20794239658bffe0d7bf9cb89c81f by Eugene Kuznetsov <eugene.kuznetsov@amd.com>: Enabling optimized FusedBatchNormInferenceMetaKernel for half -- 319393611 by A. Unique TensorFlower<gardener@tensorflow.org>: Integrate LLVM at https://github.com/llvm/llvm-project/commit/68498ce8af37 -- 319374663 by A. Unique TensorFlower<gardener@tensorflow.org>: compat: Update forward compatibility horizon to 2020-07-02 -- 319374662 by A. Unique TensorFlower<gardener@tensorflow.org>: Update GraphDef version to 450. -- 319371388 by A. Unique TensorFlower<gardener@tensorflow.org>: Update framework_build_test targets -- 319363982 by A. Unique TensorFlower<gardener@tensorflow.org>: Resolve the permission denied error on Python 3.7 pip install. -- 319361498 by A. Unique TensorFlower<gardener@tensorflow.org>: Add an option to only log parameters whose values are parsed from cmdline flags in the benchmark tool. -- 319356677 by A. Unique TensorFlower<gardener@tensorflow.org>: Fix bug in ReadNonConstantTensor assigning new value to the reference don't update the reference, so use pointer instead. -- 319350974 by A. Unique TensorFlower<gardener@tensorflow.org>: Fix the header inclusion path issue with TensorFlowLiteC -- 319342653 by A. Unique TensorFlower<gardener@tensorflow.org>: Fix the relationship between tpu_executor and tpu_executor_base build targets. -- 319342578 by A. Unique TensorFlower<gardener@tensorflow.org>: Internal change 319340968 by A. Unique TensorFlower<gardener@tensorflow.org>: Internal change PiperOrigin-RevId: 319411158
This commit is contained in:
parent
62b6c316d2
commit
e25d3e084b
@ -471,6 +471,7 @@ cc_library(
|
|||||||
":cpu_runtime",
|
":cpu_runtime",
|
||||||
":ir_emission_utils",
|
":ir_emission_utils",
|
||||||
":mlir_emitter",
|
":mlir_emitter",
|
||||||
|
":mlir_matmul_codegen_strategy",
|
||||||
":target_machine_features",
|
":target_machine_features",
|
||||||
":tiled_dot_emitter",
|
":tiled_dot_emitter",
|
||||||
":vector_support_library",
|
":vector_support_library",
|
||||||
@ -1102,12 +1103,33 @@ cc_library(
|
|||||||
"@llvm-project//llvm:Core",
|
"@llvm-project//llvm:Core",
|
||||||
"@llvm-project//llvm:IPO",
|
"@llvm-project//llvm:IPO",
|
||||||
"@llvm-project//llvm:Linker",
|
"@llvm-project//llvm:Linker",
|
||||||
|
"@llvm-project//mlir:CFGTransforms",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:LLVMTransforms",
|
|
||||||
"@llvm-project//mlir:LinalgToLLVM",
|
|
||||||
"@llvm-project//mlir:LinalgTransforms",
|
"@llvm-project//mlir:LinalgTransforms",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:TargetLLVMIR",
|
"@llvm-project//mlir:TargetLLVMIR",
|
||||||
|
"@llvm-project//mlir:Transforms",
|
||||||
"@llvm-project//mlir:VectorToLLVM",
|
"@llvm-project//mlir:VectorToLLVM",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "mlir_matmul_codegen_strategy",
|
||||||
|
srcs = ["mlir_matmul_codegen_strategy.cc"],
|
||||||
|
hdrs = ["mlir_matmul_codegen_strategy.h"],
|
||||||
|
deps = [
|
||||||
|
"@llvm-project//llvm:Support",
|
||||||
|
"@llvm-project//mlir:Affine",
|
||||||
|
"@llvm-project//mlir:Analysis",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:LinalgOps",
|
||||||
|
"@llvm-project//mlir:LinalgTransforms",
|
||||||
|
"@llvm-project//mlir:Pass",
|
||||||
|
"@llvm-project//mlir:SCFDialect",
|
||||||
|
"@llvm-project//mlir:StandardOps",
|
||||||
|
"@llvm-project//mlir:Support",
|
||||||
|
"@llvm-project//mlir:Transforms",
|
||||||
|
"@llvm-project//mlir:VectorOps",
|
||||||
|
"@llvm-project//mlir:VectorToSCF",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -25,7 +25,6 @@ const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size";
|
|||||||
const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor";
|
const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor";
|
||||||
const char* const kXlaForceEnableExperimentalLlvmIrGemm =
|
const char* const kXlaForceEnableExperimentalLlvmIrGemm =
|
||||||
"xla_force_enable_experimental_llvm_ir_gemm";
|
"xla_force_enable_experimental_llvm_ir_gemm";
|
||||||
const char* const kXlaUseLinalgForDot = "xla_use_linalg_for_dot";
|
|
||||||
const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size";
|
const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size";
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -64,12 +63,6 @@ bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config) {
|
|||||||
return extra_options_map.count(kXlaForceEnableExperimentalLlvmIrGemm) > 0;
|
return extra_options_map.count(kXlaForceEnableExperimentalLlvmIrGemm) > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool UseLinalgForDot(const HloModuleConfig& config) {
|
|
||||||
const auto& extra_options_map =
|
|
||||||
config.debug_options().xla_backend_extra_options();
|
|
||||||
return extra_options_map.count(kXlaUseLinalgForDot) > 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
static absl::string_view RemoveSuffix(absl::string_view str,
|
static absl::string_view RemoveSuffix(absl::string_view str,
|
||||||
absl::string_view suffix) {
|
absl::string_view suffix) {
|
||||||
CHECK_GE(str.size(), suffix.size());
|
CHECK_GE(str.size(), suffix.size());
|
||||||
|
@ -31,10 +31,12 @@ limitations under the License.
|
|||||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||||
#include "mlir/IR/Value.h" // from @llvm-project
|
#include "mlir/IR/Value.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
|
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
|
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
|
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h"
|
#include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
|
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h"
|
#include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
|
#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
|
||||||
@ -202,6 +204,20 @@ class DotOpEmitter {
|
|||||||
.value_or(kDefaultTileSize);
|
.value_or(kDefaultTileSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::array<int64_t, 3> GetMlirGemmTileSize() const {
|
||||||
|
// Tile by 4 x registers x register size. This was picked by running
|
||||||
|
// small matmuls on Haswell and Skylake. There's a lot of room for
|
||||||
|
// improvement here.
|
||||||
|
constexpr int64_t kDefaultTileSizeForM = 4;
|
||||||
|
int64_t elements_per_register =
|
||||||
|
target_machine_features_.vector_register_num_elements(
|
||||||
|
*b_->GetInsertBlock()->getParent(),
|
||||||
|
dot_info_.result_shape.element_type());
|
||||||
|
int64_t num_registers = target_machine_features_.vector_register_count(
|
||||||
|
*b_->GetInsertBlock()->getParent());
|
||||||
|
return {{kDefaultTileSizeForM, num_registers, elements_per_register}};
|
||||||
|
}
|
||||||
|
|
||||||
DotInfo dot_info_;
|
DotInfo dot_info_;
|
||||||
string dot_hlo_name_;
|
string dot_hlo_name_;
|
||||||
const llvm_ir::IrArray& target_array_;
|
const llvm_ir::IrArray& target_array_;
|
||||||
@ -250,6 +266,7 @@ Status DotOpEmitter::EmitLinalgMatmul() {
|
|||||||
absl::StrCat("linalgMatMul_", dot_info_.result_shape.ToString(true), "_",
|
absl::StrCat("linalgMatMul_", dot_info_.result_shape.ToString(true), "_",
|
||||||
dot_info_.lhs_shape.ToString(true), "_",
|
dot_info_.lhs_shape.ToString(true), "_",
|
||||||
dot_info_.rhs_shape.ToString(true));
|
dot_info_.rhs_shape.ToString(true));
|
||||||
|
|
||||||
return EmitMlirFuncAndCall(
|
return EmitMlirFuncAndCall(
|
||||||
mlir_context_, b_, dot_info_.result_shape, operand_shapes, target_ptr,
|
mlir_context_, b_, dot_info_.result_shape, operand_shapes, target_ptr,
|
||||||
operand_ptrs, name, [&](mlir::OpBuilder* builder, mlir::FuncOp function) {
|
operand_ptrs, name, [&](mlir::OpBuilder* builder, mlir::FuncOp function) {
|
||||||
@ -259,6 +276,27 @@ Status DotOpEmitter::EmitLinalgMatmul() {
|
|||||||
mlir::edsc::intrinsics::linalg_matmul(mlir::TypeRange{},
|
mlir::edsc::intrinsics::linalg_matmul(mlir::TypeRange{},
|
||||||
mlir::ValueRange{b, c, a});
|
mlir::ValueRange{b, c, a});
|
||||||
mlir::edsc::intrinsics::std_ret();
|
mlir::edsc::intrinsics::std_ret();
|
||||||
|
|
||||||
|
mlir::linalg::LinalgTilingOptions tilingOptions;
|
||||||
|
tilingOptions = tilingOptions.setTileSizes(GetMlirGemmTileSize());
|
||||||
|
int64 alignment =
|
||||||
|
target_machine_features_.minimum_alignment_for_allocation(
|
||||||
|
ShapeUtil::ByteSizeOf(dot_info_.result_shape));
|
||||||
|
mlir_strategy::MatmulCodegenStrategy strategy;
|
||||||
|
strategy.tile<mlir::linalg::MatmulOp>(tilingOptions)
|
||||||
|
.promote<mlir::linalg::MatmulOp>(
|
||||||
|
mlir::linalg::LinalgPromotionOptions()
|
||||||
|
.setAlignment(alignment)
|
||||||
|
.setUseFullTileBuffersByDefault(true)
|
||||||
|
.setUseAlloca(true))
|
||||||
|
.vectorize<mlir::linalg::MatmulOp>()
|
||||||
|
.setVectorTransformsOptions(
|
||||||
|
mlir::vector::VectorTransformsOptions()
|
||||||
|
.setVectorTransformsOptions(
|
||||||
|
mlir::vector::VectorContractLowering::OuterProduct))
|
||||||
|
.setVectorTransferToSCFOptions(
|
||||||
|
mlir::VectorTransferToSCFOptions().setUnroll(true));
|
||||||
|
strategy.transform(function);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -948,7 +986,8 @@ DotImplementationStrategy GetDotImplementationStrategy(
|
|||||||
|
|
||||||
if (IsAlignedGemm(dot_info, target_machine_features)) {
|
if (IsAlignedGemm(dot_info, target_machine_features)) {
|
||||||
if (CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)) {
|
if (CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)) {
|
||||||
return options::UseLinalgForDot(config)
|
return primitive_util::IsFloatingPointType(
|
||||||
|
dot_info.result_shape.element_type())
|
||||||
? DotImplementationStrategy::kLinalgMatmul
|
? DotImplementationStrategy::kLinalgMatmul
|
||||||
: DotImplementationStrategy::kTiledLlvmIrGemm;
|
: DotImplementationStrategy::kTiledLlvmIrGemm;
|
||||||
}
|
}
|
||||||
|
@ -17,14 +17,14 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "llvm/Linker/Linker.h"
|
#include "llvm/Linker/Linker.h"
|
||||||
#include "llvm/Transforms/IPO/Internalize.h"
|
#include "llvm/Transforms/IPO/Internalize.h"
|
||||||
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project
|
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project
|
||||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project
|
|
||||||
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project
|
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project
|
||||||
#include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project
|
#include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project
|
||||||
#include "mlir/IR/Module.h" // from @llvm-project
|
#include "mlir/IR/Module.h" // from @llvm-project
|
||||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||||
#include "mlir/Target/LLVMIR.h" // from @llvm-project
|
#include "mlir/Target/LLVMIR.h" // from @llvm-project
|
||||||
|
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -35,9 +35,9 @@ namespace {
|
|||||||
std::unique_ptr<llvm::Module> MakeLLVMModule(mlir::OwningModuleRef module) {
|
std::unique_ptr<llvm::Module> MakeLLVMModule(mlir::OwningModuleRef module) {
|
||||||
mlir::PassManager manager(module->getContext());
|
mlir::PassManager manager(module->getContext());
|
||||||
manager.addPass(mlir::createConvertLinalgToLoopsPass());
|
manager.addPass(mlir::createConvertLinalgToLoopsPass());
|
||||||
manager.addPass(mlir::createConvertLinalgToLLVMPass());
|
manager.addPass(mlir::createLowerAffinePass());
|
||||||
|
manager.addPass(mlir::createLowerToCFGPass());
|
||||||
manager.addPass(mlir::createConvertVectorToLLVMPass());
|
manager.addPass(mlir::createConvertVectorToLLVMPass());
|
||||||
manager.addPass(mlir::createLowerToLLVMPass());
|
|
||||||
CHECK(succeeded(manager.run(*module)));
|
CHECK(succeeded(manager.run(*module)));
|
||||||
return mlir::translateModuleToLLVMIR(*module);
|
return mlir::translateModuleToLLVMIR(*module);
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,269 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/Support/Debug.h"
|
||||||
|
#include "mlir/Analysis/SliceAnalysis.h" // from @llvm-project
|
||||||
|
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" // from @llvm-project
|
||||||
|
#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
|
||||||
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project
|
||||||
|
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" // from @llvm-project
|
||||||
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
|
||||||
|
#include "mlir/Dialect/Linalg/Utils/Utils.h" // from @llvm-project
|
||||||
|
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
|
||||||
|
#include "mlir/Dialect/SCF/Utils.h" // from @llvm-project
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||||
|
#include "mlir/Dialect/Vector/EDSC/Intrinsics.h" // from @llvm-project
|
||||||
|
#include "mlir/Dialect/Vector/VectorOps.h" // from @llvm-project
|
||||||
|
#include "mlir/Dialect/Vector/VectorTransforms.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/AffineExpr.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Dominance.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Value.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||||
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
|
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||||
|
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||||
|
#include "mlir/Transforms/LoopUtils.h" // from @llvm-project
|
||||||
|
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
||||||
|
|
||||||
|
// TODO(kramerb): Remove this once strategy is in mlir core.
|
||||||
|
|
||||||
|
using namespace mlir; // NOLINT
|
||||||
|
using namespace mlir::linalg; // NOLINT
|
||||||
|
|
||||||
|
#define DEBUG_TYPE "matmul-codegen-strategy"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace cpu {
|
||||||
|
namespace mlir_strategy {
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TODO: Cleanup and upstream these to go into core. Please ignore for now !
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
static void hoistRedundantCopies(FuncOp func) {
|
||||||
|
bool changed = true;
|
||||||
|
while (changed) {
|
||||||
|
changed = false;
|
||||||
|
func.walk([&](linalg::FillOp op) {
|
||||||
|
auto loop = op.getParentOfType<scf::ForOp>();
|
||||||
|
if (!loop) return;
|
||||||
|
|
||||||
|
for (auto operand : op.getOperands())
|
||||||
|
if (!loop.isDefinedOutsideOfLoop(operand)) return;
|
||||||
|
|
||||||
|
// Hoist fill before.
|
||||||
|
op.getOperation()->moveBefore(loop);
|
||||||
|
changed = true;
|
||||||
|
});
|
||||||
|
|
||||||
|
func.walk([&](linalg::CopyOp op) {
|
||||||
|
auto loop = op.getParentOfType<scf::ForOp>();
|
||||||
|
if (!loop) return;
|
||||||
|
|
||||||
|
for (auto operand : op.getOperands())
|
||||||
|
if (!loop.isDefinedOutsideOfLoop(operand)) return;
|
||||||
|
|
||||||
|
Value sourceView = op.getInput(0);
|
||||||
|
while (auto subViewOp = sourceView.getDefiningOp<SubViewOp>())
|
||||||
|
sourceView = subViewOp.getViewSource();
|
||||||
|
|
||||||
|
// Source traces back to a block argument.
|
||||||
|
if (sourceView.isa<BlockArgument>()) {
|
||||||
|
op.getOperation()->moveBefore(loop);
|
||||||
|
} else {
|
||||||
|
assert(sourceView.getDefiningOp<ViewOp>() ||
|
||||||
|
sourceView.getDefiningOp<AllocOp>() ||
|
||||||
|
sourceView.getDefiningOp<AllocaOp>());
|
||||||
|
op.getOperation()->moveAfter(loop);
|
||||||
|
}
|
||||||
|
changed = true;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Substitute scf.for = %lb to %ub step %step by an AffineExpr expressing:
|
||||||
|
/// `%lb + %step * new_dim` where
|
||||||
|
/// 1. the AffineExpr for %lb is either an AffineConstantExpr or an
|
||||||
|
/// AffineDimExpr depending on whether the value is constant or not.
|
||||||
|
/// 2. the AffineExpr for %step is either an AffineConstantExpr or an
|
||||||
|
/// AffineSymbolExpr depending on whether the value is constant or not.
|
||||||
|
///
|
||||||
|
static void substitute(scf::ForOp forOp, SmallVectorImpl<AffineExpr> &exprs,
|
||||||
|
SmallVectorImpl<Value> &dims,
|
||||||
|
SmallVectorImpl<Value> &symbols) {
|
||||||
|
MLIRContext *ctx = forOp.getContext();
|
||||||
|
auto lbConstant = forOp.lowerBound().getDefiningOp<ConstantIndexOp>();
|
||||||
|
AffineExpr lb = lbConstant ? getAffineConstantExpr(lbConstant.getValue(), ctx)
|
||||||
|
: getAffineDimExpr(dims.size(), ctx);
|
||||||
|
|
||||||
|
auto stepConstant = forOp.step().getDefiningOp<ConstantIndexOp>();
|
||||||
|
AffineExpr step = stepConstant
|
||||||
|
? getAffineConstantExpr(stepConstant.getValue(), ctx)
|
||||||
|
: getAffineSymbolExpr(symbols.size(), ctx);
|
||||||
|
|
||||||
|
if (!lbConstant) dims.push_back(forOp.lowerBound());
|
||||||
|
if (!stepConstant) symbols.push_back(forOp.step());
|
||||||
|
exprs.push_back(lb + step * getAffineDimExpr(dims.size(), ctx));
|
||||||
|
|
||||||
|
auto ubConstant = forOp.upperBound().getDefiningOp<ConstantIndexOp>();
|
||||||
|
AffineExpr ub = ubConstant ? getAffineConstantExpr(ubConstant.getValue(), ctx)
|
||||||
|
: getAffineDimExpr(dims.size(), ctx);
|
||||||
|
if (!ubConstant) dims.push_back(forOp.upperBound());
|
||||||
|
exprs.push_back(ub);
|
||||||
|
|
||||||
|
dims.push_back(forOp.getInductionVar());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Traverse the .
|
||||||
|
static void substitute(AffineMinOp minOp, SmallVectorImpl<AffineExpr> &exprs,
|
||||||
|
SmallVectorImpl<Value> &dims,
|
||||||
|
SmallVectorImpl<Value> &symbols) {
|
||||||
|
MLIRContext *ctx = minOp.getContext();
|
||||||
|
for (Value v : minOp.getDimOperands()) {
|
||||||
|
if (auto forOp = scf::getForInductionVarOwner(v)) {
|
||||||
|
substitute(forOp, exprs, dims, symbols);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto parentMinOp = v.getDefiningOp<AffineMinOp>()) {
|
||||||
|
substitute(parentMinOp, exprs, dims, symbols);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
exprs.push_back(getAffineDimExpr(dims.size(), ctx));
|
||||||
|
dims.push_back(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Perform folding of chains of AffineMinOp.
|
||||||
|
struct AffineMinCanonicalizationPattern : public OpRewritePattern<AffineMinOp> {
|
||||||
|
using OpRewritePattern<AffineMinOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(AffineMinOp minOp,
|
||||||
|
PatternRewriter &rewriter) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
LogicalResult AffineMinCanonicalizationPattern::matchAndRewrite(
|
||||||
|
AffineMinOp minOp, PatternRewriter &rewriter) const {
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << "\nCanonicalize AffineMin: "
|
||||||
|
<< *minOp.getOperation() << "\n");
|
||||||
|
|
||||||
|
int64_t min = std::numeric_limits<int64_t>::max();
|
||||||
|
for (auto e : minOp.map().getResults())
|
||||||
|
if (auto cstExpr = e.dyn_cast<AffineConstantExpr>())
|
||||||
|
min = std::min(min, cstExpr.getValue());
|
||||||
|
if (min == std::numeric_limits<int64_t>::max()) return failure();
|
||||||
|
|
||||||
|
SmallVector<AffineExpr, 4> exprs;
|
||||||
|
SmallVector<Value, 4> dims, symbols;
|
||||||
|
substitute(minOp, exprs, dims, symbols);
|
||||||
|
|
||||||
|
SmallVector<Value, 4> operands = dims;
|
||||||
|
operands.append(symbols.begin(), symbols.end());
|
||||||
|
|
||||||
|
MLIRContext *ctx = minOp.getContext();
|
||||||
|
auto map = AffineMap::get(dims.size(), symbols.size(), exprs, ctx);
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << "Substitution map: " << map << "\n");
|
||||||
|
|
||||||
|
SmallVector<AffineExpr, 4> modExprs;
|
||||||
|
for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx)
|
||||||
|
modExprs.push_back(getAffineDimExpr(idx, ctx) % min);
|
||||||
|
map = AffineMap::get(map.getNumResults(), 0, modExprs, ctx).compose(map);
|
||||||
|
canonicalizeMapAndOperands(&map, &operands);
|
||||||
|
map = simplifyAffineMap(map);
|
||||||
|
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << "Post mod: " << map << "\n";
|
||||||
|
llvm::interleaveComma(operands, llvm::dbgs()));
|
||||||
|
|
||||||
|
if (!llvm::all_of(map.getResults(), [](AffineExpr e) {
|
||||||
|
if (auto cst = e.dyn_cast<AffineConstantExpr>())
|
||||||
|
return cst.getValue() == 0;
|
||||||
|
return false;
|
||||||
|
}))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, min);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// END TODO
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void MatmulCodegenStrategy::transform(FuncOp func) const {
|
||||||
|
MLIRContext *context = func.getContext();
|
||||||
|
// Emplace patterns one at a time while also maintaining a simple chained
|
||||||
|
// state transition.
|
||||||
|
unsigned stepCount = 0;
|
||||||
|
SmallVector<OwningRewritePatternList, 4> stage1Patterns;
|
||||||
|
auto zeroState = Identifier::get(std::to_string(stepCount), context);
|
||||||
|
auto currentState = zeroState;
|
||||||
|
for (auto &t : transformation_sequence) {
|
||||||
|
auto nextState = Identifier::get(std::to_string(++stepCount), context);
|
||||||
|
auto marker = (currentState == zeroState)
|
||||||
|
? linalg::LinalgMarker({}, nextState)
|
||||||
|
: linalg::LinalgMarker(currentState, nextState);
|
||||||
|
stage1Patterns.emplace_back(t->buildRewritePatterns(context, marker));
|
||||||
|
currentState = nextState;
|
||||||
|
}
|
||||||
|
|
||||||
|
OwningRewritePatternList stage2Patterns =
|
||||||
|
linalg::getLinalgTilingCanonicalizationPatterns(context);
|
||||||
|
stage2Patterns.insert<AffineMinCanonicalizationPattern>(context);
|
||||||
|
|
||||||
|
auto stage3Transforms = [](Operation *op) {
|
||||||
|
// Some of these may be too aggressive as a stage 3 that is applied on each
|
||||||
|
// stage 1 application and may have to be split out to post staged patterns
|
||||||
|
// application (in which case they could just be passes, TBD).
|
||||||
|
PassManager pm(op->getContext());
|
||||||
|
pm.addPass(createLoopInvariantCodeMotionPass());
|
||||||
|
if (failed(pm.run(op->getParentOfType<ModuleOp>())))
|
||||||
|
llvm_unreachable("Unexpected failure in cleanup pass pipeline.");
|
||||||
|
promoteSingleIterationLoops(cast<FuncOp>(op));
|
||||||
|
hoistViewAllocOps(cast<FuncOp>(op));
|
||||||
|
hoistRedundantVectorTransfers(cast<FuncOp>(op));
|
||||||
|
hoistRedundantCopies(cast<FuncOp>(op));
|
||||||
|
return success();
|
||||||
|
};
|
||||||
|
linalg::applyStagedPatterns(func, stage1Patterns, stage2Patterns,
|
||||||
|
stage3Transforms);
|
||||||
|
|
||||||
|
//===--------------------------------------------------------------------===//
|
||||||
|
// Post staged patterns transforms
|
||||||
|
//===--------------------------------------------------------------------===//
|
||||||
|
// Programmatic controlled lowering of vector.contract only.
|
||||||
|
OwningRewritePatternList vectorContractLoweringPatterns;
|
||||||
|
vectorContractLoweringPatterns
|
||||||
|
.insert<ContractionOpToOuterProductOpLowering,
|
||||||
|
ContractionOpToMatmulOpLowering, ContractionOpLowering>(
|
||||||
|
vector_transforms_options, context);
|
||||||
|
applyPatternsAndFoldGreedily(func, vectorContractLoweringPatterns);
|
||||||
|
|
||||||
|
// Programmatic controlled lowering of vector.transfer only.
|
||||||
|
OwningRewritePatternList vectorToLoopsPatterns;
|
||||||
|
populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
|
||||||
|
vector_to_scf_options);
|
||||||
|
applyPatternsAndFoldGreedily(func, vectorToLoopsPatterns);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlir_strategy
|
||||||
|
} // namespace cpu
|
||||||
|
} // namespace xla
|
@ -0,0 +1,188 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MLIR_EDGE_BENCHMARKS_STRATEGIES_MATMULCODEGENSTRATEGIES_H_
|
||||||
|
#define MLIR_EDGE_BENCHMARKS_STRATEGIES_MATMULCODEGENSTRATEGIES_H_
|
||||||
|
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/ADT/StringSwitch.h"
|
||||||
|
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" // from @llvm-project
|
||||||
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
|
||||||
|
#include "mlir/Dialect/Vector/VectorOps.h" // from @llvm-project
|
||||||
|
#include "mlir/Dialect/Vector/VectorTransforms.h" // from @llvm-project
|
||||||
|
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||||
|
|
||||||
|
// TODO(kramerb): Remove this once strategy is in mlir core.
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace cpu {
|
||||||
|
namespace mlir_strategy {
|
||||||
|
|
||||||
|
/// Abstract Transformation class applied in a sequence that also handles state
|
||||||
|
/// through markers.
|
||||||
|
struct Transformation {
|
||||||
|
virtual ~Transformation() = default;
|
||||||
|
virtual mlir::OwningRewritePatternList buildRewritePatterns(
|
||||||
|
mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) = 0;
|
||||||
|
mlir::linalg::LinalgMarker marker;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Promotion transformation enqueues a particular stage-1 pattern for
|
||||||
|
/// `Tile<LinalgOpType>`with the appropriate `options`.
|
||||||
|
// TODO: variadic LinalgOpTypes.
|
||||||
|
template <typename LinalgOpType>
|
||||||
|
struct Tile : public Transformation {
|
||||||
|
explicit Tile(mlir::linalg::LinalgTilingOptions options) : options(options) {}
|
||||||
|
|
||||||
|
mlir::OwningRewritePatternList buildRewritePatterns(
|
||||||
|
mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) override {
|
||||||
|
mlir::OwningRewritePatternList tiling_patterns;
|
||||||
|
tiling_patterns.insert<mlir::linalg::LinalgTilingPattern<LinalgOpType>>(
|
||||||
|
context, options, m);
|
||||||
|
return tiling_patterns;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
mlir::linalg::LinalgTilingOptions options;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Promotion transformation enqueues a particular stage-1 pattern for
|
||||||
|
/// `Promote<LinalgOpType>`with the appropriate `options`.
|
||||||
|
// TODO: variadic LinalgOpTypes.
|
||||||
|
template <typename LinalgOpType>
|
||||||
|
struct Promote : public Transformation {
|
||||||
|
explicit Promote(mlir::linalg::LinalgPromotionOptions options)
|
||||||
|
: options(options) {}
|
||||||
|
|
||||||
|
mlir::OwningRewritePatternList buildRewritePatterns(
|
||||||
|
mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) override {
|
||||||
|
mlir::OwningRewritePatternList promotion_patterns;
|
||||||
|
promotion_patterns
|
||||||
|
.insert<mlir::linalg::LinalgPromotionPattern<LinalgOpType>>(context,
|
||||||
|
options, m);
|
||||||
|
return promotion_patterns;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
mlir::linalg::LinalgPromotionOptions options;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Vectorization transformation enqueues a particular stage-1 pattern for
|
||||||
|
/// `LinalgVectorizationPattern<LinalgOpType>` as well as copy to vector
|
||||||
|
/// transfer rewrite forwarding patterns.
|
||||||
|
// TODO: variadic LinalgOpTypes.
|
||||||
|
template <typename LinalgOpType>
|
||||||
|
struct Vectorize : public Transformation {
|
||||||
|
mlir::OwningRewritePatternList buildRewritePatterns(
|
||||||
|
mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) override {
|
||||||
|
mlir::OwningRewritePatternList vectorization_patterns;
|
||||||
|
// FillOp may interfere with forwarding patterns atm, so we bump up the
|
||||||
|
// priority of LinalgCopyVTRForwardingPattern /
|
||||||
|
// LinalgCopyVTWForwardingPattern.
|
||||||
|
vectorization_patterns
|
||||||
|
.insert<mlir::linalg::LinalgVectorizationPattern<LinalgOpType>>(context,
|
||||||
|
m);
|
||||||
|
vectorization_patterns.insert<mlir::linalg::LinalgCopyVTRForwardingPattern,
|
||||||
|
mlir::linalg::LinalgCopyVTWForwardingPattern>(
|
||||||
|
context,
|
||||||
|
/*benefit=*/2);
|
||||||
|
return vectorization_patterns;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Matmul-specific strategy object controls how a linalg.matmul is
|
||||||
|
/// progressively lowered.
|
||||||
|
/// The strategy uses a 3-level staged patterns strategy which allows ordering
|
||||||
|
/// transformations by using the Linalg `applyStagedPatterns` function, where:
|
||||||
|
/// 1. The first stage consists of the successive `tile`, `promote` and
|
||||||
|
/// `vectorize` patterns, applied sequentially.
|
||||||
|
/// 2. The second stage consists of common local canonicalization patterns
|
||||||
|
/// that are applied eagerly after each stage-1 pattern.
|
||||||
|
/// 3. the third stage consists of more global transformation, also applied
|
||||||
|
/// eagerly, after all stage-2 patterns. Such more global transformations
|
||||||
|
struct MatmulCodegenStrategy {
|
||||||
|
/// Append a pattern to add a level of tiling for `LinalgOpType` with tiling
|
||||||
|
/// `options`.
|
||||||
|
template <typename LinalgOpType>
|
||||||
|
MatmulCodegenStrategy &tile(mlir::linalg::LinalgTilingOptions options) {
|
||||||
|
transformation_sequence.emplace_back(new Tile<LinalgOpType>(options));
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
/// Conditionally append a pattern to add a level of tiling for `LinalgOpType`
|
||||||
|
/// with tiling `options`.
|
||||||
|
template <typename LinalgOpType>
|
||||||
|
MatmulCodegenStrategy &tileIf(bool b,
|
||||||
|
mlir::linalg::LinalgTilingOptions options) {
|
||||||
|
return b ? tile<LinalgOpType>(options) : *this;
|
||||||
|
}
|
||||||
|
/// Append a pattern to add a level of promotion for `LinalgOpType` with
|
||||||
|
/// promotion `options`.
|
||||||
|
template <typename LinalgOpType>
|
||||||
|
MatmulCodegenStrategy &promote(mlir::linalg::LinalgPromotionOptions options) {
|
||||||
|
transformation_sequence.emplace_back(new Promote<LinalgOpType>(options));
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
/// Conditionally append a pattern to add a level of promotion for
|
||||||
|
/// `LinalgOpType` with promotion `options`.
|
||||||
|
template <typename LinalgOpType>
|
||||||
|
MatmulCodegenStrategy &promoteIf(
|
||||||
|
bool b, mlir::linalg::LinalgPromotionOptions options) {
|
||||||
|
return b ? promote<LinalgOpType>(options) : *this;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
/// Append a pattern to rewrite `LinalgOpType` as a vector operation.
|
||||||
|
template <typename LinalgOpType>
|
||||||
|
MatmulCodegenStrategy &vectorize() {
|
||||||
|
transformation_sequence.emplace_back(new Vectorize<LinalgOpType>());
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
/// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
|
||||||
|
/// operation.
|
||||||
|
template <typename LinalgOpType>
|
||||||
|
MatmulCodegenStrategy &vectorizeIf(bool b) {
|
||||||
|
return b ? vectorize<LinalgOpType>() : *this;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
/// Configure the post staged-patterns late vector transformations.
|
||||||
|
MatmulCodegenStrategy &setVectorTransformsOptions(
|
||||||
|
mlir::vector::VectorTransformsOptions options) {
|
||||||
|
vector_transforms_options = options;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
/// Configure the post staged-patterns late vector.transfer to scf conversion.
|
||||||
|
MatmulCodegenStrategy &setVectorTransferToSCFOptions(
|
||||||
|
mlir::VectorTransferToSCFOptions options) {
|
||||||
|
vector_to_scf_options = options;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply the transformation patterns in sequence with cleanup transformations
|
||||||
|
/// interleaved.
|
||||||
|
void transform(mlir::FuncOp func) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
mlir::LogicalResult postPatternTransforms(mlir::Operation *func) const;
|
||||||
|
|
||||||
|
mlir::vector::VectorTransformsOptions vector_transforms_options;
|
||||||
|
mlir::VectorTransferToSCFOptions vector_to_scf_options;
|
||||||
|
llvm::SmallVector<std::unique_ptr<Transformation>, 4> transformation_sequence;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlir_strategy
|
||||||
|
} // namespace cpu
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // MLIR_EDGE_BENCHMARKS_STRATEGIES_MATMULCODEGENSTRATEGIES_H_
|
@ -52,6 +52,12 @@ class TargetMachineFeatures {
|
|||||||
virtual int vector_register_num_elements(const llvm::Function& function,
|
virtual int vector_register_num_elements(const llvm::Function& function,
|
||||||
PrimitiveType type) const = 0;
|
PrimitiveType type) const = 0;
|
||||||
|
|
||||||
|
// Return the number of vector registers. We need to pass in
|
||||||
|
// "function" since llvm functions can contain annotations for specializing
|
||||||
|
// them to specific micro-architectures (though currently XLA does not use
|
||||||
|
// this functionality).
|
||||||
|
virtual int vector_register_count(const llvm::Function& function) const = 0;
|
||||||
|
|
||||||
// Returns the minimum alignment for a buffer of size size_bytes.
|
// Returns the minimum alignment for a buffer of size size_bytes.
|
||||||
virtual int64 minimum_alignment_for_allocation(int64 size_bytes) const = 0;
|
virtual int64 minimum_alignment_for_allocation(int64 size_bytes) const = 0;
|
||||||
|
|
||||||
@ -84,6 +90,12 @@ class LLVMTargetMachineFeatures : public TargetMachineFeatures {
|
|||||||
(primitive_util::BitWidth(type) / 8);
|
(primitive_util::BitWidth(type) / 8);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int vector_register_count(const llvm::Function& function) const override {
|
||||||
|
llvm::TargetTransformInfo* tti = GetTargetTransformInfoFor(function);
|
||||||
|
return static_cast<int>(tti->getNumberOfRegisters(
|
||||||
|
tti->getRegisterClassForType(/*Vector=*/true)));
|
||||||
|
}
|
||||||
|
|
||||||
int64 minimum_alignment_for_allocation(int64 size_bytes) const override;
|
int64 minimum_alignment_for_allocation(int64 size_bytes) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -44,6 +44,10 @@ class TargetMachineFeaturesWithFakeAlignmentLogic
|
|||||||
LOG(FATAL) << "Unexpected call to " << __func__;
|
LOG(FATAL) << "Unexpected call to " << __func__;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int vector_register_count(const llvm::Function& function) const override {
|
||||||
|
LOG(FATAL) << "Unexpected call to " << __func__;
|
||||||
|
}
|
||||||
|
|
||||||
int64 minimum_alignment_for_allocation(int64 size_bytes) const override {
|
int64 minimum_alignment_for_allocation(int64 size_bytes) const override {
|
||||||
return fake_alignment_logic_(size_bytes);
|
return fake_alignment_logic_(size_bytes);
|
||||||
}
|
}
|
||||||
|
@ -1642,7 +1642,8 @@ void AlternateMemoryBestFitHeap::AddAsyncCopy(
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies(
|
bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies(
|
||||||
int64 start_time, int64 end_time, bool is_prefetch) const {
|
int64 start_time, int64 end_time, bool is_prefetch,
|
||||||
|
int64 extra_async_copy_limit) const {
|
||||||
if (options_.max_outstanding_prefetches < 0 && is_prefetch) {
|
if (options_.max_outstanding_prefetches < 0 && is_prefetch) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -1655,12 +1656,14 @@ bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies(
|
|||||||
int64 num_prefetches =
|
int64 num_prefetches =
|
||||||
prefetch_interval_tree_.ChunksOverlappingInTime(start_time, end_time)
|
prefetch_interval_tree_.ChunksOverlappingInTime(start_time, end_time)
|
||||||
.size();
|
.size();
|
||||||
return num_prefetches >= options_.max_outstanding_prefetches;
|
return num_prefetches >=
|
||||||
|
options_.max_outstanding_prefetches + extra_async_copy_limit;
|
||||||
} else {
|
} else {
|
||||||
int64 num_evictions =
|
int64 num_evictions =
|
||||||
eviction_interval_tree_.ChunksOverlappingInTime(start_time, end_time)
|
eviction_interval_tree_.ChunksOverlappingInTime(start_time, end_time)
|
||||||
.size();
|
.size();
|
||||||
return num_evictions >= options_.max_outstanding_evictions;
|
return num_evictions >=
|
||||||
|
options_.max_outstanding_evictions + extra_async_copy_limit;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1911,6 +1914,11 @@ bool AlternateMemoryBestFitHeap::Prefetch(
|
|||||||
// outstanding async copy limit or async copy ordering, set
|
// outstanding async copy limit or async copy ordering, set
|
||||||
// prefetch_failed_due_to_async_copy_.
|
// prefetch_failed_due_to_async_copy_.
|
||||||
prefetch_failed_due_to_async_copy_ = false;
|
prefetch_failed_due_to_async_copy_ = false;
|
||||||
|
// While uses might be allowed to have additional outstanding prefetches.
|
||||||
|
int64 extra_async_copy_limit =
|
||||||
|
request.use->hlo_use.instruction->opcode() == HloOpcode::kWhile
|
||||||
|
? options_.while_use_extra_outstanding_prefetch_limit
|
||||||
|
: 0;
|
||||||
while (!options_.prefetch_interval_picker->Done()) {
|
while (!options_.prefetch_interval_picker->Done()) {
|
||||||
alternate_mem_interval.start = options_.prefetch_interval_picker->Next();
|
alternate_mem_interval.start = options_.prefetch_interval_picker->Next();
|
||||||
CHECK_LT(alternate_mem_interval.start, request.latest_prefetch_time);
|
CHECK_LT(alternate_mem_interval.start, request.latest_prefetch_time);
|
||||||
@ -1924,9 +1932,9 @@ bool AlternateMemoryBestFitHeap::Prefetch(
|
|||||||
prefetch_failed_due_to_async_copy_ = true;
|
prefetch_failed_due_to_async_copy_ = true;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (ViolatesMaximumOutstandingAsyncCopies(alternate_mem_interval.start,
|
if (ViolatesMaximumOutstandingAsyncCopies(
|
||||||
request.latest_prefetch_time,
|
alternate_mem_interval.start, request.latest_prefetch_time,
|
||||||
/*is_prefetch=*/true)) {
|
/*is_prefetch=*/true, extra_async_copy_limit)) {
|
||||||
VLOG(4) << "This would violate the outstanding async copy limit.";
|
VLOG(4) << "This would violate the outstanding async copy limit.";
|
||||||
prefetch_failed_due_to_async_copy_ = true;
|
prefetch_failed_due_to_async_copy_ = true;
|
||||||
continue;
|
continue;
|
||||||
|
@ -379,6 +379,10 @@ class MemorySpaceAssignment {
|
|||||||
int64 max_outstanding_prefetches = -1;
|
int64 max_outstanding_prefetches = -1;
|
||||||
int64 max_outstanding_evictions = -1;
|
int64 max_outstanding_evictions = -1;
|
||||||
|
|
||||||
|
// Extra outstanding prefetch limit for while uses (in addition to
|
||||||
|
// max_outstanding_prefetches).
|
||||||
|
int64 while_use_extra_outstanding_prefetch_limit = 0;
|
||||||
|
|
||||||
// Specifies the maximum number of retries that will be performed for each
|
// Specifies the maximum number of retries that will be performed for each
|
||||||
// value in case prefetching failed due to running out of asynchronous
|
// value in case prefetching failed due to running out of asynchronous
|
||||||
// copies or asynchronous copy ordering.
|
// copies or asynchronous copy ordering.
|
||||||
@ -1019,9 +1023,12 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
|||||||
void AddToChunkMap(const HloValue* buffer, Chunk chunk) override {}
|
void AddToChunkMap(const HloValue* buffer, Chunk chunk) override {}
|
||||||
|
|
||||||
// Returns true if the addition of an asynchronous copy in the given time
|
// Returns true if the addition of an asynchronous copy in the given time
|
||||||
// interval would violate the maximum number of asynchronous copies.
|
// interval would violate the maximum number of asynchronous copies. An extra
|
||||||
bool ViolatesMaximumOutstandingAsyncCopies(int64 start_time, int64 end_time,
|
// async copy limit can be provided to increase the limit of asynchronous
|
||||||
bool is_prefetch) const;
|
// copies for this instance.
|
||||||
|
bool ViolatesMaximumOutstandingAsyncCopies(
|
||||||
|
int64 start_time, int64 end_time, bool is_prefetch,
|
||||||
|
int64 extra_async_copy_limit = 0) const;
|
||||||
|
|
||||||
// Return true if the asynchronous copy would violate the pipelining order.
|
// Return true if the asynchronous copy would violate the pipelining order.
|
||||||
bool ViolatesAsyncCopyOrdering(int64 start_time, int64 end_time) const;
|
bool ViolatesAsyncCopyOrdering(int64 start_time, int64 end_time) const;
|
||||||
|
@ -18,6 +18,10 @@ limitations under the License.
|
|||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#include "third_party/gpus/cuda/include/cuda.h"
|
#include "third_party/gpus/cuda/include/cuda.h"
|
||||||
#endif
|
#endif
|
||||||
|
#if TENSORFLOW_USE_ROCM
|
||||||
|
#include "rocm/include/hip/hip_fp16.h"
|
||||||
|
typedef __half2 half2;
|
||||||
|
#endif
|
||||||
#include "tensorflow/core/kernels/fused_batch_norm_op.h"
|
#include "tensorflow/core/kernels/fused_batch_norm_op.h"
|
||||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||||
|
|
||||||
@ -174,6 +178,11 @@ template <TensorFormat tensor_format, bool add_side_input,
|
|||||||
struct FusedBatchNormInferenceKernel<Eigen::half, float, tensor_format,
|
struct FusedBatchNormInferenceKernel<Eigen::half, float, tensor_format,
|
||||||
add_side_input, activation_mode,
|
add_side_input, activation_mode,
|
||||||
/*is_generic_kernel=*/false> {
|
/*is_generic_kernel=*/false> {
|
||||||
|
#if TENSORFLOW_USE_ROCM
|
||||||
|
using IT = __half;
|
||||||
|
#else
|
||||||
|
using IT = Eigen::half;
|
||||||
|
#endif
|
||||||
using T = Eigen::half;
|
using T = Eigen::half;
|
||||||
using U = float;
|
using U = float;
|
||||||
|
|
||||||
@ -185,15 +194,19 @@ struct FusedBatchNormInferenceKernel<Eigen::half, float, tensor_format,
|
|||||||
/*is_generic_kernel=*/true>;
|
/*is_generic_kernel=*/true>;
|
||||||
|
|
||||||
__device__ static void run(int32 count, int32 channels_size,
|
__device__ static void run(int32 count, int32 channels_size,
|
||||||
int32 inner_dim_size, const T* __restrict__ in,
|
int32 inner_dim_size, const T* __restrict__ _in,
|
||||||
const U* __restrict__ scale,
|
const U* __restrict__ scale,
|
||||||
const U* __restrict__ offset,
|
const U* __restrict__ offset,
|
||||||
const U* __restrict__ mean,
|
const U* __restrict__ mean,
|
||||||
const U* __restrict__ var,
|
const U* __restrict__ var,
|
||||||
const T* __restrict__ side_input, float epsilon,
|
const T* __restrict__ _side_input, float epsilon,
|
||||||
T* __restrict__ out) {
|
T* __restrict__ _out) {
|
||||||
// Old GPUs do not have (or have very slow) fp16 arithmetic.
|
// Old GPUs do not have (or have very slow) fp16 arithmetic.
|
||||||
#if __CUDA_ARCH__ >= 610
|
#if (__CUDA_ARCH__ >= 610) || TENSORFLOW_USE_ROCM
|
||||||
|
const IT* in = reinterpret_cast<const IT*>(_in);
|
||||||
|
const IT* side_input = reinterpret_cast<const IT*>(_side_input);
|
||||||
|
IT* out = reinterpret_cast<IT*>(_out);
|
||||||
|
|
||||||
int32 index = blockIdx.x * blockDim.x + threadIdx.x;
|
int32 index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
const int32 total_device_threads = gridDim.x * blockDim.x;
|
const int32 total_device_threads = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
@ -274,8 +287,8 @@ struct FusedBatchNormInferenceKernel<Eigen::half, float, tensor_format,
|
|||||||
}
|
}
|
||||||
|
|
||||||
#else
|
#else
|
||||||
GenericKernel::run(count, channels_size, inner_dim_size, in, scale, offset,
|
GenericKernel::run(count, channels_size, inner_dim_size, _in, scale, offset,
|
||||||
mean, var, side_input, epsilon, out);
|
mean, var, _side_input, epsilon, _out);
|
||||||
#endif // __CUDA_ARCH__ >= 610
|
#endif // __CUDA_ARCH__ >= 610
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -287,10 +300,16 @@ __global__ void FusedBatchNormInferenceMetaKernel(
|
|||||||
const U* scale, const U* offset, const U* mean, const U* var,
|
const U* scale, const U* offset, const U* mean, const U* var,
|
||||||
const T* side_input, float epsilon, T* out) {
|
const T* side_input, float epsilon, T* out) {
|
||||||
// We prefer to run non-generic specialization, for the given types T and U.
|
// We prefer to run non-generic specialization, for the given types T and U.
|
||||||
// TODO(b/135435976): Temporary disable non-generic kernel implementation.
|
FusedBatchNormInferenceKernel<T, U, tensor_format, add_side_input,
|
||||||
FusedBatchNormInferenceKernel<
|
activation_mode,
|
||||||
T, U, tensor_format, add_side_input, activation_mode,
|
#if TENSORFLOW_USE_ROCM
|
||||||
/*is_generic_kernel=*/true>::run(count, channels_size, inner_dim_size, in,
|
false
|
||||||
|
#else
|
||||||
|
// TODO(b/135435976): Temporary disable
|
||||||
|
// non-generic kernel implementation.
|
||||||
|
/*is_generic_kernel=*/true
|
||||||
|
#endif
|
||||||
|
>::run(count, channels_size, inner_dim_size, in,
|
||||||
scale, offset, mean, var, side_input,
|
scale, offset, mean, var, side_input,
|
||||||
epsilon, out);
|
epsilon, out);
|
||||||
}
|
}
|
||||||
@ -312,7 +331,11 @@ struct FusedBatchNormInferenceFunctor<GPUDevice, T, U> {
|
|||||||
if (count == 0) return;
|
if (count == 0) return;
|
||||||
|
|
||||||
bool launched = false;
|
bool launched = false;
|
||||||
|
#if TENSORFLOW_USE_ROCM
|
||||||
|
constexpr int32 kThreadInBlock = 1024;
|
||||||
|
#else
|
||||||
constexpr int32 kThreadInBlock = 512;
|
constexpr int32 kThreadInBlock = 512;
|
||||||
|
#endif
|
||||||
|
|
||||||
#define LAUNCH(DATA_FORMAT, ADD_SIDE_INPUT, ACTIVATION, CHANNEL_SIZE, \
|
#define LAUNCH(DATA_FORMAT, ADD_SIDE_INPUT, ACTIVATION, CHANNEL_SIZE, \
|
||||||
INNER_DIM_SIZE) \
|
INNER_DIM_SIZE) \
|
||||||
|
@ -108,7 +108,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
|
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
|
||||||
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
|
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
|
||||||
#define TF_GRAPH_DEF_VERSION 449 // Updated: 2020/7/1
|
#define TF_GRAPH_DEF_VERSION 450 // Updated: 2020/7/2
|
||||||
|
|
||||||
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
|
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
|
||||||
//
|
//
|
||||||
|
@ -37,14 +37,14 @@ absl::Status ObjectReader::ReadNonConstantTensor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (tensor_to_value->find(tensor_idx) == tensor_to_value->end()) {
|
if (tensor_to_value->find(tensor_idx) == tensor_to_value->end()) {
|
||||||
TfLiteTensor& tflite_tensor = context->tensors[tensor_idx];
|
TfLiteTensor* tflite_tensor = &context->tensors[tensor_idx];
|
||||||
if (tflite::IsConstantTensor(&tflite_tensor)) {
|
if (tflite::IsConstantTensor(tflite_tensor)) {
|
||||||
return absl::InvalidArgumentError(absl::StrCat(
|
return absl::InvalidArgumentError(absl::StrCat(
|
||||||
"ReadNonConstantTensor: value is a constant tensor: ", tensor_idx));
|
"ReadNonConstantTensor: value is a constant tensor: ", tensor_idx));
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((tflite_tensor.type == kTfLiteInt8 ||
|
if ((tflite_tensor->type == kTfLiteInt8 ||
|
||||||
tflite_tensor.type == kTfLiteUInt8) &&
|
tflite_tensor->type == kTfLiteUInt8) &&
|
||||||
quant_conversion_map) {
|
quant_conversion_map) {
|
||||||
// Quantized case
|
// Quantized case
|
||||||
if (quant_conversion_map->find(tensor_idx) ==
|
if (quant_conversion_map->find(tensor_idx) ==
|
||||||
@ -70,9 +70,9 @@ absl::Status ObjectReader::ReadNonConstantTensor(
|
|||||||
value->quant_params.emplace();
|
value->quant_params.emplace();
|
||||||
// tflite_tensor from the outer scope is invalidated due to calling
|
// tflite_tensor from the outer scope is invalidated due to calling
|
||||||
// CreateNewTensorWithDifferentType
|
// CreateNewTensorWithDifferentType
|
||||||
tflite_tensor = context->tensors[tensor_idx];
|
tflite_tensor = &context->tensors[tensor_idx];
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
PopulateQuantParams(tflite_tensor, &value->quant_params.value()));
|
PopulateQuantParams(*tflite_tensor, &value->quant_params.value()));
|
||||||
(*tensor_to_value)[fp_tensor_index] = value;
|
(*tensor_to_value)[fp_tensor_index] = value;
|
||||||
}
|
}
|
||||||
// We do not use the original tensor index as reference for the GPU
|
// We do not use the original tensor index as reference for the GPU
|
||||||
@ -82,7 +82,7 @@ absl::Status ObjectReader::ReadNonConstantTensor(
|
|||||||
// Floating-point case.
|
// Floating-point case.
|
||||||
Value* value = graph->NewValue();
|
Value* value = graph->NewValue();
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
ConvertTfLiteTensorToTensorRef(tflite_tensor, &value->tensor));
|
ConvertTfLiteTensorToTensorRef(*tflite_tensor, &value->tensor));
|
||||||
value->tensor.ref = tensor_idx;
|
value->tensor.ref = tensor_idx;
|
||||||
(*tensor_to_value)[tensor_idx] = value;
|
(*tensor_to_value)[tensor_idx] = value;
|
||||||
}
|
}
|
||||||
|
@ -21,7 +21,19 @@ sh_binary(
|
|||||||
# When the static framework is built with bazel, the all header files are moved
|
# When the static framework is built with bazel, the all header files are moved
|
||||||
# to the "Headers" directory with no header path prefixes. This auxiliary rule
|
# to the "Headers" directory with no header path prefixes. This auxiliary rule
|
||||||
# is used for stripping the path prefix to the "common.h" file included by the
|
# is used for stripping the path prefix to the "common.h" file included by the
|
||||||
# "xnnpack_delegate.h" header.
|
# "c_api.h" header.
|
||||||
|
genrule(
|
||||||
|
name = "strip_c_api_include_hdr",
|
||||||
|
srcs = ["//tensorflow/lite/c:c_api.h"],
|
||||||
|
outs = ["c_api.h"],
|
||||||
|
cmd = """
|
||||||
|
sed 's|#include ".*common.h"|#include "common.h"|'\
|
||||||
|
"$(location //tensorflow/lite/c:c_api.h)"\
|
||||||
|
> "$@"
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Similar rule as above, but for the "xnnpack_delegate.h" header.
|
||||||
genrule(
|
genrule(
|
||||||
name = "strip_xnnpack_include_hdr",
|
name = "strip_xnnpack_include_hdr",
|
||||||
srcs = ["//tensorflow/lite/delegates/xnnpack:xnnpack_delegate.h"],
|
srcs = ["//tensorflow/lite/delegates/xnnpack:xnnpack_delegate.h"],
|
||||||
@ -37,8 +49,8 @@ genrule(
|
|||||||
tflite_ios_static_framework(
|
tflite_ios_static_framework(
|
||||||
name = "TensorFlowLiteC_framework",
|
name = "TensorFlowLiteC_framework",
|
||||||
hdrs = [
|
hdrs = [
|
||||||
|
":c_api.h",
|
||||||
":xnnpack_delegate.h",
|
":xnnpack_delegate.h",
|
||||||
"//tensorflow/lite/c:c_api.h",
|
|
||||||
"//tensorflow/lite/c:common.h",
|
"//tensorflow/lite/c:common.h",
|
||||||
],
|
],
|
||||||
bundle_name = "TensorFlowLiteC",
|
bundle_name = "TensorFlowLiteC",
|
||||||
@ -136,12 +148,15 @@ cc_library(
|
|||||||
# Used for building TensorFlowLiteC framework.
|
# Used for building TensorFlowLiteC framework.
|
||||||
build_test(
|
build_test(
|
||||||
name = "framework_build_test",
|
name = "framework_build_test",
|
||||||
|
# build_test targets are not meant to be run with sanitizers.
|
||||||
tags = [
|
tags = [
|
||||||
"noasan", # b/147230742
|
"noasan",
|
||||||
"nomsan", # b/145205324
|
"nomsan",
|
||||||
"notsan", # b/145205324
|
"notsan",
|
||||||
],
|
],
|
||||||
targets = [
|
targets = [
|
||||||
|
":TensorFlowLiteCCoreML_framework",
|
||||||
|
":TensorFlowLiteCMetal_framework",
|
||||||
":TensorFlowLiteC_framework",
|
":TensorFlowLiteC_framework",
|
||||||
":TensorFlowLiteSelectTfOps_framework",
|
":TensorFlowLiteSelectTfOps_framework",
|
||||||
],
|
],
|
||||||
|
@ -39,6 +39,7 @@ BenchmarkParams BenchmarkModel::DefaultParams() {
|
|||||||
params.AddParam("output_prefix", BenchmarkParam::Create<std::string>(""));
|
params.AddParam("output_prefix", BenchmarkParam::Create<std::string>(""));
|
||||||
params.AddParam("warmup_runs", BenchmarkParam::Create<int32_t>(1));
|
params.AddParam("warmup_runs", BenchmarkParam::Create<int32_t>(1));
|
||||||
params.AddParam("warmup_min_secs", BenchmarkParam::Create<float>(0.5f));
|
params.AddParam("warmup_min_secs", BenchmarkParam::Create<float>(0.5f));
|
||||||
|
params.AddParam("verbose", BenchmarkParam::Create<bool>(false));
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -100,31 +101,31 @@ std::vector<Flag> BenchmarkModel::GetFlags() {
|
|||||||
"warmup_min_secs", ¶ms_,
|
"warmup_min_secs", ¶ms_,
|
||||||
"minimum number of seconds to rerun for, potentially making the "
|
"minimum number of seconds to rerun for, potentially making the "
|
||||||
"actual number of warm-up runs to be greater than warmup_runs"),
|
"actual number of warm-up runs to be greater than warmup_runs"),
|
||||||
|
CreateFlag<bool>("verbose", ¶ms_,
|
||||||
|
"Whether to log parameters whose values are not set. "
|
||||||
|
"By default, only log those parameters that are set by "
|
||||||
|
"parsing their values from the commandline flag.."),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define LOG_PARAM(type, name, prefix, suffix) \
|
||||||
|
LOG_BENCHMARK_PARAM(params_, type, name, prefix, suffix, verbose)
|
||||||
void BenchmarkModel::LogParams() {
|
void BenchmarkModel::LogParams() {
|
||||||
TFLITE_LOG(INFO) << "Min num runs: [" << params_.Get<int32_t>("num_runs")
|
const bool verbose = params_.Get<bool>("verbose");
|
||||||
<< "]";
|
LOG_PARAM(int32_t, "num_runs", "Min num runs: [", "]");
|
||||||
TFLITE_LOG(INFO) << "Min runs duration (seconds): ["
|
LOG_PARAM(int32_t, "num_runs", "Min num runs: [", "]");
|
||||||
<< params_.Get<float>("min_secs") << "]";
|
LOG_PARAM(float, "min_secs", "Min runs duration (seconds): [", "]");
|
||||||
TFLITE_LOG(INFO) << "Max runs duration (seconds): ["
|
LOG_PARAM(float, "max_secs", "Max runs duration (seconds): [", "]");
|
||||||
<< params_.Get<float>("max_secs") << "]";
|
LOG_PARAM(float, "run_delay", "Inter-run delay (seconds): [", "]");
|
||||||
TFLITE_LOG(INFO) << "Inter-run delay (seconds): ["
|
LOG_PARAM(int32_t, "num_threads", "Num threads: [", "]");
|
||||||
<< params_.Get<float>("run_delay") << "]";
|
LOG_PARAM(bool, "use_caching", "Use caching: [", "]");
|
||||||
TFLITE_LOG(INFO) << "Num threads: [" << params_.Get<int32_t>("num_threads")
|
LOG_PARAM(std::string, "benchmark_name", "Benchmark name: [", "]");
|
||||||
<< "]";
|
LOG_PARAM(std::string, "output_prefix", "Output prefix: [", "]");
|
||||||
TFLITE_LOG(INFO) << "Use caching: [" << params_.Get<bool>("use_caching")
|
LOG_PARAM(int32_t, "warmup_runs", "Min warmup runs: [", "]");
|
||||||
<< "]";
|
LOG_PARAM(float, "warmup_min_secs", "Min warmup runs duration (seconds): [",
|
||||||
TFLITE_LOG(INFO) << "Benchmark name: ["
|
"]");
|
||||||
<< params_.Get<std::string>("benchmark_name") << "]";
|
|
||||||
TFLITE_LOG(INFO) << "Output prefix: ["
|
|
||||||
<< params_.Get<std::string>("output_prefix") << "]";
|
|
||||||
TFLITE_LOG(INFO) << "Min warmup runs: ["
|
|
||||||
<< params_.Get<int32_t>("warmup_runs") << "]";
|
|
||||||
TFLITE_LOG(INFO) << "Min warmup runs duration (seconds): ["
|
|
||||||
<< params_.Get<float>("warmup_min_secs") << "]";
|
|
||||||
}
|
}
|
||||||
|
#undef LOG_PARAM
|
||||||
|
|
||||||
TfLiteStatus BenchmarkModel::PrepareInputData() { return kTfLiteOk; }
|
TfLiteStatus BenchmarkModel::PrepareInputData() { return kTfLiteOk; }
|
||||||
|
|
||||||
|
@ -21,6 +21,12 @@ namespace tflite {
|
|||||||
namespace benchmark {
|
namespace benchmark {
|
||||||
using BenchmarkParam = tflite::tools::ToolParam;
|
using BenchmarkParam = tflite::tools::ToolParam;
|
||||||
using BenchmarkParams = tflite::tools::ToolParams;
|
using BenchmarkParams = tflite::tools::ToolParams;
|
||||||
|
|
||||||
|
#define LOG_BENCHMARK_PARAM(params, type, name, prefix, suffix, verbose) \
|
||||||
|
do { \
|
||||||
|
TFLITE_MAY_LOG(INFO, verbose || params.HasValueSet<type>(name)) \
|
||||||
|
<< prefix << params.Get<type>(name) << suffix; \
|
||||||
|
} while (0)
|
||||||
} // namespace benchmark
|
} // namespace benchmark
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
#endif // TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_
|
#endif // TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_
|
||||||
|
@ -47,51 +47,18 @@ enum class ModelGraphType { FP32, INT8, STRING };
|
|||||||
|
|
||||||
BenchmarkParams CreateParams(int32_t num_runs, float min_secs, float max_secs,
|
BenchmarkParams CreateParams(int32_t num_runs, float min_secs, float max_secs,
|
||||||
ModelGraphType graph_type = ModelGraphType::FP32) {
|
ModelGraphType graph_type = ModelGraphType::FP32) {
|
||||||
BenchmarkParams params;
|
BenchmarkParams params = BenchmarkTfLiteModel::DefaultParams();
|
||||||
params.AddParam("num_runs", BenchmarkParam::Create<int32_t>(num_runs));
|
params.Set<int32_t>("num_runs", num_runs);
|
||||||
params.AddParam("min_secs", BenchmarkParam::Create<float>(min_secs));
|
params.Set<float>("min_secs", min_secs);
|
||||||
params.AddParam("max_secs", BenchmarkParam::Create<float>(max_secs));
|
params.Set<float>("max_secs", max_secs);
|
||||||
params.AddParam("run_delay", BenchmarkParam::Create<float>(-1.0f));
|
|
||||||
params.AddParam("num_threads", BenchmarkParam::Create<int32_t>(1));
|
|
||||||
params.AddParam("use_caching", BenchmarkParam::Create<bool>(false));
|
|
||||||
params.AddParam("benchmark_name", BenchmarkParam::Create<std::string>(""));
|
|
||||||
params.AddParam("output_prefix", BenchmarkParam::Create<std::string>(""));
|
|
||||||
params.AddParam("warmup_runs", BenchmarkParam::Create<int32_t>(1));
|
|
||||||
|
|
||||||
if (graph_type == ModelGraphType::INT8) {
|
if (graph_type == ModelGraphType::INT8) {
|
||||||
params.AddParam("graph",
|
params.Set<std::string>("graph", *g_int8_model_path);
|
||||||
BenchmarkParam::Create<std::string>(*g_int8_model_path));
|
|
||||||
} else if (graph_type == ModelGraphType::STRING) {
|
} else if (graph_type == ModelGraphType::STRING) {
|
||||||
params.AddParam("graph",
|
params.Set<std::string>("graph", *g_string_model_path);
|
||||||
BenchmarkParam::Create<std::string>(*g_string_model_path));
|
|
||||||
} else {
|
} else {
|
||||||
// by default, simply use the fp32 one.
|
// by default, simply use the fp32 one.
|
||||||
params.AddParam("graph",
|
params.Set<std::string>("graph", *g_fp32_model_path);
|
||||||
BenchmarkParam::Create<std::string>(*g_fp32_model_path));
|
|
||||||
}
|
|
||||||
|
|
||||||
params.AddParam("input_layer", BenchmarkParam::Create<std::string>(""));
|
|
||||||
params.AddParam("input_layer_shape", BenchmarkParam::Create<std::string>(""));
|
|
||||||
params.AddParam("input_layer_value_range",
|
|
||||||
BenchmarkParam::Create<std::string>(""));
|
|
||||||
params.AddParam("input_layer_value_files",
|
|
||||||
BenchmarkParam::Create<std::string>(""));
|
|
||||||
params.AddParam("allow_fp16", BenchmarkParam::Create<bool>(false));
|
|
||||||
params.AddParam("require_full_delegation",
|
|
||||||
BenchmarkParam::Create<bool>(false));
|
|
||||||
params.AddParam("warmup_min_secs", BenchmarkParam::Create<float>(0.5f));
|
|
||||||
params.AddParam("use_legacy_nnapi", BenchmarkParam::Create<bool>(false));
|
|
||||||
params.AddParam("enable_op_profiling", BenchmarkParam::Create<bool>(false));
|
|
||||||
params.AddParam("max_profiling_buffer_entries",
|
|
||||||
BenchmarkParam::Create<int32_t>(1024));
|
|
||||||
params.AddParam("profiling_output_csv_file",
|
|
||||||
BenchmarkParam::Create<std::string>(""));
|
|
||||||
params.AddParam("enable_platform_tracing",
|
|
||||||
BenchmarkParam::Create<bool>(false));
|
|
||||||
|
|
||||||
for (const auto& delegate_provider :
|
|
||||||
tools::GetRegisteredDelegateProviders()) {
|
|
||||||
params.Merge(delegate_provider->DefaultParams());
|
|
||||||
}
|
}
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
|
@ -347,34 +347,33 @@ std::vector<Flag> BenchmarkTfLiteModel::GetFlags() {
|
|||||||
void BenchmarkTfLiteModel::LogParams() {
|
void BenchmarkTfLiteModel::LogParams() {
|
||||||
BenchmarkModel::LogParams();
|
BenchmarkModel::LogParams();
|
||||||
TFLITE_LOG(INFO) << "Graph: [" << params_.Get<std::string>("graph") << "]";
|
TFLITE_LOG(INFO) << "Graph: [" << params_.Get<std::string>("graph") << "]";
|
||||||
TFLITE_LOG(INFO) << "Input layers: ["
|
|
||||||
<< params_.Get<std::string>("input_layer") << "]";
|
const bool verbose = params_.Get<bool>("verbose");
|
||||||
TFLITE_LOG(INFO) << "Input shapes: ["
|
|
||||||
<< params_.Get<std::string>("input_layer_shape") << "]";
|
#define LOG_PARAM(type, name, prefix, suffix) \
|
||||||
TFLITE_LOG(INFO) << "Input value ranges: ["
|
LOG_BENCHMARK_PARAM(params_, type, name, prefix, suffix, verbose)
|
||||||
<< params_.Get<std::string>("input_layer_value_range")
|
|
||||||
<< "]";
|
LOG_PARAM(std::string, "input_layer", "Input layers: [", "]");
|
||||||
TFLITE_LOG(INFO) << "Input layer values files: ["
|
LOG_PARAM(std::string, "input_layer_shape", "Input shapes: [", "]");
|
||||||
<< params_.Get<std::string>("input_layer_value_files")
|
LOG_PARAM(std::string, "input_layer_value_range", "Input value ranges: [",
|
||||||
<< "]";
|
"]");
|
||||||
|
LOG_PARAM(std::string, "input_layer_value_files", "Input value files: [",
|
||||||
|
"]");
|
||||||
|
|
||||||
#if defined(__ANDROID__)
|
#if defined(__ANDROID__)
|
||||||
TFLITE_LOG(INFO) << "Use legacy nnapi : ["
|
LOG_PARAM(bool, "use_legacy_nnapi", "Use legacy nnapi: [", "]");
|
||||||
<< params_.Get<bool>("use_legacy_nnapi") << "]";
|
|
||||||
#endif
|
#endif
|
||||||
TFLITE_LOG(INFO) << "Allow fp16 : [" << params_.Get<bool>("allow_fp16")
|
LOG_PARAM(bool, "allow_fp16", "Allow fp16: [", "]");
|
||||||
<< "]";
|
LOG_PARAM(bool, "require_full_delegation", "Require full delegation: [", "]");
|
||||||
TFLITE_LOG(INFO) << "Require full delegation : ["
|
LOG_PARAM(bool, "enable_op_profiling", "Enable op profiling: [", "]");
|
||||||
<< params_.Get<bool>("require_full_delegation") << "]";
|
LOG_PARAM(int32_t, "max_profiling_buffer_entries",
|
||||||
TFLITE_LOG(INFO) << "Enable op profiling: ["
|
"Max profiling buffer entries: [", "]");
|
||||||
<< params_.Get<bool>("enable_op_profiling") << "]";
|
LOG_PARAM(std::string, "profiling_output_csv_file",
|
||||||
TFLITE_LOG(INFO) << "Max profiling buffer entries: ["
|
"CSV File to export profiling data to: [", "]");
|
||||||
<< params_.Get<int32_t>("max_profiling_buffer_entries")
|
LOG_PARAM(bool, "enable_platform_tracing", "Enable platform-wide tracing: [",
|
||||||
<< "]";
|
"]");
|
||||||
TFLITE_LOG(INFO) << "CSV File to export profiling data to: ["
|
|
||||||
<< params_.Get<std::string>("profiling_output_csv_file")
|
#undef LOG_PARAM
|
||||||
<< "]";
|
|
||||||
TFLITE_LOG(INFO) << "Enable platform-wide tracing: ["
|
|
||||||
<< params_.Get<bool>("enable_platform_tracing") << "]";
|
|
||||||
|
|
||||||
for (const auto& delegate_provider :
|
for (const auto& delegate_provider :
|
||||||
tools::GetRegisteredDelegateProviders()) {
|
tools::GetRegisteredDelegateProviders()) {
|
||||||
|
@ -76,12 +76,13 @@ class LoggingWrapper {
|
|||||||
tflite::logging::LoggingWrapper::LogSeverity::severity) \
|
tflite::logging::LoggingWrapper::LogSeverity::severity) \
|
||||||
.Stream()
|
.Stream()
|
||||||
|
|
||||||
#define TFLITE_TOOLS_CHECK(condition) \
|
#define TFLITE_MAY_LOG(severity, should_log) \
|
||||||
tflite::logging::LoggingWrapper( \
|
tflite::logging::LoggingWrapper( \
|
||||||
tflite::logging::LoggingWrapper::LogSeverity::FATAL, \
|
tflite::logging::LoggingWrapper::LogSeverity::severity, (should_log)) \
|
||||||
(condition) ? false : true) \
|
|
||||||
.Stream()
|
.Stream()
|
||||||
|
|
||||||
|
#define TFLITE_TOOLS_CHECK(condition) TFLITE_MAY_LOG(FATAL, !(condition))
|
||||||
|
|
||||||
#define TFLITE_TOOLS_CHECK_EQ(a, b) TFLITE_TOOLS_CHECK((a) == (b))
|
#define TFLITE_TOOLS_CHECK_EQ(a, b) TFLITE_TOOLS_CHECK((a) == (b))
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_TOOLS_LOGGING_H_
|
#endif // TENSORFLOW_LITE_TOOLS_LOGGING_H_
|
||||||
|
@ -52,12 +52,17 @@ class ToolParam {
|
|||||||
}
|
}
|
||||||
|
|
||||||
virtual ~ToolParam() {}
|
virtual ~ToolParam() {}
|
||||||
explicit ToolParam(ParamType type) : type_(type) {}
|
explicit ToolParam(ParamType type) : has_value_set_(false), type_(type) {}
|
||||||
|
|
||||||
|
bool HasValueSet() const { return has_value_set_; }
|
||||||
|
|
||||||
virtual void Set(const ToolParam&) {}
|
virtual void Set(const ToolParam&) {}
|
||||||
|
|
||||||
virtual std::unique_ptr<ToolParam> Clone() const = 0;
|
virtual std::unique_ptr<ToolParam> Clone() const = 0;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
bool has_value_set_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static void AssertHasSameType(ParamType a, ParamType b);
|
static void AssertHasSameType(ParamType a, ParamType b);
|
||||||
|
|
||||||
@ -70,7 +75,10 @@ class TypedToolParam : public ToolParam {
|
|||||||
explicit TypedToolParam(const T& value)
|
explicit TypedToolParam(const T& value)
|
||||||
: ToolParam(GetValueType<T>()), value_(value) {}
|
: ToolParam(GetValueType<T>()), value_(value) {}
|
||||||
|
|
||||||
void Set(const T& value) { value_ = value; }
|
void Set(const T& value) {
|
||||||
|
value_ = value;
|
||||||
|
has_value_set_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
T Get() const { return value_; }
|
T Get() const { return value_; }
|
||||||
|
|
||||||
@ -111,10 +119,16 @@ class ToolParams {
|
|||||||
params_.at(name)->AsTyped<T>()->Set(value);
|
params_.at(name)->AsTyped<T>()->Set(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool HasValueSet(const std::string& name) const {
|
||||||
|
AssertParamExists(name);
|
||||||
|
return params_.at(name)->AsConstTyped<T>()->HasValueSet();
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T Get(const std::string& name) const {
|
T Get(const std::string& name) const {
|
||||||
AssertParamExists(name);
|
AssertParamExists(name);
|
||||||
return params_.at(name)->AsTyped<T>()->Get();
|
return params_.at(name)->AsConstTyped<T>()->Get();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the value of all same parameters from 'other'.
|
// Set the value of all same parameters from 'other'.
|
||||||
|
@ -33,7 +33,11 @@ TEST(ToolParams, SetTest) {
|
|||||||
|
|
||||||
params.Set(others);
|
params.Set(others);
|
||||||
EXPECT_EQ(19, params.Get<int>("some-int1"));
|
EXPECT_EQ(19, params.Get<int>("some-int1"));
|
||||||
|
EXPECT_TRUE(params.HasValueSet<int>("some-int1"));
|
||||||
|
|
||||||
EXPECT_EQ(17, params.Get<int>("some-int2"));
|
EXPECT_EQ(17, params.Get<int>("some-int2"));
|
||||||
|
EXPECT_FALSE(params.HasValueSet<int>("some-int2"));
|
||||||
|
|
||||||
EXPECT_FALSE(params.HasParam("some-bool"));
|
EXPECT_FALSE(params.HasParam("some-bool"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ from tensorflow.python.util.tf_export import tf_export
|
|||||||
# This value changes every day with an automatic CL. It can be modified in code
|
# This value changes every day with an automatic CL. It can be modified in code
|
||||||
# via `forward_compatibility_horizon()` or with the environment variable
|
# via `forward_compatibility_horizon()` or with the environment variable
|
||||||
# TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
|
# TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
|
||||||
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 7, 1)
|
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 7, 2)
|
||||||
_FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
|
_FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
|
||||||
_FORWARD_COMPATIBILITY_DATE_NUMBER = None
|
_FORWARD_COMPATIBILITY_DATE_NUMBER = None
|
||||||
|
|
||||||
|
@ -70,6 +70,36 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tpu_executor",
|
||||||
|
srcs = [
|
||||||
|
"tpu_platform_registration.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"tpu_executor.h",
|
||||||
|
"tpu_platform.h",
|
||||||
|
"tpu_stream.h",
|
||||||
|
"tpu_timer.h",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":device_memory_base_helper",
|
||||||
|
":status_helper",
|
||||||
|
":tpu_executor_base",
|
||||||
|
":tpu_executor_c_api_hdrs",
|
||||||
|
":tpu_executor_interface",
|
||||||
|
":tpu_platform_interface",
|
||||||
|
":tpu_stream_interface",
|
||||||
|
"//tensorflow/core/platform:types",
|
||||||
|
"//tensorflow/core/tpu:tpu_api",
|
||||||
|
"//tensorflow/stream_executor",
|
||||||
|
"//tensorflow/stream_executor/lib",
|
||||||
|
"//tensorflow/stream_executor/platform",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
],
|
||||||
|
alwayslink = True,
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tpu_executor_base",
|
name = "tpu_executor_base",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -82,6 +112,7 @@ cc_library(
|
|||||||
"tpu_stream.h",
|
"tpu_stream.h",
|
||||||
"tpu_timer.h",
|
"tpu_timer.h",
|
||||||
],
|
],
|
||||||
|
visibility = ["//tensorflow/core/tpu:__pkg__"],
|
||||||
deps = [
|
deps = [
|
||||||
":device_memory_base_helper",
|
":device_memory_base_helper",
|
||||||
":status_helper",
|
":status_helper",
|
||||||
@ -98,17 +129,6 @@ cc_library(
|
|||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
],
|
],
|
||||||
alwayslink = True,
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "tpu_executor",
|
|
||||||
srcs = ["tpu_platform_registration.cc"],
|
|
||||||
deps = [
|
|
||||||
":tpu_executor_base",
|
|
||||||
"//tensorflow/stream_executor/platform",
|
|
||||||
],
|
|
||||||
alwayslink = True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
@ -139,7 +159,6 @@ cc_library(
|
|||||||
srcs = ["tpu_transfer_manager_registration.cc"],
|
srcs = ["tpu_transfer_manager_registration.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":tpu_executor",
|
":tpu_executor",
|
||||||
":tpu_executor_base",
|
|
||||||
":tpu_transfer_manager_base",
|
":tpu_transfer_manager_base",
|
||||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||||
],
|
],
|
||||||
@ -153,7 +172,7 @@ cc_library(
|
|||||||
":c_api_conversions",
|
":c_api_conversions",
|
||||||
":proto_helper",
|
":proto_helper",
|
||||||
":status_helper",
|
":status_helper",
|
||||||
":tpu_executor_base",
|
":tpu_executor",
|
||||||
":tpu_executor_c_api_hdrs",
|
":tpu_executor_c_api_hdrs",
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
@ -171,7 +190,6 @@ cc_library(
|
|||||||
hdrs = ["tpu_computation_placer.h"],
|
hdrs = ["tpu_computation_placer.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":tpu_executor",
|
":tpu_executor",
|
||||||
":tpu_executor_base",
|
|
||||||
":tpu_executor_c_api_hdrs",
|
":tpu_executor_c_api_hdrs",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla/service:computation_placer",
|
"//tensorflow/compiler/xla/service:computation_placer",
|
||||||
|
@ -18,11 +18,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||||
#include "tensorflow/core/tpu/tpu_api.h"
|
#include "tensorflow/core/tpu/tpu_api.h"
|
||||||
#include "tensorflow/stream_executor/device_memory.h"
|
|
||||||
#include "tensorflow/stream_executor/lib/status.h"
|
|
||||||
#include "tensorflow/stream_executor/tpu/device_memory_base_helper.h"
|
#include "tensorflow/stream_executor/tpu/device_memory_base_helper.h"
|
||||||
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_stream.h"
|
#include "tensorflow/stream_executor/tpu/tpu_stream.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_timer.h"
|
#include "tensorflow/stream_executor/tpu/tpu_timer.h"
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/device_memory.h"
|
#include "tensorflow/stream_executor/device_memory.h"
|
||||||
#include "tensorflow/stream_executor/device_options.h"
|
#include "tensorflow/stream_executor/device_options.h"
|
||||||
#include "tensorflow/stream_executor/event.h"
|
#include "tensorflow/stream_executor/event.h"
|
||||||
|
#include "tensorflow/stream_executor/lib/status.h"
|
||||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
#include "tensorflow/stream_executor/stream.h"
|
#include "tensorflow/stream_executor/stream.h"
|
||||||
#include "tensorflow/stream_executor/stream_executor.h"
|
#include "tensorflow/stream_executor/stream_executor.h"
|
||||||
|
@ -18,10 +18,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
#include "tensorflow/core/tpu/tpu_api.h"
|
#include "tensorflow/core/tpu/tpu_api.h"
|
||||||
#include "tensorflow/stream_executor/platform.h"
|
|
||||||
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
|
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -32,7 +30,7 @@ using Status = ::stream_executor::port::Status;
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
using StatusOr = ::stream_executor::port::StatusOr<T>;
|
using StatusOr = ::stream_executor::port::StatusOr<T>;
|
||||||
|
|
||||||
TpuPlatform::TpuPlatform() {
|
TpuPlatform::TpuPlatform() : name_("TPU") {
|
||||||
platform_ = tpu::ExecutorApiFn()->TpuPlatform_NewFn();
|
platform_ = tpu::ExecutorApiFn()->TpuPlatform_NewFn();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,10 +107,7 @@ TpuPlatform::GetUncachedExecutor(
|
|||||||
return TpuPlatform::kId;
|
return TpuPlatform::kId;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string& TpuPlatform::Name() const {
|
const std::string& TpuPlatform::Name() const { return name_; }
|
||||||
static std::string* name = new std::string("TPU");
|
|
||||||
return *name;
|
|
||||||
}
|
|
||||||
|
|
||||||
int64 TpuPlatform::TpuMemoryLimit() {
|
int64 TpuPlatform::TpuMemoryLimit() {
|
||||||
return tpu::ExecutorApiFn()->TpuPlatform_TpuMemoryLimitFn(platform_);
|
return tpu::ExecutorApiFn()->TpuPlatform_TpuMemoryLimitFn(platform_);
|
||||||
|
@ -121,7 +121,7 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
SE_Platform* platform_;
|
SE_Platform* platform_;
|
||||||
|
std::string name_;
|
||||||
stream_executor::ExecutorCache executor_cache_;
|
stream_executor::ExecutorCache executor_cache_;
|
||||||
StreamMap stream_map_;
|
StreamMap stream_map_;
|
||||||
EventMap event_map_;
|
EventMap event_map_;
|
||||||
|
@ -60,17 +60,17 @@ sudo ldconfig
|
|||||||
# Install Horovod.
|
# Install Horovod.
|
||||||
cd ..
|
cd ..
|
||||||
HOROVOD_WITH_TENSORFLOW=1
|
HOROVOD_WITH_TENSORFLOW=1
|
||||||
pip3.7 install horovod[tensorflow]
|
pip3.7 install horovod[tensorflow] --user
|
||||||
|
|
||||||
# Install tests.
|
# Install tests.
|
||||||
git clone https://github.com/DEKHTIARJonathan/TF_HVD_Stability_Test.git
|
git clone https://github.com/DEKHTIARJonathan/TF_HVD_Stability_Test.git
|
||||||
|
|
||||||
# Install pytest.
|
# Install pytest.
|
||||||
pip3.7 install -U pytest
|
pip3.7 install -U pytest --user
|
||||||
|
|
||||||
# Install requirements.
|
# Install requirements.
|
||||||
cd TF_HVD_Stability_Test
|
cd TF_HVD_Stability_Test
|
||||||
pip3.7 install -r requirements.txt
|
pip3.7 install -r requirements.txt --user
|
||||||
|
|
||||||
# Run the tests.
|
# Run the tests.
|
||||||
python3.7 -m pytest
|
python3.7 -m pytest
|
||||||
|
@ -710,8 +710,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check out LLVM and MLIR from llvm-project.
|
# Check out LLVM and MLIR from llvm-project.
|
||||||
LLVM_COMMIT = "0f9d623b63e87b4ba30c30fd884ecc333eb32b4a"
|
LLVM_COMMIT = "d6343e607ac8fa71fa6d99f9c86369ae9e66e671"
|
||||||
LLVM_SHA256 = "58dee49dd9e79eea829fd6ca2d57cd1bf927445771bb296985061bd7644d676d"
|
LLVM_SHA256 = "0824d59e80c99e64cafe6e8051c9861e534dee60f056dcd528d5fe00ebeb542f"
|
||||||
LLVM_URLS = [
|
LLVM_URLS = [
|
||||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
|
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
|
||||||
"https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
|
"https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
|
||||||
|
23
third_party/mlir/BUILD
vendored
23
third_party/mlir/BUILD
vendored
@ -1478,6 +1478,7 @@ cc_library(
|
|||||||
":IR",
|
":IR",
|
||||||
":Pass",
|
":Pass",
|
||||||
":SCFDialect",
|
":SCFDialect",
|
||||||
|
":SCFToSPIRV",
|
||||||
":SPIRVDialect",
|
":SPIRVDialect",
|
||||||
":SPIRVLowering",
|
":SPIRVLowering",
|
||||||
":StandardToSPIRVTransforms",
|
":StandardToSPIRVTransforms",
|
||||||
@ -2233,6 +2234,28 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "SCFToSPIRV",
|
||||||
|
srcs = ["lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp"],
|
||||||
|
hdrs = ["include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"],
|
||||||
|
includes = ["include"],
|
||||||
|
deps = [
|
||||||
|
":Affine",
|
||||||
|
":AffineToStandard",
|
||||||
|
":ConversionPassIncGen",
|
||||||
|
":IR",
|
||||||
|
":Pass",
|
||||||
|
":SCFDialect",
|
||||||
|
":SPIRVDialect",
|
||||||
|
":SPIRVLowering",
|
||||||
|
":StandardOps",
|
||||||
|
":Support",
|
||||||
|
":TransformUtils",
|
||||||
|
":Transforms",
|
||||||
|
"@llvm-project//llvm:Support",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "SCFToStandard",
|
name = "SCFToStandard",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
4
third_party/mlir/tblgen.bzl
vendored
4
third_party/mlir/tblgen.bzl
vendored
@ -23,7 +23,7 @@ def gentbl(name, tblgen, td_file, tbl_outs, td_srcs = [], td_includes = [], td_r
|
|||||||
|
|
||||||
td_includes_cmd = [
|
td_includes_cmd = [
|
||||||
"-I external/llvm-project/mlir/include -I external/org_tensorflow",
|
"-I external/llvm-project/mlir/include -I external/org_tensorflow",
|
||||||
"-I $(GENDIR)/external/llvm-project/mlir/include",
|
"-I $(GENDIR)/external/llvm-project/mlir/include -I $(GENDIR)/external/org_tensorflow",
|
||||||
]
|
]
|
||||||
for td_include in td_includes:
|
for td_include in td_includes:
|
||||||
td_includes_cmd += [
|
td_includes_cmd += [
|
||||||
@ -32,7 +32,7 @@ def gentbl(name, tblgen, td_file, tbl_outs, td_srcs = [], td_includes = [], td_r
|
|||||||
]
|
]
|
||||||
for td_include in td_relative_includes:
|
for td_include in td_relative_includes:
|
||||||
td_includes_cmd += [
|
td_includes_cmd += [
|
||||||
"-I%s/%s" % (native.package_name(), td_include),
|
"-I%s/%s -Iexternal/org_tensorflow/%s/%s" % (native.package_name(), td_include, native.package_name(), td_include),
|
||||||
"-I$(GENDIR)/%s/%s" % (native.package_name(), td_include),
|
"-I$(GENDIR)/%s/%s" % (native.package_name(), td_include),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user