merge from master
This commit is contained in:
commit
2381ee56d9
RELEASE.mdtf_mlir_opt_main.cc
tensorflow
c/eager
compiler
jit
mlir
hlo
lite
tensorflow
ir
tests
transforms
tfr
tools/kernel_gen
xla
tf2xla
xla
core
common_runtime
distributed_runtime/rpc
framework
allocator_test.ccbfloat16_test.ccfunction_test.ccfunction_testlib.ccop_kernel_test.ccrendezvous_test.cctensor_shape_test.cctensor_test.cc
grappler/optimizers
generic_layout_optimizer_transposer.ccgeneric_layout_optimizer_transposer.hgeneric_layout_optimizer_transposer_factory.cc
kernels
BUILDbatch_matmul_op_test.cc
data
dynamic_partition_op_test.cceigen_spatial_convolutions_test.ccexample_parsing_ops_test.ccgather_nd_op_test.ccgather_op_test.ccin_topk_op_test.cclinalg
lrn_op_test.ccmatmul_op.ccmatmul_op_complex.ccmatmul_op_impl.hmatmul_op_real.ccmatmul_op_test.ccmkl
mlir_generated/op_definitions
multinomial_op_test.ccnn_ops_test.ccone_hot_op_test.ccparameterized_truncated_normal_op_test.ccquantize_and_dequantize_op_test.ccquantized_concat_op_test.ccrandom_binomial_op_test.ccrandom_op_test.ccreduction_ops_test.ccregex_replace_op_test.ccrequantization_range_op_test.ccreverse_op_test.ccroll_op_test.ccsave_op_test.ccscan_ops_test.ccscatter_nd_op_test.ccscatter_op_test.cc@ -42,6 +42,10 @@
|
||||
* Removed deprecated `Interpreter::UseNNAPI(bool)` C++ API.
|
||||
* Use `NnApiDelegate()` and related delegate configuration methods
|
||||
directly.
|
||||
* TF Core:
|
||||
* Corrected higher-order gradients of control flow constructs (`tf.cond`,
|
||||
`tf.while_loop`, and compositions like `tf.foldl`) computed with
|
||||
`tf.GradientTape` inside a `tf.function`.
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
|
@ -769,7 +769,7 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) {
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(status));
|
||||
EXPECT_EQ(nullptr, t);
|
||||
const char* msg = "Matrix size-incompatible: In[0]: [2,2], In[1]: [3,2]";
|
||||
const char* msg = "In[0] mismatch In[1] shape: 2 vs. 3: [2,2] [3,2]";
|
||||
EXPECT_TRUE(strstr(TF_Message(status), msg) != nullptr)
|
||||
<< TF_Message(status);
|
||||
// Since error is not cleared, the following copy with correct device will
|
||||
|
@ -583,7 +583,11 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
XlaCompiler::Argument& arg = out[input_num];
|
||||
if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
|
||||
// Handles compile-time constants.
|
||||
TF_RET_CHECK(input->dtype() != DT_RESOURCE);
|
||||
|
||||
// TODO(b/157241314): Support constants located in resource variables.
|
||||
TF_RET_CHECK(input->dtype() != DT_RESOURCE)
|
||||
<< "tf2xla bridge does not support must-be-constants located in "
|
||||
"resource variables; try moving them to a tensor";
|
||||
arg.kind = XlaCompiler::Argument::kConstant;
|
||||
arg.type = input->dtype();
|
||||
arg.shape = input->shape();
|
||||
|
@ -517,6 +517,15 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "map_chlo_to_hlo_op",
|
||||
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "map_hlo_to_lhlo_op",
|
||||
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"],
|
||||
@ -606,9 +615,11 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":hlo",
|
||||
":map_chlo_to_hlo_op",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:Shape",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
@ -893,6 +904,7 @@ cc_library(
|
||||
deps = [
|
||||
":chlo_legalize_to_hlo_inc_gen",
|
||||
":hlo",
|
||||
":map_chlo_to_hlo_op",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:Shape",
|
||||
|
@ -0,0 +1,97 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_MHLO_OP_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_MHLO_OP_H_
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace chlo {
|
||||
|
||||
struct HloComplexAdaptor {
|
||||
static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type,
|
||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type,
|
||||
broadcasted_lhs, broadcasted_rhs);
|
||||
}
|
||||
};
|
||||
template <typename FromOpTy, typename ToOpTy>
|
||||
struct HloBinaryElementwiseAdaptor {
|
||||
static ToOpTy CreateOp(FromOpTy from_op, Type result_type,
|
||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<ToOpTy>(from_op.getLoc(), result_type,
|
||||
broadcasted_lhs, broadcasted_rhs);
|
||||
}
|
||||
};
|
||||
struct HloCompareAdaptor {
|
||||
static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
|
||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<mhlo::CompareOp>(
|
||||
from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs,
|
||||
from_op.comparison_direction(), from_op.compare_typeAttr());
|
||||
}
|
||||
};
|
||||
|
||||
// Populate a pattern for each Broadcasting CHlo op. This requires the pattern
|
||||
// to take a ChloOpTy, MhloOpTy, and an Adaptor as templated values.
|
||||
template <template <typename, typename, typename> class Pattern,
|
||||
typename... ConstructorArgs>
|
||||
void PopulateForBroadcastingBinaryOp(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns,
|
||||
ConstructorArgs &&...args) {
|
||||
#define POPULATE_BCAST(ChloOp, HloOp) \
|
||||
patterns->insert< \
|
||||
Pattern<ChloOp, HloOp, HloBinaryElementwiseAdaptor<ChloOp, HloOp>>>( \
|
||||
context, args...);
|
||||
|
||||
POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp);
|
||||
POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp);
|
||||
POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op);
|
||||
POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp);
|
||||
POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp);
|
||||
POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp);
|
||||
POPULATE_BCAST(BroadcastMulOp, mhlo::MulOp);
|
||||
POPULATE_BCAST(BroadcastOrOp, mhlo::OrOp);
|
||||
POPULATE_BCAST(BroadcastPowOp, mhlo::PowOp);
|
||||
POPULATE_BCAST(BroadcastRemOp, mhlo::RemOp);
|
||||
POPULATE_BCAST(BroadcastShiftLeftOp, mhlo::ShiftLeftOp);
|
||||
POPULATE_BCAST(BroadcastShiftRightArithmeticOp, mhlo::ShiftRightArithmeticOp);
|
||||
POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp);
|
||||
POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp);
|
||||
POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp);
|
||||
|
||||
// Broadcasting ops requiring special construction.
|
||||
patterns
|
||||
->insert<Pattern<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>>(
|
||||
context, args...);
|
||||
patterns
|
||||
->insert<Pattern<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>>(
|
||||
context, args...);
|
||||
|
||||
#undef POPULATE_BCAST
|
||||
}
|
||||
|
||||
} // namespace chlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_HLO_OP_H_
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir-hlo/utils/broadcast_utils.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
@ -69,13 +70,18 @@ struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> {
|
||||
// Converts binary ops that statically are determined to not broadcast directly
|
||||
// to the corresponding mhlo non-broadcasting op.
|
||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||
struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
|
||||
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(ChloOpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
struct ConvertTrivialNonBroadcastBinaryOp
|
||||
: public OpConversionPattern<ChloOpTy> {
|
||||
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(
|
||||
ChloOpTy op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Only rewrite for statically determinable non-broadcasting cases.
|
||||
auto lhs_type = op.lhs().getType().template dyn_cast<RankedTensorType>();
|
||||
auto rhs_type = op.rhs().getType().template dyn_cast<RankedTensorType>();
|
||||
typename ChloOpTy::Adaptor transformed(operands);
|
||||
auto lhs_type =
|
||||
transformed.lhs().getType().template dyn_cast<RankedTensorType>();
|
||||
auto rhs_type =
|
||||
transformed.rhs().getType().template dyn_cast<RankedTensorType>();
|
||||
if (!lhs_type || !rhs_type) return failure();
|
||||
|
||||
// Requires rank broadcast.
|
||||
@ -93,8 +99,9 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, {Adaptor::CreateOp(op, op.getResult().getType(),
|
||||
op.lhs(), op.rhs(), rewriter)});
|
||||
rewriter.replaceOp(
|
||||
op, {Adaptor::CreateOp(op, op.getResult().getType(), operands[0],
|
||||
operands[1], rewriter)});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -113,13 +120,15 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
|
||||
// `shape.broadcast` op, which only supports prefix-padding.
|
||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||
struct ConvertRankedDynamicBroadcastBinaryOp
|
||||
: public OpRewritePattern<ChloOpTy> {
|
||||
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(ChloOpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
: public OpConversionPattern<ChloOpTy> {
|
||||
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(
|
||||
ChloOpTy op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Only support ranked operands.
|
||||
Value lhs = op.lhs();
|
||||
Value rhs = op.rhs();
|
||||
typename ChloOpTy::Adaptor transformed(operands);
|
||||
Value lhs = transformed.lhs();
|
||||
Value rhs = transformed.rhs();
|
||||
auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
|
||||
auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
|
||||
auto result_type =
|
||||
@ -193,324 +202,6 @@ struct ConvertRankedDynamicBroadcastBinaryOp
|
||||
}
|
||||
};
|
||||
|
||||
// Converts a broadcasting binary operation with a scalar operand and an
|
||||
// unranked operand to a ranked broadcasting operation by dynamically reshaping
|
||||
// the unranked operand to a 1D tensor. This will always be safe because
|
||||
// broadcasting from a scalar to another shape always works.
|
||||
template <typename ChloOpTy, typename HloOpTy>
|
||||
struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
|
||||
: public OpRewritePattern<ChloOpTy> {
|
||||
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(ChloOpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
Value lhs = op.lhs();
|
||||
Value rhs = op.rhs();
|
||||
|
||||
auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
|
||||
auto lhs_unranked_type = lhs.getType().dyn_cast<UnrankedTensorType>();
|
||||
|
||||
auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
|
||||
auto rhs_unranked_type = rhs.getType().dyn_cast<UnrankedTensorType>();
|
||||
|
||||
bool lhs_is_scalar = lhs_ranked_type &&
|
||||
lhs_ranked_type.getShape().empty() &&
|
||||
rhs_unranked_type;
|
||||
bool rhs_is_scalar = rhs_ranked_type &&
|
||||
rhs_ranked_type.getShape().empty() &&
|
||||
lhs_unranked_type;
|
||||
|
||||
// Only support the case where exactly one operand is scalar and the other
|
||||
// is unranked. Other patterns in this file will create more efficient
|
||||
// lowerings for cases where both ranks are known or will handle the more
|
||||
// generic case of both inputs being unranked.
|
||||
if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure();
|
||||
|
||||
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
||||
|
||||
// Reshape the non-scalar value into a dynamically sized, rank-1 tensor
|
||||
Value shape =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs);
|
||||
Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape);
|
||||
Value size_tensor =
|
||||
rewriter.create<TensorFromElementsOp>(loc, num_elements);
|
||||
Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, RankedTensorType::get({-1}, result_type.getElementType()),
|
||||
lhs_is_scalar ? rhs : lhs, size_tensor);
|
||||
|
||||
// Create a new ranked Chlo op that will be further lowered by other
|
||||
// patterns into Mhlo.
|
||||
SmallVector<Value, 2> operands{lhs_is_scalar ? lhs : reshaped,
|
||||
rhs_is_scalar ? rhs : reshaped};
|
||||
Value computed = rewriter.create<ChloOpTy>(
|
||||
loc, SmallVector<Type, 1>{reshaped.getType()}, operands, op.getAttrs());
|
||||
|
||||
// Reshape the result back into an unranked tensor.
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type,
|
||||
computed, shape);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Handles lowering of the following pattern to patterns that will be further
|
||||
// matched by other patterns until they result in LHLO:
|
||||
// %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy>
|
||||
//
|
||||
// The sequence of specializations this handles is:
|
||||
// - Either operand being scalar
|
||||
// - Operands having equal shapes
|
||||
// - The resulting value being any of ranks [2,6]
|
||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||
struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||
: public OpRewritePattern<ChloOpTy> {
|
||||
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ChloOpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
Value lhs = op.lhs();
|
||||
Value rhs = op.rhs();
|
||||
auto lhs_type = lhs.getType().dyn_cast<UnrankedTensorType>();
|
||||
auto rhs_type = rhs.getType().dyn_cast<UnrankedTensorType>();
|
||||
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
||||
|
||||
// Only support unranked operands. If either operand is ranked, another
|
||||
// pattern will handle the lowering.
|
||||
if (!lhs_type || !rhs_type) return failure();
|
||||
|
||||
// If lhs is scalar
|
||||
auto if_op = rewriter.create<scf::IfOp>(
|
||||
loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
|
||||
OpBuilder if_lhs_scalar_builder = if_op.getThenBodyBuilder();
|
||||
Value reshaped_lhs = if_lhs_scalar_builder.create<TensorCastOp>(
|
||||
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
|
||||
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
|
||||
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
|
||||
op.getAttrs());
|
||||
if_lhs_scalar_builder.create<scf::YieldOp>(loc, if_lhs_scalar_result);
|
||||
|
||||
// If lhs is NOT scalar
|
||||
//
|
||||
// See if rhs is scalar
|
||||
OpBuilder else_lhs_scalar_builder = if_op.getElseBodyBuilder();
|
||||
auto if_rhs_scalar_op = else_lhs_scalar_builder.create<scf::IfOp>(
|
||||
loc, result_type, IsScalarTensor(else_lhs_scalar_builder, op, rhs),
|
||||
true);
|
||||
else_lhs_scalar_builder.create<scf::YieldOp>(loc,
|
||||
if_rhs_scalar_op.getResult(0));
|
||||
OpBuilder if_rhs_scalar_builder = if_rhs_scalar_op.getThenBodyBuilder();
|
||||
Value reshaped_rhs = if_rhs_scalar_builder.create<TensorCastOp>(
|
||||
loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
|
||||
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
|
||||
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
|
||||
op.getAttrs());
|
||||
if_rhs_scalar_builder.create<scf::YieldOp>(loc, if_rhs_scalar_result);
|
||||
|
||||
// If NEITHER shape is scalar
|
||||
//
|
||||
// See if shapes are equal.
|
||||
OpBuilder else_no_scalars_builder = if_rhs_scalar_op.getElseBodyBuilder();
|
||||
Value shape_of_lhs =
|
||||
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs);
|
||||
Value shape_of_rhs =
|
||||
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs);
|
||||
Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>(
|
||||
loc, shape_of_lhs, shape_of_rhs);
|
||||
|
||||
auto if_eq_shapes_op = else_no_scalars_builder.create<scf::IfOp>(
|
||||
loc, result_type, equal_shapes, true);
|
||||
else_no_scalars_builder.create<scf::YieldOp>(loc,
|
||||
if_eq_shapes_op.getResult(0));
|
||||
|
||||
OpBuilder if_eq_shapes_builder = if_eq_shapes_op.getThenBodyBuilder();
|
||||
Value non_broadcast_op =
|
||||
Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder);
|
||||
if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
|
||||
|
||||
// If shapes are not scalar, nor equal
|
||||
//
|
||||
// See if values are of a rank that we support.
|
||||
OpBuilder if_neq_shapes_builder = if_eq_shapes_op.getElseBodyBuilder();
|
||||
if_neq_shapes_builder.create<scf::YieldOp>(
|
||||
loc, HandleBroadcastAndOp(if_neq_shapes_builder, op, lhs, rhs));
|
||||
|
||||
rewriter.replaceOp(op, {if_op.getResult(0)});
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
// Returns the dyanamic result of checking the given value is a scalar
|
||||
// tensor.
|
||||
Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
Value shape_of_tensor = rewriter.create<shape::ShapeOfOp>(loc, tensor);
|
||||
Value rank_tensor = rewriter.create<shape::RankOp>(
|
||||
loc, rewriter.getIndexType(), shape_of_tensor);
|
||||
return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq,
|
||||
rank_tensor,
|
||||
rewriter.create<ConstantIndexOp>(loc, 0));
|
||||
}
|
||||
|
||||
// Create the if statement and code for a broadcasting op with a result of a
|
||||
// given rank.
|
||||
scf::IfOp createRankSpecializedBroadcastAndOp(OpBuilder &builder, ChloOpTy op,
|
||||
Value lhs, Value rhs,
|
||||
Value actual_rank,
|
||||
int targeted_rank) const {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
// Create the if block to place the current specialized logic in.
|
||||
Value greater_rank_is_n = builder.create<CmpIOp>(
|
||||
loc, CmpIPredicate::eq, actual_rank,
|
||||
builder.create<ConstantIndexOp>(loc, targeted_rank));
|
||||
auto if_op =
|
||||
builder.create<scf::IfOp>(loc, lhs.getType(), greater_rank_is_n, true);
|
||||
OpBuilder if_builder = if_op.getThenBodyBuilder();
|
||||
|
||||
// Handle shape broadcasting and inferrence.
|
||||
Value lhs_shape = if_builder.create<shape::ShapeOfOp>(loc, lhs);
|
||||
Value rhs_shape = if_builder.create<shape::ShapeOfOp>(loc, rhs);
|
||||
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
|
||||
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
|
||||
{RankedTensorType::kDynamicSize}, builder.getIndexType());
|
||||
auto known_rank_extent_tensor_type =
|
||||
RankedTensorType::get({targeted_rank}, builder.getIndexType());
|
||||
auto reshaped_type = RankedTensorType::get(
|
||||
llvm::SmallVector<int64_t, 6>(targeted_rank,
|
||||
RankedTensorType::kDynamicSize),
|
||||
lhs.getType().template dyn_cast<TensorType>().getElementType());
|
||||
Value ranked_shape_val = if_builder.create<shape::ConstShapeOp>(
|
||||
loc, known_rank_extent_tensor_type,
|
||||
mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
|
||||
ranked_shape));
|
||||
Value extended_lhs = if_builder.create<shape::BroadcastOp>(
|
||||
loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val,
|
||||
nullptr);
|
||||
Value extended_lhs_casted = if_builder.create<TensorCastOp>(
|
||||
loc, known_rank_extent_tensor_type, extended_lhs);
|
||||
Value extended_rhs = if_builder.create<shape::BroadcastOp>(
|
||||
loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val,
|
||||
nullptr);
|
||||
Value extended_rhs_casted = if_builder.create<TensorCastOp>(
|
||||
loc, known_rank_extent_tensor_type, extended_rhs);
|
||||
|
||||
// 1. Reshape operands to the given rank (with the same number of elements)
|
||||
// 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
|
||||
// can be broadcasted and do the actual broadcasting)
|
||||
// 3. Type erase the output back to unranked
|
||||
Value reshaped_lhs = if_builder.create<mhlo::DynamicReshapeOp>(
|
||||
loc, reshaped_type, lhs, extended_lhs_casted);
|
||||
Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>(
|
||||
loc, reshaped_type, rhs, extended_rhs_casted);
|
||||
Value result = if_builder.create<ChloOpTy>(
|
||||
loc, ArrayRef<Type>{reshaped_type},
|
||||
ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
|
||||
Value reshaped_result = if_builder.create<TensorCastOp>(
|
||||
loc, UnrankedTensorType::get(reshaped_type.getElementType()), result);
|
||||
if_builder.create<scf::YieldOp>(loc, reshaped_result);
|
||||
|
||||
// Return the if_op, so the result can be used and the else block can be
|
||||
// used for the next rank specialized step.
|
||||
return if_op;
|
||||
}
|
||||
|
||||
// Iterates over the desired ranks to be specialized and generates the code
|
||||
// snippet for each case.
|
||||
Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs,
|
||||
Value rhs) const {
|
||||
constexpr int max_rank_specialization = 7;
|
||||
auto loc = op.getLoc();
|
||||
|
||||
// Find the larger rank of the 2 operands.
|
||||
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
rewriter.getIndexType());
|
||||
Value lhs_shape =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, lhs);
|
||||
Value rhs_shape =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs);
|
||||
Value lhs_rank =
|
||||
rewriter.create<RankOp>(loc, rewriter.getIndexType(), lhs_shape);
|
||||
Value rhs_rank =
|
||||
rewriter.create<RankOp>(loc, rewriter.getIndexType(), rhs_shape);
|
||||
Value greater_rank_lhs =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank);
|
||||
Value greater_rank =
|
||||
rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank);
|
||||
|
||||
// Generate a list of nested if/else statements to handle rank
|
||||
// specializations from 2-6.
|
||||
scf::IfOp if_op = createRankSpecializedBroadcastAndOp(rewriter, op, lhs,
|
||||
rhs, greater_rank, 2);
|
||||
|
||||
// Put each subsequent rank specialization inside the else statement of the
|
||||
// previous one.
|
||||
OpBuilder else_builder = if_op.getElseBodyBuilder();
|
||||
for (int i = 3; i < max_rank_specialization; i++) {
|
||||
auto inner_if = createRankSpecializedBroadcastAndOp(else_builder, op, lhs,
|
||||
rhs, greater_rank, i);
|
||||
|
||||
else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
|
||||
else_builder = inner_if.getElseBodyBuilder();
|
||||
}
|
||||
|
||||
// Fire an assertion if none of the rank specializations applied (one of the
|
||||
// ranks was greater than 6).
|
||||
else_builder.create<AssertOp>(
|
||||
loc, else_builder.create<ConstantIntOp>(loc, 0, 1),
|
||||
"Input for dynamic binary op lowering was of a rank greater than 6");
|
||||
else_builder.create<scf::YieldOp>(loc, lhs);
|
||||
|
||||
// Return the result of the outermost if statement.
|
||||
return if_op.getResult(0);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||
void PopulateForBinaryOp(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns) {
|
||||
patterns
|
||||
->insert<ConvertTrivialNonBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
|
||||
context, 10);
|
||||
patterns->insert<
|
||||
ConvertRankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
|
||||
context, 5);
|
||||
patterns->insert<
|
||||
ConvertUnrankedScalarDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy>,
|
||||
ConvertUnrankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
|
||||
context);
|
||||
}
|
||||
|
||||
template <typename FromOpTy, typename ToOpTy>
|
||||
struct HloBinaryElementwiseAdaptor {
|
||||
static ToOpTy CreateOp(FromOpTy from_op, Type result_type,
|
||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<ToOpTy>(from_op.getLoc(), result_type,
|
||||
broadcasted_lhs, broadcasted_rhs);
|
||||
}
|
||||
};
|
||||
|
||||
struct HloComplexAdaptor {
|
||||
static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type,
|
||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type,
|
||||
broadcasted_lhs, broadcasted_rhs);
|
||||
}
|
||||
};
|
||||
|
||||
struct HloCompareAdaptor {
|
||||
static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
|
||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<mhlo::CompareOp>(
|
||||
from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs,
|
||||
from_op.comparison_direction(), from_op.compare_typeAttr());
|
||||
}
|
||||
};
|
||||
|
||||
#include "generated_chlo_legalize_to_hlo.inc"
|
||||
} // namespace
|
||||
|
||||
@ -521,32 +212,10 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||
// Instantiate conversion templates for conforming binary elementwise ops
|
||||
// that do not have different dtypes between operands and results and do
|
||||
// not have special attributes that need to be preserved.
|
||||
#define POPULATE_BCAST(ChloOp, HloOp) \
|
||||
PopulateForBinaryOp<ChloOp, HloOp, \
|
||||
HloBinaryElementwiseAdaptor<ChloOp, HloOp>>(context, \
|
||||
patterns);
|
||||
|
||||
POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp);
|
||||
POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp);
|
||||
POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op);
|
||||
POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp);
|
||||
POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp);
|
||||
POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp);
|
||||
POPULATE_BCAST(BroadcastMulOp, mhlo::MulOp);
|
||||
POPULATE_BCAST(BroadcastOrOp, mhlo::OrOp);
|
||||
POPULATE_BCAST(BroadcastPowOp, mhlo::PowOp);
|
||||
POPULATE_BCAST(BroadcastRemOp, mhlo::RemOp);
|
||||
POPULATE_BCAST(BroadcastShiftLeftOp, mhlo::ShiftLeftOp);
|
||||
POPULATE_BCAST(BroadcastShiftRightArithmeticOp, mhlo::ShiftRightArithmeticOp);
|
||||
POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp);
|
||||
POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp);
|
||||
POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp);
|
||||
|
||||
// Broadcasting ops requiring special construction.
|
||||
PopulateForBinaryOp<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>(
|
||||
context, patterns);
|
||||
PopulateForBinaryOp<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>(
|
||||
context, patterns);
|
||||
PopulateForBroadcastingBinaryOp<ConvertTrivialNonBroadcastBinaryOp>(
|
||||
context, patterns, 10);
|
||||
PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
|
||||
context, patterns, 5);
|
||||
|
||||
// Other patterns.
|
||||
patterns->insert<ConvertConstantLikeOp>(context);
|
||||
|
@ -16,7 +16,9 @@ limitations under the License.
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
@ -126,6 +128,291 @@ struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
||||
}
|
||||
};
|
||||
|
||||
// Converts a broadcasting binary operation with a scalar operand and an
|
||||
// unranked operand to a ranked broadcasting operation by dynamically reshaping
|
||||
// the unranked operand to a 1D tensor. This will always be safe because
|
||||
// broadcasting from a scalar to another shape always works.
|
||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||
struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
|
||||
: public OpConversionPattern<ChloOpTy> {
|
||||
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(
|
||||
ChloOpTy op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
typename ChloOpTy::Adaptor transformed(operands);
|
||||
Value lhs = transformed.lhs();
|
||||
Value rhs = transformed.rhs();
|
||||
|
||||
auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
|
||||
auto lhs_unranked_type = lhs.getType().dyn_cast<UnrankedTensorType>();
|
||||
|
||||
auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
|
||||
auto rhs_unranked_type = rhs.getType().dyn_cast<UnrankedTensorType>();
|
||||
|
||||
bool lhs_is_scalar = lhs_ranked_type &&
|
||||
lhs_ranked_type.getShape().empty() &&
|
||||
rhs_unranked_type;
|
||||
bool rhs_is_scalar = rhs_ranked_type &&
|
||||
rhs_ranked_type.getShape().empty() &&
|
||||
lhs_unranked_type;
|
||||
|
||||
// Only support the case where exactly one operand is scalar and the other
|
||||
// is unranked. Other patterns in chlo-to-hlo legalization will create more
|
||||
// efficient lowerings for cases where both ranks are known or will handle
|
||||
// the more generic case of both inputs being unranked.
|
||||
if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure();
|
||||
|
||||
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
||||
|
||||
// Reshape the non-scalar value into a dynamically sized, rank-1 tensor
|
||||
Value shape =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs);
|
||||
Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape);
|
||||
Value size_tensor =
|
||||
rewriter.create<TensorFromElementsOp>(loc, num_elements);
|
||||
Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, RankedTensorType::get({-1}, result_type.getElementType()),
|
||||
lhs_is_scalar ? rhs : lhs, size_tensor);
|
||||
|
||||
// Create a new ranked Chlo op that will be further lowered by other
|
||||
// patterns into Mhlo.
|
||||
SmallVector<Value, 2> new_operands{lhs_is_scalar ? lhs : reshaped,
|
||||
rhs_is_scalar ? rhs : reshaped};
|
||||
Value computed =
|
||||
rewriter.create<ChloOpTy>(loc, SmallVector<Type, 1>{reshaped.getType()},
|
||||
new_operands, op.getAttrs());
|
||||
|
||||
// Reshape the result back into an unranked tensor.
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type,
|
||||
computed, shape);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Handles lowering of the following pattern to patterns that will be further
|
||||
// matched by other patterns until they result in LHLO:
|
||||
// %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy>
|
||||
//
|
||||
// The sequence of specializations this handles is:
|
||||
// - Either operand being scalar
|
||||
// - Operands having equal shapes
|
||||
// - The resulting value being any of ranks [2,6]
|
||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||
struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||
: public OpConversionPattern<ChloOpTy> {
|
||||
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
ChloOpTy op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
typename ChloOpTy::Adaptor transformed(operands);
|
||||
Value lhs = transformed.lhs();
|
||||
Value rhs = transformed.rhs();
|
||||
auto lhs_type = lhs.getType().dyn_cast<UnrankedTensorType>();
|
||||
auto rhs_type = rhs.getType().dyn_cast<UnrankedTensorType>();
|
||||
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
||||
|
||||
// Only support unranked operands. If either operand is ranked, another
|
||||
// pattern will handle the lowering.
|
||||
if (!lhs_type || !rhs_type) return failure();
|
||||
|
||||
// If lhs is scalar
|
||||
auto if_op = rewriter.create<scf::IfOp>(
|
||||
loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
|
||||
OpBuilder if_lhs_scalar_builder =
|
||||
if_op.getThenBodyBuilder(rewriter.getListener());
|
||||
Value reshaped_lhs = if_lhs_scalar_builder.create<TensorCastOp>(
|
||||
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
|
||||
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
|
||||
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
|
||||
op.getAttrs());
|
||||
if_lhs_scalar_builder.create<scf::YieldOp>(loc, if_lhs_scalar_result);
|
||||
|
||||
// If lhs is NOT scalar
|
||||
//
|
||||
// See if rhs is scalar
|
||||
OpBuilder else_lhs_scalar_builder =
|
||||
if_op.getElseBodyBuilder(rewriter.getListener());
|
||||
auto if_rhs_scalar_op = else_lhs_scalar_builder.create<scf::IfOp>(
|
||||
loc, result_type, IsScalarTensor(else_lhs_scalar_builder, op, rhs),
|
||||
true);
|
||||
else_lhs_scalar_builder.create<scf::YieldOp>(loc,
|
||||
if_rhs_scalar_op.getResult(0));
|
||||
OpBuilder if_rhs_scalar_builder =
|
||||
if_rhs_scalar_op.getThenBodyBuilder(rewriter.getListener());
|
||||
Value reshaped_rhs = if_rhs_scalar_builder.create<TensorCastOp>(
|
||||
loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
|
||||
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
|
||||
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
|
||||
op.getAttrs());
|
||||
if_rhs_scalar_builder.create<scf::YieldOp>(loc, if_rhs_scalar_result);
|
||||
|
||||
// If NEITHER shape is scalar
|
||||
//
|
||||
// See if shapes are equal.
|
||||
OpBuilder else_no_scalars_builder =
|
||||
if_rhs_scalar_op.getElseBodyBuilder(rewriter.getListener());
|
||||
Value shape_of_lhs =
|
||||
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs);
|
||||
Value shape_of_rhs =
|
||||
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs);
|
||||
Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>(
|
||||
loc, shape_of_lhs, shape_of_rhs);
|
||||
|
||||
auto if_eq_shapes_op = else_no_scalars_builder.create<scf::IfOp>(
|
||||
loc, result_type, equal_shapes, true);
|
||||
else_no_scalars_builder.create<scf::YieldOp>(loc,
|
||||
if_eq_shapes_op.getResult(0));
|
||||
|
||||
OpBuilder if_eq_shapes_builder =
|
||||
if_eq_shapes_op.getThenBodyBuilder(rewriter.getListener());
|
||||
Value non_broadcast_op =
|
||||
Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder);
|
||||
if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
|
||||
|
||||
// If shapes are not scalar, nor equal
|
||||
//
|
||||
// See if values are of a rank that we support.
|
||||
OpBuilder if_neq_shapes_builder =
|
||||
if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener());
|
||||
if_neq_shapes_builder.create<scf::YieldOp>(
|
||||
loc, HandleBroadcastAndOp(if_neq_shapes_builder, op, lhs, rhs));
|
||||
|
||||
rewriter.replaceOp(op, {if_op.getResult(0)});
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
// Returns the dyanamic result of checking the given value is a scalar
|
||||
// tensor.
|
||||
Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
Value shape_of_tensor = rewriter.create<shape::ShapeOfOp>(loc, tensor);
|
||||
Value rank_tensor = rewriter.create<shape::RankOp>(
|
||||
loc, rewriter.getIndexType(), shape_of_tensor);
|
||||
return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq,
|
||||
rank_tensor,
|
||||
rewriter.create<ConstantIndexOp>(loc, 0));
|
||||
}
|
||||
|
||||
// Create the if statement and code for a broadcasting op with a result of a
|
||||
// given rank.
|
||||
scf::IfOp createRankSpecializedBroadcastAndOp(OpBuilder &builder, ChloOpTy op,
|
||||
Value lhs, Value rhs,
|
||||
Value actual_rank,
|
||||
int targeted_rank) const {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
// Create the if block to place the current specialized logic in.
|
||||
Value greater_rank_is_n = builder.create<CmpIOp>(
|
||||
loc, CmpIPredicate::eq, actual_rank,
|
||||
builder.create<ConstantIndexOp>(loc, targeted_rank));
|
||||
auto if_op =
|
||||
builder.create<scf::IfOp>(loc, lhs.getType(), greater_rank_is_n, true);
|
||||
OpBuilder if_builder = if_op.getThenBodyBuilder(builder.getListener());
|
||||
|
||||
// Handle shape broadcasting and inferrence.
|
||||
Value lhs_shape = if_builder.create<shape::ShapeOfOp>(loc, lhs);
|
||||
Value rhs_shape = if_builder.create<shape::ShapeOfOp>(loc, rhs);
|
||||
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
|
||||
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
|
||||
{RankedTensorType::kDynamicSize}, builder.getIndexType());
|
||||
auto known_rank_extent_tensor_type =
|
||||
RankedTensorType::get({targeted_rank}, builder.getIndexType());
|
||||
auto reshaped_type = RankedTensorType::get(
|
||||
llvm::SmallVector<int64_t, 6>(targeted_rank,
|
||||
RankedTensorType::kDynamicSize),
|
||||
lhs.getType().template dyn_cast<TensorType>().getElementType());
|
||||
Value ranked_shape_val = if_builder.create<shape::ConstShapeOp>(
|
||||
loc, known_rank_extent_tensor_type,
|
||||
mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
|
||||
ranked_shape));
|
||||
Value extended_lhs = if_builder.create<shape::BroadcastOp>(
|
||||
loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val,
|
||||
nullptr);
|
||||
Value extended_lhs_casted = if_builder.create<TensorCastOp>(
|
||||
loc, known_rank_extent_tensor_type, extended_lhs);
|
||||
Value extended_rhs = if_builder.create<shape::BroadcastOp>(
|
||||
loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val,
|
||||
nullptr);
|
||||
Value extended_rhs_casted = if_builder.create<TensorCastOp>(
|
||||
loc, known_rank_extent_tensor_type, extended_rhs);
|
||||
|
||||
// 1. Reshape operands to the given rank (with the same number of elements)
|
||||
// 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
|
||||
// can be broadcasted and do the actual broadcasting)
|
||||
// 3. Type erase the output back to unranked
|
||||
Value reshaped_lhs = if_builder.create<mhlo::DynamicReshapeOp>(
|
||||
loc, reshaped_type, lhs, extended_lhs_casted);
|
||||
Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>(
|
||||
loc, reshaped_type, rhs, extended_rhs_casted);
|
||||
Value result = if_builder.create<ChloOpTy>(
|
||||
loc, ArrayRef<Type>{reshaped_type},
|
||||
ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
|
||||
Value reshaped_result = if_builder.create<TensorCastOp>(
|
||||
loc, UnrankedTensorType::get(reshaped_type.getElementType()), result);
|
||||
if_builder.create<scf::YieldOp>(loc, reshaped_result);
|
||||
|
||||
// Return the if_op, so the result can be used and the else block can be
|
||||
// used for the next rank specialized step.
|
||||
return if_op;
|
||||
}
|
||||
|
||||
// Iterates over the desired ranks to be specialized and generates the code
|
||||
// snippet for each case.
|
||||
Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs,
|
||||
Value rhs) const {
|
||||
constexpr int max_rank_specialization = 7;
|
||||
auto loc = op.getLoc();
|
||||
|
||||
// Find the larger rank of the 2 operands.
|
||||
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
rewriter.getIndexType());
|
||||
Value lhs_shape =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, lhs);
|
||||
Value rhs_shape =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs);
|
||||
Value lhs_rank =
|
||||
rewriter.create<RankOp>(loc, rewriter.getIndexType(), lhs_shape);
|
||||
Value rhs_rank =
|
||||
rewriter.create<RankOp>(loc, rewriter.getIndexType(), rhs_shape);
|
||||
Value greater_rank_lhs =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank);
|
||||
Value greater_rank =
|
||||
rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank);
|
||||
|
||||
// Generate a list of nested if/else statements to handle rank
|
||||
// specializations from 2-6.
|
||||
scf::IfOp if_op = createRankSpecializedBroadcastAndOp(rewriter, op, lhs,
|
||||
rhs, greater_rank, 2);
|
||||
|
||||
// Put each subsequent rank specialization inside the else statement of the
|
||||
// previous one.
|
||||
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
|
||||
for (int i = 3; i < max_rank_specialization; i++) {
|
||||
auto inner_if = createRankSpecializedBroadcastAndOp(else_builder, op, lhs,
|
||||
rhs, greater_rank, i);
|
||||
|
||||
else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
|
||||
else_builder = inner_if.getElseBodyBuilder(rewriter.getListener());
|
||||
}
|
||||
|
||||
// Fire an assertion if none of the rank specializations applied (one of the
|
||||
// ranks was greater than 6).
|
||||
else_builder.create<AssertOp>(
|
||||
loc, else_builder.create<ConstantIntOp>(loc, 0, 1),
|
||||
"Input for dynamic binary op lowering was of a rank greater than 6");
|
||||
else_builder.create<scf::YieldOp>(loc, lhs);
|
||||
|
||||
// Return the result of the outermost if statement.
|
||||
return if_op.getResult(0);
|
||||
}
|
||||
};
|
||||
|
||||
struct TransformUnrankedHloPass
|
||||
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
@ -137,7 +424,7 @@ struct TransformUnrankedHloPass
|
||||
MLIRContext &ctx = getContext();
|
||||
ConversionTarget target(ctx);
|
||||
target.addLegalDialect<mhlo::MhloDialect, StandardOpsDialect,
|
||||
shape::ShapeDialect>();
|
||||
shape::ShapeDialect, scf::SCFDialect>();
|
||||
target.addLegalOp<FuncOp>();
|
||||
#define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target)
|
||||
#define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target)
|
||||
@ -148,6 +435,12 @@ struct TransformUnrankedHloPass
|
||||
#undef ADD_LEGAL_CHLO
|
||||
AddLegalOpOnRankedTensor<mhlo::CompareOp>(&target);
|
||||
AddLegalOpOnRankedTensor<mhlo::SelectOp>(&target);
|
||||
target.addDynamicallyLegalDialect<chlo::HloClientDialect>(
|
||||
[](Operation *op) {
|
||||
return !llvm::any_of(op->getOperandTypes(), [](Type type) {
|
||||
return type.isa<UnrankedTensorType>();
|
||||
});
|
||||
});
|
||||
|
||||
// Populate rewrite patterns.
|
||||
OwningRewritePatternList patterns;
|
||||
@ -180,6 +473,10 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
||||
#undef MAP_BINARY
|
||||
#undef MAP_CHLO_UNARY
|
||||
#undef COMMA
|
||||
chlo::PopulateForBroadcastingBinaryOp<
|
||||
ConvertUnrankedDynamicBroadcastBinaryOp>(context, patterns);
|
||||
chlo::PopulateForBroadcastingBinaryOp<
|
||||
ConvertUnrankedScalarDynamicBroadcastBinaryOp>(context, patterns);
|
||||
}
|
||||
|
||||
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {
|
||||
|
@ -237,209 +237,3 @@ func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x
|
||||
%0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
||||
return %0 : tensor<4xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<f32>, tensor<*xf32>)
|
||||
-> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @addScalarUnranked(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>,
|
||||
// CHECK-SAME: %[[ARG_1:.*]]: tensor<*xf32>
|
||||
// CHECK-SAME: ) -> tensor<*xf32> {
|
||||
// First handle the dynamic reshaping of the unranked operand
|
||||
// to a 1D tensor.
|
||||
// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32>
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor<?xindex> -> index
|
||||
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
|
||||
// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// The assuming region is part of the second stage of lowering
|
||||
// with ranked broadcasting logic.
|
||||
// CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<f32>
|
||||
// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32>
|
||||
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_0]], %[[SHAPE_RESHAPED]]
|
||||
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
|
||||
// CHECK: %[[SCALAR_SHAPE:.*]] = shape.const_shape []
|
||||
// CHECK: %[[BROADCASTED_SHAPE:.*]] = shape.broadcast %[[SCALAR_SHAPE]], %[[SHAPE_RESHAPED]]
|
||||
// CHECK: %[[SHAPE_TENSOR:.*]] = tensor_cast %[[BROADCASTED_SHAPE]] : tensor<?xindex> to tensor<1xindex>
|
||||
// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_0]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32>
|
||||
// CHECK: shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor<?xf32>
|
||||
// CHECK: }
|
||||
// As part of the unranked logic, the result is reshaped back
|
||||
// to an unranked tensor.
|
||||
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf32> {
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<f32>)
|
||||
-> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @addUnrankedScalar(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>,
|
||||
// CHECK-SAME: %[[ARG_1:.*]]: tensor<f32>) -> tensor<*xf32> {
|
||||
// First handle the dynamic reshaping of the unranked operand
|
||||
// to a 1D tensor.
|
||||
// CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32>
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor<?xindex> -> index
|
||||
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
|
||||
// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// The assuming region is part of the second stage of lowering
|
||||
// with ranked broadcasting logic.
|
||||
// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32>
|
||||
// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<f32>
|
||||
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_RESHAPED]], %[[SHAPE_1]]
|
||||
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
|
||||
// CHECK: %[[ASTENSOR:.*]] = tensor_cast %[[SHAPE_RESHAPED]]
|
||||
// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[ASTENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_1]], %[[ASTENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32>
|
||||
// CHECK: shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor<?xf32>
|
||||
// CHECK: }
|
||||
// As part of the unranked logic, the result is reshaped back
|
||||
// to an unranked tensor.
|
||||
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_0]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
func @addUnrankedUnranked(
|
||||
%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>)
|
||||
-> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @addUnrankedUnranked(
|
||||
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK: %[[RANK_LHS:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index
|
||||
// Handle scalar LHS case
|
||||
// CHECK: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor<f32>
|
||||
// CHECK: %[[VAL_10:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RHS]] : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: scf.yield %[[VAL_10]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK: %[[RANK_RHS:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index
|
||||
// Handle scalar RHS case
|
||||
// CHECK: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor<f32>
|
||||
// CHECK: %[[VAL_16:.*]] = chlo.broadcast_add %[[LHS]], %[[SCALAR_RHS]] : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
|
||||
// CHECK: scf.yield %[[VAL_16]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
|
||||
// Handle scalar RHS case
|
||||
// CHECK: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[VAL_19:.*]] = mhlo.add %[[LHS]], %[[RHS]] : tensor<*xf32>
|
||||
// CHECK: scf.yield %[[VAL_19]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[LHS_RANK:.*]] = rank %[[LHS_SHAPE]] : tensor<?xindex>
|
||||
// CHECK: %[[RHS_RANK:.*]] = rank %[[RHS_SHAPE]] : tensor<?xindex>
|
||||
// CHECK: %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[GREATEST_RANK_IS_2:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C2]] : index
|
||||
// Handle rank 2 specialization
|
||||
// CHECK: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
|
||||
// CHECK: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||
// CHECK: %[[CASTED_LHS_2:.*]] = tensor_cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
|
||||
// CHECK: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||
// CHECK: %[[CASTED_RHS_2:.*]] = tensor_cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
|
||||
// CHECK: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT_2:.*]] = tensor_cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
|
||||
// CHECK: scf.yield %[[RESULT_2]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[C3:.*]] = constant 3 : index
|
||||
// CHECK: %[[GREATEST_RANK_IS_3:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C3]] : index
|
||||
// Handle rank 3 specialization
|
||||
// CHECK: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
|
||||
// CHECK: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||
// CHECK: %[[CASTED_LHS_3:.*]] = tensor_cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
|
||||
// CHECK: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||
// CHECK: %[[CASTED_RHS_3:.*]] = tensor_cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
|
||||
// CHECK: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[RESULT_3:.*]] = tensor_cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK: scf.yield %[[RESULT_3]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[C4:.*]] = constant 4 : index
|
||||
// CHECK: %[[GREATEST_RANK_IS_4:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C4]] : index
|
||||
// Handle rank 4 specialization
|
||||
// CHECK: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
|
||||
// CHECK: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||
// CHECK: %[[CASTED_LHS_4:.*]] = tensor_cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
|
||||
// CHECK: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||
// CHECK: %[[CASTED_RHS_4:.*]] = tensor_cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
|
||||
// CHECK: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[RESULT_4:.*]] = tensor_cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK: scf.yield %[[RESULT_4]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[C5:.*]] = constant 5 : index
|
||||
// CHECK: %[[GREATEST_RANK_IS_5:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C5]] : index
|
||||
// Handle rank 5 specialization
|
||||
// CHECK: %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
|
||||
// CHECK: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||
// CHECK: %[[CASTED_LHS_5:.*]] = tensor_cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||
// CHECK: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||
// CHECK: %[[CASTED_RHS_5:.*]] = tensor_cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||
// CHECK: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK: %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK: scf.yield %[[RESULT_5]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[C6:.*]] = constant 6 : index
|
||||
// CHECK: %[[GREATEST_RANK_IS_6:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C6]] : index
|
||||
// Handle rank 6 specialization
|
||||
// CHECK: %[[VAL_58:.*]] = scf.if %[[GREATEST_RANK_IS_6]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
|
||||
// CHECK: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
|
||||
// CHECK: %[[CASTED_LHS_6:.*]] = tensor_cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
|
||||
// CHECK: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
|
||||
// CHECK: %[[CASTED_RHS_6:.*]] = tensor_cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex>
|
||||
// CHECK: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
|
||||
// CHECK: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
|
||||
// CHECK: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
|
||||
// CHECK: %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK: scf.yield %[[RESULT_6]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
// CHECK: %false = constant false
|
||||
// CHECK: assert %false
|
||||
// CHECK: scf.yield %[[LHS]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[VAL_64:.*]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[VAL_65:.*]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[VAL_66:.*]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[VAL_67:.*]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[VAL_68:.*]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[VAL_69:.*]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[VAL_70:.*]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
// CHECK: return %[[VAL_71:.*]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: mlir-hlo-opt --transform-unranked-hlo --split-input-file %s | FileCheck %s
|
||||
// RUN: mlir-hlo-opt --transform-unranked-hlo --cse --split-input-file %s | FileCheck %s
|
||||
|
||||
// Check the validity of expected IR.
|
||||
// CHECK-LABEL: @sqr_transform_result
|
||||
@ -96,3 +96,203 @@ func @tan(%a : tensor<*xf32>) -> tensor<*xf32> {
|
||||
%result = chlo.tan %a : tensor<*xf32>
|
||||
return %result : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<f32>, tensor<*xf32>)
|
||||
-> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @addScalarUnranked(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>,
|
||||
// CHECK-SAME: %[[ARG_1:.*]]: tensor<*xf32>
|
||||
// CHECK-SAME: ) -> tensor<*xf32> {
|
||||
// First handle the dynamic reshaping of the unranked operand
|
||||
// to a 1D tensor.
|
||||
// CHECK-NEXT: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32>
|
||||
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RESULT:.*]] = chlo.broadcast_add %[[ARG_0]], %[[RESHAPED]] : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// As part of the unranked logic, the result is reshaped back
|
||||
// to an unranked tensor.
|
||||
// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[BROADCASTED_RESULT:.*]], %[[SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: return %[[RESHAPED_RESULT]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// -----
|
||||
func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf32> {
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<f32>)
|
||||
-> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @addUnrankedScalar(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>,
|
||||
// CHECK-SAME: %[[ARG_1:.*]]: tensor<f32>) -> tensor<*xf32> {
|
||||
// First handle the dynamic reshaping of the unranked operand
|
||||
// to a 1D tensor.
|
||||
// CHECK-NEXT: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32>
|
||||
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// The assuming region is part of the second stage of lowering
|
||||
// with ranked broadcasting logic.
|
||||
// CHECK-NEXT: %[[BROADCASTED_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED]], %[[ARG_1]] : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
|
||||
// As part of the unranked logic, the result is reshaped back
|
||||
// to an unranked tensor.
|
||||
// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[BROADCASTED_RESULT:.*]], %[[SHAPE_0]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: return %[[RESHAPED_RESULT]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// -----
|
||||
func @addUnrankedUnranked(
|
||||
%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>)
|
||||
-> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @addUnrankedUnranked(
|
||||
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RANK_LHS:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index
|
||||
// Handle scalar LHS case
|
||||
// CHECK-NEXT: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor<f32>
|
||||
// CHECK-NEXT: %[[RHS_SHAPE_1:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE_1]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor_from_elements %[[NUM_RHS]] : tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[LHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RESHAPED_RHS]] : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[LHS_SCALAR_RESULT]], %[[RHS_SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESHAPED_LHS_SCALAR_RESULT]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RANK_RHS:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index
|
||||
// Handle scalar RHS case
|
||||
// CHECK-NEXT: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor<f32>
|
||||
// CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor_from_elements %[[NUM_LHS]] : tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED_LHS]], %[[SCALAR_RHS]] : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RHS_SCALAR_RESULT:.*]], %[[LHS_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESHAPED_RHS_SCALAR_RESULT]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
|
||||
// Handle equal shapes case
|
||||
// CHECK-NEXT: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[ANY_SHAPE:.*]] = shape.any %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[ANY_NUM:.*]] = shape.num_elements %[[ANY_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[ANY_TENSOR:.*]] = tensor_from_elements %[[ANY_NUM]] : tensor<1xindex>
|
||||
// CHECK-NEXT: %[[FLATTENED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[FLATTENED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = mhlo.add %[[FLATTENED_LHS]], %[[FLATTENED_RHS]] : tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[LHS_RANK:.*]] = rank %[[LHS_SHAPE]] : tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RHS_RANK:.*]] = rank %[[RHS_SHAPE]] : tensor<?xindex>
|
||||
// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK-NEXT: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_2:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C2]] : index
|
||||
// Handle rank 2 specialization
|
||||
// CHECK-NEXT: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor_cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor_cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_2:.*]] = tensor_cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_2]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[C3:.*]] = constant 3 : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_3:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C3]] : index
|
||||
// Handle rank 3 specialization
|
||||
// CHECK-NEXT: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor_cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor_cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_3:.*]] = tensor_cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_3]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[C4:.*]] = constant 4 : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_4:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C4]] : index
|
||||
// Handle rank 4 specialization
|
||||
// CHECK-NEXT: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor_cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor_cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_4:.*]] = tensor_cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_4]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[C5:.*]] = constant 5 : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_5:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C5]] : index
|
||||
// Handle rank 5 specialization
|
||||
// CHECK-NEXT: %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor_cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor_cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[C6:.*]] = constant 6 : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_6:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C6]] : index
|
||||
// Handle rank 6 specialization
|
||||
// CHECK-NEXT: %[[VAL_58:.*]] = scf.if %[[GREATEST_RANK_IS_6]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_6:.*]] = tensor_cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_6:.*]] = tensor_cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_6]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %false = constant false
|
||||
// CHECK-NEXT: assert %false
|
||||
// CHECK-NEXT: scf.yield %[[LHS]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_64:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_65:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_66:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_67:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_68:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_69:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_70:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[VAL_71:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
@ -306,6 +306,14 @@ inline bool IsF32ShapedType(Type t) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Returns true if it is a shaped type of bf16 elements.
|
||||
inline bool IsBF16ShapedType(Type t) {
|
||||
if (auto shaped_type = t.dyn_cast_or_null<ShapedType>()) {
|
||||
return shaped_type.getElementType().isBF16();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Performs const folding `calculate` with broadcast behavior on the two
|
||||
// attributes `operand1` and `operand2` and returns the result if possible.
|
||||
// The two operands are expected to both be scalar values.
|
||||
@ -498,7 +506,7 @@ Attribute ConstFoldBinaryOp(
|
||||
/// "tfl.logical_not".
|
||||
Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
|
||||
llvm::function_ref<APFloat(APFloat)> calculate) {
|
||||
assert(IsF32ShapedType(result_type));
|
||||
assert(IsF32ShapedType(result_type) || IsBF16ShapedType(result_type));
|
||||
auto result_shape_type = result_type.cast<ShapedType>();
|
||||
|
||||
if (auto dense_elements = operand.dyn_cast_or_null<DenseElementsAttr>()) {
|
||||
@ -1911,13 +1919,20 @@ OpFoldResult SqrtOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
||||
OpFoldResult RsqrtOp::fold(ArrayRef<Attribute> operands) {
|
||||
Type result_type = getType();
|
||||
// Only constant fold for tensor of f32 is implemented.
|
||||
if (!IsF32ShapedType(result_type)) return nullptr;
|
||||
// Only constant fold for tensor of f32/bf16 is implemented.
|
||||
if (!IsF32ShapedType(result_type) && !IsBF16ShapedType(result_type))
|
||||
return nullptr;
|
||||
|
||||
auto compute = [](APFloat value) -> APFloat {
|
||||
bool loseInfo;
|
||||
const llvm::fltSemantics &original_float_semantics = value.getSemantics();
|
||||
value.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
|
||||
&loseInfo);
|
||||
float f = value.convertToFloat();
|
||||
float result = 1.f / std::sqrt(f);
|
||||
return APFloat(result);
|
||||
APFloat result(1.f / std::sqrt(f));
|
||||
result.convert(original_float_semantics, APFloat::rmNearestTiesToEven,
|
||||
&loseInfo);
|
||||
return result;
|
||||
};
|
||||
return ConstFoldUnaryOp(result_type, operands[0], compute);
|
||||
}
|
||||
|
@ -577,3 +577,13 @@ func @div_dense_different_rank() -> tensor<1x2x2xf32> {
|
||||
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @rsqrt_bf16
|
||||
func @rsqrt_bf16() -> tensor<bf16> {
|
||||
%cst = constant dense<4.0> : tensor<bf16>
|
||||
%0 = "tfl.rsqrt"(%cst) : (tensor<bf16>) -> tensor<bf16>
|
||||
return %0 : tensor<bf16>
|
||||
|
||||
// CHECK: %[[CST:.*]] = constant dense<5.000000e-01> : tensor<bf16>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
@ -1358,3 +1358,27 @@ func @fuseScalarAddIntoConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3
|
||||
// CHECK: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> : tensor<16xf32>
|
||||
// CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fuseScalarAddIntoConv2dBf16
|
||||
func @fuseScalarAddIntoConv2dBf16(%arg0: tensor<256x32x32x3xbf16>, %arg1: tensor<16x3x3x3xbf16>) -> tensor<256x30x30x16xbf16> {
|
||||
%cst = constant dense<1.5> : tensor<bf16>
|
||||
%cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xbf16>
|
||||
%0 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xbf16>, tensor<16x3x3x3xbf16>, tensor<16xbf16>) -> tensor<256x30x30x16xbf16>
|
||||
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xbf16>, tensor<bf16>) -> tensor<256x30x30x16xbf16>
|
||||
return %1 : tensor<256x30x30x16xbf16>
|
||||
|
||||
// CHECK: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> : tensor<16xbf16>
|
||||
// CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fuseScalarAddIntoConv2dHalf
|
||||
func @fuseScalarAddIntoConv2dHalf(%arg0: tensor<256x32x32x3xf16>, %arg1: tensor<16x3x3x3xf16>) -> tensor<256x30x30x16xf16> {
|
||||
%cst = constant dense<1.5> : tensor<f16>
|
||||
%cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf16>
|
||||
%0 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf16>, tensor<16x3x3x3xf16>, tensor<16xf16>) -> tensor<256x30x30x16xf16>
|
||||
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf16>, tensor<f16>) -> tensor<256x30x30x16xf16>
|
||||
return %1 : tensor<256x30x30x16xf16>
|
||||
|
||||
// CHECK: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> : tensor<16xf16>
|
||||
// CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst)
|
||||
}
|
||||
|
@ -211,20 +211,21 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
pass_manager->addPass(mlir::createSymbolDCEPass());
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
|
||||
// This pass should be always at the end of the floating point model
|
||||
// conversion. Some TFL ops like unidirectional
|
||||
// sequence lstm will have stateful operands and some optimization passes
|
||||
// will merge those operands if they have identical values & types. However,
|
||||
// it's not desired by TFL. This pass serves as a "fix" pass to split the
|
||||
// merged inputs until we have 1st class variable support or reuse
|
||||
// tf.variable to model this.
|
||||
pass_manager->addPass(mlir::TFL::CreateSplitMergedOperandsPass());
|
||||
|
||||
// Run quantization after all the floating point model conversion is
|
||||
// completed.
|
||||
if (pass_config.quant_specs.RunPropagationAndRewriteQuantizationPasses()) {
|
||||
AddQuantizationPasses(pass_config.quant_specs, pass_manager);
|
||||
}
|
||||
|
||||
// This pass should be always at the end of the model
|
||||
// conversion (even after quantization). Some TFL ops like unidirectional
|
||||
// sequence lstm will have stateful operands and some optimization passes
|
||||
// will merge those operands if they have identical values & types. However,
|
||||
// it's not desired by TFL. This pass serves as a "fix" pass to split the
|
||||
// merged inputs until we have 1st class variable support or reuse
|
||||
// tf.variable to model this.
|
||||
pass_manager->addPass(mlir::TFL::CreateSplitMergedOperandsPass());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -24,6 +24,11 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
// Checks if the param passed is a F32 ElementsAttr.
|
||||
def F32ElementsAttr : ElementsAttrBase<
|
||||
CPred<"$_self.isa<ElementsAttr>() && $_self.cast<ElementsAttr>().getType().getElementType().isF32()">,
|
||||
"32 bit float constant tensor">;
|
||||
|
||||
// Checks if the param passed is a float ElementsAttr.
|
||||
def FloatElementsAttr : ElementsAttrBase<
|
||||
CPred<"$_self.isa<ElementsAttr>() && $_self.cast<ElementsAttr>().getType().getElementType().isa<FloatType>()">,
|
||||
"float constant tensor">;
|
||||
|
||||
// Checks if the param passed is of NoneType.
|
||||
@ -93,9 +98,9 @@ class CanFuseConvOrDepthwiseConv<string is_depthwise> : Constraint<
|
||||
multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
|
||||
def FuseBinaryOpWithConv#binaryOp : Pat<
|
||||
(binaryOp (TFL_Conv2DOp:$output $input, $filter,
|
||||
(ConstantOp F32ElementsAttr:$bias), $h_factor, $w_factor,
|
||||
(ConstantOp FloatElementsAttr:$bias), $h_factor, $w_factor,
|
||||
TFL_AF_None, $padding, $stride_h, $stride_w),
|
||||
(ConstantOp F32ElementsAttr:$value), $act_fn),
|
||||
(ConstantOp FloatElementsAttr:$value), $act_fn),
|
||||
(TFL_Conv2DOp $input, $filter,
|
||||
(binaryOp (ConstantOp $bias),
|
||||
(ConstantOp $value), TFL_AF_None),
|
||||
@ -104,10 +109,10 @@ multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
|
||||
(HasOneUse $output)]>;
|
||||
def FuseBinaryOpWithDepthwiseConv#binaryOp : Pat<
|
||||
(binaryOp (TFL_DepthwiseConv2DOp:$output $input, $filter,
|
||||
(ConstantOp F32ElementsAttr:$bias),
|
||||
(ConstantOp FloatElementsAttr:$bias),
|
||||
$h_factor, $w_factor, TFL_AF_None, $padding, $stride_h,
|
||||
$stride_w, $multiplier),
|
||||
(ConstantOp F32ElementsAttr:$value), $act_fn),
|
||||
(ConstantOp FloatElementsAttr:$value), $act_fn),
|
||||
(TFL_DepthwiseConv2DOp $input, $filter,
|
||||
(binaryOp (ConstantOp $bias), (ConstantOp $value), TFL_AF_None),
|
||||
$h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w,
|
||||
@ -116,9 +121,9 @@ multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
|
||||
(HasOneUse $output)]>;
|
||||
def FuseBinaryOpWithTransposeConv#binaryOp : Pat<
|
||||
(binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs,
|
||||
(ConstantOp F32ElementsAttr:$bias), $padding,
|
||||
(ConstantOp FloatElementsAttr:$bias), $padding,
|
||||
$stride_h, $stride_w),
|
||||
(ConstantOp F32ElementsAttr:$value), TFL_AF_None),
|
||||
(ConstantOp FloatElementsAttr:$value), TFL_AF_None),
|
||||
(TFL_TransposeConvOp $output_shape, $weights, $inputs,
|
||||
(binaryOp (ConstantOp $bias),
|
||||
(ConstantOp $value), TFL_AF_None),
|
||||
@ -130,7 +135,7 @@ multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
|
||||
(binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs,
|
||||
(ConstantOp $bias), $padding,
|
||||
$stride_h, $stride_w),
|
||||
(ConstantOp F32ElementsAttr:$value), TFL_AF_None),
|
||||
(ConstantOp FloatElementsAttr:$value), TFL_AF_None),
|
||||
(TFL_TransposeConvOp $output_shape, $weights, $inputs,
|
||||
(ConstantOp $value),
|
||||
$padding, $stride_h, $stride_w),
|
||||
@ -155,11 +160,11 @@ def ExpandTo4DForDepthwiseConv: NativeCodeCall<
|
||||
multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> {
|
||||
def FuseMulOrDivWithDepthwiseConv#BinaryOp : Pat<
|
||||
(BinaryOp (TFL_DepthwiseConv2DOp:$output $input,
|
||||
(ConstantOp F32ElementsAttr:$filter),
|
||||
(ConstantOp F32ElementsAttr:$bias),
|
||||
(ConstantOp FloatElementsAttr:$filter),
|
||||
(ConstantOp FloatElementsAttr:$bias),
|
||||
$h_factor, $w_factor, TFL_AF_None, $padding, $stride_h,
|
||||
$stride_w, $multiplier),
|
||||
(ConstantOp F32ElementsAttr:$value), $act_fn),
|
||||
(ConstantOp FloatElementsAttr:$value), $act_fn),
|
||||
(TFL_DepthwiseConv2DOp $input,
|
||||
(BinaryOp
|
||||
(ConstantOp $filter),
|
||||
@ -175,11 +180,11 @@ multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> {
|
||||
(HasOneUse $output)]>;
|
||||
def FuseMulOrDivWithConv#BinaryOp : Pat<
|
||||
(BinaryOp (TFL_Conv2DOp:$conv_output $input,
|
||||
(ConstantOp F32ElementsAttr:$filter),
|
||||
(ConstantOp F32ElementsAttr:$bias),
|
||||
(ConstantOp FloatElementsAttr:$filter),
|
||||
(ConstantOp FloatElementsAttr:$bias),
|
||||
$h_factor, $w_factor, TFL_AF_None,
|
||||
$padding, $stride_h, $stride_w),
|
||||
(ConstantOp F32ElementsAttr:$value), $act_fn),
|
||||
(ConstantOp FloatElementsAttr:$value), $act_fn),
|
||||
(TFL_Conv2DOp $input,
|
||||
(BinaryOp (ConstantOp $filter),
|
||||
(ConstantOp (ExpandTo4DForConv $value)),
|
||||
@ -192,8 +197,8 @@ multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> {
|
||||
(HasOneUse $conv_output)]>;
|
||||
def FuseMulOrDivWithTransposeConv#BinaryOp : Pat<
|
||||
(BinaryOp (TFL_TransposeConvOp:$output $output_shape,
|
||||
(ConstantOp F32ElementsAttr:$weights), $input,
|
||||
(ConstantOp F32ElementsAttr:$bias),
|
||||
(ConstantOp FloatElementsAttr:$weights), $input,
|
||||
(ConstantOp FloatElementsAttr:$bias),
|
||||
$padding, $stride_h, $stride_w),
|
||||
(ConstantOp $value), TFL_AF_None),
|
||||
(TFL_TransposeConvOp $output_shape,
|
||||
@ -209,7 +214,7 @@ multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> {
|
||||
(HasOneUse $output)]>;
|
||||
def FuseMulOrDivWithTransposeConvWithNoneBias#BinaryOp : Pat<
|
||||
(BinaryOp (TFL_TransposeConvOp:$output $output_shape,
|
||||
(ConstantOp F32ElementsAttr:$weights), $input,
|
||||
(ConstantOp FloatElementsAttr:$weights), $input,
|
||||
(ConstantOp $bias),
|
||||
$padding, $stride_h, $stride_w),
|
||||
(ConstantOp $value), TFL_AF_None),
|
||||
|
@ -1651,10 +1651,12 @@ Mutually reduces multiple tensors of identical type and shape.
|
||||
TF_Int32Tensor:$group_size,
|
||||
TF_Int32Tensor:$group_key,
|
||||
TF_Int32Tensor:$instance_key,
|
||||
Variadic<TF_ResourceTensor>:$ordering_token,
|
||||
|
||||
TF_AnyStrAttrOf<["Min", "Max", "Mul", "Add"]>:$merge_op,
|
||||
TF_AnyStrAttrOf<["Id", "Div"]>:$final_op,
|
||||
DefaultValuedAttr<StrAttr, "auto">:$communication_hint
|
||||
DefaultValuedAttr<StrAttr, "auto">:$communication_hint,
|
||||
DefaultValuedAttr<F32Attr, "0.0f">:$timeout_seconds
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@ -1662,6 +1664,7 @@ Mutually reduces multiple tensors of identical type and shape.
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandSizeAttr Nordering_token = TF_DerivedOperandSizeAttr<4>;
|
||||
}
|
||||
|
||||
def TF_ComplexOp : TF_Op<"Complex", [NoSideEffect, ResultsBroadcastableShape]> {
|
||||
|
@ -77,66 +77,6 @@ namespace TF {
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns true of the given function has a single uses (within the scope
|
||||
// of the module containing it and all parent modules).
|
||||
bool HasSingleUse(FuncOp func) {
|
||||
// Public function can have any number of external uses.
|
||||
if (func.isPublic()) return false;
|
||||
|
||||
// Return false if unexpected IR structure seen.
|
||||
ModuleOp module = func.getParentOfType<ModuleOp>();
|
||||
if (!module) return false;
|
||||
|
||||
// Inspect function uses in the containing module and all parent
|
||||
// modules.
|
||||
bool use_seen = false;
|
||||
for (; module; module = func.isPrivate()
|
||||
? nullptr
|
||||
: module.getParentOfType<ModuleOp>()) {
|
||||
auto func_uses_optional =
|
||||
SymbolTable::getSymbolUses(func, &module.getBodyRegion());
|
||||
// Found an unknown use.
|
||||
if (!func_uses_optional) return false;
|
||||
|
||||
// If no uses in this scope, continue looking in parent module
|
||||
SymbolTable::UseRange func_uses = func_uses_optional.getValue();
|
||||
if (func_uses.empty()) continue;
|
||||
|
||||
// Check if multiple uses at this scope or another use already seen.
|
||||
if (!llvm::hasSingleElement(func_uses) || use_seen) return false;
|
||||
|
||||
// This is the first use seen.
|
||||
use_seen = true;
|
||||
}
|
||||
|
||||
// No multiple uses seen.
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns true if the caller ops can be inlined.
|
||||
bool HasInlinableUsers(FuncOp func) {
|
||||
// Return false if unexpected IR structure seen.
|
||||
ModuleOp module = func.getParentOfType<ModuleOp>();
|
||||
if (!module) return false;
|
||||
|
||||
// Inspect function uses in the containing module and all parent
|
||||
// modules.
|
||||
for (; module; module = func.isPrivate()
|
||||
? nullptr
|
||||
: module.getParentOfType<ModuleOp>()) {
|
||||
auto func_uses_optional =
|
||||
SymbolTable::getSymbolUses(func, &module.getBodyRegion());
|
||||
// Found an unknown use.
|
||||
if (!func_uses_optional) return false;
|
||||
|
||||
for (auto &use : func_uses_optional.getValue())
|
||||
if (isa<TPUPartitionedCallOp>(use.getUser())) return false;
|
||||
}
|
||||
|
||||
// All caller ops that can be inlined.
|
||||
return true;
|
||||
}
|
||||
|
||||
struct TFConstantFoldInterface : public DialectFoldInterface {
|
||||
TFConstantFoldInterface(Dialect *dialect) : DialectFoldInterface(dialect) {}
|
||||
LogicalResult fold(Operation *op, ArrayRef<Attribute> operands,
|
||||
@ -160,10 +100,12 @@ struct TFInlinerInterface : public DialectInlinerInterface {
|
||||
// Analysis Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// Allow all call operations to be inlined.
|
||||
// Returns if it's legal to inline 'callable' into the 'call', where 'call' is
|
||||
// a TF operation.
|
||||
bool isLegalToInline(Operation *call, Operation *callable,
|
||||
bool wouldBeCloned) const final {
|
||||
return true;
|
||||
// Check that the TF call operation is one that is legal to inline.
|
||||
return !isa<TPUPartitionedCallOp>(call);
|
||||
}
|
||||
|
||||
// Returns if its legal to inline 'src' region into the 'dest' region
|
||||
@ -186,10 +128,7 @@ struct TFInlinerInterface : public DialectInlinerInterface {
|
||||
// post inlining, the function will be dead and eliminated from the IR.
|
||||
// So there won't be any code duplication.
|
||||
// plus the function caller op can be replaced by inlined ops.
|
||||
FuncOp func = op->getParentOfType<FuncOp>();
|
||||
if (!func) return true;
|
||||
if (!HasInlinableUsers(func)) return false;
|
||||
return TensorFlowDialect::CanDuplicate(op) || HasSingleUse(func);
|
||||
return !wouldBeCloned || TensorFlowDialect::CanDuplicate(op);
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -361,8 +361,7 @@ func @send_recv(%arg0: tensor<2x!tf.string>) {
|
||||
// -----
|
||||
|
||||
// Tests functional control flow functions with replica variant ops reachable
|
||||
// from a replicate region is cloned and remapped. Only the branches reachable
|
||||
// with replica variant ops are cloned.
|
||||
// from a replicate region is cloned and remapped.
|
||||
|
||||
// CHECK-LABEL: func @control_flow_with_replicate_variant_ops
|
||||
func @control_flow_with_replicate_variant_ops(%arg0: tensor<i1>, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<2x!tf.string>) {
|
||||
@ -380,30 +379,32 @@ func @control_flow_with_replicate_variant_ops(%arg0: tensor<i1>, %arg1: tensor<f
|
||||
}
|
||||
|
||||
// CHECK: "tf.If"
|
||||
// CHECK-SAME: else_branch = @cond_false
|
||||
// CHECK-SAME: else_branch = [[COND_FALSE_REPLICA_0:@[a-z0-9_]+]]
|
||||
// CHECK-SAME: then_branch = [[COND_TRUE_REPLICA_0:@[a-z0-9_]+]]
|
||||
// CHECK: "tf.If"
|
||||
// CHECK-SAME: else_branch = @cond_false
|
||||
// CHECK-SAME: else_branch = [[COND_FALSE_REPLICA_1:@[a-z0-9_]+]]
|
||||
// CHECK-SAME: then_branch = [[COND_TRUE_REPLICA_1:@[a-z0-9_]+]]
|
||||
|
||||
func @cond_false(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<2x!tf.string>) -> tensor<f32> {
|
||||
return %arg0 : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-NOT: func @cond_false.+(
|
||||
|
||||
func @cond_true(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<2x!tf.string>) -> tensor<f32> {
|
||||
"tf._XlaSendFromHost"(%arg1, %arg2) {_xla_has_host_transfer = true, device_ordinal = 0 : i64, key = "host_compute_channel_recv_0"} : (tensor<f32>, tensor<2x!tf.string>) -> ()
|
||||
%0 = "tf._XlaRecvAtHost"(%arg2) {_xla_has_host_transfer = true, device_ordinal = 0 : i64, key = "host_compute_channel_send_0"} : (tensor<2x!tf.string>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func [[COND_FALSE_REPLICA_0]]
|
||||
|
||||
// CHECK: func [[COND_TRUE_REPLICA_0]]
|
||||
// CHECK: "tf._XlaSendFromHost"
|
||||
// CHECK-SAME: device_ordinal = 1
|
||||
// CHECK: "tf._XlaRecvAtHost"
|
||||
// CHECK-SAME: device_ordinal = 1
|
||||
|
||||
// CHECK: func [[COND_FALSE_REPLICA_1]]
|
||||
|
||||
// CHECK: func [[COND_TRUE_REPLICA_1]]
|
||||
// CHECK: "tf._XlaSendFromHost"
|
||||
// CHECK-SAME: device_ordinal = 2
|
||||
@ -413,7 +414,7 @@ func @cond_true(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<2x!tf.stri
|
||||
// -----
|
||||
|
||||
// Tests function with no replica variant ops reachable from a replicate region
|
||||
// is not cloned.
|
||||
// is cloned.
|
||||
|
||||
// CHECK-LABEL: func @no_replicate_variant_ops
|
||||
func @no_replicate_variant_ops(%arg0: tensor<f32>, %arg1: tensor<2x!tf.string>) {
|
||||
@ -431,11 +432,17 @@ func @no_replicate_variant_ops(%arg0: tensor<f32>, %arg1: tensor<2x!tf.string>)
|
||||
}
|
||||
|
||||
// CHECK: "tf.StatefulPartitionedCall"
|
||||
// CHECK-SAME: f = @send_recv
|
||||
// CHECK-SAME: f = [[CALLEE_REPLICA_0:@[a-z0-9_]+]]
|
||||
// CHECK: "tf.StatefulPartitionedCall"
|
||||
// CHECK-SAME: f = [[CALLEE_REPLICA_1:@[a-z0-9_]+]]
|
||||
|
||||
func @send_recv(%arg0: tensor<2x!tf.string>) {
|
||||
"tf.NoOp"() : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-NOT: @send_recv.+(
|
||||
// CHECK: func [[CALLEE_REPLICA_0]]
|
||||
// CHECK: "tf.NoOp"
|
||||
|
||||
// CHECK: func [[CALLEE_REPLICA_1]]
|
||||
// CHECK: "tf.NoOp"
|
||||
|
@ -53,7 +53,6 @@ void AddGraphExportLoweringPasses(OpPassManager &pm) {
|
||||
add_pass(TFDevice::CreateParallelExecuteToIslandsPass());
|
||||
pm.addPass(TFDevice::CreateLaunchToDeviceAttributePass());
|
||||
pm.addPass(CreateBreakUpIslandsPass());
|
||||
pm.addNestedPass<FuncOp>(CreateTPUDevicePropagationPass());
|
||||
pm.addPass(createSymbolDCEPass());
|
||||
}
|
||||
|
||||
|
@ -120,46 +120,6 @@ llvm::SmallPtrSet<FuncOp, 4> GetReachableFunctionsFromRegion(ModuleOp module,
|
||||
return visited_functions;
|
||||
}
|
||||
|
||||
// Collects all functions and transitive functions reachable from region that
|
||||
// contain replicate variant ops.
|
||||
llvm::SmallDenseMap<llvm::StringRef, FuncOp> GetReachableFunctionsToClone(
|
||||
ModuleOp module, Region& region,
|
||||
const llvm::Optional<DictionaryAttr>& devices) {
|
||||
llvm::SmallPtrSet<FuncOp, 4> reachable_functions =
|
||||
GetReachableFunctionsFromRegion(module, region);
|
||||
|
||||
llvm::SmallDenseMap<llvm::StringRef, FuncOp> functions_to_clone;
|
||||
llvm::SmallVector<FuncOp, 4> functions_to_visit;
|
||||
for (FuncOp func : reachable_functions) {
|
||||
if (!func.getCallableRegion()) continue;
|
||||
if (HasReplicaVariantOps(*func.getCallableRegion(), devices)) {
|
||||
functions_to_clone.insert({func.getName(), func});
|
||||
functions_to_visit.push_back(func);
|
||||
}
|
||||
}
|
||||
|
||||
while (!functions_to_visit.empty()) {
|
||||
llvm::SmallVector<FuncOp, 4> new_functions_to_visit;
|
||||
|
||||
for (FuncOp func_to_visit : functions_to_visit) {
|
||||
auto func_uses = func_to_visit.getSymbolUses(module);
|
||||
if (!func_uses) continue;
|
||||
for (auto use : *func_uses) {
|
||||
auto parent_func = use.getUser()->getParentOfType<FuncOp>();
|
||||
if (!parent_func || !reachable_functions.contains(parent_func) ||
|
||||
!functions_to_clone.insert({parent_func.getName(), parent_func})
|
||||
.second)
|
||||
continue;
|
||||
new_functions_to_visit.push_back(parent_func);
|
||||
}
|
||||
}
|
||||
|
||||
functions_to_visit.swap(new_functions_to_visit);
|
||||
}
|
||||
|
||||
return functions_to_clone;
|
||||
}
|
||||
|
||||
struct FuncOldNameAndClone {
|
||||
StringRef old_name;
|
||||
FuncOp clone;
|
||||
@ -276,20 +236,19 @@ LogicalResult ExpandReplicateIntoReplicas(
|
||||
terminator.erase();
|
||||
|
||||
auto funcs_to_clone =
|
||||
GetReachableFunctionsToClone(module, replicate_op.body(), devices);
|
||||
GetReachableFunctionsFromRegion(module, replicate_op.body());
|
||||
SymbolTable symbol_table(module);
|
||||
|
||||
builder.setInsertionPoint(island_op);
|
||||
BlockAndValueMapping mapping;
|
||||
for (int i : llvm::seq<int>(0, num_replicas)) {
|
||||
// Clone reachable functions with replica variant ops.
|
||||
// Clone reachable functions from region.
|
||||
llvm::SmallVector<FuncOldNameAndClone, 4> cloned_functions;
|
||||
cloned_functions.reserve(funcs_to_clone.size());
|
||||
for (auto& func_to_clone : funcs_to_clone) {
|
||||
auto cloned_function = func_to_clone.getSecond().clone();
|
||||
for (FuncOp func_to_clone : funcs_to_clone) {
|
||||
auto cloned_function = func_to_clone.clone();
|
||||
symbol_table.insert(cloned_function, module.end());
|
||||
cloned_functions.push_back(
|
||||
{func_to_clone.getSecond().getName(), cloned_function});
|
||||
cloned_functions.push_back({func_to_clone.getName(), cloned_function});
|
||||
}
|
||||
|
||||
// Create new island for replica.
|
||||
|
@ -31,7 +31,6 @@ int main(int argc, char **argv) {
|
||||
mlir::registerAllPasses();
|
||||
mlir::mhlo::registerAllMhloPasses();
|
||||
mlir::lmhlo::registerAllLmhloPasses();
|
||||
mlir::mhlo::registerAllMhloPasses();
|
||||
|
||||
mlir::DialectRegistry registry;
|
||||
mlir::registerAllDialects(registry);
|
||||
|
158
tensorflow/compiler/mlir/tfr/README.md
Normal file
158
tensorflow/compiler/mlir/tfr/README.md
Normal file
@ -0,0 +1,158 @@
|
||||
# Composable Tensorflow
|
||||
|
||||
## Composable Tensorflow
|
||||
|
||||
Composable TensorFlow (TF) is the framework for defining portable TF ops with
|
||||
composition in the authoring language.
|
||||
|
||||
The set of standard TF ops is currently open. New ops are defined for special
|
||||
purposes but it is hard to make them work end-to-end: The op
|
||||
needs to be handled separately by a several backends (tf2xla bridge, tflite
|
||||
converter, CPU kernels, etc.). Writing shape functions and gradients for these
|
||||
ops is extremely difficult. `tf.function` makes some parts of the implementation
|
||||
simpler, but it introduces runtime overhead and it cannot easily be used to
|
||||
apply dedicated optimizations to op kernels.
|
||||
|
||||
The composable TF framework allows the user to define portable TF ops as
|
||||
ompositions of other TF ops. It translates a Python function used to define the
|
||||
composition directly into a portable IR at build time, and uses it to expand the
|
||||
composite op in the TF program during compilation / execution. By using this
|
||||
expansion mechanism, new op are readily available on different platforms without
|
||||
extra work. Moreover, since the expansion is optional, the backend can easily
|
||||
treat it as a monolithic op when needed, for instance to apply optimizations or
|
||||
associate it with a custom kernel.
|
||||
|
||||
### Benefits
|
||||
|
||||
Using the Composable TF API to define a new op and its composition can bring the
|
||||
following benefits:
|
||||
|
||||
* *Automatic backend support*: As long as it is composed of ops supported by the
|
||||
backend, the new op is automatcally supported (as a `tf.function` alternative);
|
||||
* *Reduced tracing overhead*: Unlike `tf.function`, the composition function is
|
||||
compiled at build time, hence TF only needs to trace a single op to build the
|
||||
`graph`;
|
||||
* *Easy fused op/kernel optimization*: Even if it has complex
|
||||
semantics, the new op is presented as a single node in the graph, thus
|
||||
optimization passes and kernels can easily be specialized to this op for better
|
||||
performance.
|
||||
* *Automatic shape/type inference support*: No shape functions are required for
|
||||
the new op;
|
||||
* *Automatic gradient support (WIP)*: The user doesn't need to author
|
||||
gradient a function of the op for training.
|
||||
|
||||
### Use Cases
|
||||
|
||||
* (Portablity) User wants to add a new op and run this op on different
|
||||
platforms (CPU, TPU, TFLite, etc.) to be portable.
|
||||
* *Solution*: The user should define the new op as a composition. The ops used
|
||||
inside the composition should have support for these platforms. These ops can
|
||||
also be composite ops.
|
||||
|
||||
* (Performance) User defines a custom kernel for a regular structure
|
||||
(i.e. LSTM), but it is hard to add the logic to fuse the individual ops to
|
||||
target this kernel in the inference graph.
|
||||
* *Solution*: The user should define a new TF op, which corresponds to the
|
||||
fused kernel, with composition, and use this op to build the model for both
|
||||
training and inference. For the platforms where a fused kernel is not
|
||||
available, the execution will use the composition instead.
|
||||
|
||||
## Gradient
|
||||
(TODO)
|
||||
|
||||
## Authoring Op Composition in Python
|
||||
|
||||
The composable TF provides a single API to define a new op with its composition
|
||||
at the same time. For example, the following code defines a new
|
||||
`FusedFullyConnected` op, which have `MatMul`, `Add` and some
|
||||
`activation function` (specified by an op attribute) fused.
|
||||
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
|
||||
@Composite(
|
||||
'FusedFullyConnected',
|
||||
inputs=['input_: T', 'filter_: T', 'bias: T'],
|
||||
attrs=['act: {"", "RELU", "RELU6", "TANH"} = ""'],
|
||||
derived_attrs=['T: {float, int8}'],
|
||||
outputs=['o: T'])
|
||||
def _composite_fully_connected(input_, filter_, bias, act):
|
||||
res = tf.raw_ops.MatMul(
|
||||
a=input_, b=filter_, transpose_a=False, transpose_b=True)
|
||||
res = tf.raw_ops.Add(x=res, y=bias)
|
||||
if act == 'RELU':
|
||||
return tf.raw_ops.Relu(features=res)
|
||||
elif act == 'RELU6':
|
||||
return tf.raw_ops.Relu6(features=res)
|
||||
elif act == 'TANH':
|
||||
return tf.raw_ops.Tanh(x=res)
|
||||
else:
|
||||
return res
|
||||
|
||||
```
|
||||
|
||||
Besides defining new ops, composition can be specified for an existing op
|
||||
for portability. The following code defines the semantics of `AddNOp`:
|
||||
|
||||
```python
|
||||
@Composite('AddNOp')
|
||||
def _my_op_c(ins):
|
||||
N = len(ins)
|
||||
if N == 1:
|
||||
return ins[0]
|
||||
sum = ins[0]
|
||||
for i in range(1, N):
|
||||
sum += ins[i]
|
||||
return sum
|
||||
```
|
||||
|
||||
Utilities have been built to compile the Python composition functions down to
|
||||
the backend IR. The project also provides a set of graph optimization passes to
|
||||
expand the composite ops in the graph by using the input backend IR. These
|
||||
passes have been added to the TF [common runtime]
|
||||
(https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/common_runtime)
|
||||
for graph execution and [eager_runtime]
|
||||
(https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/common_runtime/eager)
|
||||
for eager execution.
|
||||
|
||||
## Compiling Op Composition
|
||||
|
||||
### Ahead-Of-Time (AOT) mode
|
||||
|
||||
Like the op kernels, the op composition can be pre-compiled to the backend IR
|
||||
so the decomposition can be invoked at runtime. A Python [define_op_template.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tfr/define_op_template.py)
|
||||
file is provided as an example to build composite ops in the users project
|
||||
directory. All the targets required to build the new ops are created by the
|
||||
following target:
|
||||
|
||||
|
||||
```BUILD
|
||||
load("//tensorflow/compiler/mlir/tfr:build_defs.bzl", "gen_op_libraries")
|
||||
|
||||
gen_op_libraries(
|
||||
name = "test_ops",
|
||||
src = "define_op_template.py",
|
||||
deps = [
|
||||
"//third_party/py/tensorflow",
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
More composite op definitions and usages are here included in the
|
||||
[examples](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir/tfr/examples)
|
||||
directory.
|
||||
|
||||
### Just-In-Time (JIT) mode
|
||||
(TODO)
|
||||
|
||||
## Known Limitations
|
||||
|
||||
* `while` statement
|
||||
* condition of `if` statement couldn't be a tensor
|
||||
|
||||
## Team
|
||||
|
||||
* Feng Liu
|
||||
* Dan Moldovan
|
||||
|
@ -44,6 +44,10 @@ extern "C" CUmodule mgpuModuleLoad(void *data) {
|
||||
return module;
|
||||
}
|
||||
|
||||
extern "C" void mgpuModuleUnload(CUmodule module) {
|
||||
CUDA_REPORT_IF_ERROR(cuModuleUnload(module));
|
||||
}
|
||||
|
||||
extern "C" CUfunction mgpuModuleGetFunction(CUmodule module, const char *name) {
|
||||
CUfunction function = nullptr;
|
||||
CUDA_REPORT_IF_ERROR(cuModuleGetFunction(&function, module, name));
|
||||
@ -64,14 +68,13 @@ extern "C" void mgpuLaunchKernel(CUfunction function, intptr_t gridX,
|
||||
}
|
||||
|
||||
extern "C" CUstream mgpuStreamCreate() {
|
||||
static CUstream stream = []() {
|
||||
// TODO(b/170649852): This is neither thread-safe nor handles
|
||||
// creation/descruction of one stream per context.
|
||||
CUstream stream = nullptr;
|
||||
CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING));
|
||||
return stream;
|
||||
}();
|
||||
return stream;
|
||||
}
|
||||
|
||||
extern "C" void mgpuStreamDestroy(CUstream stream) {
|
||||
CUDA_REPORT_IF_ERROR(cuStreamDestroy(stream));
|
||||
}
|
||||
|
||||
extern "C" void mgpuStreamSynchronize(CUstream stream) {
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
|
||||
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
|
||||
#include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Shape/Transforms/Passes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
@ -107,6 +108,8 @@ struct BufferizePass : public BufferizePassBase<BufferizePass> {
|
||||
populateStandardBufferizePattern(&context, &converter, &patterns);
|
||||
populateShapeStructuralTypeConversionsAndLegality(&context, converter,
|
||||
patterns, target);
|
||||
scf::populateSCFStructuralTypeConversionsAndLegality(&context, converter,
|
||||
patterns, target);
|
||||
patterns.insert<UnrankedTensorStoreTestOnlyPattern>(&context);
|
||||
|
||||
auto module = getOperation();
|
||||
|
@ -218,7 +218,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
|
||||
auto attr = CreateDenseElementsAttrFromLiteral(literal, *builder_);
|
||||
if (!attr.ok()) return attr.status();
|
||||
mlir::Operation* new_operation =
|
||||
func_builder->create<mlir::ConstantOp>(loc, attr.ValueOrDie());
|
||||
func_builder->create<mlir::mhlo::ConstOp>(loc, attr.ValueOrDie());
|
||||
for (auto attr : attributes) {
|
||||
new_operation->setAttr(attr.first, attr.second);
|
||||
}
|
||||
|
@ -12,7 +12,11 @@ glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
tags_override = {
|
||||
"hlo_to_lhlo_with_xla/gpu_ops.mlir": tf_cuda_tests_tags() + ["noasan"], # TODO(b/171751580)
|
||||
"hlo_to_lhlo_with_xla/gpu_ops.mlir": tf_cuda_tests_tags() + [
|
||||
"noasan",
|
||||
"nomsan",
|
||||
"noubsan",
|
||||
], # b/171751580
|
||||
},
|
||||
test_file_exts = [
|
||||
"mlir",
|
||||
|
@ -26,10 +26,10 @@ ENTRY %indexed_conditional () -> f32[] {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @main() -> tensor<f32>
|
||||
// CHECK: %[[INDEX:.*]] = constant dense<1> : tensor<i32>
|
||||
// CHECK: %[[OPERAND_1:.*]] = constant dense<5.600000e+01> : tensor<f32>
|
||||
// CHECK: %[[OPERAND_2:.*]] = constant dense<1.200000e+01> : tensor<f32>
|
||||
// CHECK: %[[OPERAND_3:.*]] = constant dense<1.300000e+01> : tensor<f32>
|
||||
// CHECK: %[[INDEX:.*]] = mhlo.constant dense<1> : tensor<i32>
|
||||
// CHECK: %[[OPERAND_1:.*]] = mhlo.constant dense<5.600000e+01> : tensor<f32>
|
||||
// CHECK: %[[OPERAND_2:.*]] = mhlo.constant dense<1.200000e+01> : tensor<f32>
|
||||
// CHECK: %[[OPERAND_3:.*]] = mhlo.constant dense<1.300000e+01> : tensor<f32>
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.case"(%[[INDEX]], %[[OPERAND_1]], %[[OPERAND_2]], %[[OPERAND_3]]) ( {
|
||||
// CHECK: ^bb0(%[[ARG_1:.*]]: tensor<f32>):
|
||||
// CHECK: %[[RES_1:.*]] = "mhlo.negate"(%[[ARG_1]]) : (tensor<f32>) -> tensor<f32>
|
||||
|
@ -4,100 +4,102 @@
|
||||
|
||||
HloModule tfcompile.48
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<1x300xf32>, %arg1: tensor<1x300x3x1xf32>) -> tuple<tensor<300x1x5xf32>> {
|
||||
// CHECK-LABEL: func @main(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x300xf32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x300x3x1xf32>) -> tuple<tensor<300x1x5xf32>> {
|
||||
ENTRY %tfcompile.48 {
|
||||
%arg0.1 = f32[1,300] parameter(0)
|
||||
%arg1.2 = f32[1,300,3,1] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = "mhlo.reshape"(%arg0) : (tensor<1x300xf32>) -> tensor<1x300xf32>
|
||||
// CHECK-NEXT: %[[VAL_2:.*]] = "mhlo.reshape"(%[[VAL_0]]) : (tensor<1x300xf32>) -> tensor<1x300xf32>
|
||||
%reshape.3 = f32[1,300] reshape(%arg0.1)
|
||||
|
||||
// CHECK-NEXT: %1 = "mhlo.transpose"(%0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32>
|
||||
// CHECK-NEXT: %[[VAL_3:.*]] = "mhlo.transpose"(%[[VAL_2]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32>
|
||||
%transpose.27 = f32[300,1] transpose(%reshape.3), dimensions={1,0}
|
||||
|
||||
// CHECK-NEXT: %2 = "mhlo.reshape"(%1) : (tensor<300x1xf32>) -> tensor<300x1x1xf32>
|
||||
// CHECK-NEXT: %[[VAL_4:.*]] = "mhlo.reshape"(%[[VAL_3]]) : (tensor<300x1xf32>) -> tensor<300x1x1xf32>
|
||||
%reshape.28 = f32[300,1,1] reshape(%transpose.27)
|
||||
|
||||
// CHECK-NEXT: %3 = "mhlo.reshape"(%2) : (tensor<300x1x1xf32>) -> tensor<300x1xf32>
|
||||
// CHECK-NEXT: %[[VAL_5:.*]] = "mhlo.reshape"(%[[VAL_4]]) : (tensor<300x1x1xf32>) -> tensor<300x1xf32>
|
||||
%reshape.29 = f32[300,1] reshape(%reshape.28)
|
||||
|
||||
// CHECK-NEXT: %4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<300x1xf32>) -> tensor<300x1x5xf32>
|
||||
// CHECK-NEXT: %[[VAL_6:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_5]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<300x1xf32>) -> tensor<300x1x5xf32>
|
||||
%broadcast.30 = f32[300,1,5] broadcast(%reshape.29), dimensions={0,1}
|
||||
|
||||
// CHECK-NEXT: %cst = constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK-NEXT: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
|
||||
%constant.8 = f32[] constant(1)
|
||||
|
||||
// CHECK-NEXT: %5 = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32>
|
||||
// CHECK-NEXT: %[[VAL_8:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_7]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32>
|
||||
%broadcast.9 = f32[300,1,5] broadcast(%constant.8), dimensions={}
|
||||
|
||||
// CHECK-NEXT: %6 = mhlo.multiply %4, %5 : tensor<300x1x5xf32>
|
||||
// CHECK-NEXT: %[[VAL_9:.*]] = mhlo.multiply %[[VAL_6]], %[[VAL_8]] : tensor<300x1x5xf32>
|
||||
%multiply.31 = f32[300,1,5] multiply(%broadcast.30, %broadcast.9)
|
||||
|
||||
// CHECK-NEXT: %cst_0 = constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK-NEXT: %[[VAL_10:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%constant.32 = f32[] constant(0)
|
||||
|
||||
// CHECK-NEXT: %7 = "mhlo.broadcast_in_dim"(%cst_0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32>
|
||||
// CHECK-NEXT: %[[VAL_11:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_10]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32>
|
||||
%broadcast.33 = f32[300,1,5] broadcast(%constant.32), dimensions={}
|
||||
|
||||
// CHECK-NEXT: %8 = "mhlo.compare"(%6, %7) {comparison_direction = "GT"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1>
|
||||
// CHECK-NEXT: %[[VAL_12:.*]] = "mhlo.compare"(%[[VAL_9]], %[[VAL_11]]) {comparison_direction = "GT"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1>
|
||||
%compare.34 = pred[300,1,5] compare(%multiply.31, %broadcast.33), direction=GT
|
||||
|
||||
// CHECK-NEXT: %cst_1 = constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK-NEXT: %[[VAL_13:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%constant.10 = f32[] constant(0)
|
||||
|
||||
// CHECK-NEXT: %9 = "mhlo.broadcast_in_dim"(%cst_1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32>
|
||||
// CHECK-NEXT: %[[VAL_14:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_13]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32>
|
||||
%broadcast.11 = f32[300,1,5] broadcast(%constant.10), dimensions={}
|
||||
|
||||
// CHECK-NEXT: %cst_2 = constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK-NEXT: %[[VAL_15:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%constant.40 = f32[] constant(0)
|
||||
|
||||
// CHECK-NEXT: %10 = "mhlo.broadcast_in_dim"(%cst_2) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x5xf32>
|
||||
// CHECK-NEXT: %[[VAL_16:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_15]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x5xf32>
|
||||
%broadcast.41 = f32[300,5] broadcast(%constant.40), dimensions={}
|
||||
|
||||
// CHECK-NEXT: %11 = "mhlo.copy"(%arg1) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
|
||||
// CHECK-NEXT: %[[VAL_17:.*]] = "mhlo.copy"(%[[VAL_1]]) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
|
||||
%copy.1 = f32[1,300,3,1] copy(%arg1.2)
|
||||
|
||||
// CHECK-NEXT: %12 = "mhlo.reshape"(%11) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
|
||||
// CHECK-NEXT: %[[VAL_18:.*]] = "mhlo.reshape"(%[[VAL_17]]) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
|
||||
%reshape.4 = f32[1,300,3,1] reshape(%copy.1)
|
||||
|
||||
// CHECK-NEXT: %13 = "mhlo.reshape"(%12) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32>
|
||||
// CHECK-NEXT: %[[VAL_19:.*]] = "mhlo.reshape"(%[[VAL_18]]) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32>
|
||||
%reshape.24 = f32[1,300,3] reshape(%reshape.4)
|
||||
|
||||
// CHECK-NEXT: %14 = "mhlo.transpose"(%13) {permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32>
|
||||
// CHECK-NEXT: %[[VAL_20:.*]] = "mhlo.transpose"(%[[VAL_19]]) {permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32>
|
||||
%transpose.25 = f32[300,1,3] transpose(%reshape.24), dimensions={1,0,2}
|
||||
|
||||
// CHECK-NEXT: %15 = "mhlo.reshape"(%14) : (tensor<300x1x3xf32>) -> tensor<300x3xf32>
|
||||
// CHECK-NEXT: %[[VAL_21:.*]] = "mhlo.reshape"(%[[VAL_20]]) : (tensor<300x1x3xf32>) -> tensor<300x3xf32>
|
||||
%reshape.26 = f32[300,3] reshape(%transpose.25)
|
||||
|
||||
// CHECK-NEXT: %cst_3 = constant dense<{{\[\[}}-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32>
|
||||
// CHECK-NEXT: %[[VAL_22:.*]] = mhlo.constant dense<{{\[\[}}-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32>
|
||||
%constant.35 = f32[3,5] constant({ { -0.106023, 0.121505, 0.800239, -0.768885, 0.0966113 }, { 0.689014, -0.407056, -0.797853, 0.00378925, -0.208881 }, { -0.608529, 0.0276617, 0.268557, 0.577401, -0.428437 } })
|
||||
|
||||
// TODO(b/129709049) consider making this default precision config implied.
|
||||
// CHECK-NEXT: %16 = "mhlo.dot"(%15, %cst_3) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32>
|
||||
// CHECK-NEXT: %[[VAL_23:.*]] = "mhlo.dot"(%[[VAL_21]], %[[VAL_22]]) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32>
|
||||
%dot.36 = f32[300,5] dot(%reshape.26, %constant.35), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
|
||||
// CHECK-NEXT: %cst_4 = constant dense<0.000000e+00> : tensor<5xf32>
|
||||
// CHECK-NEXT: %[[VAL_24:.*]] = mhlo.constant dense<0.000000e+00> : tensor<5xf32>
|
||||
%constant.37 = f32[5]{0} constant({0, 0, 0, 0, 0})
|
||||
|
||||
// CHECK-NEXT: %17 = "mhlo.broadcast_in_dim"(%cst_4) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<5xf32>) -> tensor<300x5xf32>
|
||||
// CHECK-NEXT: %[[VAL_25:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_24]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<5xf32>) -> tensor<300x5xf32>
|
||||
%broadcast.38 = f32[300,5] broadcast(%constant.37), dimensions={1}
|
||||
|
||||
// CHECK-NEXT: %18 = mhlo.add %16, %17 : tensor<300x5xf32>
|
||||
// CHECK-NEXT: %[[VAL_26:.*]] = mhlo.add %[[VAL_23]], %[[VAL_25]] : tensor<300x5xf32>
|
||||
%add.39 = f32[300,5] add(%dot.36, %broadcast.38)
|
||||
|
||||
// CHECK-NEXT: %19 = mhlo.maximum %10, %18 : tensor<300x5xf32>
|
||||
// CHECK-NEXT: %[[VAL_27:.*]] = mhlo.maximum %[[VAL_16]], %[[VAL_26]] : tensor<300x5xf32>
|
||||
%maximum.42 = f32[300,5] maximum(%broadcast.41, %add.39)
|
||||
|
||||
// CHECK-NEXT: %20 = "mhlo.reshape"(%19) : (tensor<300x5xf32>) -> tensor<300x1x5xf32>
|
||||
// CHECK-NEXT: %[[VAL_28:.*]] = "mhlo.reshape"(%[[VAL_27]]) : (tensor<300x5xf32>) -> tensor<300x1x5xf32>
|
||||
%reshape.44 = f32[300,1,5] reshape(%maximum.42)
|
||||
|
||||
// CHECK-NEXT: %21 = "mhlo.select"(%8, %9, %20) : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
|
||||
// CHECK-NEXT: %[[VAL_29:.*]] = "mhlo.select"(%[[VAL_12]], %[[VAL_14]], %[[VAL_28]]) : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
|
||||
%select.45 = f32[300,1,5] select(%compare.34, %broadcast.11, %reshape.44)
|
||||
|
||||
// CHECK-NEXT: %22 = "mhlo.reshape"(%21) : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
|
||||
// CHECK-NEXT: %[[VAL_30:.*]] = "mhlo.reshape"(%[[VAL_29]]) : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
|
||||
%reshape.46 = f32[300,1,5] reshape(%select.45)
|
||||
|
||||
// CHECK-NEXT: %23 = "mhlo.tuple"(%22) : (tensor<300x1x5xf32>) -> tuple<tensor<300x1x5xf32>>
|
||||
// CHECK-NEXT: return %23 : tuple<tensor<300x1x5xf32>>
|
||||
// CHECK-NEXT: %[[VAL_31:.*]] = "mhlo.tuple"(%[[VAL_30]]) : (tensor<300x1x5xf32>) -> tuple<tensor<300x1x5xf32>>
|
||||
// CHECK-NEXT: return %[[VAL_31]] : tuple<tensor<300x1x5xf32>>
|
||||
ROOT %tuple.47 = (f32[300,1,5]) tuple(%reshape.46)
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ HloModule tfcompile.20
|
||||
ENTRY %tfcompile.20 {
|
||||
%arg0.1 = f32[] parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"}
|
||||
|
||||
// CHECK: [[C0:%.+]] = constant
|
||||
// CHECK: [[C0:%.+]] = mhlo.constant
|
||||
%constant.3 = f32[] constant(10), metadata={op_type="Less" op_name="Less"}
|
||||
|
||||
// CHECK: [[R1:%.+]] = "mhlo.compare"([[A0]], [[C0]])
|
||||
|
@ -176,48 +176,49 @@ add {
|
||||
%test_constant {
|
||||
|
||||
// Scalar/0D tensor constant
|
||||
// CHECK-NEXT: %cst = constant dense<1> : tensor<i64>
|
||||
// CHECK-NEXT: %[[VAL_0:.*]] = mhlo.constant dense<1> : tensor<i64>
|
||||
%constant.0 = s64[] constant(1)
|
||||
|
||||
// Note that double brackets "[[" have to be escaped as they denote variables
|
||||
// in FileCheck. The only way to do so is to drop into regex with "{{"
|
||||
// CHECK-NEXT: constant dense<{{\[\[\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[\[}}3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32>
|
||||
// CHECK-NEXT: %[[VAL_1:.*]] = mhlo.constant dense<{{\[\[}}{{\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[}}[3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32>
|
||||
%constant.1 = f32[2,2,1,1]{3,2,1,0} constant({{{{1.0}},{{2.0}}},{{{3.0}},{{4.0}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
|
||||
|
||||
// CHECK: dense<[1, 2, 4, 8]> : tensor<4xui64>
|
||||
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<[1, 2, 4, 8]> : tensor<4xui64>
|
||||
%constant.2 = u64[4] constant({ 1, 2, 4, 8 })
|
||||
|
||||
// CHECK: dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16>
|
||||
// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16>
|
||||
%constant.3 = bf16[4] constant({1, 2, 3, 4})
|
||||
|
||||
// CHECK: dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>>
|
||||
// CHECK: %[[VAL_4:.*]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>>
|
||||
%constant.4 = c64[] constant((1, 0))
|
||||
|
||||
// CHECK: dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f64>>
|
||||
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f64>>
|
||||
%constant.5 = c128[] constant((1, 0))
|
||||
|
||||
// CHECK: dense<[1.000000e+00, -4.000000e+00, -6.550400e+04, 1.562500e-02]> : tensor<4xf16>
|
||||
// CHECK: %[[VAL_6:.*]] = mhlo.constant dense<[1.000000e+00, -4.000000e+00, -6.550400e+04, 1.562500e-02]> : tensor<4xf16>
|
||||
ROOT %constant.6 = f16[4] constant({1, -4, -65504, 0.015625})
|
||||
}
|
||||
|
||||
// TODO(b/129422361) Potentially update when copy, reshape, and conv have actual
|
||||
// implementations with attributes, etc.
|
||||
// CHECK-LABEL: func @test_conv(%arg0: tensor<256x32x32x6xf32>) -> tuple<tensor<256x30x30x16xf32>>
|
||||
// CHECK-LABEL: func @test_conv(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<256x32x32x6xf32>) -> tuple<tensor<256x30x30x16xf32>> attributes {sym_visibility = "private"} {
|
||||
%test_conv {
|
||||
%arg0.1 = f32[256,32,32,6]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
|
||||
|
||||
// CHECK-NEXT: %0 = "mhlo.copy"(%arg0) {minor_to_major = dense<[2, 1, 3, 0]> : tensor<4xindex>} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
|
||||
// CHECK-NEXT: %[[VAL_1:.*]] = "mhlo.copy"(%[[VAL_0]]) {minor_to_major = dense<[2, 1, 3, 0]> : tensor<4xindex>} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
|
||||
%copy.1 = f32[256,32,32,6]{2,1,3,0} copy(%arg0.1), metadata={op_name="HLO_Args"}
|
||||
|
||||
// CHECK-NEXT: %1 = "mhlo.reshape"(%0) {minor_to_major = dense<[2, 1, 3, 0]> : tensor<4xindex>} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
|
||||
// CHECK-NEXT: %[[VAL_2:.*]] = "mhlo.reshape"(%[[VAL_1]]) {minor_to_major = dense<[2, 1, 3, 0]> : tensor<4xindex>} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
|
||||
%reshape.2 = f32[256,32,32,6]{2,1,3,0} reshape(%copy.1)
|
||||
|
||||
// Note that double brackets "[[" have to be escaped as they denote variables
|
||||
// in FileCheck. The only way to do so is to drop into regex with "{{"
|
||||
// CHECK-NEXT: %cst = constant dense<{{\[\[\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[\[}}3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32>
|
||||
// CHECK-NEXT: %[[VAL_3:.*]] = mhlo.constant dense<{{\[\[}}{{\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[}}[3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32>
|
||||
%constant.3 = f32[2,2,1,1]{3,2,1,0} constant({{{{0.5}}, {{-0.6}}}, {{{0.3}}, {{-0.1}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
|
||||
|
||||
// CHECK-NEXT: %2 = "mhlo.convolution"(%1, %cst) {
|
||||
// CHECK-NEXT: %[[VAL_4:.*]] = "mhlo.convolution"(%[[VAL_2]], %[[VAL_3]]) {
|
||||
// CHECK-SAME: batch_group_count = 1 : i64
|
||||
// CHECK-SAME: dimension_numbers = {
|
||||
// CHECK-SAME: input_batch_dimension = 0 : i64
|
||||
@ -241,10 +242,10 @@ add {
|
||||
|
||||
%convolution.4 = f32[16,30,30,256]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->f01b, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
|
||||
|
||||
// CHECK-NEXT: %3 = "mhlo.reshape"(%2) : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32>
|
||||
// CHECK-NEXT: %[[VAL_5:.*]] = "mhlo.reshape"(%[[VAL_4]]) : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32>
|
||||
%reshape.5 = f32[256,30,30,16]{3,2,1,0} reshape(%convolution.4), metadata={op_name="HLO_Retvals"}
|
||||
|
||||
// CHECK-NEXT: "mhlo.tuple"(%3) : (tensor<256x30x30x16xf32>) -> tuple<tensor<256x30x30x16xf32>>
|
||||
// CHECK-NEXT: %[[VAL_6:.*]] = "mhlo.tuple"(%[[VAL_5]]) : (tensor<256x30x30x16xf32>) -> tuple<tensor<256x30x30x16xf32>>
|
||||
ROOT %tuple.6 = (f32[256,30,30,16]{3,2,1,0}) tuple(%reshape.5), metadata={op_name="HLO_Retvals"}
|
||||
}
|
||||
|
||||
|
@ -4,25 +4,25 @@ HloModule tfcompile.1
|
||||
|
||||
// CHECK-LABEL: func @main() -> tensor<i1> {
|
||||
ENTRY %tfcompile.1 {
|
||||
// CHECK-NEXT: %cst = constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK-NEXT: %[[VAL_0:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
|
||||
%constant.0 = f32[] constant(1)
|
||||
|
||||
// CHECK-NEXT: %cst_0 = constant dense<1.000000e+00> : tensor<f64>
|
||||
// CHECK-NEXT: %[[VAL_1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f64>
|
||||
%constant.1 = f64[] constant(1)
|
||||
|
||||
// CHECK-NEXT: %cst_1 = constant dense<1> : tensor<i8>
|
||||
// CHECK-NEXT: %[[VAL_2:.*]] = mhlo.constant dense<1> : tensor<i8>
|
||||
%constant.2 = s8[] constant(1)
|
||||
|
||||
// CHECK-NEXT: %cst_2 = constant dense<1> : tensor<i16>
|
||||
// CHECK-NEXT: %[[VAL_3:.*]] = mhlo.constant dense<1> : tensor<i16>
|
||||
%constant.3 = s16[] constant(1)
|
||||
|
||||
// CHECK-NEXT: %cst_3 = constant dense<1> : tensor<i32>
|
||||
// CHECK-NEXT: %[[VAL_4:.*]] = mhlo.constant dense<1> : tensor<i32>
|
||||
%constant.4 = s32[] constant(1)
|
||||
|
||||
// CHECK-NEXT: %cst_4 = constant dense<1> : tensor<i64>
|
||||
// CHECK-NEXT: %[[VAL_5:.*]] = mhlo.constant dense<1> : tensor<i64>
|
||||
%constant.5 = s64[] constant(1)
|
||||
|
||||
// CHECK-NEXT: %cst_5 = constant dense<true> : tensor<i1>
|
||||
// CHECK-NEXT: return %cst_5 : tensor<i1>
|
||||
// CHECK-NEXT: %[[VAL_6:.*]] = mhlo.constant dense<true> : tensor<i1>
|
||||
// CHECK-NEXT: return %[[VAL_6]] : tensor<i1>
|
||||
ROOT %constant.6 = pred[] constant(1)
|
||||
}
|
||||
|
@ -351,6 +351,7 @@ cc_library(
|
||||
":xla_op_registry",
|
||||
":xla_resource",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
|
@ -91,15 +91,15 @@ class DataFormatVecPermuteOp : public XlaOpKernel {
|
||||
: XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("src_format", &src_format_));
|
||||
OP_REQUIRES(
|
||||
ctx, src_format_.size() == 4,
|
||||
errors::InvalidArgument("Data format should have 4 characters"));
|
||||
ctx, src_format_.size() == 4 || src_format_.size() == 5,
|
||||
errors::InvalidArgument("Data format should have 4 or 5 characters"));
|
||||
TensorFormat data_format;
|
||||
OP_REQUIRES(ctx, FormatFromString(src_format_, &data_format),
|
||||
errors::InvalidArgument("Invalid data format"));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dst_format", &dst_format_));
|
||||
OP_REQUIRES(
|
||||
ctx, dst_format_.size() == 4,
|
||||
errors::InvalidArgument("Data format should have 4 characters"));
|
||||
ctx, dst_format_.size() == 4 || dst_format_.size() == 5,
|
||||
errors::InvalidArgument("Data format should have 4 or 5 characters"));
|
||||
OP_REQUIRES(ctx, FormatFromString(dst_format_, &data_format),
|
||||
errors::InvalidArgument("Invalid data format"));
|
||||
}
|
||||
@ -113,9 +113,10 @@ class DataFormatVecPermuteOp : public XlaOpKernel {
|
||||
input_tensor_shape.DebugString()));
|
||||
const int dim0 = input_tensor_shape.dim_size(0);
|
||||
OP_REQUIRES(
|
||||
ctx, dim0 == 2 || dim0 == 4,
|
||||
ctx, dim0 == 2 || dim0 == 4 || dim0 == 5,
|
||||
errors::InvalidArgument(
|
||||
"First dimension of input must be of size 4, but got shape ",
|
||||
"First dimension of input must be of size 2, 4 or 5, but got "
|
||||
"shape ",
|
||||
input_tensor_shape.DebugString()));
|
||||
if (input_rank == 2) {
|
||||
OP_REQUIRES(
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
@ -675,6 +676,38 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
|
||||
return graph;
|
||||
}
|
||||
|
||||
// Collects all control rets from `orig_control_ret_nodes` that are still valid,
|
||||
// keeping the same order.
|
||||
std::vector<std::string> GetValidControlRets(
|
||||
absl::Span<Node* const> orig_control_ret_nodes, const Graph& graph) {
|
||||
// Build map from control ret node to index.
|
||||
absl::flat_hash_map<const Node*, int> control_ret_nodes_map;
|
||||
for (int i = 0; i < orig_control_ret_nodes.size(); ++i) {
|
||||
const Node* n = orig_control_ret_nodes[i];
|
||||
control_ret_nodes_map[n] = i;
|
||||
}
|
||||
// Check which control rets are still valid.
|
||||
std::vector<bool> is_valid_control_ret(orig_control_ret_nodes.size(), false);
|
||||
int num_valid_control_rets = 0;
|
||||
for (const Node* n : graph.nodes()) {
|
||||
auto iter = control_ret_nodes_map.find(n);
|
||||
if (iter != control_ret_nodes_map.end()) {
|
||||
++num_valid_control_rets;
|
||||
is_valid_control_ret[iter->second] = true;
|
||||
}
|
||||
}
|
||||
// Return valid control rets in same order as they appear in
|
||||
// `orig_control_ret_nodes`.
|
||||
std::vector<std::string> valid_control_rets;
|
||||
valid_control_rets.reserve(num_valid_control_rets);
|
||||
for (int i = 0; i < orig_control_ret_nodes.size(); ++i) {
|
||||
if (is_valid_control_ret[i]) {
|
||||
valid_control_rets.push_back(orig_control_ret_nodes[i]->name());
|
||||
}
|
||||
}
|
||||
return valid_control_rets;
|
||||
}
|
||||
|
||||
Status XlaCompiler::CompileFunction(
|
||||
const XlaCompiler::CompileOptions& options,
|
||||
const NameAttrList& fn_name_attrs,
|
||||
@ -765,15 +798,15 @@ Status XlaCompiler::CompileFunction(
|
||||
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
|
||||
VLOG(1) << "Using MLIR bridge";
|
||||
GraphDebugInfo debug_info;
|
||||
std::vector<std::string> control_rets;
|
||||
for (const auto* control_ret_node : fbody->control_ret_nodes) {
|
||||
control_rets.push_back(control_ret_node->name());
|
||||
}
|
||||
|
||||
std::vector<std::string> valid_control_rets =
|
||||
GetValidControlRets(fbody->control_ret_nodes, *graph);
|
||||
|
||||
TF_RETURN_IF_ERROR(CompileGraphToXlaHlo(
|
||||
std::move(*graph), mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
|
||||
control_rets, options_.device_type.type_string(), options.use_tuple_arg,
|
||||
*options_.flib_def, debug_info, options_.shape_representation_fn,
|
||||
result));
|
||||
valid_control_rets, options_.device_type.type_string(),
|
||||
options.use_tuple_arg, *options_.flib_def, debug_info,
|
||||
options_.shape_representation_fn, result));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
CompileGraph(options, function_id, std::move(graph), args, result));
|
||||
|
@ -5220,13 +5220,31 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
|
||||
for (int64 spatial_dim = 0;
|
||||
spatial_dim < dnums.input_spatial_dimensions_size(); ++spatial_dim) {
|
||||
const int64 kernel_size = window_dims[spatial_dim].size();
|
||||
const int64 dilated_kernel_size =
|
||||
1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation();
|
||||
|
||||
const bool can_be_group_or_contraction =
|
||||
!window_dims[spatial_dim].window_reversal() &&
|
||||
window_dims[spatial_dim].padding_low() == 0 &&
|
||||
window_dims[spatial_dim].padding_high() == 0 &&
|
||||
window_dims[spatial_dim].window_dilation() == 1;
|
||||
const bool is_group_dim =
|
||||
can_be_group_or_contraction &&
|
||||
window_dims[spatial_dim].base_dilation() == kernel_size &&
|
||||
window_dims[spatial_dim].stride() == kernel_size - 1;
|
||||
const int64 input_size =
|
||||
input->shape().dimensions(dnums.input_spatial_dimensions(spatial_dim));
|
||||
const bool is_pure_contraction_dim =
|
||||
kernel_size == input_size && can_be_group_or_contraction &&
|
||||
window_dims[spatial_dim].base_dilation() == 1 &&
|
||||
window_dims[spatial_dim].stride() == 1;
|
||||
if (is_group_dim || is_pure_contraction_dim) {
|
||||
*(swapped_window.add_dimensions()) = window_dims[spatial_dim];
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64 dilated_kernel_size =
|
||||
1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation();
|
||||
const int64 dilated_input_size =
|
||||
1 + (input_size - 1) * window_dims[spatial_dim].base_dilation();
|
||||
|
||||
// Don't decide to swap if the input size is one, since many convolution
|
||||
// implementations can easily hand that special case efficiently.
|
||||
kernel_product *= kernel_size;
|
||||
|
@ -6654,6 +6654,32 @@ TEST_F(AlgebraicSimplifierTest, AbsEliminationMultiply) {
|
||||
GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, BroadcastCompareSimplification) {
|
||||
std::string module_string = R"(
|
||||
HloModule m
|
||||
test {
|
||||
a = s32[] parameter(0)
|
||||
b = s32[] parameter(1)
|
||||
x = s32[10]{0} parameter(2)
|
||||
broadcast_a = s32[10]{0} broadcast(a), dimensions={}
|
||||
broadcast_b = s32[10]{0} broadcast(b), dimensions={}
|
||||
add = s32[10]{0} add(broadcast_a, x)
|
||||
ROOT cmp = pred[10]{0} compare(add, broadcast_b), direction=EQ
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_string));
|
||||
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Compare(m::Parameter(2),
|
||||
m::Broadcast(m::Subtract(
|
||||
m::Parameter(1), m::Parameter(0))))));
|
||||
|
||||
// Numerically unstable transformation shouldn't be applied to floating types.
|
||||
std::string module_string_f32 =
|
||||
absl::StrReplaceAll(module_string, {{"s32", "f32"}});
|
||||
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, AbsEliminationPower2) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
|
@ -63,7 +63,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
@ -2200,20 +2199,21 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
|
||||
if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) {
|
||||
VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace";
|
||||
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
|
||||
FusedIrEmitter fused_emitter(&elemental_emitter);
|
||||
BindFusionArguments(fusion, &fused_emitter);
|
||||
|
||||
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
|
||||
// Delegate to common implementation of fused in-place dynamic-update-slice.
|
||||
return llvm_ir::EmitFusedDynamicUpdateSliceInPlace(
|
||||
fusion, GetGeneratorForOperandIrArrays(fusion), GetIrArrayFor(fusion),
|
||||
&elemental_emitter, &b_);
|
||||
fusion, GetIrArrayFor(fusion), &fused_emitter, &b_);
|
||||
} else if (fusion->IsLoopFusion()) {
|
||||
VLOG(3) << "HandleFusion kLoop";
|
||||
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
|
||||
auto operands = GetIrArraysForOperandsOf(fusion);
|
||||
FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion),
|
||||
&elemental_emitter);
|
||||
TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
|
||||
|
||||
return EmitTargetElementLoop(fusion, fused_emitter.GetRootGenerator());
|
||||
FusedIrEmitter fused_emitter(&elemental_emitter);
|
||||
BindFusionArguments(fusion, &fused_emitter);
|
||||
TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator(
|
||||
fusion->fused_expression_root()));
|
||||
return EmitTargetElementLoop(fusion, generator);
|
||||
} else if (fusion->IsOutputFusion()) {
|
||||
VLOG(3) << "HandleFusion kOutput";
|
||||
int64 dot_op_index = root->operand(0)->opcode() == HloOpcode::kDot ? 0 : 1;
|
||||
@ -3451,5 +3451,17 @@ llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue(
|
||||
return EmitBufferPointer(root_buffer, root_inst->shape());
|
||||
}
|
||||
|
||||
void IrEmitter::BindFusionArguments(const HloInstruction* fusion,
|
||||
FusedIrEmitter* fused_emitter) {
|
||||
for (int i = 0; i < fusion->operand_count(); i++) {
|
||||
const HloInstruction* operand = fusion->operand(i);
|
||||
fused_emitter->BindGenerator(
|
||||
fusion->fused_parameter(i),
|
||||
[this, operand](llvm_ir::IrArray::Index index) {
|
||||
return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
@ -43,6 +43,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
@ -234,10 +235,9 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||
std::vector<llvm_ir::IrArray> GetIrArraysForOperandsOf(
|
||||
const HloInstruction* hlo);
|
||||
|
||||
GeneratorForOperandIrArrays GetGeneratorForOperandIrArrays(
|
||||
HloInstruction* unnested_hlo) {
|
||||
return [=]() { return GetIrArraysForOperandsOf(unnested_hlo); };
|
||||
}
|
||||
// Bind all argument IrArrays of `fusion` to `fused_emitter`.
|
||||
void BindFusionArguments(const HloInstruction* fusion,
|
||||
FusedIrEmitter* fused_emitter);
|
||||
|
||||
// Augments IrArray with aliasing information.
|
||||
void AddAliasingInformationToIrArray(const HloInstruction& hlo,
|
||||
|
@ -38,14 +38,60 @@ FusionNodeIndexingEvaluation::FusionNodeIndexingEvaluation(
|
||||
// a tradeoff between compilation time and runtime here.
|
||||
const int64 FusionNodeIndexingEvaluation::kAllowedCodeDuplication = 15;
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns which ops invalidate the cache of emitted instructions by creating a
|
||||
// new BasicBlock and setting the insertion point to the newly created
|
||||
// BasicBlock. We can only reuse cached values if they were emitted in the same
|
||||
// BasicBlock as the current BasicBlock.
|
||||
bool OpInvalidatesCache(const HloInstruction* hlo) {
|
||||
switch (hlo->opcode()) {
|
||||
// This list of ops was created by inspecting the code. There is no
|
||||
// guarantee that it is complete.
|
||||
case HloOpcode::kConcatenate:
|
||||
case HloOpcode::kDot:
|
||||
case HloOpcode::kDynamicUpdateSlice:
|
||||
case HloOpcode::kPad:
|
||||
case HloOpcode::kReduce:
|
||||
case HloOpcode::kReduceWindow:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Counts the number of "real" users of 'hlo'. When 'hlo' has a fusion node as
|
||||
// user, we consider the users of the fusion parameter corresponding to 'hlo' as
|
||||
// the real users.
|
||||
int64 UserCount(const HloInstruction* hlo) {
|
||||
int64 cnt = 0;
|
||||
for (HloInstruction* user : hlo->users()) {
|
||||
if (user->opcode() == HloOpcode::kFusion) {
|
||||
// Count the number of users of the parameter corresponding to the fusion
|
||||
// operand.
|
||||
int64 operand_index = user->operand_index(hlo);
|
||||
cnt += user->fused_parameter(operand_index)->user_count();
|
||||
} else {
|
||||
++cnt;
|
||||
}
|
||||
}
|
||||
return cnt;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool FusionNodeIndexingEvaluation::CodeDuplicationTooHigh(
|
||||
const HloInstruction* producer) const {
|
||||
return EvaluateEmittedInstructions(producer) > kAllowedCodeDuplication;
|
||||
int64 emitted_instructions = EvaluateEmittedInstructions(producer);
|
||||
return emitted_instructions > kAllowedCodeDuplication ||
|
||||
(OpInvalidatesCache(producer) &&
|
||||
(emitted_instructions > 1 || UserCount(producer) > 1));
|
||||
}
|
||||
|
||||
bool FusionNodeIndexingEvaluation::MaxCodeDuplicationTooHigh() const {
|
||||
for (const auto& entry : index_usage_count_) {
|
||||
if (entry.second > kAllowedCodeDuplication) {
|
||||
if (entry.second > kAllowedCodeDuplication ||
|
||||
(OpInvalidatesCache(entry.first) &&
|
||||
(entry.second > 1 || UserCount(entry.first) > 1))) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -773,11 +773,11 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
|
||||
CHECK_EQ(HloInstruction::FusionKind::kLoop, fusion->fusion_kind());
|
||||
GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
|
||||
GetNestedComputer());
|
||||
FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion),
|
||||
&elemental_emitter);
|
||||
TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
|
||||
|
||||
return EmitTargetElementLoop(*fusion, fused_emitter.GetRootGenerator());
|
||||
FusedIrEmitter fused_emitter(&elemental_emitter);
|
||||
BindFusionArguments(fusion, &fused_emitter);
|
||||
TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator(
|
||||
fusion->fused_expression_root()));
|
||||
return EmitTargetElementLoop(*fusion, generator);
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleCall(HloInstruction* call) {
|
||||
@ -876,5 +876,17 @@ std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs(
|
||||
return output_arrays;
|
||||
}
|
||||
|
||||
void IrEmitter::BindFusionArguments(const HloInstruction* fusion,
|
||||
FusedIrEmitter* fused_emitter) {
|
||||
for (int i = 0; i < fusion->operand_count(); i++) {
|
||||
const HloInstruction* operand = fusion->operand(i);
|
||||
fused_emitter->BindGenerator(
|
||||
fusion->fused_parameter(i),
|
||||
[this, operand, fusion](llvm_ir::IrArray::Index index) {
|
||||
return GetIrArray(*operand, *fusion).EmitReadArrayElement(index, &b_);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
@ -36,6 +36,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
|
||||
@ -182,18 +183,9 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||
const HloModuleConfig& hlo_module_config_;
|
||||
|
||||
protected:
|
||||
GeneratorForOperandIrArrays GetGeneratorForOperandIrArrays(
|
||||
const HloInstruction* fusion) {
|
||||
return [=]() {
|
||||
std::vector<llvm_ir::IrArray> ir_arrays;
|
||||
ir_arrays.reserve(fusion->operand_count());
|
||||
absl::c_transform(fusion->operands(), std::back_inserter(ir_arrays),
|
||||
[&](const HloInstruction* operand) {
|
||||
return GetIrArray(*operand, *fusion);
|
||||
});
|
||||
return ir_arrays;
|
||||
};
|
||||
}
|
||||
// Bind all argument IrArrays of `fusion` to `fused_emitter`.
|
||||
void BindFusionArguments(const HloInstruction* fusion,
|
||||
FusedIrEmitter* fused_emitter);
|
||||
|
||||
private:
|
||||
// A helper method for EmitAtomicOperationForNestedComputation. Certain
|
||||
|
@ -960,19 +960,24 @@ Status IrEmitterUnnested::EmitLoopFusionFromMlir(MlirEmitterInput input,
|
||||
|
||||
GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
|
||||
GetNestedComputer());
|
||||
FusedIrEmitter fused_emitter(&elemental_emitter);
|
||||
|
||||
FusedIrEmitter fused_emitter(
|
||||
[&] {
|
||||
for (int i = 0; i < fusion_operands.size(); i++) {
|
||||
auto operand_ir_arrays =
|
||||
absl::MakeSpan(ir_arrays).subspan(0, fusion_operands.size());
|
||||
return std::vector<llvm_ir::IrArray>(operand_ir_arrays.begin(),
|
||||
operand_ir_arrays.end());
|
||||
},
|
||||
&elemental_emitter);
|
||||
TF_RETURN_IF_ERROR(
|
||||
fused_computation->root_instruction()->Accept(&fused_emitter));
|
||||
|
||||
auto element_generator = fused_emitter.GetRootGenerator();
|
||||
auto* builder = &b_;
|
||||
auto ir_array = operand_ir_arrays[i];
|
||||
fused_emitter.BindGenerator(
|
||||
fused_computation->parameter_instruction(i),
|
||||
[builder, ir_array](llvm_ir::IrArray::Index index) {
|
||||
return ir_array.EmitReadArrayElement(index, builder);
|
||||
});
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto element_generator,
|
||||
fused_emitter.GetGenerator(fused_computation->root_instruction()));
|
||||
|
||||
Shape element_shape = TypeToShape(fusion_outputs[0].getType());
|
||||
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
|
||||
element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor);
|
||||
@ -1022,14 +1027,14 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
||||
GpuElementalIrEmitter operand_elemental_emitter(
|
||||
hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
|
||||
GetNestedComputer());
|
||||
FusedIrEmitter operand_fused_emitter(
|
||||
GetGeneratorForOperandIrArrays(fusion),
|
||||
&operand_elemental_emitter);
|
||||
TF_RETURN_IF_ERROR(
|
||||
root->mutable_operand(0)->Accept(&operand_fused_emitter));
|
||||
FusedIrEmitter operand_fused_emitter(&operand_elemental_emitter);
|
||||
BindFusionArguments(fusion, &operand_fused_emitter);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto generator,
|
||||
operand_fused_emitter.GetGenerator(root->operand(0)));
|
||||
|
||||
TF_RETURN_IF_ERROR(EmitTargetElementLoopInThunk(
|
||||
*fusion, operand_fused_emitter.GetGenerator(root->operand(0)),
|
||||
*fusion, generator,
|
||||
static_cast<KernelThunk*>(thunks.back().get()),
|
||||
ComputeMaxUnrollFactor(fusion)));
|
||||
}
|
||||
@ -1044,10 +1049,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
||||
GpuElementalIrEmitter scatter_elemental_emitter(
|
||||
hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
|
||||
GetNestedComputer());
|
||||
FusedIrEmitter scatter_fused_emitter(
|
||||
GetGeneratorForOperandIrArrays(fusion),
|
||||
&scatter_elemental_emitter);
|
||||
TF_RETURN_IF_ERROR(root->Accept(&scatter_fused_emitter));
|
||||
FusedIrEmitter scatter_fused_emitter(&scatter_elemental_emitter);
|
||||
BindFusionArguments(fusion, &scatter_fused_emitter);
|
||||
CHECK_EQ(root->parent()->FusionInstruction(), fusion);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -1063,10 +1066,12 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
||||
desc.unique_indices = root->unique_indices();
|
||||
desc.update_computation = root->called_computations()[0];
|
||||
desc.output = GetIrArray(*fusion, *fusion);
|
||||
desc.scatter_indices_gen =
|
||||
scatter_fused_emitter.GetGenerator(root->operand(1));
|
||||
desc.updates_gen =
|
||||
scatter_fused_emitter.GetGenerator(root->operand(2));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
desc.scatter_indices_gen,
|
||||
scatter_fused_emitter.GetGenerator(root->operand(1)));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
desc.updates_gen,
|
||||
scatter_fused_emitter.GetGenerator(root->operand(2)));
|
||||
desc.get_index_type = [&](int64 launch_size) {
|
||||
return GetIndexTypeForKernel(root, launch_size, &b_);
|
||||
};
|
||||
@ -1133,9 +1138,11 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
||||
ir_emitter_context_->llvm_module());
|
||||
AddThunkToThunkSequence(std::move(fusion_thunk));
|
||||
|
||||
FusedIrEmitter fused_emitter(&elemental_emitter);
|
||||
BindFusionArguments(fusion, &fused_emitter);
|
||||
|
||||
return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace(
|
||||
fusion, GetGeneratorForOperandIrArrays(fusion), output_array,
|
||||
&elemental_emitter, launch_dimensions, &b_);
|
||||
fusion, output_array, &fused_emitter, launch_dimensions, &b_);
|
||||
}
|
||||
|
||||
CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop)
|
||||
@ -2596,13 +2603,13 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
|
||||
ir_emitter_context_->llvm_module(),
|
||||
&b_, GetNestedComputer());
|
||||
|
||||
FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo),
|
||||
&elemental_emitter);
|
||||
TF_RETURN_IF_ERROR(init_value_operand->Accept(&fused_emitter));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ParallelLoopEmitter(fused_emitter.GetGenerator(init_value_operand),
|
||||
GetIrArray(*hlo, *hlo, index), launch_dimensions,
|
||||
&b_)
|
||||
FusedIrEmitter fused_emitter(&elemental_emitter);
|
||||
BindFusionArguments(hlo, &fused_emitter);
|
||||
TF_ASSIGN_OR_RETURN(auto generator,
|
||||
fused_emitter.GetGenerator(init_value_operand));
|
||||
TF_RETURN_IF_ERROR(ParallelLoopEmitter(generator,
|
||||
GetIrArray(*hlo, *hlo, index),
|
||||
launch_dimensions, &b_)
|
||||
.EmitLoop(IrName(hlo)));
|
||||
} else {
|
||||
// In the unfused case the element is already there, just read from it.
|
||||
@ -3097,15 +3104,35 @@ void IrEmitterUnnested::EmitTileElementForFusion(
|
||||
std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(*hlo);
|
||||
GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
|
||||
GetNestedComputer());
|
||||
FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo),
|
||||
&elem_emitter, x_loc, y_loc,
|
||||
param_shmem_buffers);
|
||||
|
||||
TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter));
|
||||
FusedIrEmitter fused_emitter(&elem_emitter);
|
||||
for (int i = 0; i < hlo->operand_count(); i++) {
|
||||
llvm_ir::ElementGenerator gen;
|
||||
if (llvm::Value* param_tile_buffer = param_shmem_buffers[i]) {
|
||||
gen = [this, param_tile_buffer, x_loc,
|
||||
y_loc](llvm_ir::IrArray::Index index) {
|
||||
// TODO(jlebar): Add AA metadata to this load. Tile buffers are
|
||||
// global variables, so LLVM's points-to analysis doesn't help us
|
||||
// much. And we want the AA info to be present before address
|
||||
// spaces are inferred (which is pretty late in the pipeline), so
|
||||
// even if we had address-space-based AA in LLVM, it wouldn't help
|
||||
// us much here.
|
||||
return b_.CreateLoad(
|
||||
b_.CreateGEP(param_tile_buffer,
|
||||
{index.GetConstantWithIndexType(0), x_loc, y_loc}),
|
||||
"tiled_buffer");
|
||||
};
|
||||
} else {
|
||||
const HloInstruction* operand = hlo->operand(i);
|
||||
gen = [this, operand, hlo](llvm_ir::IrArray::Index index) {
|
||||
return GetIrArray(*operand, *hlo).EmitReadArrayElement(index, &b_);
|
||||
};
|
||||
}
|
||||
fused_emitter.BindGenerator(hlo->fused_parameter(i), std::move(gen));
|
||||
}
|
||||
IrArray::Index untiled_index = GetUnnormalizedIndex(
|
||||
index, output_arrays[0].GetShape(), &b_, mapping_scheme);
|
||||
const llvm_ir::ElementGenerator& output_generator =
|
||||
fused_emitter.GetRootGenerator();
|
||||
llvm_ir::ElementGenerator output_generator =
|
||||
*fused_emitter.GetGenerator(hlo->fused_expression_root());
|
||||
llvm::Value* output_value = output_generator(untiled_index).ValueOrDie();
|
||||
if (hlo->IsMultiOutputFusion()) {
|
||||
DCHECK(output_value->getType()->isStructTy());
|
||||
@ -3161,13 +3188,11 @@ void IrEmitterUnnested::EmitPrologueForReduction(
|
||||
llvm::Value* init_ir_value;
|
||||
const HloInstruction* init_value = reduce_inst->operand(1);
|
||||
if (unnested_hlo->opcode() == HloOpcode::kFusion) {
|
||||
FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo),
|
||||
&elemental_emitter);
|
||||
FusedIrEmitter fused_emitter(&elemental_emitter);
|
||||
BindFusionArguments(unnested_hlo, &fused_emitter);
|
||||
|
||||
TF_CHECK_OK(init_value->Accept(&fused_emitter));
|
||||
init_ir_value =
|
||||
fused_emitter
|
||||
.GetGenerator(init_value)(IrArray::Index(b_.getInt32Ty()))
|
||||
init_ir_value = (*fused_emitter.GetGenerator(init_value))(
|
||||
IrArray::Index(b_.getInt32Ty()))
|
||||
.ValueOrDie();
|
||||
} else {
|
||||
init_ir_value =
|
||||
@ -3507,21 +3532,21 @@ void IrEmitterUnnested::EmitTileElementForReduction(
|
||||
extra_output_gens;
|
||||
GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
|
||||
GetNestedComputer());
|
||||
FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo),
|
||||
&elem_emitter);
|
||||
FusedIrEmitter fused_emitter(&elem_emitter);
|
||||
|
||||
// Construct the ElementGenerator for each reduction and extra output in the
|
||||
// the group of output instructions.
|
||||
if (unnested_hlo->opcode() == HloOpcode::kFusion) {
|
||||
TF_CHECK_OK(unnested_hlo->fused_expression_root()->Accept(&fused_emitter));
|
||||
BindFusionArguments(unnested_hlo, &fused_emitter);
|
||||
|
||||
for (int i = 0, e = output_instructions.size(); i != e; ++i) {
|
||||
const HloInstruction* inst = output_instructions[i];
|
||||
ShapeIndex idx =
|
||||
CreateShapeIndexForOutputInstruction(*unnested_hlo, *inst);
|
||||
if (IsReductionFromOrToContiguousDimensions(*inst)) {
|
||||
input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0)));
|
||||
input_gens.push_back(*fused_emitter.GetGenerator(inst->operand(0)));
|
||||
} else {
|
||||
extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst),
|
||||
extra_output_gens.emplace_back(*fused_emitter.GetGenerator(inst),
|
||||
std::move(idx));
|
||||
}
|
||||
}
|
||||
@ -4506,11 +4531,10 @@ void IrEmitterUnnested::EmitElementForInputFusibleSlices(
|
||||
std::vector<llvm::Value*> input_ir_values;
|
||||
GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
|
||||
GetNestedComputer());
|
||||
FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo),
|
||||
&elem_emitter);
|
||||
TF_CHECK_OK(unnested_hlo->fused_expression_root()->Accept(&fused_emitter));
|
||||
FusedIrEmitter fused_emitter(&elem_emitter);
|
||||
BindFusionArguments(unnested_hlo, &fused_emitter);
|
||||
for (const HloInstruction* slice : slice_instructions) {
|
||||
auto input_generator = fused_emitter.GetGenerator(slice->operand(0));
|
||||
auto input_generator = *fused_emitter.GetGenerator(slice->operand(0));
|
||||
input_ir_values.push_back(input_generator(index).ValueOrDie());
|
||||
}
|
||||
|
||||
|
@ -17,7 +17,6 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
|
||||
|
||||
@ -190,9 +189,8 @@ Status EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays,
|
||||
//
|
||||
// Emits a sequential loop if launch_dimensions is null.
|
||||
static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
|
||||
HloInstruction* fusion,
|
||||
GeneratorForOperandIrArrays operand_arrays_generator,
|
||||
const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
|
||||
HloInstruction* fusion, const IrArray& fusion_output_array,
|
||||
FusedIrEmitter* fused_emitter,
|
||||
const gpu::LaunchDimensions* launch_dimensions, llvm::IRBuilder<>* b) {
|
||||
CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
|
||||
VLOG(2) << "EmitFusedDynamicUpdateSliceInPlace for "
|
||||
@ -221,14 +219,14 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
|
||||
LayoutUtil::CopyLayoutBetweenShapes(fusion->shape(), &update_shape));
|
||||
|
||||
// Create element generators for update and start_indices.
|
||||
FusedIrEmitter fused_emitter(std::move(operand_arrays_generator),
|
||||
elemental_emitter);
|
||||
TF_RETURN_IF_ERROR(dynamic_update_slice->Accept(&fused_emitter));
|
||||
ElementGenerator update_array_generator = fused_emitter.GetGenerator(update);
|
||||
TF_ASSIGN_OR_RETURN(ElementGenerator update_array_generator,
|
||||
fused_emitter->GetGenerator(update));
|
||||
|
||||
IndexGenerator start_indices_generator = [&](int64 index) {
|
||||
ElementGenerator element_generator =
|
||||
fused_emitter.GetGenerator(dynamic_update_slice->operand(2 + index));
|
||||
IndexGenerator start_indices_generator =
|
||||
[&](int64 index) -> StatusOr<llvm::Value*> {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
ElementGenerator element_generator,
|
||||
fused_emitter->GetGenerator(dynamic_update_slice->operand(2 + index)));
|
||||
return element_generator(IrArray::Index(b->getInt64Ty()));
|
||||
};
|
||||
bool is_signed = ShapeUtil::ElementIsSigned(start_indices->shape());
|
||||
@ -237,25 +235,21 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
|
||||
fusion_output_array, launch_dimensions, IrName(fusion), b);
|
||||
}
|
||||
|
||||
Status EmitFusedDynamicUpdateSliceInPlace(
|
||||
HloInstruction* fusion,
|
||||
GeneratorForOperandIrArrays operand_arrays_generator,
|
||||
const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
|
||||
Status EmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion,
|
||||
const IrArray& fusion_output_array,
|
||||
FusedIrEmitter* fused_emitter,
|
||||
llvm::IRBuilder<>* b) {
|
||||
return EmitFusedDynamicUpdateSliceInPlaceImpl(
|
||||
fusion, std::move(operand_arrays_generator), fusion_output_array,
|
||||
elemental_emitter,
|
||||
fusion, fusion_output_array, fused_emitter,
|
||||
/*launch_dimensions=*/nullptr, b);
|
||||
}
|
||||
|
||||
Status EmitParallelFusedDynamicUpdateSliceInPlace(
|
||||
HloInstruction* fusion,
|
||||
GeneratorForOperandIrArrays operand_arrays_generator,
|
||||
const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
|
||||
HloInstruction* fusion, const IrArray& fusion_output_array,
|
||||
FusedIrEmitter* fused_emitter,
|
||||
const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b) {
|
||||
return EmitFusedDynamicUpdateSliceInPlaceImpl(
|
||||
fusion, std::move(operand_arrays_generator), fusion_output_array,
|
||||
elemental_emitter, &launch_dimensions, b);
|
||||
fusion, fusion_output_array, fused_emitter, &launch_dimensions, b);
|
||||
}
|
||||
|
||||
} // namespace llvm_ir
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
||||
|
||||
// Utilities related to emitting LLVM IR for various HLO ops.
|
||||
@ -71,18 +72,16 @@ Status EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays,
|
||||
// array-to-be-updated and output share the same buffer slice, emits
|
||||
// (sequential) code for a fusion node that does the dynamic-update-slice in
|
||||
// place.
|
||||
Status EmitFusedDynamicUpdateSliceInPlace(
|
||||
HloInstruction* fusion,
|
||||
GeneratorForOperandIrArrays operand_arrays_generator,
|
||||
const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
|
||||
Status EmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion,
|
||||
const IrArray& fusion_output_array,
|
||||
FusedIrEmitter* fused_emitter,
|
||||
llvm::IRBuilder<>* b);
|
||||
|
||||
// Same as EmitFusedDynamicUpdateSliceInPlace, except emits a parallel loop with
|
||||
// the given launch dimensions.
|
||||
Status EmitParallelFusedDynamicUpdateSliceInPlace(
|
||||
HloInstruction* fusion,
|
||||
GeneratorForOperandIrArrays operand_arrays_generator,
|
||||
const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
|
||||
HloInstruction* fusion, const IrArray& fusion_output_array,
|
||||
FusedIrEmitter* fused_emitter,
|
||||
const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b);
|
||||
|
||||
} // namespace llvm_ir
|
||||
|
@ -114,25 +114,9 @@ Status FusedIrEmitter::HandleGetTupleElement(
|
||||
}
|
||||
|
||||
Status FusedIrEmitter::HandleParameter(const HloInstruction* parameter) {
|
||||
indexed_generators_[parameter] =
|
||||
[=](const IrArray::Index& index) -> llvm::Value* {
|
||||
int64 param_num = parameter->parameter_number();
|
||||
if (param_shmem_buffers_.size() > param_num) {
|
||||
if (llvm::Value* param_tile_buffer = param_shmem_buffers_[param_num]) {
|
||||
// TODO(jlebar): Add AA metadata to this load. Tile buffers are global
|
||||
// variables, so LLVM's points-to analysis doesn't help us much. And we
|
||||
// want the AA info to be present before address spaces are inferred
|
||||
// (which is pretty late in the pipeline), so even if we had
|
||||
// address-space-based AA in LLVM, it wouldn't help us much here.
|
||||
return b_->CreateLoad(
|
||||
b_->CreateGEP(param_tile_buffer, {index.GetConstantWithIndexType(0),
|
||||
thread_id_x_, thread_id_y_}),
|
||||
"tiled_buffer");
|
||||
if (indexed_generators_.find(parameter) == indexed_generators_.end()) {
|
||||
return InvalidArgument("Unbound parameter: %s", parameter->ToString());
|
||||
}
|
||||
}
|
||||
return GetIrArrayForFusedParameter(param_num).EmitReadArrayElement(index,
|
||||
b_);
|
||||
};
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -157,22 +141,6 @@ Status FusedIrEmitter::HandleTuple(const HloInstruction* tuple) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FusedIrEmitter::FinishVisit(const HloInstruction* root) {
|
||||
fused_root_ = root;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
FusedIrEmitter::IndexedGenerator FusedIrEmitter::GetRootGenerator() const {
|
||||
CHECK_NE(nullptr, fused_root_)
|
||||
<< "GetRootGenerator should be called after Accept.";
|
||||
return indexed_generators_.at(fused_root_);
|
||||
}
|
||||
|
||||
FusedIrEmitter::IndexedGenerator FusedIrEmitter::GetGenerator(
|
||||
const HloInstruction* instruction) const {
|
||||
return indexed_generators_.at(instruction);
|
||||
}
|
||||
|
||||
bool FusedIrEmitter::IsFusedIrEmitterInefficient(
|
||||
const HloInstruction* consumer, const HloInstruction* producer) {
|
||||
if (consumer->opcode() != HloOpcode::kFusion) {
|
||||
@ -189,4 +157,39 @@ bool FusedIrEmitter::IsFusedIrEmitterInefficient(
|
||||
return eval_producer.MaxCodeDuplicationTooHigh();
|
||||
}
|
||||
|
||||
StatusOr<FusedIrEmitter::IndexedGenerator> FusedIrEmitter::GetGenerator(
|
||||
const HloInstruction* instruction) {
|
||||
std::vector<const HloInstruction*> stack;
|
||||
stack.push_back(instruction);
|
||||
while (!stack.empty()) {
|
||||
const HloInstruction* instr = stack.back();
|
||||
stack.pop_back();
|
||||
if (indexed_generators_.count(instr)) {
|
||||
continue;
|
||||
}
|
||||
for (const HloInstruction* operand : instr->operands()) {
|
||||
stack.push_back(operand);
|
||||
}
|
||||
switch (instr->opcode()) {
|
||||
case HloOpcode::kConstant:
|
||||
TF_RETURN_IF_ERROR(HandleConstant(instr));
|
||||
break;
|
||||
case HloOpcode::kGetTupleElement:
|
||||
TF_RETURN_IF_ERROR(HandleGetTupleElement(instr));
|
||||
break;
|
||||
case HloOpcode::kParameter:
|
||||
TF_RETURN_IF_ERROR(HandleParameter(instr));
|
||||
break;
|
||||
case HloOpcode::kTuple:
|
||||
TF_RETURN_IF_ERROR(HandleTuple(instr));
|
||||
break;
|
||||
default:
|
||||
TF_RETURN_IF_ERROR(DefaultAction(instr));
|
||||
break;
|
||||
}
|
||||
CHECK(indexed_generators_.count(instr));
|
||||
}
|
||||
return indexed_generators_.at(instruction);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Value.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
||||
@ -51,47 +50,22 @@ namespace xla {
|
||||
// created produces an LLVM struct with N elements, one for each element of the
|
||||
// arrays in the tuple. It follows that the arrays in the tuple must have the
|
||||
// same length.
|
||||
class FusedIrEmitter : public ConstDfsHloVisitorWithDefault {
|
||||
class FusedIrEmitter {
|
||||
public:
|
||||
using IndexedGenerator = llvm_ir::ElementGenerator;
|
||||
using NonIndexedGenerator = std::function<StatusOr<llvm::Value*>()>;
|
||||
using GeneratorForOperandIrArrays =
|
||||
std::function<std::vector<llvm_ir::IrArray>()>;
|
||||
|
||||
FusedIrEmitter(GeneratorForOperandIrArrays operand_arrays_generator,
|
||||
ElementalIrEmitter* elemental_emitter,
|
||||
llvm::Value* thread_id_x = nullptr,
|
||||
llvm::Value* thread_id_y = nullptr,
|
||||
absl::Span<llvm::Value* const> param_shmem_buffers = {})
|
||||
: operand_arrays_(),
|
||||
operand_arrays_generator_(std::move(operand_arrays_generator)),
|
||||
thread_id_x_(thread_id_x),
|
||||
thread_id_y_(thread_id_y),
|
||||
param_shmem_buffers_(param_shmem_buffers.begin(),
|
||||
param_shmem_buffers.end()),
|
||||
elemental_emitter_(elemental_emitter),
|
||||
explicit FusedIrEmitter(ElementalIrEmitter* elemental_emitter)
|
||||
: elemental_emitter_(elemental_emitter),
|
||||
b_(elemental_emitter->b()),
|
||||
module_(elemental_emitter->module()) {}
|
||||
|
||||
Status DefaultAction(const HloInstruction* hlo) override;
|
||||
|
||||
Status HandleConstant(const HloInstruction* constant) override;
|
||||
|
||||
Status HandleGetTupleElement(
|
||||
const HloInstruction* get_tuple_element) override;
|
||||
|
||||
Status HandleParameter(const HloInstruction* parameter) override;
|
||||
|
||||
// Emits the ir value for each element in the tuple.
|
||||
Status HandleTuple(const HloInstruction* tuple) override;
|
||||
|
||||
Status FinishVisit(const HloInstruction* root) override;
|
||||
|
||||
// Returns the generator function for the root of the fused computation.
|
||||
IndexedGenerator GetRootGenerator() const;
|
||||
void BindGenerator(const HloInstruction* hlo,
|
||||
llvm_ir::ElementGenerator generator) {
|
||||
indexed_generators_[hlo] = std::move(generator);
|
||||
}
|
||||
|
||||
// Returns the generator function for the given instruction.
|
||||
IndexedGenerator GetGenerator(const HloInstruction* instruction) const;
|
||||
StatusOr<IndexedGenerator> GetGenerator(const HloInstruction* instruction);
|
||||
|
||||
// Evaluates whether fusing 'producer' into 'consumer' might cause exponential
|
||||
// behavior in FusedIrEmitter. We currently can have exponential time/memory
|
||||
@ -101,40 +75,20 @@ class FusedIrEmitter : public ConstDfsHloVisitorWithDefault {
|
||||
static bool IsFusedIrEmitterInefficient(const HloInstruction* consumer,
|
||||
const HloInstruction* producer);
|
||||
|
||||
protected:
|
||||
// Returns the IrArrays for the fusion instruction operands.
|
||||
llvm_ir::IrArray& GetIrArrayForFusedParameter(int64 parameter_number) {
|
||||
if (!operand_arrays_.has_value()) {
|
||||
operand_arrays_ = operand_arrays_generator_();
|
||||
}
|
||||
return operand_arrays_.value()[parameter_number];
|
||||
}
|
||||
|
||||
llvm::Value* GetBasePointerForFusedParameter(int64 parameter_number) {
|
||||
return GetIrArrayForFusedParameter(parameter_number).GetBasePointer();
|
||||
}
|
||||
|
||||
private:
|
||||
// IrArrays for the fusion instruction operands, whose base addresses are the
|
||||
// base address of the corresponding parameters in the fused computation.
|
||||
absl::optional<std::vector<llvm_ir::IrArray>> operand_arrays_;
|
||||
GeneratorForOperandIrArrays operand_arrays_generator_;
|
||||
Status DefaultAction(const HloInstruction* hlo);
|
||||
|
||||
// The x coordinate within a tile.
|
||||
llvm::Value* thread_id_x_;
|
||||
Status HandleConstant(const HloInstruction* constant);
|
||||
|
||||
// The y coordinate within a tile.
|
||||
llvm::Value* thread_id_y_;
|
||||
Status HandleGetTupleElement(const HloInstruction* get_tuple_element);
|
||||
|
||||
// Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr
|
||||
// if the parameter is not tiled.
|
||||
std::vector<llvm::Value*> param_shmem_buffers_;
|
||||
Status HandleParameter(const HloInstruction* parameter);
|
||||
|
||||
// Emits the ir value for each element in the tuple.
|
||||
Status HandleTuple(const HloInstruction* tuple);
|
||||
|
||||
ElementalIrEmitter* elemental_emitter_;
|
||||
|
||||
// This member will be set by FinishVisit and used in GetRootGenerator.
|
||||
const HloInstruction* fused_root_ = nullptr;
|
||||
|
||||
// Borrowed
|
||||
llvm::IRBuilder<>* b_;
|
||||
llvm::Module* module_;
|
||||
@ -145,12 +99,6 @@ class FusedIrEmitter : public ConstDfsHloVisitorWithDefault {
|
||||
std::unordered_map<const HloInstruction*, IndexedGenerator>
|
||||
indexed_generators_;
|
||||
|
||||
// Map from tuple-result-producing GetTupleELement instructions to functions
|
||||
// that generate the base pointers for the output elements. This is used to
|
||||
// support the translation of nested GetTupleElement instructions.
|
||||
std::unordered_map<const HloInstruction*, NonIndexedGenerator>
|
||||
non_indexed_generators_;
|
||||
|
||||
// Cache of generated values, lest we regenerate an element of a node with
|
||||
// multiple outgoing edges
|
||||
absl::flat_hash_map<
|
||||
|
@ -521,8 +521,7 @@ XLA_TEST_F(ConcatTest, ConcatDeeplyNested) {
|
||||
ComputeAndCompareR1<float>(&builder, expected, {a_data.get()});
|
||||
}
|
||||
|
||||
// TODO(b/169314478): Enable the test when the slow compilation is fixed.
|
||||
XLA_TEST_F(ConcatTestHlo, DISABLED_ConcatWithBitcast) {
|
||||
XLA_TEST_F(ConcatTestHlo, ConcatWithBitcast) {
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule jit_broken.874
|
||||
|
||||
@ -762,7 +761,7 @@ ENTRY jit_broken.874 {
|
||||
auto input_array = absl::make_unique<Array2D<float>>(4, 2);
|
||||
input_array->FillUnique(1.0f);
|
||||
auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
|
||||
EXPECT_TRUE(RunAndCompare(std::move(module), {&input}, absl::nullopt));
|
||||
EXPECT_TRUE(RunAndCompare(std::move(module), {&input}, error_spec_));
|
||||
}
|
||||
|
||||
// Describes a binary rank-2 concatenation test.
|
||||
|
@ -354,8 +354,11 @@ void BaseCollectiveExecutor::CompleteParamsAsync(
|
||||
// timeout callback executes, done_safe will become a no-op and the timeout
|
||||
// callback is responsible for invoking done() at the end.
|
||||
const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
|
||||
auto done_safe = [this, is_callback_called, cancel_mgr,
|
||||
auto trace_id =
|
||||
profiler::TraceMe::ActivityStart("CollectiveExecutor::CompleteParams");
|
||||
auto done_safe = [this, is_callback_called, cancel_mgr, trace_id,
|
||||
done](const Status& s) {
|
||||
profiler::TraceMe::ActivityEnd(trace_id);
|
||||
bool called = is_callback_called->exchange(true);
|
||||
if (!called) {
|
||||
if (!s.ok() && !IsCancelled(cancel_mgr)) {
|
||||
|
@ -2587,11 +2587,9 @@ TEST(DirectSessionTest,
|
||||
|
||||
// A simple benchmark for the overhead of `DirectSession::Run()` calls
|
||||
// with varying numbers of feeds/fetches.
|
||||
void FeedFetchBenchmarkHelper(int iters, int num_feeds, bool use_make_callable,
|
||||
int inter_op_threads,
|
||||
void FeedFetchBenchmarkHelper(::testing::benchmark::State& state, int num_feeds,
|
||||
bool use_make_callable, int inter_op_threads,
|
||||
bool use_single_threaded_executor) {
|
||||
testing::StopTiming();
|
||||
|
||||
Tensor value(DT_FLOAT, TensorShape());
|
||||
value.flat<float>()(0) = 37.0;
|
||||
|
||||
@ -2643,13 +2641,11 @@ void FeedFetchBenchmarkHelper(int iters, int num_feeds, bool use_make_callable,
|
||||
}
|
||||
TF_CHECK_OK(session->MakeCallable(callable_options, &handle));
|
||||
|
||||
testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
for (auto s : state) {
|
||||
std::vector<Tensor> output_values;
|
||||
TF_CHECK_OK(
|
||||
session->RunCallable(handle, input_tensors, &output_values, nullptr));
|
||||
}
|
||||
testing::StopTiming();
|
||||
} else {
|
||||
{
|
||||
// NOTE(mrry): Ignore the first run, which will incur the graph
|
||||
@ -2661,32 +2657,40 @@ void FeedFetchBenchmarkHelper(int iters, int num_feeds, bool use_make_callable,
|
||||
std::vector<Tensor> output_values;
|
||||
TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
|
||||
}
|
||||
testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
|
||||
for (auto s : state) {
|
||||
std::vector<Tensor> output_values;
|
||||
TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
|
||||
}
|
||||
testing::StopTiming();
|
||||
}
|
||||
}
|
||||
|
||||
void BM_FeedFetch(int iters, int num_feeds) {
|
||||
FeedFetchBenchmarkHelper(iters, num_feeds, /* use_make_callable */ false,
|
||||
void BM_FeedFetch(::testing::benchmark::State& state) {
|
||||
const int num_feeds = state.range(0);
|
||||
|
||||
FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ false,
|
||||
/* inter_op_threads */ 0,
|
||||
/* use_single_threaded_executor */ false);
|
||||
}
|
||||
void BM_FeedFetchCallable(int iters, int num_feeds) {
|
||||
FeedFetchBenchmarkHelper(iters, num_feeds, /* use_make_callable */ true,
|
||||
void BM_FeedFetchCallable(::testing::benchmark::State& state) {
|
||||
const int num_feeds = state.range(0);
|
||||
|
||||
FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ true,
|
||||
/* inter_op_threads */ 0,
|
||||
/* use_single_threaded_executor */ false);
|
||||
}
|
||||
void BM_FeedFetchCallableSingleThread(int iters, int num_feeds) {
|
||||
FeedFetchBenchmarkHelper(iters, num_feeds, /* use_make_callable */ true,
|
||||
void BM_FeedFetchCallableSingleThread(::testing::benchmark::State& state) {
|
||||
const int num_feeds = state.range(0);
|
||||
|
||||
FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ true,
|
||||
/* inter_op_threads */ -1,
|
||||
/* use_single_threaded_executor */ false);
|
||||
}
|
||||
void BM_FeedFetchCallableSingleThreadExecutor(int iters, int num_feeds) {
|
||||
FeedFetchBenchmarkHelper(iters, num_feeds, /* use_make_callable */ true,
|
||||
void BM_FeedFetchCallableSingleThreadExecutor(
|
||||
::testing::benchmark::State& state) {
|
||||
const int num_feeds = state.range(0);
|
||||
|
||||
FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ true,
|
||||
/* inter_op_threads */ -1,
|
||||
/* use_single_threaded_executor */ true);
|
||||
}
|
||||
|
@ -378,6 +378,12 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) {
|
||||
} else if (!input_def.type_attr().empty() &&
|
||||
!input_def.number_attr().empty()) {
|
||||
InferSingleTypeInputListAttrs(input_def, inputs_[start]->dtype, num_inputs);
|
||||
} else if (!input_def.number_attr().empty()) {
|
||||
if (inference_attrs_.find(input_def.number_attr()) ==
|
||||
inference_attrs_.end()) {
|
||||
MutableAttrs()->Set(input_def.number_attr(), num_inputs);
|
||||
inference_attrs_.insert(input_def.number_attr());
|
||||
}
|
||||
} else {
|
||||
return errors::InvalidArgument("Invalid input list definition");
|
||||
}
|
||||
|
@ -69,8 +69,8 @@ class TestEnv {
|
||||
Device* cpu_device_;
|
||||
};
|
||||
|
||||
void BM_CreateGraph(int iters) {
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
void BM_CreateGraph(::testing::benchmark::State& state) {
|
||||
for (auto s : state) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}});
|
||||
auto M = ops::MatMul(root, C, C);
|
||||
@ -79,8 +79,7 @@ void BM_CreateGraph(int iters) {
|
||||
}
|
||||
BENCHMARK(BM_CreateGraph);
|
||||
|
||||
void BM_RunGraph(int iters) {
|
||||
tensorflow::testing::StopTiming();
|
||||
void BM_RunGraph(::testing::benchmark::State& state) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}});
|
||||
auto M = ops::MatMul(root, C, C);
|
||||
@ -89,28 +88,24 @@ void BM_RunGraph(int iters) {
|
||||
opts.config.set_intra_op_parallelism_threads(1);
|
||||
ClientSession sess(root, opts);
|
||||
std::vector<Tensor> outputs;
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
for (auto s : state) {
|
||||
outputs.clear();
|
||||
TF_CHECK_OK(sess.Run({M}, &outputs));
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_RunGraph);
|
||||
|
||||
void BM_CreateAndDestroySession(int iters) {
|
||||
tensorflow::testing::StopTiming();
|
||||
void BM_CreateAndDestroySession(::testing::benchmark::State& state) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}});
|
||||
auto M = ops::MatMul(root, C, C);
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
for (auto s : state) {
|
||||
ClientSession sess(root);
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_CreateAndDestroySession);
|
||||
|
||||
void BM_KernelAndDeviceInit(int iters) {
|
||||
tensorflow::testing::StopTiming();
|
||||
void BM_KernelAndDeviceInit(::testing::benchmark::State& state) {
|
||||
NodeDef ndef(AttrBuilder("MatMul")
|
||||
.Set("T", DT_FLOAT)
|
||||
.Set("transpose_a", false)
|
||||
@ -120,15 +115,13 @@ void BM_KernelAndDeviceInit(int iters) {
|
||||
TestEnv env;
|
||||
KernelAndDeviceOp k(nullptr, false, env.function_library_runtime(), nullptr,
|
||||
nullptr, env.cpu_device());
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
for (auto s : state) {
|
||||
TF_CHECK_OK(k.Init({}, ndef, nullptr));
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_KernelAndDeviceInit);
|
||||
|
||||
void BM_KernelAndDeviceRun(int iters) {
|
||||
tensorflow::testing::StopTiming();
|
||||
void BM_KernelAndDeviceRun(::testing::benchmark::State& state) {
|
||||
Tensor t(Input({{1.0f, 2.0f}, {3.0f, 4.0f}}).tensor());
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
inputs.push_back(TensorValue(&t));
|
||||
@ -145,8 +138,7 @@ void BM_KernelAndDeviceRun(int iters) {
|
||||
nullptr, env.cpu_device());
|
||||
TF_CHECK_OK(k.Init({}, ndef, nullptr));
|
||||
const EagerKernelArgs args(std::move(inputs));
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
for (auto s : state) {
|
||||
TF_CHECK_OK(k.Run(nullptr, args, &outputs, nullptr, absl::nullopt));
|
||||
}
|
||||
}
|
||||
|
@ -433,11 +433,10 @@ TEST_F(ExecutorTest, NoInputTensors) {
|
||||
// Create a graph that is 'depth' deep. At each level, fan-in and fan-out a
|
||||
// maximum of 'width' nodes. All nodes are no-ops and all dependencies are
|
||||
// control dependencies.
|
||||
static void BM_executor(int iters, int width, int depth) {
|
||||
testing::StopTiming();
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
BenchmarkUseRealTime();
|
||||
#endif // PLATFORM_GOOGLE
|
||||
static void BM_executor(::testing::benchmark::State& state) {
|
||||
const int width = state.range(0);
|
||||
const int depth = state.range(1);
|
||||
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
random::PhiloxRandom philox(1729, 17);
|
||||
random::SimplePhilox rand(&philox);
|
||||
@ -466,30 +465,29 @@ static void BM_executor(int iters, int width, int depth) {
|
||||
++cur;
|
||||
}
|
||||
}
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
SetBenchmarkLabel(strings::StrCat("Nodes = ", cur));
|
||||
SetBenchmarkItemsProcessed(cur * static_cast<int64>(iters));
|
||||
#endif // PLATFORM_GOOGLE
|
||||
|
||||
FixupSourceAndSinkEdges(g);
|
||||
testing::StartTiming();
|
||||
test::Benchmark("cpu", g).Run(iters);
|
||||
test::Benchmark("cpu", g, /*old_benchmark_api=*/false).Run(state);
|
||||
|
||||
state.SetLabel(strings::StrCat("Nodes = ", cur));
|
||||
state.SetItemsProcessed(cur * static_cast<int64>(state.iterations()));
|
||||
}
|
||||
|
||||
// Tall skinny graphs
|
||||
BENCHMARK(BM_executor)->ArgPair(16, 1024);
|
||||
BENCHMARK(BM_executor)->ArgPair(32, 8192);
|
||||
BENCHMARK(BM_executor)->UseRealTime()->ArgPair(16, 1024);
|
||||
BENCHMARK(BM_executor)->UseRealTime()->ArgPair(32, 8192);
|
||||
|
||||
// Short fat graphs
|
||||
BENCHMARK(BM_executor)->ArgPair(1024, 16);
|
||||
BENCHMARK(BM_executor)->ArgPair(8192, 32);
|
||||
BENCHMARK(BM_executor)->UseRealTime()->ArgPair(1024, 16);
|
||||
BENCHMARK(BM_executor)->UseRealTime()->ArgPair(8192, 32);
|
||||
|
||||
// Tall fat graph
|
||||
BENCHMARK(BM_executor)->ArgPair(1024, 1024);
|
||||
BENCHMARK(BM_executor)->UseRealTime()->ArgPair(1024, 1024);
|
||||
|
||||
static void BM_const_identity(::testing::benchmark::State& state) {
|
||||
const int width = state.range(0);
|
||||
const int outputs_per_const = state.range(1);
|
||||
|
||||
static void BM_const_identity(int iters, int width, int outputs_per_const) {
|
||||
#ifdef PLATFORM_GOOGL
|
||||
BenchmarkUseRealTime();
|
||||
#endif // PLATFORM_GOOGLE
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
for (int i = 0; i < width; ++i) {
|
||||
Tensor i_t(i);
|
||||
@ -499,23 +497,21 @@ static void BM_const_identity(int iters, int width, int outputs_per_const) {
|
||||
}
|
||||
}
|
||||
FixupSourceAndSinkEdges(g);
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
SetBenchmarkLabel(
|
||||
strings::StrCat("Nodes = ", (1 + outputs_per_const) * width));
|
||||
SetBenchmarkItemsProcessed((1 + outputs_per_const) * width *
|
||||
static_cast<int64>(iters));
|
||||
#endif // PLATFORM_GOOGLE
|
||||
test::Benchmark("cpu", g).Run(iters);
|
||||
test::Benchmark("cpu", g, /*old_benchmark_api=*/false).Run(state);
|
||||
state.SetLabel(strings::StrCat("Nodes = ", (1 + outputs_per_const) * width));
|
||||
state.SetItemsProcessed((1 + outputs_per_const) * width *
|
||||
static_cast<int64>(state.iterations()));
|
||||
}
|
||||
|
||||
// Graph with actual op execution.
|
||||
BENCHMARK(BM_const_identity)->ArgPair(1, 1);
|
||||
BENCHMARK(BM_const_identity)->ArgPair(1, 100);
|
||||
BENCHMARK(BM_const_identity)->ArgPair(100, 1);
|
||||
BENCHMARK(BM_const_identity)->ArgPair(100, 100);
|
||||
BENCHMARK(BM_const_identity)
|
||||
->UseRealTime()
|
||||
->ArgPair(1, 1)
|
||||
->ArgPair(1, 100)
|
||||
->ArgPair(100, 1)
|
||||
->ArgPair(100, 100);
|
||||
|
||||
static void BM_FeedInputFetchOutput(int iters) {
|
||||
testing::StopTiming();
|
||||
static void BM_FeedInputFetchOutput(::testing::benchmark::State& state) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
// z = x + y: x and y are provided as benchmark inputs. z is the
|
||||
// output of the benchmark. Conceptually, the caller is ALICE, the
|
||||
@ -531,13 +527,10 @@ static void BM_FeedInputFetchOutput(int iters) {
|
||||
|
||||
Tensor val(DT_FLOAT, TensorShape({}));
|
||||
val.scalar<float>()() = 3.14;
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
SetBenchmarkItemsProcessed(static_cast<int64>(iters));
|
||||
#endif // PLATFORM_GOOGLE
|
||||
FixupSourceAndSinkEdges(g);
|
||||
testing::StartTiming();
|
||||
test::Benchmark("cpu", g).RunWithRendezvousArgs({{x_key, val}, {y_key, val}},
|
||||
{z_key}, iters);
|
||||
test::Benchmark("cpu", g, /*old_benchmark_api=*/false)
|
||||
.RunWithRendezvousArgs({{x_key, val}, {y_key, val}}, {z_key}, state);
|
||||
state.SetItemsProcessed(static_cast<int64>(state.iterations()));
|
||||
}
|
||||
BENCHMARK(BM_FeedInputFetchOutput);
|
||||
|
||||
@ -549,9 +542,8 @@ BENCHMARK(BM_FeedInputFetchOutput);
|
||||
//
|
||||
// ...using the functional `WhileOp` (if `lower` is false) or the
|
||||
// `Switch`/`Merge`-style of control flow (if `lower` is true).
|
||||
static void BM_WhileLoopHelper(int iters, int loop_iters, int loop_vars,
|
||||
bool lower) {
|
||||
testing::StopTiming();
|
||||
static void BM_WhileLoopHelper(::testing::benchmark::State& state,
|
||||
int loop_iters, int loop_vars, bool lower) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
|
||||
// Add test functions for cond and body.
|
||||
@ -661,12 +653,15 @@ static void BM_WhileLoopHelper(int iters, int loop_iters, int loop_vars,
|
||||
}
|
||||
|
||||
FixupSourceAndSinkEdges(graph.get());
|
||||
testing::StartTiming();
|
||||
test::Benchmark("cpu", graph.release()).Run(iters);
|
||||
test::Benchmark("cpu", graph.release(), /*old_benchmark_api=*/false)
|
||||
.Run(state);
|
||||
}
|
||||
|
||||
static void BM_LoweredWhileLoop(int iters, int loop_iters, int loop_vars) {
|
||||
BM_WhileLoopHelper(iters, loop_iters, loop_vars, /* lower= */ true);
|
||||
static void BM_LoweredWhileLoop(::testing::benchmark::State& state) {
|
||||
const int loop_iters = state.range(0);
|
||||
const int loop_vars = state.range(1);
|
||||
|
||||
BM_WhileLoopHelper(state, loop_iters, loop_vars, /* lower= */ true);
|
||||
}
|
||||
BENCHMARK(BM_LoweredWhileLoop)
|
||||
->ArgPair(0, 1)
|
||||
@ -680,8 +675,11 @@ BENCHMARK(BM_LoweredWhileLoop)
|
||||
->ArgPair(100, 100)
|
||||
->ArgPair(1000, 100);
|
||||
|
||||
static void BM_FunctionalWhileLoop(int iters, int loop_iters, int loop_vars) {
|
||||
BM_WhileLoopHelper(iters, loop_iters, loop_vars, /* lower= */ false);
|
||||
static void BM_FunctionalWhileLoop(::testing::benchmark::State& state) {
|
||||
const int loop_iters = state.range(0);
|
||||
const int loop_vars = state.range(1);
|
||||
|
||||
BM_WhileLoopHelper(state, loop_iters, loop_vars, /* lower= */ false);
|
||||
}
|
||||
BENCHMARK(BM_FunctionalWhileLoop)
|
||||
->ArgPair(0, 1)
|
||||
|
@ -931,7 +931,6 @@ TEST(SessionTest, InvalidOpInputName) {
|
||||
attr { key: 'T' value { type: DT_FLOAT } }
|
||||
attr { key: 'transpose_a' value { b: false } }
|
||||
attr { key: 'transpose_b' value { b: false } }
|
||||
attr { key: '_kernel' value { s: 'eigen' } }
|
||||
}
|
||||
)",
|
||||
"Illegal op input name");
|
||||
@ -950,7 +949,6 @@ TEST(SessionTest, InvalidOpInputName) {
|
||||
attr { key: 'T' value { type: DT_FLOAT } }
|
||||
attr { key: 'transpose_a' value { b: false } }
|
||||
attr { key: 'transpose_b' value { b: false } }
|
||||
attr { key: '_kernel' value { s: 'eigen' } }
|
||||
}
|
||||
)",
|
||||
"Illegal op input name");
|
||||
@ -969,7 +967,6 @@ TEST(SessionTest, InvalidOpInputName) {
|
||||
attr { key: 'T' value { type: DT_FLOAT } }
|
||||
attr { key: 'transpose_a' value { b: false } }
|
||||
attr { key: 'transpose_b' value { b: false } }
|
||||
attr { key: '_kernel' value { s: 'eigen' } }
|
||||
}
|
||||
)",
|
||||
"Illegal op input name");
|
||||
@ -988,7 +985,6 @@ TEST(SessionTest, InvalidOpInputName) {
|
||||
attr { key: 'T' value { type: DT_FLOAT } }
|
||||
attr { key: 'transpose_a' value { b: false } }
|
||||
attr { key: 'transpose_b' value { b: false } }
|
||||
attr { key: '_kernel' value { s: 'eigen' } }
|
||||
}
|
||||
)",
|
||||
"Illegal op input name");
|
||||
@ -1026,7 +1022,6 @@ TEST(SessionTest, ExtendValidation) {
|
||||
attr { key: 'T' value { type: DT_FLOAT } }
|
||||
attr { key: 'transpose_a' value { b: false } }
|
||||
attr { key: 'transpose_b' value { b: false } }
|
||||
attr { key: '_kernel' value { s: 'eigen' } }
|
||||
}
|
||||
)",
|
||||
&extension);
|
||||
@ -1043,7 +1038,6 @@ TEST(SessionTest, ExtendValidation) {
|
||||
attr { key: 'T' value { type: DT_FLOAT } }
|
||||
attr { key: 'transpose_a' value { b: false } }
|
||||
attr { key: 'transpose_b' value { b: false } }
|
||||
attr { key: '_kernel' value { s: 'eigen' } }
|
||||
}
|
||||
)",
|
||||
&extension);
|
||||
@ -1057,7 +1051,6 @@ TEST(SessionTest, ExtendValidation) {
|
||||
attr { key: 'T' value { type: DT_FLOAT } }
|
||||
attr { key: 'transpose_a' value { b: false } }
|
||||
attr { key: 'transpose_b' value { b: false } }
|
||||
attr { key: '_kernel' value { s: 'eigen' } }
|
||||
}
|
||||
)",
|
||||
&extension);
|
||||
|
@ -221,14 +221,16 @@ TEST(CustomAllocatorAttributes, TestSetterAndGetter) {
|
||||
EXPECT_FALSE(HasDeviceAllocatorAttribute(AllocatorAttributes()));
|
||||
}
|
||||
|
||||
static void BM_Allocation(int iters, int arg) {
|
||||
static void BM_Allocation(::testing::benchmark::State& state) {
|
||||
const int arg = state.range(0);
|
||||
|
||||
Allocator* a = cpu_allocator();
|
||||
// Exercise a few different allocation sizes
|
||||
std::vector<int> sizes = {256, 4096, 16384, 524288, 512, 1048576};
|
||||
int size_index = 0;
|
||||
|
||||
if (arg) EnableCPUAllocatorStats();
|
||||
while (--iters > 0) {
|
||||
for (auto s : state) {
|
||||
int bytes = sizes[size_index++ % sizes.size()];
|
||||
void* p = a->AllocateRaw(1, bytes);
|
||||
a->DeallocateRaw(p);
|
||||
|
@ -39,60 +39,60 @@ TEST(Bfloat16Test, Conversion) {
|
||||
}
|
||||
}
|
||||
|
||||
static void BM_FloatToBFloat16(int iters) {
|
||||
testing::StopTiming();
|
||||
void BM_FloatToBFloat16(::testing::benchmark::State& state) {
|
||||
static const int N = 32 << 20;
|
||||
const int64 tot = static_cast<int64>(iters) * N;
|
||||
testing::ItemsProcessed(tot);
|
||||
testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
|
||||
|
||||
float* inp = new float[N];
|
||||
bfloat16* out = new bfloat16[N];
|
||||
|
||||
testing::StartTiming();
|
||||
while (iters--) {
|
||||
for (auto s : state) {
|
||||
FloatToBFloat16(inp, out, N);
|
||||
}
|
||||
|
||||
const int64 tot = static_cast<int64>(state.iterations()) * N;
|
||||
state.SetItemsProcessed(tot);
|
||||
state.SetBytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
|
||||
|
||||
delete[] inp;
|
||||
delete[] out;
|
||||
}
|
||||
BENCHMARK(BM_FloatToBFloat16);
|
||||
|
||||
static void BM_RoundFloatToBFloat16(int iters) {
|
||||
testing::StopTiming();
|
||||
void BM_RoundFloatToBFloat16(::testing::benchmark::State& state) {
|
||||
static const int N = 32 << 20;
|
||||
const int64 tot = static_cast<int64>(iters) * N;
|
||||
testing::ItemsProcessed(tot);
|
||||
testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
|
||||
|
||||
float* inp = new float[N];
|
||||
bfloat16* out = new bfloat16[N];
|
||||
|
||||
testing::StartTiming();
|
||||
while (iters--) {
|
||||
for (auto s : state) {
|
||||
RoundFloatToBFloat16(inp, out, N);
|
||||
tensorflow::testing::DoNotOptimize(inp);
|
||||
tensorflow::testing::DoNotOptimize(out);
|
||||
}
|
||||
|
||||
const int64 tot = static_cast<int64>(state.iterations()) * N;
|
||||
state.SetItemsProcessed(tot);
|
||||
state.SetBytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
|
||||
|
||||
delete[] inp;
|
||||
delete[] out;
|
||||
}
|
||||
BENCHMARK(BM_RoundFloatToBFloat16);
|
||||
|
||||
static void BM_BFloat16ToFloat(int iters) {
|
||||
testing::StopTiming();
|
||||
void BM_BFloat16ToFloat(::testing::benchmark::State& state) {
|
||||
static const int N = 32 << 20;
|
||||
const int64 tot = static_cast<int64>(iters) * N;
|
||||
testing::ItemsProcessed(tot);
|
||||
testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
|
||||
|
||||
bfloat16* inp = new bfloat16[N];
|
||||
float* out = new float[N];
|
||||
|
||||
testing::StartTiming();
|
||||
while (iters--) {
|
||||
for (auto s : state) {
|
||||
BFloat16ToFloat(inp, out, N);
|
||||
}
|
||||
|
||||
const int64 tot = static_cast<int64>(state.iterations()) * N;
|
||||
state.SetItemsProcessed(tot);
|
||||
state.SetBytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
|
||||
|
||||
delete[] inp;
|
||||
delete[] out;
|
||||
}
|
||||
|
@ -406,7 +406,7 @@ XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
||||
TEST(TFunc, WXPlusB) {
|
||||
auto expect = R"P(
|
||||
WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) {
|
||||
mm = MatMul[T=$T, _kernel="eigen", transpose_a=false, transpose_b=false](w, x)
|
||||
mm = MatMul[T=$T, transpose_a=false, transpose_b=false](w, x)
|
||||
y = Add[T=$T](mm:product:0, b)
|
||||
return y = y:z:0
|
||||
}
|
||||
|
@ -346,10 +346,7 @@ FunctionDef WXPlusB() {
|
||||
{{{"mm"},
|
||||
"MatMul",
|
||||
{"w", "x"},
|
||||
{{"T", "$T"},
|
||||
{"transpose_a", false},
|
||||
{"transpose_b", false},
|
||||
{"_kernel", "eigen"}}},
|
||||
{{"T", "$T"}, {"transpose_a", false}, {"transpose_b", false}}},
|
||||
{{"y"}, "Add", {"mm", "b"}, {{"T", "$T"}}}});
|
||||
}
|
||||
|
||||
|
@ -1002,9 +1002,9 @@ TEST_F(LabelTest, Duplicate) {
|
||||
error::INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
void BM_InputRangeHelper(int iters, const NodeDef& node_def,
|
||||
const char* input_name, int expected_start,
|
||||
int expected_stop) {
|
||||
void BM_InputRangeHelper(::testing::benchmark::State& state,
|
||||
const NodeDef& node_def, const char* input_name,
|
||||
int expected_start, int expected_stop) {
|
||||
Status status;
|
||||
auto device = absl::make_unique<DummyDevice>(Env::Default());
|
||||
|
||||
@ -1013,24 +1013,20 @@ void BM_InputRangeHelper(int iters, const NodeDef& node_def,
|
||||
TF_GRAPH_DEF_VERSION, &status));
|
||||
TF_CHECK_OK(status);
|
||||
|
||||
testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
for (auto s : state) {
|
||||
int start;
|
||||
int stop;
|
||||
TF_CHECK_OK(op->InputRange(input_name, &start, &stop));
|
||||
EXPECT_EQ(expected_start, start);
|
||||
EXPECT_EQ(expected_stop, stop);
|
||||
}
|
||||
testing::StopTiming();
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ConcatV2").Device(DEVICE_CPU), DummyKernel);
|
||||
REGISTER_KERNEL_BUILDER(Name("Select").Device(DEVICE_CPU), DummyKernel);
|
||||
REGISTER_KERNEL_BUILDER(Name("MatMul").Device(DEVICE_CPU), DummyKernel);
|
||||
|
||||
void BM_ConcatInputRange(int iters) {
|
||||
testing::StopTiming();
|
||||
|
||||
void BM_ConcatInputRange(::testing::benchmark::State& state) {
|
||||
// Create a ConcatV2 NodeDef with 4 inputs (plus the axis).
|
||||
NodeDef node_def;
|
||||
node_def.set_name("concat-op");
|
||||
@ -1048,12 +1044,10 @@ void BM_ConcatInputRange(int iters) {
|
||||
node_def.add_input(strings::StrCat("a:", i));
|
||||
}
|
||||
|
||||
BM_InputRangeHelper(iters, node_def, "values", 0, 4);
|
||||
BM_InputRangeHelper(state, node_def, "values", 0, 4);
|
||||
}
|
||||
|
||||
void BM_SelectInputRange(int iters) {
|
||||
testing::StopTiming();
|
||||
|
||||
void BM_SelectInputRange(::testing::benchmark::State& state) {
|
||||
// Create a Select NodeDef with 3 inputs.
|
||||
NodeDef node_def;
|
||||
node_def.set_name("select-op");
|
||||
@ -1065,11 +1059,11 @@ void BM_SelectInputRange(int iters) {
|
||||
node_def.add_input(strings::StrCat("a:", i));
|
||||
}
|
||||
|
||||
BM_InputRangeHelper(iters, node_def, "condition", 0, 1);
|
||||
BM_InputRangeHelper(state, node_def, "condition", 0, 1);
|
||||
}
|
||||
|
||||
void BM_TraceString(const int iters, const int verbose) {
|
||||
testing::StopTiming();
|
||||
void BM_TraceString(::testing::benchmark::State& state) {
|
||||
const int verbose = state.range(0);
|
||||
|
||||
// Create a MatMul NodeDef with 2 inputs.
|
||||
NodeDef node_def;
|
||||
@ -1103,11 +1097,9 @@ void BM_TraceString(const int iters, const int verbose) {
|
||||
params.inputs = &inputs;
|
||||
auto ctx = absl::make_unique<OpKernelContext>(¶ms);
|
||||
|
||||
testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
for (auto s : state) {
|
||||
auto trace = op->TraceString(*ctx, verbose);
|
||||
}
|
||||
testing::StopTiming();
|
||||
}
|
||||
|
||||
BENCHMARK(BM_ConcatInputRange);
|
||||
|
@ -434,37 +434,38 @@ TEST_F(LocalRendezvousTest, TransferDummyDeviceContext) {
|
||||
args1.device_context->Unref();
|
||||
}
|
||||
|
||||
void BM_SendRecv(int iters) {
|
||||
void BM_SendRecv(::testing::benchmark::State& state) {
|
||||
Rendezvous* rendez = NewLocalRendezvous();
|
||||
Tensor orig = V("val");
|
||||
Tensor val(DT_STRING, TensorShape({}));
|
||||
bool is_dead = false;
|
||||
Rendezvous::Args args;
|
||||
if (iters > 0) {
|
||||
while (iters--) {
|
||||
|
||||
for (auto s : state) {
|
||||
TF_CHECK_OK(rendez->Send(KeyFoo(), args, orig, is_dead));
|
||||
TF_CHECK_OK(rendez->Recv(KeyFoo(), args, &val, &is_dead));
|
||||
}
|
||||
CHECK_EQ(V(val), V(orig));
|
||||
}
|
||||
|
||||
rendez->Unref();
|
||||
}
|
||||
BENCHMARK(BM_SendRecv);
|
||||
|
||||
void BM_RecvSend(int iters) {
|
||||
void BM_RecvSend(::testing::benchmark::State& state) {
|
||||
Rendezvous* rendez = NewLocalRendezvous();
|
||||
Tensor orig = V("val");
|
||||
Tensor val(DT_STRING, TensorShape({}));
|
||||
bool is_dead = false;
|
||||
Rendezvous::Args args;
|
||||
if (iters > 0) {
|
||||
while (iters--) {
|
||||
|
||||
for (auto s : state) {
|
||||
bool received = false;
|
||||
rendez->RecvAsync(
|
||||
KeyFoo(), args,
|
||||
[&val, &received](const Status& s, const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
const Tensor& tensor, bool is_dead) {
|
||||
[&val, &received](const Status& /*s*/,
|
||||
const Rendezvous::Args& /*send_args*/,
|
||||
const Rendezvous::Args& /*recv_args*/,
|
||||
const Tensor& tensor, bool /*is_dead*/) {
|
||||
val = tensor;
|
||||
received = true;
|
||||
});
|
||||
@ -472,26 +473,29 @@ void BM_RecvSend(int iters) {
|
||||
CHECK(received);
|
||||
}
|
||||
CHECK_EQ(V(val), V(orig));
|
||||
}
|
||||
|
||||
rendez->Unref();
|
||||
}
|
||||
BENCHMARK(BM_RecvSend);
|
||||
|
||||
void BM_PingPong(int iters) {
|
||||
CHECK_GT(iters, 0);
|
||||
void BM_PingPong(::testing::benchmark::State& state) {
|
||||
const int messages_count = state.range(0);
|
||||
auto* cm = new CancellationManager();
|
||||
thread::ThreadPool* pool = new thread::ThreadPool(Env::Default(), "test", 1);
|
||||
|
||||
// The main thread sends "foo" for iters times and receives "bar"
|
||||
// for iters times. The other thread sends "bar" for iters times
|
||||
// and receives "foo" for iters times.
|
||||
// Benchmark loop
|
||||
// In each iteration:
|
||||
// The main thread sends "foo" for messages_count times and receives "bar"
|
||||
// for messages_count times. The other thread sends "bar" for
|
||||
// messages_count times and receives "foo" for messages_count times.
|
||||
for (auto s : state) {
|
||||
Rendezvous* rendez = NewLocalRendezvous();
|
||||
pool->Schedule([rendez, iters]() {
|
||||
pool->Schedule([rendez, messages_count]() {
|
||||
Tensor bar = V("bar");
|
||||
Tensor foo(DT_STRING, TensorShape({}));
|
||||
bool is_dead = false;
|
||||
Rendezvous::Args args;
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
for (int i = 0; i < messages_count; ++i) {
|
||||
TF_CHECK_OK(rendez->Recv(KeyFoo(), args, &foo, &is_dead));
|
||||
TF_CHECK_OK(rendez->Send(KeyBar(), args, bar, is_dead));
|
||||
}
|
||||
@ -502,15 +506,17 @@ void BM_PingPong(int iters) {
|
||||
bool is_dead = false;
|
||||
Rendezvous::Args args;
|
||||
args.cancellation_manager = cm;
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
for (int i = 0; i < messages_count; ++i) {
|
||||
TF_CHECK_OK(rendez->Send(KeyFoo(), args, foo, is_dead));
|
||||
TF_CHECK_OK(rendez->Recv(KeyBar(), args, &bar, &is_dead));
|
||||
}
|
||||
CHECK_EQ("bar", V(bar));
|
||||
}
|
||||
state.SetItemsProcessed(messages_count * state.iterations());
|
||||
delete pool;
|
||||
delete cm;
|
||||
}
|
||||
BENCHMARK(BM_PingPong);
|
||||
BENCHMARK(BM_PingPong)->Arg(100)->Arg(200)->Arg(300);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -684,19 +684,24 @@ static std::vector<int64> MakeSizes(int arg) {
|
||||
return sizes;
|
||||
}
|
||||
|
||||
static void BM_TensorShape_Init(int iters, int arg) {
|
||||
void BM_TensorShape_Init(::testing::benchmark::State& state) {
|
||||
const int arg = state.range(0);
|
||||
|
||||
auto sizes = MakeSizes(arg);
|
||||
while (--iters > 0) {
|
||||
for (auto s : state) {
|
||||
TensorShape shape(sizes);
|
||||
tensorflow::testing::DoNotOptimize(shape.num_elements());
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_TensorShape_Init)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Arg(4);
|
||||
|
||||
static void BM_TensorShape_Assign(int iters, int arg) {
|
||||
TensorShape s(MakeSizes(arg));
|
||||
while (--iters > 0) {
|
||||
TensorShape s2 = s;
|
||||
void BM_TensorShape_Assign(::testing::benchmark::State& state) {
|
||||
const int arg = state.range(0);
|
||||
|
||||
TensorShape shape(MakeSizes(arg));
|
||||
for (auto s : state) {
|
||||
const TensorShape s2 = shape;
|
||||
tensorflow::testing::DoNotOptimize(s2);
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_TensorShape_Assign)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Arg(4);
|
||||
|
@ -1468,19 +1468,19 @@ TEST(SummarizeValue, STRING_PRINT_V2) {
|
||||
x.SummarizeValue(16, true));
|
||||
}
|
||||
|
||||
void BM_CreateAndDestroy(int iters) {
|
||||
void BM_CreateAndDestroy(::testing::benchmark::State& state) {
|
||||
TensorShape shape({10, 20});
|
||||
while (--iters) {
|
||||
for (auto s : state) {
|
||||
Tensor t(DT_FLOAT, shape);
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_CreateAndDestroy);
|
||||
|
||||
void BM_Assign(int iters) {
|
||||
void BM_Assign(::testing::benchmark::State& state) {
|
||||
Tensor a(DT_FLOAT, TensorShape({10, 20}));
|
||||
Tensor b(DT_FLOAT, TensorShape({10, 20}));
|
||||
bool a_to_b = true;
|
||||
while (--iters) {
|
||||
for (auto s : state) {
|
||||
if (a_to_b) {
|
||||
b = a;
|
||||
} else {
|
||||
@ -1498,20 +1498,20 @@ TEST(Tensor, EmptyTensorData) {
|
||||
}
|
||||
|
||||
// Benchmark create and destroy a tensor, with an allocated buffer.
|
||||
void BM_CreateAndDestroyWithBuf(int iters) {
|
||||
void BM_CreateAndDestroyWithBuf(::testing::benchmark::State& state) {
|
||||
TensorShape shape({10, 20});
|
||||
Allocator* allocator = cpu_allocator();
|
||||
while (--iters) {
|
||||
for (auto s : state) {
|
||||
Tensor a(allocator, DT_FLOAT, shape);
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_CreateAndDestroyWithBuf);
|
||||
|
||||
// Benchmark create+copy a tensor, with an allocated buffer.
|
||||
void BM_CreateAndCopyCtrWithBuf(int iters) {
|
||||
void BM_CreateAndCopyCtrWithBuf(::testing::benchmark::State& state) {
|
||||
TensorShape shape({10, 20});
|
||||
Allocator* allocator = cpu_allocator();
|
||||
while (--iters) {
|
||||
for (auto s : state) {
|
||||
Tensor a(allocator, DT_FLOAT, shape);
|
||||
Tensor b(a);
|
||||
}
|
||||
@ -1519,10 +1519,10 @@ void BM_CreateAndCopyCtrWithBuf(int iters) {
|
||||
BENCHMARK(BM_CreateAndCopyCtrWithBuf);
|
||||
|
||||
// Benchmark create+move a tensor, with an allocated buffer.
|
||||
void BM_CreateAndMoveCtrWithBuf(int iters) {
|
||||
void BM_CreateAndMoveCtrWithBuf(::testing::benchmark::State& state) {
|
||||
TensorShape shape({10, 20});
|
||||
Allocator* allocator = cpu_allocator();
|
||||
while (--iters) {
|
||||
for (auto s : state) {
|
||||
Tensor a(allocator, DT_FLOAT, shape);
|
||||
Tensor b(std::move(a));
|
||||
}
|
||||
@ -1531,10 +1531,11 @@ BENCHMARK(BM_CreateAndMoveCtrWithBuf);
|
||||
|
||||
// Benchmark creating and destroy a host-scalar tensor, using the allocator
|
||||
// interface.
|
||||
void BM_CreateAndDestroyHostScalarNonOptimized(int iters) {
|
||||
void BM_CreateAndDestroyHostScalarNonOptimized(
|
||||
::testing::benchmark::State& state) {
|
||||
TensorShape shape({});
|
||||
Allocator* allocator = cpu_allocator();
|
||||
while (--iters) {
|
||||
for (auto s : state) {
|
||||
Tensor a(allocator, DT_FLOAT, shape);
|
||||
a.scalar<float>()() = 37.0;
|
||||
}
|
||||
@ -1543,32 +1544,33 @@ BENCHMARK(BM_CreateAndDestroyHostScalarNonOptimized);
|
||||
|
||||
// Benchmark creating and destroy a host-scalar tensor, using the specialized
|
||||
// constructor.
|
||||
void BM_CreateAndDestroyHostScalarOptimized(int iters) {
|
||||
while (--iters) {
|
||||
void BM_CreateAndDestroyHostScalarOptimized(
|
||||
::testing::benchmark::State& state) {
|
||||
for (auto s : state) {
|
||||
Tensor a(37.0);
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_CreateAndDestroyHostScalarOptimized);
|
||||
|
||||
static void BM_FromProto(int iters, int size) {
|
||||
testing::StopTiming();
|
||||
void BM_FromProto(::testing::benchmark::State& state) {
|
||||
const int size = state.range(0);
|
||||
|
||||
TensorShape shape({size});
|
||||
Allocator* allocator = cpu_allocator();
|
||||
Tensor a(allocator, DT_FLOAT, shape);
|
||||
std::fill_n(a.flat<float>().data(), size, 42.0);
|
||||
TensorProto p;
|
||||
a.AsProtoField(&p);
|
||||
testing::StartTiming();
|
||||
while (--iters) {
|
||||
for (auto s : state) {
|
||||
Tensor b;
|
||||
ASSERT_TRUE(b.FromProto(p));
|
||||
}
|
||||
testing::StopTiming();
|
||||
}
|
||||
BENCHMARK(BM_FromProto)->Range(1, 1 << 20);
|
||||
|
||||
static void BM_FromProtoCompressed(int iters, int size) {
|
||||
testing::StopTiming();
|
||||
void BM_FromProtoCompressed(::testing::benchmark::State& state) {
|
||||
const int size = state.range(0);
|
||||
|
||||
TensorShape shape({size});
|
||||
Allocator* allocator = cpu_allocator();
|
||||
Tensor a(allocator, DT_FLOAT, shape);
|
||||
@ -1576,17 +1578,16 @@ static void BM_FromProtoCompressed(int iters, int size) {
|
||||
TensorProto p;
|
||||
a.AsProtoField(&p);
|
||||
tensor::CompressTensorProtoInPlace(&p);
|
||||
testing::StartTiming();
|
||||
while (--iters) {
|
||||
for (auto s : state) {
|
||||
Tensor b;
|
||||
ASSERT_TRUE(b.FromProto(p));
|
||||
}
|
||||
testing::StopTiming();
|
||||
}
|
||||
BENCHMARK(BM_FromProtoCompressed)->Range(1, 1 << 20);
|
||||
|
||||
static void BM_FromProtoCompressedZero(int iters, int size) {
|
||||
testing::StopTiming();
|
||||
void BM_FromProtoCompressedZero(::testing::benchmark::State& state) {
|
||||
const int size = state.range(0);
|
||||
|
||||
TensorShape shape({size});
|
||||
Allocator* allocator = cpu_allocator();
|
||||
Tensor a(allocator, DT_FLOAT, shape);
|
||||
@ -1595,12 +1596,10 @@ static void BM_FromProtoCompressedZero(int iters, int size) {
|
||||
TensorProto p;
|
||||
a.AsProtoField(&p);
|
||||
tensor::CompressTensorProtoInPlace(&p);
|
||||
testing::StartTiming();
|
||||
while (--iters) {
|
||||
for (auto s : state) {
|
||||
Tensor b;
|
||||
ASSERT_TRUE(b.FromProto(p));
|
||||
}
|
||||
testing::StopTiming();
|
||||
}
|
||||
BENCHMARK(BM_FromProtoCompressedZero)->Range(1, 1 << 20);
|
||||
|
||||
|
@ -767,16 +767,49 @@ Status AvgPoolGradTransposer::TransposeNode(TransposeContext* context,
|
||||
return context->graph_view->GetMutationBuilder()->Apply();
|
||||
}
|
||||
|
||||
Status BiasAddGradTransposer::TransposeNode(TransposeContext* context,
|
||||
Status BiasAddTransposer::TransposeNode(TransposeContext* context,
|
||||
utils::MutableNodeView* node) {
|
||||
DCHECK(IsBiasAddGrad(*node->node()));
|
||||
if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, 4)) {
|
||||
DCHECK(IsBiasAdd(*node->node()));
|
||||
const int rank = GetFanoutPortRank(*node, 0);
|
||||
if (rank != 4 && rank != 5) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank)) {
|
||||
return Status::OK();
|
||||
}
|
||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||
<< "' with op '" << node->GetOp() << "' from data format '"
|
||||
<< context->src_format << "' to '" << context->dst_format << "'";
|
||||
// BiasAdd itself only needs NCHW/NHWC to determine whether C dim is the
|
||||
// second or the last dim. Therefore, we use the original 4D data format in
|
||||
// the context to update the node. For the input/output tensor, the
|
||||
// corresponding 4D or 5D data format is needed.
|
||||
TF_RETURN_IF_ERROR(UpdateNode(context, node));
|
||||
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
|
||||
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||
return context->graph_view->GetMutationBuilder()->Apply();
|
||||
}
|
||||
|
||||
Status BiasAddGradTransposer::TransposeNode(TransposeContext* context,
|
||||
utils::MutableNodeView* node) {
|
||||
DCHECK(IsBiasAddGrad(*node->node()));
|
||||
const int rank = GetFaninPortRank(*node, 0);
|
||||
if (rank != 4 && rank != 5) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (!ShouldProcess(*context, *node)) {
|
||||
return Status::OK();
|
||||
}
|
||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||
<< "' with op '" << node->GetOp() << "' from data format '"
|
||||
<< context->src_format << "' to '" << context->dst_format << "'";
|
||||
// BiasAddGrad itself only needs NCHW/NHWC to determine whether C dim is the
|
||||
// second or the last dim. Therefore, we use the original 4D data format in
|
||||
// the context to update the node. For the input tensor, the corresponding 4D
|
||||
// or 5D data format is needed.
|
||||
TF_RETURN_IF_ERROR(UpdateNode(context, node));
|
||||
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
|
||||
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||
// No need to update output shape, as it is always of shape 1-D with size the
|
||||
// feature dimension of `out_backprop`, regardless of whether NCHW or NHWC is
|
||||
@ -839,7 +872,12 @@ Status Conv2DBackpropInputTransposer::TransposeNode(
|
||||
Status Conv3DTransposer::TransposeNode(TransposeContext* context,
|
||||
utils::MutableNodeView* node) {
|
||||
DCHECK(IsConv3D(*node->node()));
|
||||
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 5)) {
|
||||
const int rank = GetFanoutPortRank(*node, 0);
|
||||
if (rank != 5) {
|
||||
return Status::OK();
|
||||
}
|
||||
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
|
||||
if (!ShouldProcess(*context, *node)) {
|
||||
return Status::OK();
|
||||
}
|
||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||
@ -854,7 +892,12 @@ Status Conv3DTransposer::TransposeNode(TransposeContext* context,
|
||||
Status Conv3DBackpropFilterTransposer::TransposeNode(
|
||||
TransposeContext* context, utils::MutableNodeView* node) {
|
||||
DCHECK(IsConv3DBackpropFilterV2(*node->node()));
|
||||
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 5)) {
|
||||
const int rank = GetFanoutPortRank(*node, 0);
|
||||
if (rank != 5) {
|
||||
return Status::OK();
|
||||
}
|
||||
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
|
||||
if (!ShouldProcess(*context, *node)) {
|
||||
return Status::OK();
|
||||
}
|
||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||
@ -872,7 +915,12 @@ Status Conv3DBackpropFilterTransposer::TransposeNode(
|
||||
Status Conv3DBackpropInputTransposer::TransposeNode(
|
||||
TransposeContext* context, utils::MutableNodeView* node) {
|
||||
DCHECK(IsConv3DBackpropInputV2(*node->node()));
|
||||
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 5)) {
|
||||
const int rank = GetFanoutPortRank(*node, 0);
|
||||
if (rank != 5) {
|
||||
return Status::OK();
|
||||
}
|
||||
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
|
||||
if (!ShouldProcess(*context, *node)) {
|
||||
return Status::OK();
|
||||
}
|
||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||
@ -1081,8 +1129,9 @@ bool LayoutAgnosticOpTransposer::IsAfterDstToSrcTransform(
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<int> LayoutAgnosticOpTransposer::GetVariadic4DFaninPorts(
|
||||
const TransposeContext& context, const utils::MutableNodeView& node) const {
|
||||
std::vector<int> LayoutAgnosticOpTransposer::GetVariadicNDFaninPorts(
|
||||
const TransposeContext& context, const utils::MutableNodeView& node,
|
||||
int rank) const {
|
||||
std::vector<int> ports;
|
||||
const int num_regular_fanins = node.NumRegularFanins();
|
||||
ports.reserve(num_regular_fanins);
|
||||
@ -1090,7 +1139,7 @@ std::vector<int> LayoutAgnosticOpTransposer::GetVariadic4DFaninPorts(
|
||||
const auto& regular_fanin = node.GetRegularFanin(i);
|
||||
auto* regular_fanin_node = regular_fanin.node_view();
|
||||
int regular_fanin_port = regular_fanin.index();
|
||||
if (IsFanoutPortRankN(*regular_fanin_node, regular_fanin_port, 4) &&
|
||||
if ((IsFanoutPortRankN(*regular_fanin_node, regular_fanin_port, rank)) &&
|
||||
((IsAfterDstToSrcTransform(context, *regular_fanin_node) &&
|
||||
IsLayoutAgnosticOp(*regular_fanin_node->node())) ||
|
||||
IsLayoutOptimizerAddedDstToSrcTranspose(context,
|
||||
@ -1124,10 +1173,18 @@ Status DefaultLayoutAgnosticOpTransposer::TransposeNode(
|
||||
Status AddNTransposer::TransposeNode(TransposeContext* context,
|
||||
utils::MutableNodeView* node) {
|
||||
DCHECK(IsAddN(*node->node()));
|
||||
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
|
||||
const int rank = GetFanoutPortRank(*node, 0);
|
||||
if (rank != 4 && rank != 5) {
|
||||
return Status::OK();
|
||||
}
|
||||
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
|
||||
if (!ShouldProcess(*context, *node) ||
|
||||
!IsAfterDstToSrcTransform(*context, *node)) {
|
||||
return Status::OK();
|
||||
}
|
||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||
<< "' with op '" << node->GetOp() << "' from data format '"
|
||||
<< context->src_format << "' to '" << context->dst_format << "'";
|
||||
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, GetDataFaninPorts(*node),
|
||||
node, kOpTranspose));
|
||||
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||
@ -1284,7 +1341,12 @@ Status BinaryOpTransposer::TransposeNode(TransposeContext* context,
|
||||
Status ConcatOpTransposer::TransposeNode(TransposeContext* context,
|
||||
utils::MutableNodeView* node) {
|
||||
DCHECK(IsConcat(*node->node()));
|
||||
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
|
||||
const int rank = GetFanoutPortRank(*node, 0);
|
||||
if (rank != 4 && rank != 5) {
|
||||
return Status::OK();
|
||||
}
|
||||
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
|
||||
if (!ShouldProcess(*context, *node) ||
|
||||
!IsAfterDstToSrcTransform(*context, *node)) {
|
||||
return Status::OK();
|
||||
}
|
||||
@ -1297,6 +1359,9 @@ Status ConcatOpTransposer::TransposeNode(TransposeContext* context,
|
||||
axis_node = n_attr->i();
|
||||
}
|
||||
}
|
||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||
<< "' with op '" << node->GetOp() << "' from data format '"
|
||||
<< context->src_format << "' to '" << context->dst_format << "'";
|
||||
TF_RETURN_IF_ERROR(
|
||||
UpdateFaninEdgesWithOp(context, {axis_node}, node, kOpDataFormatDimMap));
|
||||
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||
@ -1320,14 +1385,33 @@ Status FillOpTransposer::TransposeNode(TransposeContext* context,
|
||||
Status IdentityNTransposer::TransposeNode(TransposeContext* context,
|
||||
utils::MutableNodeView* node) {
|
||||
DCHECK(IsIdentityN(*node->node()));
|
||||
const auto ports = GetVariadic4DFaninPorts(*context, *node);
|
||||
if (!ShouldProcess(*context, *node) || ports.empty()) {
|
||||
const auto ports_4d = GetVariadicNDFaninPorts(*context, *node, 4);
|
||||
|
||||
// Temporarily upgrade the context to obtain the number of 5D fanin ports.
|
||||
std::vector<int> ports_5d;
|
||||
{
|
||||
ScopedDataFormatUpgrader data_format_upgrader(context, 5);
|
||||
ports_5d = GetVariadicNDFaninPorts(*context, *node, 5);
|
||||
}
|
||||
|
||||
if (!ShouldProcess(*context, *node)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (!ports_4d.empty()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
UpdateFaninEdgesWithOp(context, ports, node, kOpTranspose));
|
||||
UpdateFaninEdgesWithOp(context, ports_4d, node, kOpTranspose));
|
||||
TF_RETURN_IF_ERROR(
|
||||
UpdateFanoutEdgesWithOp(context, ports, node, kOpTranspose));
|
||||
UpdateFanoutEdgesWithOp(context, ports_4d, node, kOpTranspose));
|
||||
}
|
||||
|
||||
if (!ports_5d.empty()) {
|
||||
ScopedDataFormatUpgrader data_format_upgrader(context, 5);
|
||||
TF_RETURN_IF_ERROR(
|
||||
UpdateFaninEdgesWithOp(context, ports_5d, node, kOpTranspose));
|
||||
TF_RETURN_IF_ERROR(
|
||||
UpdateFanoutEdgesWithOp(context, ports_5d, node, kOpTranspose));
|
||||
}
|
||||
return context->graph_view->GetMutationBuilder()->Apply();
|
||||
}
|
||||
|
||||
@ -1519,10 +1603,18 @@ Status SelectTransposer::TransposeNode(TransposeContext* context,
|
||||
Status ShapeTransposer::TransposeNode(TransposeContext* context,
|
||||
utils::MutableNodeView* node) {
|
||||
DCHECK(IsShape(*node->node()));
|
||||
if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, 4) ||
|
||||
const int rank = GetFaninPortRank(*node, 0);
|
||||
if (rank != 4 && rank != 5) {
|
||||
return Status::OK();
|
||||
}
|
||||
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
|
||||
if (!ShouldProcess(*context, *node) ||
|
||||
!IsAfterDstToSrcTransform(*context, *node)) {
|
||||
return Status::OK();
|
||||
}
|
||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||
<< "' with op '" << node->GetOp() << "' from data format '"
|
||||
<< context->src_format << "' to '" << context->dst_format << "'";
|
||||
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||
TF_RETURN_IF_ERROR(
|
||||
UpdateFanoutEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
|
||||
@ -1532,10 +1624,20 @@ Status ShapeTransposer::TransposeNode(TransposeContext* context,
|
||||
Status ShapeNTransposer::TransposeNode(TransposeContext* context,
|
||||
utils::MutableNodeView* node) {
|
||||
DCHECK(IsShapeN(*node->node()));
|
||||
const auto ports = GetVariadic4DFaninPorts(*context, *node);
|
||||
// ShapeN requires all input tensors to have the same dimensions. Therefore,
|
||||
// we simply use the 0th fanin port.
|
||||
const int rank = GetFaninPortRank(*node, 0);
|
||||
if (rank != 4 && rank != 5) {
|
||||
return Status::OK();
|
||||
}
|
||||
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
|
||||
const auto ports = GetVariadicNDFaninPorts(*context, *node, rank);
|
||||
if (!ShouldProcess(*context, *node) || ports.empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||
<< "' with op '" << node->GetOp() << "' from data format '"
|
||||
<< context->src_format << "' to '" << context->dst_format << "'";
|
||||
TF_RETURN_IF_ERROR(
|
||||
UpdateFaninEdgesWithOp(context, ports, node, kOpTranspose));
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -1546,11 +1648,19 @@ Status ShapeNTransposer::TransposeNode(TransposeContext* context,
|
||||
Status SliceTransposer::TransposeNode(TransposeContext* context,
|
||||
utils::MutableNodeView* node) {
|
||||
DCHECK(IsSlice(*node->node()));
|
||||
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
|
||||
!IsFaninPortsDimsNIfConst(*node, {1, 2}, {4}) ||
|
||||
const int rank = GetFanoutPortRank(*node, 0);
|
||||
if (rank != 4 && rank != 5) {
|
||||
return Status::OK();
|
||||
}
|
||||
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
|
||||
if (!ShouldProcess(*context, *node) ||
|
||||
!IsFaninPortsDimsNIfConst(*node, {1, 2}, {rank}) ||
|
||||
!IsAfterDstToSrcTransform(*context, *node)) {
|
||||
return Status::OK();
|
||||
}
|
||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||
<< "' with op '" << node->GetOp() << "' from data format '"
|
||||
<< context->src_format << "' to '" << context->dst_format << "'";
|
||||
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||
TF_RETURN_IF_ERROR(
|
||||
UpdateFaninEdgesWithOp(context, {1, 2}, node, kOpDataFormatVecPermute));
|
||||
@ -1839,18 +1949,17 @@ string GetDeviceName(const VirtualPlacer* virtual_placer, const NodeDef& node) {
|
||||
bool IsDefaultLayoutSensitiveOp(const NodeDef& node) {
|
||||
static absl::flat_hash_set<string>* default_layout_sensitive_ops =
|
||||
new absl::flat_hash_set<std::string>(
|
||||
{"AvgPool", "BiasAdd", "Conv2D", "DepthwiseConv2dNative",
|
||||
"DepthToSpace", "FusedBatchNorm", "FusedBatchNormV2",
|
||||
"FusedBatchNormV3", "FusedConv2DBiasActivation", "MaxPool",
|
||||
"SpaceToDepth"});
|
||||
{"AvgPool", "Conv2D", "DepthwiseConv2dNative", "DepthToSpace",
|
||||
"FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3",
|
||||
"FusedConv2DBiasActivation", "MaxPool", "SpaceToDepth"});
|
||||
return default_layout_sensitive_ops->find(node.op()) !=
|
||||
default_layout_sensitive_ops->end();
|
||||
}
|
||||
|
||||
bool IsLayoutSensitiveOp(const NodeDef& node) {
|
||||
return IsDefaultLayoutSensitiveOp(node) || IsAvgPoolGrad(node) ||
|
||||
IsBiasAddGrad(node) || IsConv2DBackpropFilter(node) ||
|
||||
IsConv2DBackpropInput(node) ||
|
||||
IsBiasAdd(node) || IsBiasAddGrad(node) ||
|
||||
IsConv2DBackpropFilter(node) || IsConv2DBackpropInput(node) ||
|
||||
IsDepthwiseConv2dNativeBackpropFilter(node) ||
|
||||
IsDepthwiseConv2dNativeBackpropInput(node) ||
|
||||
IsFusedBatchNormEx(node) || IsFusedBatchNormGrad(node) ||
|
||||
|
@ -210,6 +210,14 @@ class DefaultLayoutSensitiveOpTransposer : public LayoutSensitiveOpTransposer {
|
||||
utils::MutableNodeView* node) override;
|
||||
};
|
||||
|
||||
class BiasAddTransposer : public LayoutSensitiveOpTransposer {
|
||||
public:
|
||||
explicit BiasAddTransposer() : LayoutSensitiveOpTransposer() {}
|
||||
|
||||
Status TransposeNode(TransposeContext* context,
|
||||
utils::MutableNodeView* node) override;
|
||||
};
|
||||
|
||||
class AvgPoolGradTransposer : public LayoutSensitiveOpTransposer {
|
||||
public:
|
||||
explicit AvgPoolGradTransposer() : LayoutSensitiveOpTransposer() {}
|
||||
@ -319,9 +327,9 @@ class LayoutAgnosticOpTransposer : public Transposer {
|
||||
bool IsAfterDstToSrcTransform(const TransposeContext& context,
|
||||
const utils::MutableNodeView& node) const;
|
||||
|
||||
std::vector<int> GetVariadic4DFaninPorts(
|
||||
const TransposeContext& context,
|
||||
const utils::MutableNodeView& node) const;
|
||||
std::vector<int> GetVariadicNDFaninPorts(const TransposeContext& context,
|
||||
const utils::MutableNodeView& node,
|
||||
int rank) const;
|
||||
};
|
||||
|
||||
class DefaultLayoutAgnosticOpTransposer : public LayoutAgnosticOpTransposer {
|
||||
|
@ -27,6 +27,9 @@ std::shared_ptr<Transposer> TransposerFactory::GetTransposer(
|
||||
return GetOrCreateIfNotFound<DefaultLayoutSensitiveOpTransposer>(
|
||||
"DefaultLayoutSensitiveOp");
|
||||
}
|
||||
if (IsBiasAdd(node)) {
|
||||
return GetOrCreateIfNotFound<BiasAddTransposer>("BiasAdd");
|
||||
}
|
||||
if (IsAvgPoolGrad(node)) {
|
||||
return GetOrCreateIfNotFound<AvgPoolGradTransposer>("AvgPoolGrad");
|
||||
}
|
||||
|
@ -48,7 +48,6 @@ load(
|
||||
)
|
||||
load(
|
||||
"//third_party/mkl:build_defs.bzl",
|
||||
"if_mkl_ml",
|
||||
"mkl_deps",
|
||||
)
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
@ -3241,7 +3240,6 @@ cc_library(
|
||||
deps = [
|
||||
":aggregate_ops",
|
||||
":argmax_op",
|
||||
":batch_matmul_op",
|
||||
":betainc_op",
|
||||
":bincount_op",
|
||||
":bucketize_op",
|
||||
@ -3337,14 +3335,27 @@ tf_kernel_library(
|
||||
|
||||
tf_kernel_library(
|
||||
name = "batch_matmul_op",
|
||||
deps = [":matmul_op"],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "matmul_op",
|
||||
# <prefix>*impl.h are excluded by default from the CPU build, add explicitly.
|
||||
hdrs = ["batch_matmul_op_impl.h"],
|
||||
prefix = "batch_matmul_op",
|
||||
deps = MATH_DEPS + [":eigen_contraction_kernel"] + if_mkl_ml([
|
||||
"//third_party/mkl:intel_binary_blob",
|
||||
]) + if_cuda_or_rocm([
|
||||
"//tensorflow/core/kernels:gpu_utils",
|
||||
]),
|
||||
hdrs = ["matmul_op_impl.h"],
|
||||
defines = select({
|
||||
":xsmm": ["TENSORFLOW_USE_LIBXSMM"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
prefix = "matmul_op",
|
||||
deps = MATH_DEPS + [
|
||||
":eigen_contraction_kernel",
|
||||
":fused_eigen_output_kernels",
|
||||
] + select({
|
||||
":xsmm": ["@libxsmm_archive//:xsmm_avx"],
|
||||
"//conditions:default": [],
|
||||
}) + mkl_deps() + if_cuda([
|
||||
"//tensorflow/core/platform/default/build_config:cublas_plugin",
|
||||
]) + if_cuda_or_rocm([":gpu_utils"]),
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
@ -3406,28 +3417,6 @@ tf_kernel_library(
|
||||
]),
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "matmul_op",
|
||||
srcs = [
|
||||
"matmul_op.cc",
|
||||
"matmul_op_fused.cc",
|
||||
],
|
||||
hdrs = ["matmul_op.h"],
|
||||
defines = select({
|
||||
":xsmm": ["TENSORFLOW_USE_LIBXSMM"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
deps = MATH_DEPS + [
|
||||
":eigen_contraction_kernel",
|
||||
":fused_eigen_output_kernels",
|
||||
] + select({
|
||||
":xsmm": ["@libxsmm_archive//:xsmm_avx"],
|
||||
"//conditions:default": [],
|
||||
}) + mkl_deps() + if_cuda([
|
||||
"//tensorflow/core/platform/default/build_config:cublas_plugin",
|
||||
]) + if_cuda_or_rocm([":gpu_utils"]),
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "reduction_ops",
|
||||
gpu_srcs = ["reduction_gpu_kernels.cu.h"],
|
||||
@ -3620,25 +3609,6 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "batch_matmul_op_test",
|
||||
size = "small",
|
||||
srcs = ["batch_matmul_op_test.cc"],
|
||||
deps = [
|
||||
":batch_matmul_op",
|
||||
":broadcast_to_op",
|
||||
":ops_testutil",
|
||||
":ops_util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "scan_ops_test",
|
||||
size = "small",
|
||||
@ -5868,8 +5838,8 @@ filegroup(
|
||||
"identity_op.h",
|
||||
"immutable_constant_op.cc",
|
||||
"immutable_constant_op.h",
|
||||
"matmul_op.cc",
|
||||
"matmul_op.h",
|
||||
"matmul_op_impl.h",
|
||||
"matmul_op_real.cc",
|
||||
"no_op.cc",
|
||||
"no_op.h",
|
||||
"one_hot_op.cc",
|
||||
@ -5948,7 +5918,6 @@ filegroup(
|
||||
srcs = [
|
||||
"argmax_op.h",
|
||||
"avgpooling_op.h",
|
||||
"batch_matmul_op_impl.h",
|
||||
"batch_norm_op.h",
|
||||
"bincount_op.h",
|
||||
"broadcast_to_op.h",
|
||||
@ -6039,7 +6008,6 @@ filegroup(
|
||||
":android_extended_ops_headers",
|
||||
"argmax_op.cc",
|
||||
"avgpooling_op.cc",
|
||||
"batch_matmul_op_real.cc",
|
||||
"batch_norm_op.cc",
|
||||
"bcast_ops.cc",
|
||||
"check_numerics_op.cc",
|
||||
@ -6480,6 +6448,7 @@ cc_library(
|
||||
"//tensorflow/core/platform:strong_hash",
|
||||
"//third_party/eigen3",
|
||||
"//third_party/fft2d:fft2d_headers",
|
||||
"//third_party/icu/data:conversion_data",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_protobuf//:protobuf",
|
||||
"@fft2d",
|
||||
@ -7431,7 +7400,6 @@ test_suite(
|
||||
"manual", # Avoid redundancy when using wildcard test patterns.
|
||||
],
|
||||
tests = [
|
||||
":batch_matmul_op_test",
|
||||
":batch_norm_op_test",
|
||||
":broadcast_to_op_test",
|
||||
":cast_op_test",
|
||||
|
@ -1,257 +0,0 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/graph/testlib.h"
|
||||
#include "tensorflow/core/kernels/broadcast_to_op.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Node* BroadcastTo(Graph* g, Node* input, Node* shape) {
|
||||
Node* ret;
|
||||
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastTo")
|
||||
.Input(input)
|
||||
.Input(shape)
|
||||
.Finalize(g, &ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
Node* BatchMatmulV2(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y) {
|
||||
Node* ret;
|
||||
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BatchMatMulV2")
|
||||
.Input(in0)
|
||||
.Input(in1)
|
||||
.Attr("adj_x", adj_x)
|
||||
.Attr("adj_y", adj_y)
|
||||
.Finalize(g, &ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static Graph* BatchMatmul(int b, int m, int k, int n, bool adjoint_a,
|
||||
bool adjoint_b, DataType type) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
Tensor in0(type, adjoint_a ? TensorShape({b, k, m}) : TensorShape({b, m, k}));
|
||||
in0.flat<T>().setRandom();
|
||||
Tensor in1(type, adjoint_b ? TensorShape({b, n, k}) : TensorShape({b, k, n}));
|
||||
in1.flat<T>().setRandom();
|
||||
test::graph::BatchMatmul(g, test::graph::Constant(g, in0),
|
||||
test::graph::Constant(g, in1), adjoint_a, adjoint_b);
|
||||
return g;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static Graph* BatchMatmulWithBroadcast(int b0, int b1, int m, int k, int n,
|
||||
bool manual_broadcast, DataType type) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
Tensor in0(type, TensorShape({b0, m, k}));
|
||||
in0.flat<T>().setRandom();
|
||||
Tensor in1(type, TensorShape({b1, k, n}));
|
||||
in1.flat<T>().setRandom();
|
||||
|
||||
Tensor broadcasted_in0_shape(DT_INT64, TensorShape({3}));
|
||||
Tensor broadcasted_in1_shape(DT_INT64, TensorShape({3}));
|
||||
|
||||
Node* in0_node = nullptr;
|
||||
Node* in1_node = nullptr;
|
||||
if (manual_broadcast) {
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
auto vec0 = broadcasted_in0_shape.vec<int64>();
|
||||
auto vec1 = broadcasted_in1_shape.vec<int64>();
|
||||
vec0(i) = (i == 0 ? std::max(b0, b1) : in0.shape().dim_size(i));
|
||||
vec1(i) = (i == 0 ? std::max(b0, b1) : in1.shape().dim_size(i));
|
||||
}
|
||||
in0_node = BroadcastTo(g, test::graph::Constant(g, in0),
|
||||
test::graph::Constant(g, broadcasted_in0_shape));
|
||||
in1_node = BroadcastTo(g, test::graph::Constant(g, in1),
|
||||
test::graph::Constant(g, broadcasted_in1_shape));
|
||||
} else {
|
||||
in0_node = test::graph::Constant(g, in0);
|
||||
in1_node = test::graph::Constant(g, in1);
|
||||
}
|
||||
|
||||
BatchMatmulV2(g, in0_node, in1_node, false, false);
|
||||
return g;
|
||||
}
|
||||
|
||||
#define BM_BatchMatmulDev(B, M, K, N, TA, TB, T, TFTYPE, DEVICE) \
|
||||
static void \
|
||||
BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE( \
|
||||
int iters) { \
|
||||
testing::UseRealTime(); \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * B * M * K * N * 2); \
|
||||
test::Benchmark(#DEVICE, BatchMatmul<T>(B, M, K, N, TA, TB, TFTYPE)) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
BENCHMARK( \
|
||||
BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE);
|
||||
|
||||
#define BM_BatchMatmul(B, M, K, N, TA, TB) \
|
||||
BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, cpu);
|
||||
// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64,
|
||||
// cpu);
|
||||
// BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, gpu);
|
||||
/* Uncomment to enable benchmarks for double & complex types: */
|
||||
// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64,
|
||||
// gpu);
|
||||
// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, cpu); \
|
||||
// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, cpu);
|
||||
// \
|
||||
// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, gpu); \
|
||||
// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, gpu);
|
||||
|
||||
// Macro arguments names: --------------------------------------------------- //
|
||||
// B1: batch size of LHS
|
||||
// B2: batch size of RHS
|
||||
// M: outer dimension of LHS
|
||||
// K: inner dimensions of LHS and RHS
|
||||
// N: outer dimension of RHS
|
||||
// MB: boolean indicating whether to use manual broadcasting
|
||||
// T: C++ type of scalars (e.g. float, std::complex)
|
||||
// TT: TensorFlow type of scalars (e.g. DT_FLOAT, DT_COMPLEX128
|
||||
// D: Device (e.g. cpu, gpu)
|
||||
#define BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, T, TT, D) \
|
||||
static void \
|
||||
BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D( \
|
||||
int iters) { \
|
||||
testing::UseRealTime(); \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * std::max(B1, B2) * M * \
|
||||
K * N * 2); \
|
||||
test::Benchmark(#D, BatchMatmulWithBroadcast<T>(B1, B2, M, K, N, MB, TT)) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
BENCHMARK( \
|
||||
BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D);
|
||||
|
||||
#define BM_BatchMatmulBCast(B1, B2, M, K, N, MB) \
|
||||
BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, float, DT_FLOAT, cpu);
|
||||
|
||||
// Typical fully connected layers
|
||||
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, false);
|
||||
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, false);
|
||||
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, false);
|
||||
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, false);
|
||||
|
||||
// Square matmul.
|
||||
BM_BatchMatmulBCast(1, 128, 512, 512, 512, true);
|
||||
BM_BatchMatmulBCast(1, 128, 512, 512, 512, false);
|
||||
BM_BatchMatmulBCast(128, 1, 512, 512, 512, true);
|
||||
BM_BatchMatmulBCast(128, 1, 512, 512, 512, false);
|
||||
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, false);
|
||||
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, false);
|
||||
|
||||
// Matrix-vector multiplies.
|
||||
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, true);
|
||||
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, false);
|
||||
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, true);
|
||||
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, false);
|
||||
|
||||
// Vector-matrix multiplies.
|
||||
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, true);
|
||||
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, false);
|
||||
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, true);
|
||||
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, false);
|
||||
|
||||
// Typical fully connected layers
|
||||
BM_BatchMatmul(1, 1, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(1, 8, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(1, 16, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(1, 128, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(2, 1, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(2, 8, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(2, 16, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(2, 128, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(8, 1, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(8, 8, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(8, 16, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(8, 128, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(32, 1, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(32, 8, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(32, 16, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(32, 128, 1024, 1024, false, false);
|
||||
|
||||
// Square matmul.
|
||||
BM_BatchMatmul(1, 32, 32, 32, false, false);
|
||||
BM_BatchMatmul(1, 128, 128, 128, false, false);
|
||||
BM_BatchMatmul(1, 256, 256, 256, false, false);
|
||||
BM_BatchMatmul(1, 1024, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(1, 2048, 2048, 2048, false, false);
|
||||
BM_BatchMatmul(2, 32, 32, 32, false, false);
|
||||
BM_BatchMatmul(2, 128, 128, 128, false, false);
|
||||
BM_BatchMatmul(2, 256, 256, 256, false, false);
|
||||
BM_BatchMatmul(2, 1024, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(2, 2048, 2048, 2048, false, false);
|
||||
BM_BatchMatmul(4, 32, 32, 32, false, false);
|
||||
BM_BatchMatmul(4, 128, 128, 128, false, false);
|
||||
BM_BatchMatmul(4, 256, 256, 256, false, false);
|
||||
BM_BatchMatmul(4, 1024, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(4, 2048, 2048, 2048, false, false);
|
||||
BM_BatchMatmul(8, 32, 32, 32, false, false);
|
||||
BM_BatchMatmul(8, 128, 128, 128, false, false);
|
||||
BM_BatchMatmul(8, 256, 256, 256, false, false);
|
||||
BM_BatchMatmul(8, 1024, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(8, 2048, 2048, 2048, false, false);
|
||||
BM_BatchMatmul(32, 32, 32, 32, false, false);
|
||||
BM_BatchMatmul(32, 128, 128, 128, false, false);
|
||||
BM_BatchMatmul(32, 256, 256, 256, false, false);
|
||||
BM_BatchMatmul(32, 1024, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(32, 2048, 2048, 2048, false, false);
|
||||
|
||||
// Matrix-vector multiplies.
|
||||
BM_BatchMatmul(1, 10000, 200, 1, false, false);
|
||||
BM_BatchMatmul(8, 10000, 200, 1, false, false);
|
||||
BM_BatchMatmul(32, 10000, 200, 1, false, false);
|
||||
BM_BatchMatmul(1, 10000, 200, 1, true, false);
|
||||
BM_BatchMatmul(8, 10000, 200, 1, true, false);
|
||||
BM_BatchMatmul(32, 10000, 200, 1, true, false);
|
||||
BM_BatchMatmul(1, 10000, 200, 1, false, true);
|
||||
BM_BatchMatmul(8, 10000, 200, 1, false, true);
|
||||
BM_BatchMatmul(32, 10000, 200, 1, false, true);
|
||||
BM_BatchMatmul(1, 10000, 200, 1, true, true);
|
||||
BM_BatchMatmul(8, 10000, 200, 1, true, true);
|
||||
BM_BatchMatmul(32, 10000, 200, 1, true, true);
|
||||
|
||||
// Vector-matrix multiplies.
|
||||
BM_BatchMatmul(1, 1, 200, 10000, false, false);
|
||||
BM_BatchMatmul(8, 1, 200, 10000, false, false);
|
||||
BM_BatchMatmul(32, 1, 200, 10000, false, false);
|
||||
BM_BatchMatmul(1, 1, 200, 10000, true, false);
|
||||
BM_BatchMatmul(8, 1, 200, 10000, true, false);
|
||||
BM_BatchMatmul(32, 1, 200, 10000, true, false);
|
||||
BM_BatchMatmul(1, 1, 200, 10000, false, true);
|
||||
BM_BatchMatmul(8, 1, 200, 10000, false, true);
|
||||
BM_BatchMatmul(32, 1, 200, 10000, false, true);
|
||||
BM_BatchMatmul(1, 1, 200, 10000, true, true);
|
||||
BM_BatchMatmul(8, 1, 200, 10000, true, true);
|
||||
BM_BatchMatmul(32, 1, 200, 10000, true, true);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -38,6 +38,8 @@ namespace data {
|
||||
/* static */ constexpr const char* const CacheDatasetOp::kOutputTypes;
|
||||
/* static */ constexpr const char* const CacheDatasetOp::kOutputShapes;
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kKeyStrFormat[] = "%%%zuzu_%%%zuzu";
|
||||
constexpr char kPaddingSizeStrFormat[] = "%zu";
|
||||
constexpr char kFileDatasetPrefix[] = "File";
|
||||
@ -57,6 +59,14 @@ constexpr char kCacheCompleted[] = "cache_completed";
|
||||
constexpr char kIndex[] = "index";
|
||||
constexpr char kImpl[] = "Impl";
|
||||
constexpr char kCacheDataset[] = "CacheDataset";
|
||||
constexpr char kIncompleteCacheErrorMessage[] =
|
||||
"The calling iterator did not fully read the dataset being cached. In "
|
||||
"order to avoid unexpected truncation of the dataset, the partially cached "
|
||||
"contents of the dataset will be discarded. This can happen if you have "
|
||||
"an input pipeline similar to `dataset.cache().take(k).repeat()`. You "
|
||||
"should use `dataset.take(k).cache().repeat()` instead.";
|
||||
|
||||
} // namespace
|
||||
|
||||
class CacheDatasetOp::FileDatasetBase : public DatasetBase {
|
||||
public:
|
||||
@ -220,6 +230,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase {
|
||||
|
||||
~FileWriterIterator() override {
|
||||
if (!dataset()->env_->FileExists(MetaFilename(filename_)).ok()) {
|
||||
LOG(WARNING) << kIncompleteCacheErrorMessage;
|
||||
std::vector<string> cache_files;
|
||||
Status s = dataset()->env_->GetMatchingPaths(
|
||||
strings::StrCat(filename_, "*"), &cache_files);
|
||||
@ -754,13 +765,7 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
|
||||
~MemoryWriterIterator() override {
|
||||
mutex_lock l(mu_);
|
||||
if (!temp_cache_.empty() && !cache_->IsCompleted()) {
|
||||
LOG(WARNING)
|
||||
<< "The calling iterator did not fully read the dataset being "
|
||||
"cached. In order to avoid unexpected truncation of the "
|
||||
"dataset, the partially cached contents of the dataset "
|
||||
"will be discarded. This can happen if you have an input "
|
||||
"pipeline similar to `dataset.cache().take(k).repeat()`. "
|
||||
"You should use `dataset.take(k).cache().repeat()` instead.";
|
||||
LOG(WARNING) << kIncompleteCacheErrorMessage;
|
||||
cache_->Reset();
|
||||
}
|
||||
}
|
||||
|
@ -482,7 +482,10 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
VLOG(1) << "Failed to get element from worker "
|
||||
<< task_to_process->address << ": " << s;
|
||||
task_to_process->in_use = false;
|
||||
status_ = s;
|
||||
status_ = Status(
|
||||
s.code(),
|
||||
absl::StrCat("Failed to get element from worker ",
|
||||
task_to_process->address, ": ", s.error_message()));
|
||||
get_next_cv_.notify_all();
|
||||
return;
|
||||
}
|
||||
|
@ -189,14 +189,18 @@ static Graph* DynamicPartition(int num_partitions, int dim) {
|
||||
}
|
||||
|
||||
#define BM_DYNAMIC_PARTITION(DEVICE, T, num) \
|
||||
static void BM_##DEVICE##_dynpart_##T##_##num(int iters, int dim) { \
|
||||
static void BM_##DEVICE##_dynpart_##T##_##num( \
|
||||
::testing::benchmark::State& state) { \
|
||||
const int dim = state.range(0); \
|
||||
\
|
||||
const int64 items = ((128 << 20) / sizeof(T)); \
|
||||
const int64 tot = static_cast<int64>(iters) * items; \
|
||||
testing::ItemsProcessed(tot); \
|
||||
testing::UseRealTime(); \
|
||||
test::Benchmark(#DEVICE, DynamicPartition<T>(num, dim)).Run(iters); \
|
||||
test::Benchmark(#DEVICE, DynamicPartition<T>(num, dim), \
|
||||
/*old_benchmark_api=*/false) \
|
||||
.Run(state); \
|
||||
const int64 tot = static_cast<int64>(state.iterations()) * items; \
|
||||
state.SetItemsProcessed(tot); \
|
||||
} \
|
||||
BENCHMARK(BM_##DEVICE##_dynpart_##T##_##num)->Arg(1)->Arg(256)
|
||||
BENCHMARK(BM_##DEVICE##_dynpart_##T##_##num)->UseRealTime()->Arg(1)->Arg(256)
|
||||
|
||||
BM_DYNAMIC_PARTITION(cpu, float, 2);
|
||||
BM_DYNAMIC_PARTITION(cpu, float, 100);
|
||||
|
@ -1376,7 +1376,7 @@ TEST(EigenSpatialConvolutionsTest, SpatialConvContractionMapper) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void PackRhsHelper(int iters,
|
||||
static void PackRhsHelper(::testing::benchmark::State& state,
|
||||
/* Input dimensions: */
|
||||
int input_batches, int input_cols, int input_rows,
|
||||
int input_depth,
|
||||
@ -1393,9 +1393,6 @@ static void PackRhsHelper(int iters,
|
||||
// Set random seed for benchmark repeatability.
|
||||
srand(12345);
|
||||
|
||||
tensorflow::testing::UseRealTime();
|
||||
tensorflow::testing::StopTiming();
|
||||
|
||||
using Dimensions = Eigen::DSizes<Eigen::Index, 4>;
|
||||
|
||||
// Default Eigen::Tensor layout is column major, so we configure dimensions
|
||||
@ -1547,8 +1544,7 @@ static void PackRhsHelper(int iters,
|
||||
return (idx / packet_size) * packet_size;
|
||||
};
|
||||
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
for (auto s : state) {
|
||||
int input_idx =
|
||||
num_inputs == 1 ? 1 : internal::random<int>(0, num_inputs - 1);
|
||||
|
||||
@ -1571,15 +1567,15 @@ static void PackRhsHelper(int iters,
|
||||
input_mappers[input_idx].getSubMapper(depth_offset, col_offset);
|
||||
pack_rhs(packed.data() + packed_offset, sub_mapper, depth, cols);
|
||||
}
|
||||
tensorflow::testing::StopTiming();
|
||||
tensorflow::testing::SetLabel(
|
||||
|
||||
state.SetLabel(
|
||||
absl::StrCat("patch: ", patch_rows, "x", patch_cols, " D", patch_depth,
|
||||
"; num_patches=", num_patches, " patch_size=", patch_size,
|
||||
" num_inputs=", num_inputs, " padding=", padding));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void PackLhsHelper(int iters,
|
||||
static void PackLhsHelper(::testing::benchmark::State& state,
|
||||
/* Input dimensions: */
|
||||
int input_depth,
|
||||
/* Filter (kernel) dimensions: */
|
||||
@ -1592,9 +1588,6 @@ static void PackLhsHelper(int iters,
|
||||
eigen_assert(block_rows <= filter_count);
|
||||
eigen_assert(block_cols <= input_depth * filter_rows * filter_cols);
|
||||
|
||||
tensorflow::testing::UseRealTime();
|
||||
tensorflow::testing::StopTiming();
|
||||
|
||||
using Dimensions = Eigen::DSizes<Eigen::Index, 4>;
|
||||
|
||||
// Default Eigen::Tensor layout is column major, so we configure dimensions
|
||||
@ -1716,8 +1709,7 @@ static void PackLhsHelper(int iters,
|
||||
const Index max_row = filter_count;
|
||||
const Index max_col = filter_rows * filter_cols * input_depth;
|
||||
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
for (auto s : state) {
|
||||
int filter_idx =
|
||||
num_filters == 1 ? 1 : internal::random<int>(0, num_filters - 1);
|
||||
|
||||
@ -1743,8 +1735,7 @@ static void PackLhsHelper(int iters,
|
||||
pack_lhs(packed.data() + packed_offset, sub_mapper, cols, rows);
|
||||
#endif
|
||||
}
|
||||
tensorflow::testing::StopTiming();
|
||||
tensorflow::testing::SetLabel(absl::StrCat(
|
||||
state.SetLabel(absl::StrCat(
|
||||
"filter: count=", filter_count, " dims=", filter_rows, "x", filter_cols,
|
||||
"; input: depth=", input_depth, "; num_filers=", num_filters));
|
||||
}
|
||||
@ -1777,12 +1768,14 @@ static void PackLhsHelper(int iters,
|
||||
|
||||
#define BM_PackRhs(T, N, H, W, C, FC, FH, FW, PAD, SH, SW, ISH, ISW, BR, BC) \
|
||||
static void BM_RHS_NAME(PackRhs, T, N, H, W, C, FC, FH, FW, PAD, SH, SW, \
|
||||
ISH, ISW, BR, BC)(int iters) { \
|
||||
PackRhsHelper<T>(iters, N, H, W, C, FC, FH, FW, PADDING_##PAD, SH, SW, \
|
||||
ISH, ISW, BR, \
|
||||
BC)(::testing::benchmark::State & state) { \
|
||||
PackRhsHelper<T>(state, N, H, W, C, FC, FH, FW, PADDING_##PAD, SH, SW, \
|
||||
ISH, ISW, BR, BC); \
|
||||
} \
|
||||
BENCHMARK(BM_RHS_NAME(PackRhs, T, N, H, W, C, FC, FH, FW, PAD, SH, SW, ISH, \
|
||||
ISW, BR, BC))
|
||||
ISW, BR, BC)) \
|
||||
->UseRealTime()
|
||||
|
||||
// Number of input channel (input depth) it equal to the number of patch
|
||||
// channels (patch depth).
|
||||
@ -2020,10 +2013,11 @@ BM_PackRhs(/*type*/ qint8, //
|
||||
BM_CONCAT(BM_##prefix##_##T##_##C##_FC##FC##_##FH##x##FW, _B##BR##x##BC)
|
||||
|
||||
#define BM_PackLhs(T, C, FC, FH, FW, BR, BC) \
|
||||
static void BM_LHS_NAME(PackLhs, T, C, FC, FH, FW, BR, BC)(int iters) { \
|
||||
PackLhsHelper<T>(iters, C, FC, FH, FW, BR, BC); \
|
||||
static void BM_LHS_NAME(PackLhs, T, C, FC, FH, FW, BR, \
|
||||
BC)(::testing::benchmark::State & state) { \
|
||||
PackLhsHelper<T>(state, C, FC, FH, FW, BR, BC); \
|
||||
} \
|
||||
BENCHMARK(BM_LHS_NAME(PackLhs, T, C, FC, FH, FW, BR, BC))
|
||||
BENCHMARK(BM_LHS_NAME(PackLhs, T, C, FC, FH, FW, BR, BC))->UseRealTime()
|
||||
|
||||
// Number of input channel (input depth) it equal to the number of patch
|
||||
// channels (patch depth).
|
||||
|
@ -349,15 +349,16 @@ typedef BenchmarkOptions<ExampleStore<FloatFiller>, kRagged> RaggedFloat;
|
||||
// B == batch_size, K == num_keys. F == feature_size.
|
||||
// K must be one of 10, 100, 1000
|
||||
#define BM_ParseExample(TYPE, B, K, F) \
|
||||
static void BM_ParseExample##_##TYPE##_##B##_##K##_##F(int iters) { \
|
||||
static void BM_ParseExample##_##TYPE##_##B##_##K##_##F( \
|
||||
::testing::benchmark::State& state) { \
|
||||
int64 items_per_iter = static_cast<int64>(B) * K * F; \
|
||||
testing::UseRealTime(); \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * items_per_iter); \
|
||||
test::Benchmark("cpu", ParseExample<TYPE>(B, K, F), nullptr, nullptr, \
|
||||
nullptr, "SINGLE_THREADED_EXECUTOR") \
|
||||
.Run(iters); \
|
||||
nullptr, "SINGLE_THREADED_EXECUTOR", false) \
|
||||
.Run(state); \
|
||||
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * \
|
||||
items_per_iter); \
|
||||
} \
|
||||
BENCHMARK(BM_ParseExample##_##TYPE##_##B##_##K##_##F);
|
||||
BENCHMARK(BM_ParseExample##_##TYPE##_##B##_##K##_##F)->UseRealTime();
|
||||
|
||||
#define BM_AllParseExample(Type) \
|
||||
BM_ParseExample(Type, 1, 10, 1); \
|
||||
@ -385,15 +386,17 @@ BM_AllParseExample(VarLenDenseFloat);
|
||||
// K must be one of 10, 100, 1000
|
||||
// B=0 indicates that a scalar input should be used (instead of a vector).
|
||||
#define BM_ParseExampleV2(TYPE, B, K, F) \
|
||||
static void BM_ParseExampleV2##_##TYPE##_##B##_##K##_##F(int iters) { \
|
||||
static void BM_ParseExampleV2##_##TYPE##_##B##_##K##_##F( \
|
||||
::testing::benchmark::State& state) { \
|
||||
int64 items_per_iter = static_cast<int64>(std::max(B, 1)) * K * F; \
|
||||
testing::UseRealTime(); \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * items_per_iter); \
|
||||
test::Benchmark("cpu", ParseExampleV2<TYPE>(B, K, F), nullptr, nullptr, \
|
||||
nullptr, "SINGLE_THREADED_EXECUTOR") \
|
||||
.Run(iters); \
|
||||
nullptr, "SINGLE_THREADED_EXECUTOR", \
|
||||
/*old_benchmark_api=*/false) \
|
||||
.Run(state); \
|
||||
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * \
|
||||
items_per_iter); \
|
||||
} \
|
||||
BENCHMARK(BM_ParseExampleV2##_##TYPE##_##B##_##K##_##F);
|
||||
BENCHMARK(BM_ParseExampleV2##_##TYPE##_##B##_##K##_##F)->UseRealTime();
|
||||
|
||||
#define BM_AllParseExampleV2(Type) \
|
||||
/* Vector Inputs */ \
|
||||
@ -437,15 +440,17 @@ BM_AllParseExampleV2(RaggedFloat);
|
||||
// K == num_keys. F == feature_size.
|
||||
// K must be one of 10, 100, 1000
|
||||
#define BM_ParseSingleExample(TYPE, K, F) \
|
||||
static void BM_ParseSingleExample##_##TYPE##_1_##K##_##F(int iters) { \
|
||||
void BM_ParseSingleExample##_##TYPE##_1_##K##_##F( \
|
||||
::testing::benchmark::State& state) { \
|
||||
int64 items_per_iter = K * F; \
|
||||
testing::UseRealTime(); \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * items_per_iter); \
|
||||
test::Benchmark("cpu", ParseSingleExample<TYPE>(K, F), nullptr, nullptr, \
|
||||
nullptr, "SINGLE_THREADED_EXECUTOR") \
|
||||
.Run(iters); \
|
||||
nullptr, "SINGLE_THREADED_EXECUTOR", \
|
||||
/*old_benchmark_api=*/false) \
|
||||
.Run(state); \
|
||||
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * \
|
||||
items_per_iter); \
|
||||
} \
|
||||
BENCHMARK(BM_ParseSingleExample##_##TYPE##_1_##K##_##F);
|
||||
BENCHMARK(BM_ParseSingleExample##_##TYPE##_1_##K##_##F)->UseRealTime();
|
||||
|
||||
#define BM_AllParseSingleExample(Type) \
|
||||
BM_ParseSingleExample(Type, 10, 1); \
|
||||
|
@ -133,14 +133,18 @@ static Graph* GatherNd(int dim) {
|
||||
}
|
||||
|
||||
#define BM_GATHER_ND(DEVICE, INDEX) \
|
||||
static void BM_##DEVICE##_gather_nd_##INDEX(int iters, int dim) { \
|
||||
const int64 tot = static_cast<int64>(iters) * kLookups * 4; \
|
||||
testing::ItemsProcessed(tot); \
|
||||
testing::BytesProcessed(tot * sizeof(float)); \
|
||||
testing::UseRealTime(); \
|
||||
test::Benchmark(#DEVICE, GatherNd<INDEX>(dim)).Run(iters); \
|
||||
static void BM_##DEVICE##_gather_nd_##INDEX( \
|
||||
::testing::benchmark::State& state) { \
|
||||
const int dim = state.range(0); \
|
||||
test::Benchmark(#DEVICE, GatherNd<INDEX>(dim), \
|
||||
/*old_benchmark_api=*/false) \
|
||||
.Run(state); \
|
||||
const int64 tot = static_cast<int64>(state.iterations()) * kLookups * 4; \
|
||||
state.SetItemsProcessed(tot); \
|
||||
state.SetBytesProcessed(tot * sizeof(float)); \
|
||||
} \
|
||||
BENCHMARK(BM_##DEVICE##_gather_nd_##INDEX) \
|
||||
->UseRealTime() \
|
||||
->Arg(10) \
|
||||
->Arg(100) \
|
||||
->Arg(1000) \
|
||||
|
@ -223,14 +223,17 @@ static Graph* Gather(int dim) {
|
||||
}
|
||||
|
||||
#define BM_GATHER(DEVICE, INDEX) \
|
||||
static void BM_##DEVICE##_gather_##INDEX(int iters, int dim) { \
|
||||
const int64 tot = static_cast<int64>(iters) * kLookups * dim; \
|
||||
testing::ItemsProcessed(tot); \
|
||||
testing::BytesProcessed(tot * sizeof(float)); \
|
||||
testing::UseRealTime(); \
|
||||
test::Benchmark(#DEVICE, Gather<INDEX>(dim)).Run(iters); \
|
||||
static void BM_##DEVICE##_gather_##INDEX( \
|
||||
::testing::benchmark::State& state) { \
|
||||
const int dim = state.range(0); \
|
||||
test::Benchmark(#DEVICE, Gather<INDEX>(dim), /*old_benchmark_api=*/false) \
|
||||
.Run(state); \
|
||||
const int64 tot = static_cast<int64>(state.iterations()) * kLookups * dim; \
|
||||
state.SetItemsProcessed(tot); \
|
||||
state.SetBytesProcessed(tot * sizeof(float)); \
|
||||
} \
|
||||
BENCHMARK(BM_##DEVICE##_gather_##INDEX) \
|
||||
->UseRealTime() \
|
||||
->Arg(1) \
|
||||
->Arg(10) \
|
||||
->Arg(20) \
|
||||
|
@ -66,12 +66,15 @@ static Graph* InTopK(int num_targets, int num_classes, T top_k) {
|
||||
BM_InTopK##_##T##_##TARGETS##_##CLASSES##_##K##_##DEVICE
|
||||
|
||||
#define BM_InTopK(T, TARGETS, CLASSES, K, DEVICE) \
|
||||
static void BM_NAME(T, TARGETS, CLASSES, K, DEVICE)(int iters) { \
|
||||
testing::UseRealTime(); \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * TARGETS * CLASSES); \
|
||||
test::Benchmark(#DEVICE, InTopK<T>(TARGETS, CLASSES, K)).Run(iters); \
|
||||
static void BM_NAME(T, TARGETS, CLASSES, K, \
|
||||
DEVICE)(::testing::benchmark::State & state) { \
|
||||
test::Benchmark(#DEVICE, InTopK<T>(TARGETS, CLASSES, K), \
|
||||
/*old_benchmark_api=*/false) \
|
||||
.Run(state); \
|
||||
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * TARGETS * \
|
||||
CLASSES); \
|
||||
} \
|
||||
BENCHMARK(BM_NAME(T, TARGETS, CLASSES, K, DEVICE));
|
||||
BENCHMARK(BM_NAME(T, TARGETS, CLASSES, K, DEVICE))->UseRealTime();
|
||||
|
||||
BM_InTopK(int64, 64, 1000, 10, cpu);
|
||||
BM_InTopK(int64, 64, 10000, 10, cpu);
|
||||
|
@ -30,9 +30,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
#include "tensorflow/core/kernels/linalg/einsum_op.h"
|
||||
#include "tensorflow/core/kernels/matmul_op_impl.h"
|
||||
#include "tensorflow/core/kernels/reduction_ops_common.h"
|
||||
#include "tensorflow/core/kernels/transpose_functor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
@ -199,7 +199,7 @@ TCASE(T3, 128, 4, 3, 2.0f, 1.0f, 1.0f)
|
||||
|
||||
#undef TCASE
|
||||
|
||||
static Graph* BM_LRNGrad(int batches, int rows, int cols, int depth,
|
||||
static Graph* MakeRNGrad(int batches, int rows, int cols, int depth,
|
||||
int depth_radius) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
Tensor grads(DT_FLOAT, TensorShape({batches, rows, cols, depth}));
|
||||
@ -224,10 +224,13 @@ static Graph* BM_LRNGrad(int batches, int rows, int cols, int depth,
|
||||
}
|
||||
|
||||
#define BM_LRNGradDev(DEVICE, B, R, C, D, DR) \
|
||||
static void BM_LRNGrad_##DEVICE##_##B##_##R##_##C##_##D##_##DR(int iters) { \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * B * R * C * D * DR * \
|
||||
4); \
|
||||
test::Benchmark(#DEVICE, BM_LRNGrad(B, R, C, D, DR)).Run(iters); \
|
||||
static void BM_LRNGrad_##DEVICE##_##B##_##R##_##C##_##D##_##DR( \
|
||||
::testing::benchmark::State& state) { \
|
||||
test::Benchmark(#DEVICE, MakeRNGrad(B, R, C, D, DR), \
|
||||
/*old_benchmark_api*/ false) \
|
||||
.Run(state); \
|
||||
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * B * R * \
|
||||
C * D * DR * 4); \
|
||||
} \
|
||||
BENCHMARK(BM_LRNGrad_##DEVICE##_##B##_##R##_##C##_##D##_##DR)
|
||||
|
||||
|
@ -1,567 +0,0 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// See docs in ../ops/math_ops.cc.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/kernels/matmul_op.h"
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
#include "tensorflow/core/util/matmul_autotune.h"
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#endif
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/gpu_utils.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
template <typename Device, typename T, bool USE_CUBLAS>
|
||||
struct LaunchMatMul;
|
||||
|
||||
namespace {
|
||||
// Converts a TensorFlow Tensor to an Eigen Matrix.
|
||||
template <typename T>
|
||||
Eigen::Map<
|
||||
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
|
||||
ToEigenMatrix(const Tensor& tensor) {
|
||||
auto matrix = tensor.matrix<T>();
|
||||
return Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>::Map(
|
||||
matrix.data(), matrix.dimension(0), matrix.dimension(1));
|
||||
}
|
||||
|
||||
// Converts a TensorFlow Tensor to an Eigen Vector.
|
||||
template <typename T>
|
||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, 1>> ToEigenVector(Tensor* tensor) {
|
||||
auto v = tensor->flat<T>();
|
||||
return Eigen::Matrix<T, Eigen::Dynamic, 1>::Map(v.data(), v.dimension(0));
|
||||
}
|
||||
template <typename T>
|
||||
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>> ToEigenVector(
|
||||
const Tensor& tensor) {
|
||||
auto v = tensor.flat<T>();
|
||||
return Eigen::Matrix<T, Eigen::Dynamic, 1>::Map(v.data(), v.dimension(0));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// If either side can be represented as a vector, do an explicit vector
|
||||
// matrix multiply and return true; else return false.
|
||||
//
|
||||
// Note: this uses plain Eigen and not Eigen Tensor because it is more
|
||||
// efficient.
|
||||
template <typename T>
|
||||
bool ExplicitVectorMatrixOptimization(
|
||||
const Tensor& a, const Tensor& b,
|
||||
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
|
||||
Tensor* out) {
|
||||
if (out->dim_size(0) == 1) {
|
||||
if (dim_pair[0].second == 0) {
|
||||
// Note: this case is optimized in Eigen Tensors.
|
||||
return false;
|
||||
} else {
|
||||
auto out_v = ToEigenVector<T>(out);
|
||||
auto a_v = ToEigenVector<T>(a);
|
||||
auto b_m = ToEigenMatrix<T>(b);
|
||||
out_v.noalias() = b_m * a_v;
|
||||
}
|
||||
return true;
|
||||
} else if (out->dim_size(1) == 1) {
|
||||
auto out_v = ToEigenVector<T>(out);
|
||||
auto a_m = ToEigenMatrix<T>(a);
|
||||
auto b_v = ToEigenVector<T>(b);
|
||||
if (dim_pair[0].first == 0) {
|
||||
out_v.noalias() = a_m.transpose() * b_v;
|
||||
} else {
|
||||
out_v.noalias() = a_m * b_v;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
// Half is not supported.
|
||||
template <>
|
||||
bool ExplicitVectorMatrixOptimization<Eigen::half>(
|
||||
const Tensor& a, const Tensor& b,
|
||||
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
|
||||
Tensor* out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct LaunchMatMulBase {
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
typedef se::blas::AlgorithmType AlgorithmType;
|
||||
#else
|
||||
typedef int64 AlgorithmType;
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
static void launch(
|
||||
OpKernelContext* ctx, const Tensor& a, const Tensor& b,
|
||||
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
|
||||
std::vector<AlgorithmType>* algorithms, bool use_autotune, Tensor* out) {
|
||||
// An explicit vector-matrix multiply is much better optimized than an
|
||||
// implicit one and this is a bottleneck during non-batched inference.
|
||||
bool was_vector = ExplicitVectorMatrixOptimization<T>(a, b, dim_pair, out);
|
||||
if (!was_vector) {
|
||||
functor::MatMulFunctor<Device, T>()(ctx->eigen_device<Device>(),
|
||||
out->matrix<T>(), a.matrix<T>(),
|
||||
b.matrix<T>(), dim_pair);
|
||||
}
|
||||
}
|
||||
|
||||
static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
|
||||
std::vector<int64>* algorithms,
|
||||
bool* algorithm_set_flag) {}
|
||||
};
|
||||
// On CPUs, we ignore USE_CUBLAS
|
||||
template <typename T>
|
||||
struct LaunchMatMulCPU : LaunchMatMulBase<CPUDevice, T> {};
|
||||
|
||||
template <typename T, bool USE_CUBLAS>
|
||||
struct LaunchMatMul<CPUDevice, T, USE_CUBLAS> : public LaunchMatMulCPU<T> {};
|
||||
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
struct LaunchBlasGemv {
|
||||
static void Compute(OpKernelContext* ctx, se::Stream* stream, bool trans,
|
||||
uint64 m, uint64 n, const se::DeviceMemory<T>& a,
|
||||
const se::DeviceMemory<T>& b, se::DeviceMemory<T>* c,
|
||||
se::blas::ProfileResult* output_profile) {
|
||||
const auto blas_trans = trans ? se::blas::Transpose::kTranspose
|
||||
: se::blas::Transpose::kNoTranspose;
|
||||
if (output_profile == nullptr) {
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemv(blas_trans, m, n, static_cast<T>(1.0), a, m, b, 1,
|
||||
static_cast<T>(0.0), c, 1)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
ctx->SetStatus(
|
||||
errors::Internal("Blas GEMV launch failed: m=", m, ", n=", n));
|
||||
}
|
||||
} else {
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemvWithProfiling(blas_trans, m, n, static_cast<T>(1.0),
|
||||
a, m, b, 1, static_cast<T>(0.0), c, 1,
|
||||
output_profile)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
ctx->SetStatus(errors::Internal(
|
||||
"Blas GEMV with profiling launch failed: m=", m, ", n=", n));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static bool IsSupported() { return true; }
|
||||
};
|
||||
|
||||
template <>
|
||||
void LaunchBlasGemv<Eigen::half>::Compute(
|
||||
OpKernelContext* ctx, se::Stream* stream, bool trans, uint64 m, uint64 n,
|
||||
const se::DeviceMemory<Eigen::half>& a,
|
||||
const se::DeviceMemory<Eigen::half>& b, se::DeviceMemory<Eigen::half>* c,
|
||||
se::blas::ProfileResult* output_profile) {
|
||||
ctx->SetStatus(errors::Internal(
|
||||
"Blas GEMV launch failed: GEMV is not implemented for float16."));
|
||||
}
|
||||
|
||||
template <>
|
||||
bool LaunchBlasGemv<Eigen::half>::IsSupported() {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ShouldUseGemv(uint64 n) {
|
||||
return (LaunchBlasGemv<T>::IsSupported() && n == 1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool GetCublasAutotuneComputationType(const DataType& dtype,
|
||||
se::blas::ComputationType* compute_type) {
|
||||
using se::blas::ComputationType;
|
||||
switch (dtype) {
|
||||
case DT_HALF:
|
||||
case DT_BFLOAT16:
|
||||
static bool use_f32_for_f16_computation =
|
||||
MatmulDoFP32ComputationFP16Input();
|
||||
if (use_f32_for_f16_computation) {
|
||||
*compute_type = ComputationType::kF32;
|
||||
} else {
|
||||
*compute_type = ComputationType::kF16;
|
||||
}
|
||||
return false;
|
||||
case DT_FLOAT:
|
||||
*compute_type = ComputationType::kF32;
|
||||
return true;
|
||||
case DT_DOUBLE:
|
||||
*compute_type = ComputationType::kF64;
|
||||
return true;
|
||||
default:
|
||||
// Unsupported compute_type, return false.
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// A dummy type to group matmul autotune results together.
|
||||
struct MatmulAutoTuneGroup {
|
||||
static string name() { return "Matmul"; }
|
||||
};
|
||||
typedef AutoTuneSingleton<MatmulAutoTuneGroup, MatmulParameters,
|
||||
se::blas::AlgorithmConfig>
|
||||
AutoTuneMatmul;
|
||||
|
||||
template <typename T>
|
||||
struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
|
||||
static void launch(
|
||||
OpKernelContext* ctx, const Tensor& a, const Tensor& b,
|
||||
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
|
||||
std::vector<int64>* algorithms, bool use_autotune, Tensor* out) {
|
||||
using se::blas::AlgorithmConfig;
|
||||
using se::blas::ComputationType;
|
||||
using se::blas::kDefaultAlgorithm;
|
||||
using se::blas::kDefaultBlasGemm;
|
||||
using se::blas::kDefaultBlasGemv;
|
||||
using se::blas::kNoAlgorithm;
|
||||
using se::blas::ProfileResult;
|
||||
using se::blas::Transpose;
|
||||
Transpose trans[] = {Transpose::kNoTranspose, Transpose::kTranspose};
|
||||
const uint64 m = a.dim_size(1 - dim_pair[0].first);
|
||||
const uint64 k = a.dim_size(dim_pair[0].first);
|
||||
const uint64 n = b.dim_size(1 - dim_pair[0].second);
|
||||
bool transpose_a = dim_pair[0].first == 0;
|
||||
bool transpose_b = dim_pair[0].second == 1;
|
||||
auto blas_transpose_a = trans[transpose_a];
|
||||
auto blas_transpose_b = trans[transpose_b];
|
||||
|
||||
auto* stream = ctx->op_device_context()->stream();
|
||||
OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
|
||||
|
||||
auto a_ptr = AsDeviceMemory(a.template flat<T>().data(),
|
||||
a.template flat<T>().size());
|
||||
auto b_ptr = AsDeviceMemory(b.template flat<T>().data(),
|
||||
b.template flat<T>().size());
|
||||
auto c_ptr = AsDeviceMemory(out->template flat<T>().data(),
|
||||
out->template flat<T>().size());
|
||||
auto alpha = static_cast<T>(1.0);
|
||||
auto beta = static_cast<T>(0.0);
|
||||
|
||||
int device_id = stream->parent()->device_ordinal();
|
||||
DataType dtype = a.dtype();
|
||||
MatmulParameters matmul_parameters = {
|
||||
transpose_a, transpose_b, m, n, k, dtype, device_id,
|
||||
};
|
||||
AlgorithmConfig algorithm_config(kNoAlgorithm);
|
||||
|
||||
ComputationType computation_type;
|
||||
bool compute_type_supported =
|
||||
GetCublasAutotuneComputationType(dtype, &computation_type);
|
||||
if (use_autotune && compute_type_supported && !algorithms->empty()) {
|
||||
ProfileResult best_result;
|
||||
// TODO(yangzihao): Unify this code with conv autotuning.
|
||||
if (!AutoTuneMatmul::GetInstance()->Find(matmul_parameters,
|
||||
&algorithm_config)) {
|
||||
ProfileResult profile_result;
|
||||
for (auto profile_algorithm : (*algorithms)) {
|
||||
// Cublas does
|
||||
// C = A x B
|
||||
// where A, B and C are assumed to be in column major.
|
||||
// We want the output to be in row-major, so we can compute
|
||||
// C' = B' x A' (' stands for transpose)
|
||||
bool cublas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemmWithAlgorithm(
|
||||
blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
|
||||
transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
|
||||
&c_ptr, n, computation_type, profile_algorithm,
|
||||
&profile_result)
|
||||
.ok();
|
||||
if (cublas_launch_status) {
|
||||
if (profile_result.is_valid()) {
|
||||
if (profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Try BlasGemmWithProfiling
|
||||
bool cublas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemmWithProfiling(
|
||||
blas_transpose_b, blas_transpose_a, n, m, k, 1.0, b_ptr,
|
||||
transpose_b ? k : n, a_ptr, transpose_a ? m : k, 0.0,
|
||||
&c_ptr, n, &profile_result)
|
||||
.ok();
|
||||
if (cublas_launch_status) {
|
||||
if (profile_result.is_valid()) {
|
||||
if (profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Try BlasGemvWithProfiling
|
||||
if (ShouldUseGemv<T>(n)) {
|
||||
LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
|
||||
transpose_a ? m : k, transpose_a ? k : m,
|
||||
a_ptr, b_ptr, &c_ptr, &profile_result);
|
||||
if (profile_result.is_valid()) {
|
||||
if (profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// We make sure that each matmul parameter set only gets one pass of
|
||||
// autotune. If the best result is found, assign it to algorithm_type
|
||||
// and insert it to autotune map. If all internal kernels of
|
||||
// cublasGemmEx() returns invalid results, we add kNoAlgorithm to the
|
||||
// autotune map.
|
||||
if (best_result.is_valid()) {
|
||||
algorithm_config.set_algorithm(best_result.algorithm());
|
||||
}
|
||||
AutoTuneMatmul::GetInstance()->Insert(matmul_parameters,
|
||||
algorithm_config);
|
||||
if (algorithm_config.algorithm() != kNoAlgorithm &&
|
||||
algorithm_config.algorithm() != kDefaultBlasGemm &&
|
||||
algorithm_config.algorithm() != kDefaultBlasGemv) {
|
||||
bool cublas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemmWithAlgorithm(
|
||||
blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
|
||||
transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
|
||||
&c_ptr, n, computation_type, algorithm_config.algorithm(),
|
||||
nullptr)
|
||||
.ok();
|
||||
if (!cublas_launch_status) {
|
||||
ctx->SetStatus(errors::Internal(
|
||||
"Blas GEMM with algorithm launch failed : a.shape=(",
|
||||
a.dim_size(0), ", ", a.dim_size(1), "), b.shape=(", b.dim_size(0),
|
||||
", ", b.dim_size(1), "), m=", m, ", n=", n, ", k=", k));
|
||||
}
|
||||
}
|
||||
}
|
||||
// For the following case, we use normal BlasGemm():
|
||||
// 1) We didn't set the use_autotune flag;
|
||||
// 2) compute type does not support autotune;
|
||||
// 3) no algorithm is found;
|
||||
// 4) all internal kernels in autotune return invalid results.
|
||||
// For the following case, we use normal BlasGemv():
|
||||
// 1) We didn't set the use_autotune flag but LaunchBlasGemv is supported
|
||||
// and n == 1.
|
||||
// 2) We set the use_autotune flag and it picked up BlasGemv() and set the
|
||||
// algorithm_config.algorithm() to be kDefaultBlasGemv.
|
||||
if (!use_autotune || !compute_type_supported || algorithms->empty() ||
|
||||
algorithm_config.algorithm() == kNoAlgorithm ||
|
||||
algorithm_config.algorithm() == kDefaultBlasGemm ||
|
||||
algorithm_config.algorithm() == kDefaultBlasGemv) {
|
||||
if (algorithm_config.algorithm() == kDefaultBlasGemv ||
|
||||
ShouldUseGemv<T>(n)) {
|
||||
// This is a matrix*vector multiply so use GEMV to compute A * b.
|
||||
// Here we are multiplying in the natural order, so we have to flip
|
||||
// the transposition flag to compensate for the tensor being stored
|
||||
// row-major.
|
||||
// TODO(yangzihao): Add Gemv as an autotuning option too.
|
||||
LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
|
||||
transpose_a ? m : k, transpose_a ? k : m,
|
||||
a_ptr, b_ptr, &c_ptr, nullptr);
|
||||
} else {
|
||||
// Use C' = B' x A' (' stands for transpose)
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
|
||||
1.0f, b_ptr, transpose_b ? k : n, a_ptr,
|
||||
transpose_a ? m : k, 0.0f, &c_ptr, n)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
ctx->SetStatus(errors::Internal(
|
||||
"Blas GEMM launch failed : a.shape=(", a.dim_size(0), ", ",
|
||||
a.dim_size(1), "), b.shape=(", b.dim_size(0), ", ", b.dim_size(1),
|
||||
"), m=", m, ", n=", n, ", k=", k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
|
||||
std::vector<int64>* algorithms,
|
||||
bool* algorithm_set_flag) {
|
||||
if (*algorithm_set_flag == false) {
|
||||
auto* stream = ctx->device()->tensorflow_gpu_device_info()->stream;
|
||||
stream->parent()->GetBlasGemmAlgorithms(algorithms);
|
||||
*algorithm_set_flag = true;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
template <typename Device, typename T, bool USE_CUBLAS>
|
||||
class MatMulOp : public OpKernel {
|
||||
public:
|
||||
explicit MatMulOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx), algorithms_set_already_(false) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
|
||||
|
||||
LaunchMatMul<Device, T, USE_CUBLAS>::GetBlasGemmAlgorithm(
|
||||
ctx, &algorithms_, &algorithms_set_already_);
|
||||
use_autotune_ = MatmulAutotuneEnable();
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& a = ctx->input(0);
|
||||
const Tensor& b = ctx->input(1);
|
||||
|
||||
// Check that the dimensions of the two matrices are valid.
|
||||
OP_REQUIRES(
|
||||
ctx, TensorShapeUtils::IsMatrix(a.shape()),
|
||||
errors::InvalidArgument("In[0] is not a matrix. Instead it has shape ",
|
||||
a.shape().DebugString()));
|
||||
OP_REQUIRES(
|
||||
ctx, TensorShapeUtils::IsMatrix(b.shape()),
|
||||
errors::InvalidArgument("In[1] is not a matrix. Instead it has shape ",
|
||||
b.shape().DebugString()));
|
||||
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
|
||||
dim_pair[0].first = transpose_a_ ? 0 : 1;
|
||||
dim_pair[0].second = transpose_b_ ? 1 : 0;
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
|
||||
errors::InvalidArgument(
|
||||
"Matrix size-incompatible: In[0]: ", a.shape().DebugString(),
|
||||
", In[1]: ", b.shape().DebugString()));
|
||||
int a_dim_remaining = 1 - dim_pair[0].first;
|
||||
int b_dim_remaining = 1 - dim_pair[0].second;
|
||||
TensorShape out_shape(
|
||||
{a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)});
|
||||
Tensor* out = nullptr;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
|
||||
|
||||
if (out->NumElements() == 0) {
|
||||
// If a has shape [0, x] or b has shape [x, 0], the output shape
|
||||
// is a 0-element matrix, so there is nothing to do.
|
||||
return;
|
||||
}
|
||||
|
||||
if (a.NumElements() == 0 && b.NumElements() == 0) {
|
||||
// If a has shape [x, 0] and b has shape [0, y], the
|
||||
// output shape is [x, y] where x and y are non-zero, so we fill
|
||||
// the output with zeros.
|
||||
functor::SetZeroFunctor<Device, T> f;
|
||||
f(ctx->eigen_device<Device>(), out->flat<T>());
|
||||
return;
|
||||
}
|
||||
|
||||
if (std::is_same<T, bfloat16>::value) {
|
||||
bool is_cpu = std::is_same<Device, CPUDevice>::value;
|
||||
OP_REQUIRES(ctx, is_cpu,
|
||||
errors::Internal("bfloat16 matmul is not supported by GPU"));
|
||||
Tensor a_float, b_float, out_float;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, a.shape(), &a_float));
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, b.shape(), &b_float));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_temp(DT_FLOAT, out->shape(), &out_float));
|
||||
|
||||
// TODO: Avoid extra copy to make bfloat16 matmul efficient on CPU.
|
||||
BFloat16ToFloat(a.flat<bfloat16>().data(), a_float.flat<float>().data(),
|
||||
a.NumElements());
|
||||
BFloat16ToFloat(b.flat<bfloat16>().data(), b_float.flat<float>().data(),
|
||||
b.NumElements());
|
||||
|
||||
LaunchMatMul<Device, float, USE_CUBLAS>::launch(
|
||||
ctx, a_float, b_float, dim_pair, &algorithms_, use_autotune_,
|
||||
&out_float);
|
||||
FloatToBFloat16(out_float.flat<float>().data(),
|
||||
out->flat<bfloat16>().data(), out->NumElements());
|
||||
} else {
|
||||
LaunchMatMul<Device, T, USE_CUBLAS>::launch(
|
||||
ctx, a, b, dim_pair, &algorithms_, use_autotune_, out);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int64> algorithms_;
|
||||
bool algorithms_set_already_;
|
||||
bool use_autotune_;
|
||||
bool transpose_a_;
|
||||
bool transpose_b_;
|
||||
};
|
||||
|
||||
namespace functor {
|
||||
|
||||
// Partial specialization MatMulFunctor<Device=CPUDevice, T>.
|
||||
template <typename T>
|
||||
struct MatMulFunctor<CPUDevice, T> {
|
||||
void operator()(
|
||||
const CPUDevice& d, typename MatMulTypes<T>::out_type out,
|
||||
typename MatMulTypes<T>::in_type in0,
|
||||
typename MatMulTypes<T>::in_type in1,
|
||||
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) {
|
||||
MatMul<CPUDevice>(d, out, in0, in1, dim_pair);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
} // end namespace functor
|
||||
|
||||
#define REGISTER_CPU_EIGEN(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T").Label("eigen"), \
|
||||
MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>);
|
||||
|
||||
#define REGISTER_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>); \
|
||||
REGISTER_CPU_EIGEN(T);
|
||||
|
||||
#define REGISTER_GPU(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MatMul").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
|
||||
MatMulOp<GPUDevice, T, true /* cublas, true by default */>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("MatMul") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label("cublas"), \
|
||||
MatMulOp<GPUDevice, T, true /* cublas */>)
|
||||
|
||||
TF_CALL_int32(REGISTER_CPU);
|
||||
TF_CALL_int64(REGISTER_CPU);
|
||||
TF_CALL_FLOAT_TYPES(REGISTER_CPU);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_CPU);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_GPU);
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
|
||||
#include "tensorflow/core/kernels/matmul_op_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -15,8 +15,8 @@ limitations under the License.
|
||||
|
||||
// See docs in ../ops/math_ops.cc.
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
@ -633,10 +633,21 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
|
||||
template <typename Device, typename Scalar>
|
||||
class BaseBatchMatMulOp : public OpKernel {
|
||||
public:
|
||||
explicit BaseBatchMatMulOp(OpKernelConstruction* context)
|
||||
explicit BaseBatchMatMulOp(OpKernelConstruction* context,
|
||||
bool is_legacy_matmul)
|
||||
: OpKernel(context) {
|
||||
if (is_legacy_matmul) {
|
||||
// The old MatMul kernel has "transpose_a/transpose_b" attributes.
|
||||
OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &trans_x_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &trans_y_));
|
||||
adj_x_ = false;
|
||||
adj_y_ = false;
|
||||
} else {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
|
||||
trans_x_ = false;
|
||||
trans_y_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
~BaseBatchMatMulOp() override {}
|
||||
@ -672,8 +683,8 @@ class BaseBatchMatMulOp : public OpKernel {
|
||||
in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
|
||||
errors::Internal("Failed to reshape In[1] from ",
|
||||
in1.shape().DebugString()));
|
||||
if (adj_x_) std::swap(d0, d1);
|
||||
if (adj_y_) std::swap(d2, d3);
|
||||
if (adj_x_ || trans_x_) std::swap(d0, d1);
|
||||
if (adj_y_ || trans_y_) std::swap(d2, d3);
|
||||
OP_REQUIRES(ctx, d1 == d2,
|
||||
errors::InvalidArgument(
|
||||
"In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
|
||||
@ -696,9 +707,36 @@ class BaseBatchMatMulOp : public OpKernel {
|
||||
out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})),
|
||||
errors::Internal("Failed to reshape output from ",
|
||||
out->shape().DebugString()));
|
||||
LaunchBatchMatMul<Device, Scalar>::Launch(
|
||||
ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, /*trans_x=*/false,
|
||||
/*trans_y=*/false, bcast, &out_reshaped);
|
||||
if (std::is_same<Scalar, bfloat16>::value) {
|
||||
bool is_cpu = std::is_same<Device, CPUDevice>::value;
|
||||
OP_REQUIRES(ctx, is_cpu,
|
||||
errors::Internal("bfloat16 matmul is not supported by GPU"));
|
||||
Tensor in0_reshaped_float, in1_reshaped_float, out_reshaped_float;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in0_reshaped.shape(),
|
||||
&in0_reshaped_float));
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in1_reshaped.shape(),
|
||||
&in1_reshaped_float));
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, out_reshaped.shape(),
|
||||
&out_reshaped_float));
|
||||
|
||||
// TODO: Avoid extra copy to make bfloat16 matmul efficient on CPU.
|
||||
BFloat16ToFloat(in0_reshaped.flat<bfloat16>().data(),
|
||||
in0_reshaped_float.flat<float>().data(),
|
||||
in0_reshaped.NumElements());
|
||||
BFloat16ToFloat(in1_reshaped.flat<bfloat16>().data(),
|
||||
in1_reshaped_float.flat<float>().data(),
|
||||
in1_reshaped.NumElements());
|
||||
|
||||
LaunchBatchMatMul<Device, float>::Launch(
|
||||
ctx, in0_reshaped_float, in1_reshaped_float, adj_x_, adj_y_, trans_x_,
|
||||
trans_y_, bcast, &out_reshaped_float);
|
||||
FloatToBFloat16(out_reshaped_float.flat<float>().data(),
|
||||
out_reshaped.flat<bfloat16>().data(), out->NumElements());
|
||||
} else {
|
||||
LaunchBatchMatMul<Device, Scalar>::Launch(ctx, in0_reshaped, in1_reshaped,
|
||||
adj_x_, adj_y_, trans_x_,
|
||||
trans_y_, bcast, &out_reshaped);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -706,16 +744,19 @@ class BaseBatchMatMulOp : public OpKernel {
|
||||
const Tensor& in1) = 0;
|
||||
|
||||
private:
|
||||
// TODO(171979567) Make the ops take both adj and transpose attributes.
|
||||
bool adj_x_;
|
||||
bool adj_y_;
|
||||
bool trans_x_;
|
||||
bool trans_y_;
|
||||
};
|
||||
|
||||
// BatchMatMul Op implementation which disallows broadcasting.
|
||||
template <typename Device, typename Scalar>
|
||||
template <typename Device, typename Scalar, bool is_legacy_matmul = false>
|
||||
class BatchMatMulOp : public BaseBatchMatMulOp<Device, Scalar> {
|
||||
public:
|
||||
explicit BatchMatMulOp(OpKernelConstruction* context)
|
||||
: BaseBatchMatMulOp<Device, Scalar>(context) {}
|
||||
: BaseBatchMatMulOp<Device, Scalar>(context, is_legacy_matmul) {}
|
||||
|
||||
~BatchMatMulOp() override {}
|
||||
|
||||
@ -729,15 +770,21 @@ class BatchMatMulOp : public BaseBatchMatMulOp<Device, Scalar> {
|
||||
in0.shape().DebugString(), " vs. ",
|
||||
in1.shape().DebugString()));
|
||||
const int ndims = in0.dims();
|
||||
OP_REQUIRES(
|
||||
ctx, ndims >= 2,
|
||||
errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims));
|
||||
if (is_legacy_matmul) {
|
||||
OP_REQUIRES(ctx, ndims == 2,
|
||||
errors::InvalidArgument(
|
||||
"In[0] and In[1] ndims must be == 2: ", ndims));
|
||||
} else {
|
||||
OP_REQUIRES(ctx, ndims >= 2,
|
||||
errors::InvalidArgument(
|
||||
"In[0] and In[1] ndims must be >= 2: ", ndims));
|
||||
for (int i = 0; i < ndims - 2; ++i) {
|
||||
OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
|
||||
errors::InvalidArgument(
|
||||
"In[0].dim(", i, ") and In[1].dim(", i,
|
||||
") must be the same: ", in0.shape().DebugString(), " vs ",
|
||||
in1.shape().DebugString()));
|
||||
") must be the same: ", in0.shape().DebugString(),
|
||||
" vs ", in1.shape().DebugString()));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -747,7 +794,8 @@ template <typename Device, typename Scalar>
|
||||
class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> {
|
||||
public:
|
||||
explicit BatchMatMulV2Op(OpKernelConstruction* context)
|
||||
: BaseBatchMatMulOp<Device, Scalar>(context) {}
|
||||
: BaseBatchMatMulOp<Device, Scalar>(context,
|
||||
/* is_legacy_matmul= */ false) {}
|
||||
|
||||
~BatchMatMulV2Op() override {}
|
||||
|
||||
@ -771,7 +819,10 @@ class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> {
|
||||
BatchMatMulOp<CPUDevice, TYPE>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
|
||||
BatchMatMulV2Op<CPUDevice, TYPE>)
|
||||
BatchMatMulV2Op<CPUDevice, TYPE>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
|
||||
BatchMatMulOp<CPUDevice, TYPE, /* is_legacy_matmul=*/true>)
|
||||
|
||||
#define REGISTER_BATCH_MATMUL_GPU(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
@ -779,8 +830,11 @@ class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> {
|
||||
BatchMatMulOp<GPUDevice, TYPE>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("BatchMatMulV2").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
|
||||
BatchMatMulV2Op<GPUDevice, TYPE>)
|
||||
BatchMatMulV2Op<GPUDevice, TYPE>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
|
||||
BatchMatMulOp<GPUDevice, TYPE, /* is_legacy_matmul=*/true>)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
|
||||
#endif // TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
|
||||
#include "tensorflow/core/kernels/matmul_op_impl.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
@ -21,17 +21,13 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TF_CALL_float(REGISTER_BATCH_MATMUL_CPU);
|
||||
TF_CALL_double(REGISTER_BATCH_MATMUL_CPU);
|
||||
TF_CALL_half(REGISTER_BATCH_MATMUL_CPU);
|
||||
TF_CALL_FLOAT_TYPES(REGISTER_BATCH_MATMUL_CPU);
|
||||
TF_CALL_int16(REGISTER_BATCH_MATMUL_CPU);
|
||||
TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU);
|
||||
TF_CALL_int64(REGISTER_BATCH_MATMUL_CPU);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
TF_CALL_float(REGISTER_BATCH_MATMUL_GPU);
|
||||
TF_CALL_double(REGISTER_BATCH_MATMUL_GPU);
|
||||
TF_CALL_half(REGISTER_BATCH_MATMUL_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_BATCH_MATMUL_GPU);
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
class FusedMatMulOpTest : public OpsTestBase {
|
||||
@ -459,4 +460,230 @@ BM_Matmul(2000, 1, 2000, true, false);
|
||||
BM_Matmul(2000, 1, 2000, false, true);
|
||||
BM_Matmul(2000, 1, 2000, true, true);
|
||||
|
||||
} // end namespace tensorflow
|
||||
// Benchmarks for batched matmul with broadcasting.
|
||||
Node* BroadcastTo(Graph* g, Node* input, Node* shape) {
|
||||
Node* ret;
|
||||
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastTo")
|
||||
.Input(input)
|
||||
.Input(shape)
|
||||
.Finalize(g, &ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
Node* BatchMatmulV2(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y) {
|
||||
Node* ret;
|
||||
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BatchMatMulV2")
|
||||
.Input(in0)
|
||||
.Input(in1)
|
||||
.Attr("adj_x", adj_x)
|
||||
.Attr("adj_y", adj_y)
|
||||
.Finalize(g, &ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static Graph* BatchMatmul(int b, int m, int k, int n, bool adjoint_a,
|
||||
bool adjoint_b, DataType type) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
Tensor in0(type, adjoint_a ? TensorShape({b, k, m}) : TensorShape({b, m, k}));
|
||||
in0.flat<T>().setRandom();
|
||||
Tensor in1(type, adjoint_b ? TensorShape({b, n, k}) : TensorShape({b, k, n}));
|
||||
in1.flat<T>().setRandom();
|
||||
test::graph::BatchMatmul(g, test::graph::Constant(g, in0),
|
||||
test::graph::Constant(g, in1), adjoint_a, adjoint_b);
|
||||
return g;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static Graph* BatchMatmulWithBroadcast(int b0, int b1, int m, int k, int n,
|
||||
bool manual_broadcast, DataType type) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
Tensor in0(type, TensorShape({b0, m, k}));
|
||||
in0.flat<T>().setRandom();
|
||||
Tensor in1(type, TensorShape({b1, k, n}));
|
||||
in1.flat<T>().setRandom();
|
||||
|
||||
Tensor broadcasted_in0_shape(DT_INT64, TensorShape({3}));
|
||||
Tensor broadcasted_in1_shape(DT_INT64, TensorShape({3}));
|
||||
|
||||
Node* in0_node = nullptr;
|
||||
Node* in1_node = nullptr;
|
||||
if (manual_broadcast) {
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
auto vec0 = broadcasted_in0_shape.vec<int64>();
|
||||
auto vec1 = broadcasted_in1_shape.vec<int64>();
|
||||
vec0(i) = (i == 0 ? std::max(b0, b1) : in0.shape().dim_size(i));
|
||||
vec1(i) = (i == 0 ? std::max(b0, b1) : in1.shape().dim_size(i));
|
||||
}
|
||||
in0_node = BroadcastTo(g, test::graph::Constant(g, in0),
|
||||
test::graph::Constant(g, broadcasted_in0_shape));
|
||||
in1_node = BroadcastTo(g, test::graph::Constant(g, in1),
|
||||
test::graph::Constant(g, broadcasted_in1_shape));
|
||||
} else {
|
||||
in0_node = test::graph::Constant(g, in0);
|
||||
in1_node = test::graph::Constant(g, in1);
|
||||
}
|
||||
|
||||
BatchMatmulV2(g, in0_node, in1_node, false, false);
|
||||
return g;
|
||||
}
|
||||
|
||||
#define BM_BatchMatmulDev(B, M, K, N, TA, TB, T, TFTYPE, DEVICE) \
|
||||
static void \
|
||||
BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE( \
|
||||
int iters) { \
|
||||
testing::UseRealTime(); \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * B * M * K * N * 2); \
|
||||
test::Benchmark(#DEVICE, BatchMatmul<T>(B, M, K, N, TA, TB, TFTYPE)) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
BENCHMARK( \
|
||||
BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE);
|
||||
|
||||
#define BM_BatchMatmul(B, M, K, N, TA, TB) \
|
||||
BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, cpu);
|
||||
// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64,
|
||||
// cpu);
|
||||
// BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, gpu);
|
||||
/* Uncomment to enable benchmarks for double & complex types: */
|
||||
// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64,
|
||||
// gpu);
|
||||
// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, cpu); \
|
||||
// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, cpu);
|
||||
// \
|
||||
// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, gpu); \
|
||||
// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, gpu);
|
||||
|
||||
// Macro arguments names: --------------------------------------------------- //
|
||||
// B1: batch size of LHS
|
||||
// B2: batch size of RHS
|
||||
// M: outer dimension of LHS
|
||||
// K: inner dimensions of LHS and RHS
|
||||
// N: outer dimension of RHS
|
||||
// MB: boolean indicating whether to use manual broadcasting
|
||||
// T: C++ type of scalars (e.g. float, std::complex)
|
||||
// TT: TensorFlow type of scalars (e.g. DT_FLOAT, DT_COMPLEX128
|
||||
// D: Device (e.g. cpu, gpu)
|
||||
#define BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, T, TT, D) \
|
||||
static void \
|
||||
BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D( \
|
||||
int iters) { \
|
||||
testing::UseRealTime(); \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * std::max(B1, B2) * M * \
|
||||
K * N * 2); \
|
||||
test::Benchmark(#D, BatchMatmulWithBroadcast<T>(B1, B2, M, K, N, MB, TT)) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
BENCHMARK( \
|
||||
BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D);
|
||||
|
||||
#define BM_BatchMatmulBCast(B1, B2, M, K, N, MB) \
|
||||
BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, float, DT_FLOAT, cpu);
|
||||
|
||||
// Typical fully connected layers
|
||||
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, false);
|
||||
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, false);
|
||||
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, false);
|
||||
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, false);
|
||||
|
||||
// Square matmul.
|
||||
BM_BatchMatmulBCast(1, 128, 512, 512, 512, true);
|
||||
BM_BatchMatmulBCast(1, 128, 512, 512, 512, false);
|
||||
BM_BatchMatmulBCast(128, 1, 512, 512, 512, true);
|
||||
BM_BatchMatmulBCast(128, 1, 512, 512, 512, false);
|
||||
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, false);
|
||||
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, false);
|
||||
|
||||
// Matrix-vector multiplies.
|
||||
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, true);
|
||||
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, false);
|
||||
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, true);
|
||||
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, false);
|
||||
|
||||
// Vector-matrix multiplies.
|
||||
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, true);
|
||||
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, false);
|
||||
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, true);
|
||||
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, false);
|
||||
|
||||
// Typical fully connected layers
|
||||
BM_BatchMatmul(1, 1, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(1, 8, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(1, 16, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(1, 128, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(2, 1, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(2, 8, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(2, 16, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(2, 128, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(8, 1, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(8, 8, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(8, 16, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(8, 128, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(32, 1, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(32, 8, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(32, 16, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(32, 128, 1024, 1024, false, false);
|
||||
|
||||
// Square matmul.
|
||||
BM_BatchMatmul(1, 32, 32, 32, false, false);
|
||||
BM_BatchMatmul(1, 128, 128, 128, false, false);
|
||||
BM_BatchMatmul(1, 256, 256, 256, false, false);
|
||||
BM_BatchMatmul(1, 1024, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(1, 2048, 2048, 2048, false, false);
|
||||
BM_BatchMatmul(2, 32, 32, 32, false, false);
|
||||
BM_BatchMatmul(2, 128, 128, 128, false, false);
|
||||
BM_BatchMatmul(2, 256, 256, 256, false, false);
|
||||
BM_BatchMatmul(2, 1024, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(2, 2048, 2048, 2048, false, false);
|
||||
BM_BatchMatmul(4, 32, 32, 32, false, false);
|
||||
BM_BatchMatmul(4, 128, 128, 128, false, false);
|
||||
BM_BatchMatmul(4, 256, 256, 256, false, false);
|
||||
BM_BatchMatmul(4, 1024, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(4, 2048, 2048, 2048, false, false);
|
||||
BM_BatchMatmul(8, 32, 32, 32, false, false);
|
||||
BM_BatchMatmul(8, 128, 128, 128, false, false);
|
||||
BM_BatchMatmul(8, 256, 256, 256, false, false);
|
||||
BM_BatchMatmul(8, 1024, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(8, 2048, 2048, 2048, false, false);
|
||||
BM_BatchMatmul(32, 32, 32, 32, false, false);
|
||||
BM_BatchMatmul(32, 128, 128, 128, false, false);
|
||||
BM_BatchMatmul(32, 256, 256, 256, false, false);
|
||||
BM_BatchMatmul(32, 1024, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(32, 2048, 2048, 2048, false, false);
|
||||
|
||||
// Matrix-vector multiplies.
|
||||
BM_BatchMatmul(1, 10000, 200, 1, false, false);
|
||||
BM_BatchMatmul(8, 10000, 200, 1, false, false);
|
||||
BM_BatchMatmul(32, 10000, 200, 1, false, false);
|
||||
BM_BatchMatmul(1, 10000, 200, 1, true, false);
|
||||
BM_BatchMatmul(8, 10000, 200, 1, true, false);
|
||||
BM_BatchMatmul(32, 10000, 200, 1, true, false);
|
||||
BM_BatchMatmul(1, 10000, 200, 1, false, true);
|
||||
BM_BatchMatmul(8, 10000, 200, 1, false, true);
|
||||
BM_BatchMatmul(32, 10000, 200, 1, false, true);
|
||||
BM_BatchMatmul(1, 10000, 200, 1, true, true);
|
||||
BM_BatchMatmul(8, 10000, 200, 1, true, true);
|
||||
BM_BatchMatmul(32, 10000, 200, 1, true, true);
|
||||
|
||||
// Vector-matrix multiplies.
|
||||
BM_BatchMatmul(1, 1, 200, 10000, false, false);
|
||||
BM_BatchMatmul(8, 1, 200, 10000, false, false);
|
||||
BM_BatchMatmul(32, 1, 200, 10000, false, false);
|
||||
BM_BatchMatmul(1, 1, 200, 10000, true, false);
|
||||
BM_BatchMatmul(8, 1, 200, 10000, true, false);
|
||||
BM_BatchMatmul(32, 1, 200, 10000, true, false);
|
||||
BM_BatchMatmul(1, 1, 200, 10000, false, true);
|
||||
BM_BatchMatmul(8, 1, 200, 10000, false, true);
|
||||
BM_BatchMatmul(32, 1, 200, 10000, false, true);
|
||||
BM_BatchMatmul(1, 1, 200, 10000, true, true);
|
||||
BM_BatchMatmul(8, 1, 200, 10000, true, true);
|
||||
BM_BatchMatmul(32, 1, 200, 10000, true, true);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -33,8 +33,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/type_traits.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
#include "tensorflow/core/kernels/matmul_op_impl.h"
|
||||
#include "tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
@ -1,5 +1,5 @@
|
||||
func @Isinf_elem_type(%arg0: tensor<*xelem_type>)
|
||||
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
|
||||
-> tensor<*xi1> attributes {tf_entry, llvm.emit_c_interface} {
|
||||
%0 = "tf.IsInf"(%arg0) : (tensor<*xelem_type>) -> tensor<*xi1>
|
||||
return %0 : tensor<*xi1>
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
func @Isnan_elem_type(%arg0: tensor<*xelem_type>)
|
||||
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
|
||||
-> tensor<*xi1> attributes {tf_entry, llvm.emit_c_interface} {
|
||||
%0 = "tf.IsNan"(%arg0) : (tensor<*xelem_type>) -> tensor<*xi1>
|
||||
return %0 : tensor<*xi1>
|
||||
}
|
||||
|
@ -41,9 +41,13 @@ static Graph* Multinomial(int batch_size, int num_classes, int num_samples) {
|
||||
}
|
||||
|
||||
#define BM_MultinomialDev(DEVICE, B, C, S) \
|
||||
static void BM_Multinomial_##DEVICE##_##B##_##C##_##S(int iters) { \
|
||||
test::Benchmark(#DEVICE, Multinomial(B, C, S)).Run(iters); \
|
||||
testing::ItemsProcessed(static_cast<int64>(B) * C * S * iters); \
|
||||
static void BM_Multinomial_##DEVICE##_##B##_##C##_##S( \
|
||||
::testing::benchmark::State& state) { \
|
||||
test::Benchmark(#DEVICE, Multinomial(B, C, S), \
|
||||
/*old_benchmark_api*/ false) \
|
||||
.Run(state); \
|
||||
state.SetItemsProcessed(static_cast<int64>(B) * C * S * \
|
||||
state.iterations()); \
|
||||
} \
|
||||
BENCHMARK(BM_Multinomial_##DEVICE##_##B##_##C##_##S);
|
||||
|
||||
|
@ -103,18 +103,18 @@ enum CONV_OP {
|
||||
|
||||
} // namespace
|
||||
|
||||
static void BM_ConvFloat(int iters, int batch, int rows, int cols, int in_depth,
|
||||
int out_depth, int filter_rows, int filter_cols,
|
||||
CONV_OP op, int num_threads, int stride,
|
||||
Padding padding, bool use_gpu, DataType data_type,
|
||||
static void BM_ConvFloat(::testing::benchmark::State& state, int batch,
|
||||
int rows, int cols, int in_depth, int out_depth,
|
||||
int filter_rows, int filter_cols, CONV_OP op,
|
||||
int num_threads, int stride, Padding padding,
|
||||
bool use_gpu, DataType data_type,
|
||||
const string& label) {
|
||||
testing::StopTiming();
|
||||
if (!IsGoogleCudaEnabled() && use_gpu) {
|
||||
testing::SetLabel(
|
||||
state.SetLabel(
|
||||
strings::StrCat("Skipping GPU test (no --config=cuda): ", label));
|
||||
return;
|
||||
}
|
||||
testing::SetLabel(label);
|
||||
state.SetLabel(label);
|
||||
|
||||
// Set the number of threads
|
||||
SessionOptions options;
|
||||
@ -221,10 +221,10 @@ static void BM_ConvFloat(int iters, int batch, int rows, int cols, int in_depth,
|
||||
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph, g));
|
||||
|
||||
string device = use_gpu ? "gpu" : "cpu";
|
||||
testing::UseRealTime();
|
||||
testing::StartTiming();
|
||||
test::Benchmark(device, g, &options).Run(iters);
|
||||
testing::ItemsProcessed(num_ops * iters);
|
||||
test::Benchmark(device, g, &options, nullptr, nullptr, "",
|
||||
/*old_benchmark_api*/ false)
|
||||
.Run(state);
|
||||
state.SetItemsProcessed(num_ops * state.iterations());
|
||||
}
|
||||
|
||||
// BS: batch_size
|
||||
@ -235,48 +235,52 @@ static void BM_ConvFloat(int iters, int batch, int rows, int cols, int in_depth,
|
||||
// KR: kernel_rows
|
||||
// KC: kernel_cols
|
||||
#define BM_ConvFloatFwd(BS, R, C, ID, OD, KR, KC, STR, PAD, LABEL) \
|
||||
static void BM_ConvFloatFwdCPU1_##LABEL(int iters) { \
|
||||
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR, \
|
||||
static void BM_ConvFloatFwdCPU1_##LABEL( \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR, \
|
||||
PAD, false, DT_FLOAT, \
|
||||
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
|
||||
KR, "_", KC, "_", STR, "_", PAD, "_f_cpu1")); \
|
||||
} \
|
||||
static void BM_ConvFloatFwdCPU4_##LABEL(int iters) { \
|
||||
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 4, STR, \
|
||||
static void BM_ConvFloatFwdCPU4_##LABEL( \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 4, STR, \
|
||||
PAD, false, DT_FLOAT, \
|
||||
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
|
||||
KR, "_", KC, "_", STR, "_", PAD, "_f_cpu4")); \
|
||||
} \
|
||||
static void BM_ConvFloatFusedCPU1_##LABEL(int iters) { \
|
||||
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FUSED, 1, STR, PAD, \
|
||||
static void BM_ConvFloatFusedCPU1_##LABEL( \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_FUSED, 1, STR, PAD, \
|
||||
false, DT_FLOAT, \
|
||||
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
|
||||
KR, "_", KC, "_", STR, "_", PAD, "_f_cpu1")); \
|
||||
} \
|
||||
static void BM_ConvFloatFusedCPU4_##LABEL(int iters) { \
|
||||
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FUSED, 4, STR, PAD, \
|
||||
static void BM_ConvFloatFusedCPU4_##LABEL( \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_FUSED, 4, STR, PAD, \
|
||||
false, DT_FLOAT, \
|
||||
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
|
||||
KR, "_", KC, "_", STR, "_", PAD, "_f_cpu4")); \
|
||||
} \
|
||||
static void BM_ConvFloatFwdGPU_##LABEL(int iters) { \
|
||||
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR, \
|
||||
static void BM_ConvFloatFwdGPU_##LABEL(::testing::benchmark::State& state) { \
|
||||
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR, \
|
||||
PAD, true, DT_FLOAT, \
|
||||
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
|
||||
KR, "_", KC, "_", STR, "_", PAD, "_f_gpu")); \
|
||||
} \
|
||||
static void BM_ConvHalfFwdGPU_##LABEL(int iters) { \
|
||||
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR, \
|
||||
static void BM_ConvHalfFwdGPU_##LABEL(::testing::benchmark::State& state) { \
|
||||
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR, \
|
||||
PAD, true, DT_HALF, \
|
||||
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
|
||||
KR, "_", KC, "_", STR, "_", PAD, "_h_gpu")); \
|
||||
} \
|
||||
BENCHMARK(BM_ConvFloatFwdCPU1_##LABEL); \
|
||||
BENCHMARK(BM_ConvFloatFwdCPU4_##LABEL); \
|
||||
BENCHMARK(BM_ConvFloatFusedCPU1_##LABEL); \
|
||||
BENCHMARK(BM_ConvFloatFusedCPU4_##LABEL); \
|
||||
BENCHMARK(BM_ConvFloatFwdGPU_##LABEL); \
|
||||
BENCHMARK(BM_ConvHalfFwdGPU_##LABEL)
|
||||
BENCHMARK(BM_ConvFloatFwdCPU1_##LABEL)->UseRealTime(); \
|
||||
BENCHMARK(BM_ConvFloatFwdCPU4_##LABEL)->UseRealTime(); \
|
||||
BENCHMARK(BM_ConvFloatFusedCPU1_##LABEL)->UseRealTime(); \
|
||||
BENCHMARK(BM_ConvFloatFusedCPU4_##LABEL)->UseRealTime(); \
|
||||
BENCHMARK(BM_ConvFloatFwdGPU_##LABEL)->UseRealTime(); \
|
||||
BENCHMARK(BM_ConvHalfFwdGPU_##LABEL)->UseRealTime()
|
||||
|
||||
BM_ConvFloatFwd(32, 5, 5, 1248, 128, 1, 1, 1, SAME, conv0);
|
||||
BM_ConvFloatFwd(32, 8, 8, 384, 384, 1, 3, 1, SAME, conv1);
|
||||
@ -335,62 +339,69 @@ BM_ConvFloatFwd(32, 73, 73, 64, 64, 1, 1, 1, VALID, conv53);
|
||||
BM_ConvFloatFwd(32, 147, 147, 24, 64, 1, 1, 1, VALID, conv54);
|
||||
|
||||
#define BM_ConvFloatBkInAndFilter(BS, R, C, ID, OD, KR, KC, STR, PAD, LABEL) \
|
||||
static void BM_ConvFloatBkInCPU1_##LABEL(int iters) { \
|
||||
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1, \
|
||||
static void BM_ConvFloatBkInCPU1_##LABEL( \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1, \
|
||||
STR, PAD, false, DT_FLOAT, \
|
||||
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
|
||||
KR, "_", KC, "_", STR, "_", PAD, "_cpu1")); \
|
||||
} \
|
||||
static void BM_ConvFloatBkInCPU4_##LABEL(int iters) { \
|
||||
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 4, \
|
||||
static void BM_ConvFloatBkInCPU4_##LABEL( \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 4, \
|
||||
STR, PAD, false, DT_FLOAT, \
|
||||
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
|
||||
KR, "_", KC, "_", STR, "_", PAD, "_cpu4")); \
|
||||
} \
|
||||
static void BM_ConvFloatBkInGPU_##LABEL(int iters) { \
|
||||
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1, \
|
||||
static void BM_ConvFloatBkInGPU_##LABEL( \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1, \
|
||||
STR, PAD, true, DT_FLOAT, \
|
||||
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
|
||||
KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \
|
||||
} \
|
||||
static void BM_ConvFloatBkFilterCPU1_##LABEL(int iters) { \
|
||||
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
|
||||
static void BM_ConvFloatBkFilterCPU1_##LABEL( \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
|
||||
STR, PAD, false, DT_FLOAT, \
|
||||
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
|
||||
KR, "_", KC, "_", STR, "_", PAD, "_cpu1")); \
|
||||
} \
|
||||
static void BM_ConvFloatBkFilterCPU4_##LABEL(int iters) { \
|
||||
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 4, \
|
||||
static void BM_ConvFloatBkFilterCPU4_##LABEL( \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 4, \
|
||||
STR, PAD, false, DT_FLOAT, \
|
||||
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
|
||||
KR, "_", KC, "_", STR, "_", PAD, "_cpu4")); \
|
||||
} \
|
||||
static void BM_ConvFloatBkFilterGPU_##LABEL(int iters) { \
|
||||
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
|
||||
static void BM_ConvFloatBkFilterGPU_##LABEL( \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
|
||||
STR, PAD, true, DT_FLOAT, \
|
||||
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
|
||||
KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \
|
||||
} \
|
||||
static void BM_ConvHalfBkInGPU_##LABEL(int iters) { \
|
||||
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1, \
|
||||
static void BM_ConvHalfBkInGPU_##LABEL(::testing::benchmark::State& state) { \
|
||||
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1, \
|
||||
STR, PAD, true, DT_HALF, \
|
||||
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
|
||||
KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \
|
||||
} \
|
||||
static void BM_ConvHalfBkFilterGPU_##LABEL(int iters) { \
|
||||
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
|
||||
static void BM_ConvHalfBkFilterGPU_##LABEL( \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
|
||||
STR, PAD, true, DT_HALF, \
|
||||
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
|
||||
KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \
|
||||
} \
|
||||
BENCHMARK(BM_ConvFloatBkInCPU1_##LABEL); \
|
||||
BENCHMARK(BM_ConvFloatBkInCPU4_##LABEL); \
|
||||
BENCHMARK(BM_ConvFloatBkInGPU_##LABEL); \
|
||||
BENCHMARK(BM_ConvFloatBkFilterCPU1_##LABEL); \
|
||||
BENCHMARK(BM_ConvFloatBkFilterCPU4_##LABEL); \
|
||||
BENCHMARK(BM_ConvFloatBkFilterGPU_##LABEL); \
|
||||
BENCHMARK(BM_ConvHalfBkInGPU_##LABEL); \
|
||||
BENCHMARK(BM_ConvHalfBkFilterGPU_##LABEL)
|
||||
BENCHMARK(BM_ConvFloatBkInCPU1_##LABEL)->UseRealTime(); \
|
||||
BENCHMARK(BM_ConvFloatBkInCPU4_##LABEL)->UseRealTime(); \
|
||||
BENCHMARK(BM_ConvFloatBkInGPU_##LABEL)->UseRealTime(); \
|
||||
BENCHMARK(BM_ConvFloatBkFilterCPU1_##LABEL)->UseRealTime(); \
|
||||
BENCHMARK(BM_ConvFloatBkFilterCPU4_##LABEL)->UseRealTime(); \
|
||||
BENCHMARK(BM_ConvFloatBkFilterGPU_##LABEL)->UseRealTime(); \
|
||||
BENCHMARK(BM_ConvHalfBkInGPU_##LABEL)->UseRealTime(); \
|
||||
BENCHMARK(BM_ConvHalfBkFilterGPU_##LABEL)->UseRealTime()
|
||||
|
||||
// Benchmarks from the inception model
|
||||
|
||||
@ -453,8 +464,8 @@ BM_ConvFloatBkInAndFilter(32, 147, 147, 24, 64, 1, 1, 1, VALID, conv54);
|
||||
#define BM_ConvFloatBkFCPU(BS, R, C, ID, OD, KR, KC, TH, LABEL) \
|
||||
static void \
|
||||
BM_ConvFloatBkFCPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC##_##TH( \
|
||||
int iters) { \
|
||||
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, TH, \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, TH, \
|
||||
1, VALID, false, DT_FLOAT, LABEL); \
|
||||
} \
|
||||
BENCHMARK( \
|
||||
@ -469,17 +480,19 @@ BM_ConvFloatBkFCPU(128, 13, 13, 384, 384, 3, 3, 4, "convnet-layer5");
|
||||
|
||||
#define BM_ConvFloatBkFGPU(BS, R, C, ID, OD, KR, KC, LABEL) \
|
||||
static void BM_ConvFloatBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC( \
|
||||
int iters) { \
|
||||
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
|
||||
1, VALID, true, DT_FLOAT, LABEL); \
|
||||
} \
|
||||
static void BM_ConvHalfBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC( \
|
||||
int iters) { \
|
||||
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
|
||||
1, VALID, true, DT_HALF, LABEL); \
|
||||
} \
|
||||
BENCHMARK(BM_ConvFloatBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC); \
|
||||
BENCHMARK(BM_ConvHalfBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC)
|
||||
BENCHMARK(BM_ConvFloatBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC) \
|
||||
->UseRealTime(); \
|
||||
BENCHMARK(BM_ConvHalfBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC) \
|
||||
->UseRealTime()
|
||||
|
||||
// Benchmarks from https://github.com/soumith/convnet-benchmarks
|
||||
BM_ConvFloatBkFGPU(128, 128, 128, 3, 96, 11, 11, "convnet-layer1");
|
||||
@ -498,19 +511,19 @@ enum DEPTHWISE_CONV_OP {
|
||||
|
||||
} // namespace
|
||||
|
||||
static void BM_ConvFloatDepthwise(int iters, int batch, int rows, int cols,
|
||||
int in_depth, int depth_multiplier,
|
||||
int out_depth, int filter_rows,
|
||||
int filter_cols, DEPTHWISE_CONV_OP op,
|
||||
int num_threads, int stride, Padding padding,
|
||||
bool use_gpu, const string& label) {
|
||||
testing::StopTiming();
|
||||
static void BM_ConvFloatDepthwise(::testing::benchmark::State& state, int batch,
|
||||
int rows, int cols, int in_depth,
|
||||
int depth_multiplier, int out_depth,
|
||||
int filter_rows, int filter_cols,
|
||||
DEPTHWISE_CONV_OP op, int num_threads,
|
||||
int stride, Padding padding, bool use_gpu,
|
||||
const string& label) {
|
||||
if (!IsGoogleCudaEnabled() && use_gpu) {
|
||||
testing::SetLabel(
|
||||
state.SetLabel(
|
||||
strings::StrCat("Skipping GPU test (no --config=cuda): ", label));
|
||||
return;
|
||||
}
|
||||
testing::SetLabel(label);
|
||||
state.SetLabel(label);
|
||||
|
||||
// Set the number of threads
|
||||
SessionOptions options;
|
||||
@ -603,10 +616,10 @@ static void BM_ConvFloatDepthwise(int iters, int batch, int rows, int cols,
|
||||
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph, g));
|
||||
|
||||
string device = use_gpu ? "gpu" : "cpu";
|
||||
testing::UseRealTime();
|
||||
testing::StartTiming();
|
||||
test::Benchmark(device, g, &options).Run(iters);
|
||||
testing::ItemsProcessed(num_ops * iters);
|
||||
test::Benchmark(device, g, &options, nullptr, nullptr, "",
|
||||
/*old_benchmark_api=*/false)
|
||||
.Run(state);
|
||||
state.SetItemsProcessed(num_ops * state.iterations());
|
||||
}
|
||||
|
||||
// BS: batch_size
|
||||
@ -622,30 +635,33 @@ static void BM_ConvFloatDepthwise(int iters, int batch, int rows, int cols,
|
||||
|
||||
#define BM_ConvFloatDepthwiseFwd(BS, R, C, ID, DM, OD, KR, KC, STR, PAD, \
|
||||
LABEL) \
|
||||
static void BM_ConvFloatDepthwiseFwdCPU1_##LABEL(int iters) { \
|
||||
static void BM_ConvFloatDepthwiseFwdCPU1_##LABEL( \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_ConvFloatDepthwise( \
|
||||
iters, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_FWD, 1, STR, \
|
||||
state, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_FWD, 1, STR, \
|
||||
PAD, false, \
|
||||
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \
|
||||
KR, "_", KC, "_", STR, "_", PAD, "_cpu1")); \
|
||||
} \
|
||||
static void BM_ConvFloatDepthwiseFwdCPU4_##LABEL(int iters) { \
|
||||
static void BM_ConvFloatDepthwiseFwdCPU4_##LABEL( \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_ConvFloatDepthwise( \
|
||||
iters, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_FWD, 4, STR, \
|
||||
state, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_FWD, 4, STR, \
|
||||
PAD, false, \
|
||||
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \
|
||||
KR, "_", KC, "_", STR, "_", PAD, "_cpu4")); \
|
||||
} \
|
||||
static void BM_ConvFloatDepthwiseFwdGPU_##LABEL(int iters) { \
|
||||
static void BM_ConvFloatDepthwiseFwdGPU_##LABEL( \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_ConvFloatDepthwise( \
|
||||
iters, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_FWD, 1, STR, \
|
||||
state, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_FWD, 1, STR, \
|
||||
PAD, true, \
|
||||
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \
|
||||
KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \
|
||||
} \
|
||||
BENCHMARK(BM_ConvFloatDepthwiseFwdCPU1_##LABEL); \
|
||||
BENCHMARK(BM_ConvFloatDepthwiseFwdCPU4_##LABEL); \
|
||||
BENCHMARK(BM_ConvFloatDepthwiseFwdGPU_##LABEL);
|
||||
BENCHMARK(BM_ConvFloatDepthwiseFwdCPU1_##LABEL)->UseRealTime(); \
|
||||
BENCHMARK(BM_ConvFloatDepthwiseFwdCPU4_##LABEL)->UseRealTime(); \
|
||||
BENCHMARK(BM_ConvFloatDepthwiseFwdGPU_##LABEL)->UseRealTime();
|
||||
|
||||
// The configurations below are mostly from mobilenet models.
|
||||
BM_ConvFloatDepthwiseFwd(32, 112, 112, 3, 8, 24, 3, 3, 1, SAME, conv0);
|
||||
@ -662,53 +678,59 @@ BM_ConvFloatDepthwiseFwd(1, 100, 100, 72, 1, 72, 3, 3, 1, SAME, conv9);
|
||||
BM_ConvFloatDepthwiseFwd(1, 100, 100, 72, 1, 72, 5, 5, 1, SAME, conv10);
|
||||
|
||||
#define BM_ConvFloatDepthwiseBk(BS, R, C, ID, DM, OD, KR, KC, STR, PAD, LABEL) \
|
||||
static void BM_ConvFloatDepthwiseBkInCPU1_##LABEL(int iters) { \
|
||||
static void BM_ConvFloatDepthwiseBkInCPU1_##LABEL( \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_ConvFloatDepthwise( \
|
||||
iters, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_BACKPROP_INPUT, \
|
||||
state, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_BACKPROP_INPUT, \
|
||||
1, STR, PAD, false, \
|
||||
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \
|
||||
KR, "_", KC, "_", STR, "_", PAD, "_cpu1")); \
|
||||
} \
|
||||
static void BM_ConvFloatDepthwiseBkInCPU4_##LABEL(int iters) { \
|
||||
static void BM_ConvFloatDepthwiseBkInCPU4_##LABEL( \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_ConvFloatDepthwise( \
|
||||
iters, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_BACKPROP_INPUT, \
|
||||
state, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_BACKPROP_INPUT, \
|
||||
4, STR, PAD, false, \
|
||||
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \
|
||||
KR, "_", KC, "_", STR, "_", PAD, "_cpu4")); \
|
||||
} \
|
||||
static void BM_ConvFloatDepthwiseBkInGPU_##LABEL(int iters) { \
|
||||
static void BM_ConvFloatDepthwiseBkInGPU_##LABEL( \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_ConvFloatDepthwise( \
|
||||
iters, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_BACKPROP_INPUT, \
|
||||
state, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_BACKPROP_INPUT, \
|
||||
4, STR, PAD, true, \
|
||||
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \
|
||||
KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \
|
||||
} \
|
||||
static void BM_ConvFloatDepthwiseBkFilterCPU1_##LABEL(int iters) { \
|
||||
static void BM_ConvFloatDepthwiseBkFilterCPU1_##LABEL( \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_ConvFloatDepthwise( \
|
||||
iters, BS, R, C, ID, DM, OD, KR, KC, \
|
||||
state, BS, R, C, ID, DM, OD, KR, KC, \
|
||||
DEPTHWISE_CONV_OP_BACKPROP_FILTER, 1, STR, PAD, false, \
|
||||
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \
|
||||
KR, "_", KC, "_", STR, "_", PAD, "_cpu1")); \
|
||||
} \
|
||||
static void BM_ConvFloatDepthwiseBkFilterCPU4_##LABEL(int iters) { \
|
||||
static void BM_ConvFloatDepthwiseBkFilterCPU4_##LABEL( \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_ConvFloatDepthwise( \
|
||||
iters, BS, R, C, ID, DM, OD, KR, KC, \
|
||||
state, BS, R, C, ID, DM, OD, KR, KC, \
|
||||
DEPTHWISE_CONV_OP_BACKPROP_FILTER, 4, STR, PAD, false, \
|
||||
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \
|
||||
KR, "_", KC, "_", STR, "_", PAD, "_cpu4")); \
|
||||
} \
|
||||
static void BM_ConvFloatDepthwiseBkFilterGPU_##LABEL(int iters) { \
|
||||
static void BM_ConvFloatDepthwiseBkFilterGPU_##LABEL( \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_ConvFloatDepthwise( \
|
||||
iters, BS, R, C, ID, DM, OD, KR, KC, \
|
||||
state, BS, R, C, ID, DM, OD, KR, KC, \
|
||||
DEPTHWISE_CONV_OP_BACKPROP_FILTER, 4, STR, PAD, true, \
|
||||
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \
|
||||
KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \
|
||||
} \
|
||||
BENCHMARK(BM_ConvFloatDepthwiseBkInCPU1_##LABEL); \
|
||||
BENCHMARK(BM_ConvFloatDepthwiseBkInCPU4_##LABEL); \
|
||||
BENCHMARK(BM_ConvFloatDepthwiseBkInGPU_##LABEL); \
|
||||
BENCHMARK(BM_ConvFloatDepthwiseBkFilterCPU1_##LABEL); \
|
||||
BENCHMARK(BM_ConvFloatDepthwiseBkFilterCPU4_##LABEL); \
|
||||
BENCHMARK(BM_ConvFloatDepthwiseBkInCPU1_##LABEL)->UseRealTime(); \
|
||||
BENCHMARK(BM_ConvFloatDepthwiseBkInCPU4_##LABEL)->UseRealTime(); \
|
||||
BENCHMARK(BM_ConvFloatDepthwiseBkInGPU_##LABEL)->UseRealTime(); \
|
||||
BENCHMARK(BM_ConvFloatDepthwiseBkFilterCPU1_##LABEL)->UseRealTime(); \
|
||||
BENCHMARK(BM_ConvFloatDepthwiseBkFilterCPU4_##LABEL)->UseRealTime(); \
|
||||
BENCHMARK(BM_ConvFloatDepthwiseBkFilterGPU_##LABEL)
|
||||
|
||||
// The configurations below are mostly from mobilenet models.
|
||||
@ -732,10 +754,9 @@ BM_ConvFloatDepthwiseBk(32, 112, 112, 8, 3, 24, 3, 3, 1, SAME, conv12);
|
||||
BM_ConvFloatDepthwiseBk(32, 112, 112, 12, 2, 24, 3, 3, 1, SAME, conv13);
|
||||
BM_ConvFloatDepthwiseBk(32, 112, 112, 24, 1, 24, 3, 3, 1, SAME, conv14);
|
||||
|
||||
static void BM_LRNFloat(int iters, int depth, int cols, int rows,
|
||||
int batch_size, int range, int num_threads,
|
||||
static void BM_LRNFloat(::testing::benchmark::State& state, int depth, int cols,
|
||||
int rows, int batch_size, int range, int num_threads,
|
||||
const string& label) {
|
||||
tensorflow::testing::StopTiming();
|
||||
std::unique_ptr<Device> device(
|
||||
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
|
||||
|
||||
@ -778,26 +799,24 @@ static void BM_LRNFloat(int iters, int depth, int cols, int rows,
|
||||
std::unique_ptr<OpKernelContext> context(new OpKernelContext(¶ms));
|
||||
|
||||
op->Compute(context.get());
|
||||
testing::UseRealTime();
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
for (auto s : state) {
|
||||
delete context->release_output(0).tensor;
|
||||
op->Compute(context.get());
|
||||
}
|
||||
tensorflow::testing::StopTiming();
|
||||
testing::ItemsProcessed(context->mutable_output(0)->NumElements() * iters *
|
||||
(2 * range + 1) * 2);
|
||||
testing::SetLabel(label);
|
||||
state.SetItemsProcessed(context->mutable_output(0)->NumElements() *
|
||||
state.iterations() * (2 * range + 1) * 2);
|
||||
state.SetLabel(label);
|
||||
}
|
||||
|
||||
#define BM_LRNFloatFwdCPU(DEPTH, COLS, ROWS, BATCH, RANGE, THREADS, LABEL) \
|
||||
static void \
|
||||
BM_LRNFloat_##DEPTH##_##COLS##_##ROWS##_##BATCH##_##RANGE##_##THREADS( \
|
||||
int iters) { \
|
||||
BM_LRNFloat(iters, DEPTH, COLS, ROWS, BATCH, RANGE, THREADS, LABEL); \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_LRNFloat(state, DEPTH, COLS, ROWS, BATCH, RANGE, THREADS, LABEL); \
|
||||
} \
|
||||
BENCHMARK( \
|
||||
BM_LRNFloat_##DEPTH##_##COLS##_##ROWS##_##BATCH##_##RANGE##_##THREADS)
|
||||
BM_LRNFloat_##DEPTH##_##COLS##_##ROWS##_##BATCH##_##RANGE##_##THREADS) \
|
||||
->UseRealTime()
|
||||
|
||||
// clang-format off
|
||||
// DEPTH, COLS, ROWS, BATCH, RANGE, THREADS, LABEL
|
||||
@ -815,10 +834,10 @@ BM_LRNFloatFwdCPU(192, 56, 56, 32, 5, 8, "lrn 8 threads");
|
||||
/*
|
||||
AvgPooling Op
|
||||
*/
|
||||
static void BM_AvgPool(int iters, int batch_size, int rows, int cols, int depth,
|
||||
int kernel_rows, int kernel_cols, int stride,
|
||||
Padding padding, int num_threads, const string& label) {
|
||||
tensorflow::testing::StopTiming();
|
||||
static void BM_AvgPool(::testing::benchmark::State& state, int batch_size,
|
||||
int rows, int cols, int depth, int kernel_rows,
|
||||
int kernel_cols, int stride, Padding padding,
|
||||
int num_threads, const string& label) {
|
||||
std::unique_ptr<Device> device(
|
||||
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
|
||||
|
||||
@ -860,16 +879,13 @@ static void BM_AvgPool(int iters, int batch_size, int rows, int cols, int depth,
|
||||
new OpKernelContext(¶ms));
|
||||
|
||||
op->Compute(avgpool_context.get());
|
||||
testing::UseRealTime();
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
for (auto s : state) {
|
||||
delete avgpool_context->release_output(0).tensor;
|
||||
op->Compute(avgpool_context.get());
|
||||
}
|
||||
tensorflow::testing::StopTiming();
|
||||
testing::ItemsProcessed(avgpool_context->mutable_output(0)->NumElements() *
|
||||
iters);
|
||||
testing::SetLabel(label);
|
||||
state.SetItemsProcessed(avgpool_context->mutable_output(0)->NumElements() *
|
||||
state.iterations());
|
||||
state.SetLabel(label);
|
||||
}
|
||||
|
||||
// BS: batch_size
|
||||
@ -883,11 +899,12 @@ static void BM_AvgPool(int iters, int batch_size, int rows, int cols, int depth,
|
||||
#define BM_AvgPoolFwdCPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL) \
|
||||
static void \
|
||||
BM_AvgPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH( \
|
||||
int iters) { \
|
||||
BM_AvgPool(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL); \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_AvgPool(state, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL); \
|
||||
} \
|
||||
BENCHMARK( \
|
||||
BM_AvgPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH)
|
||||
BM_AvgPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH) \
|
||||
->UseRealTime()
|
||||
|
||||
// Labels are taken from the 2014-July-24 version of imagenet
|
||||
BM_AvgPoolFwdCPU(32, 112, 112, 64, 3, 3, 2, VALID, 1, "avgpool0_VALID");
|
||||
@ -907,11 +924,10 @@ BM_AvgPoolFwdCPU(32, 56, 56, 192, 3, 3, 2, SAME, 4, "avgpool1_SAME");
|
||||
BM_AvgPoolFwdCPU(32, 28, 28, 352, 3, 3, 2, SAME, 4, "avgpool4_SAME");
|
||||
BM_AvgPoolFwdCPU(32, 14, 14, 576, 3, 3, 2, SAME, 4, "avgpool10_SAME");
|
||||
|
||||
static void BM_AvgPoolBk(int iters, int batch_size, int rows, int cols,
|
||||
int depth, int kernel_rows, int kernel_cols,
|
||||
int stride, Padding padding, int num_threads,
|
||||
const string& label) {
|
||||
tensorflow::testing::StopTiming();
|
||||
static void BM_AvgPoolBk(::testing::benchmark::State& state, int batch_size,
|
||||
int rows, int cols, int depth, int kernel_rows,
|
||||
int kernel_cols, int stride, Padding padding,
|
||||
int num_threads, const string& label) {
|
||||
std::unique_ptr<Device> device(
|
||||
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
|
||||
|
||||
@ -966,16 +982,13 @@ static void BM_AvgPoolBk(int iters, int batch_size, int rows, int cols,
|
||||
new OpKernelContext(¶ms));
|
||||
|
||||
op->Compute(avgpool_context.get());
|
||||
testing::UseRealTime();
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
for (auto s : state) {
|
||||
delete avgpool_context->release_output(0).tensor;
|
||||
op->Compute(avgpool_context.get());
|
||||
}
|
||||
tensorflow::testing::StopTiming();
|
||||
testing::ItemsProcessed(avgpool_context->mutable_output(0)->NumElements() *
|
||||
iters);
|
||||
testing::SetLabel(label);
|
||||
state.SetItemsProcessed(avgpool_context->mutable_output(0)->NumElements() *
|
||||
state.iterations());
|
||||
state.SetLabel(label);
|
||||
}
|
||||
|
||||
// BS: batch_size
|
||||
@ -987,14 +1000,17 @@ static void BM_AvgPoolBk(int iters, int batch_size, int rows, int cols,
|
||||
// ST: stride. We use the same stride for both directions.
|
||||
// PT: padding
|
||||
// The resulted symbol is too long. Need to use two macros to fit in 80-chars
|
||||
// NOLINTBEGIN
|
||||
#define BM_AvgPoolBkCPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL) \
|
||||
static void \
|
||||
BM_AvgPoolBk_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH( \
|
||||
int iters) { \
|
||||
BM_AvgPoolBk(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL); \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_AvgPoolBk(state, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL); \
|
||||
} \
|
||||
BENCHMARK( \
|
||||
BM_AvgPoolBk_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH)
|
||||
BM_AvgPoolBk_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH) \
|
||||
->UseRealTime()
|
||||
// NOLINTEND
|
||||
|
||||
// Shapes taken from the 2015/05/16 inception model
|
||||
BM_AvgPoolBkCPU(32, 35, 35, 192, 3, 3, 1, SAME, 1, "avgpool_grad0_SAME");
|
||||
@ -1010,10 +1026,10 @@ BM_AvgPoolBkCPU(32, 8, 8, 2048, 8, 8, 1, VALID, 1, "avgpool_grad8_VALID");
|
||||
/*
|
||||
MaxPooling Op
|
||||
*/
|
||||
static void BM_MaxPool(int iters, int batch_size, int rows, int cols, int depth,
|
||||
int kernel_rows, int kernel_cols, int stride,
|
||||
Padding padding, int num_threads, const string& label) {
|
||||
tensorflow::testing::StopTiming();
|
||||
static void BM_MaxPool(::testing::benchmark::State& state, int batch_size,
|
||||
int rows, int cols, int depth, int kernel_rows,
|
||||
int kernel_cols, int stride, Padding padding,
|
||||
int num_threads, const string& label) {
|
||||
SessionOptions options;
|
||||
options.config.set_intra_op_parallelism_threads(num_threads);
|
||||
|
||||
@ -1057,16 +1073,13 @@ static void BM_MaxPool(int iters, int batch_size, int rows, int cols, int depth,
|
||||
new OpKernelContext(¶ms));
|
||||
|
||||
op->Compute(maxpool_context.get());
|
||||
testing::UseRealTime();
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
for (auto s : state) {
|
||||
delete maxpool_context->release_output(0).tensor;
|
||||
op->Compute(maxpool_context.get());
|
||||
}
|
||||
tensorflow::testing::StopTiming();
|
||||
testing::ItemsProcessed(maxpool_context->mutable_output(0)->NumElements() *
|
||||
iters);
|
||||
testing::SetLabel(label);
|
||||
state.SetItemsProcessed(maxpool_context->mutable_output(0)->NumElements() *
|
||||
state.iterations());
|
||||
state.SetLabel(label);
|
||||
}
|
||||
|
||||
// BS: batch_size
|
||||
@ -1080,11 +1093,12 @@ static void BM_MaxPool(int iters, int batch_size, int rows, int cols, int depth,
|
||||
#define BM_MaxPoolFwdCPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL) \
|
||||
static void \
|
||||
BM_MaxPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH( \
|
||||
int iters) { \
|
||||
BM_MaxPool(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL); \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_MaxPool(state, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL); \
|
||||
} \
|
||||
BENCHMARK( \
|
||||
BM_MaxPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH)
|
||||
BM_MaxPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH) \
|
||||
->UseRealTime()
|
||||
|
||||
// Labels are taken from the 2014-July-24 version of imagenet
|
||||
/* TODO XXX
|
||||
@ -1106,10 +1120,10 @@ BM_MaxPoolFwdCPU(32, 56, 56, 192, 3, 3, 2, SAME, 4, "maxpool1_SAME");
|
||||
BM_MaxPoolFwdCPU(32, 28, 28, 352, 3, 3, 2, SAME, 4, "maxpool4_SAME");
|
||||
BM_MaxPoolFwdCPU(32, 14, 14, 576, 3, 3, 2, SAME, 4, "maxpool10_SAME");
|
||||
|
||||
static void BM_MaxPoolBk(int iters, int batch_size, int rows, int cols,
|
||||
int depth, int kernel_rows, int kernel_cols,
|
||||
int stride, Padding padding, int num_threads,
|
||||
bool use_gpu, const string& label) {
|
||||
static void BM_MaxPoolBk(::testing::benchmark::State& state, int batch_size,
|
||||
int rows, int cols, int depth, int kernel_rows,
|
||||
int kernel_cols, int stride, Padding padding,
|
||||
int num_threads, bool use_gpu, const string& label) {
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
int64 out_height, out_width, pad_rows, pad_cols;
|
||||
@ -1138,11 +1152,11 @@ static void BM_MaxPoolBk(int iters, int batch_size, int rows, int cols,
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
TF_CHECK_OK(root.ToGraph(g));
|
||||
string device = use_gpu ? "gpu" : "cpu";
|
||||
testing::UseRealTime();
|
||||
test::Benchmark(device, g).Run(iters);
|
||||
test::Benchmark(device, g, /*old_benchmark_api*/ false).Run(state);
|
||||
|
||||
testing::ItemsProcessed(batch_size * rows * cols * depth * iters);
|
||||
testing::SetLabel(label);
|
||||
state.SetItemsProcessed(batch_size * rows * cols * depth *
|
||||
state.iterations());
|
||||
state.SetLabel(label);
|
||||
}
|
||||
|
||||
// BS: batch_size
|
||||
@ -1159,23 +1173,23 @@ static void BM_MaxPoolBk(int iters, int batch_size, int rows, int cols,
|
||||
static void \
|
||||
BM_MaxPoolBk_GPU_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_ \
|
||||
##PT##_##TH( \
|
||||
int iters) { \
|
||||
BM_MaxPoolBk(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, true, LABEL); \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_MaxPoolBk(state, BS, IR, IC, ND, KR, KC, ST, PT, TH, true, LABEL); \
|
||||
} \
|
||||
BENCHMARK( \
|
||||
BM_MaxPoolBk_GPU_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_ \
|
||||
##PT##_##TH) \
|
||||
##PT##_##TH)->UseRealTime()
|
||||
|
||||
#define BM_MaxPoolBkCPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL) \
|
||||
static void \
|
||||
BM_MaxPoolBk_CPU_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_ \
|
||||
##PT##_##TH( \
|
||||
int iters) { \
|
||||
BM_MaxPoolBk(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, false, LABEL); \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_MaxPoolBk(state, BS, IR, IC, ND, KR, KC, ST, PT, TH, false, LABEL); \
|
||||
} \
|
||||
BENCHMARK( \
|
||||
BM_MaxPoolBk_CPU_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_ \
|
||||
##PT##_##TH)
|
||||
##PT##_##TH)->UseRealTime()
|
||||
// clang-format on
|
||||
|
||||
// Shapes taken from the 2015/05/16 inception model
|
||||
@ -1195,9 +1209,9 @@ BM_MaxPoolBkCPU(32, 8, 8, 2048, 3, 3, 2, VALID, 1, "maxpool_grad4_VALID");
|
||||
Relu Op
|
||||
Run benchmark with:
|
||||
*/
|
||||
static void BM_ReluFloat(int iters, int batch_size, int rows, int cols,
|
||||
int depth, int num_threads, const string& label) {
|
||||
tensorflow::testing::StopTiming();
|
||||
static void BM_ReluFloat(::testing::benchmark::State& state, int batch_size,
|
||||
int rows, int cols, int depth, int num_threads,
|
||||
const string& label) {
|
||||
std::unique_ptr<Device> device(
|
||||
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
|
||||
|
||||
@ -1233,16 +1247,13 @@ static void BM_ReluFloat(int iters, int batch_size, int rows, int cols,
|
||||
std::unique_ptr<OpKernelContext> relu_context(new OpKernelContext(¶ms));
|
||||
|
||||
op->Compute(relu_context.get());
|
||||
testing::UseRealTime();
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
for (auto s : state) {
|
||||
delete relu_context->release_output(0).tensor;
|
||||
op->Compute(relu_context.get());
|
||||
}
|
||||
tensorflow::testing::StopTiming();
|
||||
testing::ItemsProcessed(relu_context->mutable_output(0)->NumElements() *
|
||||
iters);
|
||||
testing::SetLabel(label);
|
||||
state.SetItemsProcessed(relu_context->mutable_output(0)->NumElements() *
|
||||
state.iterations());
|
||||
state.SetLabel(label);
|
||||
}
|
||||
|
||||
// BS: batch_size
|
||||
@ -1250,10 +1261,11 @@ static void BM_ReluFloat(int iters, int batch_size, int rows, int cols,
|
||||
// IC: input_cols
|
||||
// ND: node_depth
|
||||
#define BM_Relu(BS, IR, IC, ND, TH, LABEL) \
|
||||
static void BM_ReluFloat_##BS##_##IR##_##IC##_##ND##_##TH(int iters) { \
|
||||
BM_ReluFloat(iters, BS, IR, IC, ND, TH, LABEL); \
|
||||
static void BM_ReluFloat_##BS##_##IR##_##IC##_##ND##_##TH( \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_ReluFloat(state, BS, IR, IC, ND, TH, LABEL); \
|
||||
} \
|
||||
BENCHMARK(BM_ReluFloat_##BS##_##IR##_##IC##_##ND##_##TH)
|
||||
BENCHMARK(BM_ReluFloat_##BS##_##IR##_##IC##_##ND##_##TH)->UseRealTime()
|
||||
|
||||
BM_Relu(32, 112, 112, 64, 1, "relu0");
|
||||
BM_Relu(32, 56, 56, 192, 1, "relu1");
|
||||
@ -1268,9 +1280,9 @@ BM_Relu(32, 14, 14, 576, 4, "relu10");
|
||||
Softplus Op
|
||||
Run benchmark with:
|
||||
*/
|
||||
static void BM_SoftplusFloat(int iters, int batch_size, int rows, int cols,
|
||||
int depth, int num_threads, const string& label) {
|
||||
tensorflow::testing::StopTiming();
|
||||
static void BM_SoftplusFloat(::testing::benchmark::State& state, int batch_size,
|
||||
int rows, int cols, int depth, int num_threads,
|
||||
const string& label) {
|
||||
std::unique_ptr<Device> device(
|
||||
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
|
||||
|
||||
@ -1307,16 +1319,13 @@ static void BM_SoftplusFloat(int iters, int batch_size, int rows, int cols,
|
||||
new OpKernelContext(¶ms));
|
||||
|
||||
op->Compute(softplus_context.get());
|
||||
testing::UseRealTime();
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
for (auto s : state) {
|
||||
delete softplus_context->release_output(0).tensor;
|
||||
op->Compute(softplus_context.get());
|
||||
}
|
||||
tensorflow::testing::StopTiming();
|
||||
testing::ItemsProcessed(softplus_context->mutable_output(0)->NumElements() *
|
||||
iters);
|
||||
testing::SetLabel(label);
|
||||
state.SetItemsProcessed(softplus_context->mutable_output(0)->NumElements() *
|
||||
state.iterations());
|
||||
state.SetLabel(label);
|
||||
}
|
||||
|
||||
// BS: batch_size
|
||||
@ -1324,10 +1333,11 @@ static void BM_SoftplusFloat(int iters, int batch_size, int rows, int cols,
|
||||
// IC: input_cols
|
||||
// ND: node_depth
|
||||
#define BM_Softplus(BS, IR, IC, ND, TH, LABEL) \
|
||||
static void BM_SoftplusFloat_##BS##_##IR##_##IC##_##ND##_##TH(int iters) { \
|
||||
BM_SoftplusFloat(iters, BS, IR, IC, ND, TH, LABEL); \
|
||||
static void BM_SoftplusFloat_##BS##_##IR##_##IC##_##ND##_##TH( \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_SoftplusFloat(state, BS, IR, IC, ND, TH, LABEL); \
|
||||
} \
|
||||
BENCHMARK(BM_SoftplusFloat_##BS##_##IR##_##IC##_##ND##_##TH)
|
||||
BENCHMARK(BM_SoftplusFloat_##BS##_##IR##_##IC##_##ND##_##TH)->UseRealTime()
|
||||
|
||||
BM_Softplus(32, 112, 112, 64, 1, "softplus0");
|
||||
BM_Softplus(32, 56, 56, 192, 1, "softplus1");
|
||||
@ -1338,7 +1348,8 @@ BM_Softplus(32, 56, 56, 192, 4, "softplus1");
|
||||
BM_Softplus(32, 28, 28, 352, 4, "softplus4");
|
||||
BM_Softplus(32, 14, 14, 576, 4, "softplus10");
|
||||
|
||||
static void BM_ImageNetSoftmaxFwd(int iters, int batch_size, int node_depth,
|
||||
static void BM_ImageNetSoftmaxFwd(::testing::benchmark::State& state,
|
||||
int batch_size, int node_depth,
|
||||
int num_threads, bool use_gpu,
|
||||
const string& label) {
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
@ -1359,19 +1370,21 @@ static void BM_ImageNetSoftmaxFwd(int iters, int batch_size, int node_depth,
|
||||
opts.config.mutable_graph_options()
|
||||
->mutable_optimizer_options()
|
||||
->set_opt_level(OptimizerOptions::L0);
|
||||
testing::UseRealTime();
|
||||
test::Benchmark(device, g, &opts).Run(iters);
|
||||
testing::ItemsProcessed(batch_size * node_depth * iters);
|
||||
testing::SetLabel(label);
|
||||
test::Benchmark(device, g, &opts, nullptr, nullptr, "",
|
||||
/*old_benchmark_api*/ false)
|
||||
.Run(state);
|
||||
state.SetItemsProcessed(batch_size * node_depth * state.iterations());
|
||||
state.SetLabel(label);
|
||||
}
|
||||
|
||||
#define BM_ImageNetSoftmaxFwd(BATCH_SIZE, NODE_DEPTH, TH, GPU, LABEL) \
|
||||
static void \
|
||||
BM_ImageNetSoftmaxFwd_##BATCH_SIZE##_##NODE_DEPTH##_##TH##_##GPU( \
|
||||
int iters) { \
|
||||
BM_ImageNetSoftmaxFwd(iters, BATCH_SIZE, NODE_DEPTH, TH, GPU, LABEL); \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_ImageNetSoftmaxFwd(state, BATCH_SIZE, NODE_DEPTH, TH, GPU, LABEL); \
|
||||
} \
|
||||
BENCHMARK(BM_ImageNetSoftmaxFwd_##BATCH_SIZE##_##NODE_DEPTH##_##TH##_##GPU)
|
||||
BENCHMARK(BM_ImageNetSoftmaxFwd_##BATCH_SIZE##_##NODE_DEPTH##_##TH##_##GPU) \
|
||||
->UseRealTime()
|
||||
|
||||
// Labels are taken from the 2014-July-24 version of imagenet
|
||||
BM_ImageNetSoftmaxFwd(32, 1008, 1, false, "softmax32");
|
||||
@ -1383,9 +1396,8 @@ BM_ImageNetSoftmaxFwd(128, 1008, 1, true, "softmax128");
|
||||
BM_ImageNetSoftmaxFwd(8192, 1024, 1, true, "softmax32");
|
||||
BM_ImageNetSoftmaxFwd(8192, 32768, 1, true, "softmax128");
|
||||
|
||||
static void BM_TopK(int iters, int rows, int cols, int k, int num_threads,
|
||||
bool use_gpu, const string& label) {
|
||||
testing::StopTiming();
|
||||
static void BM_TopK(::testing::benchmark::State& state, int rows, int cols,
|
||||
int k, int num_threads, bool use_gpu, const string& label) {
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
Tensor input(DT_FLOAT, TensorShape({rows, cols}));
|
||||
@ -1407,11 +1419,11 @@ static void BM_TopK(int iters, int rows, int cols, int k, int num_threads,
|
||||
opts.config.mutable_graph_options()
|
||||
->mutable_optimizer_options()
|
||||
->set_opt_level(OptimizerOptions::L0);
|
||||
testing::UseRealTime();
|
||||
testing::StartTiming();
|
||||
test::Benchmark(device, g, &opts).Run(iters);
|
||||
testing::ItemsProcessed(rows * cols * iters);
|
||||
testing::SetLabel(label);
|
||||
test::Benchmark(device, g, &opts, nullptr, nullptr, "",
|
||||
/*old_benchmark_api=*/false)
|
||||
.Run(state);
|
||||
state.SetItemsProcessed(rows * cols * state.iterations());
|
||||
state.SetLabel(label);
|
||||
}
|
||||
|
||||
// IR: input_rows
|
||||
@ -1419,16 +1431,18 @@ static void BM_TopK(int iters, int rows, int cols, int k, int num_threads,
|
||||
// IK: k
|
||||
// TH: number of threads
|
||||
#define BM_TopKGPU(IR, IC, IK, TH, LABEL) \
|
||||
static void BM_TopK_GPU_##IR##_##IC##_##IK##_##TH(int iters) { \
|
||||
BM_TopK(iters, IR, IC, IK, TH, true, LABEL); \
|
||||
static void BM_TopK_GPU_##IR##_##IC##_##IK##_##TH( \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_TopK(state, IR, IC, IK, TH, true, LABEL); \
|
||||
} \
|
||||
BENCHMARK(BM_TopK_GPU_##IR##_##IC##_##IK##_##TH)
|
||||
BENCHMARK(BM_TopK_GPU_##IR##_##IC##_##IK##_##TH)->UseRealTime()
|
||||
|
||||
#define BM_TopKCPU(IR, IC, IK, TH, LABEL) \
|
||||
static void BM_TopK_CPU_##IR##_##IC##_##IK##_##TH(int iters) { \
|
||||
BM_TopK(iters, IR, IC, IK, TH, false, LABEL); \
|
||||
static void BM_TopK_CPU_##IR##_##IC##_##IK##_##TH( \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_TopK(state, IR, IC, IK, TH, false, LABEL); \
|
||||
} \
|
||||
BENCHMARK(BM_TopK_CPU_##IR##_##IC##_##IK##_##TH)
|
||||
BENCHMARK(BM_TopK_CPU_##IR##_##IC##_##IK##_##TH)->UseRealTime()
|
||||
|
||||
// clang-format on
|
||||
|
||||
|
@ -56,9 +56,13 @@ static Graph* OneHot(int batch_size, int num_classes, int axis) {
|
||||
}
|
||||
|
||||
#define BM_OneHot(BATCH, CLASS, AXIS, DEVICE) \
|
||||
static void BM_OneHot##_##BATCH##_##CLASS##_##AXIS##_##DEVICE(int iters) { \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * BATCH * CLASS); \
|
||||
test::Benchmark(#DEVICE, OneHot(BATCH, CLASS, AXIS)).Run(iters); \
|
||||
static void BM_OneHot##_##BATCH##_##CLASS##_##AXIS##_##DEVICE( \
|
||||
::testing::benchmark::State& state) { \
|
||||
test::Benchmark(#DEVICE, OneHot(BATCH, CLASS, AXIS), \
|
||||
/*old_benchmark_api*/ false) \
|
||||
.Run(state); \
|
||||
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * BATCH * \
|
||||
CLASS); \
|
||||
} \
|
||||
BENCHMARK(BM_OneHot##_##BATCH##_##CLASS##_##AXIS##_##DEVICE);
|
||||
|
||||
|
@ -108,23 +108,32 @@ static Graph* PTruncatedNormalOneTail(int num_batches, int samples_per_batch) {
|
||||
}
|
||||
|
||||
#define BM_PTruncatedNormalDev(DEVICE, B, S) \
|
||||
static void BM_PTruncatedNormal_##DEVICE##_##B##_##S(int iters) { \
|
||||
test::Benchmark(#DEVICE, PTruncatedNormal(B, S)).Run(iters); \
|
||||
testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \
|
||||
static void BM_PTruncatedNormal_##DEVICE##_##B##_##S( \
|
||||
::testing::benchmark::State& state) { \
|
||||
test::Benchmark(#DEVICE, PTruncatedNormal(B, S), \
|
||||
/*old_benchmark_api*/ false) \
|
||||
.Run(state); \
|
||||
state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \
|
||||
} \
|
||||
BENCHMARK(BM_PTruncatedNormal_##DEVICE##_##B##_##S);
|
||||
|
||||
#define BM_PTruncatedNormalDev_2SD(DEVICE, B, S) \
|
||||
static void BM_PTruncatedNormal_2SD_##DEVICE##_##B##_##S(int iters) { \
|
||||
test::Benchmark(#DEVICE, PTruncatedNormal2SD(B, S)).Run(iters); \
|
||||
testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \
|
||||
static void BM_PTruncatedNormal_2SD_##DEVICE##_##B##_##S( \
|
||||
::testing::benchmark::State& state) { \
|
||||
test::Benchmark(#DEVICE, PTruncatedNormal2SD(B, S), \
|
||||
/*old_benchmark_api*/ false) \
|
||||
.Run(state); \
|
||||
state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \
|
||||
} \
|
||||
BENCHMARK(BM_PTruncatedNormal_2SD_##DEVICE##_##B##_##S);
|
||||
|
||||
#define BM_PTruncatedNormalDev_OneTail(DEVICE, B, S) \
|
||||
static void BM_PTruncatedNormal_OneTail_##DEVICE##_##B##_##S(int iters) { \
|
||||
test::Benchmark(#DEVICE, PTruncatedNormalOneTail(B, S)).Run(iters); \
|
||||
testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \
|
||||
static void BM_PTruncatedNormal_OneTail_##DEVICE##_##B##_##S( \
|
||||
::testing::benchmark::State& state) { \
|
||||
test::Benchmark(#DEVICE, PTruncatedNormalOneTail(B, S), \
|
||||
/*old_benchmark_api*/ false) \
|
||||
.Run(state); \
|
||||
state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \
|
||||
} \
|
||||
BENCHMARK(BM_PTruncatedNormal_OneTail_##DEVICE##_##B##_##S);
|
||||
|
||||
|
@ -760,13 +760,14 @@ TEST_F(QuantizeAndDequantizeTest, Invalid_range_given_V3) {
|
||||
}
|
||||
|
||||
#define BM_SIMPLE_QUAN_DEQUAN(DEVICE) \
|
||||
static void BM_SIMPLE_QUAN_DEQUAN_##DEVICE(int iters) { \
|
||||
static void BM_SIMPLE_QUAN_DEQUAN_##DEVICE( \
|
||||
::testing::benchmark::State& state) { \
|
||||
auto root = Scope::NewRootScope().ExitOnError(); \
|
||||
ops::QuantizeAndDequantizeV2(root, -3.5, -3.5, -3.5); \
|
||||
TF_CHECK_OK(root.status()); \
|
||||
Graph* g = new Graph(OpRegistry::Global()); \
|
||||
TF_CHECK_OK(root.ToGraph(g)); \
|
||||
test::Benchmark(#DEVICE, g).Run(iters); \
|
||||
test::Benchmark(#DEVICE, g, /*old_benchmark_api*/ false).Run(state); \
|
||||
} \
|
||||
BENCHMARK(BM_SIMPLE_QUAN_DEQUAN_##DEVICE);
|
||||
|
||||
|
@ -248,9 +248,8 @@ void QuantizedConcatTest::TestSecondDim8Bit(float first_min, float first_max,
|
||||
// If <same_limits> is true, then both concatenated dimensions have the same
|
||||
// quantized range; otherwise, they are set to different values.
|
||||
template <typename T>
|
||||
static void ConcatHelper(int iters, int concat_dimension, bool same_limits,
|
||||
int dim2) {
|
||||
testing::StopTiming();
|
||||
static void ConcatHelper(::testing::benchmark::State& state,
|
||||
int concat_dimension, bool same_limits, int dim2) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
|
||||
DataType dt = DataTypeToEnum<T>::v();
|
||||
@ -278,61 +277,111 @@ static void ConcatHelper(int iters, int concat_dimension, bool same_limits,
|
||||
.Attr("T", dt)
|
||||
.Finalize(g, &node));
|
||||
|
||||
testing::BytesProcessed(static_cast<int64>(iters) *
|
||||
test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state);
|
||||
state.SetBytesProcessed(static_cast<int64>(state.iterations()) *
|
||||
((kDim1 * dim2) + (kDim1 * dim2)) * sizeof(T));
|
||||
testing::StartTiming();
|
||||
test::Benchmark("cpu", g).Run(iters);
|
||||
testing::UseRealTime();
|
||||
}
|
||||
|
||||
static void BM_QConcatDim0SameLimitQInt32(int iters, int dim2) {
|
||||
ConcatHelper<qint32>(iters, 0 /* concat_dimension */, true /* same_limits */,
|
||||
static void BM_QConcatDim0SameLimitQInt32(::testing::benchmark::State& state) {
|
||||
const int dim2 = state.range(0);
|
||||
|
||||
ConcatHelper<qint32>(state, 0 /* concat_dimension */, true /* same_limits */,
|
||||
dim2);
|
||||
}
|
||||
|
||||
static void BM_QConcatDim1SameLimitQInt32(int iters, int dim2) {
|
||||
ConcatHelper<qint32>(iters, 1 /* concat_dimension */, true /* same_limits */,
|
||||
static void BM_QConcatDim1SameLimitQInt32(::testing::benchmark::State& state) {
|
||||
const int dim2 = state.range(0);
|
||||
|
||||
ConcatHelper<qint32>(state, 1 /* concat_dimension */, true /* same_limits */,
|
||||
dim2);
|
||||
}
|
||||
|
||||
static void BM_QConcatDim0DifferLimitQInt32(int iters, int dim2) {
|
||||
ConcatHelper<qint32>(iters, 0 /* concat_dimension */, false /* same_limits */,
|
||||
static void BM_QConcatDim0DifferLimitQInt32(
|
||||
::testing::benchmark::State& state) {
|
||||
const int dim2 = state.range(0);
|
||||
|
||||
ConcatHelper<qint32>(state, 0 /* concat_dimension */, false /* same_limits */,
|
||||
dim2);
|
||||
}
|
||||
|
||||
static void BM_QConcatDim1DifferLimitQInt32(int iters, int dim2) {
|
||||
ConcatHelper<qint32>(iters, 1 /* concat_dimension */, false /* same_limits */,
|
||||
static void BM_QConcatDim1DifferLimitQInt32(
|
||||
::testing::benchmark::State& state) {
|
||||
const int dim2 = state.range(0);
|
||||
|
||||
ConcatHelper<qint32>(state, 1 /* concat_dimension */, false /* same_limits */,
|
||||
dim2);
|
||||
}
|
||||
|
||||
BENCHMARK(BM_QConcatDim0SameLimitQInt32)->Arg(1000)->Arg(20000)->Arg(100000);
|
||||
BENCHMARK(BM_QConcatDim1SameLimitQInt32)->Arg(1000)->Arg(20000)->Arg(100000);
|
||||
BENCHMARK(BM_QConcatDim0DifferLimitQInt32)->Arg(1000)->Arg(20000)->Arg(100000);
|
||||
BENCHMARK(BM_QConcatDim1DifferLimitQInt32)->Arg(1000)->Arg(20000)->Arg(100000);
|
||||
BENCHMARK(BM_QConcatDim0SameLimitQInt32)
|
||||
->UseRealTime()
|
||||
->Arg(1000)
|
||||
->Arg(20000)
|
||||
->Arg(100000);
|
||||
BENCHMARK(BM_QConcatDim1SameLimitQInt32)
|
||||
->UseRealTime()
|
||||
->Arg(1000)
|
||||
->Arg(20000)
|
||||
->Arg(100000);
|
||||
BENCHMARK(BM_QConcatDim0DifferLimitQInt32)
|
||||
->UseRealTime()
|
||||
->Arg(1000)
|
||||
->Arg(20000)
|
||||
->Arg(100000);
|
||||
BENCHMARK(BM_QConcatDim1DifferLimitQInt32)
|
||||
->UseRealTime()
|
||||
->Arg(1000)
|
||||
->Arg(20000)
|
||||
->Arg(100000);
|
||||
|
||||
static void BM_QConcatDim0SameLimitQUint8(int iters, int dim2) {
|
||||
ConcatHelper<qint32>(iters, 0 /* concat_dimension */, true /* same_limits */,
|
||||
static void BM_QConcatDim0SameLimitQUint8(::testing::benchmark::State& state) {
|
||||
const int dim2 = state.range(0);
|
||||
|
||||
ConcatHelper<qint32>(state, 0 /* concat_dimension */, true /* same_limits */,
|
||||
dim2);
|
||||
}
|
||||
|
||||
static void BM_QConcatDim1SameLimitQUint8(int iters, int dim2) {
|
||||
ConcatHelper<qint32>(iters, 1 /* concat_dimension */, true /* same_limits */,
|
||||
static void BM_QConcatDim1SameLimitQUint8(::testing::benchmark::State& state) {
|
||||
const int dim2 = state.range(0);
|
||||
|
||||
ConcatHelper<qint32>(state, 1 /* concat_dimension */, true /* same_limits */,
|
||||
dim2);
|
||||
}
|
||||
|
||||
static void BM_QConcatDim0DifferLimitQUint8(int iters, int dim2) {
|
||||
ConcatHelper<qint32>(iters, 0 /* concat_dimension */, false /* same_limits */,
|
||||
static void BM_QConcatDim0DifferLimitQUint8(
|
||||
::testing::benchmark::State& state) {
|
||||
const int dim2 = state.range(0);
|
||||
|
||||
ConcatHelper<qint32>(state, 0 /* concat_dimension */, false /* same_limits */,
|
||||
dim2);
|
||||
}
|
||||
|
||||
static void BM_QConcatDim1DifferLimitQUint8(int iters, int dim2) {
|
||||
ConcatHelper<qint32>(iters, 1 /* concat_dimension */, false /* same_limits */,
|
||||
static void BM_QConcatDim1DifferLimitQUint8(
|
||||
::testing::benchmark::State& state) {
|
||||
const int dim2 = state.range(0);
|
||||
|
||||
ConcatHelper<qint32>(state, 1 /* concat_dimension */, false /* same_limits */,
|
||||
dim2);
|
||||
}
|
||||
|
||||
BENCHMARK(BM_QConcatDim0SameLimitQUint8)->Arg(1000)->Arg(20000)->Arg(100000);
|
||||
BENCHMARK(BM_QConcatDim1SameLimitQUint8)->Arg(1000)->Arg(20000)->Arg(100000);
|
||||
BENCHMARK(BM_QConcatDim0DifferLimitQUint8)->Arg(1000)->Arg(20000)->Arg(100000);
|
||||
BENCHMARK(BM_QConcatDim1DifferLimitQUint8)->Arg(1000)->Arg(20000)->Arg(100000);
|
||||
BENCHMARK(BM_QConcatDim0SameLimitQUint8)
|
||||
->UseRealTime()
|
||||
->Arg(1000)
|
||||
->Arg(20000)
|
||||
->Arg(100000);
|
||||
BENCHMARK(BM_QConcatDim1SameLimitQUint8)
|
||||
->UseRealTime()
|
||||
->Arg(1000)
|
||||
->Arg(20000)
|
||||
->Arg(100000);
|
||||
BENCHMARK(BM_QConcatDim0DifferLimitQUint8)
|
||||
->UseRealTime()
|
||||
->Arg(1000)
|
||||
->Arg(20000)
|
||||
->Arg(100000);
|
||||
BENCHMARK(BM_QConcatDim1DifferLimitQUint8)
|
||||
->UseRealTime()
|
||||
->Arg(1000)
|
||||
->Arg(20000)
|
||||
->Arg(100000);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -68,30 +68,42 @@ static Graph* RandomBinomialRejComplement(int num_batches,
|
||||
}
|
||||
|
||||
#define BM_RandomBinomialInv(DEVICE, B, S) \
|
||||
static void BM_RandomBinomialInv_##DEVICE##_##B##_##S(int iters) { \
|
||||
test::Benchmark(#DEVICE, RandomBinomialInv(B, S)).Run(iters); \
|
||||
testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \
|
||||
static void BM_RandomBinomialInv_##DEVICE##_##B##_##S( \
|
||||
::testing::benchmark::State& state) { \
|
||||
test::Benchmark(#DEVICE, RandomBinomialInv(B, S), \
|
||||
/*old_benchmark_api*/ false) \
|
||||
.Run(state); \
|
||||
state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \
|
||||
} \
|
||||
BENCHMARK(BM_RandomBinomialInv_##DEVICE##_##B##_##S);
|
||||
|
||||
#define BM_RandomBinomialRej(DEVICE, B, S) \
|
||||
static void BM_RandomBinomialRej_##DEVICE##_##B##_##S(int iters) { \
|
||||
test::Benchmark(#DEVICE, RandomBinomialRej(B, S)).Run(iters); \
|
||||
testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \
|
||||
static void BM_RandomBinomialRej_##DEVICE##_##B##_##S( \
|
||||
::testing::benchmark::State& state) { \
|
||||
test::Benchmark(#DEVICE, RandomBinomialRej(B, S), \
|
||||
/*old_benchmark_api*/ false) \
|
||||
.Run(state); \
|
||||
state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \
|
||||
} \
|
||||
BENCHMARK(BM_RandomBinomialRej_##DEVICE##_##B##_##S);
|
||||
|
||||
#define BM_RandomBinomialInvComplement(DEVICE, B, S) \
|
||||
static void BM_RandomBinomialInvComplement_##DEVICE##_##B##_##S(int iters) { \
|
||||
test::Benchmark(#DEVICE, RandomBinomialInvComplement(B, S)).Run(iters); \
|
||||
testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \
|
||||
static void BM_RandomBinomialInvComplement_##DEVICE##_##B##_##S( \
|
||||
::testing::benchmark::State& state) { \
|
||||
test::Benchmark(#DEVICE, RandomBinomialInvComplement(B, S), \
|
||||
/*old_benchmark_api*/ false) \
|
||||
.Run(state); \
|
||||
state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \
|
||||
} \
|
||||
BENCHMARK(BM_RandomBinomialInvComplement_##DEVICE##_##B##_##S);
|
||||
|
||||
#define BM_RandomBinomialRejComplement(DEVICE, B, S) \
|
||||
static void BM_RandomBinomialRejComplement_##DEVICE##_##B##_##S(int iters) { \
|
||||
test::Benchmark(#DEVICE, RandomBinomialRejComplement(B, S)).Run(iters); \
|
||||
testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \
|
||||
static void BM_RandomBinomialRejComplement_##DEVICE##_##B##_##S( \
|
||||
::testing::benchmark::State& state) { \
|
||||
test::Benchmark(#DEVICE, RandomBinomialRejComplement(B, S), \
|
||||
/*old_benchmark_api*/ false) \
|
||||
.Run(state); \
|
||||
state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \
|
||||
} \
|
||||
BENCHMARK(BM_RandomBinomialRejComplement_##DEVICE##_##B##_##S);
|
||||
|
||||
|
@ -59,9 +59,12 @@ Graph* TruncatedNormal(int64 n) {
|
||||
}
|
||||
|
||||
#define BM_RNG(DEVICE, RNG) \
|
||||
void BM_##DEVICE##_##RNG(int iters, int arg) { \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * arg); \
|
||||
test::Benchmark(#DEVICE, RNG(arg)).Run(iters); \
|
||||
void BM_##DEVICE##_##RNG(::testing::benchmark::State& state) { \
|
||||
const int arg = state.range(0); \
|
||||
\
|
||||
test::Benchmark(#DEVICE, RNG(arg), /*old_benchmark_api*/ false) \
|
||||
.Run(state); \
|
||||
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * arg); \
|
||||
} \
|
||||
BENCHMARK(BM_##DEVICE##_##RNG)->Range(1 << 20, 8 << 20);
|
||||
|
||||
@ -84,60 +87,48 @@ Tensor VecAlphas(int64 n) {
|
||||
return alphas;
|
||||
}
|
||||
|
||||
void BM_cpu_RandomGamma(int iters, int nsamp, int nalpha) {
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * nsamp * nalpha);
|
||||
void BM_cpu_RandomGamma(::testing::benchmark::State& state) {
|
||||
const int nsamp = state.range(0);
|
||||
const int nalpha = state.range(1);
|
||||
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
test::graph::RandomGamma(g, test::graph::Constant(g, VecShape(nsamp)),
|
||||
test::graph::Constant(g, VecAlphas(nalpha)));
|
||||
test::Benchmark("cpu", g).Run(iters);
|
||||
test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state);
|
||||
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * nsamp *
|
||||
nalpha);
|
||||
}
|
||||
BENCHMARK(BM_cpu_RandomGamma)->RangePair(1 << 14, 4 << 15, 2, 50);
|
||||
|
||||
void BM_PhiloxRandom(int iters) {
|
||||
void BM_PhiloxRandom(::testing::benchmark::State& state) {
|
||||
// Fill 2M random numbers
|
||||
int count = 2 << 20;
|
||||
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * count);
|
||||
|
||||
random::PhiloxRandom gen(0x12345);
|
||||
|
||||
int val = 1;
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
for (auto s : state) {
|
||||
for (int j = 0; j < count; j += 4) {
|
||||
/// each invocation of gen() returns 128-bit samples
|
||||
auto samples = gen();
|
||||
|
||||
// use the result trivially so the compiler does not optimize it away
|
||||
val ^= samples[0] ^ samples[1] ^ samples[2] ^ samples[3];
|
||||
tensorflow::testing::DoNotOptimize(samples);
|
||||
}
|
||||
}
|
||||
|
||||
// A anchor point to make sure the compiler does not cut corners
|
||||
CHECK(val) << val;
|
||||
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * count);
|
||||
}
|
||||
BENCHMARK(BM_PhiloxRandom);
|
||||
|
||||
void BM_StdMTRandom(int iters) {
|
||||
void BM_StdMTRandom(::testing::benchmark::State& state) {
|
||||
// Fill 2M random numbers
|
||||
int count = 2 << 20;
|
||||
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * count);
|
||||
|
||||
std::mt19937 gen(0x12345);
|
||||
|
||||
uint_fast32_t val = 1;
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
for (auto s : state) {
|
||||
for (int j = 0; j < count; ++j) {
|
||||
/// each invocation of gen() returns 32-bit sample
|
||||
uint_fast32_t sample = gen();
|
||||
|
||||
// use the result trivially so the compiler does not optimize it away
|
||||
val ^= sample;
|
||||
tensorflow::testing::DoNotOptimize(sample);
|
||||
}
|
||||
}
|
||||
|
||||
// A anchor point to make sure the compiler does not cut corners
|
||||
CHECK(val) << val;
|
||||
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * count);
|
||||
}
|
||||
BENCHMARK(BM_StdMTRandom);
|
||||
|
||||
|
@ -84,108 +84,167 @@ static Graph* ThreeDXZReduce(const string& reduce, int num_y, int num_z) {
|
||||
// Creates a bench which reduces a 3D tensor with total "num" floats
|
||||
// into a scalar on a "device". Runs the bench for "iters" times.
|
||||
template <typename T>
|
||||
static void ReduceToScalar(int iters, const string& device,
|
||||
const string& reduce, int num_x, int num_y) {
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
|
||||
testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
|
||||
sizeof(T));
|
||||
test::Benchmark(device, ToScalar<T>(reduce, num_x, num_y)).Run(iters);
|
||||
}
|
||||
|
||||
static void DoRowReduce(int iters, const string& device, const string& reduce,
|
||||
static void ReduceToScalar(::testing::benchmark::State& state,
|
||||
const string& device, const string& reduce,
|
||||
int num_x, int num_y) {
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
|
||||
testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
|
||||
sizeof(float));
|
||||
test::Benchmark(device, RowReduce(reduce, num_x, num_y)).Run(iters);
|
||||
test::Benchmark(device, ToScalar<T>(reduce, num_x, num_y),
|
||||
/*old_benchmark_api*/ false)
|
||||
.Run(state);
|
||||
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x *
|
||||
num_y);
|
||||
state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
|
||||
num_y * sizeof(T));
|
||||
}
|
||||
|
||||
static void DoColReduce(int iters, const string& device, const string& reduce,
|
||||
int num_x, int num_y) {
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
|
||||
testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
|
||||
sizeof(float));
|
||||
test::Benchmark(device, ColReduce(reduce, num_x, num_y)).Run(iters);
|
||||
static void DoRowReduce(::testing::benchmark::State& state,
|
||||
const string& device, const string& reduce, int num_x,
|
||||
int num_y) {
|
||||
test::Benchmark(device, RowReduce(reduce, num_x, num_y),
|
||||
/*old_benchmark_api*/ false)
|
||||
.Run(state);
|
||||
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x *
|
||||
num_y);
|
||||
state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
|
||||
num_y * sizeof(float));
|
||||
}
|
||||
|
||||
static void Do3DYReduce(int iters, const string& device, const string& reduce,
|
||||
int num_x, int num_y) {
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
|
||||
testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
|
||||
sizeof(float));
|
||||
test::Benchmark(device, ThreeDYReduce(reduce, num_x, num_y)).Run(iters);
|
||||
static void DoColReduce(::testing::benchmark::State& state,
|
||||
const string& device, const string& reduce, int num_x,
|
||||
int num_y) {
|
||||
test::Benchmark(device, ColReduce(reduce, num_x, num_y),
|
||||
/*old_benchmark_api*/ false)
|
||||
.Run(state);
|
||||
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x *
|
||||
num_y);
|
||||
state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
|
||||
num_y * sizeof(float));
|
||||
}
|
||||
|
||||
static void Do3DXZReduce(int iters, const string& device, const string& reduce,
|
||||
int num_x, int num_y) {
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
|
||||
testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
|
||||
sizeof(float));
|
||||
test::Benchmark(device, ThreeDXZReduce(reduce, num_x, num_y)).Run(iters);
|
||||
static void Do3DYReduce(::testing::benchmark::State& state,
|
||||
const string& device, const string& reduce, int num_x,
|
||||
int num_y) {
|
||||
test::Benchmark(device, ThreeDYReduce(reduce, num_x, num_y),
|
||||
/*old_benchmark_api*/ false)
|
||||
.Run(state);
|
||||
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x *
|
||||
num_y);
|
||||
state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
|
||||
num_y * sizeof(float));
|
||||
}
|
||||
|
||||
static void BM_Sum2DToScalarGPU(int iters, int num_x, int num_y) {
|
||||
ReduceToScalar<float>(iters, "gpu", "Sum", num_x, num_y);
|
||||
static void Do3DXZReduce(::testing::benchmark::State& state,
|
||||
const string& device, const string& reduce, int num_x,
|
||||
int num_y) {
|
||||
test::Benchmark(device, ThreeDXZReduce(reduce, num_x, num_y),
|
||||
/*old_benchmark_api*/ false)
|
||||
.Run(state);
|
||||
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x *
|
||||
num_y);
|
||||
state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
|
||||
num_y * sizeof(float));
|
||||
}
|
||||
|
||||
static void BM_Sum2DToScalarGPU(::testing::benchmark::State& state) {
|
||||
const int num_x = state.range(0);
|
||||
const int num_y = state.range(1);
|
||||
|
||||
ReduceToScalar<float>(state, "gpu", "Sum", num_x, num_y);
|
||||
}
|
||||
BENCHMARK(BM_Sum2DToScalarGPU)->RangePair(1, 8192, 1, 8192);
|
||||
|
||||
static void BM_Sum2DToScalarGPUComplex(int iters, int num_x, int num_y) {
|
||||
ReduceToScalar<std::complex<float>>(iters, "gpu", "Sum", num_x, num_y);
|
||||
static void BM_Sum2DToScalarGPUComplex(::testing::benchmark::State& state) {
|
||||
const int num_x = state.range(0);
|
||||
const int num_y = state.range(1);
|
||||
|
||||
ReduceToScalar<std::complex<float>>(state, "gpu", "Sum", num_x, num_y);
|
||||
}
|
||||
BENCHMARK(BM_Sum2DToScalarGPUComplex)->RangePair(1, 8192, 1, 8192);
|
||||
|
||||
static void BM_Sum2DToScalarGPUHalf(int iters, int num_x, int num_y) {
|
||||
ReduceToScalar<Eigen::half>(iters, "gpu", "Sum", num_x, num_y);
|
||||
static void BM_Sum2DToScalarGPUHalf(::testing::benchmark::State& state) {
|
||||
const int num_x = state.range(0);
|
||||
const int num_y = state.range(1);
|
||||
|
||||
ReduceToScalar<Eigen::half>(state, "gpu", "Sum", num_x, num_y);
|
||||
}
|
||||
BENCHMARK(BM_Sum2DToScalarGPUHalf)->RangePair(1, 8192, 1, 8192);
|
||||
|
||||
static void BM_Sum2DRowReduceGPU(int iters, int num_x, int num_y) {
|
||||
DoRowReduce(iters, "gpu", "Sum", num_x, num_y);
|
||||
static void BM_Sum2DRowReduceGPU(::testing::benchmark::State& state) {
|
||||
const int num_x = state.range(0);
|
||||
const int num_y = state.range(1);
|
||||
|
||||
DoRowReduce(state, "gpu", "Sum", num_x, num_y);
|
||||
}
|
||||
BENCHMARK(BM_Sum2DRowReduceGPU)->RangePair(1, 8192, 1, 8192);
|
||||
|
||||
static void BM_Sum2DColumnReduceGPU(int iters, int num_x, int num_y) {
|
||||
DoColReduce(iters, "gpu", "Sum", num_x, num_y);
|
||||
static void BM_Sum2DColumnReduceGPU(::testing::benchmark::State& state) {
|
||||
const int num_x = state.range(0);
|
||||
const int num_y = state.range(1);
|
||||
|
||||
DoColReduce(state, "gpu", "Sum", num_x, num_y);
|
||||
}
|
||||
BENCHMARK(BM_Sum2DColumnReduceGPU)->RangePair(1, 8192, 1, 8192);
|
||||
|
||||
static void BM_Sum3DYReduceGPU(int iters, int num_x, int num_y) {
|
||||
Do3DYReduce(iters, "gpu", "Sum", num_x, num_y);
|
||||
static void BM_Sum3DYReduceGPU(::testing::benchmark::State& state) {
|
||||
const int num_x = state.range(0);
|
||||
const int num_y = state.range(1);
|
||||
|
||||
Do3DYReduce(state, "gpu", "Sum", num_x, num_y);
|
||||
}
|
||||
BENCHMARK(BM_Sum3DYReduceGPU)->RangePair(64, 4096, 64, 4096);
|
||||
|
||||
static void BM_Sum3DXZReduceGPU(int iters, int num_x, int num_y) {
|
||||
Do3DXZReduce(iters, "gpu", "Sum", num_x, num_y);
|
||||
static void BM_Sum3DXZReduceGPU(::testing::benchmark::State& state) {
|
||||
const int num_x = state.range(0);
|
||||
const int num_y = state.range(1);
|
||||
|
||||
Do3DXZReduce(state, "gpu", "Sum", num_x, num_y);
|
||||
}
|
||||
BENCHMARK(BM_Sum3DXZReduceGPU)->RangePair(64, 4096, 64, 4096);
|
||||
|
||||
static void BM_Mean2DToScalarGPU(int iters, int num_x, int num_y) {
|
||||
ReduceToScalar<float>(iters, "gpu", "Mean", num_x, num_y);
|
||||
static void BM_Mean2DToScalarGPU(::testing::benchmark::State& state) {
|
||||
const int num_x = state.range(0);
|
||||
const int num_y = state.range(1);
|
||||
|
||||
ReduceToScalar<float>(state, "gpu", "Mean", num_x, num_y);
|
||||
}
|
||||
BENCHMARK(BM_Mean2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192);
|
||||
|
||||
static void BM_EuclideanNorm2DToScalarGPU(int iters, int num_x, int num_y) {
|
||||
ReduceToScalar<float>(iters, "gpu", "EuclideanNorm", num_x, num_y);
|
||||
static void BM_EuclideanNorm2DToScalarGPU(::testing::benchmark::State& state) {
|
||||
const int num_x = state.range(0);
|
||||
const int num_y = state.range(1);
|
||||
|
||||
ReduceToScalar<float>(state, "gpu", "EuclideanNorm", num_x, num_y);
|
||||
}
|
||||
BENCHMARK(BM_EuclideanNorm2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192);
|
||||
|
||||
static void BM_Max2DToScalarGPU(int iters, int num_x, int num_y) {
|
||||
ReduceToScalar<float>(iters, "gpu", "Max", num_x, num_y);
|
||||
static void BM_Max2DToScalarGPU(::testing::benchmark::State& state) {
|
||||
const int num_x = state.range(0);
|
||||
const int num_y = state.range(1);
|
||||
|
||||
ReduceToScalar<float>(state, "gpu", "Max", num_x, num_y);
|
||||
}
|
||||
BENCHMARK(BM_Max2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192);
|
||||
|
||||
static void BM_Min2DToScalarGPU(int iters, int num_x, int num_y) {
|
||||
ReduceToScalar<float>(iters, "gpu", "Min", num_x, num_y);
|
||||
static void BM_Min2DToScalarGPU(::testing::benchmark::State& state) {
|
||||
const int num_x = state.range(0);
|
||||
const int num_y = state.range(1);
|
||||
|
||||
ReduceToScalar<float>(state, "gpu", "Min", num_x, num_y);
|
||||
}
|
||||
BENCHMARK(BM_Min2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192);
|
||||
|
||||
static void BM_Min2DToScalarGPUHalf(int iters, int num_x, int num_y) {
|
||||
ReduceToScalar<Eigen::half>(iters, "gpu", "Min", num_x, num_y);
|
||||
static void BM_Min2DToScalarGPUHalf(::testing::benchmark::State& state) {
|
||||
const int num_x = state.range(0);
|
||||
const int num_y = state.range(1);
|
||||
|
||||
ReduceToScalar<Eigen::half>(state, "gpu", "Min", num_x, num_y);
|
||||
}
|
||||
BENCHMARK(BM_Min2DToScalarGPUHalf)->RangePair(2048, 8192, 2048, 8192);
|
||||
|
||||
static void BM_Bool2DToScalarGPU(int iters, int num_x, int num_y) {
|
||||
ReduceToScalar<bool>(iters, "gpu", "All", num_x, num_y);
|
||||
static void BM_Bool2DToScalarGPU(::testing::benchmark::State& state) {
|
||||
const int num_x = state.range(0);
|
||||
const int num_y = state.range(1);
|
||||
|
||||
ReduceToScalar<bool>(state, "gpu", "All", num_x, num_y);
|
||||
}
|
||||
BENCHMARK(BM_Bool2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192);
|
||||
|
||||
|
@ -84,17 +84,17 @@ Graph* SetupRegexReplaceGraph(const Tensor& input, const string& input_pattern,
|
||||
return g;
|
||||
}
|
||||
|
||||
void BM_RegexReplace(int iters, int batch_size) {
|
||||
testing::StopTiming();
|
||||
testing::ItemsProcessed(static_cast<int64>(iters));
|
||||
testing::UseRealTime();
|
||||
static void BM_RegexReplace(::testing::benchmark::State& state) {
|
||||
const int batch_size = state.range(0);
|
||||
|
||||
Tensor input = GetTestTensor(batch_size);
|
||||
Graph* g = SetupRegexReplaceGraph(input, kRegExPattern, kRewrite);
|
||||
testing::StartTiming();
|
||||
test::Benchmark("cpu", g).Run(iters);
|
||||
test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state);
|
||||
state.SetItemsProcessed(static_cast<int64>(state.iterations()));
|
||||
}
|
||||
|
||||
BENCHMARK(BM_RegexReplace)
|
||||
->UseRealTime()
|
||||
->Arg(1)
|
||||
->Arg(8)
|
||||
->Arg(16)
|
||||
@ -115,17 +115,17 @@ Graph* SetupStaticGraph(const Tensor& input, const string& input_pattern,
|
||||
.Finalize(g, nullptr /* node */));
|
||||
return g;
|
||||
}
|
||||
void BM_StaticRegexReplace(int iters, int batch_size) {
|
||||
testing::StopTiming();
|
||||
testing::ItemsProcessed(static_cast<int64>(iters));
|
||||
testing::UseRealTime();
|
||||
static void BM_StaticRegexReplace(::testing::benchmark::State& state) {
|
||||
const int batch_size = state.range(0);
|
||||
|
||||
Tensor input = GetTestTensor(batch_size);
|
||||
Graph* g = SetupStaticGraph(input, kRegExPattern, kRewrite);
|
||||
testing::StartTiming();
|
||||
test::Benchmark("cpu", g).Run(iters);
|
||||
test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state);
|
||||
state.SetItemsProcessed(static_cast<int64>(state.iterations()));
|
||||
}
|
||||
|
||||
BENCHMARK(BM_StaticRegexReplace)
|
||||
->UseRealTime()
|
||||
->Arg(1)
|
||||
->Arg(8)
|
||||
->Arg(16)
|
||||
|
@ -67,56 +67,29 @@ TEST_F(RequantizationRangeTest, HandCrafted) {
|
||||
test::ExpectTensorEqual<float>(expected_max, *GetOutput(1));
|
||||
}
|
||||
|
||||
static void BM_RequantizationRange(int iters, int size) {
|
||||
testing::StopTiming();
|
||||
testing::UseRealTime();
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * size);
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * size * 4);
|
||||
static void BM_RequantizationRange(::testing::benchmark::State& state) {
|
||||
const int size = state.range(0);
|
||||
|
||||
Tensor quantized_tensor(DT_QINT32, TensorShape({1, size}));
|
||||
test::FillFn<qint32>(&quantized_tensor, [](int n) { return qint32(n); });
|
||||
|
||||
qint32 actual_min;
|
||||
qint32 actual_max;
|
||||
testing::StartTiming();
|
||||
for (int iter = 0; iter < iters; ++iter) {
|
||||
for (auto s : state) {
|
||||
CalculateUsedRange(quantized_tensor, &actual_min, &actual_max);
|
||||
}
|
||||
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * size);
|
||||
state.SetBytesProcessed(static_cast<int64>(state.iterations()) * size * 4);
|
||||
}
|
||||
|
||||
static void BM_RequantizationRange100(int iters) {
|
||||
BM_RequantizationRange(100, iters);
|
||||
}
|
||||
BENCHMARK(BM_RequantizationRange100);
|
||||
|
||||
static void BM_RequantizationRange1000(int iters) {
|
||||
BM_RequantizationRange(1000, iters);
|
||||
}
|
||||
BENCHMARK(BM_RequantizationRange1000);
|
||||
|
||||
static void BM_RequantizationRange10000(int iters) {
|
||||
BM_RequantizationRange(10000, iters);
|
||||
}
|
||||
BENCHMARK(BM_RequantizationRange10000);
|
||||
|
||||
static void BM_RequantizationRange100000(int iters) {
|
||||
BM_RequantizationRange(100000, iters);
|
||||
}
|
||||
BENCHMARK(BM_RequantizationRange100000);
|
||||
|
||||
static void BM_RequantizationRange1000000(int iters) {
|
||||
BM_RequantizationRange(1000000, iters);
|
||||
}
|
||||
BENCHMARK(BM_RequantizationRange1000000);
|
||||
|
||||
static void BM_RequantizationRange10000000(int iters) {
|
||||
BM_RequantizationRange(10000000, iters);
|
||||
}
|
||||
BENCHMARK(BM_RequantizationRange10000000);
|
||||
|
||||
static void BM_RequantizationRange100000000(int iters) {
|
||||
BM_RequantizationRange(100000000, iters);
|
||||
}
|
||||
BENCHMARK(BM_RequantizationRange100000000);
|
||||
BENCHMARK(BM_RequantizationRange)
|
||||
->UseRealTime()
|
||||
->Arg(100)
|
||||
->Arg(1000)
|
||||
->Arg(10000)
|
||||
->Arg(100000)
|
||||
->Arg(1000000)
|
||||
->Arg(10000000)
|
||||
->Arg(100000000);
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -197,148 +197,187 @@ static Graph* Reverse(const TensorShape& shape, int reverse_axis) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void RunReverseRowsBenchmark(int iters, int outer_dim, int middle_dim,
|
||||
static void RunReverseRowsBenchmark(::testing::benchmark::State& state,
|
||||
int outer_dim, int middle_dim,
|
||||
int intra_threads, int channels) {
|
||||
SessionOptions opts = GetOptions(intra_threads);
|
||||
TensorShape shape{outer_dim, middle_dim, channels};
|
||||
const int64 num_items = static_cast<int64>(iters) * shape.num_elements();
|
||||
testing::ItemsProcessed(num_items);
|
||||
testing::BytesProcessed(num_items * sizeof(T));
|
||||
testing::UseRealTime();
|
||||
test::Benchmark("cpu", Reverse<T>(shape, 1), &opts).Run(iters);
|
||||
test::Benchmark("cpu", Reverse<T>(shape, 1), &opts, nullptr, nullptr, "",
|
||||
/*old_benchmark_api*/ false)
|
||||
.Run(state);
|
||||
const int64 num_items =
|
||||
static_cast<int64>(state.iterations()) * shape.num_elements();
|
||||
state.SetItemsProcessed(num_items);
|
||||
state.SetBytesProcessed(num_items * sizeof(T));
|
||||
}
|
||||
|
||||
static void BM_ReverseRowsOf1Channel_1T_float(int iters, int outer_dim,
|
||||
int middle_dim) {
|
||||
RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim,
|
||||
void BM_ReverseRowsOf1Channel_1T_float(::testing::benchmark::State& state) {
|
||||
const int outer_dim = state.range(0);
|
||||
const int middle_dim = state.range(1);
|
||||
|
||||
RunReverseRowsBenchmark<float>(state, outer_dim, middle_dim,
|
||||
1 /* intra_threads */, 1 /* channels */);
|
||||
}
|
||||
|
||||
BENCHMARK(BM_ReverseRowsOf1Channel_1T_float)
|
||||
->UseRealTime()
|
||||
->ArgPair(288, 288)
|
||||
->ArgPair(1024, 1024)
|
||||
->ArgPair(10 * 1024, 1024);
|
||||
|
||||
static void BM_ReverseRowsOf1Channel_1T_uint8(int iters, int outer_dim,
|
||||
int middle_dim) {
|
||||
RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim,
|
||||
void BM_ReverseRowsOf1Channel_1T_uint8(::testing::benchmark::State& state) {
|
||||
const int outer_dim = state.range(0);
|
||||
const int middle_dim = state.range(1);
|
||||
|
||||
RunReverseRowsBenchmark<uint8>(state, outer_dim, middle_dim,
|
||||
1 /* intra_threads */, 1 /* channels */);
|
||||
}
|
||||
|
||||
BENCHMARK(BM_ReverseRowsOf1Channel_1T_uint8)
|
||||
->UseRealTime()
|
||||
->ArgPair(288, 288)
|
||||
->ArgPair(1024, 1024)
|
||||
->ArgPair(10 * 1024, 1024);
|
||||
|
||||
static void BM_ReverseRowsOf1Channel_4T_float(int iters, int outer_dim,
|
||||
int middle_dim) {
|
||||
RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim,
|
||||
void BM_ReverseRowsOf1Channel_4T_float(::testing::benchmark::State& state) {
|
||||
const int outer_dim = state.range(0);
|
||||
const int middle_dim = state.range(1);
|
||||
|
||||
RunReverseRowsBenchmark<float>(state, outer_dim, middle_dim,
|
||||
4 /* intra_threads */, 1 /* channels */);
|
||||
}
|
||||
|
||||
BENCHMARK(BM_ReverseRowsOf1Channel_4T_float)
|
||||
->UseRealTime()
|
||||
->ArgPair(288, 288)
|
||||
->ArgPair(1024, 1024)
|
||||
->ArgPair(10 * 1024, 1024);
|
||||
|
||||
static void BM_ReverseRowsOf1Channel_4T_uint8(int iters, int outer_dim,
|
||||
int middle_dim) {
|
||||
RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim,
|
||||
void BM_ReverseRowsOf1Channel_4T_uint8(::testing::benchmark::State& state) {
|
||||
const int outer_dim = state.range(0);
|
||||
const int middle_dim = state.range(1);
|
||||
|
||||
RunReverseRowsBenchmark<uint8>(state, outer_dim, middle_dim,
|
||||
4 /* intra_threads */, 1 /* channels */);
|
||||
}
|
||||
|
||||
BENCHMARK(BM_ReverseRowsOf1Channel_4T_uint8)
|
||||
->UseRealTime()
|
||||
->ArgPair(288, 288)
|
||||
->ArgPair(1024, 1024)
|
||||
->ArgPair(10 * 1024, 1024);
|
||||
|
||||
static void BM_ReverseRowsOf3Channels_1T_float(int iters, int outer_dim,
|
||||
int middle_dim) {
|
||||
RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim,
|
||||
void BM_ReverseRowsOf3Channels_1T_float(::testing::benchmark::State& state) {
|
||||
const int outer_dim = state.range(0);
|
||||
const int middle_dim = state.range(1);
|
||||
|
||||
RunReverseRowsBenchmark<float>(state, outer_dim, middle_dim,
|
||||
1 /* intra_threads */, 3 /* channels */);
|
||||
}
|
||||
|
||||
BENCHMARK(BM_ReverseRowsOf3Channels_1T_float)
|
||||
->UseRealTime()
|
||||
->ArgPair(288, 288)
|
||||
->ArgPair(30, 30)
|
||||
->ArgPair(1024, 1024)
|
||||
->ArgPair(10 * 1024, 1024);
|
||||
|
||||
static void BM_ReverseRowsOf3Channels_1T_uint8(int iters, int outer_dim,
|
||||
int middle_dim) {
|
||||
RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim,
|
||||
void BM_ReverseRowsOf3Channels_1T_uint8(::testing::benchmark::State& state) {
|
||||
const int outer_dim = state.range(0);
|
||||
const int middle_dim = state.range(1);
|
||||
|
||||
RunReverseRowsBenchmark<uint8>(state, outer_dim, middle_dim,
|
||||
1 /* intra_threads */, 3 /* channels */);
|
||||
}
|
||||
|
||||
BENCHMARK(BM_ReverseRowsOf3Channels_1T_uint8)
|
||||
->UseRealTime()
|
||||
->ArgPair(288, 288)
|
||||
->ArgPair(30, 30)
|
||||
->ArgPair(1024, 1024)
|
||||
->ArgPair(10 * 1024, 1024);
|
||||
|
||||
static void BM_ReverseRowsOf3Channels_4T_float(int iters, int outer_dim,
|
||||
int middle_dim) {
|
||||
RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim,
|
||||
void BM_ReverseRowsOf3Channels_4T_float(::testing::benchmark::State& state) {
|
||||
const int outer_dim = state.range(0);
|
||||
const int middle_dim = state.range(1);
|
||||
|
||||
RunReverseRowsBenchmark<float>(state, outer_dim, middle_dim,
|
||||
4 /* intra_threads */, 3 /* channels */);
|
||||
}
|
||||
|
||||
BENCHMARK(BM_ReverseRowsOf3Channels_4T_float)
|
||||
->UseRealTime()
|
||||
->ArgPair(288, 288)
|
||||
->ArgPair(30, 30)
|
||||
->ArgPair(1024, 1024)
|
||||
->ArgPair(10 * 1024, 1024);
|
||||
|
||||
static void BM_ReverseRowsOf3Channels_4T_uint8(int iters, int outer_dim,
|
||||
int middle_dim) {
|
||||
RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim,
|
||||
void BM_ReverseRowsOf3Channels_4T_uint8(::testing::benchmark::State& state) {
|
||||
const int outer_dim = state.range(0);
|
||||
const int middle_dim = state.range(1);
|
||||
|
||||
RunReverseRowsBenchmark<uint8>(state, outer_dim, middle_dim,
|
||||
4 /* intra_threads */, 3 /* channels */);
|
||||
}
|
||||
BENCHMARK(BM_ReverseRowsOf3Channels_4T_uint8)
|
||||
->UseRealTime()
|
||||
->ArgPair(288, 288)
|
||||
->ArgPair(30, 30)
|
||||
->ArgPair(1024, 1024)
|
||||
->ArgPair(10 * 1024, 1024);
|
||||
|
||||
static void BM_ReverseRowsOf4Channels_1T_float(int iters, int outer_dim,
|
||||
int middle_dim) {
|
||||
RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim,
|
||||
void BM_ReverseRowsOf4Channels_1T_float(::testing::benchmark::State& state) {
|
||||
const int outer_dim = state.range(0);
|
||||
const int middle_dim = state.range(1);
|
||||
|
||||
RunReverseRowsBenchmark<float>(state, outer_dim, middle_dim,
|
||||
1 /* intra_threads */, 4 /* channels */);
|
||||
}
|
||||
|
||||
BENCHMARK(BM_ReverseRowsOf4Channels_1T_float)
|
||||
->UseRealTime()
|
||||
->ArgPair(288, 288)
|
||||
->ArgPair(1024, 1024)
|
||||
->ArgPair(10 * 1024, 1024);
|
||||
|
||||
static void BM_ReverseRowsOf4Channels_1T_uint8(int iters, int outer_dim,
|
||||
int middle_dim) {
|
||||
RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim,
|
||||
void BM_ReverseRowsOf4Channels_1T_uint8(::testing::benchmark::State& state) {
|
||||
const int outer_dim = state.range(0);
|
||||
const int middle_dim = state.range(1);
|
||||
|
||||
RunReverseRowsBenchmark<uint8>(state, outer_dim, middle_dim,
|
||||
1 /* intra_threads */, 4 /* channels */);
|
||||
}
|
||||
|
||||
BENCHMARK(BM_ReverseRowsOf4Channels_1T_uint8)
|
||||
->UseRealTime()
|
||||
->ArgPair(288, 288)
|
||||
->ArgPair(1024, 1024)
|
||||
->ArgPair(10 * 1024, 1024);
|
||||
|
||||
static void BM_ReverseRowsOf4Channels_4T_float(int iters, int outer_dim,
|
||||
int middle_dim) {
|
||||
RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim,
|
||||
void BM_ReverseRowsOf4Channels_4T_float(::testing::benchmark::State& state) {
|
||||
const int outer_dim = state.range(0);
|
||||
const int middle_dim = state.range(1);
|
||||
|
||||
RunReverseRowsBenchmark<float>(state, outer_dim, middle_dim,
|
||||
4 /* intra_threads */, 4 /* channels */);
|
||||
}
|
||||
|
||||
BENCHMARK(BM_ReverseRowsOf4Channels_4T_float)
|
||||
->UseRealTime()
|
||||
->ArgPair(288, 288)
|
||||
->ArgPair(1024, 1024)
|
||||
->ArgPair(10 * 1024, 1024);
|
||||
|
||||
static void BM_ReverseRowsOf4Channels_4T_uint8(int iters, int outer_dim,
|
||||
int middle_dim) {
|
||||
RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim,
|
||||
void BM_ReverseRowsOf4Channels_4T_uint8(::testing::benchmark::State& state) {
|
||||
const int outer_dim = state.range(0);
|
||||
const int middle_dim = state.range(1);
|
||||
|
||||
RunReverseRowsBenchmark<uint8>(state, outer_dim, middle_dim,
|
||||
4 /* intra_threads */, 4 /* channels */);
|
||||
}
|
||||
|
||||
BENCHMARK(BM_ReverseRowsOf4Channels_4T_uint8)
|
||||
->UseRealTime()
|
||||
->ArgPair(288, 288)
|
||||
->ArgPair(1024, 1024)
|
||||
->ArgPair(10 * 1024, 1024);
|
||||
|
@ -451,30 +451,40 @@ static Graph* RollGraph(const TensorShape& shape, int isd) {
|
||||
}
|
||||
|
||||
#define BM_ROLL_OUTER(DEVICE) \
|
||||
static void BM_##DEVICE##_roll_outer(int iters, int rows, int columns) { \
|
||||
static void BM_##DEVICE##_roll_outer(::testing::benchmark::State& state) { \
|
||||
const int rows = state.range(0); \
|
||||
const int columns = state.range(1); \
|
||||
\
|
||||
TensorShape shape{rows, columns}; \
|
||||
const int64 num_items = static_cast<int64>(iters) * shape.num_elements(); \
|
||||
testing::ItemsProcessed(num_items); \
|
||||
testing::BytesProcessed(num_items * sizeof(float)); \
|
||||
testing::UseRealTime(); \
|
||||
test::Benchmark(#DEVICE, RollGraph(shape, 0)).Run(iters); \
|
||||
test::Benchmark(#DEVICE, RollGraph(shape, 0), /*old_benchmark_api*/ false) \
|
||||
.Run(state); \
|
||||
const int64 num_items = \
|
||||
static_cast<int64>(state.iterations()) * shape.num_elements(); \
|
||||
state.SetItemsProcessed(num_items); \
|
||||
state.SetBytesProcessed(num_items * sizeof(float)); \
|
||||
} \
|
||||
BENCHMARK(BM_##DEVICE##_roll_outer) \
|
||||
->UseRealTime() \
|
||||
->ArgPair(256, 256) \
|
||||
->ArgPair(512, 512) \
|
||||
->ArgPair(1024, 1024) \
|
||||
->ArgPair(2048, 2048)
|
||||
|
||||
#define BM_ROLL_ALL(DEVICE) \
|
||||
static void BM_##DEVICE##_roll_all(int iters, int rows, int columns) { \
|
||||
static void BM_##DEVICE##_roll_all(::testing::benchmark::State& state) { \
|
||||
const int rows = state.range(0); \
|
||||
const int columns = state.range(1); \
|
||||
\
|
||||
TensorShape shape{rows, columns}; \
|
||||
const int64 num_items = static_cast<int64>(iters) * shape.num_elements(); \
|
||||
testing::ItemsProcessed(num_items); \
|
||||
testing::BytesProcessed(num_items * sizeof(float)); \
|
||||
testing::UseRealTime(); \
|
||||
test::Benchmark(#DEVICE, RollGraph(shape, 1)).Run(iters); \
|
||||
test::Benchmark(#DEVICE, RollGraph(shape, 1), /*old_benchmark_api*/ false) \
|
||||
.Run(state); \
|
||||
const int64 num_items = \
|
||||
static_cast<int64>(state.iterations()) * shape.num_elements(); \
|
||||
state.SetItemsProcessed(num_items); \
|
||||
state.SetBytesProcessed(num_items * sizeof(float)); \
|
||||
} \
|
||||
BENCHMARK(BM_##DEVICE##_roll_all) \
|
||||
->UseRealTime() \
|
||||
->ArgPair(256, 256) \
|
||||
->ArgPair(512, 512) \
|
||||
->ArgPair(1024, 1024) \
|
||||
|
@ -663,8 +663,8 @@ TEST_F(SaveOpSlices2Test, TwoSlices) {
|
||||
|
||||
// Benchmark-related code below.
|
||||
|
||||
static void BM_LargeTensorWrite(int iters, int num_elements) {
|
||||
testing::StopTiming();
|
||||
void BM_LargeTensorWrite(::testing::benchmark::State& state) {
|
||||
const int num_elements = state.range(0);
|
||||
|
||||
// 4 * num_elements bytes total , since sizeof(float) == 4.
|
||||
Tensor tensor(DT_FLOAT, TensorShape({num_elements}));
|
||||
@ -689,8 +689,9 @@ static void BM_LargeTensorWrite(int iters, int num_elements) {
|
||||
VLOG(1) << "Save op's output path: " << temp_filename;
|
||||
VLOG(1) << "# nodes in Graph: " << g->num_nodes();
|
||||
|
||||
testing::StartTiming();
|
||||
test::Benchmark("cpu", g, &session_options).Run(iters);
|
||||
test::Benchmark("cpu", g, &session_options, nullptr, nullptr, "",
|
||||
/*old_benchmark_api*/ false)
|
||||
.Run(state);
|
||||
}
|
||||
BENCHMARK(BM_LargeTensorWrite)->Arg((1 << 30) / 4 /* 1GB float tensor */);
|
||||
|
||||
|
@ -67,79 +67,120 @@ static Graph* ThreeDYCumsum(int num_y, int num_z, bool reverse = false) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void LargeOneDimensional(int iters, const string& device, int num_x,
|
||||
static void LargeOneDimensional(::testing::benchmark::State& state,
|
||||
const string& device, int num_x,
|
||||
bool reverse = false) {
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * num_x);
|
||||
testing::BytesProcessed(static_cast<int64>(iters) * num_x * sizeof(T));
|
||||
test::Benchmark(device, LargeOneDCumsum<T>(num_x, reverse)).Run(iters);
|
||||
test::Benchmark(device, LargeOneDCumsum<T>(num_x, reverse),
|
||||
/*old_benchmark_api*/ false)
|
||||
.Run(state);
|
||||
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x);
|
||||
state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
|
||||
sizeof(T));
|
||||
}
|
||||
|
||||
static void DoRowCumsum(int iters, const string& device, int num_x, int num_y,
|
||||
static void DoRowCumsum(::testing::benchmark::State& state,
|
||||
const string& device, int num_x, int num_y,
|
||||
bool reverse = false) {
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
|
||||
testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
|
||||
sizeof(float));
|
||||
test::Benchmark(device, RowCumsum(num_x, num_y, reverse)).Run(iters);
|
||||
test::Benchmark(device, RowCumsum(num_x, num_y, reverse),
|
||||
/*old_benchmark_api*/ false)
|
||||
.Run(state);
|
||||
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x *
|
||||
num_y);
|
||||
state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
|
||||
num_y * sizeof(float));
|
||||
}
|
||||
|
||||
static void DoColCumsum(int iters, const string& device, int num_x, int num_y,
|
||||
static void DoColCumsum(::testing::benchmark::State& state,
|
||||
const string& device, int num_x, int num_y,
|
||||
bool reverse = false) {
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
|
||||
testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
|
||||
sizeof(float));
|
||||
test::Benchmark(device, ColCumsum(num_x, num_y, reverse)).Run(iters);
|
||||
test::Benchmark(device, ColCumsum(num_x, num_y, reverse),
|
||||
/*old_benchmark_api*/ false)
|
||||
.Run(state);
|
||||
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x *
|
||||
num_y);
|
||||
state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
|
||||
num_y * sizeof(float));
|
||||
}
|
||||
|
||||
static void Do3DYCumsum(int iters, const string& device, int num_x, int num_y,
|
||||
static void Do3DYCumsum(::testing::benchmark::State& state,
|
||||
const string& device, int num_x, int num_y,
|
||||
bool reverse = false) {
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
|
||||
testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
|
||||
sizeof(float));
|
||||
test::Benchmark(device, ThreeDYCumsum(num_x, num_y, reverse)).Run(iters);
|
||||
test::Benchmark(device, ThreeDYCumsum(num_x, num_y, reverse),
|
||||
/*old_benchmark_api*/ false)
|
||||
.Run(state);
|
||||
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x *
|
||||
num_y);
|
||||
state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
|
||||
num_y * sizeof(float));
|
||||
}
|
||||
|
||||
static void BM_OneDCumsumGPU(int iters, int num_x) {
|
||||
LargeOneDimensional<float>(iters, "gpu", num_x);
|
||||
static void BM_OneDCumsumGPU(::testing::benchmark::State& state) {
|
||||
const int num_x = state.range(0);
|
||||
|
||||
LargeOneDimensional<float>(state, "gpu", num_x);
|
||||
}
|
||||
BENCHMARK(BM_OneDCumsumGPU)->Range(1, 1 << 21);
|
||||
|
||||
static void BM_OneDCumsumGPUHalf(int iters, int num_x) {
|
||||
LargeOneDimensional<Eigen::half>(iters, "gpu", num_x);
|
||||
static void BM_OneDCumsumGPUHalf(::testing::benchmark::State& state) {
|
||||
const int num_x = state.range(0);
|
||||
|
||||
LargeOneDimensional<Eigen::half>(state, "gpu", num_x);
|
||||
}
|
||||
BENCHMARK(BM_OneDCumsumGPUHalf)->Range(1, 1 << 21);
|
||||
|
||||
static void BM_Sum2DRowCumsumGPU(int iters, int num_x, int num_y) {
|
||||
DoRowCumsum(iters, "gpu", num_x, num_y);
|
||||
static void BM_Sum2DRowCumsumGPU(::testing::benchmark::State& state) {
|
||||
const int num_x = state.range(0);
|
||||
const int num_y = state.range(1);
|
||||
|
||||
DoRowCumsum(state, "gpu", num_x, num_y);
|
||||
}
|
||||
BENCHMARK(BM_Sum2DRowCumsumGPU)->RangePair(1, 8192, 1, 8192);
|
||||
|
||||
static void BM_Sum2DColumnCumsumGPU(int iters, int num_x, int num_y) {
|
||||
DoColCumsum(iters, "gpu", num_x, num_y);
|
||||
static void BM_Sum2DColumnCumsumGPU(::testing::benchmark::State& state) {
|
||||
const int num_x = state.range(0);
|
||||
const int num_y = state.range(1);
|
||||
|
||||
DoColCumsum(state, "gpu", num_x, num_y);
|
||||
}
|
||||
BENCHMARK(BM_Sum2DColumnCumsumGPU)->RangePair(1, 8192, 1, 8192);
|
||||
|
||||
static void BM_Sum3DYCumsumGPU(int iters, int num_x, int num_y) {
|
||||
Do3DYCumsum(iters, "gpu", num_x, num_y);
|
||||
static void BM_Sum3DYCumsumGPU(::testing::benchmark::State& state) {
|
||||
const int num_x = state.range(0);
|
||||
const int num_y = state.range(1);
|
||||
|
||||
Do3DYCumsum(state, "gpu", num_x, num_y);
|
||||
}
|
||||
BENCHMARK(BM_Sum3DYCumsumGPU)->RangePair(64, 4096, 64, 4096);
|
||||
|
||||
static void BM_OneDCumsumGPU_reverse(int iters, int num_x) {
|
||||
LargeOneDimensional<float>(iters, "gpu", num_x, true);
|
||||
static void BM_OneDCumsumGPU_reverse(::testing::benchmark::State& state) {
|
||||
const int num_x = state.range(0);
|
||||
|
||||
LargeOneDimensional<float>(state, "gpu", num_x, true);
|
||||
}
|
||||
BENCHMARK(BM_OneDCumsumGPU_reverse)->Range(1, 1 << 21);
|
||||
|
||||
static void BM_Sum2DRowCumsumGPU_reverse(int iters, int num_x, int num_y) {
|
||||
DoRowCumsum(iters, "gpu", num_x, num_y, true);
|
||||
static void BM_Sum2DRowCumsumGPU_reverse(::testing::benchmark::State& state) {
|
||||
const int num_x = state.range(0);
|
||||
const int num_y = state.range(1);
|
||||
|
||||
DoRowCumsum(state, "gpu", num_x, num_y, true);
|
||||
}
|
||||
BENCHMARK(BM_Sum2DRowCumsumGPU_reverse)->RangePair(1, 8192, 1, 8192);
|
||||
|
||||
static void BM_Sum2DColumnCumsumGPU_reverse(int iters, int num_x, int num_y) {
|
||||
DoColCumsum(iters, "gpu", num_x, num_y, true);
|
||||
static void BM_Sum2DColumnCumsumGPU_reverse(
|
||||
::testing::benchmark::State& state) {
|
||||
const int num_x = state.range(0);
|
||||
const int num_y = state.range(1);
|
||||
|
||||
DoColCumsum(state, "gpu", num_x, num_y, true);
|
||||
}
|
||||
BENCHMARK(BM_Sum2DColumnCumsumGPU_reverse)->RangePair(1, 8192, 1, 8192);
|
||||
|
||||
static void BM_Sum3DYCumsumGPU_reverse(int iters, int num_x, int num_y) {
|
||||
Do3DYCumsum(iters, "gpu", num_x, num_y, true);
|
||||
static void BM_Sum3DYCumsumGPU_reverse(::testing::benchmark::State& state) {
|
||||
const int num_x = state.range(0);
|
||||
const int num_y = state.range(1);
|
||||
|
||||
Do3DYCumsum(state, "gpu", num_x, num_y, true);
|
||||
}
|
||||
BENCHMARK(BM_Sum3DYCumsumGPU_reverse)->RangePair(32, 2048, 32, 2048);
|
||||
|
||||
|
@ -254,8 +254,8 @@ class ScatterNdUpdateBM : public ScatterNdUpdateOpTest {
|
||||
};
|
||||
|
||||
template <typename Index>
|
||||
static void BM_ScatterNdHelper(int iters, int embedding_size, const char* op) {
|
||||
testing::StopTiming();
|
||||
void BM_ScatterNdHelper(::testing::benchmark::State& state, int embedding_size,
|
||||
const char* op) {
|
||||
const int kRows = 10000000 / embedding_size;
|
||||
std::vector<float> values;
|
||||
values.reserve(kRows);
|
||||
@ -280,27 +280,33 @@ static void BM_ScatterNdHelper(int iters, int embedding_size, const char* op) {
|
||||
bm.AddInputFromArray<Index>(TensorShape({kNumUpdates}), indices);
|
||||
bm.AddInputFromArray<float>(TensorShape({kNumUpdates, embedding_size}),
|
||||
updates);
|
||||
testing::ItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) *
|
||||
iters);
|
||||
testing::StartTiming();
|
||||
while (iters-- > 0) {
|
||||
for (auto i : state) {
|
||||
Status s = bm.RunOpKernel();
|
||||
}
|
||||
testing::StopTiming();
|
||||
state.SetItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) *
|
||||
state.iterations());
|
||||
}
|
||||
|
||||
static void BM_ScatterNdUpdateInt32(int iters, int embedding_size) {
|
||||
BM_ScatterNdHelper<int32>(iters, embedding_size, "ScatterNdUpdate");
|
||||
void BM_ScatterNdUpdateInt32(::testing::benchmark::State& state) {
|
||||
const int embedding_size = state.range(0);
|
||||
|
||||
BM_ScatterNdHelper<int32>(state, embedding_size, "ScatterNdUpdate");
|
||||
}
|
||||
static void BM_ScatterNdUpdateInt64(int iters, int embedding_size) {
|
||||
BM_ScatterNdHelper<int64>(iters, embedding_size, "ScatterNdUpdate");
|
||||
void BM_ScatterNdUpdateInt64(::testing::benchmark::State& state) {
|
||||
const int embedding_size = state.range(0);
|
||||
|
||||
BM_ScatterNdHelper<int64>(state, embedding_size, "ScatterNdUpdate");
|
||||
}
|
||||
|
||||
static void BM_ScatterNdAddInt32(int iters, int embedding_size) {
|
||||
BM_ScatterNdHelper<int32>(iters, embedding_size, "ScatterNdAdd");
|
||||
void BM_ScatterNdAddInt32(::testing::benchmark::State& state) {
|
||||
const int embedding_size = state.range(0);
|
||||
|
||||
BM_ScatterNdHelper<int32>(state, embedding_size, "ScatterNdAdd");
|
||||
}
|
||||
static void BM_ScatterNdAddInt64(int iters, int embedding_size) {
|
||||
BM_ScatterNdHelper<int64>(iters, embedding_size, "ScatterNdAdd");
|
||||
void BM_ScatterNdAddInt64(::testing::benchmark::State& state) {
|
||||
const int embedding_size = state.range(0);
|
||||
|
||||
BM_ScatterNdHelper<int64>(state, embedding_size, "ScatterNdAdd");
|
||||
}
|
||||
|
||||
BENCHMARK(BM_ScatterNdUpdateInt32)
|
||||
|
@ -280,9 +280,8 @@ class ScatterUpdateBM : public ScatterUpdateOpTest {
|
||||
};
|
||||
|
||||
template <typename Index>
|
||||
static void BM_ScatterHelper(int iters, int embedding_size, const char* op,
|
||||
bool big_num_updates = false) {
|
||||
testing::StopTiming();
|
||||
void BM_ScatterHelper(::testing::benchmark::State& state, int embedding_size,
|
||||
const char* op, bool big_num_updates = false) {
|
||||
const int kRows = 10000000 / embedding_size;
|
||||
std::vector<float> values;
|
||||
values.reserve(kRows);
|
||||
@ -307,59 +306,83 @@ static void BM_ScatterHelper(int iters, int embedding_size, const char* op,
|
||||
bm.AddInputFromArray<Index>(TensorShape({kNumUpdates}), indices);
|
||||
bm.AddInputFromArray<float>(TensorShape({kNumUpdates, embedding_size}),
|
||||
updates);
|
||||
testing::ItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) *
|
||||
iters);
|
||||
testing::StartTiming();
|
||||
while (iters-- > 0) {
|
||||
for (auto i : state) {
|
||||
Status s = bm.RunOpKernel();
|
||||
}
|
||||
testing::StopTiming();
|
||||
state.SetItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) *
|
||||
state.iterations());
|
||||
}
|
||||
|
||||
static void BM_ScatterUpdateInt32(int iters, int embedding_size) {
|
||||
BM_ScatterHelper<int32>(iters, embedding_size, "ScatterUpdate");
|
||||
void BM_ScatterUpdateInt32(::testing::benchmark::State& state) {
|
||||
const int embedding_size = state.range(0);
|
||||
|
||||
BM_ScatterHelper<int32>(state, embedding_size, "ScatterUpdate");
|
||||
}
|
||||
static void BM_ScatterUpdateInt64(int iters, int embedding_size) {
|
||||
BM_ScatterHelper<int64>(iters, embedding_size, "ScatterUpdate");
|
||||
void BM_ScatterUpdateInt64(::testing::benchmark::State& state) {
|
||||
const int embedding_size = state.range(0);
|
||||
|
||||
BM_ScatterHelper<int64>(state, embedding_size, "ScatterUpdate");
|
||||
}
|
||||
|
||||
static void BM_ScatterAddInt32(int iters, int embedding_size) {
|
||||
BM_ScatterHelper<int32>(iters, embedding_size, "ScatterAdd");
|
||||
void BM_ScatterAddInt32(::testing::benchmark::State& state) {
|
||||
const int embedding_size = state.range(0);
|
||||
|
||||
BM_ScatterHelper<int32>(state, embedding_size, "ScatterAdd");
|
||||
}
|
||||
|
||||
static void BM_ScatterAddInt32Large(int iters, int embedding_size) {
|
||||
BM_ScatterHelper<int32>(iters, embedding_size, "ScatterAdd", true);
|
||||
void BM_ScatterAddInt32Large(::testing::benchmark::State& state) {
|
||||
const int embedding_size = state.range(0);
|
||||
|
||||
BM_ScatterHelper<int32>(state, embedding_size, "ScatterAdd", true);
|
||||
}
|
||||
static void BM_ScatterAddInt64(int iters, int embedding_size) {
|
||||
BM_ScatterHelper<int64>(iters, embedding_size, "ScatterAdd");
|
||||
void BM_ScatterAddInt64(::testing::benchmark::State& state) {
|
||||
const int embedding_size = state.range(0);
|
||||
|
||||
BM_ScatterHelper<int64>(state, embedding_size, "ScatterAdd");
|
||||
}
|
||||
|
||||
static void BM_ScatterMulInt32(int iters, int embedding_size) {
|
||||
BM_ScatterHelper<int32>(iters, embedding_size, "ScatterMul");
|
||||
void BM_ScatterMulInt32(::testing::benchmark::State& state) {
|
||||
const int embedding_size = state.range(0);
|
||||
|
||||
BM_ScatterHelper<int32>(state, embedding_size, "ScatterMul");
|
||||
}
|
||||
static void BM_ScatterMulInt64(int iters, int embedding_size) {
|
||||
BM_ScatterHelper<int64>(iters, embedding_size, "ScatterMul");
|
||||
void BM_ScatterMulInt64(::testing::benchmark::State& state) {
|
||||
const int embedding_size = state.range(0);
|
||||
|
||||
BM_ScatterHelper<int64>(state, embedding_size, "ScatterMul");
|
||||
}
|
||||
|
||||
static void BM_ScatterDivInt32(int iters, int embedding_size) {
|
||||
BM_ScatterHelper<int32>(iters, embedding_size, "ScatterDiv");
|
||||
void BM_ScatterDivInt32(::testing::benchmark::State& state) {
|
||||
const int embedding_size = state.range(0);
|
||||
|
||||
BM_ScatterHelper<int32>(state, embedding_size, "ScatterDiv");
|
||||
}
|
||||
static void BM_ScatterDivInt64(int iters, int embedding_size) {
|
||||
BM_ScatterHelper<int64>(iters, embedding_size, "ScatterDiv");
|
||||
void BM_ScatterDivInt64(::testing::benchmark::State& state) {
|
||||
const int embedding_size = state.range(0);
|
||||
|
||||
BM_ScatterHelper<int64>(state, embedding_size, "ScatterDiv");
|
||||
}
|
||||
|
||||
static void BM_ScatterMinInt32(int iters, int embedding_size) {
|
||||
BM_ScatterHelper<int32>(iters, embedding_size, "ScatterMin");
|
||||
void BM_ScatterMinInt32(::testing::benchmark::State& state) {
|
||||
const int embedding_size = state.range(0);
|
||||
|
||||
BM_ScatterHelper<int32>(state, embedding_size, "ScatterMin");
|
||||
}
|
||||
static void BM_ScatterMinInt64(int iters, int embedding_size) {
|
||||
BM_ScatterHelper<int64>(iters, embedding_size, "ScatterMin");
|
||||
void BM_ScatterMinInt64(::testing::benchmark::State& state) {
|
||||
const int embedding_size = state.range(0);
|
||||
|
||||
BM_ScatterHelper<int64>(state, embedding_size, "ScatterMin");
|
||||
}
|
||||
|
||||
static void BM_ScatterMaxInt32(int iters, int embedding_size) {
|
||||
BM_ScatterHelper<int32>(iters, embedding_size, "ScatterMax");
|
||||
void BM_ScatterMaxInt32(::testing::benchmark::State& state) {
|
||||
const int embedding_size = state.range(0);
|
||||
|
||||
BM_ScatterHelper<int32>(state, embedding_size, "ScatterMax");
|
||||
}
|
||||
static void BM_ScatterMaxInt64(int iters, int embedding_size) {
|
||||
BM_ScatterHelper<int64>(iters, embedding_size, "ScatterMax");
|
||||
void BM_ScatterMaxInt64(::testing::benchmark::State& state) {
|
||||
const int embedding_size = state.range(0);
|
||||
|
||||
BM_ScatterHelper<int64>(state, embedding_size, "ScatterMax");
|
||||
}
|
||||
|
||||
BENCHMARK(BM_ScatterUpdateInt32)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user