Merged commit includes the following changes:
319411158 by A. Unique TensorFlower<gardener@tensorflow.org>: Integrate LLVM at https://github.com/llvm/llvm-project/commit/d6343e607ac8 -- 319410296 by A. Unique TensorFlower<gardener@tensorflow.org>: [XLA] Implement extra prefetch limit for while uses. Outstanding prefetch limits can prevent prefetches from being scheduled for the duration of while loops. Since using alternate memory for the while loops can be more beneficial, allow specifying additional prefetch limit when the use is a while HLO. -- 319406145 by A. Unique TensorFlower<gardener@tensorflow.org>: [XLA:CPU] Teach dot_op_emitter how to tile&vectorize linalg matmuls And turn them on by default. This is on-par with the existing emitter, sometimes better and unlocks more potential. The strategy classes are duplicated right now, but I expect them to graduate to mlir core soon. I'm planning to remove the custom LLVM IR emitters if this turns out to be stable enough. -- 319402982 by A. Unique TensorFlower<gardener@tensorflow.org>: PR #40327: [ROCm] Enabling optimized FusedBatchNormInferenceMetaKernel for half Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/40327 This PR enables optimized FusedBatchNormInferenceMetaKernel for half on ROCm. Copybara import of the project: -- 5f658e2bc1b20794239658bffe0d7bf9cb89c81f by Eugene Kuznetsov <eugene.kuznetsov@amd.com>: Enabling optimized FusedBatchNormInferenceMetaKernel for half -- 319393611 by A. Unique TensorFlower<gardener@tensorflow.org>: Integrate LLVM at https://github.com/llvm/llvm-project/commit/68498ce8af37 -- 319374663 by A. Unique TensorFlower<gardener@tensorflow.org>: compat: Update forward compatibility horizon to 2020-07-02 -- 319374662 by A. Unique TensorFlower<gardener@tensorflow.org>: Update GraphDef version to 450. -- 319371388 by A. Unique TensorFlower<gardener@tensorflow.org>: Update framework_build_test targets -- 319363982 by A. Unique TensorFlower<gardener@tensorflow.org>: Resolve the permission denied error on Python 3.7 pip install. -- 319361498 by A. Unique TensorFlower<gardener@tensorflow.org>: Add an option to only log parameters whose values are parsed from cmdline flags in the benchmark tool. -- 319356677 by A. Unique TensorFlower<gardener@tensorflow.org>: Fix bug in ReadNonConstantTensor assigning new value to the reference don't update the reference, so use pointer instead. -- 319350974 by A. Unique TensorFlower<gardener@tensorflow.org>: Fix the header inclusion path issue with TensorFlowLiteC -- 319342653 by A. Unique TensorFlower<gardener@tensorflow.org>: Fix the relationship between tpu_executor and tpu_executor_base build targets. -- 319342578 by A. Unique TensorFlower<gardener@tensorflow.org>: Internal change 319340968 by A. Unique TensorFlower<gardener@tensorflow.org>: Internal change PiperOrigin-RevId: 319411158
This commit is contained in:
parent
62b6c316d2
commit
e25d3e084b
@ -471,6 +471,7 @@ cc_library(
|
||||
":cpu_runtime",
|
||||
":ir_emission_utils",
|
||||
":mlir_emitter",
|
||||
":mlir_matmul_codegen_strategy",
|
||||
":target_machine_features",
|
||||
":tiled_dot_emitter",
|
||||
":vector_support_library",
|
||||
@ -1102,12 +1103,33 @@ cc_library(
|
||||
"@llvm-project//llvm:Core",
|
||||
"@llvm-project//llvm:IPO",
|
||||
"@llvm-project//llvm:Linker",
|
||||
"@llvm-project//mlir:CFGTransforms",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LLVMTransforms",
|
||||
"@llvm-project//mlir:LinalgToLLVM",
|
||||
"@llvm-project//mlir:LinalgTransforms",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:TargetLLVMIR",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
"@llvm-project//mlir:VectorToLLVM",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mlir_matmul_codegen_strategy",
|
||||
srcs = ["mlir_matmul_codegen_strategy.cc"],
|
||||
hdrs = ["mlir_matmul_codegen_strategy.h"],
|
||||
deps = [
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Affine",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LinalgOps",
|
||||
"@llvm-project//mlir:LinalgTransforms",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
"@llvm-project//mlir:VectorOps",
|
||||
"@llvm-project//mlir:VectorToSCF",
|
||||
],
|
||||
)
|
||||
|
@ -25,7 +25,6 @@ const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size";
|
||||
const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor";
|
||||
const char* const kXlaForceEnableExperimentalLlvmIrGemm =
|
||||
"xla_force_enable_experimental_llvm_ir_gemm";
|
||||
const char* const kXlaUseLinalgForDot = "xla_use_linalg_for_dot";
|
||||
const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size";
|
||||
|
||||
} // namespace
|
||||
@ -64,12 +63,6 @@ bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config) {
|
||||
return extra_options_map.count(kXlaForceEnableExperimentalLlvmIrGemm) > 0;
|
||||
}
|
||||
|
||||
bool UseLinalgForDot(const HloModuleConfig& config) {
|
||||
const auto& extra_options_map =
|
||||
config.debug_options().xla_backend_extra_options();
|
||||
return extra_options_map.count(kXlaUseLinalgForDot) > 0;
|
||||
}
|
||||
|
||||
static absl::string_view RemoveSuffix(absl::string_view str,
|
||||
absl::string_view suffix) {
|
||||
CHECK_GE(str.size(), suffix.size());
|
||||
|
@ -31,10 +31,12 @@ limitations under the License.
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
|
||||
@ -202,6 +204,20 @@ class DotOpEmitter {
|
||||
.value_or(kDefaultTileSize);
|
||||
}
|
||||
|
||||
std::array<int64_t, 3> GetMlirGemmTileSize() const {
|
||||
// Tile by 4 x registers x register size. This was picked by running
|
||||
// small matmuls on Haswell and Skylake. There's a lot of room for
|
||||
// improvement here.
|
||||
constexpr int64_t kDefaultTileSizeForM = 4;
|
||||
int64_t elements_per_register =
|
||||
target_machine_features_.vector_register_num_elements(
|
||||
*b_->GetInsertBlock()->getParent(),
|
||||
dot_info_.result_shape.element_type());
|
||||
int64_t num_registers = target_machine_features_.vector_register_count(
|
||||
*b_->GetInsertBlock()->getParent());
|
||||
return {{kDefaultTileSizeForM, num_registers, elements_per_register}};
|
||||
}
|
||||
|
||||
DotInfo dot_info_;
|
||||
string dot_hlo_name_;
|
||||
const llvm_ir::IrArray& target_array_;
|
||||
@ -250,6 +266,7 @@ Status DotOpEmitter::EmitLinalgMatmul() {
|
||||
absl::StrCat("linalgMatMul_", dot_info_.result_shape.ToString(true), "_",
|
||||
dot_info_.lhs_shape.ToString(true), "_",
|
||||
dot_info_.rhs_shape.ToString(true));
|
||||
|
||||
return EmitMlirFuncAndCall(
|
||||
mlir_context_, b_, dot_info_.result_shape, operand_shapes, target_ptr,
|
||||
operand_ptrs, name, [&](mlir::OpBuilder* builder, mlir::FuncOp function) {
|
||||
@ -259,6 +276,27 @@ Status DotOpEmitter::EmitLinalgMatmul() {
|
||||
mlir::edsc::intrinsics::linalg_matmul(mlir::TypeRange{},
|
||||
mlir::ValueRange{b, c, a});
|
||||
mlir::edsc::intrinsics::std_ret();
|
||||
|
||||
mlir::linalg::LinalgTilingOptions tilingOptions;
|
||||
tilingOptions = tilingOptions.setTileSizes(GetMlirGemmTileSize());
|
||||
int64 alignment =
|
||||
target_machine_features_.minimum_alignment_for_allocation(
|
||||
ShapeUtil::ByteSizeOf(dot_info_.result_shape));
|
||||
mlir_strategy::MatmulCodegenStrategy strategy;
|
||||
strategy.tile<mlir::linalg::MatmulOp>(tilingOptions)
|
||||
.promote<mlir::linalg::MatmulOp>(
|
||||
mlir::linalg::LinalgPromotionOptions()
|
||||
.setAlignment(alignment)
|
||||
.setUseFullTileBuffersByDefault(true)
|
||||
.setUseAlloca(true))
|
||||
.vectorize<mlir::linalg::MatmulOp>()
|
||||
.setVectorTransformsOptions(
|
||||
mlir::vector::VectorTransformsOptions()
|
||||
.setVectorTransformsOptions(
|
||||
mlir::vector::VectorContractLowering::OuterProduct))
|
||||
.setVectorTransferToSCFOptions(
|
||||
mlir::VectorTransferToSCFOptions().setUnroll(true));
|
||||
strategy.transform(function);
|
||||
});
|
||||
}
|
||||
|
||||
@ -948,7 +986,8 @@ DotImplementationStrategy GetDotImplementationStrategy(
|
||||
|
||||
if (IsAlignedGemm(dot_info, target_machine_features)) {
|
||||
if (CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)) {
|
||||
return options::UseLinalgForDot(config)
|
||||
return primitive_util::IsFloatingPointType(
|
||||
dot_info.result_shape.element_type())
|
||||
? DotImplementationStrategy::kLinalgMatmul
|
||||
: DotImplementationStrategy::kTiledLlvmIrGemm;
|
||||
}
|
||||
|
@ -17,14 +17,14 @@ limitations under the License.
|
||||
|
||||
#include "llvm/Linker/Linker.h"
|
||||
#include "llvm/Transforms/IPO/Internalize.h"
|
||||
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project
|
||||
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Target/LLVMIR.h" // from @llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
||||
|
||||
namespace xla {
|
||||
@ -35,9 +35,9 @@ namespace {
|
||||
std::unique_ptr<llvm::Module> MakeLLVMModule(mlir::OwningModuleRef module) {
|
||||
mlir::PassManager manager(module->getContext());
|
||||
manager.addPass(mlir::createConvertLinalgToLoopsPass());
|
||||
manager.addPass(mlir::createConvertLinalgToLLVMPass());
|
||||
manager.addPass(mlir::createLowerAffinePass());
|
||||
manager.addPass(mlir::createLowerToCFGPass());
|
||||
manager.addPass(mlir::createConvertVectorToLLVMPass());
|
||||
manager.addPass(mlir::createLowerToLLVMPass());
|
||||
CHECK(succeeded(manager.run(*module)));
|
||||
return mlir::translateModuleToLLVMIR(*module);
|
||||
}
|
||||
|
@ -0,0 +1,269 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h" // from @llvm-project
|
||||
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h" // from @llvm-project
|
||||
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
|
||||
#include "mlir/Dialect/SCF/Utils.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Vector/EDSC/Intrinsics.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Vector/VectorOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Vector/VectorTransforms.h" // from @llvm-project
|
||||
#include "mlir/IR/AffineExpr.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
|
||||
#include "mlir/IR/Dominance.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Transforms/LoopUtils.h" // from @llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
||||
|
||||
// TODO(kramerb): Remove this once strategy is in mlir core.
|
||||
|
||||
using namespace mlir; // NOLINT
|
||||
using namespace mlir::linalg; // NOLINT
|
||||
|
||||
#define DEBUG_TYPE "matmul-codegen-strategy"
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
namespace mlir_strategy {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TODO: Cleanup and upstream these to go into core. Please ignore for now !
|
||||
//===----------------------------------------------------------------------===//
|
||||
static void hoistRedundantCopies(FuncOp func) {
|
||||
bool changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
func.walk([&](linalg::FillOp op) {
|
||||
auto loop = op.getParentOfType<scf::ForOp>();
|
||||
if (!loop) return;
|
||||
|
||||
for (auto operand : op.getOperands())
|
||||
if (!loop.isDefinedOutsideOfLoop(operand)) return;
|
||||
|
||||
// Hoist fill before.
|
||||
op.getOperation()->moveBefore(loop);
|
||||
changed = true;
|
||||
});
|
||||
|
||||
func.walk([&](linalg::CopyOp op) {
|
||||
auto loop = op.getParentOfType<scf::ForOp>();
|
||||
if (!loop) return;
|
||||
|
||||
for (auto operand : op.getOperands())
|
||||
if (!loop.isDefinedOutsideOfLoop(operand)) return;
|
||||
|
||||
Value sourceView = op.getInput(0);
|
||||
while (auto subViewOp = sourceView.getDefiningOp<SubViewOp>())
|
||||
sourceView = subViewOp.getViewSource();
|
||||
|
||||
// Source traces back to a block argument.
|
||||
if (sourceView.isa<BlockArgument>()) {
|
||||
op.getOperation()->moveBefore(loop);
|
||||
} else {
|
||||
assert(sourceView.getDefiningOp<ViewOp>() ||
|
||||
sourceView.getDefiningOp<AllocOp>() ||
|
||||
sourceView.getDefiningOp<AllocaOp>());
|
||||
op.getOperation()->moveAfter(loop);
|
||||
}
|
||||
changed = true;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Substitute scf.for = %lb to %ub step %step by an AffineExpr expressing:
|
||||
/// `%lb + %step * new_dim` where
|
||||
/// 1. the AffineExpr for %lb is either an AffineConstantExpr or an
|
||||
/// AffineDimExpr depending on whether the value is constant or not.
|
||||
/// 2. the AffineExpr for %step is either an AffineConstantExpr or an
|
||||
/// AffineSymbolExpr depending on whether the value is constant or not.
|
||||
///
|
||||
static void substitute(scf::ForOp forOp, SmallVectorImpl<AffineExpr> &exprs,
|
||||
SmallVectorImpl<Value> &dims,
|
||||
SmallVectorImpl<Value> &symbols) {
|
||||
MLIRContext *ctx = forOp.getContext();
|
||||
auto lbConstant = forOp.lowerBound().getDefiningOp<ConstantIndexOp>();
|
||||
AffineExpr lb = lbConstant ? getAffineConstantExpr(lbConstant.getValue(), ctx)
|
||||
: getAffineDimExpr(dims.size(), ctx);
|
||||
|
||||
auto stepConstant = forOp.step().getDefiningOp<ConstantIndexOp>();
|
||||
AffineExpr step = stepConstant
|
||||
? getAffineConstantExpr(stepConstant.getValue(), ctx)
|
||||
: getAffineSymbolExpr(symbols.size(), ctx);
|
||||
|
||||
if (!lbConstant) dims.push_back(forOp.lowerBound());
|
||||
if (!stepConstant) symbols.push_back(forOp.step());
|
||||
exprs.push_back(lb + step * getAffineDimExpr(dims.size(), ctx));
|
||||
|
||||
auto ubConstant = forOp.upperBound().getDefiningOp<ConstantIndexOp>();
|
||||
AffineExpr ub = ubConstant ? getAffineConstantExpr(ubConstant.getValue(), ctx)
|
||||
: getAffineDimExpr(dims.size(), ctx);
|
||||
if (!ubConstant) dims.push_back(forOp.upperBound());
|
||||
exprs.push_back(ub);
|
||||
|
||||
dims.push_back(forOp.getInductionVar());
|
||||
}
|
||||
|
||||
/// Traverse the .
|
||||
static void substitute(AffineMinOp minOp, SmallVectorImpl<AffineExpr> &exprs,
|
||||
SmallVectorImpl<Value> &dims,
|
||||
SmallVectorImpl<Value> &symbols) {
|
||||
MLIRContext *ctx = minOp.getContext();
|
||||
for (Value v : minOp.getDimOperands()) {
|
||||
if (auto forOp = scf::getForInductionVarOwner(v)) {
|
||||
substitute(forOp, exprs, dims, symbols);
|
||||
continue;
|
||||
}
|
||||
if (auto parentMinOp = v.getDefiningOp<AffineMinOp>()) {
|
||||
substitute(parentMinOp, exprs, dims, symbols);
|
||||
continue;
|
||||
}
|
||||
exprs.push_back(getAffineDimExpr(dims.size(), ctx));
|
||||
dims.push_back(v);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform folding of chains of AffineMinOp.
|
||||
struct AffineMinCanonicalizationPattern : public OpRewritePattern<AffineMinOp> {
|
||||
using OpRewritePattern<AffineMinOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(AffineMinOp minOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
LogicalResult AffineMinCanonicalizationPattern::matchAndRewrite(
|
||||
AffineMinOp minOp, PatternRewriter &rewriter) const {
|
||||
LLVM_DEBUG(llvm::dbgs() << "\nCanonicalize AffineMin: "
|
||||
<< *minOp.getOperation() << "\n");
|
||||
|
||||
int64_t min = std::numeric_limits<int64_t>::max();
|
||||
for (auto e : minOp.map().getResults())
|
||||
if (auto cstExpr = e.dyn_cast<AffineConstantExpr>())
|
||||
min = std::min(min, cstExpr.getValue());
|
||||
if (min == std::numeric_limits<int64_t>::max()) return failure();
|
||||
|
||||
SmallVector<AffineExpr, 4> exprs;
|
||||
SmallVector<Value, 4> dims, symbols;
|
||||
substitute(minOp, exprs, dims, symbols);
|
||||
|
||||
SmallVector<Value, 4> operands = dims;
|
||||
operands.append(symbols.begin(), symbols.end());
|
||||
|
||||
MLIRContext *ctx = minOp.getContext();
|
||||
auto map = AffineMap::get(dims.size(), symbols.size(), exprs, ctx);
|
||||
LLVM_DEBUG(llvm::dbgs() << "Substitution map: " << map << "\n");
|
||||
|
||||
SmallVector<AffineExpr, 4> modExprs;
|
||||
for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx)
|
||||
modExprs.push_back(getAffineDimExpr(idx, ctx) % min);
|
||||
map = AffineMap::get(map.getNumResults(), 0, modExprs, ctx).compose(map);
|
||||
canonicalizeMapAndOperands(&map, &operands);
|
||||
map = simplifyAffineMap(map);
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << "Post mod: " << map << "\n";
|
||||
llvm::interleaveComma(operands, llvm::dbgs()));
|
||||
|
||||
if (!llvm::all_of(map.getResults(), [](AffineExpr e) {
|
||||
if (auto cst = e.dyn_cast<AffineConstantExpr>())
|
||||
return cst.getValue() == 0;
|
||||
return false;
|
||||
}))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, min);
|
||||
return success();
|
||||
}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// END TODO
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void MatmulCodegenStrategy::transform(FuncOp func) const {
|
||||
MLIRContext *context = func.getContext();
|
||||
// Emplace patterns one at a time while also maintaining a simple chained
|
||||
// state transition.
|
||||
unsigned stepCount = 0;
|
||||
SmallVector<OwningRewritePatternList, 4> stage1Patterns;
|
||||
auto zeroState = Identifier::get(std::to_string(stepCount), context);
|
||||
auto currentState = zeroState;
|
||||
for (auto &t : transformation_sequence) {
|
||||
auto nextState = Identifier::get(std::to_string(++stepCount), context);
|
||||
auto marker = (currentState == zeroState)
|
||||
? linalg::LinalgMarker({}, nextState)
|
||||
: linalg::LinalgMarker(currentState, nextState);
|
||||
stage1Patterns.emplace_back(t->buildRewritePatterns(context, marker));
|
||||
currentState = nextState;
|
||||
}
|
||||
|
||||
OwningRewritePatternList stage2Patterns =
|
||||
linalg::getLinalgTilingCanonicalizationPatterns(context);
|
||||
stage2Patterns.insert<AffineMinCanonicalizationPattern>(context);
|
||||
|
||||
auto stage3Transforms = [](Operation *op) {
|
||||
// Some of these may be too aggressive as a stage 3 that is applied on each
|
||||
// stage 1 application and may have to be split out to post staged patterns
|
||||
// application (in which case they could just be passes, TBD).
|
||||
PassManager pm(op->getContext());
|
||||
pm.addPass(createLoopInvariantCodeMotionPass());
|
||||
if (failed(pm.run(op->getParentOfType<ModuleOp>())))
|
||||
llvm_unreachable("Unexpected failure in cleanup pass pipeline.");
|
||||
promoteSingleIterationLoops(cast<FuncOp>(op));
|
||||
hoistViewAllocOps(cast<FuncOp>(op));
|
||||
hoistRedundantVectorTransfers(cast<FuncOp>(op));
|
||||
hoistRedundantCopies(cast<FuncOp>(op));
|
||||
return success();
|
||||
};
|
||||
linalg::applyStagedPatterns(func, stage1Patterns, stage2Patterns,
|
||||
stage3Transforms);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Post staged patterns transforms
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Programmatic controlled lowering of vector.contract only.
|
||||
OwningRewritePatternList vectorContractLoweringPatterns;
|
||||
vectorContractLoweringPatterns
|
||||
.insert<ContractionOpToOuterProductOpLowering,
|
||||
ContractionOpToMatmulOpLowering, ContractionOpLowering>(
|
||||
vector_transforms_options, context);
|
||||
applyPatternsAndFoldGreedily(func, vectorContractLoweringPatterns);
|
||||
|
||||
// Programmatic controlled lowering of vector.transfer only.
|
||||
OwningRewritePatternList vectorToLoopsPatterns;
|
||||
populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
|
||||
vector_to_scf_options);
|
||||
applyPatternsAndFoldGreedily(func, vectorToLoopsPatterns);
|
||||
}
|
||||
|
||||
} // namespace mlir_strategy
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
@ -0,0 +1,188 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MLIR_EDGE_BENCHMARKS_STRATEGIES_MATMULCODEGENSTRATEGIES_H_
|
||||
#define MLIR_EDGE_BENCHMARKS_STRATEGIES_MATMULCODEGENSTRATEGIES_H_
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Vector/VectorOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Vector/VectorTransforms.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
|
||||
// TODO(kramerb): Remove this once strategy is in mlir core.
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
namespace mlir_strategy {
|
||||
|
||||
/// Abstract Transformation class applied in a sequence that also handles state
|
||||
/// through markers.
|
||||
struct Transformation {
|
||||
virtual ~Transformation() = default;
|
||||
virtual mlir::OwningRewritePatternList buildRewritePatterns(
|
||||
mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) = 0;
|
||||
mlir::linalg::LinalgMarker marker;
|
||||
};
|
||||
|
||||
/// Promotion transformation enqueues a particular stage-1 pattern for
|
||||
/// `Tile<LinalgOpType>`with the appropriate `options`.
|
||||
// TODO: variadic LinalgOpTypes.
|
||||
template <typename LinalgOpType>
|
||||
struct Tile : public Transformation {
|
||||
explicit Tile(mlir::linalg::LinalgTilingOptions options) : options(options) {}
|
||||
|
||||
mlir::OwningRewritePatternList buildRewritePatterns(
|
||||
mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) override {
|
||||
mlir::OwningRewritePatternList tiling_patterns;
|
||||
tiling_patterns.insert<mlir::linalg::LinalgTilingPattern<LinalgOpType>>(
|
||||
context, options, m);
|
||||
return tiling_patterns;
|
||||
}
|
||||
|
||||
private:
|
||||
mlir::linalg::LinalgTilingOptions options;
|
||||
};
|
||||
|
||||
/// Promotion transformation enqueues a particular stage-1 pattern for
|
||||
/// `Promote<LinalgOpType>`with the appropriate `options`.
|
||||
// TODO: variadic LinalgOpTypes.
|
||||
template <typename LinalgOpType>
|
||||
struct Promote : public Transformation {
|
||||
explicit Promote(mlir::linalg::LinalgPromotionOptions options)
|
||||
: options(options) {}
|
||||
|
||||
mlir::OwningRewritePatternList buildRewritePatterns(
|
||||
mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) override {
|
||||
mlir::OwningRewritePatternList promotion_patterns;
|
||||
promotion_patterns
|
||||
.insert<mlir::linalg::LinalgPromotionPattern<LinalgOpType>>(context,
|
||||
options, m);
|
||||
return promotion_patterns;
|
||||
}
|
||||
|
||||
private:
|
||||
mlir::linalg::LinalgPromotionOptions options;
|
||||
};
|
||||
|
||||
/// Vectorization transformation enqueues a particular stage-1 pattern for
|
||||
/// `LinalgVectorizationPattern<LinalgOpType>` as well as copy to vector
|
||||
/// transfer rewrite forwarding patterns.
|
||||
// TODO: variadic LinalgOpTypes.
|
||||
template <typename LinalgOpType>
|
||||
struct Vectorize : public Transformation {
|
||||
mlir::OwningRewritePatternList buildRewritePatterns(
|
||||
mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) override {
|
||||
mlir::OwningRewritePatternList vectorization_patterns;
|
||||
// FillOp may interfere with forwarding patterns atm, so we bump up the
|
||||
// priority of LinalgCopyVTRForwardingPattern /
|
||||
// LinalgCopyVTWForwardingPattern.
|
||||
vectorization_patterns
|
||||
.insert<mlir::linalg::LinalgVectorizationPattern<LinalgOpType>>(context,
|
||||
m);
|
||||
vectorization_patterns.insert<mlir::linalg::LinalgCopyVTRForwardingPattern,
|
||||
mlir::linalg::LinalgCopyVTWForwardingPattern>(
|
||||
context,
|
||||
/*benefit=*/2);
|
||||
return vectorization_patterns;
|
||||
}
|
||||
};
|
||||
|
||||
/// Matmul-specific strategy object controls how a linalg.matmul is
|
||||
/// progressively lowered.
|
||||
/// The strategy uses a 3-level staged patterns strategy which allows ordering
|
||||
/// transformations by using the Linalg `applyStagedPatterns` function, where:
|
||||
/// 1. The first stage consists of the successive `tile`, `promote` and
|
||||
/// `vectorize` patterns, applied sequentially.
|
||||
/// 2. The second stage consists of common local canonicalization patterns
|
||||
/// that are applied eagerly after each stage-1 pattern.
|
||||
/// 3. the third stage consists of more global transformation, also applied
|
||||
/// eagerly, after all stage-2 patterns. Such more global transformations
|
||||
struct MatmulCodegenStrategy {
|
||||
/// Append a pattern to add a level of tiling for `LinalgOpType` with tiling
|
||||
/// `options`.
|
||||
template <typename LinalgOpType>
|
||||
MatmulCodegenStrategy &tile(mlir::linalg::LinalgTilingOptions options) {
|
||||
transformation_sequence.emplace_back(new Tile<LinalgOpType>(options));
|
||||
return *this;
|
||||
}
|
||||
/// Conditionally append a pattern to add a level of tiling for `LinalgOpType`
|
||||
/// with tiling `options`.
|
||||
template <typename LinalgOpType>
|
||||
MatmulCodegenStrategy &tileIf(bool b,
|
||||
mlir::linalg::LinalgTilingOptions options) {
|
||||
return b ? tile<LinalgOpType>(options) : *this;
|
||||
}
|
||||
/// Append a pattern to add a level of promotion for `LinalgOpType` with
|
||||
/// promotion `options`.
|
||||
template <typename LinalgOpType>
|
||||
MatmulCodegenStrategy &promote(mlir::linalg::LinalgPromotionOptions options) {
|
||||
transformation_sequence.emplace_back(new Promote<LinalgOpType>(options));
|
||||
return *this;
|
||||
}
|
||||
/// Conditionally append a pattern to add a level of promotion for
|
||||
/// `LinalgOpType` with promotion `options`.
|
||||
template <typename LinalgOpType>
|
||||
MatmulCodegenStrategy &promoteIf(
|
||||
bool b, mlir::linalg::LinalgPromotionOptions options) {
|
||||
return b ? promote<LinalgOpType>(options) : *this;
|
||||
return *this;
|
||||
}
|
||||
/// Append a pattern to rewrite `LinalgOpType` as a vector operation.
|
||||
template <typename LinalgOpType>
|
||||
MatmulCodegenStrategy &vectorize() {
|
||||
transformation_sequence.emplace_back(new Vectorize<LinalgOpType>());
|
||||
return *this;
|
||||
}
|
||||
/// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
|
||||
/// operation.
|
||||
template <typename LinalgOpType>
|
||||
MatmulCodegenStrategy &vectorizeIf(bool b) {
|
||||
return b ? vectorize<LinalgOpType>() : *this;
|
||||
return *this;
|
||||
}
|
||||
/// Configure the post staged-patterns late vector transformations.
|
||||
MatmulCodegenStrategy &setVectorTransformsOptions(
|
||||
mlir::vector::VectorTransformsOptions options) {
|
||||
vector_transforms_options = options;
|
||||
return *this;
|
||||
}
|
||||
/// Configure the post staged-patterns late vector.transfer to scf conversion.
|
||||
MatmulCodegenStrategy &setVectorTransferToSCFOptions(
|
||||
mlir::VectorTransferToSCFOptions options) {
|
||||
vector_to_scf_options = options;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Apply the transformation patterns in sequence with cleanup transformations
|
||||
/// interleaved.
|
||||
void transform(mlir::FuncOp func) const;
|
||||
|
||||
private:
|
||||
mlir::LogicalResult postPatternTransforms(mlir::Operation *func) const;
|
||||
|
||||
mlir::vector::VectorTransformsOptions vector_transforms_options;
|
||||
mlir::VectorTransferToSCFOptions vector_to_scf_options;
|
||||
llvm::SmallVector<std::unique_ptr<Transformation>, 4> transformation_sequence;
|
||||
};
|
||||
|
||||
} // namespace mlir_strategy
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // MLIR_EDGE_BENCHMARKS_STRATEGIES_MATMULCODEGENSTRATEGIES_H_
|
@ -52,6 +52,12 @@ class TargetMachineFeatures {
|
||||
virtual int vector_register_num_elements(const llvm::Function& function,
|
||||
PrimitiveType type) const = 0;
|
||||
|
||||
// Return the number of vector registers. We need to pass in
|
||||
// "function" since llvm functions can contain annotations for specializing
|
||||
// them to specific micro-architectures (though currently XLA does not use
|
||||
// this functionality).
|
||||
virtual int vector_register_count(const llvm::Function& function) const = 0;
|
||||
|
||||
// Returns the minimum alignment for a buffer of size size_bytes.
|
||||
virtual int64 minimum_alignment_for_allocation(int64 size_bytes) const = 0;
|
||||
|
||||
@ -84,6 +90,12 @@ class LLVMTargetMachineFeatures : public TargetMachineFeatures {
|
||||
(primitive_util::BitWidth(type) / 8);
|
||||
}
|
||||
|
||||
int vector_register_count(const llvm::Function& function) const override {
|
||||
llvm::TargetTransformInfo* tti = GetTargetTransformInfoFor(function);
|
||||
return static_cast<int>(tti->getNumberOfRegisters(
|
||||
tti->getRegisterClassForType(/*Vector=*/true)));
|
||||
}
|
||||
|
||||
int64 minimum_alignment_for_allocation(int64 size_bytes) const override;
|
||||
|
||||
private:
|
||||
|
@ -44,6 +44,10 @@ class TargetMachineFeaturesWithFakeAlignmentLogic
|
||||
LOG(FATAL) << "Unexpected call to " << __func__;
|
||||
}
|
||||
|
||||
int vector_register_count(const llvm::Function& function) const override {
|
||||
LOG(FATAL) << "Unexpected call to " << __func__;
|
||||
}
|
||||
|
||||
int64 minimum_alignment_for_allocation(int64 size_bytes) const override {
|
||||
return fake_alignment_logic_(size_bytes);
|
||||
}
|
||||
|
@ -1642,7 +1642,8 @@ void AlternateMemoryBestFitHeap::AddAsyncCopy(
|
||||
}
|
||||
|
||||
bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies(
|
||||
int64 start_time, int64 end_time, bool is_prefetch) const {
|
||||
int64 start_time, int64 end_time, bool is_prefetch,
|
||||
int64 extra_async_copy_limit) const {
|
||||
if (options_.max_outstanding_prefetches < 0 && is_prefetch) {
|
||||
return false;
|
||||
}
|
||||
@ -1655,12 +1656,14 @@ bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies(
|
||||
int64 num_prefetches =
|
||||
prefetch_interval_tree_.ChunksOverlappingInTime(start_time, end_time)
|
||||
.size();
|
||||
return num_prefetches >= options_.max_outstanding_prefetches;
|
||||
return num_prefetches >=
|
||||
options_.max_outstanding_prefetches + extra_async_copy_limit;
|
||||
} else {
|
||||
int64 num_evictions =
|
||||
eviction_interval_tree_.ChunksOverlappingInTime(start_time, end_time)
|
||||
.size();
|
||||
return num_evictions >= options_.max_outstanding_evictions;
|
||||
return num_evictions >=
|
||||
options_.max_outstanding_evictions + extra_async_copy_limit;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1911,6 +1914,11 @@ bool AlternateMemoryBestFitHeap::Prefetch(
|
||||
// outstanding async copy limit or async copy ordering, set
|
||||
// prefetch_failed_due_to_async_copy_.
|
||||
prefetch_failed_due_to_async_copy_ = false;
|
||||
// While uses might be allowed to have additional outstanding prefetches.
|
||||
int64 extra_async_copy_limit =
|
||||
request.use->hlo_use.instruction->opcode() == HloOpcode::kWhile
|
||||
? options_.while_use_extra_outstanding_prefetch_limit
|
||||
: 0;
|
||||
while (!options_.prefetch_interval_picker->Done()) {
|
||||
alternate_mem_interval.start = options_.prefetch_interval_picker->Next();
|
||||
CHECK_LT(alternate_mem_interval.start, request.latest_prefetch_time);
|
||||
@ -1924,9 +1932,9 @@ bool AlternateMemoryBestFitHeap::Prefetch(
|
||||
prefetch_failed_due_to_async_copy_ = true;
|
||||
continue;
|
||||
}
|
||||
if (ViolatesMaximumOutstandingAsyncCopies(alternate_mem_interval.start,
|
||||
request.latest_prefetch_time,
|
||||
/*is_prefetch=*/true)) {
|
||||
if (ViolatesMaximumOutstandingAsyncCopies(
|
||||
alternate_mem_interval.start, request.latest_prefetch_time,
|
||||
/*is_prefetch=*/true, extra_async_copy_limit)) {
|
||||
VLOG(4) << "This would violate the outstanding async copy limit.";
|
||||
prefetch_failed_due_to_async_copy_ = true;
|
||||
continue;
|
||||
|
@ -379,6 +379,10 @@ class MemorySpaceAssignment {
|
||||
int64 max_outstanding_prefetches = -1;
|
||||
int64 max_outstanding_evictions = -1;
|
||||
|
||||
// Extra outstanding prefetch limit for while uses (in addition to
|
||||
// max_outstanding_prefetches).
|
||||
int64 while_use_extra_outstanding_prefetch_limit = 0;
|
||||
|
||||
// Specifies the maximum number of retries that will be performed for each
|
||||
// value in case prefetching failed due to running out of asynchronous
|
||||
// copies or asynchronous copy ordering.
|
||||
@ -1019,9 +1023,12 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||
void AddToChunkMap(const HloValue* buffer, Chunk chunk) override {}
|
||||
|
||||
// Returns true if the addition of an asynchronous copy in the given time
|
||||
// interval would violate the maximum number of asynchronous copies.
|
||||
bool ViolatesMaximumOutstandingAsyncCopies(int64 start_time, int64 end_time,
|
||||
bool is_prefetch) const;
|
||||
// interval would violate the maximum number of asynchronous copies. An extra
|
||||
// async copy limit can be provided to increase the limit of asynchronous
|
||||
// copies for this instance.
|
||||
bool ViolatesMaximumOutstandingAsyncCopies(
|
||||
int64 start_time, int64 end_time, bool is_prefetch,
|
||||
int64 extra_async_copy_limit = 0) const;
|
||||
|
||||
// Return true if the asynchronous copy would violate the pipelining order.
|
||||
bool ViolatesAsyncCopyOrdering(int64 start_time, int64 end_time) const;
|
||||
|
@ -18,6 +18,10 @@ limitations under the License.
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#endif
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
#include "rocm/include/hip/hip_fp16.h"
|
||||
typedef __half2 half2;
|
||||
#endif
|
||||
#include "tensorflow/core/kernels/fused_batch_norm_op.h"
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
|
||||
@ -174,6 +178,11 @@ template <TensorFormat tensor_format, bool add_side_input,
|
||||
struct FusedBatchNormInferenceKernel<Eigen::half, float, tensor_format,
|
||||
add_side_input, activation_mode,
|
||||
/*is_generic_kernel=*/false> {
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
using IT = __half;
|
||||
#else
|
||||
using IT = Eigen::half;
|
||||
#endif
|
||||
using T = Eigen::half;
|
||||
using U = float;
|
||||
|
||||
@ -185,15 +194,19 @@ struct FusedBatchNormInferenceKernel<Eigen::half, float, tensor_format,
|
||||
/*is_generic_kernel=*/true>;
|
||||
|
||||
__device__ static void run(int32 count, int32 channels_size,
|
||||
int32 inner_dim_size, const T* __restrict__ in,
|
||||
int32 inner_dim_size, const T* __restrict__ _in,
|
||||
const U* __restrict__ scale,
|
||||
const U* __restrict__ offset,
|
||||
const U* __restrict__ mean,
|
||||
const U* __restrict__ var,
|
||||
const T* __restrict__ side_input, float epsilon,
|
||||
T* __restrict__ out) {
|
||||
const T* __restrict__ _side_input, float epsilon,
|
||||
T* __restrict__ _out) {
|
||||
// Old GPUs do not have (or have very slow) fp16 arithmetic.
|
||||
#if __CUDA_ARCH__ >= 610
|
||||
#if (__CUDA_ARCH__ >= 610) || TENSORFLOW_USE_ROCM
|
||||
const IT* in = reinterpret_cast<const IT*>(_in);
|
||||
const IT* side_input = reinterpret_cast<const IT*>(_side_input);
|
||||
IT* out = reinterpret_cast<IT*>(_out);
|
||||
|
||||
int32 index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int32 total_device_threads = gridDim.x * blockDim.x;
|
||||
|
||||
@ -274,8 +287,8 @@ struct FusedBatchNormInferenceKernel<Eigen::half, float, tensor_format,
|
||||
}
|
||||
|
||||
#else
|
||||
GenericKernel::run(count, channels_size, inner_dim_size, in, scale, offset,
|
||||
mean, var, side_input, epsilon, out);
|
||||
GenericKernel::run(count, channels_size, inner_dim_size, _in, scale, offset,
|
||||
mean, var, _side_input, epsilon, _out);
|
||||
#endif // __CUDA_ARCH__ >= 610
|
||||
}
|
||||
};
|
||||
@ -287,10 +300,16 @@ __global__ void FusedBatchNormInferenceMetaKernel(
|
||||
const U* scale, const U* offset, const U* mean, const U* var,
|
||||
const T* side_input, float epsilon, T* out) {
|
||||
// We prefer to run non-generic specialization, for the given types T and U.
|
||||
// TODO(b/135435976): Temporary disable non-generic kernel implementation.
|
||||
FusedBatchNormInferenceKernel<
|
||||
T, U, tensor_format, add_side_input, activation_mode,
|
||||
/*is_generic_kernel=*/true>::run(count, channels_size, inner_dim_size, in,
|
||||
FusedBatchNormInferenceKernel<T, U, tensor_format, add_side_input,
|
||||
activation_mode,
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
false
|
||||
#else
|
||||
// TODO(b/135435976): Temporary disable
|
||||
// non-generic kernel implementation.
|
||||
/*is_generic_kernel=*/true
|
||||
#endif
|
||||
>::run(count, channels_size, inner_dim_size, in,
|
||||
scale, offset, mean, var, side_input,
|
||||
epsilon, out);
|
||||
}
|
||||
@ -312,7 +331,11 @@ struct FusedBatchNormInferenceFunctor<GPUDevice, T, U> {
|
||||
if (count == 0) return;
|
||||
|
||||
bool launched = false;
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
constexpr int32 kThreadInBlock = 1024;
|
||||
#else
|
||||
constexpr int32 kThreadInBlock = 512;
|
||||
#endif
|
||||
|
||||
#define LAUNCH(DATA_FORMAT, ADD_SIDE_INPUT, ACTIVATION, CHANNEL_SIZE, \
|
||||
INNER_DIM_SIZE) \
|
||||
|
@ -108,7 +108,7 @@ limitations under the License.
|
||||
|
||||
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
|
||||
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
|
||||
#define TF_GRAPH_DEF_VERSION 449 // Updated: 2020/7/1
|
||||
#define TF_GRAPH_DEF_VERSION 450 // Updated: 2020/7/2
|
||||
|
||||
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
|
||||
//
|
||||
|
@ -37,14 +37,14 @@ absl::Status ObjectReader::ReadNonConstantTensor(
|
||||
}
|
||||
|
||||
if (tensor_to_value->find(tensor_idx) == tensor_to_value->end()) {
|
||||
TfLiteTensor& tflite_tensor = context->tensors[tensor_idx];
|
||||
if (tflite::IsConstantTensor(&tflite_tensor)) {
|
||||
TfLiteTensor* tflite_tensor = &context->tensors[tensor_idx];
|
||||
if (tflite::IsConstantTensor(tflite_tensor)) {
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"ReadNonConstantTensor: value is a constant tensor: ", tensor_idx));
|
||||
}
|
||||
|
||||
if ((tflite_tensor.type == kTfLiteInt8 ||
|
||||
tflite_tensor.type == kTfLiteUInt8) &&
|
||||
if ((tflite_tensor->type == kTfLiteInt8 ||
|
||||
tflite_tensor->type == kTfLiteUInt8) &&
|
||||
quant_conversion_map) {
|
||||
// Quantized case
|
||||
if (quant_conversion_map->find(tensor_idx) ==
|
||||
@ -70,9 +70,9 @@ absl::Status ObjectReader::ReadNonConstantTensor(
|
||||
value->quant_params.emplace();
|
||||
// tflite_tensor from the outer scope is invalidated due to calling
|
||||
// CreateNewTensorWithDifferentType
|
||||
tflite_tensor = context->tensors[tensor_idx];
|
||||
tflite_tensor = &context->tensors[tensor_idx];
|
||||
RETURN_IF_ERROR(
|
||||
PopulateQuantParams(tflite_tensor, &value->quant_params.value()));
|
||||
PopulateQuantParams(*tflite_tensor, &value->quant_params.value()));
|
||||
(*tensor_to_value)[fp_tensor_index] = value;
|
||||
}
|
||||
// We do not use the original tensor index as reference for the GPU
|
||||
@ -82,7 +82,7 @@ absl::Status ObjectReader::ReadNonConstantTensor(
|
||||
// Floating-point case.
|
||||
Value* value = graph->NewValue();
|
||||
RETURN_IF_ERROR(
|
||||
ConvertTfLiteTensorToTensorRef(tflite_tensor, &value->tensor));
|
||||
ConvertTfLiteTensorToTensorRef(*tflite_tensor, &value->tensor));
|
||||
value->tensor.ref = tensor_idx;
|
||||
(*tensor_to_value)[tensor_idx] = value;
|
||||
}
|
||||
|
@ -21,7 +21,19 @@ sh_binary(
|
||||
# When the static framework is built with bazel, the all header files are moved
|
||||
# to the "Headers" directory with no header path prefixes. This auxiliary rule
|
||||
# is used for stripping the path prefix to the "common.h" file included by the
|
||||
# "xnnpack_delegate.h" header.
|
||||
# "c_api.h" header.
|
||||
genrule(
|
||||
name = "strip_c_api_include_hdr",
|
||||
srcs = ["//tensorflow/lite/c:c_api.h"],
|
||||
outs = ["c_api.h"],
|
||||
cmd = """
|
||||
sed 's|#include ".*common.h"|#include "common.h"|'\
|
||||
"$(location //tensorflow/lite/c:c_api.h)"\
|
||||
> "$@"
|
||||
""",
|
||||
)
|
||||
|
||||
# Similar rule as above, but for the "xnnpack_delegate.h" header.
|
||||
genrule(
|
||||
name = "strip_xnnpack_include_hdr",
|
||||
srcs = ["//tensorflow/lite/delegates/xnnpack:xnnpack_delegate.h"],
|
||||
@ -37,8 +49,8 @@ genrule(
|
||||
tflite_ios_static_framework(
|
||||
name = "TensorFlowLiteC_framework",
|
||||
hdrs = [
|
||||
":c_api.h",
|
||||
":xnnpack_delegate.h",
|
||||
"//tensorflow/lite/c:c_api.h",
|
||||
"//tensorflow/lite/c:common.h",
|
||||
],
|
||||
bundle_name = "TensorFlowLiteC",
|
||||
@ -136,12 +148,15 @@ cc_library(
|
||||
# Used for building TensorFlowLiteC framework.
|
||||
build_test(
|
||||
name = "framework_build_test",
|
||||
# build_test targets are not meant to be run with sanitizers.
|
||||
tags = [
|
||||
"noasan", # b/147230742
|
||||
"nomsan", # b/145205324
|
||||
"notsan", # b/145205324
|
||||
"noasan",
|
||||
"nomsan",
|
||||
"notsan",
|
||||
],
|
||||
targets = [
|
||||
":TensorFlowLiteCCoreML_framework",
|
||||
":TensorFlowLiteCMetal_framework",
|
||||
":TensorFlowLiteC_framework",
|
||||
":TensorFlowLiteSelectTfOps_framework",
|
||||
],
|
||||
|
@ -39,6 +39,7 @@ BenchmarkParams BenchmarkModel::DefaultParams() {
|
||||
params.AddParam("output_prefix", BenchmarkParam::Create<std::string>(""));
|
||||
params.AddParam("warmup_runs", BenchmarkParam::Create<int32_t>(1));
|
||||
params.AddParam("warmup_min_secs", BenchmarkParam::Create<float>(0.5f));
|
||||
params.AddParam("verbose", BenchmarkParam::Create<bool>(false));
|
||||
return params;
|
||||
}
|
||||
|
||||
@ -100,31 +101,31 @@ std::vector<Flag> BenchmarkModel::GetFlags() {
|
||||
"warmup_min_secs", ¶ms_,
|
||||
"minimum number of seconds to rerun for, potentially making the "
|
||||
"actual number of warm-up runs to be greater than warmup_runs"),
|
||||
CreateFlag<bool>("verbose", ¶ms_,
|
||||
"Whether to log parameters whose values are not set. "
|
||||
"By default, only log those parameters that are set by "
|
||||
"parsing their values from the commandline flag.."),
|
||||
};
|
||||
}
|
||||
|
||||
#define LOG_PARAM(type, name, prefix, suffix) \
|
||||
LOG_BENCHMARK_PARAM(params_, type, name, prefix, suffix, verbose)
|
||||
void BenchmarkModel::LogParams() {
|
||||
TFLITE_LOG(INFO) << "Min num runs: [" << params_.Get<int32_t>("num_runs")
|
||||
<< "]";
|
||||
TFLITE_LOG(INFO) << "Min runs duration (seconds): ["
|
||||
<< params_.Get<float>("min_secs") << "]";
|
||||
TFLITE_LOG(INFO) << "Max runs duration (seconds): ["
|
||||
<< params_.Get<float>("max_secs") << "]";
|
||||
TFLITE_LOG(INFO) << "Inter-run delay (seconds): ["
|
||||
<< params_.Get<float>("run_delay") << "]";
|
||||
TFLITE_LOG(INFO) << "Num threads: [" << params_.Get<int32_t>("num_threads")
|
||||
<< "]";
|
||||
TFLITE_LOG(INFO) << "Use caching: [" << params_.Get<bool>("use_caching")
|
||||
<< "]";
|
||||
TFLITE_LOG(INFO) << "Benchmark name: ["
|
||||
<< params_.Get<std::string>("benchmark_name") << "]";
|
||||
TFLITE_LOG(INFO) << "Output prefix: ["
|
||||
<< params_.Get<std::string>("output_prefix") << "]";
|
||||
TFLITE_LOG(INFO) << "Min warmup runs: ["
|
||||
<< params_.Get<int32_t>("warmup_runs") << "]";
|
||||
TFLITE_LOG(INFO) << "Min warmup runs duration (seconds): ["
|
||||
<< params_.Get<float>("warmup_min_secs") << "]";
|
||||
const bool verbose = params_.Get<bool>("verbose");
|
||||
LOG_PARAM(int32_t, "num_runs", "Min num runs: [", "]");
|
||||
LOG_PARAM(int32_t, "num_runs", "Min num runs: [", "]");
|
||||
LOG_PARAM(float, "min_secs", "Min runs duration (seconds): [", "]");
|
||||
LOG_PARAM(float, "max_secs", "Max runs duration (seconds): [", "]");
|
||||
LOG_PARAM(float, "run_delay", "Inter-run delay (seconds): [", "]");
|
||||
LOG_PARAM(int32_t, "num_threads", "Num threads: [", "]");
|
||||
LOG_PARAM(bool, "use_caching", "Use caching: [", "]");
|
||||
LOG_PARAM(std::string, "benchmark_name", "Benchmark name: [", "]");
|
||||
LOG_PARAM(std::string, "output_prefix", "Output prefix: [", "]");
|
||||
LOG_PARAM(int32_t, "warmup_runs", "Min warmup runs: [", "]");
|
||||
LOG_PARAM(float, "warmup_min_secs", "Min warmup runs duration (seconds): [",
|
||||
"]");
|
||||
}
|
||||
#undef LOG_PARAM
|
||||
|
||||
TfLiteStatus BenchmarkModel::PrepareInputData() { return kTfLiteOk; }
|
||||
|
||||
|
@ -21,6 +21,12 @@ namespace tflite {
|
||||
namespace benchmark {
|
||||
using BenchmarkParam = tflite::tools::ToolParam;
|
||||
using BenchmarkParams = tflite::tools::ToolParams;
|
||||
|
||||
#define LOG_BENCHMARK_PARAM(params, type, name, prefix, suffix, verbose) \
|
||||
do { \
|
||||
TFLITE_MAY_LOG(INFO, verbose || params.HasValueSet<type>(name)) \
|
||||
<< prefix << params.Get<type>(name) << suffix; \
|
||||
} while (0)
|
||||
} // namespace benchmark
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_
|
||||
|
@ -47,51 +47,18 @@ enum class ModelGraphType { FP32, INT8, STRING };
|
||||
|
||||
BenchmarkParams CreateParams(int32_t num_runs, float min_secs, float max_secs,
|
||||
ModelGraphType graph_type = ModelGraphType::FP32) {
|
||||
BenchmarkParams params;
|
||||
params.AddParam("num_runs", BenchmarkParam::Create<int32_t>(num_runs));
|
||||
params.AddParam("min_secs", BenchmarkParam::Create<float>(min_secs));
|
||||
params.AddParam("max_secs", BenchmarkParam::Create<float>(max_secs));
|
||||
params.AddParam("run_delay", BenchmarkParam::Create<float>(-1.0f));
|
||||
params.AddParam("num_threads", BenchmarkParam::Create<int32_t>(1));
|
||||
params.AddParam("use_caching", BenchmarkParam::Create<bool>(false));
|
||||
params.AddParam("benchmark_name", BenchmarkParam::Create<std::string>(""));
|
||||
params.AddParam("output_prefix", BenchmarkParam::Create<std::string>(""));
|
||||
params.AddParam("warmup_runs", BenchmarkParam::Create<int32_t>(1));
|
||||
BenchmarkParams params = BenchmarkTfLiteModel::DefaultParams();
|
||||
params.Set<int32_t>("num_runs", num_runs);
|
||||
params.Set<float>("min_secs", min_secs);
|
||||
params.Set<float>("max_secs", max_secs);
|
||||
|
||||
if (graph_type == ModelGraphType::INT8) {
|
||||
params.AddParam("graph",
|
||||
BenchmarkParam::Create<std::string>(*g_int8_model_path));
|
||||
params.Set<std::string>("graph", *g_int8_model_path);
|
||||
} else if (graph_type == ModelGraphType::STRING) {
|
||||
params.AddParam("graph",
|
||||
BenchmarkParam::Create<std::string>(*g_string_model_path));
|
||||
params.Set<std::string>("graph", *g_string_model_path);
|
||||
} else {
|
||||
// by default, simply use the fp32 one.
|
||||
params.AddParam("graph",
|
||||
BenchmarkParam::Create<std::string>(*g_fp32_model_path));
|
||||
}
|
||||
|
||||
params.AddParam("input_layer", BenchmarkParam::Create<std::string>(""));
|
||||
params.AddParam("input_layer_shape", BenchmarkParam::Create<std::string>(""));
|
||||
params.AddParam("input_layer_value_range",
|
||||
BenchmarkParam::Create<std::string>(""));
|
||||
params.AddParam("input_layer_value_files",
|
||||
BenchmarkParam::Create<std::string>(""));
|
||||
params.AddParam("allow_fp16", BenchmarkParam::Create<bool>(false));
|
||||
params.AddParam("require_full_delegation",
|
||||
BenchmarkParam::Create<bool>(false));
|
||||
params.AddParam("warmup_min_secs", BenchmarkParam::Create<float>(0.5f));
|
||||
params.AddParam("use_legacy_nnapi", BenchmarkParam::Create<bool>(false));
|
||||
params.AddParam("enable_op_profiling", BenchmarkParam::Create<bool>(false));
|
||||
params.AddParam("max_profiling_buffer_entries",
|
||||
BenchmarkParam::Create<int32_t>(1024));
|
||||
params.AddParam("profiling_output_csv_file",
|
||||
BenchmarkParam::Create<std::string>(""));
|
||||
params.AddParam("enable_platform_tracing",
|
||||
BenchmarkParam::Create<bool>(false));
|
||||
|
||||
for (const auto& delegate_provider :
|
||||
tools::GetRegisteredDelegateProviders()) {
|
||||
params.Merge(delegate_provider->DefaultParams());
|
||||
params.Set<std::string>("graph", *g_fp32_model_path);
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
@ -347,34 +347,33 @@ std::vector<Flag> BenchmarkTfLiteModel::GetFlags() {
|
||||
void BenchmarkTfLiteModel::LogParams() {
|
||||
BenchmarkModel::LogParams();
|
||||
TFLITE_LOG(INFO) << "Graph: [" << params_.Get<std::string>("graph") << "]";
|
||||
TFLITE_LOG(INFO) << "Input layers: ["
|
||||
<< params_.Get<std::string>("input_layer") << "]";
|
||||
TFLITE_LOG(INFO) << "Input shapes: ["
|
||||
<< params_.Get<std::string>("input_layer_shape") << "]";
|
||||
TFLITE_LOG(INFO) << "Input value ranges: ["
|
||||
<< params_.Get<std::string>("input_layer_value_range")
|
||||
<< "]";
|
||||
TFLITE_LOG(INFO) << "Input layer values files: ["
|
||||
<< params_.Get<std::string>("input_layer_value_files")
|
||||
<< "]";
|
||||
|
||||
const bool verbose = params_.Get<bool>("verbose");
|
||||
|
||||
#define LOG_PARAM(type, name, prefix, suffix) \
|
||||
LOG_BENCHMARK_PARAM(params_, type, name, prefix, suffix, verbose)
|
||||
|
||||
LOG_PARAM(std::string, "input_layer", "Input layers: [", "]");
|
||||
LOG_PARAM(std::string, "input_layer_shape", "Input shapes: [", "]");
|
||||
LOG_PARAM(std::string, "input_layer_value_range", "Input value ranges: [",
|
||||
"]");
|
||||
LOG_PARAM(std::string, "input_layer_value_files", "Input value files: [",
|
||||
"]");
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
TFLITE_LOG(INFO) << "Use legacy nnapi : ["
|
||||
<< params_.Get<bool>("use_legacy_nnapi") << "]";
|
||||
LOG_PARAM(bool, "use_legacy_nnapi", "Use legacy nnapi: [", "]");
|
||||
#endif
|
||||
TFLITE_LOG(INFO) << "Allow fp16 : [" << params_.Get<bool>("allow_fp16")
|
||||
<< "]";
|
||||
TFLITE_LOG(INFO) << "Require full delegation : ["
|
||||
<< params_.Get<bool>("require_full_delegation") << "]";
|
||||
TFLITE_LOG(INFO) << "Enable op profiling: ["
|
||||
<< params_.Get<bool>("enable_op_profiling") << "]";
|
||||
TFLITE_LOG(INFO) << "Max profiling buffer entries: ["
|
||||
<< params_.Get<int32_t>("max_profiling_buffer_entries")
|
||||
<< "]";
|
||||
TFLITE_LOG(INFO) << "CSV File to export profiling data to: ["
|
||||
<< params_.Get<std::string>("profiling_output_csv_file")
|
||||
<< "]";
|
||||
TFLITE_LOG(INFO) << "Enable platform-wide tracing: ["
|
||||
<< params_.Get<bool>("enable_platform_tracing") << "]";
|
||||
LOG_PARAM(bool, "allow_fp16", "Allow fp16: [", "]");
|
||||
LOG_PARAM(bool, "require_full_delegation", "Require full delegation: [", "]");
|
||||
LOG_PARAM(bool, "enable_op_profiling", "Enable op profiling: [", "]");
|
||||
LOG_PARAM(int32_t, "max_profiling_buffer_entries",
|
||||
"Max profiling buffer entries: [", "]");
|
||||
LOG_PARAM(std::string, "profiling_output_csv_file",
|
||||
"CSV File to export profiling data to: [", "]");
|
||||
LOG_PARAM(bool, "enable_platform_tracing", "Enable platform-wide tracing: [",
|
||||
"]");
|
||||
|
||||
#undef LOG_PARAM
|
||||
|
||||
for (const auto& delegate_provider :
|
||||
tools::GetRegisteredDelegateProviders()) {
|
||||
|
@ -76,12 +76,13 @@ class LoggingWrapper {
|
||||
tflite::logging::LoggingWrapper::LogSeverity::severity) \
|
||||
.Stream()
|
||||
|
||||
#define TFLITE_TOOLS_CHECK(condition) \
|
||||
tflite::logging::LoggingWrapper( \
|
||||
tflite::logging::LoggingWrapper::LogSeverity::FATAL, \
|
||||
(condition) ? false : true) \
|
||||
#define TFLITE_MAY_LOG(severity, should_log) \
|
||||
tflite::logging::LoggingWrapper( \
|
||||
tflite::logging::LoggingWrapper::LogSeverity::severity, (should_log)) \
|
||||
.Stream()
|
||||
|
||||
#define TFLITE_TOOLS_CHECK(condition) TFLITE_MAY_LOG(FATAL, !(condition))
|
||||
|
||||
#define TFLITE_TOOLS_CHECK_EQ(a, b) TFLITE_TOOLS_CHECK((a) == (b))
|
||||
|
||||
#endif // TENSORFLOW_LITE_TOOLS_LOGGING_H_
|
||||
|
@ -52,12 +52,17 @@ class ToolParam {
|
||||
}
|
||||
|
||||
virtual ~ToolParam() {}
|
||||
explicit ToolParam(ParamType type) : type_(type) {}
|
||||
explicit ToolParam(ParamType type) : has_value_set_(false), type_(type) {}
|
||||
|
||||
bool HasValueSet() const { return has_value_set_; }
|
||||
|
||||
virtual void Set(const ToolParam&) {}
|
||||
|
||||
virtual std::unique_ptr<ToolParam> Clone() const = 0;
|
||||
|
||||
protected:
|
||||
bool has_value_set_;
|
||||
|
||||
private:
|
||||
static void AssertHasSameType(ParamType a, ParamType b);
|
||||
|
||||
@ -70,7 +75,10 @@ class TypedToolParam : public ToolParam {
|
||||
explicit TypedToolParam(const T& value)
|
||||
: ToolParam(GetValueType<T>()), value_(value) {}
|
||||
|
||||
void Set(const T& value) { value_ = value; }
|
||||
void Set(const T& value) {
|
||||
value_ = value;
|
||||
has_value_set_ = true;
|
||||
}
|
||||
|
||||
T Get() const { return value_; }
|
||||
|
||||
@ -111,10 +119,16 @@ class ToolParams {
|
||||
params_.at(name)->AsTyped<T>()->Set(value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool HasValueSet(const std::string& name) const {
|
||||
AssertParamExists(name);
|
||||
return params_.at(name)->AsConstTyped<T>()->HasValueSet();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T Get(const std::string& name) const {
|
||||
AssertParamExists(name);
|
||||
return params_.at(name)->AsTyped<T>()->Get();
|
||||
return params_.at(name)->AsConstTyped<T>()->Get();
|
||||
}
|
||||
|
||||
// Set the value of all same parameters from 'other'.
|
||||
|
@ -33,7 +33,11 @@ TEST(ToolParams, SetTest) {
|
||||
|
||||
params.Set(others);
|
||||
EXPECT_EQ(19, params.Get<int>("some-int1"));
|
||||
EXPECT_TRUE(params.HasValueSet<int>("some-int1"));
|
||||
|
||||
EXPECT_EQ(17, params.Get<int>("some-int2"));
|
||||
EXPECT_FALSE(params.HasValueSet<int>("some-int2"));
|
||||
|
||||
EXPECT_FALSE(params.HasParam("some-bool"));
|
||||
}
|
||||
|
||||
|
@ -33,7 +33,7 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
# This value changes every day with an automatic CL. It can be modified in code
|
||||
# via `forward_compatibility_horizon()` or with the environment variable
|
||||
# TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
|
||||
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 7, 1)
|
||||
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 7, 2)
|
||||
_FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
|
||||
_FORWARD_COMPATIBILITY_DATE_NUMBER = None
|
||||
|
||||
|
@ -70,6 +70,36 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_executor",
|
||||
srcs = [
|
||||
"tpu_platform_registration.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"tpu_executor.h",
|
||||
"tpu_platform.h",
|
||||
"tpu_stream.h",
|
||||
"tpu_timer.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":device_memory_base_helper",
|
||||
":status_helper",
|
||||
":tpu_executor_base",
|
||||
":tpu_executor_c_api_hdrs",
|
||||
":tpu_executor_interface",
|
||||
":tpu_platform_interface",
|
||||
":tpu_stream_interface",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/core/tpu:tpu_api",
|
||||
"//tensorflow/stream_executor",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/stream_executor/platform",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_executor_base",
|
||||
srcs = [
|
||||
@ -82,6 +112,7 @@ cc_library(
|
||||
"tpu_stream.h",
|
||||
"tpu_timer.h",
|
||||
],
|
||||
visibility = ["//tensorflow/core/tpu:__pkg__"],
|
||||
deps = [
|
||||
":device_memory_base_helper",
|
||||
":status_helper",
|
||||
@ -98,17 +129,6 @@ cc_library(
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_executor",
|
||||
srcs = ["tpu_platform_registration.cc"],
|
||||
deps = [
|
||||
":tpu_executor_base",
|
||||
"//tensorflow/stream_executor/platform",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -139,7 +159,6 @@ cc_library(
|
||||
srcs = ["tpu_transfer_manager_registration.cc"],
|
||||
deps = [
|
||||
":tpu_executor",
|
||||
":tpu_executor_base",
|
||||
":tpu_transfer_manager_base",
|
||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||
],
|
||||
@ -153,7 +172,7 @@ cc_library(
|
||||
":c_api_conversions",
|
||||
":proto_helper",
|
||||
":status_helper",
|
||||
":tpu_executor_base",
|
||||
":tpu_executor",
|
||||
":tpu_executor_c_api_hdrs",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -171,7 +190,6 @@ cc_library(
|
||||
hdrs = ["tpu_computation_placer.h"],
|
||||
deps = [
|
||||
":tpu_executor",
|
||||
":tpu_executor_base",
|
||||
":tpu_executor_c_api_hdrs",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
|
@ -18,11 +18,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/tpu/tpu_api.h"
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
#include "tensorflow/stream_executor/lib/status.h"
|
||||
#include "tensorflow/stream_executor/tpu/device_memory_base_helper.h"
|
||||
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_stream.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_timer.h"
|
||||
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
#include "tensorflow/stream_executor/device_options.h"
|
||||
#include "tensorflow/stream_executor/event.h"
|
||||
#include "tensorflow/stream_executor/lib/status.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
#include "tensorflow/stream_executor/stream.h"
|
||||
#include "tensorflow/stream_executor/stream_executor.h"
|
||||
|
@ -18,10 +18,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/tpu/tpu_api.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -32,7 +30,7 @@ using Status = ::stream_executor::port::Status;
|
||||
template <typename T>
|
||||
using StatusOr = ::stream_executor::port::StatusOr<T>;
|
||||
|
||||
TpuPlatform::TpuPlatform() {
|
||||
TpuPlatform::TpuPlatform() : name_("TPU") {
|
||||
platform_ = tpu::ExecutorApiFn()->TpuPlatform_NewFn();
|
||||
}
|
||||
|
||||
@ -109,10 +107,7 @@ TpuPlatform::GetUncachedExecutor(
|
||||
return TpuPlatform::kId;
|
||||
}
|
||||
|
||||
const std::string& TpuPlatform::Name() const {
|
||||
static std::string* name = new std::string("TPU");
|
||||
return *name;
|
||||
}
|
||||
const std::string& TpuPlatform::Name() const { return name_; }
|
||||
|
||||
int64 TpuPlatform::TpuMemoryLimit() {
|
||||
return tpu::ExecutorApiFn()->TpuPlatform_TpuMemoryLimitFn(platform_);
|
||||
|
@ -121,7 +121,7 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {
|
||||
|
||||
private:
|
||||
SE_Platform* platform_;
|
||||
|
||||
std::string name_;
|
||||
stream_executor::ExecutorCache executor_cache_;
|
||||
StreamMap stream_map_;
|
||||
EventMap event_map_;
|
||||
|
@ -60,17 +60,17 @@ sudo ldconfig
|
||||
# Install Horovod.
|
||||
cd ..
|
||||
HOROVOD_WITH_TENSORFLOW=1
|
||||
pip3.7 install horovod[tensorflow]
|
||||
pip3.7 install horovod[tensorflow] --user
|
||||
|
||||
# Install tests.
|
||||
git clone https://github.com/DEKHTIARJonathan/TF_HVD_Stability_Test.git
|
||||
|
||||
# Install pytest.
|
||||
pip3.7 install -U pytest
|
||||
pip3.7 install -U pytest --user
|
||||
|
||||
# Install requirements.
|
||||
cd TF_HVD_Stability_Test
|
||||
pip3.7 install -r requirements.txt
|
||||
pip3.7 install -r requirements.txt --user
|
||||
|
||||
# Run the tests.
|
||||
python3.7 -m pytest
|
||||
|
@ -710,8 +710,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
|
||||
)
|
||||
|
||||
# Check out LLVM and MLIR from llvm-project.
|
||||
LLVM_COMMIT = "0f9d623b63e87b4ba30c30fd884ecc333eb32b4a"
|
||||
LLVM_SHA256 = "58dee49dd9e79eea829fd6ca2d57cd1bf927445771bb296985061bd7644d676d"
|
||||
LLVM_COMMIT = "d6343e607ac8fa71fa6d99f9c86369ae9e66e671"
|
||||
LLVM_SHA256 = "0824d59e80c99e64cafe6e8051c9861e534dee60f056dcd528d5fe00ebeb542f"
|
||||
LLVM_URLS = [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
|
||||
"https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
|
||||
|
23
third_party/mlir/BUILD
vendored
23
third_party/mlir/BUILD
vendored
@ -1478,6 +1478,7 @@ cc_library(
|
||||
":IR",
|
||||
":Pass",
|
||||
":SCFDialect",
|
||||
":SCFToSPIRV",
|
||||
":SPIRVDialect",
|
||||
":SPIRVLowering",
|
||||
":StandardToSPIRVTransforms",
|
||||
@ -2233,6 +2234,28 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "SCFToSPIRV",
|
||||
srcs = ["lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp"],
|
||||
hdrs = ["include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":Affine",
|
||||
":AffineToStandard",
|
||||
":ConversionPassIncGen",
|
||||
":IR",
|
||||
":Pass",
|
||||
":SCFDialect",
|
||||
":SPIRVDialect",
|
||||
":SPIRVLowering",
|
||||
":StandardOps",
|
||||
":Support",
|
||||
":TransformUtils",
|
||||
":Transforms",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "SCFToStandard",
|
||||
srcs = [
|
||||
|
4
third_party/mlir/tblgen.bzl
vendored
4
third_party/mlir/tblgen.bzl
vendored
@ -23,7 +23,7 @@ def gentbl(name, tblgen, td_file, tbl_outs, td_srcs = [], td_includes = [], td_r
|
||||
|
||||
td_includes_cmd = [
|
||||
"-I external/llvm-project/mlir/include -I external/org_tensorflow",
|
||||
"-I $(GENDIR)/external/llvm-project/mlir/include",
|
||||
"-I $(GENDIR)/external/llvm-project/mlir/include -I $(GENDIR)/external/org_tensorflow",
|
||||
]
|
||||
for td_include in td_includes:
|
||||
td_includes_cmd += [
|
||||
@ -32,7 +32,7 @@ def gentbl(name, tblgen, td_file, tbl_outs, td_srcs = [], td_includes = [], td_r
|
||||
]
|
||||
for td_include in td_relative_includes:
|
||||
td_includes_cmd += [
|
||||
"-I%s/%s" % (native.package_name(), td_include),
|
||||
"-I%s/%s -Iexternal/org_tensorflow/%s/%s" % (native.package_name(), td_include, native.package_name(), td_include),
|
||||
"-I$(GENDIR)/%s/%s" % (native.package_name(), td_include),
|
||||
]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user