merge from master

This commit is contained in:
Dmitry Volodin 2020-10-30 15:04:50 +03:00
commit 2381ee56d9
222 changed files with 6021 additions and 3991 deletions
RELEASE.md
tensorflow
c/eager
compiler
core

View File

@ -42,6 +42,10 @@
* Removed deprecated `Interpreter::UseNNAPI(bool)` C++ API.
* Use `NnApiDelegate()` and related delegate configuration methods
directly.
* TF Core:
* Corrected higher-order gradients of control flow constructs (`tf.cond`,
`tf.while_loop`, and compositions like `tf.foldl`) computed with
`tf.GradientTape` inside a `tf.function`.
## Thanks to our Contributors

View File

@ -769,7 +769,7 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) {
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
EXPECT_NE(TF_OK, TF_GetCode(status));
EXPECT_EQ(nullptr, t);
const char* msg = "Matrix size-incompatible: In[0]: [2,2], In[1]: [3,2]";
const char* msg = "In[0] mismatch In[1] shape: 2 vs. 3: [2,2] [3,2]";
EXPECT_TRUE(strstr(TF_Message(status), msg) != nullptr)
<< TF_Message(status);
// Since error is not cleared, the following copy with correct device will

View File

@ -583,7 +583,11 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments(
XlaCompiler::Argument& arg = out[input_num];
if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
// Handles compile-time constants.
TF_RET_CHECK(input->dtype() != DT_RESOURCE);
// TODO(b/157241314): Support constants located in resource variables.
TF_RET_CHECK(input->dtype() != DT_RESOURCE)
<< "tf2xla bridge does not support must-be-constants located in "
"resource variables; try moving them to a tensor";
arg.kind = XlaCompiler::Argument::kConstant;
arg.type = input->dtype();
arg.shape = input->shape();

View File

@ -517,6 +517,15 @@ cc_library(
],
)
cc_library(
name = "map_chlo_to_hlo_op",
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"],
deps = [
":hlo",
"@llvm-project//mlir:IR",
],
)
cc_library(
name = "map_hlo_to_lhlo_op",
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"],
@ -606,9 +615,11 @@ cc_library(
],
deps = [
":hlo",
":map_chlo_to_hlo_op",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms",
@ -893,6 +904,7 @@ cc_library(
deps = [
":chlo_legalize_to_hlo_inc_gen",
":hlo",
":map_chlo_to_hlo_op",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Shape",

View File

@ -0,0 +1,97 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_MHLO_OP_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_MHLO_OP_H_
#include <type_traits>
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
namespace chlo {
struct HloComplexAdaptor {
static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type,
Value broadcasted_lhs, Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs);
}
};
template <typename FromOpTy, typename ToOpTy>
struct HloBinaryElementwiseAdaptor {
static ToOpTy CreateOp(FromOpTy from_op, Type result_type,
Value broadcasted_lhs, Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<ToOpTy>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs);
}
};
struct HloCompareAdaptor {
static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
Value broadcasted_lhs, Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<mhlo::CompareOp>(
from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs,
from_op.comparison_direction(), from_op.compare_typeAttr());
}
};
// Populate a pattern for each Broadcasting CHlo op. This requires the pattern
// to take a ChloOpTy, MhloOpTy, and an Adaptor as templated values.
template <template <typename, typename, typename> class Pattern,
typename... ConstructorArgs>
void PopulateForBroadcastingBinaryOp(MLIRContext *context,
OwningRewritePatternList *patterns,
ConstructorArgs &&...args) {
#define POPULATE_BCAST(ChloOp, HloOp) \
patterns->insert< \
Pattern<ChloOp, HloOp, HloBinaryElementwiseAdaptor<ChloOp, HloOp>>>( \
context, args...);
POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp);
POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp);
POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op);
POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp);
POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp);
POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp);
POPULATE_BCAST(BroadcastMulOp, mhlo::MulOp);
POPULATE_BCAST(BroadcastOrOp, mhlo::OrOp);
POPULATE_BCAST(BroadcastPowOp, mhlo::PowOp);
POPULATE_BCAST(BroadcastRemOp, mhlo::RemOp);
POPULATE_BCAST(BroadcastShiftLeftOp, mhlo::ShiftLeftOp);
POPULATE_BCAST(BroadcastShiftRightArithmeticOp, mhlo::ShiftRightArithmeticOp);
POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp);
POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp);
POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp);
// Broadcasting ops requiring special construction.
patterns
->insert<Pattern<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>>(
context, args...);
patterns
->insert<Pattern<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>>(
context, args...);
#undef POPULATE_BCAST
}
} // namespace chlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_HLO_OP_H_

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir-hlo/utils/broadcast_utils.h"
#include "mlir/Dialect/SCF/SCF.h"
@ -69,13 +70,18 @@ struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> {
// Converts binary ops that statically are determined to not broadcast directly
// to the corresponding mhlo non-broadcasting op.
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ChloOpTy op,
PatternRewriter &rewriter) const override {
struct ConvertTrivialNonBroadcastBinaryOp
: public OpConversionPattern<ChloOpTy> {
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
ChloOpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Only rewrite for statically determinable non-broadcasting cases.
auto lhs_type = op.lhs().getType().template dyn_cast<RankedTensorType>();
auto rhs_type = op.rhs().getType().template dyn_cast<RankedTensorType>();
typename ChloOpTy::Adaptor transformed(operands);
auto lhs_type =
transformed.lhs().getType().template dyn_cast<RankedTensorType>();
auto rhs_type =
transformed.rhs().getType().template dyn_cast<RankedTensorType>();
if (!lhs_type || !rhs_type) return failure();
// Requires rank broadcast.
@ -93,8 +99,9 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
}
}
rewriter.replaceOp(op, {Adaptor::CreateOp(op, op.getResult().getType(),
op.lhs(), op.rhs(), rewriter)});
rewriter.replaceOp(
op, {Adaptor::CreateOp(op, op.getResult().getType(), operands[0],
operands[1], rewriter)});
return success();
}
};
@ -113,13 +120,15 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
// `shape.broadcast` op, which only supports prefix-padding.
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
struct ConvertRankedDynamicBroadcastBinaryOp
: public OpRewritePattern<ChloOpTy> {
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ChloOpTy op,
PatternRewriter &rewriter) const override {
: public OpConversionPattern<ChloOpTy> {
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
ChloOpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Only support ranked operands.
Value lhs = op.lhs();
Value rhs = op.rhs();
typename ChloOpTy::Adaptor transformed(operands);
Value lhs = transformed.lhs();
Value rhs = transformed.rhs();
auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
auto result_type =
@ -193,324 +202,6 @@ struct ConvertRankedDynamicBroadcastBinaryOp
}
};
// Converts a broadcasting binary operation with a scalar operand and an
// unranked operand to a ranked broadcasting operation by dynamically reshaping
// the unranked operand to a 1D tensor. This will always be safe because
// broadcasting from a scalar to another shape always works.
template <typename ChloOpTy, typename HloOpTy>
struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
: public OpRewritePattern<ChloOpTy> {
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ChloOpTy op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
Value lhs = op.lhs();
Value rhs = op.rhs();
auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
auto lhs_unranked_type = lhs.getType().dyn_cast<UnrankedTensorType>();
auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
auto rhs_unranked_type = rhs.getType().dyn_cast<UnrankedTensorType>();
bool lhs_is_scalar = lhs_ranked_type &&
lhs_ranked_type.getShape().empty() &&
rhs_unranked_type;
bool rhs_is_scalar = rhs_ranked_type &&
rhs_ranked_type.getShape().empty() &&
lhs_unranked_type;
// Only support the case where exactly one operand is scalar and the other
// is unranked. Other patterns in this file will create more efficient
// lowerings for cases where both ranks are known or will handle the more
// generic case of both inputs being unranked.
if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure();
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
// Reshape the non-scalar value into a dynamically sized, rank-1 tensor
Value shape =
rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs);
Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape);
Value size_tensor =
rewriter.create<TensorFromElementsOp>(loc, num_elements);
Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
loc, RankedTensorType::get({-1}, result_type.getElementType()),
lhs_is_scalar ? rhs : lhs, size_tensor);
// Create a new ranked Chlo op that will be further lowered by other
// patterns into Mhlo.
SmallVector<Value, 2> operands{lhs_is_scalar ? lhs : reshaped,
rhs_is_scalar ? rhs : reshaped};
Value computed = rewriter.create<ChloOpTy>(
loc, SmallVector<Type, 1>{reshaped.getType()}, operands, op.getAttrs());
// Reshape the result back into an unranked tensor.
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type,
computed, shape);
return success();
}
};
// Handles lowering of the following pattern to patterns that will be further
// matched by other patterns until they result in LHLO:
// %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy>
//
// The sequence of specializations this handles is:
// - Either operand being scalar
// - Operands having equal shapes
// - The resulting value being any of ranks [2,6]
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
struct ConvertUnrankedDynamicBroadcastBinaryOp
: public OpRewritePattern<ChloOpTy> {
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ChloOpTy op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
Value lhs = op.lhs();
Value rhs = op.rhs();
auto lhs_type = lhs.getType().dyn_cast<UnrankedTensorType>();
auto rhs_type = rhs.getType().dyn_cast<UnrankedTensorType>();
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
// Only support unranked operands. If either operand is ranked, another
// pattern will handle the lowering.
if (!lhs_type || !rhs_type) return failure();
// If lhs is scalar
auto if_op = rewriter.create<scf::IfOp>(
loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
OpBuilder if_lhs_scalar_builder = if_op.getThenBodyBuilder();
Value reshaped_lhs = if_lhs_scalar_builder.create<TensorCastOp>(
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
op.getAttrs());
if_lhs_scalar_builder.create<scf::YieldOp>(loc, if_lhs_scalar_result);
// If lhs is NOT scalar
//
// See if rhs is scalar
OpBuilder else_lhs_scalar_builder = if_op.getElseBodyBuilder();
auto if_rhs_scalar_op = else_lhs_scalar_builder.create<scf::IfOp>(
loc, result_type, IsScalarTensor(else_lhs_scalar_builder, op, rhs),
true);
else_lhs_scalar_builder.create<scf::YieldOp>(loc,
if_rhs_scalar_op.getResult(0));
OpBuilder if_rhs_scalar_builder = if_rhs_scalar_op.getThenBodyBuilder();
Value reshaped_rhs = if_rhs_scalar_builder.create<TensorCastOp>(
loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
op.getAttrs());
if_rhs_scalar_builder.create<scf::YieldOp>(loc, if_rhs_scalar_result);
// If NEITHER shape is scalar
//
// See if shapes are equal.
OpBuilder else_no_scalars_builder = if_rhs_scalar_op.getElseBodyBuilder();
Value shape_of_lhs =
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs);
Value shape_of_rhs =
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs);
Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>(
loc, shape_of_lhs, shape_of_rhs);
auto if_eq_shapes_op = else_no_scalars_builder.create<scf::IfOp>(
loc, result_type, equal_shapes, true);
else_no_scalars_builder.create<scf::YieldOp>(loc,
if_eq_shapes_op.getResult(0));
OpBuilder if_eq_shapes_builder = if_eq_shapes_op.getThenBodyBuilder();
Value non_broadcast_op =
Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder);
if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
// If shapes are not scalar, nor equal
//
// See if values are of a rank that we support.
OpBuilder if_neq_shapes_builder = if_eq_shapes_op.getElseBodyBuilder();
if_neq_shapes_builder.create<scf::YieldOp>(
loc, HandleBroadcastAndOp(if_neq_shapes_builder, op, lhs, rhs));
rewriter.replaceOp(op, {if_op.getResult(0)});
return success();
}
private:
// Returns the dyanamic result of checking the given value is a scalar
// tensor.
Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const {
auto loc = op.getLoc();
Value shape_of_tensor = rewriter.create<shape::ShapeOfOp>(loc, tensor);
Value rank_tensor = rewriter.create<shape::RankOp>(
loc, rewriter.getIndexType(), shape_of_tensor);
return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq,
rank_tensor,
rewriter.create<ConstantIndexOp>(loc, 0));
}
// Create the if statement and code for a broadcasting op with a result of a
// given rank.
scf::IfOp createRankSpecializedBroadcastAndOp(OpBuilder &builder, ChloOpTy op,
Value lhs, Value rhs,
Value actual_rank,
int targeted_rank) const {
auto loc = op.getLoc();
// Create the if block to place the current specialized logic in.
Value greater_rank_is_n = builder.create<CmpIOp>(
loc, CmpIPredicate::eq, actual_rank,
builder.create<ConstantIndexOp>(loc, targeted_rank));
auto if_op =
builder.create<scf::IfOp>(loc, lhs.getType(), greater_rank_is_n, true);
OpBuilder if_builder = if_op.getThenBodyBuilder();
// Handle shape broadcasting and inferrence.
Value lhs_shape = if_builder.create<shape::ShapeOfOp>(loc, lhs);
Value rhs_shape = if_builder.create<shape::ShapeOfOp>(loc, rhs);
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
{RankedTensorType::kDynamicSize}, builder.getIndexType());
auto known_rank_extent_tensor_type =
RankedTensorType::get({targeted_rank}, builder.getIndexType());
auto reshaped_type = RankedTensorType::get(
llvm::SmallVector<int64_t, 6>(targeted_rank,
RankedTensorType::kDynamicSize),
lhs.getType().template dyn_cast<TensorType>().getElementType());
Value ranked_shape_val = if_builder.create<shape::ConstShapeOp>(
loc, known_rank_extent_tensor_type,
mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
ranked_shape));
Value extended_lhs = if_builder.create<shape::BroadcastOp>(
loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val,
nullptr);
Value extended_lhs_casted = if_builder.create<TensorCastOp>(
loc, known_rank_extent_tensor_type, extended_lhs);
Value extended_rhs = if_builder.create<shape::BroadcastOp>(
loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val,
nullptr);
Value extended_rhs_casted = if_builder.create<TensorCastOp>(
loc, known_rank_extent_tensor_type, extended_rhs);
// 1. Reshape operands to the given rank (with the same number of elements)
// 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
// can be broadcasted and do the actual broadcasting)
// 3. Type erase the output back to unranked
Value reshaped_lhs = if_builder.create<mhlo::DynamicReshapeOp>(
loc, reshaped_type, lhs, extended_lhs_casted);
Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>(
loc, reshaped_type, rhs, extended_rhs_casted);
Value result = if_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{reshaped_type},
ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
Value reshaped_result = if_builder.create<TensorCastOp>(
loc, UnrankedTensorType::get(reshaped_type.getElementType()), result);
if_builder.create<scf::YieldOp>(loc, reshaped_result);
// Return the if_op, so the result can be used and the else block can be
// used for the next rank specialized step.
return if_op;
}
// Iterates over the desired ranks to be specialized and generates the code
// snippet for each case.
Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs,
Value rhs) const {
constexpr int max_rank_specialization = 7;
auto loc = op.getLoc();
// Find the larger rank of the 2 operands.
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
rewriter.getIndexType());
Value lhs_shape =
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, lhs);
Value rhs_shape =
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs);
Value lhs_rank =
rewriter.create<RankOp>(loc, rewriter.getIndexType(), lhs_shape);
Value rhs_rank =
rewriter.create<RankOp>(loc, rewriter.getIndexType(), rhs_shape);
Value greater_rank_lhs =
rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank);
Value greater_rank =
rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank);
// Generate a list of nested if/else statements to handle rank
// specializations from 2-6.
scf::IfOp if_op = createRankSpecializedBroadcastAndOp(rewriter, op, lhs,
rhs, greater_rank, 2);
// Put each subsequent rank specialization inside the else statement of the
// previous one.
OpBuilder else_builder = if_op.getElseBodyBuilder();
for (int i = 3; i < max_rank_specialization; i++) {
auto inner_if = createRankSpecializedBroadcastAndOp(else_builder, op, lhs,
rhs, greater_rank, i);
else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
else_builder = inner_if.getElseBodyBuilder();
}
// Fire an assertion if none of the rank specializations applied (one of the
// ranks was greater than 6).
else_builder.create<AssertOp>(
loc, else_builder.create<ConstantIntOp>(loc, 0, 1),
"Input for dynamic binary op lowering was of a rank greater than 6");
else_builder.create<scf::YieldOp>(loc, lhs);
// Return the result of the outermost if statement.
return if_op.getResult(0);
}
};
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
void PopulateForBinaryOp(MLIRContext *context,
OwningRewritePatternList *patterns) {
patterns
->insert<ConvertTrivialNonBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
context, 10);
patterns->insert<
ConvertRankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
context, 5);
patterns->insert<
ConvertUnrankedScalarDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy>,
ConvertUnrankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
context);
}
template <typename FromOpTy, typename ToOpTy>
struct HloBinaryElementwiseAdaptor {
static ToOpTy CreateOp(FromOpTy from_op, Type result_type,
Value broadcasted_lhs, Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<ToOpTy>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs);
}
};
struct HloComplexAdaptor {
static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type,
Value broadcasted_lhs, Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs);
}
};
struct HloCompareAdaptor {
static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
Value broadcasted_lhs, Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<mhlo::CompareOp>(
from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs,
from_op.comparison_direction(), from_op.compare_typeAttr());
}
};
#include "generated_chlo_legalize_to_hlo.inc"
} // namespace
@ -521,32 +212,10 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
// Instantiate conversion templates for conforming binary elementwise ops
// that do not have different dtypes between operands and results and do
// not have special attributes that need to be preserved.
#define POPULATE_BCAST(ChloOp, HloOp) \
PopulateForBinaryOp<ChloOp, HloOp, \
HloBinaryElementwiseAdaptor<ChloOp, HloOp>>(context, \
patterns);
POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp);
POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp);
POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op);
POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp);
POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp);
POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp);
POPULATE_BCAST(BroadcastMulOp, mhlo::MulOp);
POPULATE_BCAST(BroadcastOrOp, mhlo::OrOp);
POPULATE_BCAST(BroadcastPowOp, mhlo::PowOp);
POPULATE_BCAST(BroadcastRemOp, mhlo::RemOp);
POPULATE_BCAST(BroadcastShiftLeftOp, mhlo::ShiftLeftOp);
POPULATE_BCAST(BroadcastShiftRightArithmeticOp, mhlo::ShiftRightArithmeticOp);
POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp);
POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp);
POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp);
// Broadcasting ops requiring special construction.
PopulateForBinaryOp<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>(
context, patterns);
PopulateForBinaryOp<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>(
context, patterns);
PopulateForBroadcastingBinaryOp<ConvertTrivialNonBroadcastBinaryOp>(
context, patterns, 10);
PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
context, patterns, 5);
// Other patterns.
patterns->insert<ConvertConstantLikeOp>(context);

View File

@ -16,7 +16,9 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Function.h"
@ -126,6 +128,291 @@ struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
}
};
// Converts a broadcasting binary operation with a scalar operand and an
// unranked operand to a ranked broadcasting operation by dynamically reshaping
// the unranked operand to a 1D tensor. This will always be safe because
// broadcasting from a scalar to another shape always works.
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
: public OpConversionPattern<ChloOpTy> {
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
ChloOpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
typename ChloOpTy::Adaptor transformed(operands);
Value lhs = transformed.lhs();
Value rhs = transformed.rhs();
auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
auto lhs_unranked_type = lhs.getType().dyn_cast<UnrankedTensorType>();
auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
auto rhs_unranked_type = rhs.getType().dyn_cast<UnrankedTensorType>();
bool lhs_is_scalar = lhs_ranked_type &&
lhs_ranked_type.getShape().empty() &&
rhs_unranked_type;
bool rhs_is_scalar = rhs_ranked_type &&
rhs_ranked_type.getShape().empty() &&
lhs_unranked_type;
// Only support the case where exactly one operand is scalar and the other
// is unranked. Other patterns in chlo-to-hlo legalization will create more
// efficient lowerings for cases where both ranks are known or will handle
// the more generic case of both inputs being unranked.
if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure();
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
// Reshape the non-scalar value into a dynamically sized, rank-1 tensor
Value shape =
rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs);
Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape);
Value size_tensor =
rewriter.create<TensorFromElementsOp>(loc, num_elements);
Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
loc, RankedTensorType::get({-1}, result_type.getElementType()),
lhs_is_scalar ? rhs : lhs, size_tensor);
// Create a new ranked Chlo op that will be further lowered by other
// patterns into Mhlo.
SmallVector<Value, 2> new_operands{lhs_is_scalar ? lhs : reshaped,
rhs_is_scalar ? rhs : reshaped};
Value computed =
rewriter.create<ChloOpTy>(loc, SmallVector<Type, 1>{reshaped.getType()},
new_operands, op.getAttrs());
// Reshape the result back into an unranked tensor.
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type,
computed, shape);
return success();
}
};
// Handles lowering of the following pattern to patterns that will be further
// matched by other patterns until they result in LHLO:
// %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy>
//
// The sequence of specializations this handles is:
// - Either operand being scalar
// - Operands having equal shapes
// - The resulting value being any of ranks [2,6]
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
struct ConvertUnrankedDynamicBroadcastBinaryOp
: public OpConversionPattern<ChloOpTy> {
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
ChloOpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
typename ChloOpTy::Adaptor transformed(operands);
Value lhs = transformed.lhs();
Value rhs = transformed.rhs();
auto lhs_type = lhs.getType().dyn_cast<UnrankedTensorType>();
auto rhs_type = rhs.getType().dyn_cast<UnrankedTensorType>();
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
// Only support unranked operands. If either operand is ranked, another
// pattern will handle the lowering.
if (!lhs_type || !rhs_type) return failure();
// If lhs is scalar
auto if_op = rewriter.create<scf::IfOp>(
loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
OpBuilder if_lhs_scalar_builder =
if_op.getThenBodyBuilder(rewriter.getListener());
Value reshaped_lhs = if_lhs_scalar_builder.create<TensorCastOp>(
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
op.getAttrs());
if_lhs_scalar_builder.create<scf::YieldOp>(loc, if_lhs_scalar_result);
// If lhs is NOT scalar
//
// See if rhs is scalar
OpBuilder else_lhs_scalar_builder =
if_op.getElseBodyBuilder(rewriter.getListener());
auto if_rhs_scalar_op = else_lhs_scalar_builder.create<scf::IfOp>(
loc, result_type, IsScalarTensor(else_lhs_scalar_builder, op, rhs),
true);
else_lhs_scalar_builder.create<scf::YieldOp>(loc,
if_rhs_scalar_op.getResult(0));
OpBuilder if_rhs_scalar_builder =
if_rhs_scalar_op.getThenBodyBuilder(rewriter.getListener());
Value reshaped_rhs = if_rhs_scalar_builder.create<TensorCastOp>(
loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
op.getAttrs());
if_rhs_scalar_builder.create<scf::YieldOp>(loc, if_rhs_scalar_result);
// If NEITHER shape is scalar
//
// See if shapes are equal.
OpBuilder else_no_scalars_builder =
if_rhs_scalar_op.getElseBodyBuilder(rewriter.getListener());
Value shape_of_lhs =
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs);
Value shape_of_rhs =
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs);
Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>(
loc, shape_of_lhs, shape_of_rhs);
auto if_eq_shapes_op = else_no_scalars_builder.create<scf::IfOp>(
loc, result_type, equal_shapes, true);
else_no_scalars_builder.create<scf::YieldOp>(loc,
if_eq_shapes_op.getResult(0));
OpBuilder if_eq_shapes_builder =
if_eq_shapes_op.getThenBodyBuilder(rewriter.getListener());
Value non_broadcast_op =
Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder);
if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
// If shapes are not scalar, nor equal
//
// See if values are of a rank that we support.
OpBuilder if_neq_shapes_builder =
if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener());
if_neq_shapes_builder.create<scf::YieldOp>(
loc, HandleBroadcastAndOp(if_neq_shapes_builder, op, lhs, rhs));
rewriter.replaceOp(op, {if_op.getResult(0)});
return success();
}
private:
// Returns the dyanamic result of checking the given value is a scalar
// tensor.
Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const {
auto loc = op.getLoc();
Value shape_of_tensor = rewriter.create<shape::ShapeOfOp>(loc, tensor);
Value rank_tensor = rewriter.create<shape::RankOp>(
loc, rewriter.getIndexType(), shape_of_tensor);
return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq,
rank_tensor,
rewriter.create<ConstantIndexOp>(loc, 0));
}
// Create the if statement and code for a broadcasting op with a result of a
// given rank.
scf::IfOp createRankSpecializedBroadcastAndOp(OpBuilder &builder, ChloOpTy op,
Value lhs, Value rhs,
Value actual_rank,
int targeted_rank) const {
auto loc = op.getLoc();
// Create the if block to place the current specialized logic in.
Value greater_rank_is_n = builder.create<CmpIOp>(
loc, CmpIPredicate::eq, actual_rank,
builder.create<ConstantIndexOp>(loc, targeted_rank));
auto if_op =
builder.create<scf::IfOp>(loc, lhs.getType(), greater_rank_is_n, true);
OpBuilder if_builder = if_op.getThenBodyBuilder(builder.getListener());
// Handle shape broadcasting and inferrence.
Value lhs_shape = if_builder.create<shape::ShapeOfOp>(loc, lhs);
Value rhs_shape = if_builder.create<shape::ShapeOfOp>(loc, rhs);
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
{RankedTensorType::kDynamicSize}, builder.getIndexType());
auto known_rank_extent_tensor_type =
RankedTensorType::get({targeted_rank}, builder.getIndexType());
auto reshaped_type = RankedTensorType::get(
llvm::SmallVector<int64_t, 6>(targeted_rank,
RankedTensorType::kDynamicSize),
lhs.getType().template dyn_cast<TensorType>().getElementType());
Value ranked_shape_val = if_builder.create<shape::ConstShapeOp>(
loc, known_rank_extent_tensor_type,
mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
ranked_shape));
Value extended_lhs = if_builder.create<shape::BroadcastOp>(
loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val,
nullptr);
Value extended_lhs_casted = if_builder.create<TensorCastOp>(
loc, known_rank_extent_tensor_type, extended_lhs);
Value extended_rhs = if_builder.create<shape::BroadcastOp>(
loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val,
nullptr);
Value extended_rhs_casted = if_builder.create<TensorCastOp>(
loc, known_rank_extent_tensor_type, extended_rhs);
// 1. Reshape operands to the given rank (with the same number of elements)
// 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
// can be broadcasted and do the actual broadcasting)
// 3. Type erase the output back to unranked
Value reshaped_lhs = if_builder.create<mhlo::DynamicReshapeOp>(
loc, reshaped_type, lhs, extended_lhs_casted);
Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>(
loc, reshaped_type, rhs, extended_rhs_casted);
Value result = if_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{reshaped_type},
ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
Value reshaped_result = if_builder.create<TensorCastOp>(
loc, UnrankedTensorType::get(reshaped_type.getElementType()), result);
if_builder.create<scf::YieldOp>(loc, reshaped_result);
// Return the if_op, so the result can be used and the else block can be
// used for the next rank specialized step.
return if_op;
}
// Iterates over the desired ranks to be specialized and generates the code
// snippet for each case.
Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs,
Value rhs) const {
constexpr int max_rank_specialization = 7;
auto loc = op.getLoc();
// Find the larger rank of the 2 operands.
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
rewriter.getIndexType());
Value lhs_shape =
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, lhs);
Value rhs_shape =
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs);
Value lhs_rank =
rewriter.create<RankOp>(loc, rewriter.getIndexType(), lhs_shape);
Value rhs_rank =
rewriter.create<RankOp>(loc, rewriter.getIndexType(), rhs_shape);
Value greater_rank_lhs =
rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank);
Value greater_rank =
rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank);
// Generate a list of nested if/else statements to handle rank
// specializations from 2-6.
scf::IfOp if_op = createRankSpecializedBroadcastAndOp(rewriter, op, lhs,
rhs, greater_rank, 2);
// Put each subsequent rank specialization inside the else statement of the
// previous one.
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
for (int i = 3; i < max_rank_specialization; i++) {
auto inner_if = createRankSpecializedBroadcastAndOp(else_builder, op, lhs,
rhs, greater_rank, i);
else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
else_builder = inner_if.getElseBodyBuilder(rewriter.getListener());
}
// Fire an assertion if none of the rank specializations applied (one of the
// ranks was greater than 6).
else_builder.create<AssertOp>(
loc, else_builder.create<ConstantIntOp>(loc, 0, 1),
"Input for dynamic binary op lowering was of a rank greater than 6");
else_builder.create<scf::YieldOp>(loc, lhs);
// Return the result of the outermost if statement.
return if_op.getResult(0);
}
};
struct TransformUnrankedHloPass
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
@ -137,7 +424,7 @@ struct TransformUnrankedHloPass
MLIRContext &ctx = getContext();
ConversionTarget target(ctx);
target.addLegalDialect<mhlo::MhloDialect, StandardOpsDialect,
shape::ShapeDialect>();
shape::ShapeDialect, scf::SCFDialect>();
target.addLegalOp<FuncOp>();
#define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target)
#define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target)
@ -148,6 +435,12 @@ struct TransformUnrankedHloPass
#undef ADD_LEGAL_CHLO
AddLegalOpOnRankedTensor<mhlo::CompareOp>(&target);
AddLegalOpOnRankedTensor<mhlo::SelectOp>(&target);
target.addDynamicallyLegalDialect<chlo::HloClientDialect>(
[](Operation *op) {
return !llvm::any_of(op->getOperandTypes(), [](Type type) {
return type.isa<UnrankedTensorType>();
});
});
// Populate rewrite patterns.
OwningRewritePatternList patterns;
@ -180,6 +473,10 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
#undef MAP_BINARY
#undef MAP_CHLO_UNARY
#undef COMMA
chlo::PopulateForBroadcastingBinaryOp<
ConvertUnrankedDynamicBroadcastBinaryOp>(context, patterns);
chlo::PopulateForBroadcastingBinaryOp<
ConvertUnrankedScalarDynamicBroadcastBinaryOp>(context, patterns);
}
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {

View File

@ -237,209 +237,3 @@ func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x
%0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
return %0 : tensor<4xi1>
}
// -----
func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<f32>, tensor<*xf32>)
-> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @addScalarUnranked(
// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>,
// CHECK-SAME: %[[ARG_1:.*]]: tensor<*xf32>
// CHECK-SAME: ) -> tensor<*xf32> {
// First handle the dynamic reshaping of the unranked operand
// to a 1D tensor.
// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32>
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor<?xindex> -> index
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// The assuming region is part of the second stage of lowering
// with ranked broadcasting logic.
// CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<f32>
// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32>
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_0]], %[[SHAPE_RESHAPED]]
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
// CHECK: %[[SCALAR_SHAPE:.*]] = shape.const_shape []
// CHECK: %[[BROADCASTED_SHAPE:.*]] = shape.broadcast %[[SCALAR_SHAPE]], %[[SHAPE_RESHAPED]]
// CHECK: %[[SHAPE_TENSOR:.*]] = tensor_cast %[[BROADCASTED_SHAPE]] : tensor<?xindex> to tensor<1xindex>
// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_0]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32>
// CHECK: shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor<?xf32>
// CHECK: }
// As part of the unranked logic, the result is reshaped back
// to an unranked tensor.
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32>
// CHECK: }
// -----
func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf32> {
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<f32>)
-> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @addUnrankedScalar(
// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>,
// CHECK-SAME: %[[ARG_1:.*]]: tensor<f32>) -> tensor<*xf32> {
// First handle the dynamic reshaping of the unranked operand
// to a 1D tensor.
// CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32>
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor<?xindex> -> index
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// The assuming region is part of the second stage of lowering
// with ranked broadcasting logic.
// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32>
// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<f32>
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_RESHAPED]], %[[SHAPE_1]]
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
// CHECK: %[[ASTENSOR:.*]] = tensor_cast %[[SHAPE_RESHAPED]]
// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[ASTENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_1]], %[[ASTENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32>
// CHECK: shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor<?xf32>
// CHECK: }
// As part of the unranked logic, the result is reshaped back
// to an unranked tensor.
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_0]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32>
// CHECK: }
// -----
func @addUnrankedUnranked(
%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>)
-> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @addUnrankedUnranked(
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK: %[[RANK_LHS:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index
// Handle scalar LHS case
// CHECK: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) {
// CHECK: %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor<f32>
// CHECK: %[[VAL_10:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RHS]] : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: scf.yield %[[VAL_10]] : tensor<*xf32>
// CHECK: } else {
// CHECK: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK: %[[RANK_RHS:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
// CHECK: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index
// Handle scalar RHS case
// CHECK: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
// CHECK: %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor<f32>
// CHECK: %[[VAL_16:.*]] = chlo.broadcast_add %[[LHS]], %[[SCALAR_RHS]] : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
// CHECK: scf.yield %[[VAL_16]] : tensor<*xf32>
// CHECK: } else {
// CHECK: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
// Handle scalar RHS case
// CHECK: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) {
// CHECK: %[[VAL_19:.*]] = mhlo.add %[[LHS]], %[[RHS]] : tensor<*xf32>
// CHECK: scf.yield %[[VAL_19]] : tensor<*xf32>
// CHECK: } else {
// CHECK: %[[LHS_RANK:.*]] = rank %[[LHS_SHAPE]] : tensor<?xindex>
// CHECK: %[[RHS_RANK:.*]] = rank %[[RHS_SHAPE]] : tensor<?xindex>
// CHECK: %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[GREATEST_RANK_IS_2:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C2]] : index
// Handle rank 2 specialization
// CHECK: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
// CHECK: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
// CHECK: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
// CHECK: %[[CASTED_LHS_2:.*]] = tensor_cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
// CHECK: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
// CHECK: %[[CASTED_RHS_2:.*]] = tensor_cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
// CHECK: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT_2:.*]] = tensor_cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
// CHECK: scf.yield %[[RESULT_2]] : tensor<*xf32>
// CHECK: } else {
// CHECK: %[[C3:.*]] = constant 3 : index
// CHECK: %[[GREATEST_RANK_IS_3:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C3]] : index
// Handle rank 3 specialization
// CHECK: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) {
// CHECK: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
// CHECK: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
// CHECK: %[[CASTED_LHS_3:.*]] = tensor_cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
// CHECK: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
// CHECK: %[[CASTED_RHS_3:.*]] = tensor_cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
// CHECK: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
// CHECK: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
// CHECK: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: %[[RESULT_3:.*]] = tensor_cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
// CHECK: scf.yield %[[RESULT_3]] : tensor<*xf32>
// CHECK: } else {
// CHECK: %[[C4:.*]] = constant 4 : index
// CHECK: %[[GREATEST_RANK_IS_4:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C4]] : index
// Handle rank 4 specialization
// CHECK: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) {
// CHECK: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
// CHECK: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
// CHECK: %[[CASTED_LHS_4:.*]] = tensor_cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
// CHECK: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
// CHECK: %[[CASTED_RHS_4:.*]] = tensor_cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
// CHECK: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[RESULT_4:.*]] = tensor_cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
// CHECK: scf.yield %[[RESULT_4]] : tensor<*xf32>
// CHECK: } else {
// CHECK: %[[C5:.*]] = constant 5 : index
// CHECK: %[[GREATEST_RANK_IS_5:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C5]] : index
// Handle rank 5 specialization
// CHECK: %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) {
// CHECK: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
// CHECK: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
// CHECK: %[[CASTED_LHS_5:.*]] = tensor_cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
// CHECK: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
// CHECK: %[[CASTED_RHS_5:.*]] = tensor_cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
// CHECK: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
// CHECK: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
// CHECK: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
// CHECK: %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
// CHECK: scf.yield %[[RESULT_5]] : tensor<*xf32>
// CHECK: } else {
// CHECK: %[[C6:.*]] = constant 6 : index
// CHECK: %[[GREATEST_RANK_IS_6:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C6]] : index
// Handle rank 6 specialization
// CHECK: %[[VAL_58:.*]] = scf.if %[[GREATEST_RANK_IS_6]] -> (tensor<*xf32>) {
// CHECK: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
// CHECK: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
// CHECK: %[[CASTED_LHS_6:.*]] = tensor_cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
// CHECK: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
// CHECK: %[[CASTED_RHS_6:.*]] = tensor_cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex>
// CHECK: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
// CHECK: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
// CHECK: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
// CHECK: %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
// CHECK: scf.yield %[[RESULT_6]] : tensor<*xf32>
// CHECK: } else {
// CHECK: %false = constant false
// CHECK: assert %false
// CHECK: scf.yield %[[LHS]] : tensor<*xf32>
// CHECK: }
// CHECK: scf.yield %[[VAL_64:.*]] : tensor<*xf32>
// CHECK: }
// CHECK: scf.yield %[[VAL_65:.*]] : tensor<*xf32>
// CHECK: }
// CHECK: scf.yield %[[VAL_66:.*]] : tensor<*xf32>
// CHECK: }
// CHECK: scf.yield %[[VAL_67:.*]] : tensor<*xf32>
// CHECK: }
// CHECK: scf.yield %[[VAL_68:.*]] : tensor<*xf32>
// CHECK: }
// CHECK: scf.yield %[[VAL_69:.*]] : tensor<*xf32>
// CHECK: }
// CHECK: scf.yield %[[VAL_70:.*]] : tensor<*xf32>
// CHECK: }
// CHECK: return %[[VAL_71:.*]] : tensor<*xf32>
// CHECK: }

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt --transform-unranked-hlo --split-input-file %s | FileCheck %s
// RUN: mlir-hlo-opt --transform-unranked-hlo --cse --split-input-file %s | FileCheck %s
// Check the validity of expected IR.
// CHECK-LABEL: @sqr_transform_result
@ -96,3 +96,203 @@ func @tan(%a : tensor<*xf32>) -> tensor<*xf32> {
%result = chlo.tan %a : tensor<*xf32>
return %result : tensor<*xf32>
}
// -----
func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<f32>, tensor<*xf32>)
-> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @addScalarUnranked(
// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>,
// CHECK-SAME: %[[ARG_1:.*]]: tensor<*xf32>
// CHECK-SAME: ) -> tensor<*xf32> {
// First handle the dynamic reshaping of the unranked operand
// to a 1D tensor.
// CHECK-NEXT: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32>
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[BROADCASTED_RESULT:.*]] = chlo.broadcast_add %[[ARG_0]], %[[RESHAPED]] : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32>
// As part of the unranked logic, the result is reshaped back
// to an unranked tensor.
// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[BROADCASTED_RESULT:.*]], %[[SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: return %[[RESHAPED_RESULT]] : tensor<*xf32>
// CHECK-NEXT: }
// -----
func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf32> {
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<f32>)
-> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @addUnrankedScalar(
// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>,
// CHECK-SAME: %[[ARG_1:.*]]: tensor<f32>) -> tensor<*xf32> {
// First handle the dynamic reshaping of the unranked operand
// to a 1D tensor.
// CHECK-NEXT: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32>
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// The assuming region is part of the second stage of lowering
// with ranked broadcasting logic.
// CHECK-NEXT: %[[BROADCASTED_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED]], %[[ARG_1]] : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
// As part of the unranked logic, the result is reshaped back
// to an unranked tensor.
// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[BROADCASTED_RESULT:.*]], %[[SHAPE_0]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: return %[[RESHAPED_RESULT]] : tensor<*xf32>
// CHECK-NEXT: }
// -----
func @addUnrankedUnranked(
%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>)
-> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @addUnrankedUnranked(
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[RANK_LHS:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index
// Handle scalar LHS case
// CHECK-NEXT: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor<f32>
// CHECK-NEXT: %[[RHS_SHAPE_1:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE_1]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor_from_elements %[[NUM_RHS]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[LHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RESHAPED_RHS]] : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESHAPED_LHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[LHS_SCALAR_RESULT]], %[[RHS_SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_LHS_SCALAR_RESULT]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[RANK_RHS:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index
// Handle scalar RHS case
// CHECK-NEXT: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor<f32>
// CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor_from_elements %[[NUM_LHS]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[RHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED_LHS]], %[[SCALAR_RHS]] : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RHS_SCALAR_RESULT:.*]], %[[LHS_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_RHS_SCALAR_RESULT]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
// Handle equal shapes case
// CHECK-NEXT: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[ANY_SHAPE:.*]] = shape.any %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[ANY_NUM:.*]] = shape.num_elements %[[ANY_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[ANY_TENSOR:.*]] = tensor_from_elements %[[ANY_NUM]] : tensor<1xindex>
// CHECK-NEXT: %[[FLATTENED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[FLATTENED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = mhlo.add %[[FLATTENED_LHS]], %[[FLATTENED_RHS]] : tensor<?xf32>
// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[LHS_RANK:.*]] = rank %[[LHS_SHAPE]] : tensor<?xindex>
// CHECK-NEXT: %[[RHS_RANK:.*]] = rank %[[RHS_SHAPE]] : tensor<?xindex>
// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK-NEXT: %[[C2:.*]] = constant 2 : index
// CHECK-NEXT: %[[GREATEST_RANK_IS_2:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C2]] : index
// Handle rank 2 specialization
// CHECK-NEXT: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor_cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor_cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[RESULT_2:.*]] = tensor_cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_2]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[C3:.*]] = constant 3 : index
// CHECK-NEXT: %[[GREATEST_RANK_IS_3:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C3]] : index
// Handle rank 3 specialization
// CHECK-NEXT: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor_cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor_cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK-NEXT: %[[RESULT_3:.*]] = tensor_cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_3]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[C4:.*]] = constant 4 : index
// CHECK-NEXT: %[[GREATEST_RANK_IS_4:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C4]] : index
// Handle rank 4 specialization
// CHECK-NEXT: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor_cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor_cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_4:.*]] = tensor_cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_4]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[C5:.*]] = constant 5 : index
// CHECK-NEXT: %[[GREATEST_RANK_IS_5:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C5]] : index
// Handle rank 5 specialization
// CHECK-NEXT: %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor_cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor_cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[C6:.*]] = constant 6 : index
// CHECK-NEXT: %[[GREATEST_RANK_IS_6:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C6]] : index
// Handle rank 6 specialization
// CHECK-NEXT: %[[VAL_58:.*]] = scf.if %[[GREATEST_RANK_IS_6]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_6:.*]] = tensor_cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
// CHECK-NEXT: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_6:.*]] = tensor_cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_6]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %false = constant false
// CHECK-NEXT: assert %false
// CHECK-NEXT: scf.yield %[[LHS]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_64:.*]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_65:.*]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_66:.*]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_67:.*]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_68:.*]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_69:.*]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_70:.*]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: return %[[VAL_71:.*]] : tensor<*xf32>
// CHECK-NEXT: }

View File

@ -306,6 +306,14 @@ inline bool IsF32ShapedType(Type t) {
return false;
}
// Returns true if it is a shaped type of bf16 elements.
inline bool IsBF16ShapedType(Type t) {
if (auto shaped_type = t.dyn_cast_or_null<ShapedType>()) {
return shaped_type.getElementType().isBF16();
}
return false;
}
// Performs const folding `calculate` with broadcast behavior on the two
// attributes `operand1` and `operand2` and returns the result if possible.
// The two operands are expected to both be scalar values.
@ -498,7 +506,7 @@ Attribute ConstFoldBinaryOp(
/// "tfl.logical_not".
Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
llvm::function_ref<APFloat(APFloat)> calculate) {
assert(IsF32ShapedType(result_type));
assert(IsF32ShapedType(result_type) || IsBF16ShapedType(result_type));
auto result_shape_type = result_type.cast<ShapedType>();
if (auto dense_elements = operand.dyn_cast_or_null<DenseElementsAttr>()) {
@ -1911,13 +1919,20 @@ OpFoldResult SqrtOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult RsqrtOp::fold(ArrayRef<Attribute> operands) {
Type result_type = getType();
// Only constant fold for tensor of f32 is implemented.
if (!IsF32ShapedType(result_type)) return nullptr;
// Only constant fold for tensor of f32/bf16 is implemented.
if (!IsF32ShapedType(result_type) && !IsBF16ShapedType(result_type))
return nullptr;
auto compute = [](APFloat value) -> APFloat {
bool loseInfo;
const llvm::fltSemantics &original_float_semantics = value.getSemantics();
value.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
&loseInfo);
float f = value.convertToFloat();
float result = 1.f / std::sqrt(f);
return APFloat(result);
APFloat result(1.f / std::sqrt(f));
result.convert(original_float_semantics, APFloat::rmNearestTiesToEven,
&loseInfo);
return result;
};
return ConstFoldUnaryOp(result_type, operands[0], compute);
}

View File

@ -577,3 +577,13 @@ func @div_dense_different_rank() -> tensor<1x2x2xf32> {
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @rsqrt_bf16
func @rsqrt_bf16() -> tensor<bf16> {
%cst = constant dense<4.0> : tensor<bf16>
%0 = "tfl.rsqrt"(%cst) : (tensor<bf16>) -> tensor<bf16>
return %0 : tensor<bf16>
// CHECK: %[[CST:.*]] = constant dense<5.000000e-01> : tensor<bf16>
// CHECK: return %[[CST]]
}

View File

@ -1358,3 +1358,27 @@ func @fuseScalarAddIntoConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3
// CHECK: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> : tensor<16xf32>
// CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst)
}
// CHECK-LABEL: fuseScalarAddIntoConv2dBf16
func @fuseScalarAddIntoConv2dBf16(%arg0: tensor<256x32x32x3xbf16>, %arg1: tensor<16x3x3x3xbf16>) -> tensor<256x30x30x16xbf16> {
%cst = constant dense<1.5> : tensor<bf16>
%cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xbf16>
%0 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xbf16>, tensor<16x3x3x3xbf16>, tensor<16xbf16>) -> tensor<256x30x30x16xbf16>
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xbf16>, tensor<bf16>) -> tensor<256x30x30x16xbf16>
return %1 : tensor<256x30x30x16xbf16>
// CHECK: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> : tensor<16xbf16>
// CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst)
}
// CHECK-LABEL: fuseScalarAddIntoConv2dHalf
func @fuseScalarAddIntoConv2dHalf(%arg0: tensor<256x32x32x3xf16>, %arg1: tensor<16x3x3x3xf16>) -> tensor<256x30x30x16xf16> {
%cst = constant dense<1.5> : tensor<f16>
%cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf16>
%0 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf16>, tensor<16x3x3x3xf16>, tensor<16xf16>) -> tensor<256x30x30x16xf16>
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf16>, tensor<f16>) -> tensor<256x30x30x16xf16>
return %1 : tensor<256x30x30x16xf16>
// CHECK: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> : tensor<16xf16>
// CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst)
}

View File

@ -211,20 +211,21 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
pass_manager->addPass(mlir::createSymbolDCEPass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
// This pass should be always at the end of the floating point model
// conversion. Some TFL ops like unidirectional
// sequence lstm will have stateful operands and some optimization passes
// will merge those operands if they have identical values & types. However,
// it's not desired by TFL. This pass serves as a "fix" pass to split the
// merged inputs until we have 1st class variable support or reuse
// tf.variable to model this.
pass_manager->addPass(mlir::TFL::CreateSplitMergedOperandsPass());
// Run quantization after all the floating point model conversion is
// completed.
if (pass_config.quant_specs.RunPropagationAndRewriteQuantizationPasses()) {
AddQuantizationPasses(pass_config.quant_specs, pass_manager);
}
// This pass should be always at the end of the model
// conversion (even after quantization). Some TFL ops like unidirectional
// sequence lstm will have stateful operands and some optimization passes
// will merge those operands if they have identical values & types. However,
// it's not desired by TFL. This pass serves as a "fix" pass to split the
// merged inputs until we have 1st class variable support or reuse
// tf.variable to model this.
pass_manager->addPass(mlir::TFL::CreateSplitMergedOperandsPass());
}
}

View File

@ -24,6 +24,11 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
// Checks if the param passed is a F32 ElementsAttr.
def F32ElementsAttr : ElementsAttrBase<
CPred<"$_self.isa<ElementsAttr>() && $_self.cast<ElementsAttr>().getType().getElementType().isF32()">,
"32 bit float constant tensor">;
// Checks if the param passed is a float ElementsAttr.
def FloatElementsAttr : ElementsAttrBase<
CPred<"$_self.isa<ElementsAttr>() && $_self.cast<ElementsAttr>().getType().getElementType().isa<FloatType>()">,
"float constant tensor">;
// Checks if the param passed is of NoneType.
@ -93,9 +98,9 @@ class CanFuseConvOrDepthwiseConv<string is_depthwise> : Constraint<
multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
def FuseBinaryOpWithConv#binaryOp : Pat<
(binaryOp (TFL_Conv2DOp:$output $input, $filter,
(ConstantOp F32ElementsAttr:$bias), $h_factor, $w_factor,
(ConstantOp FloatElementsAttr:$bias), $h_factor, $w_factor,
TFL_AF_None, $padding, $stride_h, $stride_w),
(ConstantOp F32ElementsAttr:$value), $act_fn),
(ConstantOp FloatElementsAttr:$value), $act_fn),
(TFL_Conv2DOp $input, $filter,
(binaryOp (ConstantOp $bias),
(ConstantOp $value), TFL_AF_None),
@ -104,10 +109,10 @@ multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
(HasOneUse $output)]>;
def FuseBinaryOpWithDepthwiseConv#binaryOp : Pat<
(binaryOp (TFL_DepthwiseConv2DOp:$output $input, $filter,
(ConstantOp F32ElementsAttr:$bias),
(ConstantOp FloatElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None, $padding, $stride_h,
$stride_w, $multiplier),
(ConstantOp F32ElementsAttr:$value), $act_fn),
(ConstantOp FloatElementsAttr:$value), $act_fn),
(TFL_DepthwiseConv2DOp $input, $filter,
(binaryOp (ConstantOp $bias), (ConstantOp $value), TFL_AF_None),
$h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w,
@ -116,9 +121,9 @@ multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
(HasOneUse $output)]>;
def FuseBinaryOpWithTransposeConv#binaryOp : Pat<
(binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs,
(ConstantOp F32ElementsAttr:$bias), $padding,
(ConstantOp FloatElementsAttr:$bias), $padding,
$stride_h, $stride_w),
(ConstantOp F32ElementsAttr:$value), TFL_AF_None),
(ConstantOp FloatElementsAttr:$value), TFL_AF_None),
(TFL_TransposeConvOp $output_shape, $weights, $inputs,
(binaryOp (ConstantOp $bias),
(ConstantOp $value), TFL_AF_None),
@ -130,7 +135,7 @@ multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
(binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs,
(ConstantOp $bias), $padding,
$stride_h, $stride_w),
(ConstantOp F32ElementsAttr:$value), TFL_AF_None),
(ConstantOp FloatElementsAttr:$value), TFL_AF_None),
(TFL_TransposeConvOp $output_shape, $weights, $inputs,
(ConstantOp $value),
$padding, $stride_h, $stride_w),
@ -155,11 +160,11 @@ def ExpandTo4DForDepthwiseConv: NativeCodeCall<
multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> {
def FuseMulOrDivWithDepthwiseConv#BinaryOp : Pat<
(BinaryOp (TFL_DepthwiseConv2DOp:$output $input,
(ConstantOp F32ElementsAttr:$filter),
(ConstantOp F32ElementsAttr:$bias),
(ConstantOp FloatElementsAttr:$filter),
(ConstantOp FloatElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None, $padding, $stride_h,
$stride_w, $multiplier),
(ConstantOp F32ElementsAttr:$value), $act_fn),
(ConstantOp FloatElementsAttr:$value), $act_fn),
(TFL_DepthwiseConv2DOp $input,
(BinaryOp
(ConstantOp $filter),
@ -175,11 +180,11 @@ multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> {
(HasOneUse $output)]>;
def FuseMulOrDivWithConv#BinaryOp : Pat<
(BinaryOp (TFL_Conv2DOp:$conv_output $input,
(ConstantOp F32ElementsAttr:$filter),
(ConstantOp F32ElementsAttr:$bias),
(ConstantOp FloatElementsAttr:$filter),
(ConstantOp FloatElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w),
(ConstantOp F32ElementsAttr:$value), $act_fn),
(ConstantOp FloatElementsAttr:$value), $act_fn),
(TFL_Conv2DOp $input,
(BinaryOp (ConstantOp $filter),
(ConstantOp (ExpandTo4DForConv $value)),
@ -192,8 +197,8 @@ multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> {
(HasOneUse $conv_output)]>;
def FuseMulOrDivWithTransposeConv#BinaryOp : Pat<
(BinaryOp (TFL_TransposeConvOp:$output $output_shape,
(ConstantOp F32ElementsAttr:$weights), $input,
(ConstantOp F32ElementsAttr:$bias),
(ConstantOp FloatElementsAttr:$weights), $input,
(ConstantOp FloatElementsAttr:$bias),
$padding, $stride_h, $stride_w),
(ConstantOp $value), TFL_AF_None),
(TFL_TransposeConvOp $output_shape,
@ -209,7 +214,7 @@ multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> {
(HasOneUse $output)]>;
def FuseMulOrDivWithTransposeConvWithNoneBias#BinaryOp : Pat<
(BinaryOp (TFL_TransposeConvOp:$output $output_shape,
(ConstantOp F32ElementsAttr:$weights), $input,
(ConstantOp FloatElementsAttr:$weights), $input,
(ConstantOp $bias),
$padding, $stride_h, $stride_w),
(ConstantOp $value), TFL_AF_None),

View File

@ -1651,10 +1651,12 @@ Mutually reduces multiple tensors of identical type and shape.
TF_Int32Tensor:$group_size,
TF_Int32Tensor:$group_key,
TF_Int32Tensor:$instance_key,
Variadic<TF_ResourceTensor>:$ordering_token,
TF_AnyStrAttrOf<["Min", "Max", "Mul", "Add"]>:$merge_op,
TF_AnyStrAttrOf<["Id", "Div"]>:$final_op,
DefaultValuedAttr<StrAttr, "auto">:$communication_hint
DefaultValuedAttr<StrAttr, "auto">:$communication_hint,
DefaultValuedAttr<F32Attr, "0.0f">:$timeout_seconds
);
let results = (outs
@ -1662,6 +1664,7 @@ Mutually reduces multiple tensors of identical type and shape.
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandSizeAttr Nordering_token = TF_DerivedOperandSizeAttr<4>;
}
def TF_ComplexOp : TF_Op<"Complex", [NoSideEffect, ResultsBroadcastableShape]> {

View File

@ -77,66 +77,6 @@ namespace TF {
namespace {
// Returns true of the given function has a single uses (within the scope
// of the module containing it and all parent modules).
bool HasSingleUse(FuncOp func) {
// Public function can have any number of external uses.
if (func.isPublic()) return false;
// Return false if unexpected IR structure seen.
ModuleOp module = func.getParentOfType<ModuleOp>();
if (!module) return false;
// Inspect function uses in the containing module and all parent
// modules.
bool use_seen = false;
for (; module; module = func.isPrivate()
? nullptr
: module.getParentOfType<ModuleOp>()) {
auto func_uses_optional =
SymbolTable::getSymbolUses(func, &module.getBodyRegion());
// Found an unknown use.
if (!func_uses_optional) return false;
// If no uses in this scope, continue looking in parent module
SymbolTable::UseRange func_uses = func_uses_optional.getValue();
if (func_uses.empty()) continue;
// Check if multiple uses at this scope or another use already seen.
if (!llvm::hasSingleElement(func_uses) || use_seen) return false;
// This is the first use seen.
use_seen = true;
}
// No multiple uses seen.
return true;
}
// Returns true if the caller ops can be inlined.
bool HasInlinableUsers(FuncOp func) {
// Return false if unexpected IR structure seen.
ModuleOp module = func.getParentOfType<ModuleOp>();
if (!module) return false;
// Inspect function uses in the containing module and all parent
// modules.
for (; module; module = func.isPrivate()
? nullptr
: module.getParentOfType<ModuleOp>()) {
auto func_uses_optional =
SymbolTable::getSymbolUses(func, &module.getBodyRegion());
// Found an unknown use.
if (!func_uses_optional) return false;
for (auto &use : func_uses_optional.getValue())
if (isa<TPUPartitionedCallOp>(use.getUser())) return false;
}
// All caller ops that can be inlined.
return true;
}
struct TFConstantFoldInterface : public DialectFoldInterface {
TFConstantFoldInterface(Dialect *dialect) : DialectFoldInterface(dialect) {}
LogicalResult fold(Operation *op, ArrayRef<Attribute> operands,
@ -160,10 +100,12 @@ struct TFInlinerInterface : public DialectInlinerInterface {
// Analysis Hooks
//===--------------------------------------------------------------------===//
// Allow all call operations to be inlined.
// Returns if it's legal to inline 'callable' into the 'call', where 'call' is
// a TF operation.
bool isLegalToInline(Operation *call, Operation *callable,
bool wouldBeCloned) const final {
return true;
// Check that the TF call operation is one that is legal to inline.
return !isa<TPUPartitionedCallOp>(call);
}
// Returns if its legal to inline 'src' region into the 'dest' region
@ -186,10 +128,7 @@ struct TFInlinerInterface : public DialectInlinerInterface {
// post inlining, the function will be dead and eliminated from the IR.
// So there won't be any code duplication.
// plus the function caller op can be replaced by inlined ops.
FuncOp func = op->getParentOfType<FuncOp>();
if (!func) return true;
if (!HasInlinableUsers(func)) return false;
return TensorFlowDialect::CanDuplicate(op) || HasSingleUse(func);
return !wouldBeCloned || TensorFlowDialect::CanDuplicate(op);
}
//===--------------------------------------------------------------------===//

View File

@ -361,8 +361,7 @@ func @send_recv(%arg0: tensor<2x!tf.string>) {
// -----
// Tests functional control flow functions with replica variant ops reachable
// from a replicate region is cloned and remapped. Only the branches reachable
// with replica variant ops are cloned.
// from a replicate region is cloned and remapped.
// CHECK-LABEL: func @control_flow_with_replicate_variant_ops
func @control_flow_with_replicate_variant_ops(%arg0: tensor<i1>, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<2x!tf.string>) {
@ -380,30 +379,32 @@ func @control_flow_with_replicate_variant_ops(%arg0: tensor<i1>, %arg1: tensor<f
}
// CHECK: "tf.If"
// CHECK-SAME: else_branch = @cond_false
// CHECK-SAME: else_branch = [[COND_FALSE_REPLICA_0:@[a-z0-9_]+]]
// CHECK-SAME: then_branch = [[COND_TRUE_REPLICA_0:@[a-z0-9_]+]]
// CHECK: "tf.If"
// CHECK-SAME: else_branch = @cond_false
// CHECK-SAME: else_branch = [[COND_FALSE_REPLICA_1:@[a-z0-9_]+]]
// CHECK-SAME: then_branch = [[COND_TRUE_REPLICA_1:@[a-z0-9_]+]]
func @cond_false(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<2x!tf.string>) -> tensor<f32> {
return %arg0 : tensor<f32>
}
// CHECK-NOT: func @cond_false.+(
func @cond_true(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<2x!tf.string>) -> tensor<f32> {
"tf._XlaSendFromHost"(%arg1, %arg2) {_xla_has_host_transfer = true, device_ordinal = 0 : i64, key = "host_compute_channel_recv_0"} : (tensor<f32>, tensor<2x!tf.string>) -> ()
%0 = "tf._XlaRecvAtHost"(%arg2) {_xla_has_host_transfer = true, device_ordinal = 0 : i64, key = "host_compute_channel_send_0"} : (tensor<2x!tf.string>) -> tensor<f32>
return %0 : tensor<f32>
}
// CHECK: func [[COND_FALSE_REPLICA_0]]
// CHECK: func [[COND_TRUE_REPLICA_0]]
// CHECK: "tf._XlaSendFromHost"
// CHECK-SAME: device_ordinal = 1
// CHECK: "tf._XlaRecvAtHost"
// CHECK-SAME: device_ordinal = 1
// CHECK: func [[COND_FALSE_REPLICA_1]]
// CHECK: func [[COND_TRUE_REPLICA_1]]
// CHECK: "tf._XlaSendFromHost"
// CHECK-SAME: device_ordinal = 2
@ -413,7 +414,7 @@ func @cond_true(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<2x!tf.stri
// -----
// Tests function with no replica variant ops reachable from a replicate region
// is not cloned.
// is cloned.
// CHECK-LABEL: func @no_replicate_variant_ops
func @no_replicate_variant_ops(%arg0: tensor<f32>, %arg1: tensor<2x!tf.string>) {
@ -431,11 +432,17 @@ func @no_replicate_variant_ops(%arg0: tensor<f32>, %arg1: tensor<2x!tf.string>)
}
// CHECK: "tf.StatefulPartitionedCall"
// CHECK-SAME: f = @send_recv
// CHECK-SAME: f = [[CALLEE_REPLICA_0:@[a-z0-9_]+]]
// CHECK: "tf.StatefulPartitionedCall"
// CHECK-SAME: f = [[CALLEE_REPLICA_1:@[a-z0-9_]+]]
func @send_recv(%arg0: tensor<2x!tf.string>) {
"tf.NoOp"() : () -> ()
return
}
// CHECK-NOT: @send_recv.+(
// CHECK: func [[CALLEE_REPLICA_0]]
// CHECK: "tf.NoOp"
// CHECK: func [[CALLEE_REPLICA_1]]
// CHECK: "tf.NoOp"

View File

@ -53,7 +53,6 @@ void AddGraphExportLoweringPasses(OpPassManager &pm) {
add_pass(TFDevice::CreateParallelExecuteToIslandsPass());
pm.addPass(TFDevice::CreateLaunchToDeviceAttributePass());
pm.addPass(CreateBreakUpIslandsPass());
pm.addNestedPass<FuncOp>(CreateTPUDevicePropagationPass());
pm.addPass(createSymbolDCEPass());
}

View File

@ -120,46 +120,6 @@ llvm::SmallPtrSet<FuncOp, 4> GetReachableFunctionsFromRegion(ModuleOp module,
return visited_functions;
}
// Collects all functions and transitive functions reachable from region that
// contain replicate variant ops.
llvm::SmallDenseMap<llvm::StringRef, FuncOp> GetReachableFunctionsToClone(
ModuleOp module, Region& region,
const llvm::Optional<DictionaryAttr>& devices) {
llvm::SmallPtrSet<FuncOp, 4> reachable_functions =
GetReachableFunctionsFromRegion(module, region);
llvm::SmallDenseMap<llvm::StringRef, FuncOp> functions_to_clone;
llvm::SmallVector<FuncOp, 4> functions_to_visit;
for (FuncOp func : reachable_functions) {
if (!func.getCallableRegion()) continue;
if (HasReplicaVariantOps(*func.getCallableRegion(), devices)) {
functions_to_clone.insert({func.getName(), func});
functions_to_visit.push_back(func);
}
}
while (!functions_to_visit.empty()) {
llvm::SmallVector<FuncOp, 4> new_functions_to_visit;
for (FuncOp func_to_visit : functions_to_visit) {
auto func_uses = func_to_visit.getSymbolUses(module);
if (!func_uses) continue;
for (auto use : *func_uses) {
auto parent_func = use.getUser()->getParentOfType<FuncOp>();
if (!parent_func || !reachable_functions.contains(parent_func) ||
!functions_to_clone.insert({parent_func.getName(), parent_func})
.second)
continue;
new_functions_to_visit.push_back(parent_func);
}
}
functions_to_visit.swap(new_functions_to_visit);
}
return functions_to_clone;
}
struct FuncOldNameAndClone {
StringRef old_name;
FuncOp clone;
@ -276,20 +236,19 @@ LogicalResult ExpandReplicateIntoReplicas(
terminator.erase();
auto funcs_to_clone =
GetReachableFunctionsToClone(module, replicate_op.body(), devices);
GetReachableFunctionsFromRegion(module, replicate_op.body());
SymbolTable symbol_table(module);
builder.setInsertionPoint(island_op);
BlockAndValueMapping mapping;
for (int i : llvm::seq<int>(0, num_replicas)) {
// Clone reachable functions with replica variant ops.
// Clone reachable functions from region.
llvm::SmallVector<FuncOldNameAndClone, 4> cloned_functions;
cloned_functions.reserve(funcs_to_clone.size());
for (auto& func_to_clone : funcs_to_clone) {
auto cloned_function = func_to_clone.getSecond().clone();
for (FuncOp func_to_clone : funcs_to_clone) {
auto cloned_function = func_to_clone.clone();
symbol_table.insert(cloned_function, module.end());
cloned_functions.push_back(
{func_to_clone.getSecond().getName(), cloned_function});
cloned_functions.push_back({func_to_clone.getName(), cloned_function});
}
// Create new island for replica.

View File

@ -31,7 +31,6 @@ int main(int argc, char **argv) {
mlir::registerAllPasses();
mlir::mhlo::registerAllMhloPasses();
mlir::lmhlo::registerAllLmhloPasses();
mlir::mhlo::registerAllMhloPasses();
mlir::DialectRegistry registry;
mlir::registerAllDialects(registry);

View File

@ -0,0 +1,158 @@
# Composable Tensorflow
## Composable Tensorflow
Composable TensorFlow (TF) is the framework for defining portable TF ops with
composition in the authoring language.
The set of standard TF ops is currently open. New ops are defined for special
purposes but it is hard to make them work end-to-end: The op
needs to be handled separately by a several backends (tf2xla bridge, tflite
converter, CPU kernels, etc.). Writing shape functions and gradients for these
ops is extremely difficult. `tf.function` makes some parts of the implementation
simpler, but it introduces runtime overhead and it cannot easily be used to
apply dedicated optimizations to op kernels.
The composable TF framework allows the user to define portable TF ops as
ompositions of other TF ops. It translates a Python function used to define the
composition directly into a portable IR at build time, and uses it to expand the
composite op in the TF program during compilation / execution. By using this
expansion mechanism, new op are readily available on different platforms without
extra work. Moreover, since the expansion is optional, the backend can easily
treat it as a monolithic op when needed, for instance to apply optimizations or
associate it with a custom kernel.
### Benefits
Using the Composable TF API to define a new op and its composition can bring the
following benefits:
* *Automatic backend support*: As long as it is composed of ops supported by the
backend, the new op is automatcally supported (as a `tf.function` alternative);
* *Reduced tracing overhead*: Unlike `tf.function`, the composition function is
compiled at build time, hence TF only needs to trace a single op to build the
`graph`;
* *Easy fused op/kernel optimization*: Even if it has complex
semantics, the new op is presented as a single node in the graph, thus
optimization passes and kernels can easily be specialized to this op for better
performance.
* *Automatic shape/type inference support*: No shape functions are required for
the new op;
* *Automatic gradient support (WIP)*: The user doesn't need to author
gradient a function of the op for training.
### Use Cases
* (Portablity) User wants to add a new op and run this op on different
platforms (CPU, TPU, TFLite, etc.) to be portable.
* *Solution*: The user should define the new op as a composition. The ops used
inside the composition should have support for these platforms. These ops can
also be composite ops.
* (Performance) User defines a custom kernel for a regular structure
(i.e. LSTM), but it is hard to add the logic to fuse the individual ops to
target this kernel in the inference graph.
* *Solution*: The user should define a new TF op, which corresponds to the
fused kernel, with composition, and use this op to build the model for both
training and inference. For the platforms where a fused kernel is not
available, the execution will use the composition instead.
## Gradient
(TODO)
## Authoring Op Composition in Python
The composable TF provides a single API to define a new op with its composition
at the same time. For example, the following code defines a new
`FusedFullyConnected` op, which have `MatMul`, `Add` and some
`activation function` (specified by an op attribute) fused.
```python
import tensorflow as tf
@Composite(
'FusedFullyConnected',
inputs=['input_: T', 'filter_: T', 'bias: T'],
attrs=['act: {"", "RELU", "RELU6", "TANH"} = ""'],
derived_attrs=['T: {float, int8}'],
outputs=['o: T'])
def _composite_fully_connected(input_, filter_, bias, act):
res = tf.raw_ops.MatMul(
a=input_, b=filter_, transpose_a=False, transpose_b=True)
res = tf.raw_ops.Add(x=res, y=bias)
if act == 'RELU':
return tf.raw_ops.Relu(features=res)
elif act == 'RELU6':
return tf.raw_ops.Relu6(features=res)
elif act == 'TANH':
return tf.raw_ops.Tanh(x=res)
else:
return res
```
Besides defining new ops, composition can be specified for an existing op
for portability. The following code defines the semantics of `AddNOp`:
```python
@Composite('AddNOp')
def _my_op_c(ins):
N = len(ins)
if N == 1:
return ins[0]
sum = ins[0]
for i in range(1, N):
sum += ins[i]
return sum
```
Utilities have been built to compile the Python composition functions down to
the backend IR. The project also provides a set of graph optimization passes to
expand the composite ops in the graph by using the input backend IR. These
passes have been added to the TF [common runtime]
(https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/common_runtime)
for graph execution and [eager_runtime]
(https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/common_runtime/eager)
for eager execution.
## Compiling Op Composition
### Ahead-Of-Time (AOT) mode
Like the op kernels, the op composition can be pre-compiled to the backend IR
so the decomposition can be invoked at runtime. A Python [define_op_template.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tfr/define_op_template.py)
file is provided as an example to build composite ops in the users project
directory. All the targets required to build the new ops are created by the
following target:
```BUILD
load("//tensorflow/compiler/mlir/tfr:build_defs.bzl", "gen_op_libraries")
gen_op_libraries(
name = "test_ops",
src = "define_op_template.py",
deps = [
"//third_party/py/tensorflow",
],
)
```
More composite op definitions and usages are here included in the
[examples](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir/tfr/examples)
directory.
### Just-In-Time (JIT) mode
(TODO)
## Known Limitations
* `while` statement
* condition of `if` statement couldn't be a tensor
## Team
* Feng Liu
* Dan Moldovan

View File

@ -44,6 +44,10 @@ extern "C" CUmodule mgpuModuleLoad(void *data) {
return module;
}
extern "C" void mgpuModuleUnload(CUmodule module) {
CUDA_REPORT_IF_ERROR(cuModuleUnload(module));
}
extern "C" CUfunction mgpuModuleGetFunction(CUmodule module, const char *name) {
CUfunction function = nullptr;
CUDA_REPORT_IF_ERROR(cuModuleGetFunction(&function, module, name));
@ -64,14 +68,13 @@ extern "C" void mgpuLaunchKernel(CUfunction function, intptr_t gridX,
}
extern "C" CUstream mgpuStreamCreate() {
static CUstream stream = []() {
// TODO(b/170649852): This is neither thread-safe nor handles
// creation/descruction of one stream per context.
CUstream stream = nullptr;
CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING));
return stream;
}();
return stream;
}
extern "C" void mgpuStreamDestroy(CUstream stream) {
CUDA_REPORT_IF_ERROR(cuStreamDestroy(stream));
}
extern "C" void mgpuStreamSynchronize(CUstream stream) {

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
#include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
#include "mlir/Dialect/Shape/Transforms/Passes.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
@ -107,6 +108,8 @@ struct BufferizePass : public BufferizePassBase<BufferizePass> {
populateStandardBufferizePattern(&context, &converter, &patterns);
populateShapeStructuralTypeConversionsAndLegality(&context, converter,
patterns, target);
scf::populateSCFStructuralTypeConversionsAndLegality(&context, converter,
patterns, target);
patterns.insert<UnrankedTensorStoreTestOnlyPattern>(&context);
auto module = getOperation();

View File

@ -218,7 +218,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
auto attr = CreateDenseElementsAttrFromLiteral(literal, *builder_);
if (!attr.ok()) return attr.status();
mlir::Operation* new_operation =
func_builder->create<mlir::ConstantOp>(loc, attr.ValueOrDie());
func_builder->create<mlir::mhlo::ConstOp>(loc, attr.ValueOrDie());
for (auto attr : attributes) {
new_operation->setAttr(attr.first, attr.second);
}

View File

@ -12,7 +12,11 @@ glob_lit_tests(
data = [":test_utilities"],
driver = "@llvm-project//mlir:run_lit.sh",
tags_override = {
"hlo_to_lhlo_with_xla/gpu_ops.mlir": tf_cuda_tests_tags() + ["noasan"], # TODO(b/171751580)
"hlo_to_lhlo_with_xla/gpu_ops.mlir": tf_cuda_tests_tags() + [
"noasan",
"nomsan",
"noubsan",
], # b/171751580
},
test_file_exts = [
"mlir",

View File

@ -26,10 +26,10 @@ ENTRY %indexed_conditional () -> f32[] {
}
// CHECK-LABEL: func @main() -> tensor<f32>
// CHECK: %[[INDEX:.*]] = constant dense<1> : tensor<i32>
// CHECK: %[[OPERAND_1:.*]] = constant dense<5.600000e+01> : tensor<f32>
// CHECK: %[[OPERAND_2:.*]] = constant dense<1.200000e+01> : tensor<f32>
// CHECK: %[[OPERAND_3:.*]] = constant dense<1.300000e+01> : tensor<f32>
// CHECK: %[[INDEX:.*]] = mhlo.constant dense<1> : tensor<i32>
// CHECK: %[[OPERAND_1:.*]] = mhlo.constant dense<5.600000e+01> : tensor<f32>
// CHECK: %[[OPERAND_2:.*]] = mhlo.constant dense<1.200000e+01> : tensor<f32>
// CHECK: %[[OPERAND_3:.*]] = mhlo.constant dense<1.300000e+01> : tensor<f32>
// CHECK: %[[RESULT:.*]] = "mhlo.case"(%[[INDEX]], %[[OPERAND_1]], %[[OPERAND_2]], %[[OPERAND_3]]) ( {
// CHECK: ^bb0(%[[ARG_1:.*]]: tensor<f32>):
// CHECK: %[[RES_1:.*]] = "mhlo.negate"(%[[ARG_1]]) : (tensor<f32>) -> tensor<f32>

View File

@ -4,100 +4,102 @@
HloModule tfcompile.48
// CHECK-LABEL: func @main(%arg0: tensor<1x300xf32>, %arg1: tensor<1x300x3x1xf32>) -> tuple<tensor<300x1x5xf32>> {
// CHECK-LABEL: func @main(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x300xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x300x3x1xf32>) -> tuple<tensor<300x1x5xf32>> {
ENTRY %tfcompile.48 {
%arg0.1 = f32[1,300] parameter(0)
%arg1.2 = f32[1,300,3,1] parameter(1)
// CHECK-NEXT: %0 = "mhlo.reshape"(%arg0) : (tensor<1x300xf32>) -> tensor<1x300xf32>
// CHECK-NEXT: %[[VAL_2:.*]] = "mhlo.reshape"(%[[VAL_0]]) : (tensor<1x300xf32>) -> tensor<1x300xf32>
%reshape.3 = f32[1,300] reshape(%arg0.1)
// CHECK-NEXT: %1 = "mhlo.transpose"(%0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32>
// CHECK-NEXT: %[[VAL_3:.*]] = "mhlo.transpose"(%[[VAL_2]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32>
%transpose.27 = f32[300,1] transpose(%reshape.3), dimensions={1,0}
// CHECK-NEXT: %2 = "mhlo.reshape"(%1) : (tensor<300x1xf32>) -> tensor<300x1x1xf32>
// CHECK-NEXT: %[[VAL_4:.*]] = "mhlo.reshape"(%[[VAL_3]]) : (tensor<300x1xf32>) -> tensor<300x1x1xf32>
%reshape.28 = f32[300,1,1] reshape(%transpose.27)
// CHECK-NEXT: %3 = "mhlo.reshape"(%2) : (tensor<300x1x1xf32>) -> tensor<300x1xf32>
// CHECK-NEXT: %[[VAL_5:.*]] = "mhlo.reshape"(%[[VAL_4]]) : (tensor<300x1x1xf32>) -> tensor<300x1xf32>
%reshape.29 = f32[300,1] reshape(%reshape.28)
// CHECK-NEXT: %4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<300x1xf32>) -> tensor<300x1x5xf32>
// CHECK-NEXT: %[[VAL_6:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_5]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<300x1xf32>) -> tensor<300x1x5xf32>
%broadcast.30 = f32[300,1,5] broadcast(%reshape.29), dimensions={0,1}
// CHECK-NEXT: %cst = constant dense<1.000000e+00> : tensor<f32>
// CHECK-NEXT: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
%constant.8 = f32[] constant(1)
// CHECK-NEXT: %5 = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32>
// CHECK-NEXT: %[[VAL_8:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_7]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32>
%broadcast.9 = f32[300,1,5] broadcast(%constant.8), dimensions={}
// CHECK-NEXT: %6 = mhlo.multiply %4, %5 : tensor<300x1x5xf32>
// CHECK-NEXT: %[[VAL_9:.*]] = mhlo.multiply %[[VAL_6]], %[[VAL_8]] : tensor<300x1x5xf32>
%multiply.31 = f32[300,1,5] multiply(%broadcast.30, %broadcast.9)
// CHECK-NEXT: %cst_0 = constant dense<0.000000e+00> : tensor<f32>
// CHECK-NEXT: %[[VAL_10:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
%constant.32 = f32[] constant(0)
// CHECK-NEXT: %7 = "mhlo.broadcast_in_dim"(%cst_0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32>
// CHECK-NEXT: %[[VAL_11:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_10]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32>
%broadcast.33 = f32[300,1,5] broadcast(%constant.32), dimensions={}
// CHECK-NEXT: %8 = "mhlo.compare"(%6, %7) {comparison_direction = "GT"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1>
// CHECK-NEXT: %[[VAL_12:.*]] = "mhlo.compare"(%[[VAL_9]], %[[VAL_11]]) {comparison_direction = "GT"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1>
%compare.34 = pred[300,1,5] compare(%multiply.31, %broadcast.33), direction=GT
// CHECK-NEXT: %cst_1 = constant dense<0.000000e+00> : tensor<f32>
// CHECK-NEXT: %[[VAL_13:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
%constant.10 = f32[] constant(0)
// CHECK-NEXT: %9 = "mhlo.broadcast_in_dim"(%cst_1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32>
// CHECK-NEXT: %[[VAL_14:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_13]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32>
%broadcast.11 = f32[300,1,5] broadcast(%constant.10), dimensions={}
// CHECK-NEXT: %cst_2 = constant dense<0.000000e+00> : tensor<f32>
// CHECK-NEXT: %[[VAL_15:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
%constant.40 = f32[] constant(0)
// CHECK-NEXT: %10 = "mhlo.broadcast_in_dim"(%cst_2) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x5xf32>
// CHECK-NEXT: %[[VAL_16:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_15]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x5xf32>
%broadcast.41 = f32[300,5] broadcast(%constant.40), dimensions={}
// CHECK-NEXT: %11 = "mhlo.copy"(%arg1) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
// CHECK-NEXT: %[[VAL_17:.*]] = "mhlo.copy"(%[[VAL_1]]) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
%copy.1 = f32[1,300,3,1] copy(%arg1.2)
// CHECK-NEXT: %12 = "mhlo.reshape"(%11) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
// CHECK-NEXT: %[[VAL_18:.*]] = "mhlo.reshape"(%[[VAL_17]]) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
%reshape.4 = f32[1,300,3,1] reshape(%copy.1)
// CHECK-NEXT: %13 = "mhlo.reshape"(%12) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32>
// CHECK-NEXT: %[[VAL_19:.*]] = "mhlo.reshape"(%[[VAL_18]]) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32>
%reshape.24 = f32[1,300,3] reshape(%reshape.4)
// CHECK-NEXT: %14 = "mhlo.transpose"(%13) {permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32>
// CHECK-NEXT: %[[VAL_20:.*]] = "mhlo.transpose"(%[[VAL_19]]) {permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32>
%transpose.25 = f32[300,1,3] transpose(%reshape.24), dimensions={1,0,2}
// CHECK-NEXT: %15 = "mhlo.reshape"(%14) : (tensor<300x1x3xf32>) -> tensor<300x3xf32>
// CHECK-NEXT: %[[VAL_21:.*]] = "mhlo.reshape"(%[[VAL_20]]) : (tensor<300x1x3xf32>) -> tensor<300x3xf32>
%reshape.26 = f32[300,3] reshape(%transpose.25)
// CHECK-NEXT: %cst_3 = constant dense<{{\[\[}}-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32>
// CHECK-NEXT: %[[VAL_22:.*]] = mhlo.constant dense<{{\[\[}}-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32>
%constant.35 = f32[3,5] constant({ { -0.106023, 0.121505, 0.800239, -0.768885, 0.0966113 }, { 0.689014, -0.407056, -0.797853, 0.00378925, -0.208881 }, { -0.608529, 0.0276617, 0.268557, 0.577401, -0.428437 } })
// TODO(b/129709049) consider making this default precision config implied.
// CHECK-NEXT: %16 = "mhlo.dot"(%15, %cst_3) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32>
// CHECK-NEXT: %[[VAL_23:.*]] = "mhlo.dot"(%[[VAL_21]], %[[VAL_22]]) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32>
%dot.36 = f32[300,5] dot(%reshape.26, %constant.35), lhs_contracting_dims={1}, rhs_contracting_dims={0}
// CHECK-NEXT: %cst_4 = constant dense<0.000000e+00> : tensor<5xf32>
// CHECK-NEXT: %[[VAL_24:.*]] = mhlo.constant dense<0.000000e+00> : tensor<5xf32>
%constant.37 = f32[5]{0} constant({0, 0, 0, 0, 0})
// CHECK-NEXT: %17 = "mhlo.broadcast_in_dim"(%cst_4) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<5xf32>) -> tensor<300x5xf32>
// CHECK-NEXT: %[[VAL_25:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_24]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<5xf32>) -> tensor<300x5xf32>
%broadcast.38 = f32[300,5] broadcast(%constant.37), dimensions={1}
// CHECK-NEXT: %18 = mhlo.add %16, %17 : tensor<300x5xf32>
// CHECK-NEXT: %[[VAL_26:.*]] = mhlo.add %[[VAL_23]], %[[VAL_25]] : tensor<300x5xf32>
%add.39 = f32[300,5] add(%dot.36, %broadcast.38)
// CHECK-NEXT: %19 = mhlo.maximum %10, %18 : tensor<300x5xf32>
// CHECK-NEXT: %[[VAL_27:.*]] = mhlo.maximum %[[VAL_16]], %[[VAL_26]] : tensor<300x5xf32>
%maximum.42 = f32[300,5] maximum(%broadcast.41, %add.39)
// CHECK-NEXT: %20 = "mhlo.reshape"(%19) : (tensor<300x5xf32>) -> tensor<300x1x5xf32>
// CHECK-NEXT: %[[VAL_28:.*]] = "mhlo.reshape"(%[[VAL_27]]) : (tensor<300x5xf32>) -> tensor<300x1x5xf32>
%reshape.44 = f32[300,1,5] reshape(%maximum.42)
// CHECK-NEXT: %21 = "mhlo.select"(%8, %9, %20) : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
// CHECK-NEXT: %[[VAL_29:.*]] = "mhlo.select"(%[[VAL_12]], %[[VAL_14]], %[[VAL_28]]) : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
%select.45 = f32[300,1,5] select(%compare.34, %broadcast.11, %reshape.44)
// CHECK-NEXT: %22 = "mhlo.reshape"(%21) : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
// CHECK-NEXT: %[[VAL_30:.*]] = "mhlo.reshape"(%[[VAL_29]]) : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
%reshape.46 = f32[300,1,5] reshape(%select.45)
// CHECK-NEXT: %23 = "mhlo.tuple"(%22) : (tensor<300x1x5xf32>) -> tuple<tensor<300x1x5xf32>>
// CHECK-NEXT: return %23 : tuple<tensor<300x1x5xf32>>
// CHECK-NEXT: %[[VAL_31:.*]] = "mhlo.tuple"(%[[VAL_30]]) : (tensor<300x1x5xf32>) -> tuple<tensor<300x1x5xf32>>
// CHECK-NEXT: return %[[VAL_31]] : tuple<tensor<300x1x5xf32>>
ROOT %tuple.47 = (f32[300,1,5]) tuple(%reshape.46)
}

View File

@ -20,7 +20,7 @@ HloModule tfcompile.20
ENTRY %tfcompile.20 {
%arg0.1 = f32[] parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"}
// CHECK: [[C0:%.+]] = constant
// CHECK: [[C0:%.+]] = mhlo.constant
%constant.3 = f32[] constant(10), metadata={op_type="Less" op_name="Less"}
// CHECK: [[R1:%.+]] = "mhlo.compare"([[A0]], [[C0]])

View File

@ -176,48 +176,49 @@ add {
%test_constant {
// Scalar/0D tensor constant
// CHECK-NEXT: %cst = constant dense<1> : tensor<i64>
// CHECK-NEXT: %[[VAL_0:.*]] = mhlo.constant dense<1> : tensor<i64>
%constant.0 = s64[] constant(1)
// Note that double brackets "[[" have to be escaped as they denote variables
// in FileCheck. The only way to do so is to drop into regex with "{{"
// CHECK-NEXT: constant dense<{{\[\[\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[\[}}3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32>
// CHECK-NEXT: %[[VAL_1:.*]] = mhlo.constant dense<{{\[\[}}{{\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[}}[3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32>
%constant.1 = f32[2,2,1,1]{3,2,1,0} constant({{{{1.0}},{{2.0}}},{{{3.0}},{{4.0}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
// CHECK: dense<[1, 2, 4, 8]> : tensor<4xui64>
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<[1, 2, 4, 8]> : tensor<4xui64>
%constant.2 = u64[4] constant({ 1, 2, 4, 8 })
// CHECK: dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16>
// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16>
%constant.3 = bf16[4] constant({1, 2, 3, 4})
// CHECK: dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>>
// CHECK: %[[VAL_4:.*]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>>
%constant.4 = c64[] constant((1, 0))
// CHECK: dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f64>>
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f64>>
%constant.5 = c128[] constant((1, 0))
// CHECK: dense<[1.000000e+00, -4.000000e+00, -6.550400e+04, 1.562500e-02]> : tensor<4xf16>
// CHECK: %[[VAL_6:.*]] = mhlo.constant dense<[1.000000e+00, -4.000000e+00, -6.550400e+04, 1.562500e-02]> : tensor<4xf16>
ROOT %constant.6 = f16[4] constant({1, -4, -65504, 0.015625})
}
// TODO(b/129422361) Potentially update when copy, reshape, and conv have actual
// implementations with attributes, etc.
// CHECK-LABEL: func @test_conv(%arg0: tensor<256x32x32x6xf32>) -> tuple<tensor<256x30x30x16xf32>>
// CHECK-LABEL: func @test_conv(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<256x32x32x6xf32>) -> tuple<tensor<256x30x30x16xf32>> attributes {sym_visibility = "private"} {
%test_conv {
%arg0.1 = f32[256,32,32,6]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
// CHECK-NEXT: %0 = "mhlo.copy"(%arg0) {minor_to_major = dense<[2, 1, 3, 0]> : tensor<4xindex>} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
// CHECK-NEXT: %[[VAL_1:.*]] = "mhlo.copy"(%[[VAL_0]]) {minor_to_major = dense<[2, 1, 3, 0]> : tensor<4xindex>} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
%copy.1 = f32[256,32,32,6]{2,1,3,0} copy(%arg0.1), metadata={op_name="HLO_Args"}
// CHECK-NEXT: %1 = "mhlo.reshape"(%0) {minor_to_major = dense<[2, 1, 3, 0]> : tensor<4xindex>} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
// CHECK-NEXT: %[[VAL_2:.*]] = "mhlo.reshape"(%[[VAL_1]]) {minor_to_major = dense<[2, 1, 3, 0]> : tensor<4xindex>} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
%reshape.2 = f32[256,32,32,6]{2,1,3,0} reshape(%copy.1)
// Note that double brackets "[[" have to be escaped as they denote variables
// in FileCheck. The only way to do so is to drop into regex with "{{"
// CHECK-NEXT: %cst = constant dense<{{\[\[\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[\[}}3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32>
// CHECK-NEXT: %[[VAL_3:.*]] = mhlo.constant dense<{{\[\[}}{{\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[}}[3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32>
%constant.3 = f32[2,2,1,1]{3,2,1,0} constant({{{{0.5}}, {{-0.6}}}, {{{0.3}}, {{-0.1}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
// CHECK-NEXT: %2 = "mhlo.convolution"(%1, %cst) {
// CHECK-NEXT: %[[VAL_4:.*]] = "mhlo.convolution"(%[[VAL_2]], %[[VAL_3]]) {
// CHECK-SAME: batch_group_count = 1 : i64
// CHECK-SAME: dimension_numbers = {
// CHECK-SAME: input_batch_dimension = 0 : i64
@ -241,10 +242,10 @@ add {
%convolution.4 = f32[16,30,30,256]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->f01b, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
// CHECK-NEXT: %3 = "mhlo.reshape"(%2) : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32>
// CHECK-NEXT: %[[VAL_5:.*]] = "mhlo.reshape"(%[[VAL_4]]) : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32>
%reshape.5 = f32[256,30,30,16]{3,2,1,0} reshape(%convolution.4), metadata={op_name="HLO_Retvals"}
// CHECK-NEXT: "mhlo.tuple"(%3) : (tensor<256x30x30x16xf32>) -> tuple<tensor<256x30x30x16xf32>>
// CHECK-NEXT: %[[VAL_6:.*]] = "mhlo.tuple"(%[[VAL_5]]) : (tensor<256x30x30x16xf32>) -> tuple<tensor<256x30x30x16xf32>>
ROOT %tuple.6 = (f32[256,30,30,16]{3,2,1,0}) tuple(%reshape.5), metadata={op_name="HLO_Retvals"}
}

View File

@ -4,25 +4,25 @@ HloModule tfcompile.1
// CHECK-LABEL: func @main() -> tensor<i1> {
ENTRY %tfcompile.1 {
// CHECK-NEXT: %cst = constant dense<1.000000e+00> : tensor<f32>
// CHECK-NEXT: %[[VAL_0:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
%constant.0 = f32[] constant(1)
// CHECK-NEXT: %cst_0 = constant dense<1.000000e+00> : tensor<f64>
// CHECK-NEXT: %[[VAL_1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f64>
%constant.1 = f64[] constant(1)
// CHECK-NEXT: %cst_1 = constant dense<1> : tensor<i8>
// CHECK-NEXT: %[[VAL_2:.*]] = mhlo.constant dense<1> : tensor<i8>
%constant.2 = s8[] constant(1)
// CHECK-NEXT: %cst_2 = constant dense<1> : tensor<i16>
// CHECK-NEXT: %[[VAL_3:.*]] = mhlo.constant dense<1> : tensor<i16>
%constant.3 = s16[] constant(1)
// CHECK-NEXT: %cst_3 = constant dense<1> : tensor<i32>
// CHECK-NEXT: %[[VAL_4:.*]] = mhlo.constant dense<1> : tensor<i32>
%constant.4 = s32[] constant(1)
// CHECK-NEXT: %cst_4 = constant dense<1> : tensor<i64>
// CHECK-NEXT: %[[VAL_5:.*]] = mhlo.constant dense<1> : tensor<i64>
%constant.5 = s64[] constant(1)
// CHECK-NEXT: %cst_5 = constant dense<true> : tensor<i1>
// CHECK-NEXT: return %cst_5 : tensor<i1>
// CHECK-NEXT: %[[VAL_6:.*]] = mhlo.constant dense<true> : tensor<i1>
// CHECK-NEXT: return %[[VAL_6]] : tensor<i1>
ROOT %constant.6 = pred[] constant(1)
}

View File

@ -351,6 +351,7 @@ cc_library(
":xla_op_registry",
":xla_resource",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:variant",

View File

@ -91,15 +91,15 @@ class DataFormatVecPermuteOp : public XlaOpKernel {
: XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("src_format", &src_format_));
OP_REQUIRES(
ctx, src_format_.size() == 4,
errors::InvalidArgument("Data format should have 4 characters"));
ctx, src_format_.size() == 4 || src_format_.size() == 5,
errors::InvalidArgument("Data format should have 4 or 5 characters"));
TensorFormat data_format;
OP_REQUIRES(ctx, FormatFromString(src_format_, &data_format),
errors::InvalidArgument("Invalid data format"));
OP_REQUIRES_OK(ctx, ctx->GetAttr("dst_format", &dst_format_));
OP_REQUIRES(
ctx, dst_format_.size() == 4,
errors::InvalidArgument("Data format should have 4 characters"));
ctx, dst_format_.size() == 4 || dst_format_.size() == 5,
errors::InvalidArgument("Data format should have 4 or 5 characters"));
OP_REQUIRES(ctx, FormatFromString(dst_format_, &data_format),
errors::InvalidArgument("Invalid data format"));
}
@ -113,9 +113,10 @@ class DataFormatVecPermuteOp : public XlaOpKernel {
input_tensor_shape.DebugString()));
const int dim0 = input_tensor_shape.dim_size(0);
OP_REQUIRES(
ctx, dim0 == 2 || dim0 == 4,
ctx, dim0 == 2 || dim0 == 4 || dim0 == 5,
errors::InvalidArgument(
"First dimension of input must be of size 4, but got shape ",
"First dimension of input must be of size 2, 4 or 5, but got "
"shape ",
input_tensor_shape.DebugString()));
if (input_rank == 2) {
OP_REQUIRES(

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <numeric>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/types/variant.h"
#include "tensorflow/compiler/jit/defs.h"
@ -675,6 +676,38 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
return graph;
}
// Collects all control rets from `orig_control_ret_nodes` that are still valid,
// keeping the same order.
std::vector<std::string> GetValidControlRets(
absl::Span<Node* const> orig_control_ret_nodes, const Graph& graph) {
// Build map from control ret node to index.
absl::flat_hash_map<const Node*, int> control_ret_nodes_map;
for (int i = 0; i < orig_control_ret_nodes.size(); ++i) {
const Node* n = orig_control_ret_nodes[i];
control_ret_nodes_map[n] = i;
}
// Check which control rets are still valid.
std::vector<bool> is_valid_control_ret(orig_control_ret_nodes.size(), false);
int num_valid_control_rets = 0;
for (const Node* n : graph.nodes()) {
auto iter = control_ret_nodes_map.find(n);
if (iter != control_ret_nodes_map.end()) {
++num_valid_control_rets;
is_valid_control_ret[iter->second] = true;
}
}
// Return valid control rets in same order as they appear in
// `orig_control_ret_nodes`.
std::vector<std::string> valid_control_rets;
valid_control_rets.reserve(num_valid_control_rets);
for (int i = 0; i < orig_control_ret_nodes.size(); ++i) {
if (is_valid_control_ret[i]) {
valid_control_rets.push_back(orig_control_ret_nodes[i]->name());
}
}
return valid_control_rets;
}
Status XlaCompiler::CompileFunction(
const XlaCompiler::CompileOptions& options,
const NameAttrList& fn_name_attrs,
@ -765,15 +798,15 @@ Status XlaCompiler::CompileFunction(
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
VLOG(1) << "Using MLIR bridge";
GraphDebugInfo debug_info;
std::vector<std::string> control_rets;
for (const auto* control_ret_node : fbody->control_ret_nodes) {
control_rets.push_back(control_ret_node->name());
}
std::vector<std::string> valid_control_rets =
GetValidControlRets(fbody->control_ret_nodes, *graph);
TF_RETURN_IF_ERROR(CompileGraphToXlaHlo(
std::move(*graph), mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
control_rets, options_.device_type.type_string(), options.use_tuple_arg,
*options_.flib_def, debug_info, options_.shape_representation_fn,
result));
valid_control_rets, options_.device_type.type_string(),
options.use_tuple_arg, *options_.flib_def, debug_info,
options_.shape_representation_fn, result));
} else {
TF_RETURN_IF_ERROR(
CompileGraph(options, function_id, std::move(graph), args, result));

View File

@ -5220,13 +5220,31 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
for (int64 spatial_dim = 0;
spatial_dim < dnums.input_spatial_dimensions_size(); ++spatial_dim) {
const int64 kernel_size = window_dims[spatial_dim].size();
const int64 dilated_kernel_size =
1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation();
const bool can_be_group_or_contraction =
!window_dims[spatial_dim].window_reversal() &&
window_dims[spatial_dim].padding_low() == 0 &&
window_dims[spatial_dim].padding_high() == 0 &&
window_dims[spatial_dim].window_dilation() == 1;
const bool is_group_dim =
can_be_group_or_contraction &&
window_dims[spatial_dim].base_dilation() == kernel_size &&
window_dims[spatial_dim].stride() == kernel_size - 1;
const int64 input_size =
input->shape().dimensions(dnums.input_spatial_dimensions(spatial_dim));
const bool is_pure_contraction_dim =
kernel_size == input_size && can_be_group_or_contraction &&
window_dims[spatial_dim].base_dilation() == 1 &&
window_dims[spatial_dim].stride() == 1;
if (is_group_dim || is_pure_contraction_dim) {
*(swapped_window.add_dimensions()) = window_dims[spatial_dim];
continue;
}
const int64 dilated_kernel_size =
1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation();
const int64 dilated_input_size =
1 + (input_size - 1) * window_dims[spatial_dim].base_dilation();
// Don't decide to swap if the input size is one, since many convolution
// implementations can easily hand that special case efficiently.
kernel_product *= kernel_size;

View File

@ -6654,6 +6654,32 @@ TEST_F(AlgebraicSimplifierTest, AbsEliminationMultiply) {
GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
}
TEST_F(AlgebraicSimplifierTest, BroadcastCompareSimplification) {
std::string module_string = R"(
HloModule m
test {
a = s32[] parameter(0)
b = s32[] parameter(1)
x = s32[10]{0} parameter(2)
broadcast_a = s32[10]{0} broadcast(a), dimensions={}
broadcast_b = s32[10]{0} broadcast(b), dimensions={}
add = s32[10]{0} add(broadcast_a, x)
ROOT cmp = pred[10]{0} compare(add, broadcast_b), direction=EQ
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_string));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Compare(m::Parameter(2),
m::Broadcast(m::Subtract(
m::Parameter(1), m::Parameter(0))))));
// Numerically unstable transformation shouldn't be applied to floating types.
std::string module_string_f32 =
absl::StrReplaceAll(module_string, {{"s32", "f32"}});
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
}
TEST_F(AlgebraicSimplifierTest, AbsEliminationPower2) {
const char* kModuleStr = R"(
HloModule m

View File

@ -63,7 +63,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
@ -2200,20 +2199,21 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) {
VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace";
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
FusedIrEmitter fused_emitter(&elemental_emitter);
BindFusionArguments(fusion, &fused_emitter);
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
// Delegate to common implementation of fused in-place dynamic-update-slice.
return llvm_ir::EmitFusedDynamicUpdateSliceInPlace(
fusion, GetGeneratorForOperandIrArrays(fusion), GetIrArrayFor(fusion),
&elemental_emitter, &b_);
fusion, GetIrArrayFor(fusion), &fused_emitter, &b_);
} else if (fusion->IsLoopFusion()) {
VLOG(3) << "HandleFusion kLoop";
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
auto operands = GetIrArraysForOperandsOf(fusion);
FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion),
&elemental_emitter);
TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
return EmitTargetElementLoop(fusion, fused_emitter.GetRootGenerator());
FusedIrEmitter fused_emitter(&elemental_emitter);
BindFusionArguments(fusion, &fused_emitter);
TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator(
fusion->fused_expression_root()));
return EmitTargetElementLoop(fusion, generator);
} else if (fusion->IsOutputFusion()) {
VLOG(3) << "HandleFusion kOutput";
int64 dot_op_index = root->operand(0)->opcode() == HloOpcode::kDot ? 0 : 1;
@ -3451,5 +3451,17 @@ llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue(
return EmitBufferPointer(root_buffer, root_inst->shape());
}
void IrEmitter::BindFusionArguments(const HloInstruction* fusion,
FusedIrEmitter* fused_emitter) {
for (int i = 0; i < fusion->operand_count(); i++) {
const HloInstruction* operand = fusion->operand(i);
fused_emitter->BindGenerator(
fusion->fused_parameter(i),
[this, operand](llvm_ir::IrArray::Index index) {
return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_);
});
}
}
} // namespace cpu
} // namespace xla

View File

@ -43,6 +43,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
@ -234,10 +235,9 @@ class IrEmitter : public DfsHloVisitorWithDefault,
std::vector<llvm_ir::IrArray> GetIrArraysForOperandsOf(
const HloInstruction* hlo);
GeneratorForOperandIrArrays GetGeneratorForOperandIrArrays(
HloInstruction* unnested_hlo) {
return [=]() { return GetIrArraysForOperandsOf(unnested_hlo); };
}
// Bind all argument IrArrays of `fusion` to `fused_emitter`.
void BindFusionArguments(const HloInstruction* fusion,
FusedIrEmitter* fused_emitter);
// Augments IrArray with aliasing information.
void AddAliasingInformationToIrArray(const HloInstruction& hlo,

View File

@ -38,14 +38,60 @@ FusionNodeIndexingEvaluation::FusionNodeIndexingEvaluation(
// a tradeoff between compilation time and runtime here.
const int64 FusionNodeIndexingEvaluation::kAllowedCodeDuplication = 15;
namespace {
// Returns which ops invalidate the cache of emitted instructions by creating a
// new BasicBlock and setting the insertion point to the newly created
// BasicBlock. We can only reuse cached values if they were emitted in the same
// BasicBlock as the current BasicBlock.
bool OpInvalidatesCache(const HloInstruction* hlo) {
switch (hlo->opcode()) {
// This list of ops was created by inspecting the code. There is no
// guarantee that it is complete.
case HloOpcode::kConcatenate:
case HloOpcode::kDot:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kPad:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
return true;
default:
return false;
}
}
// Counts the number of "real" users of 'hlo'. When 'hlo' has a fusion node as
// user, we consider the users of the fusion parameter corresponding to 'hlo' as
// the real users.
int64 UserCount(const HloInstruction* hlo) {
int64 cnt = 0;
for (HloInstruction* user : hlo->users()) {
if (user->opcode() == HloOpcode::kFusion) {
// Count the number of users of the parameter corresponding to the fusion
// operand.
int64 operand_index = user->operand_index(hlo);
cnt += user->fused_parameter(operand_index)->user_count();
} else {
++cnt;
}
}
return cnt;
}
} // namespace
bool FusionNodeIndexingEvaluation::CodeDuplicationTooHigh(
const HloInstruction* producer) const {
return EvaluateEmittedInstructions(producer) > kAllowedCodeDuplication;
int64 emitted_instructions = EvaluateEmittedInstructions(producer);
return emitted_instructions > kAllowedCodeDuplication ||
(OpInvalidatesCache(producer) &&
(emitted_instructions > 1 || UserCount(producer) > 1));
}
bool FusionNodeIndexingEvaluation::MaxCodeDuplicationTooHigh() const {
for (const auto& entry : index_usage_count_) {
if (entry.second > kAllowedCodeDuplication) {
if (entry.second > kAllowedCodeDuplication ||
(OpInvalidatesCache(entry.first) &&
(entry.second > 1 || UserCount(entry.first) > 1))) {
return true;
}
}

View File

@ -773,11 +773,11 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
CHECK_EQ(HloInstruction::FusionKind::kLoop, fusion->fusion_kind());
GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
GetNestedComputer());
FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion),
&elemental_emitter);
TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
return EmitTargetElementLoop(*fusion, fused_emitter.GetRootGenerator());
FusedIrEmitter fused_emitter(&elemental_emitter);
BindFusionArguments(fusion, &fused_emitter);
TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator(
fusion->fused_expression_root()));
return EmitTargetElementLoop(*fusion, generator);
}
Status IrEmitter::HandleCall(HloInstruction* call) {
@ -876,5 +876,17 @@ std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs(
return output_arrays;
}
void IrEmitter::BindFusionArguments(const HloInstruction* fusion,
FusedIrEmitter* fused_emitter) {
for (int i = 0; i < fusion->operand_count(); i++) {
const HloInstruction* operand = fusion->operand(i);
fused_emitter->BindGenerator(
fusion->fused_parameter(i),
[this, operand, fusion](llvm_ir::IrArray::Index index) {
return GetIrArray(*operand, *fusion).EmitReadArrayElement(index, &b_);
});
}
}
} // namespace gpu
} // namespace xla

View File

@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
@ -182,18 +183,9 @@ class IrEmitter : public DfsHloVisitorWithDefault,
const HloModuleConfig& hlo_module_config_;
protected:
GeneratorForOperandIrArrays GetGeneratorForOperandIrArrays(
const HloInstruction* fusion) {
return [=]() {
std::vector<llvm_ir::IrArray> ir_arrays;
ir_arrays.reserve(fusion->operand_count());
absl::c_transform(fusion->operands(), std::back_inserter(ir_arrays),
[&](const HloInstruction* operand) {
return GetIrArray(*operand, *fusion);
});
return ir_arrays;
};
}
// Bind all argument IrArrays of `fusion` to `fused_emitter`.
void BindFusionArguments(const HloInstruction* fusion,
FusedIrEmitter* fused_emitter);
private:
// A helper method for EmitAtomicOperationForNestedComputation. Certain

View File

@ -960,19 +960,24 @@ Status IrEmitterUnnested::EmitLoopFusionFromMlir(MlirEmitterInput input,
GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
GetNestedComputer());
FusedIrEmitter fused_emitter(&elemental_emitter);
FusedIrEmitter fused_emitter(
[&] {
for (int i = 0; i < fusion_operands.size(); i++) {
auto operand_ir_arrays =
absl::MakeSpan(ir_arrays).subspan(0, fusion_operands.size());
return std::vector<llvm_ir::IrArray>(operand_ir_arrays.begin(),
operand_ir_arrays.end());
},
&elemental_emitter);
TF_RETURN_IF_ERROR(
fused_computation->root_instruction()->Accept(&fused_emitter));
auto element_generator = fused_emitter.GetRootGenerator();
auto* builder = &b_;
auto ir_array = operand_ir_arrays[i];
fused_emitter.BindGenerator(
fused_computation->parameter_instruction(i),
[builder, ir_array](llvm_ir::IrArray::Index index) {
return ir_array.EmitReadArrayElement(index, builder);
});
}
TF_ASSIGN_OR_RETURN(
auto element_generator,
fused_emitter.GetGenerator(fused_computation->root_instruction()));
Shape element_shape = TypeToShape(fusion_outputs[0].getType());
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor);
@ -1022,14 +1027,14 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
GpuElementalIrEmitter operand_elemental_emitter(
hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
GetNestedComputer());
FusedIrEmitter operand_fused_emitter(
GetGeneratorForOperandIrArrays(fusion),
&operand_elemental_emitter);
TF_RETURN_IF_ERROR(
root->mutable_operand(0)->Accept(&operand_fused_emitter));
FusedIrEmitter operand_fused_emitter(&operand_elemental_emitter);
BindFusionArguments(fusion, &operand_fused_emitter);
TF_ASSIGN_OR_RETURN(
auto generator,
operand_fused_emitter.GetGenerator(root->operand(0)));
TF_RETURN_IF_ERROR(EmitTargetElementLoopInThunk(
*fusion, operand_fused_emitter.GetGenerator(root->operand(0)),
*fusion, generator,
static_cast<KernelThunk*>(thunks.back().get()),
ComputeMaxUnrollFactor(fusion)));
}
@ -1044,10 +1049,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
GpuElementalIrEmitter scatter_elemental_emitter(
hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
GetNestedComputer());
FusedIrEmitter scatter_fused_emitter(
GetGeneratorForOperandIrArrays(fusion),
&scatter_elemental_emitter);
TF_RETURN_IF_ERROR(root->Accept(&scatter_fused_emitter));
FusedIrEmitter scatter_fused_emitter(&scatter_elemental_emitter);
BindFusionArguments(fusion, &scatter_fused_emitter);
CHECK_EQ(root->parent()->FusionInstruction(), fusion);
TF_ASSIGN_OR_RETURN(
@ -1063,10 +1066,12 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
desc.unique_indices = root->unique_indices();
desc.update_computation = root->called_computations()[0];
desc.output = GetIrArray(*fusion, *fusion);
desc.scatter_indices_gen =
scatter_fused_emitter.GetGenerator(root->operand(1));
desc.updates_gen =
scatter_fused_emitter.GetGenerator(root->operand(2));
TF_ASSIGN_OR_RETURN(
desc.scatter_indices_gen,
scatter_fused_emitter.GetGenerator(root->operand(1)));
TF_ASSIGN_OR_RETURN(
desc.updates_gen,
scatter_fused_emitter.GetGenerator(root->operand(2)));
desc.get_index_type = [&](int64 launch_size) {
return GetIndexTypeForKernel(root, launch_size, &b_);
};
@ -1133,9 +1138,11 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
ir_emitter_context_->llvm_module());
AddThunkToThunkSequence(std::move(fusion_thunk));
FusedIrEmitter fused_emitter(&elemental_emitter);
BindFusionArguments(fusion, &fused_emitter);
return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace(
fusion, GetGeneratorForOperandIrArrays(fusion), output_array,
&elemental_emitter, launch_dimensions, &b_);
fusion, output_array, &fused_emitter, launch_dimensions, &b_);
}
CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop)
@ -2596,13 +2603,13 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
ir_emitter_context_->llvm_module(),
&b_, GetNestedComputer());
FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo),
&elemental_emitter);
TF_RETURN_IF_ERROR(init_value_operand->Accept(&fused_emitter));
TF_RETURN_IF_ERROR(
ParallelLoopEmitter(fused_emitter.GetGenerator(init_value_operand),
GetIrArray(*hlo, *hlo, index), launch_dimensions,
&b_)
FusedIrEmitter fused_emitter(&elemental_emitter);
BindFusionArguments(hlo, &fused_emitter);
TF_ASSIGN_OR_RETURN(auto generator,
fused_emitter.GetGenerator(init_value_operand));
TF_RETURN_IF_ERROR(ParallelLoopEmitter(generator,
GetIrArray(*hlo, *hlo, index),
launch_dimensions, &b_)
.EmitLoop(IrName(hlo)));
} else {
// In the unfused case the element is already there, just read from it.
@ -3097,15 +3104,35 @@ void IrEmitterUnnested::EmitTileElementForFusion(
std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(*hlo);
GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
GetNestedComputer());
FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo),
&elem_emitter, x_loc, y_loc,
param_shmem_buffers);
TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter));
FusedIrEmitter fused_emitter(&elem_emitter);
for (int i = 0; i < hlo->operand_count(); i++) {
llvm_ir::ElementGenerator gen;
if (llvm::Value* param_tile_buffer = param_shmem_buffers[i]) {
gen = [this, param_tile_buffer, x_loc,
y_loc](llvm_ir::IrArray::Index index) {
// TODO(jlebar): Add AA metadata to this load. Tile buffers are
// global variables, so LLVM's points-to analysis doesn't help us
// much. And we want the AA info to be present before address
// spaces are inferred (which is pretty late in the pipeline), so
// even if we had address-space-based AA in LLVM, it wouldn't help
// us much here.
return b_.CreateLoad(
b_.CreateGEP(param_tile_buffer,
{index.GetConstantWithIndexType(0), x_loc, y_loc}),
"tiled_buffer");
};
} else {
const HloInstruction* operand = hlo->operand(i);
gen = [this, operand, hlo](llvm_ir::IrArray::Index index) {
return GetIrArray(*operand, *hlo).EmitReadArrayElement(index, &b_);
};
}
fused_emitter.BindGenerator(hlo->fused_parameter(i), std::move(gen));
}
IrArray::Index untiled_index = GetUnnormalizedIndex(
index, output_arrays[0].GetShape(), &b_, mapping_scheme);
const llvm_ir::ElementGenerator& output_generator =
fused_emitter.GetRootGenerator();
llvm_ir::ElementGenerator output_generator =
*fused_emitter.GetGenerator(hlo->fused_expression_root());
llvm::Value* output_value = output_generator(untiled_index).ValueOrDie();
if (hlo->IsMultiOutputFusion()) {
DCHECK(output_value->getType()->isStructTy());
@ -3161,13 +3188,11 @@ void IrEmitterUnnested::EmitPrologueForReduction(
llvm::Value* init_ir_value;
const HloInstruction* init_value = reduce_inst->operand(1);
if (unnested_hlo->opcode() == HloOpcode::kFusion) {
FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo),
&elemental_emitter);
FusedIrEmitter fused_emitter(&elemental_emitter);
BindFusionArguments(unnested_hlo, &fused_emitter);
TF_CHECK_OK(init_value->Accept(&fused_emitter));
init_ir_value =
fused_emitter
.GetGenerator(init_value)(IrArray::Index(b_.getInt32Ty()))
init_ir_value = (*fused_emitter.GetGenerator(init_value))(
IrArray::Index(b_.getInt32Ty()))
.ValueOrDie();
} else {
init_ir_value =
@ -3507,21 +3532,21 @@ void IrEmitterUnnested::EmitTileElementForReduction(
extra_output_gens;
GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
GetNestedComputer());
FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo),
&elem_emitter);
FusedIrEmitter fused_emitter(&elem_emitter);
// Construct the ElementGenerator for each reduction and extra output in the
// the group of output instructions.
if (unnested_hlo->opcode() == HloOpcode::kFusion) {
TF_CHECK_OK(unnested_hlo->fused_expression_root()->Accept(&fused_emitter));
BindFusionArguments(unnested_hlo, &fused_emitter);
for (int i = 0, e = output_instructions.size(); i != e; ++i) {
const HloInstruction* inst = output_instructions[i];
ShapeIndex idx =
CreateShapeIndexForOutputInstruction(*unnested_hlo, *inst);
if (IsReductionFromOrToContiguousDimensions(*inst)) {
input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0)));
input_gens.push_back(*fused_emitter.GetGenerator(inst->operand(0)));
} else {
extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst),
extra_output_gens.emplace_back(*fused_emitter.GetGenerator(inst),
std::move(idx));
}
}
@ -4506,11 +4531,10 @@ void IrEmitterUnnested::EmitElementForInputFusibleSlices(
std::vector<llvm::Value*> input_ir_values;
GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
GetNestedComputer());
FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo),
&elem_emitter);
TF_CHECK_OK(unnested_hlo->fused_expression_root()->Accept(&fused_emitter));
FusedIrEmitter fused_emitter(&elem_emitter);
BindFusionArguments(unnested_hlo, &fused_emitter);
for (const HloInstruction* slice : slice_instructions) {
auto input_generator = fused_emitter.GetGenerator(slice->operand(0));
auto input_generator = *fused_emitter.GetGenerator(slice->operand(0));
input_ir_values.push_back(input_generator(index).ValueOrDie());
}

View File

@ -17,7 +17,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
@ -190,9 +189,8 @@ Status EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays,
//
// Emits a sequential loop if launch_dimensions is null.
static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
HloInstruction* fusion,
GeneratorForOperandIrArrays operand_arrays_generator,
const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
HloInstruction* fusion, const IrArray& fusion_output_array,
FusedIrEmitter* fused_emitter,
const gpu::LaunchDimensions* launch_dimensions, llvm::IRBuilder<>* b) {
CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
VLOG(2) << "EmitFusedDynamicUpdateSliceInPlace for "
@ -221,14 +219,14 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
LayoutUtil::CopyLayoutBetweenShapes(fusion->shape(), &update_shape));
// Create element generators for update and start_indices.
FusedIrEmitter fused_emitter(std::move(operand_arrays_generator),
elemental_emitter);
TF_RETURN_IF_ERROR(dynamic_update_slice->Accept(&fused_emitter));
ElementGenerator update_array_generator = fused_emitter.GetGenerator(update);
TF_ASSIGN_OR_RETURN(ElementGenerator update_array_generator,
fused_emitter->GetGenerator(update));
IndexGenerator start_indices_generator = [&](int64 index) {
ElementGenerator element_generator =
fused_emitter.GetGenerator(dynamic_update_slice->operand(2 + index));
IndexGenerator start_indices_generator =
[&](int64 index) -> StatusOr<llvm::Value*> {
TF_ASSIGN_OR_RETURN(
ElementGenerator element_generator,
fused_emitter->GetGenerator(dynamic_update_slice->operand(2 + index)));
return element_generator(IrArray::Index(b->getInt64Ty()));
};
bool is_signed = ShapeUtil::ElementIsSigned(start_indices->shape());
@ -237,25 +235,21 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
fusion_output_array, launch_dimensions, IrName(fusion), b);
}
Status EmitFusedDynamicUpdateSliceInPlace(
HloInstruction* fusion,
GeneratorForOperandIrArrays operand_arrays_generator,
const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
Status EmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion,
const IrArray& fusion_output_array,
FusedIrEmitter* fused_emitter,
llvm::IRBuilder<>* b) {
return EmitFusedDynamicUpdateSliceInPlaceImpl(
fusion, std::move(operand_arrays_generator), fusion_output_array,
elemental_emitter,
fusion, fusion_output_array, fused_emitter,
/*launch_dimensions=*/nullptr, b);
}
Status EmitParallelFusedDynamicUpdateSliceInPlace(
HloInstruction* fusion,
GeneratorForOperandIrArrays operand_arrays_generator,
const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
HloInstruction* fusion, const IrArray& fusion_output_array,
FusedIrEmitter* fused_emitter,
const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b) {
return EmitFusedDynamicUpdateSliceInPlaceImpl(
fusion, std::move(operand_arrays_generator), fusion_output_array,
elemental_emitter, &launch_dimensions, b);
fusion, fusion_output_array, fused_emitter, &launch_dimensions, b);
}
} // namespace llvm_ir

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
// Utilities related to emitting LLVM IR for various HLO ops.
@ -71,18 +72,16 @@ Status EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays,
// array-to-be-updated and output share the same buffer slice, emits
// (sequential) code for a fusion node that does the dynamic-update-slice in
// place.
Status EmitFusedDynamicUpdateSliceInPlace(
HloInstruction* fusion,
GeneratorForOperandIrArrays operand_arrays_generator,
const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
Status EmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion,
const IrArray& fusion_output_array,
FusedIrEmitter* fused_emitter,
llvm::IRBuilder<>* b);
// Same as EmitFusedDynamicUpdateSliceInPlace, except emits a parallel loop with
// the given launch dimensions.
Status EmitParallelFusedDynamicUpdateSliceInPlace(
HloInstruction* fusion,
GeneratorForOperandIrArrays operand_arrays_generator,
const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
HloInstruction* fusion, const IrArray& fusion_output_array,
FusedIrEmitter* fused_emitter,
const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b);
} // namespace llvm_ir

View File

@ -114,25 +114,9 @@ Status FusedIrEmitter::HandleGetTupleElement(
}
Status FusedIrEmitter::HandleParameter(const HloInstruction* parameter) {
indexed_generators_[parameter] =
[=](const IrArray::Index& index) -> llvm::Value* {
int64 param_num = parameter->parameter_number();
if (param_shmem_buffers_.size() > param_num) {
if (llvm::Value* param_tile_buffer = param_shmem_buffers_[param_num]) {
// TODO(jlebar): Add AA metadata to this load. Tile buffers are global
// variables, so LLVM's points-to analysis doesn't help us much. And we
// want the AA info to be present before address spaces are inferred
// (which is pretty late in the pipeline), so even if we had
// address-space-based AA in LLVM, it wouldn't help us much here.
return b_->CreateLoad(
b_->CreateGEP(param_tile_buffer, {index.GetConstantWithIndexType(0),
thread_id_x_, thread_id_y_}),
"tiled_buffer");
if (indexed_generators_.find(parameter) == indexed_generators_.end()) {
return InvalidArgument("Unbound parameter: %s", parameter->ToString());
}
}
return GetIrArrayForFusedParameter(param_num).EmitReadArrayElement(index,
b_);
};
return Status::OK();
}
@ -157,22 +141,6 @@ Status FusedIrEmitter::HandleTuple(const HloInstruction* tuple) {
return Status::OK();
}
Status FusedIrEmitter::FinishVisit(const HloInstruction* root) {
fused_root_ = root;
return Status::OK();
}
FusedIrEmitter::IndexedGenerator FusedIrEmitter::GetRootGenerator() const {
CHECK_NE(nullptr, fused_root_)
<< "GetRootGenerator should be called after Accept.";
return indexed_generators_.at(fused_root_);
}
FusedIrEmitter::IndexedGenerator FusedIrEmitter::GetGenerator(
const HloInstruction* instruction) const {
return indexed_generators_.at(instruction);
}
bool FusedIrEmitter::IsFusedIrEmitterInefficient(
const HloInstruction* consumer, const HloInstruction* producer) {
if (consumer->opcode() != HloOpcode::kFusion) {
@ -189,4 +157,39 @@ bool FusedIrEmitter::IsFusedIrEmitterInefficient(
return eval_producer.MaxCodeDuplicationTooHigh();
}
StatusOr<FusedIrEmitter::IndexedGenerator> FusedIrEmitter::GetGenerator(
const HloInstruction* instruction) {
std::vector<const HloInstruction*> stack;
stack.push_back(instruction);
while (!stack.empty()) {
const HloInstruction* instr = stack.back();
stack.pop_back();
if (indexed_generators_.count(instr)) {
continue;
}
for (const HloInstruction* operand : instr->operands()) {
stack.push_back(operand);
}
switch (instr->opcode()) {
case HloOpcode::kConstant:
TF_RETURN_IF_ERROR(HandleConstant(instr));
break;
case HloOpcode::kGetTupleElement:
TF_RETURN_IF_ERROR(HandleGetTupleElement(instr));
break;
case HloOpcode::kParameter:
TF_RETURN_IF_ERROR(HandleParameter(instr));
break;
case HloOpcode::kTuple:
TF_RETURN_IF_ERROR(HandleTuple(instr));
break;
default:
TF_RETURN_IF_ERROR(DefaultAction(instr));
break;
}
CHECK(indexed_generators_.count(instr));
}
return indexed_generators_.at(instruction);
}
} // namespace xla

View File

@ -24,7 +24,6 @@ limitations under the License.
#include "absl/types/span.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
@ -51,47 +50,22 @@ namespace xla {
// created produces an LLVM struct with N elements, one for each element of the
// arrays in the tuple. It follows that the arrays in the tuple must have the
// same length.
class FusedIrEmitter : public ConstDfsHloVisitorWithDefault {
class FusedIrEmitter {
public:
using IndexedGenerator = llvm_ir::ElementGenerator;
using NonIndexedGenerator = std::function<StatusOr<llvm::Value*>()>;
using GeneratorForOperandIrArrays =
std::function<std::vector<llvm_ir::IrArray>()>;
FusedIrEmitter(GeneratorForOperandIrArrays operand_arrays_generator,
ElementalIrEmitter* elemental_emitter,
llvm::Value* thread_id_x = nullptr,
llvm::Value* thread_id_y = nullptr,
absl::Span<llvm::Value* const> param_shmem_buffers = {})
: operand_arrays_(),
operand_arrays_generator_(std::move(operand_arrays_generator)),
thread_id_x_(thread_id_x),
thread_id_y_(thread_id_y),
param_shmem_buffers_(param_shmem_buffers.begin(),
param_shmem_buffers.end()),
elemental_emitter_(elemental_emitter),
explicit FusedIrEmitter(ElementalIrEmitter* elemental_emitter)
: elemental_emitter_(elemental_emitter),
b_(elemental_emitter->b()),
module_(elemental_emitter->module()) {}
Status DefaultAction(const HloInstruction* hlo) override;
Status HandleConstant(const HloInstruction* constant) override;
Status HandleGetTupleElement(
const HloInstruction* get_tuple_element) override;
Status HandleParameter(const HloInstruction* parameter) override;
// Emits the ir value for each element in the tuple.
Status HandleTuple(const HloInstruction* tuple) override;
Status FinishVisit(const HloInstruction* root) override;
// Returns the generator function for the root of the fused computation.
IndexedGenerator GetRootGenerator() const;
void BindGenerator(const HloInstruction* hlo,
llvm_ir::ElementGenerator generator) {
indexed_generators_[hlo] = std::move(generator);
}
// Returns the generator function for the given instruction.
IndexedGenerator GetGenerator(const HloInstruction* instruction) const;
StatusOr<IndexedGenerator> GetGenerator(const HloInstruction* instruction);
// Evaluates whether fusing 'producer' into 'consumer' might cause exponential
// behavior in FusedIrEmitter. We currently can have exponential time/memory
@ -101,40 +75,20 @@ class FusedIrEmitter : public ConstDfsHloVisitorWithDefault {
static bool IsFusedIrEmitterInefficient(const HloInstruction* consumer,
const HloInstruction* producer);
protected:
// Returns the IrArrays for the fusion instruction operands.
llvm_ir::IrArray& GetIrArrayForFusedParameter(int64 parameter_number) {
if (!operand_arrays_.has_value()) {
operand_arrays_ = operand_arrays_generator_();
}
return operand_arrays_.value()[parameter_number];
}
llvm::Value* GetBasePointerForFusedParameter(int64 parameter_number) {
return GetIrArrayForFusedParameter(parameter_number).GetBasePointer();
}
private:
// IrArrays for the fusion instruction operands, whose base addresses are the
// base address of the corresponding parameters in the fused computation.
absl::optional<std::vector<llvm_ir::IrArray>> operand_arrays_;
GeneratorForOperandIrArrays operand_arrays_generator_;
Status DefaultAction(const HloInstruction* hlo);
// The x coordinate within a tile.
llvm::Value* thread_id_x_;
Status HandleConstant(const HloInstruction* constant);
// The y coordinate within a tile.
llvm::Value* thread_id_y_;
Status HandleGetTupleElement(const HloInstruction* get_tuple_element);
// Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr
// if the parameter is not tiled.
std::vector<llvm::Value*> param_shmem_buffers_;
Status HandleParameter(const HloInstruction* parameter);
// Emits the ir value for each element in the tuple.
Status HandleTuple(const HloInstruction* tuple);
ElementalIrEmitter* elemental_emitter_;
// This member will be set by FinishVisit and used in GetRootGenerator.
const HloInstruction* fused_root_ = nullptr;
// Borrowed
llvm::IRBuilder<>* b_;
llvm::Module* module_;
@ -145,12 +99,6 @@ class FusedIrEmitter : public ConstDfsHloVisitorWithDefault {
std::unordered_map<const HloInstruction*, IndexedGenerator>
indexed_generators_;
// Map from tuple-result-producing GetTupleELement instructions to functions
// that generate the base pointers for the output elements. This is used to
// support the translation of nested GetTupleElement instructions.
std::unordered_map<const HloInstruction*, NonIndexedGenerator>
non_indexed_generators_;
// Cache of generated values, lest we regenerate an element of a node with
// multiple outgoing edges
absl::flat_hash_map<

View File

@ -521,8 +521,7 @@ XLA_TEST_F(ConcatTest, ConcatDeeplyNested) {
ComputeAndCompareR1<float>(&builder, expected, {a_data.get()});
}
// TODO(b/169314478): Enable the test when the slow compilation is fixed.
XLA_TEST_F(ConcatTestHlo, DISABLED_ConcatWithBitcast) {
XLA_TEST_F(ConcatTestHlo, ConcatWithBitcast) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule jit_broken.874
@ -762,7 +761,7 @@ ENTRY jit_broken.874 {
auto input_array = absl::make_unique<Array2D<float>>(4, 2);
input_array->FillUnique(1.0f);
auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
EXPECT_TRUE(RunAndCompare(std::move(module), {&input}, absl::nullopt));
EXPECT_TRUE(RunAndCompare(std::move(module), {&input}, error_spec_));
}
// Describes a binary rank-2 concatenation test.

View File

@ -354,8 +354,11 @@ void BaseCollectiveExecutor::CompleteParamsAsync(
// timeout callback executes, done_safe will become a no-op and the timeout
// callback is responsible for invoking done() at the end.
const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
auto done_safe = [this, is_callback_called, cancel_mgr,
auto trace_id =
profiler::TraceMe::ActivityStart("CollectiveExecutor::CompleteParams");
auto done_safe = [this, is_callback_called, cancel_mgr, trace_id,
done](const Status& s) {
profiler::TraceMe::ActivityEnd(trace_id);
bool called = is_callback_called->exchange(true);
if (!called) {
if (!s.ok() && !IsCancelled(cancel_mgr)) {

View File

@ -2587,11 +2587,9 @@ TEST(DirectSessionTest,
// A simple benchmark for the overhead of `DirectSession::Run()` calls
// with varying numbers of feeds/fetches.
void FeedFetchBenchmarkHelper(int iters, int num_feeds, bool use_make_callable,
int inter_op_threads,
void FeedFetchBenchmarkHelper(::testing::benchmark::State& state, int num_feeds,
bool use_make_callable, int inter_op_threads,
bool use_single_threaded_executor) {
testing::StopTiming();
Tensor value(DT_FLOAT, TensorShape());
value.flat<float>()(0) = 37.0;
@ -2643,13 +2641,11 @@ void FeedFetchBenchmarkHelper(int iters, int num_feeds, bool use_make_callable,
}
TF_CHECK_OK(session->MakeCallable(callable_options, &handle));
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
std::vector<Tensor> output_values;
TF_CHECK_OK(
session->RunCallable(handle, input_tensors, &output_values, nullptr));
}
testing::StopTiming();
} else {
{
// NOTE(mrry): Ignore the first run, which will incur the graph
@ -2661,32 +2657,40 @@ void FeedFetchBenchmarkHelper(int iters, int num_feeds, bool use_make_callable,
std::vector<Tensor> output_values;
TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
}
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
std::vector<Tensor> output_values;
TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
}
testing::StopTiming();
}
}
void BM_FeedFetch(int iters, int num_feeds) {
FeedFetchBenchmarkHelper(iters, num_feeds, /* use_make_callable */ false,
void BM_FeedFetch(::testing::benchmark::State& state) {
const int num_feeds = state.range(0);
FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ false,
/* inter_op_threads */ 0,
/* use_single_threaded_executor */ false);
}
void BM_FeedFetchCallable(int iters, int num_feeds) {
FeedFetchBenchmarkHelper(iters, num_feeds, /* use_make_callable */ true,
void BM_FeedFetchCallable(::testing::benchmark::State& state) {
const int num_feeds = state.range(0);
FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ true,
/* inter_op_threads */ 0,
/* use_single_threaded_executor */ false);
}
void BM_FeedFetchCallableSingleThread(int iters, int num_feeds) {
FeedFetchBenchmarkHelper(iters, num_feeds, /* use_make_callable */ true,
void BM_FeedFetchCallableSingleThread(::testing::benchmark::State& state) {
const int num_feeds = state.range(0);
FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ true,
/* inter_op_threads */ -1,
/* use_single_threaded_executor */ false);
}
void BM_FeedFetchCallableSingleThreadExecutor(int iters, int num_feeds) {
FeedFetchBenchmarkHelper(iters, num_feeds, /* use_make_callable */ true,
void BM_FeedFetchCallableSingleThreadExecutor(
::testing::benchmark::State& state) {
const int num_feeds = state.range(0);
FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ true,
/* inter_op_threads */ -1,
/* use_single_threaded_executor */ true);
}

View File

@ -378,6 +378,12 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) {
} else if (!input_def.type_attr().empty() &&
!input_def.number_attr().empty()) {
InferSingleTypeInputListAttrs(input_def, inputs_[start]->dtype, num_inputs);
} else if (!input_def.number_attr().empty()) {
if (inference_attrs_.find(input_def.number_attr()) ==
inference_attrs_.end()) {
MutableAttrs()->Set(input_def.number_attr(), num_inputs);
inference_attrs_.insert(input_def.number_attr());
}
} else {
return errors::InvalidArgument("Invalid input list definition");
}

View File

@ -69,8 +69,8 @@ class TestEnv {
Device* cpu_device_;
};
void BM_CreateGraph(int iters) {
for (int i = 0; i < iters; ++i) {
void BM_CreateGraph(::testing::benchmark::State& state) {
for (auto s : state) {
Scope root = Scope::NewRootScope();
auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}});
auto M = ops::MatMul(root, C, C);
@ -79,8 +79,7 @@ void BM_CreateGraph(int iters) {
}
BENCHMARK(BM_CreateGraph);
void BM_RunGraph(int iters) {
tensorflow::testing::StopTiming();
void BM_RunGraph(::testing::benchmark::State& state) {
Scope root = Scope::NewRootScope();
auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}});
auto M = ops::MatMul(root, C, C);
@ -89,28 +88,24 @@ void BM_RunGraph(int iters) {
opts.config.set_intra_op_parallelism_threads(1);
ClientSession sess(root, opts);
std::vector<Tensor> outputs;
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
outputs.clear();
TF_CHECK_OK(sess.Run({M}, &outputs));
}
}
BENCHMARK(BM_RunGraph);
void BM_CreateAndDestroySession(int iters) {
tensorflow::testing::StopTiming();
void BM_CreateAndDestroySession(::testing::benchmark::State& state) {
Scope root = Scope::NewRootScope();
auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}});
auto M = ops::MatMul(root, C, C);
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
ClientSession sess(root);
}
}
BENCHMARK(BM_CreateAndDestroySession);
void BM_KernelAndDeviceInit(int iters) {
tensorflow::testing::StopTiming();
void BM_KernelAndDeviceInit(::testing::benchmark::State& state) {
NodeDef ndef(AttrBuilder("MatMul")
.Set("T", DT_FLOAT)
.Set("transpose_a", false)
@ -120,15 +115,13 @@ void BM_KernelAndDeviceInit(int iters) {
TestEnv env;
KernelAndDeviceOp k(nullptr, false, env.function_library_runtime(), nullptr,
nullptr, env.cpu_device());
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
TF_CHECK_OK(k.Init({}, ndef, nullptr));
}
}
BENCHMARK(BM_KernelAndDeviceInit);
void BM_KernelAndDeviceRun(int iters) {
tensorflow::testing::StopTiming();
void BM_KernelAndDeviceRun(::testing::benchmark::State& state) {
Tensor t(Input({{1.0f, 2.0f}, {3.0f, 4.0f}}).tensor());
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.push_back(TensorValue(&t));
@ -145,8 +138,7 @@ void BM_KernelAndDeviceRun(int iters) {
nullptr, env.cpu_device());
TF_CHECK_OK(k.Init({}, ndef, nullptr));
const EagerKernelArgs args(std::move(inputs));
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
TF_CHECK_OK(k.Run(nullptr, args, &outputs, nullptr, absl::nullopt));
}
}

View File

@ -433,11 +433,10 @@ TEST_F(ExecutorTest, NoInputTensors) {
// Create a graph that is 'depth' deep. At each level, fan-in and fan-out a
// maximum of 'width' nodes. All nodes are no-ops and all dependencies are
// control dependencies.
static void BM_executor(int iters, int width, int depth) {
testing::StopTiming();
#ifdef PLATFORM_GOOGLE
BenchmarkUseRealTime();
#endif // PLATFORM_GOOGLE
static void BM_executor(::testing::benchmark::State& state) {
const int width = state.range(0);
const int depth = state.range(1);
Graph* g = new Graph(OpRegistry::Global());
random::PhiloxRandom philox(1729, 17);
random::SimplePhilox rand(&philox);
@ -466,30 +465,29 @@ static void BM_executor(int iters, int width, int depth) {
++cur;
}
}
#ifdef PLATFORM_GOOGLE
SetBenchmarkLabel(strings::StrCat("Nodes = ", cur));
SetBenchmarkItemsProcessed(cur * static_cast<int64>(iters));
#endif // PLATFORM_GOOGLE
FixupSourceAndSinkEdges(g);
testing::StartTiming();
test::Benchmark("cpu", g).Run(iters);
test::Benchmark("cpu", g, /*old_benchmark_api=*/false).Run(state);
state.SetLabel(strings::StrCat("Nodes = ", cur));
state.SetItemsProcessed(cur * static_cast<int64>(state.iterations()));
}
// Tall skinny graphs
BENCHMARK(BM_executor)->ArgPair(16, 1024);
BENCHMARK(BM_executor)->ArgPair(32, 8192);
BENCHMARK(BM_executor)->UseRealTime()->ArgPair(16, 1024);
BENCHMARK(BM_executor)->UseRealTime()->ArgPair(32, 8192);
// Short fat graphs
BENCHMARK(BM_executor)->ArgPair(1024, 16);
BENCHMARK(BM_executor)->ArgPair(8192, 32);
BENCHMARK(BM_executor)->UseRealTime()->ArgPair(1024, 16);
BENCHMARK(BM_executor)->UseRealTime()->ArgPair(8192, 32);
// Tall fat graph
BENCHMARK(BM_executor)->ArgPair(1024, 1024);
BENCHMARK(BM_executor)->UseRealTime()->ArgPair(1024, 1024);
static void BM_const_identity(::testing::benchmark::State& state) {
const int width = state.range(0);
const int outputs_per_const = state.range(1);
static void BM_const_identity(int iters, int width, int outputs_per_const) {
#ifdef PLATFORM_GOOGL
BenchmarkUseRealTime();
#endif // PLATFORM_GOOGLE
Graph* g = new Graph(OpRegistry::Global());
for (int i = 0; i < width; ++i) {
Tensor i_t(i);
@ -499,23 +497,21 @@ static void BM_const_identity(int iters, int width, int outputs_per_const) {
}
}
FixupSourceAndSinkEdges(g);
#ifdef PLATFORM_GOOGLE
SetBenchmarkLabel(
strings::StrCat("Nodes = ", (1 + outputs_per_const) * width));
SetBenchmarkItemsProcessed((1 + outputs_per_const) * width *
static_cast<int64>(iters));
#endif // PLATFORM_GOOGLE
test::Benchmark("cpu", g).Run(iters);
test::Benchmark("cpu", g, /*old_benchmark_api=*/false).Run(state);
state.SetLabel(strings::StrCat("Nodes = ", (1 + outputs_per_const) * width));
state.SetItemsProcessed((1 + outputs_per_const) * width *
static_cast<int64>(state.iterations()));
}
// Graph with actual op execution.
BENCHMARK(BM_const_identity)->ArgPair(1, 1);
BENCHMARK(BM_const_identity)->ArgPair(1, 100);
BENCHMARK(BM_const_identity)->ArgPair(100, 1);
BENCHMARK(BM_const_identity)->ArgPair(100, 100);
BENCHMARK(BM_const_identity)
->UseRealTime()
->ArgPair(1, 1)
->ArgPair(1, 100)
->ArgPair(100, 1)
->ArgPair(100, 100);
static void BM_FeedInputFetchOutput(int iters) {
testing::StopTiming();
static void BM_FeedInputFetchOutput(::testing::benchmark::State& state) {
Graph* g = new Graph(OpRegistry::Global());
// z = x + y: x and y are provided as benchmark inputs. z is the
// output of the benchmark. Conceptually, the caller is ALICE, the
@ -531,13 +527,10 @@ static void BM_FeedInputFetchOutput(int iters) {
Tensor val(DT_FLOAT, TensorShape({}));
val.scalar<float>()() = 3.14;
#ifdef PLATFORM_GOOGLE
SetBenchmarkItemsProcessed(static_cast<int64>(iters));
#endif // PLATFORM_GOOGLE
FixupSourceAndSinkEdges(g);
testing::StartTiming();
test::Benchmark("cpu", g).RunWithRendezvousArgs({{x_key, val}, {y_key, val}},
{z_key}, iters);
test::Benchmark("cpu", g, /*old_benchmark_api=*/false)
.RunWithRendezvousArgs({{x_key, val}, {y_key, val}}, {z_key}, state);
state.SetItemsProcessed(static_cast<int64>(state.iterations()));
}
BENCHMARK(BM_FeedInputFetchOutput);
@ -549,9 +542,8 @@ BENCHMARK(BM_FeedInputFetchOutput);
//
// ...using the functional `WhileOp` (if `lower` is false) or the
// `Switch`/`Merge`-style of control flow (if `lower` is true).
static void BM_WhileLoopHelper(int iters, int loop_iters, int loop_vars,
bool lower) {
testing::StopTiming();
static void BM_WhileLoopHelper(::testing::benchmark::State& state,
int loop_iters, int loop_vars, bool lower) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
// Add test functions for cond and body.
@ -661,12 +653,15 @@ static void BM_WhileLoopHelper(int iters, int loop_iters, int loop_vars,
}
FixupSourceAndSinkEdges(graph.get());
testing::StartTiming();
test::Benchmark("cpu", graph.release()).Run(iters);
test::Benchmark("cpu", graph.release(), /*old_benchmark_api=*/false)
.Run(state);
}
static void BM_LoweredWhileLoop(int iters, int loop_iters, int loop_vars) {
BM_WhileLoopHelper(iters, loop_iters, loop_vars, /* lower= */ true);
static void BM_LoweredWhileLoop(::testing::benchmark::State& state) {
const int loop_iters = state.range(0);
const int loop_vars = state.range(1);
BM_WhileLoopHelper(state, loop_iters, loop_vars, /* lower= */ true);
}
BENCHMARK(BM_LoweredWhileLoop)
->ArgPair(0, 1)
@ -680,8 +675,11 @@ BENCHMARK(BM_LoweredWhileLoop)
->ArgPair(100, 100)
->ArgPair(1000, 100);
static void BM_FunctionalWhileLoop(int iters, int loop_iters, int loop_vars) {
BM_WhileLoopHelper(iters, loop_iters, loop_vars, /* lower= */ false);
static void BM_FunctionalWhileLoop(::testing::benchmark::State& state) {
const int loop_iters = state.range(0);
const int loop_vars = state.range(1);
BM_WhileLoopHelper(state, loop_iters, loop_vars, /* lower= */ false);
}
BENCHMARK(BM_FunctionalWhileLoop)
->ArgPair(0, 1)

View File

@ -931,7 +931,6 @@ TEST(SessionTest, InvalidOpInputName) {
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
"Illegal op input name");
@ -950,7 +949,6 @@ TEST(SessionTest, InvalidOpInputName) {
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
"Illegal op input name");
@ -969,7 +967,6 @@ TEST(SessionTest, InvalidOpInputName) {
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
"Illegal op input name");
@ -988,7 +985,6 @@ TEST(SessionTest, InvalidOpInputName) {
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
"Illegal op input name");
@ -1026,7 +1022,6 @@ TEST(SessionTest, ExtendValidation) {
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
&extension);
@ -1043,7 +1038,6 @@ TEST(SessionTest, ExtendValidation) {
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
&extension);
@ -1057,7 +1051,6 @@ TEST(SessionTest, ExtendValidation) {
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
&extension);

View File

@ -221,14 +221,16 @@ TEST(CustomAllocatorAttributes, TestSetterAndGetter) {
EXPECT_FALSE(HasDeviceAllocatorAttribute(AllocatorAttributes()));
}
static void BM_Allocation(int iters, int arg) {
static void BM_Allocation(::testing::benchmark::State& state) {
const int arg = state.range(0);
Allocator* a = cpu_allocator();
// Exercise a few different allocation sizes
std::vector<int> sizes = {256, 4096, 16384, 524288, 512, 1048576};
int size_index = 0;
if (arg) EnableCPUAllocatorStats();
while (--iters > 0) {
for (auto s : state) {
int bytes = sizes[size_index++ % sizes.size()];
void* p = a->AllocateRaw(1, bytes);
a->DeallocateRaw(p);

View File

@ -39,60 +39,60 @@ TEST(Bfloat16Test, Conversion) {
}
}
static void BM_FloatToBFloat16(int iters) {
testing::StopTiming();
void BM_FloatToBFloat16(::testing::benchmark::State& state) {
static const int N = 32 << 20;
const int64 tot = static_cast<int64>(iters) * N;
testing::ItemsProcessed(tot);
testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
float* inp = new float[N];
bfloat16* out = new bfloat16[N];
testing::StartTiming();
while (iters--) {
for (auto s : state) {
FloatToBFloat16(inp, out, N);
}
const int64 tot = static_cast<int64>(state.iterations()) * N;
state.SetItemsProcessed(tot);
state.SetBytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
delete[] inp;
delete[] out;
}
BENCHMARK(BM_FloatToBFloat16);
static void BM_RoundFloatToBFloat16(int iters) {
testing::StopTiming();
void BM_RoundFloatToBFloat16(::testing::benchmark::State& state) {
static const int N = 32 << 20;
const int64 tot = static_cast<int64>(iters) * N;
testing::ItemsProcessed(tot);
testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
float* inp = new float[N];
bfloat16* out = new bfloat16[N];
testing::StartTiming();
while (iters--) {
for (auto s : state) {
RoundFloatToBFloat16(inp, out, N);
tensorflow::testing::DoNotOptimize(inp);
tensorflow::testing::DoNotOptimize(out);
}
const int64 tot = static_cast<int64>(state.iterations()) * N;
state.SetItemsProcessed(tot);
state.SetBytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
delete[] inp;
delete[] out;
}
BENCHMARK(BM_RoundFloatToBFloat16);
static void BM_BFloat16ToFloat(int iters) {
testing::StopTiming();
void BM_BFloat16ToFloat(::testing::benchmark::State& state) {
static const int N = 32 << 20;
const int64 tot = static_cast<int64>(iters) * N;
testing::ItemsProcessed(tot);
testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
bfloat16* inp = new bfloat16[N];
float* out = new float[N];
testing::StartTiming();
while (iters--) {
for (auto s : state) {
BFloat16ToFloat(inp, out, N);
}
const int64 tot = static_cast<int64>(state.iterations()) * N;
state.SetItemsProcessed(tot);
state.SetBytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
delete[] inp;
delete[] out;
}

View File

@ -406,7 +406,7 @@ XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
TEST(TFunc, WXPlusB) {
auto expect = R"P(
WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) {
mm = MatMul[T=$T, _kernel="eigen", transpose_a=false, transpose_b=false](w, x)
mm = MatMul[T=$T, transpose_a=false, transpose_b=false](w, x)
y = Add[T=$T](mm:product:0, b)
return y = y:z:0
}

View File

@ -346,10 +346,7 @@ FunctionDef WXPlusB() {
{{{"mm"},
"MatMul",
{"w", "x"},
{{"T", "$T"},
{"transpose_a", false},
{"transpose_b", false},
{"_kernel", "eigen"}}},
{{"T", "$T"}, {"transpose_a", false}, {"transpose_b", false}}},
{{"y"}, "Add", {"mm", "b"}, {{"T", "$T"}}}});
}

View File

@ -1002,9 +1002,9 @@ TEST_F(LabelTest, Duplicate) {
error::INVALID_ARGUMENT);
}
void BM_InputRangeHelper(int iters, const NodeDef& node_def,
const char* input_name, int expected_start,
int expected_stop) {
void BM_InputRangeHelper(::testing::benchmark::State& state,
const NodeDef& node_def, const char* input_name,
int expected_start, int expected_stop) {
Status status;
auto device = absl::make_unique<DummyDevice>(Env::Default());
@ -1013,24 +1013,20 @@ void BM_InputRangeHelper(int iters, const NodeDef& node_def,
TF_GRAPH_DEF_VERSION, &status));
TF_CHECK_OK(status);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
int start;
int stop;
TF_CHECK_OK(op->InputRange(input_name, &start, &stop));
EXPECT_EQ(expected_start, start);
EXPECT_EQ(expected_stop, stop);
}
testing::StopTiming();
}
REGISTER_KERNEL_BUILDER(Name("ConcatV2").Device(DEVICE_CPU), DummyKernel);
REGISTER_KERNEL_BUILDER(Name("Select").Device(DEVICE_CPU), DummyKernel);
REGISTER_KERNEL_BUILDER(Name("MatMul").Device(DEVICE_CPU), DummyKernel);
void BM_ConcatInputRange(int iters) {
testing::StopTiming();
void BM_ConcatInputRange(::testing::benchmark::State& state) {
// Create a ConcatV2 NodeDef with 4 inputs (plus the axis).
NodeDef node_def;
node_def.set_name("concat-op");
@ -1048,12 +1044,10 @@ void BM_ConcatInputRange(int iters) {
node_def.add_input(strings::StrCat("a:", i));
}
BM_InputRangeHelper(iters, node_def, "values", 0, 4);
BM_InputRangeHelper(state, node_def, "values", 0, 4);
}
void BM_SelectInputRange(int iters) {
testing::StopTiming();
void BM_SelectInputRange(::testing::benchmark::State& state) {
// Create a Select NodeDef with 3 inputs.
NodeDef node_def;
node_def.set_name("select-op");
@ -1065,11 +1059,11 @@ void BM_SelectInputRange(int iters) {
node_def.add_input(strings::StrCat("a:", i));
}
BM_InputRangeHelper(iters, node_def, "condition", 0, 1);
BM_InputRangeHelper(state, node_def, "condition", 0, 1);
}
void BM_TraceString(const int iters, const int verbose) {
testing::StopTiming();
void BM_TraceString(::testing::benchmark::State& state) {
const int verbose = state.range(0);
// Create a MatMul NodeDef with 2 inputs.
NodeDef node_def;
@ -1103,11 +1097,9 @@ void BM_TraceString(const int iters, const int verbose) {
params.inputs = &inputs;
auto ctx = absl::make_unique<OpKernelContext>(&params);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
auto trace = op->TraceString(*ctx, verbose);
}
testing::StopTiming();
}
BENCHMARK(BM_ConcatInputRange);

View File

@ -434,37 +434,38 @@ TEST_F(LocalRendezvousTest, TransferDummyDeviceContext) {
args1.device_context->Unref();
}
void BM_SendRecv(int iters) {
void BM_SendRecv(::testing::benchmark::State& state) {
Rendezvous* rendez = NewLocalRendezvous();
Tensor orig = V("val");
Tensor val(DT_STRING, TensorShape({}));
bool is_dead = false;
Rendezvous::Args args;
if (iters > 0) {
while (iters--) {
for (auto s : state) {
TF_CHECK_OK(rendez->Send(KeyFoo(), args, orig, is_dead));
TF_CHECK_OK(rendez->Recv(KeyFoo(), args, &val, &is_dead));
}
CHECK_EQ(V(val), V(orig));
}
rendez->Unref();
}
BENCHMARK(BM_SendRecv);
void BM_RecvSend(int iters) {
void BM_RecvSend(::testing::benchmark::State& state) {
Rendezvous* rendez = NewLocalRendezvous();
Tensor orig = V("val");
Tensor val(DT_STRING, TensorShape({}));
bool is_dead = false;
Rendezvous::Args args;
if (iters > 0) {
while (iters--) {
for (auto s : state) {
bool received = false;
rendez->RecvAsync(
KeyFoo(), args,
[&val, &received](const Status& s, const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args,
const Tensor& tensor, bool is_dead) {
[&val, &received](const Status& /*s*/,
const Rendezvous::Args& /*send_args*/,
const Rendezvous::Args& /*recv_args*/,
const Tensor& tensor, bool /*is_dead*/) {
val = tensor;
received = true;
});
@ -472,26 +473,29 @@ void BM_RecvSend(int iters) {
CHECK(received);
}
CHECK_EQ(V(val), V(orig));
}
rendez->Unref();
}
BENCHMARK(BM_RecvSend);
void BM_PingPong(int iters) {
CHECK_GT(iters, 0);
void BM_PingPong(::testing::benchmark::State& state) {
const int messages_count = state.range(0);
auto* cm = new CancellationManager();
thread::ThreadPool* pool = new thread::ThreadPool(Env::Default(), "test", 1);
// The main thread sends "foo" for iters times and receives "bar"
// for iters times. The other thread sends "bar" for iters times
// and receives "foo" for iters times.
// Benchmark loop
// In each iteration:
// The main thread sends "foo" for messages_count times and receives "bar"
// for messages_count times. The other thread sends "bar" for
// messages_count times and receives "foo" for messages_count times.
for (auto s : state) {
Rendezvous* rendez = NewLocalRendezvous();
pool->Schedule([rendez, iters]() {
pool->Schedule([rendez, messages_count]() {
Tensor bar = V("bar");
Tensor foo(DT_STRING, TensorShape({}));
bool is_dead = false;
Rendezvous::Args args;
for (int i = 0; i < iters; ++i) {
for (int i = 0; i < messages_count; ++i) {
TF_CHECK_OK(rendez->Recv(KeyFoo(), args, &foo, &is_dead));
TF_CHECK_OK(rendez->Send(KeyBar(), args, bar, is_dead));
}
@ -502,15 +506,17 @@ void BM_PingPong(int iters) {
bool is_dead = false;
Rendezvous::Args args;
args.cancellation_manager = cm;
for (int i = 0; i < iters; ++i) {
for (int i = 0; i < messages_count; ++i) {
TF_CHECK_OK(rendez->Send(KeyFoo(), args, foo, is_dead));
TF_CHECK_OK(rendez->Recv(KeyBar(), args, &bar, &is_dead));
}
CHECK_EQ("bar", V(bar));
}
state.SetItemsProcessed(messages_count * state.iterations());
delete pool;
delete cm;
}
BENCHMARK(BM_PingPong);
BENCHMARK(BM_PingPong)->Arg(100)->Arg(200)->Arg(300);
} // namespace
} // namespace tensorflow

View File

@ -684,19 +684,24 @@ static std::vector<int64> MakeSizes(int arg) {
return sizes;
}
static void BM_TensorShape_Init(int iters, int arg) {
void BM_TensorShape_Init(::testing::benchmark::State& state) {
const int arg = state.range(0);
auto sizes = MakeSizes(arg);
while (--iters > 0) {
for (auto s : state) {
TensorShape shape(sizes);
tensorflow::testing::DoNotOptimize(shape.num_elements());
}
}
BENCHMARK(BM_TensorShape_Init)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Arg(4);
static void BM_TensorShape_Assign(int iters, int arg) {
TensorShape s(MakeSizes(arg));
while (--iters > 0) {
TensorShape s2 = s;
void BM_TensorShape_Assign(::testing::benchmark::State& state) {
const int arg = state.range(0);
TensorShape shape(MakeSizes(arg));
for (auto s : state) {
const TensorShape s2 = shape;
tensorflow::testing::DoNotOptimize(s2);
}
}
BENCHMARK(BM_TensorShape_Assign)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Arg(4);

View File

@ -1468,19 +1468,19 @@ TEST(SummarizeValue, STRING_PRINT_V2) {
x.SummarizeValue(16, true));
}
void BM_CreateAndDestroy(int iters) {
void BM_CreateAndDestroy(::testing::benchmark::State& state) {
TensorShape shape({10, 20});
while (--iters) {
for (auto s : state) {
Tensor t(DT_FLOAT, shape);
}
}
BENCHMARK(BM_CreateAndDestroy);
void BM_Assign(int iters) {
void BM_Assign(::testing::benchmark::State& state) {
Tensor a(DT_FLOAT, TensorShape({10, 20}));
Tensor b(DT_FLOAT, TensorShape({10, 20}));
bool a_to_b = true;
while (--iters) {
for (auto s : state) {
if (a_to_b) {
b = a;
} else {
@ -1498,20 +1498,20 @@ TEST(Tensor, EmptyTensorData) {
}
// Benchmark create and destroy a tensor, with an allocated buffer.
void BM_CreateAndDestroyWithBuf(int iters) {
void BM_CreateAndDestroyWithBuf(::testing::benchmark::State& state) {
TensorShape shape({10, 20});
Allocator* allocator = cpu_allocator();
while (--iters) {
for (auto s : state) {
Tensor a(allocator, DT_FLOAT, shape);
}
}
BENCHMARK(BM_CreateAndDestroyWithBuf);
// Benchmark create+copy a tensor, with an allocated buffer.
void BM_CreateAndCopyCtrWithBuf(int iters) {
void BM_CreateAndCopyCtrWithBuf(::testing::benchmark::State& state) {
TensorShape shape({10, 20});
Allocator* allocator = cpu_allocator();
while (--iters) {
for (auto s : state) {
Tensor a(allocator, DT_FLOAT, shape);
Tensor b(a);
}
@ -1519,10 +1519,10 @@ void BM_CreateAndCopyCtrWithBuf(int iters) {
BENCHMARK(BM_CreateAndCopyCtrWithBuf);
// Benchmark create+move a tensor, with an allocated buffer.
void BM_CreateAndMoveCtrWithBuf(int iters) {
void BM_CreateAndMoveCtrWithBuf(::testing::benchmark::State& state) {
TensorShape shape({10, 20});
Allocator* allocator = cpu_allocator();
while (--iters) {
for (auto s : state) {
Tensor a(allocator, DT_FLOAT, shape);
Tensor b(std::move(a));
}
@ -1531,10 +1531,11 @@ BENCHMARK(BM_CreateAndMoveCtrWithBuf);
// Benchmark creating and destroy a host-scalar tensor, using the allocator
// interface.
void BM_CreateAndDestroyHostScalarNonOptimized(int iters) {
void BM_CreateAndDestroyHostScalarNonOptimized(
::testing::benchmark::State& state) {
TensorShape shape({});
Allocator* allocator = cpu_allocator();
while (--iters) {
for (auto s : state) {
Tensor a(allocator, DT_FLOAT, shape);
a.scalar<float>()() = 37.0;
}
@ -1543,32 +1544,33 @@ BENCHMARK(BM_CreateAndDestroyHostScalarNonOptimized);
// Benchmark creating and destroy a host-scalar tensor, using the specialized
// constructor.
void BM_CreateAndDestroyHostScalarOptimized(int iters) {
while (--iters) {
void BM_CreateAndDestroyHostScalarOptimized(
::testing::benchmark::State& state) {
for (auto s : state) {
Tensor a(37.0);
}
}
BENCHMARK(BM_CreateAndDestroyHostScalarOptimized);
static void BM_FromProto(int iters, int size) {
testing::StopTiming();
void BM_FromProto(::testing::benchmark::State& state) {
const int size = state.range(0);
TensorShape shape({size});
Allocator* allocator = cpu_allocator();
Tensor a(allocator, DT_FLOAT, shape);
std::fill_n(a.flat<float>().data(), size, 42.0);
TensorProto p;
a.AsProtoField(&p);
testing::StartTiming();
while (--iters) {
for (auto s : state) {
Tensor b;
ASSERT_TRUE(b.FromProto(p));
}
testing::StopTiming();
}
BENCHMARK(BM_FromProto)->Range(1, 1 << 20);
static void BM_FromProtoCompressed(int iters, int size) {
testing::StopTiming();
void BM_FromProtoCompressed(::testing::benchmark::State& state) {
const int size = state.range(0);
TensorShape shape({size});
Allocator* allocator = cpu_allocator();
Tensor a(allocator, DT_FLOAT, shape);
@ -1576,17 +1578,16 @@ static void BM_FromProtoCompressed(int iters, int size) {
TensorProto p;
a.AsProtoField(&p);
tensor::CompressTensorProtoInPlace(&p);
testing::StartTiming();
while (--iters) {
for (auto s : state) {
Tensor b;
ASSERT_TRUE(b.FromProto(p));
}
testing::StopTiming();
}
BENCHMARK(BM_FromProtoCompressed)->Range(1, 1 << 20);
static void BM_FromProtoCompressedZero(int iters, int size) {
testing::StopTiming();
void BM_FromProtoCompressedZero(::testing::benchmark::State& state) {
const int size = state.range(0);
TensorShape shape({size});
Allocator* allocator = cpu_allocator();
Tensor a(allocator, DT_FLOAT, shape);
@ -1595,12 +1596,10 @@ static void BM_FromProtoCompressedZero(int iters, int size) {
TensorProto p;
a.AsProtoField(&p);
tensor::CompressTensorProtoInPlace(&p);
testing::StartTiming();
while (--iters) {
for (auto s : state) {
Tensor b;
ASSERT_TRUE(b.FromProto(p));
}
testing::StopTiming();
}
BENCHMARK(BM_FromProtoCompressedZero)->Range(1, 1 << 20);

View File

@ -767,16 +767,49 @@ Status AvgPoolGradTransposer::TransposeNode(TransposeContext* context,
return context->graph_view->GetMutationBuilder()->Apply();
}
Status BiasAddGradTransposer::TransposeNode(TransposeContext* context,
Status BiasAddTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsBiasAddGrad(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, 4)) {
DCHECK(IsBiasAdd(*node->node()));
const int rank = GetFanoutPortRank(*node, 0);
if (rank != 4 && rank != 5) {
return Status::OK();
}
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
// BiasAdd itself only needs NCHW/NHWC to determine whether C dim is the
// second or the last dim. Therefore, we use the original 4D data format in
// the context to update the node. For the input/output tensor, the
// corresponding 4D or 5D data format is needed.
TF_RETURN_IF_ERROR(UpdateNode(context, node));
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
Status BiasAddGradTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsBiasAddGrad(*node->node()));
const int rank = GetFaninPortRank(*node, 0);
if (rank != 4 && rank != 5) {
return Status::OK();
}
if (!ShouldProcess(*context, *node)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
// BiasAddGrad itself only needs NCHW/NHWC to determine whether C dim is the
// second or the last dim. Therefore, we use the original 4D data format in
// the context to update the node. For the input tensor, the corresponding 4D
// or 5D data format is needed.
TF_RETURN_IF_ERROR(UpdateNode(context, node));
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
// No need to update output shape, as it is always of shape 1-D with size the
// feature dimension of `out_backprop`, regardless of whether NCHW or NHWC is
@ -839,7 +872,12 @@ Status Conv2DBackpropInputTransposer::TransposeNode(
Status Conv3DTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsConv3D(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 5)) {
const int rank = GetFanoutPortRank(*node, 0);
if (rank != 5) {
return Status::OK();
}
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
if (!ShouldProcess(*context, *node)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
@ -854,7 +892,12 @@ Status Conv3DTransposer::TransposeNode(TransposeContext* context,
Status Conv3DBackpropFilterTransposer::TransposeNode(
TransposeContext* context, utils::MutableNodeView* node) {
DCHECK(IsConv3DBackpropFilterV2(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 5)) {
const int rank = GetFanoutPortRank(*node, 0);
if (rank != 5) {
return Status::OK();
}
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
if (!ShouldProcess(*context, *node)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
@ -872,7 +915,12 @@ Status Conv3DBackpropFilterTransposer::TransposeNode(
Status Conv3DBackpropInputTransposer::TransposeNode(
TransposeContext* context, utils::MutableNodeView* node) {
DCHECK(IsConv3DBackpropInputV2(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 5)) {
const int rank = GetFanoutPortRank(*node, 0);
if (rank != 5) {
return Status::OK();
}
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
if (!ShouldProcess(*context, *node)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
@ -1081,8 +1129,9 @@ bool LayoutAgnosticOpTransposer::IsAfterDstToSrcTransform(
return false;
}
std::vector<int> LayoutAgnosticOpTransposer::GetVariadic4DFaninPorts(
const TransposeContext& context, const utils::MutableNodeView& node) const {
std::vector<int> LayoutAgnosticOpTransposer::GetVariadicNDFaninPorts(
const TransposeContext& context, const utils::MutableNodeView& node,
int rank) const {
std::vector<int> ports;
const int num_regular_fanins = node.NumRegularFanins();
ports.reserve(num_regular_fanins);
@ -1090,7 +1139,7 @@ std::vector<int> LayoutAgnosticOpTransposer::GetVariadic4DFaninPorts(
const auto& regular_fanin = node.GetRegularFanin(i);
auto* regular_fanin_node = regular_fanin.node_view();
int regular_fanin_port = regular_fanin.index();
if (IsFanoutPortRankN(*regular_fanin_node, regular_fanin_port, 4) &&
if ((IsFanoutPortRankN(*regular_fanin_node, regular_fanin_port, rank)) &&
((IsAfterDstToSrcTransform(context, *regular_fanin_node) &&
IsLayoutAgnosticOp(*regular_fanin_node->node())) ||
IsLayoutOptimizerAddedDstToSrcTranspose(context,
@ -1124,10 +1173,18 @@ Status DefaultLayoutAgnosticOpTransposer::TransposeNode(
Status AddNTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsAddN(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
const int rank = GetFanoutPortRank(*node, 0);
if (rank != 4 && rank != 5) {
return Status::OK();
}
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
if (!ShouldProcess(*context, *node) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, GetDataFaninPorts(*node),
node, kOpTranspose));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
@ -1284,7 +1341,12 @@ Status BinaryOpTransposer::TransposeNode(TransposeContext* context,
Status ConcatOpTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsConcat(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
const int rank = GetFanoutPortRank(*node, 0);
if (rank != 4 && rank != 5) {
return Status::OK();
}
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
if (!ShouldProcess(*context, *node) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
@ -1297,6 +1359,9 @@ Status ConcatOpTransposer::TransposeNode(TransposeContext* context,
axis_node = n_attr->i();
}
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {axis_node}, node, kOpDataFormatDimMap));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
@ -1320,14 +1385,33 @@ Status FillOpTransposer::TransposeNode(TransposeContext* context,
Status IdentityNTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsIdentityN(*node->node()));
const auto ports = GetVariadic4DFaninPorts(*context, *node);
if (!ShouldProcess(*context, *node) || ports.empty()) {
const auto ports_4d = GetVariadicNDFaninPorts(*context, *node, 4);
// Temporarily upgrade the context to obtain the number of 5D fanin ports.
std::vector<int> ports_5d;
{
ScopedDataFormatUpgrader data_format_upgrader(context, 5);
ports_5d = GetVariadicNDFaninPorts(*context, *node, 5);
}
if (!ShouldProcess(*context, *node)) {
return Status::OK();
}
if (!ports_4d.empty()) {
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, ports, node, kOpTranspose));
UpdateFaninEdgesWithOp(context, ports_4d, node, kOpTranspose));
TF_RETURN_IF_ERROR(
UpdateFanoutEdgesWithOp(context, ports, node, kOpTranspose));
UpdateFanoutEdgesWithOp(context, ports_4d, node, kOpTranspose));
}
if (!ports_5d.empty()) {
ScopedDataFormatUpgrader data_format_upgrader(context, 5);
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, ports_5d, node, kOpTranspose));
TF_RETURN_IF_ERROR(
UpdateFanoutEdgesWithOp(context, ports_5d, node, kOpTranspose));
}
return context->graph_view->GetMutationBuilder()->Apply();
}
@ -1519,10 +1603,18 @@ Status SelectTransposer::TransposeNode(TransposeContext* context,
Status ShapeTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsShape(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, 4) ||
const int rank = GetFaninPortRank(*node, 0);
if (rank != 4 && rank != 5) {
return Status::OK();
}
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
if (!ShouldProcess(*context, *node) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
TF_RETURN_IF_ERROR(
UpdateFanoutEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
@ -1532,10 +1624,20 @@ Status ShapeTransposer::TransposeNode(TransposeContext* context,
Status ShapeNTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsShapeN(*node->node()));
const auto ports = GetVariadic4DFaninPorts(*context, *node);
// ShapeN requires all input tensors to have the same dimensions. Therefore,
// we simply use the 0th fanin port.
const int rank = GetFaninPortRank(*node, 0);
if (rank != 4 && rank != 5) {
return Status::OK();
}
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
const auto ports = GetVariadicNDFaninPorts(*context, *node, rank);
if (!ShouldProcess(*context, *node) || ports.empty()) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, ports, node, kOpTranspose));
TF_RETURN_IF_ERROR(
@ -1546,11 +1648,19 @@ Status ShapeNTransposer::TransposeNode(TransposeContext* context,
Status SliceTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsSlice(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
!IsFaninPortsDimsNIfConst(*node, {1, 2}, {4}) ||
const int rank = GetFanoutPortRank(*node, 0);
if (rank != 4 && rank != 5) {
return Status::OK();
}
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
if (!ShouldProcess(*context, *node) ||
!IsFaninPortsDimsNIfConst(*node, {1, 2}, {rank}) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {1, 2}, node, kOpDataFormatVecPermute));
@ -1839,18 +1949,17 @@ string GetDeviceName(const VirtualPlacer* virtual_placer, const NodeDef& node) {
bool IsDefaultLayoutSensitiveOp(const NodeDef& node) {
static absl::flat_hash_set<string>* default_layout_sensitive_ops =
new absl::flat_hash_set<std::string>(
{"AvgPool", "BiasAdd", "Conv2D", "DepthwiseConv2dNative",
"DepthToSpace", "FusedBatchNorm", "FusedBatchNormV2",
"FusedBatchNormV3", "FusedConv2DBiasActivation", "MaxPool",
"SpaceToDepth"});
{"AvgPool", "Conv2D", "DepthwiseConv2dNative", "DepthToSpace",
"FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3",
"FusedConv2DBiasActivation", "MaxPool", "SpaceToDepth"});
return default_layout_sensitive_ops->find(node.op()) !=
default_layout_sensitive_ops->end();
}
bool IsLayoutSensitiveOp(const NodeDef& node) {
return IsDefaultLayoutSensitiveOp(node) || IsAvgPoolGrad(node) ||
IsBiasAddGrad(node) || IsConv2DBackpropFilter(node) ||
IsConv2DBackpropInput(node) ||
IsBiasAdd(node) || IsBiasAddGrad(node) ||
IsConv2DBackpropFilter(node) || IsConv2DBackpropInput(node) ||
IsDepthwiseConv2dNativeBackpropFilter(node) ||
IsDepthwiseConv2dNativeBackpropInput(node) ||
IsFusedBatchNormEx(node) || IsFusedBatchNormGrad(node) ||

View File

@ -210,6 +210,14 @@ class DefaultLayoutSensitiveOpTransposer : public LayoutSensitiveOpTransposer {
utils::MutableNodeView* node) override;
};
class BiasAddTransposer : public LayoutSensitiveOpTransposer {
public:
explicit BiasAddTransposer() : LayoutSensitiveOpTransposer() {}
Status TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) override;
};
class AvgPoolGradTransposer : public LayoutSensitiveOpTransposer {
public:
explicit AvgPoolGradTransposer() : LayoutSensitiveOpTransposer() {}
@ -319,9 +327,9 @@ class LayoutAgnosticOpTransposer : public Transposer {
bool IsAfterDstToSrcTransform(const TransposeContext& context,
const utils::MutableNodeView& node) const;
std::vector<int> GetVariadic4DFaninPorts(
const TransposeContext& context,
const utils::MutableNodeView& node) const;
std::vector<int> GetVariadicNDFaninPorts(const TransposeContext& context,
const utils::MutableNodeView& node,
int rank) const;
};
class DefaultLayoutAgnosticOpTransposer : public LayoutAgnosticOpTransposer {

View File

@ -27,6 +27,9 @@ std::shared_ptr<Transposer> TransposerFactory::GetTransposer(
return GetOrCreateIfNotFound<DefaultLayoutSensitiveOpTransposer>(
"DefaultLayoutSensitiveOp");
}
if (IsBiasAdd(node)) {
return GetOrCreateIfNotFound<BiasAddTransposer>("BiasAdd");
}
if (IsAvgPoolGrad(node)) {
return GetOrCreateIfNotFound<AvgPoolGradTransposer>("AvgPoolGrad");
}

View File

@ -48,7 +48,6 @@ load(
)
load(
"//third_party/mkl:build_defs.bzl",
"if_mkl_ml",
"mkl_deps",
)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
@ -3241,7 +3240,6 @@ cc_library(
deps = [
":aggregate_ops",
":argmax_op",
":batch_matmul_op",
":betainc_op",
":bincount_op",
":bucketize_op",
@ -3337,14 +3335,27 @@ tf_kernel_library(
tf_kernel_library(
name = "batch_matmul_op",
deps = [":matmul_op"],
)
tf_kernel_library(
name = "matmul_op",
# <prefix>*impl.h are excluded by default from the CPU build, add explicitly.
hdrs = ["batch_matmul_op_impl.h"],
prefix = "batch_matmul_op",
deps = MATH_DEPS + [":eigen_contraction_kernel"] + if_mkl_ml([
"//third_party/mkl:intel_binary_blob",
]) + if_cuda_or_rocm([
"//tensorflow/core/kernels:gpu_utils",
]),
hdrs = ["matmul_op_impl.h"],
defines = select({
":xsmm": ["TENSORFLOW_USE_LIBXSMM"],
"//conditions:default": [],
}),
prefix = "matmul_op",
deps = MATH_DEPS + [
":eigen_contraction_kernel",
":fused_eigen_output_kernels",
] + select({
":xsmm": ["@libxsmm_archive//:xsmm_avx"],
"//conditions:default": [],
}) + mkl_deps() + if_cuda([
"//tensorflow/core/platform/default/build_config:cublas_plugin",
]) + if_cuda_or_rocm([":gpu_utils"]),
)
tf_kernel_library(
@ -3406,28 +3417,6 @@ tf_kernel_library(
]),
)
tf_kernel_library(
name = "matmul_op",
srcs = [
"matmul_op.cc",
"matmul_op_fused.cc",
],
hdrs = ["matmul_op.h"],
defines = select({
":xsmm": ["TENSORFLOW_USE_LIBXSMM"],
"//conditions:default": [],
}),
deps = MATH_DEPS + [
":eigen_contraction_kernel",
":fused_eigen_output_kernels",
] + select({
":xsmm": ["@libxsmm_archive//:xsmm_avx"],
"//conditions:default": [],
}) + mkl_deps() + if_cuda([
"//tensorflow/core/platform/default/build_config:cublas_plugin",
]) + if_cuda_or_rocm([":gpu_utils"]),
)
tf_kernel_library(
name = "reduction_ops",
gpu_srcs = ["reduction_gpu_kernels.cu.h"],
@ -3620,25 +3609,6 @@ tf_cuda_cc_test(
],
)
tf_cuda_cc_test(
name = "batch_matmul_op_test",
size = "small",
srcs = ["batch_matmul_op_test.cc"],
deps = [
":batch_matmul_op",
":broadcast_to_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_cuda_cc_test(
name = "scan_ops_test",
size = "small",
@ -5868,8 +5838,8 @@ filegroup(
"identity_op.h",
"immutable_constant_op.cc",
"immutable_constant_op.h",
"matmul_op.cc",
"matmul_op.h",
"matmul_op_impl.h",
"matmul_op_real.cc",
"no_op.cc",
"no_op.h",
"one_hot_op.cc",
@ -5948,7 +5918,6 @@ filegroup(
srcs = [
"argmax_op.h",
"avgpooling_op.h",
"batch_matmul_op_impl.h",
"batch_norm_op.h",
"bincount_op.h",
"broadcast_to_op.h",
@ -6039,7 +6008,6 @@ filegroup(
":android_extended_ops_headers",
"argmax_op.cc",
"avgpooling_op.cc",
"batch_matmul_op_real.cc",
"batch_norm_op.cc",
"bcast_ops.cc",
"check_numerics_op.cc",
@ -6480,6 +6448,7 @@ cc_library(
"//tensorflow/core/platform:strong_hash",
"//third_party/eigen3",
"//third_party/fft2d:fft2d_headers",
"//third_party/icu/data:conversion_data",
"@com_google_absl//absl/base",
"@com_google_protobuf//:protobuf",
"@fft2d",
@ -7431,7 +7400,6 @@ test_suite(
"manual", # Avoid redundancy when using wildcard test patterns.
],
tests = [
":batch_matmul_op_test",
":batch_norm_op_test",
":broadcast_to_op_test",
":cast_op_test",

View File

@ -1,257 +0,0 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/kernels/broadcast_to_op.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
namespace {
Node* BroadcastTo(Graph* g, Node* input, Node* shape) {
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastTo")
.Input(input)
.Input(shape)
.Finalize(g, &ret));
return ret;
}
Node* BatchMatmulV2(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y) {
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BatchMatMulV2")
.Input(in0)
.Input(in1)
.Attr("adj_x", adj_x)
.Attr("adj_y", adj_y)
.Finalize(g, &ret));
return ret;
}
template <typename T>
static Graph* BatchMatmul(int b, int m, int k, int n, bool adjoint_a,
bool adjoint_b, DataType type) {
Graph* g = new Graph(OpRegistry::Global());
Tensor in0(type, adjoint_a ? TensorShape({b, k, m}) : TensorShape({b, m, k}));
in0.flat<T>().setRandom();
Tensor in1(type, adjoint_b ? TensorShape({b, n, k}) : TensorShape({b, k, n}));
in1.flat<T>().setRandom();
test::graph::BatchMatmul(g, test::graph::Constant(g, in0),
test::graph::Constant(g, in1), adjoint_a, adjoint_b);
return g;
}
template <typename T>
static Graph* BatchMatmulWithBroadcast(int b0, int b1, int m, int k, int n,
bool manual_broadcast, DataType type) {
Graph* g = new Graph(OpRegistry::Global());
Tensor in0(type, TensorShape({b0, m, k}));
in0.flat<T>().setRandom();
Tensor in1(type, TensorShape({b1, k, n}));
in1.flat<T>().setRandom();
Tensor broadcasted_in0_shape(DT_INT64, TensorShape({3}));
Tensor broadcasted_in1_shape(DT_INT64, TensorShape({3}));
Node* in0_node = nullptr;
Node* in1_node = nullptr;
if (manual_broadcast) {
for (int i = 0; i < 3; ++i) {
auto vec0 = broadcasted_in0_shape.vec<int64>();
auto vec1 = broadcasted_in1_shape.vec<int64>();
vec0(i) = (i == 0 ? std::max(b0, b1) : in0.shape().dim_size(i));
vec1(i) = (i == 0 ? std::max(b0, b1) : in1.shape().dim_size(i));
}
in0_node = BroadcastTo(g, test::graph::Constant(g, in0),
test::graph::Constant(g, broadcasted_in0_shape));
in1_node = BroadcastTo(g, test::graph::Constant(g, in1),
test::graph::Constant(g, broadcasted_in1_shape));
} else {
in0_node = test::graph::Constant(g, in0);
in1_node = test::graph::Constant(g, in1);
}
BatchMatmulV2(g, in0_node, in1_node, false, false);
return g;
}
#define BM_BatchMatmulDev(B, M, K, N, TA, TB, T, TFTYPE, DEVICE) \
static void \
BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE( \
int iters) { \
testing::UseRealTime(); \
testing::ItemsProcessed(static_cast<int64>(iters) * B * M * K * N * 2); \
test::Benchmark(#DEVICE, BatchMatmul<T>(B, M, K, N, TA, TB, TFTYPE)) \
.Run(iters); \
} \
BENCHMARK( \
BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE);
#define BM_BatchMatmul(B, M, K, N, TA, TB) \
BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, cpu);
// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64,
// cpu);
// BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, gpu);
/* Uncomment to enable benchmarks for double & complex types: */
// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64,
// gpu);
// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, cpu); \
// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, cpu);
// \
// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, gpu); \
// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, gpu);
// Macro arguments names: --------------------------------------------------- //
// B1: batch size of LHS
// B2: batch size of RHS
// M: outer dimension of LHS
// K: inner dimensions of LHS and RHS
// N: outer dimension of RHS
// MB: boolean indicating whether to use manual broadcasting
// T: C++ type of scalars (e.g. float, std::complex)
// TT: TensorFlow type of scalars (e.g. DT_FLOAT, DT_COMPLEX128
// D: Device (e.g. cpu, gpu)
#define BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, T, TT, D) \
static void \
BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D( \
int iters) { \
testing::UseRealTime(); \
testing::ItemsProcessed(static_cast<int64>(iters) * std::max(B1, B2) * M * \
K * N * 2); \
test::Benchmark(#D, BatchMatmulWithBroadcast<T>(B1, B2, M, K, N, MB, TT)) \
.Run(iters); \
} \
BENCHMARK( \
BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D);
#define BM_BatchMatmulBCast(B1, B2, M, K, N, MB) \
BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, float, DT_FLOAT, cpu);
// Typical fully connected layers
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, true);
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, false);
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, true);
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, false);
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, true);
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, false);
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, true);
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, false);
// Square matmul.
BM_BatchMatmulBCast(1, 128, 512, 512, 512, true);
BM_BatchMatmulBCast(1, 128, 512, 512, 512, false);
BM_BatchMatmulBCast(128, 1, 512, 512, 512, true);
BM_BatchMatmulBCast(128, 1, 512, 512, 512, false);
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, true);
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, false);
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, true);
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, false);
// Matrix-vector multiplies.
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, true);
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, false);
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, true);
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, false);
// Vector-matrix multiplies.
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, true);
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, false);
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, true);
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, false);
// Typical fully connected layers
BM_BatchMatmul(1, 1, 1024, 1024, false, false);
BM_BatchMatmul(1, 8, 1024, 1024, false, false);
BM_BatchMatmul(1, 16, 1024, 1024, false, false);
BM_BatchMatmul(1, 128, 1024, 1024, false, false);
BM_BatchMatmul(2, 1, 1024, 1024, false, false);
BM_BatchMatmul(2, 8, 1024, 1024, false, false);
BM_BatchMatmul(2, 16, 1024, 1024, false, false);
BM_BatchMatmul(2, 128, 1024, 1024, false, false);
BM_BatchMatmul(8, 1, 1024, 1024, false, false);
BM_BatchMatmul(8, 8, 1024, 1024, false, false);
BM_BatchMatmul(8, 16, 1024, 1024, false, false);
BM_BatchMatmul(8, 128, 1024, 1024, false, false);
BM_BatchMatmul(32, 1, 1024, 1024, false, false);
BM_BatchMatmul(32, 8, 1024, 1024, false, false);
BM_BatchMatmul(32, 16, 1024, 1024, false, false);
BM_BatchMatmul(32, 128, 1024, 1024, false, false);
// Square matmul.
BM_BatchMatmul(1, 32, 32, 32, false, false);
BM_BatchMatmul(1, 128, 128, 128, false, false);
BM_BatchMatmul(1, 256, 256, 256, false, false);
BM_BatchMatmul(1, 1024, 1024, 1024, false, false);
BM_BatchMatmul(1, 2048, 2048, 2048, false, false);
BM_BatchMatmul(2, 32, 32, 32, false, false);
BM_BatchMatmul(2, 128, 128, 128, false, false);
BM_BatchMatmul(2, 256, 256, 256, false, false);
BM_BatchMatmul(2, 1024, 1024, 1024, false, false);
BM_BatchMatmul(2, 2048, 2048, 2048, false, false);
BM_BatchMatmul(4, 32, 32, 32, false, false);
BM_BatchMatmul(4, 128, 128, 128, false, false);
BM_BatchMatmul(4, 256, 256, 256, false, false);
BM_BatchMatmul(4, 1024, 1024, 1024, false, false);
BM_BatchMatmul(4, 2048, 2048, 2048, false, false);
BM_BatchMatmul(8, 32, 32, 32, false, false);
BM_BatchMatmul(8, 128, 128, 128, false, false);
BM_BatchMatmul(8, 256, 256, 256, false, false);
BM_BatchMatmul(8, 1024, 1024, 1024, false, false);
BM_BatchMatmul(8, 2048, 2048, 2048, false, false);
BM_BatchMatmul(32, 32, 32, 32, false, false);
BM_BatchMatmul(32, 128, 128, 128, false, false);
BM_BatchMatmul(32, 256, 256, 256, false, false);
BM_BatchMatmul(32, 1024, 1024, 1024, false, false);
BM_BatchMatmul(32, 2048, 2048, 2048, false, false);
// Matrix-vector multiplies.
BM_BatchMatmul(1, 10000, 200, 1, false, false);
BM_BatchMatmul(8, 10000, 200, 1, false, false);
BM_BatchMatmul(32, 10000, 200, 1, false, false);
BM_BatchMatmul(1, 10000, 200, 1, true, false);
BM_BatchMatmul(8, 10000, 200, 1, true, false);
BM_BatchMatmul(32, 10000, 200, 1, true, false);
BM_BatchMatmul(1, 10000, 200, 1, false, true);
BM_BatchMatmul(8, 10000, 200, 1, false, true);
BM_BatchMatmul(32, 10000, 200, 1, false, true);
BM_BatchMatmul(1, 10000, 200, 1, true, true);
BM_BatchMatmul(8, 10000, 200, 1, true, true);
BM_BatchMatmul(32, 10000, 200, 1, true, true);
// Vector-matrix multiplies.
BM_BatchMatmul(1, 1, 200, 10000, false, false);
BM_BatchMatmul(8, 1, 200, 10000, false, false);
BM_BatchMatmul(32, 1, 200, 10000, false, false);
BM_BatchMatmul(1, 1, 200, 10000, true, false);
BM_BatchMatmul(8, 1, 200, 10000, true, false);
BM_BatchMatmul(32, 1, 200, 10000, true, false);
BM_BatchMatmul(1, 1, 200, 10000, false, true);
BM_BatchMatmul(8, 1, 200, 10000, false, true);
BM_BatchMatmul(32, 1, 200, 10000, false, true);
BM_BatchMatmul(1, 1, 200, 10000, true, true);
BM_BatchMatmul(8, 1, 200, 10000, true, true);
BM_BatchMatmul(32, 1, 200, 10000, true, true);
} // namespace
} // namespace tensorflow

View File

@ -38,6 +38,8 @@ namespace data {
/* static */ constexpr const char* const CacheDatasetOp::kOutputTypes;
/* static */ constexpr const char* const CacheDatasetOp::kOutputShapes;
namespace {
constexpr char kKeyStrFormat[] = "%%%zuzu_%%%zuzu";
constexpr char kPaddingSizeStrFormat[] = "%zu";
constexpr char kFileDatasetPrefix[] = "File";
@ -57,6 +59,14 @@ constexpr char kCacheCompleted[] = "cache_completed";
constexpr char kIndex[] = "index";
constexpr char kImpl[] = "Impl";
constexpr char kCacheDataset[] = "CacheDataset";
constexpr char kIncompleteCacheErrorMessage[] =
"The calling iterator did not fully read the dataset being cached. In "
"order to avoid unexpected truncation of the dataset, the partially cached "
"contents of the dataset will be discarded. This can happen if you have "
"an input pipeline similar to `dataset.cache().take(k).repeat()`. You "
"should use `dataset.take(k).cache().repeat()` instead.";
} // namespace
class CacheDatasetOp::FileDatasetBase : public DatasetBase {
public:
@ -220,6 +230,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase {
~FileWriterIterator() override {
if (!dataset()->env_->FileExists(MetaFilename(filename_)).ok()) {
LOG(WARNING) << kIncompleteCacheErrorMessage;
std::vector<string> cache_files;
Status s = dataset()->env_->GetMatchingPaths(
strings::StrCat(filename_, "*"), &cache_files);
@ -754,13 +765,7 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
~MemoryWriterIterator() override {
mutex_lock l(mu_);
if (!temp_cache_.empty() && !cache_->IsCompleted()) {
LOG(WARNING)
<< "The calling iterator did not fully read the dataset being "
"cached. In order to avoid unexpected truncation of the "
"dataset, the partially cached contents of the dataset "
"will be discarded. This can happen if you have an input "
"pipeline similar to `dataset.cache().take(k).repeat()`. "
"You should use `dataset.take(k).cache().repeat()` instead.";
LOG(WARNING) << kIncompleteCacheErrorMessage;
cache_->Reset();
}
}

View File

@ -482,7 +482,10 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
VLOG(1) << "Failed to get element from worker "
<< task_to_process->address << ": " << s;
task_to_process->in_use = false;
status_ = s;
status_ = Status(
s.code(),
absl::StrCat("Failed to get element from worker ",
task_to_process->address, ": ", s.error_message()));
get_next_cv_.notify_all();
return;
}

View File

@ -189,14 +189,18 @@ static Graph* DynamicPartition(int num_partitions, int dim) {
}
#define BM_DYNAMIC_PARTITION(DEVICE, T, num) \
static void BM_##DEVICE##_dynpart_##T##_##num(int iters, int dim) { \
static void BM_##DEVICE##_dynpart_##T##_##num( \
::testing::benchmark::State& state) { \
const int dim = state.range(0); \
\
const int64 items = ((128 << 20) / sizeof(T)); \
const int64 tot = static_cast<int64>(iters) * items; \
testing::ItemsProcessed(tot); \
testing::UseRealTime(); \
test::Benchmark(#DEVICE, DynamicPartition<T>(num, dim)).Run(iters); \
test::Benchmark(#DEVICE, DynamicPartition<T>(num, dim), \
/*old_benchmark_api=*/false) \
.Run(state); \
const int64 tot = static_cast<int64>(state.iterations()) * items; \
state.SetItemsProcessed(tot); \
} \
BENCHMARK(BM_##DEVICE##_dynpart_##T##_##num)->Arg(1)->Arg(256)
BENCHMARK(BM_##DEVICE##_dynpart_##T##_##num)->UseRealTime()->Arg(1)->Arg(256)
BM_DYNAMIC_PARTITION(cpu, float, 2);
BM_DYNAMIC_PARTITION(cpu, float, 100);

View File

@ -1376,7 +1376,7 @@ TEST(EigenSpatialConvolutionsTest, SpatialConvContractionMapper) {
}
template <typename T>
static void PackRhsHelper(int iters,
static void PackRhsHelper(::testing::benchmark::State& state,
/* Input dimensions: */
int input_batches, int input_cols, int input_rows,
int input_depth,
@ -1393,9 +1393,6 @@ static void PackRhsHelper(int iters,
// Set random seed for benchmark repeatability.
srand(12345);
tensorflow::testing::UseRealTime();
tensorflow::testing::StopTiming();
using Dimensions = Eigen::DSizes<Eigen::Index, 4>;
// Default Eigen::Tensor layout is column major, so we configure dimensions
@ -1547,8 +1544,7 @@ static void PackRhsHelper(int iters,
return (idx / packet_size) * packet_size;
};
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
int input_idx =
num_inputs == 1 ? 1 : internal::random<int>(0, num_inputs - 1);
@ -1571,15 +1567,15 @@ static void PackRhsHelper(int iters,
input_mappers[input_idx].getSubMapper(depth_offset, col_offset);
pack_rhs(packed.data() + packed_offset, sub_mapper, depth, cols);
}
tensorflow::testing::StopTiming();
tensorflow::testing::SetLabel(
state.SetLabel(
absl::StrCat("patch: ", patch_rows, "x", patch_cols, " D", patch_depth,
"; num_patches=", num_patches, " patch_size=", patch_size,
" num_inputs=", num_inputs, " padding=", padding));
}
template <typename T>
static void PackLhsHelper(int iters,
static void PackLhsHelper(::testing::benchmark::State& state,
/* Input dimensions: */
int input_depth,
/* Filter (kernel) dimensions: */
@ -1592,9 +1588,6 @@ static void PackLhsHelper(int iters,
eigen_assert(block_rows <= filter_count);
eigen_assert(block_cols <= input_depth * filter_rows * filter_cols);
tensorflow::testing::UseRealTime();
tensorflow::testing::StopTiming();
using Dimensions = Eigen::DSizes<Eigen::Index, 4>;
// Default Eigen::Tensor layout is column major, so we configure dimensions
@ -1716,8 +1709,7 @@ static void PackLhsHelper(int iters,
const Index max_row = filter_count;
const Index max_col = filter_rows * filter_cols * input_depth;
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
int filter_idx =
num_filters == 1 ? 1 : internal::random<int>(0, num_filters - 1);
@ -1743,8 +1735,7 @@ static void PackLhsHelper(int iters,
pack_lhs(packed.data() + packed_offset, sub_mapper, cols, rows);
#endif
}
tensorflow::testing::StopTiming();
tensorflow::testing::SetLabel(absl::StrCat(
state.SetLabel(absl::StrCat(
"filter: count=", filter_count, " dims=", filter_rows, "x", filter_cols,
"; input: depth=", input_depth, "; num_filers=", num_filters));
}
@ -1777,12 +1768,14 @@ static void PackLhsHelper(int iters,
#define BM_PackRhs(T, N, H, W, C, FC, FH, FW, PAD, SH, SW, ISH, ISW, BR, BC) \
static void BM_RHS_NAME(PackRhs, T, N, H, W, C, FC, FH, FW, PAD, SH, SW, \
ISH, ISW, BR, BC)(int iters) { \
PackRhsHelper<T>(iters, N, H, W, C, FC, FH, FW, PADDING_##PAD, SH, SW, \
ISH, ISW, BR, \
BC)(::testing::benchmark::State & state) { \
PackRhsHelper<T>(state, N, H, W, C, FC, FH, FW, PADDING_##PAD, SH, SW, \
ISH, ISW, BR, BC); \
} \
BENCHMARK(BM_RHS_NAME(PackRhs, T, N, H, W, C, FC, FH, FW, PAD, SH, SW, ISH, \
ISW, BR, BC))
ISW, BR, BC)) \
->UseRealTime()
// Number of input channel (input depth) it equal to the number of patch
// channels (patch depth).
@ -2020,10 +2013,11 @@ BM_PackRhs(/*type*/ qint8, //
BM_CONCAT(BM_##prefix##_##T##_##C##_FC##FC##_##FH##x##FW, _B##BR##x##BC)
#define BM_PackLhs(T, C, FC, FH, FW, BR, BC) \
static void BM_LHS_NAME(PackLhs, T, C, FC, FH, FW, BR, BC)(int iters) { \
PackLhsHelper<T>(iters, C, FC, FH, FW, BR, BC); \
static void BM_LHS_NAME(PackLhs, T, C, FC, FH, FW, BR, \
BC)(::testing::benchmark::State & state) { \
PackLhsHelper<T>(state, C, FC, FH, FW, BR, BC); \
} \
BENCHMARK(BM_LHS_NAME(PackLhs, T, C, FC, FH, FW, BR, BC))
BENCHMARK(BM_LHS_NAME(PackLhs, T, C, FC, FH, FW, BR, BC))->UseRealTime()
// Number of input channel (input depth) it equal to the number of patch
// channels (patch depth).

View File

@ -349,15 +349,16 @@ typedef BenchmarkOptions<ExampleStore<FloatFiller>, kRagged> RaggedFloat;
// B == batch_size, K == num_keys. F == feature_size.
// K must be one of 10, 100, 1000
#define BM_ParseExample(TYPE, B, K, F) \
static void BM_ParseExample##_##TYPE##_##B##_##K##_##F(int iters) { \
static void BM_ParseExample##_##TYPE##_##B##_##K##_##F( \
::testing::benchmark::State& state) { \
int64 items_per_iter = static_cast<int64>(B) * K * F; \
testing::UseRealTime(); \
testing::ItemsProcessed(static_cast<int64>(iters) * items_per_iter); \
test::Benchmark("cpu", ParseExample<TYPE>(B, K, F), nullptr, nullptr, \
nullptr, "SINGLE_THREADED_EXECUTOR") \
.Run(iters); \
nullptr, "SINGLE_THREADED_EXECUTOR", false) \
.Run(state); \
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * \
items_per_iter); \
} \
BENCHMARK(BM_ParseExample##_##TYPE##_##B##_##K##_##F);
BENCHMARK(BM_ParseExample##_##TYPE##_##B##_##K##_##F)->UseRealTime();
#define BM_AllParseExample(Type) \
BM_ParseExample(Type, 1, 10, 1); \
@ -385,15 +386,17 @@ BM_AllParseExample(VarLenDenseFloat);
// K must be one of 10, 100, 1000
// B=0 indicates that a scalar input should be used (instead of a vector).
#define BM_ParseExampleV2(TYPE, B, K, F) \
static void BM_ParseExampleV2##_##TYPE##_##B##_##K##_##F(int iters) { \
static void BM_ParseExampleV2##_##TYPE##_##B##_##K##_##F( \
::testing::benchmark::State& state) { \
int64 items_per_iter = static_cast<int64>(std::max(B, 1)) * K * F; \
testing::UseRealTime(); \
testing::ItemsProcessed(static_cast<int64>(iters) * items_per_iter); \
test::Benchmark("cpu", ParseExampleV2<TYPE>(B, K, F), nullptr, nullptr, \
nullptr, "SINGLE_THREADED_EXECUTOR") \
.Run(iters); \
nullptr, "SINGLE_THREADED_EXECUTOR", \
/*old_benchmark_api=*/false) \
.Run(state); \
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * \
items_per_iter); \
} \
BENCHMARK(BM_ParseExampleV2##_##TYPE##_##B##_##K##_##F);
BENCHMARK(BM_ParseExampleV2##_##TYPE##_##B##_##K##_##F)->UseRealTime();
#define BM_AllParseExampleV2(Type) \
/* Vector Inputs */ \
@ -437,15 +440,17 @@ BM_AllParseExampleV2(RaggedFloat);
// K == num_keys. F == feature_size.
// K must be one of 10, 100, 1000
#define BM_ParseSingleExample(TYPE, K, F) \
static void BM_ParseSingleExample##_##TYPE##_1_##K##_##F(int iters) { \
void BM_ParseSingleExample##_##TYPE##_1_##K##_##F( \
::testing::benchmark::State& state) { \
int64 items_per_iter = K * F; \
testing::UseRealTime(); \
testing::ItemsProcessed(static_cast<int64>(iters) * items_per_iter); \
test::Benchmark("cpu", ParseSingleExample<TYPE>(K, F), nullptr, nullptr, \
nullptr, "SINGLE_THREADED_EXECUTOR") \
.Run(iters); \
nullptr, "SINGLE_THREADED_EXECUTOR", \
/*old_benchmark_api=*/false) \
.Run(state); \
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * \
items_per_iter); \
} \
BENCHMARK(BM_ParseSingleExample##_##TYPE##_1_##K##_##F);
BENCHMARK(BM_ParseSingleExample##_##TYPE##_1_##K##_##F)->UseRealTime();
#define BM_AllParseSingleExample(Type) \
BM_ParseSingleExample(Type, 10, 1); \

View File

@ -133,14 +133,18 @@ static Graph* GatherNd(int dim) {
}
#define BM_GATHER_ND(DEVICE, INDEX) \
static void BM_##DEVICE##_gather_nd_##INDEX(int iters, int dim) { \
const int64 tot = static_cast<int64>(iters) * kLookups * 4; \
testing::ItemsProcessed(tot); \
testing::BytesProcessed(tot * sizeof(float)); \
testing::UseRealTime(); \
test::Benchmark(#DEVICE, GatherNd<INDEX>(dim)).Run(iters); \
static void BM_##DEVICE##_gather_nd_##INDEX( \
::testing::benchmark::State& state) { \
const int dim = state.range(0); \
test::Benchmark(#DEVICE, GatherNd<INDEX>(dim), \
/*old_benchmark_api=*/false) \
.Run(state); \
const int64 tot = static_cast<int64>(state.iterations()) * kLookups * 4; \
state.SetItemsProcessed(tot); \
state.SetBytesProcessed(tot * sizeof(float)); \
} \
BENCHMARK(BM_##DEVICE##_gather_nd_##INDEX) \
->UseRealTime() \
->Arg(10) \
->Arg(100) \
->Arg(1000) \

View File

@ -223,14 +223,17 @@ static Graph* Gather(int dim) {
}
#define BM_GATHER(DEVICE, INDEX) \
static void BM_##DEVICE##_gather_##INDEX(int iters, int dim) { \
const int64 tot = static_cast<int64>(iters) * kLookups * dim; \
testing::ItemsProcessed(tot); \
testing::BytesProcessed(tot * sizeof(float)); \
testing::UseRealTime(); \
test::Benchmark(#DEVICE, Gather<INDEX>(dim)).Run(iters); \
static void BM_##DEVICE##_gather_##INDEX( \
::testing::benchmark::State& state) { \
const int dim = state.range(0); \
test::Benchmark(#DEVICE, Gather<INDEX>(dim), /*old_benchmark_api=*/false) \
.Run(state); \
const int64 tot = static_cast<int64>(state.iterations()) * kLookups * dim; \
state.SetItemsProcessed(tot); \
state.SetBytesProcessed(tot * sizeof(float)); \
} \
BENCHMARK(BM_##DEVICE##_gather_##INDEX) \
->UseRealTime() \
->Arg(1) \
->Arg(10) \
->Arg(20) \

View File

@ -66,12 +66,15 @@ static Graph* InTopK(int num_targets, int num_classes, T top_k) {
BM_InTopK##_##T##_##TARGETS##_##CLASSES##_##K##_##DEVICE
#define BM_InTopK(T, TARGETS, CLASSES, K, DEVICE) \
static void BM_NAME(T, TARGETS, CLASSES, K, DEVICE)(int iters) { \
testing::UseRealTime(); \
testing::ItemsProcessed(static_cast<int64>(iters) * TARGETS * CLASSES); \
test::Benchmark(#DEVICE, InTopK<T>(TARGETS, CLASSES, K)).Run(iters); \
static void BM_NAME(T, TARGETS, CLASSES, K, \
DEVICE)(::testing::benchmark::State & state) { \
test::Benchmark(#DEVICE, InTopK<T>(TARGETS, CLASSES, K), \
/*old_benchmark_api=*/false) \
.Run(state); \
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * TARGETS * \
CLASSES); \
} \
BENCHMARK(BM_NAME(T, TARGETS, CLASSES, K, DEVICE));
BENCHMARK(BM_NAME(T, TARGETS, CLASSES, K, DEVICE))->UseRealTime();
BM_InTopK(int64, 64, 1000, 10, cpu);
BM_InTopK(int64, 64, 10000, 10, cpu);

View File

@ -30,9 +30,9 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/linalg/einsum_op.h"
#include "tensorflow/core/kernels/matmul_op_impl.h"
#include "tensorflow/core/kernels/reduction_ops_common.h"
#include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/lib/core/errors.h"

View File

@ -199,7 +199,7 @@ TCASE(T3, 128, 4, 3, 2.0f, 1.0f, 1.0f)
#undef TCASE
static Graph* BM_LRNGrad(int batches, int rows, int cols, int depth,
static Graph* MakeRNGrad(int batches, int rows, int cols, int depth,
int depth_radius) {
Graph* g = new Graph(OpRegistry::Global());
Tensor grads(DT_FLOAT, TensorShape({batches, rows, cols, depth}));
@ -224,10 +224,13 @@ static Graph* BM_LRNGrad(int batches, int rows, int cols, int depth,
}
#define BM_LRNGradDev(DEVICE, B, R, C, D, DR) \
static void BM_LRNGrad_##DEVICE##_##B##_##R##_##C##_##D##_##DR(int iters) { \
testing::ItemsProcessed(static_cast<int64>(iters) * B * R * C * D * DR * \
4); \
test::Benchmark(#DEVICE, BM_LRNGrad(B, R, C, D, DR)).Run(iters); \
static void BM_LRNGrad_##DEVICE##_##B##_##R##_##C##_##D##_##DR( \
::testing::benchmark::State& state) { \
test::Benchmark(#DEVICE, MakeRNGrad(B, R, C, D, DR), \
/*old_benchmark_api*/ false) \
.Run(state); \
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * B * R * \
C * D * DR * 4); \
} \
BENCHMARK(BM_LRNGrad_##DEVICE##_##B##_##R##_##C##_##D##_##DR)

View File

@ -1,567 +0,0 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/math_ops.cc.
#define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/matmul_op.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/util/matmul_autotune.h"
#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
#endif
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/gpu_utils.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
template <typename Device, typename T, bool USE_CUBLAS>
struct LaunchMatMul;
namespace {
// Converts a TensorFlow Tensor to an Eigen Matrix.
template <typename T>
Eigen::Map<
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
ToEigenMatrix(const Tensor& tensor) {
auto matrix = tensor.matrix<T>();
return Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>::Map(
matrix.data(), matrix.dimension(0), matrix.dimension(1));
}
// Converts a TensorFlow Tensor to an Eigen Vector.
template <typename T>
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, 1>> ToEigenVector(Tensor* tensor) {
auto v = tensor->flat<T>();
return Eigen::Matrix<T, Eigen::Dynamic, 1>::Map(v.data(), v.dimension(0));
}
template <typename T>
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>> ToEigenVector(
const Tensor& tensor) {
auto v = tensor.flat<T>();
return Eigen::Matrix<T, Eigen::Dynamic, 1>::Map(v.data(), v.dimension(0));
}
} // namespace
// If either side can be represented as a vector, do an explicit vector
// matrix multiply and return true; else return false.
//
// Note: this uses plain Eigen and not Eigen Tensor because it is more
// efficient.
template <typename T>
bool ExplicitVectorMatrixOptimization(
const Tensor& a, const Tensor& b,
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
Tensor* out) {
if (out->dim_size(0) == 1) {
if (dim_pair[0].second == 0) {
// Note: this case is optimized in Eigen Tensors.
return false;
} else {
auto out_v = ToEigenVector<T>(out);
auto a_v = ToEigenVector<T>(a);
auto b_m = ToEigenMatrix<T>(b);
out_v.noalias() = b_m * a_v;
}
return true;
} else if (out->dim_size(1) == 1) {
auto out_v = ToEigenVector<T>(out);
auto a_m = ToEigenMatrix<T>(a);
auto b_v = ToEigenVector<T>(b);
if (dim_pair[0].first == 0) {
out_v.noalias() = a_m.transpose() * b_v;
} else {
out_v.noalias() = a_m * b_v;
}
return true;
}
return false;
}
// Half is not supported.
template <>
bool ExplicitVectorMatrixOptimization<Eigen::half>(
const Tensor& a, const Tensor& b,
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
Tensor* out) {
return false;
}
template <typename Device, typename T>
struct LaunchMatMulBase {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
typedef se::blas::AlgorithmType AlgorithmType;
#else
typedef int64 AlgorithmType;
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
static void launch(
OpKernelContext* ctx, const Tensor& a, const Tensor& b,
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
std::vector<AlgorithmType>* algorithms, bool use_autotune, Tensor* out) {
// An explicit vector-matrix multiply is much better optimized than an
// implicit one and this is a bottleneck during non-batched inference.
bool was_vector = ExplicitVectorMatrixOptimization<T>(a, b, dim_pair, out);
if (!was_vector) {
functor::MatMulFunctor<Device, T>()(ctx->eigen_device<Device>(),
out->matrix<T>(), a.matrix<T>(),
b.matrix<T>(), dim_pair);
}
}
static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
std::vector<int64>* algorithms,
bool* algorithm_set_flag) {}
};
// On CPUs, we ignore USE_CUBLAS
template <typename T>
struct LaunchMatMulCPU : LaunchMatMulBase<CPUDevice, T> {};
template <typename T, bool USE_CUBLAS>
struct LaunchMatMul<CPUDevice, T, USE_CUBLAS> : public LaunchMatMulCPU<T> {};
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace {
template <typename T>
struct LaunchBlasGemv {
static void Compute(OpKernelContext* ctx, se::Stream* stream, bool trans,
uint64 m, uint64 n, const se::DeviceMemory<T>& a,
const se::DeviceMemory<T>& b, se::DeviceMemory<T>* c,
se::blas::ProfileResult* output_profile) {
const auto blas_trans = trans ? se::blas::Transpose::kTranspose
: se::blas::Transpose::kNoTranspose;
if (output_profile == nullptr) {
bool blas_launch_status =
stream
->ThenBlasGemv(blas_trans, m, n, static_cast<T>(1.0), a, m, b, 1,
static_cast<T>(0.0), c, 1)
.ok();
if (!blas_launch_status) {
ctx->SetStatus(
errors::Internal("Blas GEMV launch failed: m=", m, ", n=", n));
}
} else {
bool blas_launch_status =
stream
->ThenBlasGemvWithProfiling(blas_trans, m, n, static_cast<T>(1.0),
a, m, b, 1, static_cast<T>(0.0), c, 1,
output_profile)
.ok();
if (!blas_launch_status) {
ctx->SetStatus(errors::Internal(
"Blas GEMV with profiling launch failed: m=", m, ", n=", n));
}
}
}
static bool IsSupported() { return true; }
};
template <>
void LaunchBlasGemv<Eigen::half>::Compute(
OpKernelContext* ctx, se::Stream* stream, bool trans, uint64 m, uint64 n,
const se::DeviceMemory<Eigen::half>& a,
const se::DeviceMemory<Eigen::half>& b, se::DeviceMemory<Eigen::half>* c,
se::blas::ProfileResult* output_profile) {
ctx->SetStatus(errors::Internal(
"Blas GEMV launch failed: GEMV is not implemented for float16."));
}
template <>
bool LaunchBlasGemv<Eigen::half>::IsSupported() {
return false;
}
template <typename T>
bool ShouldUseGemv(uint64 n) {
return (LaunchBlasGemv<T>::IsSupported() && n == 1);
}
} // namespace
bool GetCublasAutotuneComputationType(const DataType& dtype,
se::blas::ComputationType* compute_type) {
using se::blas::ComputationType;
switch (dtype) {
case DT_HALF:
case DT_BFLOAT16:
static bool use_f32_for_f16_computation =
MatmulDoFP32ComputationFP16Input();
if (use_f32_for_f16_computation) {
*compute_type = ComputationType::kF32;
} else {
*compute_type = ComputationType::kF16;
}
return false;
case DT_FLOAT:
*compute_type = ComputationType::kF32;
return true;
case DT_DOUBLE:
*compute_type = ComputationType::kF64;
return true;
default:
// Unsupported compute_type, return false.
return false;
}
}
// A dummy type to group matmul autotune results together.
struct MatmulAutoTuneGroup {
static string name() { return "Matmul"; }
};
typedef AutoTuneSingleton<MatmulAutoTuneGroup, MatmulParameters,
se::blas::AlgorithmConfig>
AutoTuneMatmul;
template <typename T>
struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
static void launch(
OpKernelContext* ctx, const Tensor& a, const Tensor& b,
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
std::vector<int64>* algorithms, bool use_autotune, Tensor* out) {
using se::blas::AlgorithmConfig;
using se::blas::ComputationType;
using se::blas::kDefaultAlgorithm;
using se::blas::kDefaultBlasGemm;
using se::blas::kDefaultBlasGemv;
using se::blas::kNoAlgorithm;
using se::blas::ProfileResult;
using se::blas::Transpose;
Transpose trans[] = {Transpose::kNoTranspose, Transpose::kTranspose};
const uint64 m = a.dim_size(1 - dim_pair[0].first);
const uint64 k = a.dim_size(dim_pair[0].first);
const uint64 n = b.dim_size(1 - dim_pair[0].second);
bool transpose_a = dim_pair[0].first == 0;
bool transpose_b = dim_pair[0].second == 1;
auto blas_transpose_a = trans[transpose_a];
auto blas_transpose_b = trans[transpose_b];
auto* stream = ctx->op_device_context()->stream();
OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
auto a_ptr = AsDeviceMemory(a.template flat<T>().data(),
a.template flat<T>().size());
auto b_ptr = AsDeviceMemory(b.template flat<T>().data(),
b.template flat<T>().size());
auto c_ptr = AsDeviceMemory(out->template flat<T>().data(),
out->template flat<T>().size());
auto alpha = static_cast<T>(1.0);
auto beta = static_cast<T>(0.0);
int device_id = stream->parent()->device_ordinal();
DataType dtype = a.dtype();
MatmulParameters matmul_parameters = {
transpose_a, transpose_b, m, n, k, dtype, device_id,
};
AlgorithmConfig algorithm_config(kNoAlgorithm);
ComputationType computation_type;
bool compute_type_supported =
GetCublasAutotuneComputationType(dtype, &computation_type);
if (use_autotune && compute_type_supported && !algorithms->empty()) {
ProfileResult best_result;
// TODO(yangzihao): Unify this code with conv autotuning.
if (!AutoTuneMatmul::GetInstance()->Find(matmul_parameters,
&algorithm_config)) {
ProfileResult profile_result;
for (auto profile_algorithm : (*algorithms)) {
// Cublas does
// C = A x B
// where A, B and C are assumed to be in column major.
// We want the output to be in row-major, so we can compute
// C' = B' x A' (' stands for transpose)
bool cublas_launch_status =
stream
->ThenBlasGemmWithAlgorithm(
blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
&c_ptr, n, computation_type, profile_algorithm,
&profile_result)
.ok();
if (cublas_launch_status) {
if (profile_result.is_valid()) {
if (profile_result.elapsed_time_in_ms() <
best_result.elapsed_time_in_ms()) {
best_result = profile_result;
}
}
}
}
// Try BlasGemmWithProfiling
bool cublas_launch_status =
stream
->ThenBlasGemmWithProfiling(
blas_transpose_b, blas_transpose_a, n, m, k, 1.0, b_ptr,
transpose_b ? k : n, a_ptr, transpose_a ? m : k, 0.0,
&c_ptr, n, &profile_result)
.ok();
if (cublas_launch_status) {
if (profile_result.is_valid()) {
if (profile_result.elapsed_time_in_ms() <
best_result.elapsed_time_in_ms()) {
best_result = profile_result;
}
}
}
// Try BlasGemvWithProfiling
if (ShouldUseGemv<T>(n)) {
LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
transpose_a ? m : k, transpose_a ? k : m,
a_ptr, b_ptr, &c_ptr, &profile_result);
if (profile_result.is_valid()) {
if (profile_result.elapsed_time_in_ms() <
best_result.elapsed_time_in_ms()) {
best_result = profile_result;
}
}
}
}
// We make sure that each matmul parameter set only gets one pass of
// autotune. If the best result is found, assign it to algorithm_type
// and insert it to autotune map. If all internal kernels of
// cublasGemmEx() returns invalid results, we add kNoAlgorithm to the
// autotune map.
if (best_result.is_valid()) {
algorithm_config.set_algorithm(best_result.algorithm());
}
AutoTuneMatmul::GetInstance()->Insert(matmul_parameters,
algorithm_config);
if (algorithm_config.algorithm() != kNoAlgorithm &&
algorithm_config.algorithm() != kDefaultBlasGemm &&
algorithm_config.algorithm() != kDefaultBlasGemv) {
bool cublas_launch_status =
stream
->ThenBlasGemmWithAlgorithm(
blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
&c_ptr, n, computation_type, algorithm_config.algorithm(),
nullptr)
.ok();
if (!cublas_launch_status) {
ctx->SetStatus(errors::Internal(
"Blas GEMM with algorithm launch failed : a.shape=(",
a.dim_size(0), ", ", a.dim_size(1), "), b.shape=(", b.dim_size(0),
", ", b.dim_size(1), "), m=", m, ", n=", n, ", k=", k));
}
}
}
// For the following case, we use normal BlasGemm():
// 1) We didn't set the use_autotune flag;
// 2) compute type does not support autotune;
// 3) no algorithm is found;
// 4) all internal kernels in autotune return invalid results.
// For the following case, we use normal BlasGemv():
// 1) We didn't set the use_autotune flag but LaunchBlasGemv is supported
// and n == 1.
// 2) We set the use_autotune flag and it picked up BlasGemv() and set the
// algorithm_config.algorithm() to be kDefaultBlasGemv.
if (!use_autotune || !compute_type_supported || algorithms->empty() ||
algorithm_config.algorithm() == kNoAlgorithm ||
algorithm_config.algorithm() == kDefaultBlasGemm ||
algorithm_config.algorithm() == kDefaultBlasGemv) {
if (algorithm_config.algorithm() == kDefaultBlasGemv ||
ShouldUseGemv<T>(n)) {
// This is a matrix*vector multiply so use GEMV to compute A * b.
// Here we are multiplying in the natural order, so we have to flip
// the transposition flag to compensate for the tensor being stored
// row-major.
// TODO(yangzihao): Add Gemv as an autotuning option too.
LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
transpose_a ? m : k, transpose_a ? k : m,
a_ptr, b_ptr, &c_ptr, nullptr);
} else {
// Use C' = B' x A' (' stands for transpose)
bool blas_launch_status =
stream
->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
1.0f, b_ptr, transpose_b ? k : n, a_ptr,
transpose_a ? m : k, 0.0f, &c_ptr, n)
.ok();
if (!blas_launch_status) {
ctx->SetStatus(errors::Internal(
"Blas GEMM launch failed : a.shape=(", a.dim_size(0), ", ",
a.dim_size(1), "), b.shape=(", b.dim_size(0), ", ", b.dim_size(1),
"), m=", m, ", n=", n, ", k=", k));
}
}
}
}
static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
std::vector<int64>* algorithms,
bool* algorithm_set_flag) {
if (*algorithm_set_flag == false) {
auto* stream = ctx->device()->tensorflow_gpu_device_info()->stream;
stream->parent()->GetBlasGemmAlgorithms(algorithms);
*algorithm_set_flag = true;
}
}
};
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename Device, typename T, bool USE_CUBLAS>
class MatMulOp : public OpKernel {
public:
explicit MatMulOp(OpKernelConstruction* ctx)
: OpKernel(ctx), algorithms_set_already_(false) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
LaunchMatMul<Device, T, USE_CUBLAS>::GetBlasGemmAlgorithm(
ctx, &algorithms_, &algorithms_set_already_);
use_autotune_ = MatmulAutotuneEnable();
}
void Compute(OpKernelContext* ctx) override {
const Tensor& a = ctx->input(0);
const Tensor& b = ctx->input(1);
// Check that the dimensions of the two matrices are valid.
OP_REQUIRES(
ctx, TensorShapeUtils::IsMatrix(a.shape()),
errors::InvalidArgument("In[0] is not a matrix. Instead it has shape ",
a.shape().DebugString()));
OP_REQUIRES(
ctx, TensorShapeUtils::IsMatrix(b.shape()),
errors::InvalidArgument("In[1] is not a matrix. Instead it has shape ",
b.shape().DebugString()));
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
dim_pair[0].first = transpose_a_ ? 0 : 1;
dim_pair[0].second = transpose_b_ ? 1 : 0;
OP_REQUIRES(
ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
errors::InvalidArgument(
"Matrix size-incompatible: In[0]: ", a.shape().DebugString(),
", In[1]: ", b.shape().DebugString()));
int a_dim_remaining = 1 - dim_pair[0].first;
int b_dim_remaining = 1 - dim_pair[0].second;
TensorShape out_shape(
{a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)});
Tensor* out = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
if (out->NumElements() == 0) {
// If a has shape [0, x] or b has shape [x, 0], the output shape
// is a 0-element matrix, so there is nothing to do.
return;
}
if (a.NumElements() == 0 && b.NumElements() == 0) {
// If a has shape [x, 0] and b has shape [0, y], the
// output shape is [x, y] where x and y are non-zero, so we fill
// the output with zeros.
functor::SetZeroFunctor<Device, T> f;
f(ctx->eigen_device<Device>(), out->flat<T>());
return;
}
if (std::is_same<T, bfloat16>::value) {
bool is_cpu = std::is_same<Device, CPUDevice>::value;
OP_REQUIRES(ctx, is_cpu,
errors::Internal("bfloat16 matmul is not supported by GPU"));
Tensor a_float, b_float, out_float;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, a.shape(), &a_float));
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, b.shape(), &b_float));
OP_REQUIRES_OK(ctx,
ctx->allocate_temp(DT_FLOAT, out->shape(), &out_float));
// TODO: Avoid extra copy to make bfloat16 matmul efficient on CPU.
BFloat16ToFloat(a.flat<bfloat16>().data(), a_float.flat<float>().data(),
a.NumElements());
BFloat16ToFloat(b.flat<bfloat16>().data(), b_float.flat<float>().data(),
b.NumElements());
LaunchMatMul<Device, float, USE_CUBLAS>::launch(
ctx, a_float, b_float, dim_pair, &algorithms_, use_autotune_,
&out_float);
FloatToBFloat16(out_float.flat<float>().data(),
out->flat<bfloat16>().data(), out->NumElements());
} else {
LaunchMatMul<Device, T, USE_CUBLAS>::launch(
ctx, a, b, dim_pair, &algorithms_, use_autotune_, out);
}
}
private:
std::vector<int64> algorithms_;
bool algorithms_set_already_;
bool use_autotune_;
bool transpose_a_;
bool transpose_b_;
};
namespace functor {
// Partial specialization MatMulFunctor<Device=CPUDevice, T>.
template <typename T>
struct MatMulFunctor<CPUDevice, T> {
void operator()(
const CPUDevice& d, typename MatMulTypes<T>::out_type out,
typename MatMulTypes<T>::in_type in0,
typename MatMulTypes<T>::in_type in1,
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) {
MatMul<CPUDevice>(d, out, in0, in1, dim_pair);
}
};
} // end namespace functor
#define REGISTER_CPU_EIGEN(T) \
REGISTER_KERNEL_BUILDER( \
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T").Label("eigen"), \
MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>);
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>); \
REGISTER_CPU_EIGEN(T);
#define REGISTER_GPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("MatMul").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
MatMulOp<GPUDevice, T, true /* cublas, true by default */>); \
REGISTER_KERNEL_BUILDER(Name("MatMul") \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.Label("cublas"), \
MatMulOp<GPUDevice, T, true /* cublas */>)
TF_CALL_int32(REGISTER_CPU);
TF_CALL_int64(REGISTER_CPU);
TF_CALL_FLOAT_TYPES(REGISTER_CPU);
TF_CALL_COMPLEX_TYPES(REGISTER_CPU);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_COMPLEX_TYPES(REGISTER_GPU);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
#include "tensorflow/core/kernels/matmul_op_impl.h"
namespace tensorflow {

View File

@ -15,8 +15,8 @@ limitations under the License.
// See docs in ../ops/math_ops.cc.
#ifndef TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
#define TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
#ifndef TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
#define TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
#define EIGEN_USE_THREADS
@ -633,10 +633,21 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
template <typename Device, typename Scalar>
class BaseBatchMatMulOp : public OpKernel {
public:
explicit BaseBatchMatMulOp(OpKernelConstruction* context)
explicit BaseBatchMatMulOp(OpKernelConstruction* context,
bool is_legacy_matmul)
: OpKernel(context) {
if (is_legacy_matmul) {
// The old MatMul kernel has "transpose_a/transpose_b" attributes.
OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &trans_x_));
OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &trans_y_));
adj_x_ = false;
adj_y_ = false;
} else {
OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
trans_x_ = false;
trans_y_ = false;
}
}
~BaseBatchMatMulOp() override {}
@ -672,8 +683,8 @@ class BaseBatchMatMulOp : public OpKernel {
in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
errors::Internal("Failed to reshape In[1] from ",
in1.shape().DebugString()));
if (adj_x_) std::swap(d0, d1);
if (adj_y_) std::swap(d2, d3);
if (adj_x_ || trans_x_) std::swap(d0, d1);
if (adj_y_ || trans_y_) std::swap(d2, d3);
OP_REQUIRES(ctx, d1 == d2,
errors::InvalidArgument(
"In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
@ -696,9 +707,36 @@ class BaseBatchMatMulOp : public OpKernel {
out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})),
errors::Internal("Failed to reshape output from ",
out->shape().DebugString()));
LaunchBatchMatMul<Device, Scalar>::Launch(
ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, /*trans_x=*/false,
/*trans_y=*/false, bcast, &out_reshaped);
if (std::is_same<Scalar, bfloat16>::value) {
bool is_cpu = std::is_same<Device, CPUDevice>::value;
OP_REQUIRES(ctx, is_cpu,
errors::Internal("bfloat16 matmul is not supported by GPU"));
Tensor in0_reshaped_float, in1_reshaped_float, out_reshaped_float;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in0_reshaped.shape(),
&in0_reshaped_float));
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in1_reshaped.shape(),
&in1_reshaped_float));
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, out_reshaped.shape(),
&out_reshaped_float));
// TODO: Avoid extra copy to make bfloat16 matmul efficient on CPU.
BFloat16ToFloat(in0_reshaped.flat<bfloat16>().data(),
in0_reshaped_float.flat<float>().data(),
in0_reshaped.NumElements());
BFloat16ToFloat(in1_reshaped.flat<bfloat16>().data(),
in1_reshaped_float.flat<float>().data(),
in1_reshaped.NumElements());
LaunchBatchMatMul<Device, float>::Launch(
ctx, in0_reshaped_float, in1_reshaped_float, adj_x_, adj_y_, trans_x_,
trans_y_, bcast, &out_reshaped_float);
FloatToBFloat16(out_reshaped_float.flat<float>().data(),
out_reshaped.flat<bfloat16>().data(), out->NumElements());
} else {
LaunchBatchMatMul<Device, Scalar>::Launch(ctx, in0_reshaped, in1_reshaped,
adj_x_, adj_y_, trans_x_,
trans_y_, bcast, &out_reshaped);
}
}
protected:
@ -706,16 +744,19 @@ class BaseBatchMatMulOp : public OpKernel {
const Tensor& in1) = 0;
private:
// TODO(171979567) Make the ops take both adj and transpose attributes.
bool adj_x_;
bool adj_y_;
bool trans_x_;
bool trans_y_;
};
// BatchMatMul Op implementation which disallows broadcasting.
template <typename Device, typename Scalar>
template <typename Device, typename Scalar, bool is_legacy_matmul = false>
class BatchMatMulOp : public BaseBatchMatMulOp<Device, Scalar> {
public:
explicit BatchMatMulOp(OpKernelConstruction* context)
: BaseBatchMatMulOp<Device, Scalar>(context) {}
: BaseBatchMatMulOp<Device, Scalar>(context, is_legacy_matmul) {}
~BatchMatMulOp() override {}
@ -729,15 +770,21 @@ class BatchMatMulOp : public BaseBatchMatMulOp<Device, Scalar> {
in0.shape().DebugString(), " vs. ",
in1.shape().DebugString()));
const int ndims = in0.dims();
OP_REQUIRES(
ctx, ndims >= 2,
errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims));
if (is_legacy_matmul) {
OP_REQUIRES(ctx, ndims == 2,
errors::InvalidArgument(
"In[0] and In[1] ndims must be == 2: ", ndims));
} else {
OP_REQUIRES(ctx, ndims >= 2,
errors::InvalidArgument(
"In[0] and In[1] ndims must be >= 2: ", ndims));
for (int i = 0; i < ndims - 2; ++i) {
OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
errors::InvalidArgument(
"In[0].dim(", i, ") and In[1].dim(", i,
") must be the same: ", in0.shape().DebugString(), " vs ",
in1.shape().DebugString()));
") must be the same: ", in0.shape().DebugString(),
" vs ", in1.shape().DebugString()));
}
}
}
};
@ -747,7 +794,8 @@ template <typename Device, typename Scalar>
class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> {
public:
explicit BatchMatMulV2Op(OpKernelConstruction* context)
: BaseBatchMatMulOp<Device, Scalar>(context) {}
: BaseBatchMatMulOp<Device, Scalar>(context,
/* is_legacy_matmul= */ false) {}
~BatchMatMulV2Op() override {}
@ -771,7 +819,10 @@ class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> {
BatchMatMulOp<CPUDevice, TYPE>); \
REGISTER_KERNEL_BUILDER( \
Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
BatchMatMulV2Op<CPUDevice, TYPE>)
BatchMatMulV2Op<CPUDevice, TYPE>); \
REGISTER_KERNEL_BUILDER( \
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
BatchMatMulOp<CPUDevice, TYPE, /* is_legacy_matmul=*/true>)
#define REGISTER_BATCH_MATMUL_GPU(TYPE) \
REGISTER_KERNEL_BUILDER( \
@ -779,8 +830,11 @@ class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> {
BatchMatMulOp<GPUDevice, TYPE>); \
REGISTER_KERNEL_BUILDER( \
Name("BatchMatMulV2").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
BatchMatMulV2Op<GPUDevice, TYPE>)
BatchMatMulV2Op<GPUDevice, TYPE>); \
REGISTER_KERNEL_BUILDER( \
Name("MatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
BatchMatMulOp<GPUDevice, TYPE, /* is_legacy_matmul=*/true>)
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
#endif // TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
#include "tensorflow/core/kernels/matmul_op_impl.h"
#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
@ -21,17 +21,13 @@ limitations under the License.
namespace tensorflow {
TF_CALL_float(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_double(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_half(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_FLOAT_TYPES(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_int16(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_int64(REGISTER_BATCH_MATMUL_CPU);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TF_CALL_float(REGISTER_BATCH_MATMUL_GPU);
TF_CALL_double(REGISTER_BATCH_MATMUL_GPU);
TF_CALL_half(REGISTER_BATCH_MATMUL_GPU);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_BATCH_MATMUL_GPU);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/public/session.h"
namespace tensorflow {
namespace {
template <typename T>
class FusedMatMulOpTest : public OpsTestBase {
@ -459,4 +460,230 @@ BM_Matmul(2000, 1, 2000, true, false);
BM_Matmul(2000, 1, 2000, false, true);
BM_Matmul(2000, 1, 2000, true, true);
} // end namespace tensorflow
// Benchmarks for batched matmul with broadcasting.
Node* BroadcastTo(Graph* g, Node* input, Node* shape) {
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastTo")
.Input(input)
.Input(shape)
.Finalize(g, &ret));
return ret;
}
Node* BatchMatmulV2(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y) {
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BatchMatMulV2")
.Input(in0)
.Input(in1)
.Attr("adj_x", adj_x)
.Attr("adj_y", adj_y)
.Finalize(g, &ret));
return ret;
}
template <typename T>
static Graph* BatchMatmul(int b, int m, int k, int n, bool adjoint_a,
bool adjoint_b, DataType type) {
Graph* g = new Graph(OpRegistry::Global());
Tensor in0(type, adjoint_a ? TensorShape({b, k, m}) : TensorShape({b, m, k}));
in0.flat<T>().setRandom();
Tensor in1(type, adjoint_b ? TensorShape({b, n, k}) : TensorShape({b, k, n}));
in1.flat<T>().setRandom();
test::graph::BatchMatmul(g, test::graph::Constant(g, in0),
test::graph::Constant(g, in1), adjoint_a, adjoint_b);
return g;
}
template <typename T>
static Graph* BatchMatmulWithBroadcast(int b0, int b1, int m, int k, int n,
bool manual_broadcast, DataType type) {
Graph* g = new Graph(OpRegistry::Global());
Tensor in0(type, TensorShape({b0, m, k}));
in0.flat<T>().setRandom();
Tensor in1(type, TensorShape({b1, k, n}));
in1.flat<T>().setRandom();
Tensor broadcasted_in0_shape(DT_INT64, TensorShape({3}));
Tensor broadcasted_in1_shape(DT_INT64, TensorShape({3}));
Node* in0_node = nullptr;
Node* in1_node = nullptr;
if (manual_broadcast) {
for (int i = 0; i < 3; ++i) {
auto vec0 = broadcasted_in0_shape.vec<int64>();
auto vec1 = broadcasted_in1_shape.vec<int64>();
vec0(i) = (i == 0 ? std::max(b0, b1) : in0.shape().dim_size(i));
vec1(i) = (i == 0 ? std::max(b0, b1) : in1.shape().dim_size(i));
}
in0_node = BroadcastTo(g, test::graph::Constant(g, in0),
test::graph::Constant(g, broadcasted_in0_shape));
in1_node = BroadcastTo(g, test::graph::Constant(g, in1),
test::graph::Constant(g, broadcasted_in1_shape));
} else {
in0_node = test::graph::Constant(g, in0);
in1_node = test::graph::Constant(g, in1);
}
BatchMatmulV2(g, in0_node, in1_node, false, false);
return g;
}
#define BM_BatchMatmulDev(B, M, K, N, TA, TB, T, TFTYPE, DEVICE) \
static void \
BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE( \
int iters) { \
testing::UseRealTime(); \
testing::ItemsProcessed(static_cast<int64>(iters) * B * M * K * N * 2); \
test::Benchmark(#DEVICE, BatchMatmul<T>(B, M, K, N, TA, TB, TFTYPE)) \
.Run(iters); \
} \
BENCHMARK( \
BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE);
#define BM_BatchMatmul(B, M, K, N, TA, TB) \
BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, cpu);
// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64,
// cpu);
// BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, gpu);
/* Uncomment to enable benchmarks for double & complex types: */
// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64,
// gpu);
// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, cpu); \
// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, cpu);
// \
// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, gpu); \
// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, gpu);
// Macro arguments names: --------------------------------------------------- //
// B1: batch size of LHS
// B2: batch size of RHS
// M: outer dimension of LHS
// K: inner dimensions of LHS and RHS
// N: outer dimension of RHS
// MB: boolean indicating whether to use manual broadcasting
// T: C++ type of scalars (e.g. float, std::complex)
// TT: TensorFlow type of scalars (e.g. DT_FLOAT, DT_COMPLEX128
// D: Device (e.g. cpu, gpu)
#define BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, T, TT, D) \
static void \
BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D( \
int iters) { \
testing::UseRealTime(); \
testing::ItemsProcessed(static_cast<int64>(iters) * std::max(B1, B2) * M * \
K * N * 2); \
test::Benchmark(#D, BatchMatmulWithBroadcast<T>(B1, B2, M, K, N, MB, TT)) \
.Run(iters); \
} \
BENCHMARK( \
BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D);
#define BM_BatchMatmulBCast(B1, B2, M, K, N, MB) \
BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, float, DT_FLOAT, cpu);
// Typical fully connected layers
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, true);
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, false);
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, true);
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, false);
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, true);
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, false);
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, true);
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, false);
// Square matmul.
BM_BatchMatmulBCast(1, 128, 512, 512, 512, true);
BM_BatchMatmulBCast(1, 128, 512, 512, 512, false);
BM_BatchMatmulBCast(128, 1, 512, 512, 512, true);
BM_BatchMatmulBCast(128, 1, 512, 512, 512, false);
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, true);
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, false);
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, true);
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, false);
// Matrix-vector multiplies.
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, true);
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, false);
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, true);
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, false);
// Vector-matrix multiplies.
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, true);
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, false);
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, true);
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, false);
// Typical fully connected layers
BM_BatchMatmul(1, 1, 1024, 1024, false, false);
BM_BatchMatmul(1, 8, 1024, 1024, false, false);
BM_BatchMatmul(1, 16, 1024, 1024, false, false);
BM_BatchMatmul(1, 128, 1024, 1024, false, false);
BM_BatchMatmul(2, 1, 1024, 1024, false, false);
BM_BatchMatmul(2, 8, 1024, 1024, false, false);
BM_BatchMatmul(2, 16, 1024, 1024, false, false);
BM_BatchMatmul(2, 128, 1024, 1024, false, false);
BM_BatchMatmul(8, 1, 1024, 1024, false, false);
BM_BatchMatmul(8, 8, 1024, 1024, false, false);
BM_BatchMatmul(8, 16, 1024, 1024, false, false);
BM_BatchMatmul(8, 128, 1024, 1024, false, false);
BM_BatchMatmul(32, 1, 1024, 1024, false, false);
BM_BatchMatmul(32, 8, 1024, 1024, false, false);
BM_BatchMatmul(32, 16, 1024, 1024, false, false);
BM_BatchMatmul(32, 128, 1024, 1024, false, false);
// Square matmul.
BM_BatchMatmul(1, 32, 32, 32, false, false);
BM_BatchMatmul(1, 128, 128, 128, false, false);
BM_BatchMatmul(1, 256, 256, 256, false, false);
BM_BatchMatmul(1, 1024, 1024, 1024, false, false);
BM_BatchMatmul(1, 2048, 2048, 2048, false, false);
BM_BatchMatmul(2, 32, 32, 32, false, false);
BM_BatchMatmul(2, 128, 128, 128, false, false);
BM_BatchMatmul(2, 256, 256, 256, false, false);
BM_BatchMatmul(2, 1024, 1024, 1024, false, false);
BM_BatchMatmul(2, 2048, 2048, 2048, false, false);
BM_BatchMatmul(4, 32, 32, 32, false, false);
BM_BatchMatmul(4, 128, 128, 128, false, false);
BM_BatchMatmul(4, 256, 256, 256, false, false);
BM_BatchMatmul(4, 1024, 1024, 1024, false, false);
BM_BatchMatmul(4, 2048, 2048, 2048, false, false);
BM_BatchMatmul(8, 32, 32, 32, false, false);
BM_BatchMatmul(8, 128, 128, 128, false, false);
BM_BatchMatmul(8, 256, 256, 256, false, false);
BM_BatchMatmul(8, 1024, 1024, 1024, false, false);
BM_BatchMatmul(8, 2048, 2048, 2048, false, false);
BM_BatchMatmul(32, 32, 32, 32, false, false);
BM_BatchMatmul(32, 128, 128, 128, false, false);
BM_BatchMatmul(32, 256, 256, 256, false, false);
BM_BatchMatmul(32, 1024, 1024, 1024, false, false);
BM_BatchMatmul(32, 2048, 2048, 2048, false, false);
// Matrix-vector multiplies.
BM_BatchMatmul(1, 10000, 200, 1, false, false);
BM_BatchMatmul(8, 10000, 200, 1, false, false);
BM_BatchMatmul(32, 10000, 200, 1, false, false);
BM_BatchMatmul(1, 10000, 200, 1, true, false);
BM_BatchMatmul(8, 10000, 200, 1, true, false);
BM_BatchMatmul(32, 10000, 200, 1, true, false);
BM_BatchMatmul(1, 10000, 200, 1, false, true);
BM_BatchMatmul(8, 10000, 200, 1, false, true);
BM_BatchMatmul(32, 10000, 200, 1, false, true);
BM_BatchMatmul(1, 10000, 200, 1, true, true);
BM_BatchMatmul(8, 10000, 200, 1, true, true);
BM_BatchMatmul(32, 10000, 200, 1, true, true);
// Vector-matrix multiplies.
BM_BatchMatmul(1, 1, 200, 10000, false, false);
BM_BatchMatmul(8, 1, 200, 10000, false, false);
BM_BatchMatmul(32, 1, 200, 10000, false, false);
BM_BatchMatmul(1, 1, 200, 10000, true, false);
BM_BatchMatmul(8, 1, 200, 10000, true, false);
BM_BatchMatmul(32, 1, 200, 10000, true, false);
BM_BatchMatmul(1, 1, 200, 10000, false, true);
BM_BatchMatmul(8, 1, 200, 10000, false, true);
BM_BatchMatmul(32, 1, 200, 10000, false, true);
BM_BatchMatmul(1, 1, 200, 10000, true, true);
BM_BatchMatmul(8, 1, 200, 10000, true, true);
BM_BatchMatmul(32, 1, 200, 10000, true, true);
} // namespace
} // namespace tensorflow

View File

@ -33,8 +33,8 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/type_traits.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/matmul_op_impl.h"
#include "tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"

View File

@ -1,5 +1,5 @@
func @Isinf_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
-> tensor<*xi1> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.IsInf"(%arg0) : (tensor<*xelem_type>) -> tensor<*xi1>
return %0 : tensor<*xi1>
}

View File

@ -1,5 +1,5 @@
func @Isnan_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
-> tensor<*xi1> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.IsNan"(%arg0) : (tensor<*xelem_type>) -> tensor<*xi1>
return %0 : tensor<*xi1>
}

View File

@ -41,9 +41,13 @@ static Graph* Multinomial(int batch_size, int num_classes, int num_samples) {
}
#define BM_MultinomialDev(DEVICE, B, C, S) \
static void BM_Multinomial_##DEVICE##_##B##_##C##_##S(int iters) { \
test::Benchmark(#DEVICE, Multinomial(B, C, S)).Run(iters); \
testing::ItemsProcessed(static_cast<int64>(B) * C * S * iters); \
static void BM_Multinomial_##DEVICE##_##B##_##C##_##S( \
::testing::benchmark::State& state) { \
test::Benchmark(#DEVICE, Multinomial(B, C, S), \
/*old_benchmark_api*/ false) \
.Run(state); \
state.SetItemsProcessed(static_cast<int64>(B) * C * S * \
state.iterations()); \
} \
BENCHMARK(BM_Multinomial_##DEVICE##_##B##_##C##_##S);

View File

@ -103,18 +103,18 @@ enum CONV_OP {
} // namespace
static void BM_ConvFloat(int iters, int batch, int rows, int cols, int in_depth,
int out_depth, int filter_rows, int filter_cols,
CONV_OP op, int num_threads, int stride,
Padding padding, bool use_gpu, DataType data_type,
static void BM_ConvFloat(::testing::benchmark::State& state, int batch,
int rows, int cols, int in_depth, int out_depth,
int filter_rows, int filter_cols, CONV_OP op,
int num_threads, int stride, Padding padding,
bool use_gpu, DataType data_type,
const string& label) {
testing::StopTiming();
if (!IsGoogleCudaEnabled() && use_gpu) {
testing::SetLabel(
state.SetLabel(
strings::StrCat("Skipping GPU test (no --config=cuda): ", label));
return;
}
testing::SetLabel(label);
state.SetLabel(label);
// Set the number of threads
SessionOptions options;
@ -221,10 +221,10 @@ static void BM_ConvFloat(int iters, int batch, int rows, int cols, int in_depth,
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph, g));
string device = use_gpu ? "gpu" : "cpu";
testing::UseRealTime();
testing::StartTiming();
test::Benchmark(device, g, &options).Run(iters);
testing::ItemsProcessed(num_ops * iters);
test::Benchmark(device, g, &options, nullptr, nullptr, "",
/*old_benchmark_api*/ false)
.Run(state);
state.SetItemsProcessed(num_ops * state.iterations());
}
// BS: batch_size
@ -235,48 +235,52 @@ static void BM_ConvFloat(int iters, int batch, int rows, int cols, int in_depth,
// KR: kernel_rows
// KC: kernel_cols
#define BM_ConvFloatFwd(BS, R, C, ID, OD, KR, KC, STR, PAD, LABEL) \
static void BM_ConvFloatFwdCPU1_##LABEL(int iters) { \
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR, \
static void BM_ConvFloatFwdCPU1_##LABEL( \
::testing::benchmark::State& state) { \
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR, \
PAD, false, DT_FLOAT, \
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
KR, "_", KC, "_", STR, "_", PAD, "_f_cpu1")); \
} \
static void BM_ConvFloatFwdCPU4_##LABEL(int iters) { \
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 4, STR, \
static void BM_ConvFloatFwdCPU4_##LABEL( \
::testing::benchmark::State& state) { \
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 4, STR, \
PAD, false, DT_FLOAT, \
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
KR, "_", KC, "_", STR, "_", PAD, "_f_cpu4")); \
} \
static void BM_ConvFloatFusedCPU1_##LABEL(int iters) { \
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FUSED, 1, STR, PAD, \
static void BM_ConvFloatFusedCPU1_##LABEL( \
::testing::benchmark::State& state) { \
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_FUSED, 1, STR, PAD, \
false, DT_FLOAT, \
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
KR, "_", KC, "_", STR, "_", PAD, "_f_cpu1")); \
} \
static void BM_ConvFloatFusedCPU4_##LABEL(int iters) { \
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FUSED, 4, STR, PAD, \
static void BM_ConvFloatFusedCPU4_##LABEL( \
::testing::benchmark::State& state) { \
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_FUSED, 4, STR, PAD, \
false, DT_FLOAT, \
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
KR, "_", KC, "_", STR, "_", PAD, "_f_cpu4")); \
} \
static void BM_ConvFloatFwdGPU_##LABEL(int iters) { \
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR, \
static void BM_ConvFloatFwdGPU_##LABEL(::testing::benchmark::State& state) { \
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR, \
PAD, true, DT_FLOAT, \
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
KR, "_", KC, "_", STR, "_", PAD, "_f_gpu")); \
} \
static void BM_ConvHalfFwdGPU_##LABEL(int iters) { \
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR, \
static void BM_ConvHalfFwdGPU_##LABEL(::testing::benchmark::State& state) { \
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR, \
PAD, true, DT_HALF, \
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
KR, "_", KC, "_", STR, "_", PAD, "_h_gpu")); \
} \
BENCHMARK(BM_ConvFloatFwdCPU1_##LABEL); \
BENCHMARK(BM_ConvFloatFwdCPU4_##LABEL); \
BENCHMARK(BM_ConvFloatFusedCPU1_##LABEL); \
BENCHMARK(BM_ConvFloatFusedCPU4_##LABEL); \
BENCHMARK(BM_ConvFloatFwdGPU_##LABEL); \
BENCHMARK(BM_ConvHalfFwdGPU_##LABEL)
BENCHMARK(BM_ConvFloatFwdCPU1_##LABEL)->UseRealTime(); \
BENCHMARK(BM_ConvFloatFwdCPU4_##LABEL)->UseRealTime(); \
BENCHMARK(BM_ConvFloatFusedCPU1_##LABEL)->UseRealTime(); \
BENCHMARK(BM_ConvFloatFusedCPU4_##LABEL)->UseRealTime(); \
BENCHMARK(BM_ConvFloatFwdGPU_##LABEL)->UseRealTime(); \
BENCHMARK(BM_ConvHalfFwdGPU_##LABEL)->UseRealTime()
BM_ConvFloatFwd(32, 5, 5, 1248, 128, 1, 1, 1, SAME, conv0);
BM_ConvFloatFwd(32, 8, 8, 384, 384, 1, 3, 1, SAME, conv1);
@ -335,62 +339,69 @@ BM_ConvFloatFwd(32, 73, 73, 64, 64, 1, 1, 1, VALID, conv53);
BM_ConvFloatFwd(32, 147, 147, 24, 64, 1, 1, 1, VALID, conv54);
#define BM_ConvFloatBkInAndFilter(BS, R, C, ID, OD, KR, KC, STR, PAD, LABEL) \
static void BM_ConvFloatBkInCPU1_##LABEL(int iters) { \
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1, \
static void BM_ConvFloatBkInCPU1_##LABEL( \
::testing::benchmark::State& state) { \
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1, \
STR, PAD, false, DT_FLOAT, \
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
KR, "_", KC, "_", STR, "_", PAD, "_cpu1")); \
} \
static void BM_ConvFloatBkInCPU4_##LABEL(int iters) { \
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 4, \
static void BM_ConvFloatBkInCPU4_##LABEL( \
::testing::benchmark::State& state) { \
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 4, \
STR, PAD, false, DT_FLOAT, \
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
KR, "_", KC, "_", STR, "_", PAD, "_cpu4")); \
} \
static void BM_ConvFloatBkInGPU_##LABEL(int iters) { \
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1, \
static void BM_ConvFloatBkInGPU_##LABEL( \
::testing::benchmark::State& state) { \
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1, \
STR, PAD, true, DT_FLOAT, \
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \
} \
static void BM_ConvFloatBkFilterCPU1_##LABEL(int iters) { \
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
static void BM_ConvFloatBkFilterCPU1_##LABEL( \
::testing::benchmark::State& state) { \
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
STR, PAD, false, DT_FLOAT, \
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
KR, "_", KC, "_", STR, "_", PAD, "_cpu1")); \
} \
static void BM_ConvFloatBkFilterCPU4_##LABEL(int iters) { \
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 4, \
static void BM_ConvFloatBkFilterCPU4_##LABEL( \
::testing::benchmark::State& state) { \
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 4, \
STR, PAD, false, DT_FLOAT, \
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
KR, "_", KC, "_", STR, "_", PAD, "_cpu4")); \
} \
static void BM_ConvFloatBkFilterGPU_##LABEL(int iters) { \
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
static void BM_ConvFloatBkFilterGPU_##LABEL( \
::testing::benchmark::State& state) { \
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
STR, PAD, true, DT_FLOAT, \
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \
} \
static void BM_ConvHalfBkInGPU_##LABEL(int iters) { \
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1, \
static void BM_ConvHalfBkInGPU_##LABEL(::testing::benchmark::State& state) { \
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1, \
STR, PAD, true, DT_HALF, \
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \
} \
static void BM_ConvHalfBkFilterGPU_##LABEL(int iters) { \
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
static void BM_ConvHalfBkFilterGPU_##LABEL( \
::testing::benchmark::State& state) { \
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
STR, PAD, true, DT_HALF, \
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \
} \
BENCHMARK(BM_ConvFloatBkInCPU1_##LABEL); \
BENCHMARK(BM_ConvFloatBkInCPU4_##LABEL); \
BENCHMARK(BM_ConvFloatBkInGPU_##LABEL); \
BENCHMARK(BM_ConvFloatBkFilterCPU1_##LABEL); \
BENCHMARK(BM_ConvFloatBkFilterCPU4_##LABEL); \
BENCHMARK(BM_ConvFloatBkFilterGPU_##LABEL); \
BENCHMARK(BM_ConvHalfBkInGPU_##LABEL); \
BENCHMARK(BM_ConvHalfBkFilterGPU_##LABEL)
BENCHMARK(BM_ConvFloatBkInCPU1_##LABEL)->UseRealTime(); \
BENCHMARK(BM_ConvFloatBkInCPU4_##LABEL)->UseRealTime(); \
BENCHMARK(BM_ConvFloatBkInGPU_##LABEL)->UseRealTime(); \
BENCHMARK(BM_ConvFloatBkFilterCPU1_##LABEL)->UseRealTime(); \
BENCHMARK(BM_ConvFloatBkFilterCPU4_##LABEL)->UseRealTime(); \
BENCHMARK(BM_ConvFloatBkFilterGPU_##LABEL)->UseRealTime(); \
BENCHMARK(BM_ConvHalfBkInGPU_##LABEL)->UseRealTime(); \
BENCHMARK(BM_ConvHalfBkFilterGPU_##LABEL)->UseRealTime()
// Benchmarks from the inception model
@ -453,8 +464,8 @@ BM_ConvFloatBkInAndFilter(32, 147, 147, 24, 64, 1, 1, 1, VALID, conv54);
#define BM_ConvFloatBkFCPU(BS, R, C, ID, OD, KR, KC, TH, LABEL) \
static void \
BM_ConvFloatBkFCPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC##_##TH( \
int iters) { \
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, TH, \
::testing::benchmark::State& state) { \
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, TH, \
1, VALID, false, DT_FLOAT, LABEL); \
} \
BENCHMARK( \
@ -469,17 +480,19 @@ BM_ConvFloatBkFCPU(128, 13, 13, 384, 384, 3, 3, 4, "convnet-layer5");
#define BM_ConvFloatBkFGPU(BS, R, C, ID, OD, KR, KC, LABEL) \
static void BM_ConvFloatBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC( \
int iters) { \
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
::testing::benchmark::State& state) { \
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
1, VALID, true, DT_FLOAT, LABEL); \
} \
static void BM_ConvHalfBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC( \
int iters) { \
BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
::testing::benchmark::State& state) { \
BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
1, VALID, true, DT_HALF, LABEL); \
} \
BENCHMARK(BM_ConvFloatBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC); \
BENCHMARK(BM_ConvHalfBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC)
BENCHMARK(BM_ConvFloatBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC) \
->UseRealTime(); \
BENCHMARK(BM_ConvHalfBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC) \
->UseRealTime()
// Benchmarks from https://github.com/soumith/convnet-benchmarks
BM_ConvFloatBkFGPU(128, 128, 128, 3, 96, 11, 11, "convnet-layer1");
@ -498,19 +511,19 @@ enum DEPTHWISE_CONV_OP {
} // namespace
static void BM_ConvFloatDepthwise(int iters, int batch, int rows, int cols,
int in_depth, int depth_multiplier,
int out_depth, int filter_rows,
int filter_cols, DEPTHWISE_CONV_OP op,
int num_threads, int stride, Padding padding,
bool use_gpu, const string& label) {
testing::StopTiming();
static void BM_ConvFloatDepthwise(::testing::benchmark::State& state, int batch,
int rows, int cols, int in_depth,
int depth_multiplier, int out_depth,
int filter_rows, int filter_cols,
DEPTHWISE_CONV_OP op, int num_threads,
int stride, Padding padding, bool use_gpu,
const string& label) {
if (!IsGoogleCudaEnabled() && use_gpu) {
testing::SetLabel(
state.SetLabel(
strings::StrCat("Skipping GPU test (no --config=cuda): ", label));
return;
}
testing::SetLabel(label);
state.SetLabel(label);
// Set the number of threads
SessionOptions options;
@ -603,10 +616,10 @@ static void BM_ConvFloatDepthwise(int iters, int batch, int rows, int cols,
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph, g));
string device = use_gpu ? "gpu" : "cpu";
testing::UseRealTime();
testing::StartTiming();
test::Benchmark(device, g, &options).Run(iters);
testing::ItemsProcessed(num_ops * iters);
test::Benchmark(device, g, &options, nullptr, nullptr, "",
/*old_benchmark_api=*/false)
.Run(state);
state.SetItemsProcessed(num_ops * state.iterations());
}
// BS: batch_size
@ -622,30 +635,33 @@ static void BM_ConvFloatDepthwise(int iters, int batch, int rows, int cols,
#define BM_ConvFloatDepthwiseFwd(BS, R, C, ID, DM, OD, KR, KC, STR, PAD, \
LABEL) \
static void BM_ConvFloatDepthwiseFwdCPU1_##LABEL(int iters) { \
static void BM_ConvFloatDepthwiseFwdCPU1_##LABEL( \
::testing::benchmark::State& state) { \
BM_ConvFloatDepthwise( \
iters, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_FWD, 1, STR, \
state, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_FWD, 1, STR, \
PAD, false, \
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \
KR, "_", KC, "_", STR, "_", PAD, "_cpu1")); \
} \
static void BM_ConvFloatDepthwiseFwdCPU4_##LABEL(int iters) { \
static void BM_ConvFloatDepthwiseFwdCPU4_##LABEL( \
::testing::benchmark::State& state) { \
BM_ConvFloatDepthwise( \
iters, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_FWD, 4, STR, \
state, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_FWD, 4, STR, \
PAD, false, \
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \
KR, "_", KC, "_", STR, "_", PAD, "_cpu4")); \
} \
static void BM_ConvFloatDepthwiseFwdGPU_##LABEL(int iters) { \
static void BM_ConvFloatDepthwiseFwdGPU_##LABEL( \
::testing::benchmark::State& state) { \
BM_ConvFloatDepthwise( \
iters, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_FWD, 1, STR, \
state, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_FWD, 1, STR, \
PAD, true, \
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \
KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \
} \
BENCHMARK(BM_ConvFloatDepthwiseFwdCPU1_##LABEL); \
BENCHMARK(BM_ConvFloatDepthwiseFwdCPU4_##LABEL); \
BENCHMARK(BM_ConvFloatDepthwiseFwdGPU_##LABEL);
BENCHMARK(BM_ConvFloatDepthwiseFwdCPU1_##LABEL)->UseRealTime(); \
BENCHMARK(BM_ConvFloatDepthwiseFwdCPU4_##LABEL)->UseRealTime(); \
BENCHMARK(BM_ConvFloatDepthwiseFwdGPU_##LABEL)->UseRealTime();
// The configurations below are mostly from mobilenet models.
BM_ConvFloatDepthwiseFwd(32, 112, 112, 3, 8, 24, 3, 3, 1, SAME, conv0);
@ -662,53 +678,59 @@ BM_ConvFloatDepthwiseFwd(1, 100, 100, 72, 1, 72, 3, 3, 1, SAME, conv9);
BM_ConvFloatDepthwiseFwd(1, 100, 100, 72, 1, 72, 5, 5, 1, SAME, conv10);
#define BM_ConvFloatDepthwiseBk(BS, R, C, ID, DM, OD, KR, KC, STR, PAD, LABEL) \
static void BM_ConvFloatDepthwiseBkInCPU1_##LABEL(int iters) { \
static void BM_ConvFloatDepthwiseBkInCPU1_##LABEL( \
::testing::benchmark::State& state) { \
BM_ConvFloatDepthwise( \
iters, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_BACKPROP_INPUT, \
state, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_BACKPROP_INPUT, \
1, STR, PAD, false, \
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \
KR, "_", KC, "_", STR, "_", PAD, "_cpu1")); \
} \
static void BM_ConvFloatDepthwiseBkInCPU4_##LABEL(int iters) { \
static void BM_ConvFloatDepthwiseBkInCPU4_##LABEL( \
::testing::benchmark::State& state) { \
BM_ConvFloatDepthwise( \
iters, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_BACKPROP_INPUT, \
state, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_BACKPROP_INPUT, \
4, STR, PAD, false, \
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \
KR, "_", KC, "_", STR, "_", PAD, "_cpu4")); \
} \
static void BM_ConvFloatDepthwiseBkInGPU_##LABEL(int iters) { \
static void BM_ConvFloatDepthwiseBkInGPU_##LABEL( \
::testing::benchmark::State& state) { \
BM_ConvFloatDepthwise( \
iters, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_BACKPROP_INPUT, \
state, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_BACKPROP_INPUT, \
4, STR, PAD, true, \
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \
KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \
} \
static void BM_ConvFloatDepthwiseBkFilterCPU1_##LABEL(int iters) { \
static void BM_ConvFloatDepthwiseBkFilterCPU1_##LABEL( \
::testing::benchmark::State& state) { \
BM_ConvFloatDepthwise( \
iters, BS, R, C, ID, DM, OD, KR, KC, \
state, BS, R, C, ID, DM, OD, KR, KC, \
DEPTHWISE_CONV_OP_BACKPROP_FILTER, 1, STR, PAD, false, \
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \
KR, "_", KC, "_", STR, "_", PAD, "_cpu1")); \
} \
static void BM_ConvFloatDepthwiseBkFilterCPU4_##LABEL(int iters) { \
static void BM_ConvFloatDepthwiseBkFilterCPU4_##LABEL( \
::testing::benchmark::State& state) { \
BM_ConvFloatDepthwise( \
iters, BS, R, C, ID, DM, OD, KR, KC, \
state, BS, R, C, ID, DM, OD, KR, KC, \
DEPTHWISE_CONV_OP_BACKPROP_FILTER, 4, STR, PAD, false, \
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \
KR, "_", KC, "_", STR, "_", PAD, "_cpu4")); \
} \
static void BM_ConvFloatDepthwiseBkFilterGPU_##LABEL(int iters) { \
static void BM_ConvFloatDepthwiseBkFilterGPU_##LABEL( \
::testing::benchmark::State& state) { \
BM_ConvFloatDepthwise( \
iters, BS, R, C, ID, DM, OD, KR, KC, \
state, BS, R, C, ID, DM, OD, KR, KC, \
DEPTHWISE_CONV_OP_BACKPROP_FILTER, 4, STR, PAD, true, \
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \
KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \
} \
BENCHMARK(BM_ConvFloatDepthwiseBkInCPU1_##LABEL); \
BENCHMARK(BM_ConvFloatDepthwiseBkInCPU4_##LABEL); \
BENCHMARK(BM_ConvFloatDepthwiseBkInGPU_##LABEL); \
BENCHMARK(BM_ConvFloatDepthwiseBkFilterCPU1_##LABEL); \
BENCHMARK(BM_ConvFloatDepthwiseBkFilterCPU4_##LABEL); \
BENCHMARK(BM_ConvFloatDepthwiseBkInCPU1_##LABEL)->UseRealTime(); \
BENCHMARK(BM_ConvFloatDepthwiseBkInCPU4_##LABEL)->UseRealTime(); \
BENCHMARK(BM_ConvFloatDepthwiseBkInGPU_##LABEL)->UseRealTime(); \
BENCHMARK(BM_ConvFloatDepthwiseBkFilterCPU1_##LABEL)->UseRealTime(); \
BENCHMARK(BM_ConvFloatDepthwiseBkFilterCPU4_##LABEL)->UseRealTime(); \
BENCHMARK(BM_ConvFloatDepthwiseBkFilterGPU_##LABEL)
// The configurations below are mostly from mobilenet models.
@ -732,10 +754,9 @@ BM_ConvFloatDepthwiseBk(32, 112, 112, 8, 3, 24, 3, 3, 1, SAME, conv12);
BM_ConvFloatDepthwiseBk(32, 112, 112, 12, 2, 24, 3, 3, 1, SAME, conv13);
BM_ConvFloatDepthwiseBk(32, 112, 112, 24, 1, 24, 3, 3, 1, SAME, conv14);
static void BM_LRNFloat(int iters, int depth, int cols, int rows,
int batch_size, int range, int num_threads,
static void BM_LRNFloat(::testing::benchmark::State& state, int depth, int cols,
int rows, int batch_size, int range, int num_threads,
const string& label) {
tensorflow::testing::StopTiming();
std::unique_ptr<Device> device(
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
@ -778,26 +799,24 @@ static void BM_LRNFloat(int iters, int depth, int cols, int rows,
std::unique_ptr<OpKernelContext> context(new OpKernelContext(&params));
op->Compute(context.get());
testing::UseRealTime();
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
delete context->release_output(0).tensor;
op->Compute(context.get());
}
tensorflow::testing::StopTiming();
testing::ItemsProcessed(context->mutable_output(0)->NumElements() * iters *
(2 * range + 1) * 2);
testing::SetLabel(label);
state.SetItemsProcessed(context->mutable_output(0)->NumElements() *
state.iterations() * (2 * range + 1) * 2);
state.SetLabel(label);
}
#define BM_LRNFloatFwdCPU(DEPTH, COLS, ROWS, BATCH, RANGE, THREADS, LABEL) \
static void \
BM_LRNFloat_##DEPTH##_##COLS##_##ROWS##_##BATCH##_##RANGE##_##THREADS( \
int iters) { \
BM_LRNFloat(iters, DEPTH, COLS, ROWS, BATCH, RANGE, THREADS, LABEL); \
::testing::benchmark::State& state) { \
BM_LRNFloat(state, DEPTH, COLS, ROWS, BATCH, RANGE, THREADS, LABEL); \
} \
BENCHMARK( \
BM_LRNFloat_##DEPTH##_##COLS##_##ROWS##_##BATCH##_##RANGE##_##THREADS)
BM_LRNFloat_##DEPTH##_##COLS##_##ROWS##_##BATCH##_##RANGE##_##THREADS) \
->UseRealTime()
// clang-format off
// DEPTH, COLS, ROWS, BATCH, RANGE, THREADS, LABEL
@ -815,10 +834,10 @@ BM_LRNFloatFwdCPU(192, 56, 56, 32, 5, 8, "lrn 8 threads");
/*
AvgPooling Op
*/
static void BM_AvgPool(int iters, int batch_size, int rows, int cols, int depth,
int kernel_rows, int kernel_cols, int stride,
Padding padding, int num_threads, const string& label) {
tensorflow::testing::StopTiming();
static void BM_AvgPool(::testing::benchmark::State& state, int batch_size,
int rows, int cols, int depth, int kernel_rows,
int kernel_cols, int stride, Padding padding,
int num_threads, const string& label) {
std::unique_ptr<Device> device(
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
@ -860,16 +879,13 @@ static void BM_AvgPool(int iters, int batch_size, int rows, int cols, int depth,
new OpKernelContext(&params));
op->Compute(avgpool_context.get());
testing::UseRealTime();
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
delete avgpool_context->release_output(0).tensor;
op->Compute(avgpool_context.get());
}
tensorflow::testing::StopTiming();
testing::ItemsProcessed(avgpool_context->mutable_output(0)->NumElements() *
iters);
testing::SetLabel(label);
state.SetItemsProcessed(avgpool_context->mutable_output(0)->NumElements() *
state.iterations());
state.SetLabel(label);
}
// BS: batch_size
@ -883,11 +899,12 @@ static void BM_AvgPool(int iters, int batch_size, int rows, int cols, int depth,
#define BM_AvgPoolFwdCPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL) \
static void \
BM_AvgPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH( \
int iters) { \
BM_AvgPool(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL); \
::testing::benchmark::State& state) { \
BM_AvgPool(state, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL); \
} \
BENCHMARK( \
BM_AvgPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH)
BM_AvgPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH) \
->UseRealTime()
// Labels are taken from the 2014-July-24 version of imagenet
BM_AvgPoolFwdCPU(32, 112, 112, 64, 3, 3, 2, VALID, 1, "avgpool0_VALID");
@ -907,11 +924,10 @@ BM_AvgPoolFwdCPU(32, 56, 56, 192, 3, 3, 2, SAME, 4, "avgpool1_SAME");
BM_AvgPoolFwdCPU(32, 28, 28, 352, 3, 3, 2, SAME, 4, "avgpool4_SAME");
BM_AvgPoolFwdCPU(32, 14, 14, 576, 3, 3, 2, SAME, 4, "avgpool10_SAME");
static void BM_AvgPoolBk(int iters, int batch_size, int rows, int cols,
int depth, int kernel_rows, int kernel_cols,
int stride, Padding padding, int num_threads,
const string& label) {
tensorflow::testing::StopTiming();
static void BM_AvgPoolBk(::testing::benchmark::State& state, int batch_size,
int rows, int cols, int depth, int kernel_rows,
int kernel_cols, int stride, Padding padding,
int num_threads, const string& label) {
std::unique_ptr<Device> device(
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
@ -966,16 +982,13 @@ static void BM_AvgPoolBk(int iters, int batch_size, int rows, int cols,
new OpKernelContext(&params));
op->Compute(avgpool_context.get());
testing::UseRealTime();
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
delete avgpool_context->release_output(0).tensor;
op->Compute(avgpool_context.get());
}
tensorflow::testing::StopTiming();
testing::ItemsProcessed(avgpool_context->mutable_output(0)->NumElements() *
iters);
testing::SetLabel(label);
state.SetItemsProcessed(avgpool_context->mutable_output(0)->NumElements() *
state.iterations());
state.SetLabel(label);
}
// BS: batch_size
@ -987,14 +1000,17 @@ static void BM_AvgPoolBk(int iters, int batch_size, int rows, int cols,
// ST: stride. We use the same stride for both directions.
// PT: padding
// The resulted symbol is too long. Need to use two macros to fit in 80-chars
// NOLINTBEGIN
#define BM_AvgPoolBkCPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL) \
static void \
BM_AvgPoolBk_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH( \
int iters) { \
BM_AvgPoolBk(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL); \
::testing::benchmark::State& state) { \
BM_AvgPoolBk(state, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL); \
} \
BENCHMARK( \
BM_AvgPoolBk_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH)
BM_AvgPoolBk_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH) \
->UseRealTime()
// NOLINTEND
// Shapes taken from the 2015/05/16 inception model
BM_AvgPoolBkCPU(32, 35, 35, 192, 3, 3, 1, SAME, 1, "avgpool_grad0_SAME");
@ -1010,10 +1026,10 @@ BM_AvgPoolBkCPU(32, 8, 8, 2048, 8, 8, 1, VALID, 1, "avgpool_grad8_VALID");
/*
MaxPooling Op
*/
static void BM_MaxPool(int iters, int batch_size, int rows, int cols, int depth,
int kernel_rows, int kernel_cols, int stride,
Padding padding, int num_threads, const string& label) {
tensorflow::testing::StopTiming();
static void BM_MaxPool(::testing::benchmark::State& state, int batch_size,
int rows, int cols, int depth, int kernel_rows,
int kernel_cols, int stride, Padding padding,
int num_threads, const string& label) {
SessionOptions options;
options.config.set_intra_op_parallelism_threads(num_threads);
@ -1057,16 +1073,13 @@ static void BM_MaxPool(int iters, int batch_size, int rows, int cols, int depth,
new OpKernelContext(&params));
op->Compute(maxpool_context.get());
testing::UseRealTime();
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
delete maxpool_context->release_output(0).tensor;
op->Compute(maxpool_context.get());
}
tensorflow::testing::StopTiming();
testing::ItemsProcessed(maxpool_context->mutable_output(0)->NumElements() *
iters);
testing::SetLabel(label);
state.SetItemsProcessed(maxpool_context->mutable_output(0)->NumElements() *
state.iterations());
state.SetLabel(label);
}
// BS: batch_size
@ -1080,11 +1093,12 @@ static void BM_MaxPool(int iters, int batch_size, int rows, int cols, int depth,
#define BM_MaxPoolFwdCPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL) \
static void \
BM_MaxPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH( \
int iters) { \
BM_MaxPool(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL); \
::testing::benchmark::State& state) { \
BM_MaxPool(state, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL); \
} \
BENCHMARK( \
BM_MaxPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH)
BM_MaxPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH) \
->UseRealTime()
// Labels are taken from the 2014-July-24 version of imagenet
/* TODO XXX
@ -1106,10 +1120,10 @@ BM_MaxPoolFwdCPU(32, 56, 56, 192, 3, 3, 2, SAME, 4, "maxpool1_SAME");
BM_MaxPoolFwdCPU(32, 28, 28, 352, 3, 3, 2, SAME, 4, "maxpool4_SAME");
BM_MaxPoolFwdCPU(32, 14, 14, 576, 3, 3, 2, SAME, 4, "maxpool10_SAME");
static void BM_MaxPoolBk(int iters, int batch_size, int rows, int cols,
int depth, int kernel_rows, int kernel_cols,
int stride, Padding padding, int num_threads,
bool use_gpu, const string& label) {
static void BM_MaxPoolBk(::testing::benchmark::State& state, int batch_size,
int rows, int cols, int depth, int kernel_rows,
int kernel_cols, int stride, Padding padding,
int num_threads, bool use_gpu, const string& label) {
auto root = Scope::NewRootScope().ExitOnError();
int64 out_height, out_width, pad_rows, pad_cols;
@ -1138,11 +1152,11 @@ static void BM_MaxPoolBk(int iters, int batch_size, int rows, int cols,
Graph* g = new Graph(OpRegistry::Global());
TF_CHECK_OK(root.ToGraph(g));
string device = use_gpu ? "gpu" : "cpu";
testing::UseRealTime();
test::Benchmark(device, g).Run(iters);
test::Benchmark(device, g, /*old_benchmark_api*/ false).Run(state);
testing::ItemsProcessed(batch_size * rows * cols * depth * iters);
testing::SetLabel(label);
state.SetItemsProcessed(batch_size * rows * cols * depth *
state.iterations());
state.SetLabel(label);
}
// BS: batch_size
@ -1159,23 +1173,23 @@ static void BM_MaxPoolBk(int iters, int batch_size, int rows, int cols,
static void \
BM_MaxPoolBk_GPU_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_ \
##PT##_##TH( \
int iters) { \
BM_MaxPoolBk(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, true, LABEL); \
::testing::benchmark::State& state) { \
BM_MaxPoolBk(state, BS, IR, IC, ND, KR, KC, ST, PT, TH, true, LABEL); \
} \
BENCHMARK( \
BM_MaxPoolBk_GPU_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_ \
##PT##_##TH) \
##PT##_##TH)->UseRealTime()
#define BM_MaxPoolBkCPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL) \
static void \
BM_MaxPoolBk_CPU_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_ \
##PT##_##TH( \
int iters) { \
BM_MaxPoolBk(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, false, LABEL); \
::testing::benchmark::State& state) { \
BM_MaxPoolBk(state, BS, IR, IC, ND, KR, KC, ST, PT, TH, false, LABEL); \
} \
BENCHMARK( \
BM_MaxPoolBk_CPU_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_ \
##PT##_##TH)
##PT##_##TH)->UseRealTime()
// clang-format on
// Shapes taken from the 2015/05/16 inception model
@ -1195,9 +1209,9 @@ BM_MaxPoolBkCPU(32, 8, 8, 2048, 3, 3, 2, VALID, 1, "maxpool_grad4_VALID");
Relu Op
Run benchmark with:
*/
static void BM_ReluFloat(int iters, int batch_size, int rows, int cols,
int depth, int num_threads, const string& label) {
tensorflow::testing::StopTiming();
static void BM_ReluFloat(::testing::benchmark::State& state, int batch_size,
int rows, int cols, int depth, int num_threads,
const string& label) {
std::unique_ptr<Device> device(
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
@ -1233,16 +1247,13 @@ static void BM_ReluFloat(int iters, int batch_size, int rows, int cols,
std::unique_ptr<OpKernelContext> relu_context(new OpKernelContext(&params));
op->Compute(relu_context.get());
testing::UseRealTime();
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
delete relu_context->release_output(0).tensor;
op->Compute(relu_context.get());
}
tensorflow::testing::StopTiming();
testing::ItemsProcessed(relu_context->mutable_output(0)->NumElements() *
iters);
testing::SetLabel(label);
state.SetItemsProcessed(relu_context->mutable_output(0)->NumElements() *
state.iterations());
state.SetLabel(label);
}
// BS: batch_size
@ -1250,10 +1261,11 @@ static void BM_ReluFloat(int iters, int batch_size, int rows, int cols,
// IC: input_cols
// ND: node_depth
#define BM_Relu(BS, IR, IC, ND, TH, LABEL) \
static void BM_ReluFloat_##BS##_##IR##_##IC##_##ND##_##TH(int iters) { \
BM_ReluFloat(iters, BS, IR, IC, ND, TH, LABEL); \
static void BM_ReluFloat_##BS##_##IR##_##IC##_##ND##_##TH( \
::testing::benchmark::State& state) { \
BM_ReluFloat(state, BS, IR, IC, ND, TH, LABEL); \
} \
BENCHMARK(BM_ReluFloat_##BS##_##IR##_##IC##_##ND##_##TH)
BENCHMARK(BM_ReluFloat_##BS##_##IR##_##IC##_##ND##_##TH)->UseRealTime()
BM_Relu(32, 112, 112, 64, 1, "relu0");
BM_Relu(32, 56, 56, 192, 1, "relu1");
@ -1268,9 +1280,9 @@ BM_Relu(32, 14, 14, 576, 4, "relu10");
Softplus Op
Run benchmark with:
*/
static void BM_SoftplusFloat(int iters, int batch_size, int rows, int cols,
int depth, int num_threads, const string& label) {
tensorflow::testing::StopTiming();
static void BM_SoftplusFloat(::testing::benchmark::State& state, int batch_size,
int rows, int cols, int depth, int num_threads,
const string& label) {
std::unique_ptr<Device> device(
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
@ -1307,16 +1319,13 @@ static void BM_SoftplusFloat(int iters, int batch_size, int rows, int cols,
new OpKernelContext(&params));
op->Compute(softplus_context.get());
testing::UseRealTime();
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
delete softplus_context->release_output(0).tensor;
op->Compute(softplus_context.get());
}
tensorflow::testing::StopTiming();
testing::ItemsProcessed(softplus_context->mutable_output(0)->NumElements() *
iters);
testing::SetLabel(label);
state.SetItemsProcessed(softplus_context->mutable_output(0)->NumElements() *
state.iterations());
state.SetLabel(label);
}
// BS: batch_size
@ -1324,10 +1333,11 @@ static void BM_SoftplusFloat(int iters, int batch_size, int rows, int cols,
// IC: input_cols
// ND: node_depth
#define BM_Softplus(BS, IR, IC, ND, TH, LABEL) \
static void BM_SoftplusFloat_##BS##_##IR##_##IC##_##ND##_##TH(int iters) { \
BM_SoftplusFloat(iters, BS, IR, IC, ND, TH, LABEL); \
static void BM_SoftplusFloat_##BS##_##IR##_##IC##_##ND##_##TH( \
::testing::benchmark::State& state) { \
BM_SoftplusFloat(state, BS, IR, IC, ND, TH, LABEL); \
} \
BENCHMARK(BM_SoftplusFloat_##BS##_##IR##_##IC##_##ND##_##TH)
BENCHMARK(BM_SoftplusFloat_##BS##_##IR##_##IC##_##ND##_##TH)->UseRealTime()
BM_Softplus(32, 112, 112, 64, 1, "softplus0");
BM_Softplus(32, 56, 56, 192, 1, "softplus1");
@ -1338,7 +1348,8 @@ BM_Softplus(32, 56, 56, 192, 4, "softplus1");
BM_Softplus(32, 28, 28, 352, 4, "softplus4");
BM_Softplus(32, 14, 14, 576, 4, "softplus10");
static void BM_ImageNetSoftmaxFwd(int iters, int batch_size, int node_depth,
static void BM_ImageNetSoftmaxFwd(::testing::benchmark::State& state,
int batch_size, int node_depth,
int num_threads, bool use_gpu,
const string& label) {
auto root = Scope::NewRootScope().ExitOnError();
@ -1359,19 +1370,21 @@ static void BM_ImageNetSoftmaxFwd(int iters, int batch_size, int node_depth,
opts.config.mutable_graph_options()
->mutable_optimizer_options()
->set_opt_level(OptimizerOptions::L0);
testing::UseRealTime();
test::Benchmark(device, g, &opts).Run(iters);
testing::ItemsProcessed(batch_size * node_depth * iters);
testing::SetLabel(label);
test::Benchmark(device, g, &opts, nullptr, nullptr, "",
/*old_benchmark_api*/ false)
.Run(state);
state.SetItemsProcessed(batch_size * node_depth * state.iterations());
state.SetLabel(label);
}
#define BM_ImageNetSoftmaxFwd(BATCH_SIZE, NODE_DEPTH, TH, GPU, LABEL) \
static void \
BM_ImageNetSoftmaxFwd_##BATCH_SIZE##_##NODE_DEPTH##_##TH##_##GPU( \
int iters) { \
BM_ImageNetSoftmaxFwd(iters, BATCH_SIZE, NODE_DEPTH, TH, GPU, LABEL); \
::testing::benchmark::State& state) { \
BM_ImageNetSoftmaxFwd(state, BATCH_SIZE, NODE_DEPTH, TH, GPU, LABEL); \
} \
BENCHMARK(BM_ImageNetSoftmaxFwd_##BATCH_SIZE##_##NODE_DEPTH##_##TH##_##GPU)
BENCHMARK(BM_ImageNetSoftmaxFwd_##BATCH_SIZE##_##NODE_DEPTH##_##TH##_##GPU) \
->UseRealTime()
// Labels are taken from the 2014-July-24 version of imagenet
BM_ImageNetSoftmaxFwd(32, 1008, 1, false, "softmax32");
@ -1383,9 +1396,8 @@ BM_ImageNetSoftmaxFwd(128, 1008, 1, true, "softmax128");
BM_ImageNetSoftmaxFwd(8192, 1024, 1, true, "softmax32");
BM_ImageNetSoftmaxFwd(8192, 32768, 1, true, "softmax128");
static void BM_TopK(int iters, int rows, int cols, int k, int num_threads,
bool use_gpu, const string& label) {
testing::StopTiming();
static void BM_TopK(::testing::benchmark::State& state, int rows, int cols,
int k, int num_threads, bool use_gpu, const string& label) {
auto root = Scope::NewRootScope().ExitOnError();
Tensor input(DT_FLOAT, TensorShape({rows, cols}));
@ -1407,11 +1419,11 @@ static void BM_TopK(int iters, int rows, int cols, int k, int num_threads,
opts.config.mutable_graph_options()
->mutable_optimizer_options()
->set_opt_level(OptimizerOptions::L0);
testing::UseRealTime();
testing::StartTiming();
test::Benchmark(device, g, &opts).Run(iters);
testing::ItemsProcessed(rows * cols * iters);
testing::SetLabel(label);
test::Benchmark(device, g, &opts, nullptr, nullptr, "",
/*old_benchmark_api=*/false)
.Run(state);
state.SetItemsProcessed(rows * cols * state.iterations());
state.SetLabel(label);
}
// IR: input_rows
@ -1419,16 +1431,18 @@ static void BM_TopK(int iters, int rows, int cols, int k, int num_threads,
// IK: k
// TH: number of threads
#define BM_TopKGPU(IR, IC, IK, TH, LABEL) \
static void BM_TopK_GPU_##IR##_##IC##_##IK##_##TH(int iters) { \
BM_TopK(iters, IR, IC, IK, TH, true, LABEL); \
static void BM_TopK_GPU_##IR##_##IC##_##IK##_##TH( \
::testing::benchmark::State& state) { \
BM_TopK(state, IR, IC, IK, TH, true, LABEL); \
} \
BENCHMARK(BM_TopK_GPU_##IR##_##IC##_##IK##_##TH)
BENCHMARK(BM_TopK_GPU_##IR##_##IC##_##IK##_##TH)->UseRealTime()
#define BM_TopKCPU(IR, IC, IK, TH, LABEL) \
static void BM_TopK_CPU_##IR##_##IC##_##IK##_##TH(int iters) { \
BM_TopK(iters, IR, IC, IK, TH, false, LABEL); \
static void BM_TopK_CPU_##IR##_##IC##_##IK##_##TH( \
::testing::benchmark::State& state) { \
BM_TopK(state, IR, IC, IK, TH, false, LABEL); \
} \
BENCHMARK(BM_TopK_CPU_##IR##_##IC##_##IK##_##TH)
BENCHMARK(BM_TopK_CPU_##IR##_##IC##_##IK##_##TH)->UseRealTime()
// clang-format on

View File

@ -56,9 +56,13 @@ static Graph* OneHot(int batch_size, int num_classes, int axis) {
}
#define BM_OneHot(BATCH, CLASS, AXIS, DEVICE) \
static void BM_OneHot##_##BATCH##_##CLASS##_##AXIS##_##DEVICE(int iters) { \
testing::ItemsProcessed(static_cast<int64>(iters) * BATCH * CLASS); \
test::Benchmark(#DEVICE, OneHot(BATCH, CLASS, AXIS)).Run(iters); \
static void BM_OneHot##_##BATCH##_##CLASS##_##AXIS##_##DEVICE( \
::testing::benchmark::State& state) { \
test::Benchmark(#DEVICE, OneHot(BATCH, CLASS, AXIS), \
/*old_benchmark_api*/ false) \
.Run(state); \
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * BATCH * \
CLASS); \
} \
BENCHMARK(BM_OneHot##_##BATCH##_##CLASS##_##AXIS##_##DEVICE);

View File

@ -108,23 +108,32 @@ static Graph* PTruncatedNormalOneTail(int num_batches, int samples_per_batch) {
}
#define BM_PTruncatedNormalDev(DEVICE, B, S) \
static void BM_PTruncatedNormal_##DEVICE##_##B##_##S(int iters) { \
test::Benchmark(#DEVICE, PTruncatedNormal(B, S)).Run(iters); \
testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \
static void BM_PTruncatedNormal_##DEVICE##_##B##_##S( \
::testing::benchmark::State& state) { \
test::Benchmark(#DEVICE, PTruncatedNormal(B, S), \
/*old_benchmark_api*/ false) \
.Run(state); \
state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \
} \
BENCHMARK(BM_PTruncatedNormal_##DEVICE##_##B##_##S);
#define BM_PTruncatedNormalDev_2SD(DEVICE, B, S) \
static void BM_PTruncatedNormal_2SD_##DEVICE##_##B##_##S(int iters) { \
test::Benchmark(#DEVICE, PTruncatedNormal2SD(B, S)).Run(iters); \
testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \
static void BM_PTruncatedNormal_2SD_##DEVICE##_##B##_##S( \
::testing::benchmark::State& state) { \
test::Benchmark(#DEVICE, PTruncatedNormal2SD(B, S), \
/*old_benchmark_api*/ false) \
.Run(state); \
state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \
} \
BENCHMARK(BM_PTruncatedNormal_2SD_##DEVICE##_##B##_##S);
#define BM_PTruncatedNormalDev_OneTail(DEVICE, B, S) \
static void BM_PTruncatedNormal_OneTail_##DEVICE##_##B##_##S(int iters) { \
test::Benchmark(#DEVICE, PTruncatedNormalOneTail(B, S)).Run(iters); \
testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \
static void BM_PTruncatedNormal_OneTail_##DEVICE##_##B##_##S( \
::testing::benchmark::State& state) { \
test::Benchmark(#DEVICE, PTruncatedNormalOneTail(B, S), \
/*old_benchmark_api*/ false) \
.Run(state); \
state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \
} \
BENCHMARK(BM_PTruncatedNormal_OneTail_##DEVICE##_##B##_##S);

View File

@ -760,13 +760,14 @@ TEST_F(QuantizeAndDequantizeTest, Invalid_range_given_V3) {
}
#define BM_SIMPLE_QUAN_DEQUAN(DEVICE) \
static void BM_SIMPLE_QUAN_DEQUAN_##DEVICE(int iters) { \
static void BM_SIMPLE_QUAN_DEQUAN_##DEVICE( \
::testing::benchmark::State& state) { \
auto root = Scope::NewRootScope().ExitOnError(); \
ops::QuantizeAndDequantizeV2(root, -3.5, -3.5, -3.5); \
TF_CHECK_OK(root.status()); \
Graph* g = new Graph(OpRegistry::Global()); \
TF_CHECK_OK(root.ToGraph(g)); \
test::Benchmark(#DEVICE, g).Run(iters); \
test::Benchmark(#DEVICE, g, /*old_benchmark_api*/ false).Run(state); \
} \
BENCHMARK(BM_SIMPLE_QUAN_DEQUAN_##DEVICE);

View File

@ -248,9 +248,8 @@ void QuantizedConcatTest::TestSecondDim8Bit(float first_min, float first_max,
// If <same_limits> is true, then both concatenated dimensions have the same
// quantized range; otherwise, they are set to different values.
template <typename T>
static void ConcatHelper(int iters, int concat_dimension, bool same_limits,
int dim2) {
testing::StopTiming();
static void ConcatHelper(::testing::benchmark::State& state,
int concat_dimension, bool same_limits, int dim2) {
Graph* g = new Graph(OpRegistry::Global());
DataType dt = DataTypeToEnum<T>::v();
@ -278,61 +277,111 @@ static void ConcatHelper(int iters, int concat_dimension, bool same_limits,
.Attr("T", dt)
.Finalize(g, &node));
testing::BytesProcessed(static_cast<int64>(iters) *
test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state);
state.SetBytesProcessed(static_cast<int64>(state.iterations()) *
((kDim1 * dim2) + (kDim1 * dim2)) * sizeof(T));
testing::StartTiming();
test::Benchmark("cpu", g).Run(iters);
testing::UseRealTime();
}
static void BM_QConcatDim0SameLimitQInt32(int iters, int dim2) {
ConcatHelper<qint32>(iters, 0 /* concat_dimension */, true /* same_limits */,
static void BM_QConcatDim0SameLimitQInt32(::testing::benchmark::State& state) {
const int dim2 = state.range(0);
ConcatHelper<qint32>(state, 0 /* concat_dimension */, true /* same_limits */,
dim2);
}
static void BM_QConcatDim1SameLimitQInt32(int iters, int dim2) {
ConcatHelper<qint32>(iters, 1 /* concat_dimension */, true /* same_limits */,
static void BM_QConcatDim1SameLimitQInt32(::testing::benchmark::State& state) {
const int dim2 = state.range(0);
ConcatHelper<qint32>(state, 1 /* concat_dimension */, true /* same_limits */,
dim2);
}
static void BM_QConcatDim0DifferLimitQInt32(int iters, int dim2) {
ConcatHelper<qint32>(iters, 0 /* concat_dimension */, false /* same_limits */,
static void BM_QConcatDim0DifferLimitQInt32(
::testing::benchmark::State& state) {
const int dim2 = state.range(0);
ConcatHelper<qint32>(state, 0 /* concat_dimension */, false /* same_limits */,
dim2);
}
static void BM_QConcatDim1DifferLimitQInt32(int iters, int dim2) {
ConcatHelper<qint32>(iters, 1 /* concat_dimension */, false /* same_limits */,
static void BM_QConcatDim1DifferLimitQInt32(
::testing::benchmark::State& state) {
const int dim2 = state.range(0);
ConcatHelper<qint32>(state, 1 /* concat_dimension */, false /* same_limits */,
dim2);
}
BENCHMARK(BM_QConcatDim0SameLimitQInt32)->Arg(1000)->Arg(20000)->Arg(100000);
BENCHMARK(BM_QConcatDim1SameLimitQInt32)->Arg(1000)->Arg(20000)->Arg(100000);
BENCHMARK(BM_QConcatDim0DifferLimitQInt32)->Arg(1000)->Arg(20000)->Arg(100000);
BENCHMARK(BM_QConcatDim1DifferLimitQInt32)->Arg(1000)->Arg(20000)->Arg(100000);
BENCHMARK(BM_QConcatDim0SameLimitQInt32)
->UseRealTime()
->Arg(1000)
->Arg(20000)
->Arg(100000);
BENCHMARK(BM_QConcatDim1SameLimitQInt32)
->UseRealTime()
->Arg(1000)
->Arg(20000)
->Arg(100000);
BENCHMARK(BM_QConcatDim0DifferLimitQInt32)
->UseRealTime()
->Arg(1000)
->Arg(20000)
->Arg(100000);
BENCHMARK(BM_QConcatDim1DifferLimitQInt32)
->UseRealTime()
->Arg(1000)
->Arg(20000)
->Arg(100000);
static void BM_QConcatDim0SameLimitQUint8(int iters, int dim2) {
ConcatHelper<qint32>(iters, 0 /* concat_dimension */, true /* same_limits */,
static void BM_QConcatDim0SameLimitQUint8(::testing::benchmark::State& state) {
const int dim2 = state.range(0);
ConcatHelper<qint32>(state, 0 /* concat_dimension */, true /* same_limits */,
dim2);
}
static void BM_QConcatDim1SameLimitQUint8(int iters, int dim2) {
ConcatHelper<qint32>(iters, 1 /* concat_dimension */, true /* same_limits */,
static void BM_QConcatDim1SameLimitQUint8(::testing::benchmark::State& state) {
const int dim2 = state.range(0);
ConcatHelper<qint32>(state, 1 /* concat_dimension */, true /* same_limits */,
dim2);
}
static void BM_QConcatDim0DifferLimitQUint8(int iters, int dim2) {
ConcatHelper<qint32>(iters, 0 /* concat_dimension */, false /* same_limits */,
static void BM_QConcatDim0DifferLimitQUint8(
::testing::benchmark::State& state) {
const int dim2 = state.range(0);
ConcatHelper<qint32>(state, 0 /* concat_dimension */, false /* same_limits */,
dim2);
}
static void BM_QConcatDim1DifferLimitQUint8(int iters, int dim2) {
ConcatHelper<qint32>(iters, 1 /* concat_dimension */, false /* same_limits */,
static void BM_QConcatDim1DifferLimitQUint8(
::testing::benchmark::State& state) {
const int dim2 = state.range(0);
ConcatHelper<qint32>(state, 1 /* concat_dimension */, false /* same_limits */,
dim2);
}
BENCHMARK(BM_QConcatDim0SameLimitQUint8)->Arg(1000)->Arg(20000)->Arg(100000);
BENCHMARK(BM_QConcatDim1SameLimitQUint8)->Arg(1000)->Arg(20000)->Arg(100000);
BENCHMARK(BM_QConcatDim0DifferLimitQUint8)->Arg(1000)->Arg(20000)->Arg(100000);
BENCHMARK(BM_QConcatDim1DifferLimitQUint8)->Arg(1000)->Arg(20000)->Arg(100000);
BENCHMARK(BM_QConcatDim0SameLimitQUint8)
->UseRealTime()
->Arg(1000)
->Arg(20000)
->Arg(100000);
BENCHMARK(BM_QConcatDim1SameLimitQUint8)
->UseRealTime()
->Arg(1000)
->Arg(20000)
->Arg(100000);
BENCHMARK(BM_QConcatDim0DifferLimitQUint8)
->UseRealTime()
->Arg(1000)
->Arg(20000)
->Arg(100000);
BENCHMARK(BM_QConcatDim1DifferLimitQUint8)
->UseRealTime()
->Arg(1000)
->Arg(20000)
->Arg(100000);
} // namespace tensorflow

View File

@ -68,30 +68,42 @@ static Graph* RandomBinomialRejComplement(int num_batches,
}
#define BM_RandomBinomialInv(DEVICE, B, S) \
static void BM_RandomBinomialInv_##DEVICE##_##B##_##S(int iters) { \
test::Benchmark(#DEVICE, RandomBinomialInv(B, S)).Run(iters); \
testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \
static void BM_RandomBinomialInv_##DEVICE##_##B##_##S( \
::testing::benchmark::State& state) { \
test::Benchmark(#DEVICE, RandomBinomialInv(B, S), \
/*old_benchmark_api*/ false) \
.Run(state); \
state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \
} \
BENCHMARK(BM_RandomBinomialInv_##DEVICE##_##B##_##S);
#define BM_RandomBinomialRej(DEVICE, B, S) \
static void BM_RandomBinomialRej_##DEVICE##_##B##_##S(int iters) { \
test::Benchmark(#DEVICE, RandomBinomialRej(B, S)).Run(iters); \
testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \
static void BM_RandomBinomialRej_##DEVICE##_##B##_##S( \
::testing::benchmark::State& state) { \
test::Benchmark(#DEVICE, RandomBinomialRej(B, S), \
/*old_benchmark_api*/ false) \
.Run(state); \
state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \
} \
BENCHMARK(BM_RandomBinomialRej_##DEVICE##_##B##_##S);
#define BM_RandomBinomialInvComplement(DEVICE, B, S) \
static void BM_RandomBinomialInvComplement_##DEVICE##_##B##_##S(int iters) { \
test::Benchmark(#DEVICE, RandomBinomialInvComplement(B, S)).Run(iters); \
testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \
static void BM_RandomBinomialInvComplement_##DEVICE##_##B##_##S( \
::testing::benchmark::State& state) { \
test::Benchmark(#DEVICE, RandomBinomialInvComplement(B, S), \
/*old_benchmark_api*/ false) \
.Run(state); \
state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \
} \
BENCHMARK(BM_RandomBinomialInvComplement_##DEVICE##_##B##_##S);
#define BM_RandomBinomialRejComplement(DEVICE, B, S) \
static void BM_RandomBinomialRejComplement_##DEVICE##_##B##_##S(int iters) { \
test::Benchmark(#DEVICE, RandomBinomialRejComplement(B, S)).Run(iters); \
testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \
static void BM_RandomBinomialRejComplement_##DEVICE##_##B##_##S( \
::testing::benchmark::State& state) { \
test::Benchmark(#DEVICE, RandomBinomialRejComplement(B, S), \
/*old_benchmark_api*/ false) \
.Run(state); \
state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \
} \
BENCHMARK(BM_RandomBinomialRejComplement_##DEVICE##_##B##_##S);

View File

@ -59,9 +59,12 @@ Graph* TruncatedNormal(int64 n) {
}
#define BM_RNG(DEVICE, RNG) \
void BM_##DEVICE##_##RNG(int iters, int arg) { \
testing::ItemsProcessed(static_cast<int64>(iters) * arg); \
test::Benchmark(#DEVICE, RNG(arg)).Run(iters); \
void BM_##DEVICE##_##RNG(::testing::benchmark::State& state) { \
const int arg = state.range(0); \
\
test::Benchmark(#DEVICE, RNG(arg), /*old_benchmark_api*/ false) \
.Run(state); \
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * arg); \
} \
BENCHMARK(BM_##DEVICE##_##RNG)->Range(1 << 20, 8 << 20);
@ -84,60 +87,48 @@ Tensor VecAlphas(int64 n) {
return alphas;
}
void BM_cpu_RandomGamma(int iters, int nsamp, int nalpha) {
testing::ItemsProcessed(static_cast<int64>(iters) * nsamp * nalpha);
void BM_cpu_RandomGamma(::testing::benchmark::State& state) {
const int nsamp = state.range(0);
const int nalpha = state.range(1);
Graph* g = new Graph(OpRegistry::Global());
test::graph::RandomGamma(g, test::graph::Constant(g, VecShape(nsamp)),
test::graph::Constant(g, VecAlphas(nalpha)));
test::Benchmark("cpu", g).Run(iters);
test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state);
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * nsamp *
nalpha);
}
BENCHMARK(BM_cpu_RandomGamma)->RangePair(1 << 14, 4 << 15, 2, 50);
void BM_PhiloxRandom(int iters) {
void BM_PhiloxRandom(::testing::benchmark::State& state) {
// Fill 2M random numbers
int count = 2 << 20;
testing::ItemsProcessed(static_cast<int64>(iters) * count);
random::PhiloxRandom gen(0x12345);
int val = 1;
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
for (int j = 0; j < count; j += 4) {
/// each invocation of gen() returns 128-bit samples
auto samples = gen();
// use the result trivially so the compiler does not optimize it away
val ^= samples[0] ^ samples[1] ^ samples[2] ^ samples[3];
tensorflow::testing::DoNotOptimize(samples);
}
}
// A anchor point to make sure the compiler does not cut corners
CHECK(val) << val;
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * count);
}
BENCHMARK(BM_PhiloxRandom);
void BM_StdMTRandom(int iters) {
void BM_StdMTRandom(::testing::benchmark::State& state) {
// Fill 2M random numbers
int count = 2 << 20;
testing::ItemsProcessed(static_cast<int64>(iters) * count);
std::mt19937 gen(0x12345);
uint_fast32_t val = 1;
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
for (int j = 0; j < count; ++j) {
/// each invocation of gen() returns 32-bit sample
uint_fast32_t sample = gen();
// use the result trivially so the compiler does not optimize it away
val ^= sample;
tensorflow::testing::DoNotOptimize(sample);
}
}
// A anchor point to make sure the compiler does not cut corners
CHECK(val) << val;
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * count);
}
BENCHMARK(BM_StdMTRandom);

View File

@ -84,108 +84,167 @@ static Graph* ThreeDXZReduce(const string& reduce, int num_y, int num_z) {
// Creates a bench which reduces a 3D tensor with total "num" floats
// into a scalar on a "device". Runs the bench for "iters" times.
template <typename T>
static void ReduceToScalar(int iters, const string& device,
const string& reduce, int num_x, int num_y) {
testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
sizeof(T));
test::Benchmark(device, ToScalar<T>(reduce, num_x, num_y)).Run(iters);
}
static void DoRowReduce(int iters, const string& device, const string& reduce,
static void ReduceToScalar(::testing::benchmark::State& state,
const string& device, const string& reduce,
int num_x, int num_y) {
testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
sizeof(float));
test::Benchmark(device, RowReduce(reduce, num_x, num_y)).Run(iters);
test::Benchmark(device, ToScalar<T>(reduce, num_x, num_y),
/*old_benchmark_api*/ false)
.Run(state);
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x *
num_y);
state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
num_y * sizeof(T));
}
static void DoColReduce(int iters, const string& device, const string& reduce,
int num_x, int num_y) {
testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
sizeof(float));
test::Benchmark(device, ColReduce(reduce, num_x, num_y)).Run(iters);
static void DoRowReduce(::testing::benchmark::State& state,
const string& device, const string& reduce, int num_x,
int num_y) {
test::Benchmark(device, RowReduce(reduce, num_x, num_y),
/*old_benchmark_api*/ false)
.Run(state);
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x *
num_y);
state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
num_y * sizeof(float));
}
static void Do3DYReduce(int iters, const string& device, const string& reduce,
int num_x, int num_y) {
testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
sizeof(float));
test::Benchmark(device, ThreeDYReduce(reduce, num_x, num_y)).Run(iters);
static void DoColReduce(::testing::benchmark::State& state,
const string& device, const string& reduce, int num_x,
int num_y) {
test::Benchmark(device, ColReduce(reduce, num_x, num_y),
/*old_benchmark_api*/ false)
.Run(state);
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x *
num_y);
state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
num_y * sizeof(float));
}
static void Do3DXZReduce(int iters, const string& device, const string& reduce,
int num_x, int num_y) {
testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
sizeof(float));
test::Benchmark(device, ThreeDXZReduce(reduce, num_x, num_y)).Run(iters);
static void Do3DYReduce(::testing::benchmark::State& state,
const string& device, const string& reduce, int num_x,
int num_y) {
test::Benchmark(device, ThreeDYReduce(reduce, num_x, num_y),
/*old_benchmark_api*/ false)
.Run(state);
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x *
num_y);
state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
num_y * sizeof(float));
}
static void BM_Sum2DToScalarGPU(int iters, int num_x, int num_y) {
ReduceToScalar<float>(iters, "gpu", "Sum", num_x, num_y);
static void Do3DXZReduce(::testing::benchmark::State& state,
const string& device, const string& reduce, int num_x,
int num_y) {
test::Benchmark(device, ThreeDXZReduce(reduce, num_x, num_y),
/*old_benchmark_api*/ false)
.Run(state);
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x *
num_y);
state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
num_y * sizeof(float));
}
static void BM_Sum2DToScalarGPU(::testing::benchmark::State& state) {
const int num_x = state.range(0);
const int num_y = state.range(1);
ReduceToScalar<float>(state, "gpu", "Sum", num_x, num_y);
}
BENCHMARK(BM_Sum2DToScalarGPU)->RangePair(1, 8192, 1, 8192);
static void BM_Sum2DToScalarGPUComplex(int iters, int num_x, int num_y) {
ReduceToScalar<std::complex<float>>(iters, "gpu", "Sum", num_x, num_y);
static void BM_Sum2DToScalarGPUComplex(::testing::benchmark::State& state) {
const int num_x = state.range(0);
const int num_y = state.range(1);
ReduceToScalar<std::complex<float>>(state, "gpu", "Sum", num_x, num_y);
}
BENCHMARK(BM_Sum2DToScalarGPUComplex)->RangePair(1, 8192, 1, 8192);
static void BM_Sum2DToScalarGPUHalf(int iters, int num_x, int num_y) {
ReduceToScalar<Eigen::half>(iters, "gpu", "Sum", num_x, num_y);
static void BM_Sum2DToScalarGPUHalf(::testing::benchmark::State& state) {
const int num_x = state.range(0);
const int num_y = state.range(1);
ReduceToScalar<Eigen::half>(state, "gpu", "Sum", num_x, num_y);
}
BENCHMARK(BM_Sum2DToScalarGPUHalf)->RangePair(1, 8192, 1, 8192);
static void BM_Sum2DRowReduceGPU(int iters, int num_x, int num_y) {
DoRowReduce(iters, "gpu", "Sum", num_x, num_y);
static void BM_Sum2DRowReduceGPU(::testing::benchmark::State& state) {
const int num_x = state.range(0);
const int num_y = state.range(1);
DoRowReduce(state, "gpu", "Sum", num_x, num_y);
}
BENCHMARK(BM_Sum2DRowReduceGPU)->RangePair(1, 8192, 1, 8192);
static void BM_Sum2DColumnReduceGPU(int iters, int num_x, int num_y) {
DoColReduce(iters, "gpu", "Sum", num_x, num_y);
static void BM_Sum2DColumnReduceGPU(::testing::benchmark::State& state) {
const int num_x = state.range(0);
const int num_y = state.range(1);
DoColReduce(state, "gpu", "Sum", num_x, num_y);
}
BENCHMARK(BM_Sum2DColumnReduceGPU)->RangePair(1, 8192, 1, 8192);
static void BM_Sum3DYReduceGPU(int iters, int num_x, int num_y) {
Do3DYReduce(iters, "gpu", "Sum", num_x, num_y);
static void BM_Sum3DYReduceGPU(::testing::benchmark::State& state) {
const int num_x = state.range(0);
const int num_y = state.range(1);
Do3DYReduce(state, "gpu", "Sum", num_x, num_y);
}
BENCHMARK(BM_Sum3DYReduceGPU)->RangePair(64, 4096, 64, 4096);
static void BM_Sum3DXZReduceGPU(int iters, int num_x, int num_y) {
Do3DXZReduce(iters, "gpu", "Sum", num_x, num_y);
static void BM_Sum3DXZReduceGPU(::testing::benchmark::State& state) {
const int num_x = state.range(0);
const int num_y = state.range(1);
Do3DXZReduce(state, "gpu", "Sum", num_x, num_y);
}
BENCHMARK(BM_Sum3DXZReduceGPU)->RangePair(64, 4096, 64, 4096);
static void BM_Mean2DToScalarGPU(int iters, int num_x, int num_y) {
ReduceToScalar<float>(iters, "gpu", "Mean", num_x, num_y);
static void BM_Mean2DToScalarGPU(::testing::benchmark::State& state) {
const int num_x = state.range(0);
const int num_y = state.range(1);
ReduceToScalar<float>(state, "gpu", "Mean", num_x, num_y);
}
BENCHMARK(BM_Mean2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192);
static void BM_EuclideanNorm2DToScalarGPU(int iters, int num_x, int num_y) {
ReduceToScalar<float>(iters, "gpu", "EuclideanNorm", num_x, num_y);
static void BM_EuclideanNorm2DToScalarGPU(::testing::benchmark::State& state) {
const int num_x = state.range(0);
const int num_y = state.range(1);
ReduceToScalar<float>(state, "gpu", "EuclideanNorm", num_x, num_y);
}
BENCHMARK(BM_EuclideanNorm2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192);
static void BM_Max2DToScalarGPU(int iters, int num_x, int num_y) {
ReduceToScalar<float>(iters, "gpu", "Max", num_x, num_y);
static void BM_Max2DToScalarGPU(::testing::benchmark::State& state) {
const int num_x = state.range(0);
const int num_y = state.range(1);
ReduceToScalar<float>(state, "gpu", "Max", num_x, num_y);
}
BENCHMARK(BM_Max2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192);
static void BM_Min2DToScalarGPU(int iters, int num_x, int num_y) {
ReduceToScalar<float>(iters, "gpu", "Min", num_x, num_y);
static void BM_Min2DToScalarGPU(::testing::benchmark::State& state) {
const int num_x = state.range(0);
const int num_y = state.range(1);
ReduceToScalar<float>(state, "gpu", "Min", num_x, num_y);
}
BENCHMARK(BM_Min2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192);
static void BM_Min2DToScalarGPUHalf(int iters, int num_x, int num_y) {
ReduceToScalar<Eigen::half>(iters, "gpu", "Min", num_x, num_y);
static void BM_Min2DToScalarGPUHalf(::testing::benchmark::State& state) {
const int num_x = state.range(0);
const int num_y = state.range(1);
ReduceToScalar<Eigen::half>(state, "gpu", "Min", num_x, num_y);
}
BENCHMARK(BM_Min2DToScalarGPUHalf)->RangePair(2048, 8192, 2048, 8192);
static void BM_Bool2DToScalarGPU(int iters, int num_x, int num_y) {
ReduceToScalar<bool>(iters, "gpu", "All", num_x, num_y);
static void BM_Bool2DToScalarGPU(::testing::benchmark::State& state) {
const int num_x = state.range(0);
const int num_y = state.range(1);
ReduceToScalar<bool>(state, "gpu", "All", num_x, num_y);
}
BENCHMARK(BM_Bool2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192);

View File

@ -84,17 +84,17 @@ Graph* SetupRegexReplaceGraph(const Tensor& input, const string& input_pattern,
return g;
}
void BM_RegexReplace(int iters, int batch_size) {
testing::StopTiming();
testing::ItemsProcessed(static_cast<int64>(iters));
testing::UseRealTime();
static void BM_RegexReplace(::testing::benchmark::State& state) {
const int batch_size = state.range(0);
Tensor input = GetTestTensor(batch_size);
Graph* g = SetupRegexReplaceGraph(input, kRegExPattern, kRewrite);
testing::StartTiming();
test::Benchmark("cpu", g).Run(iters);
test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state);
state.SetItemsProcessed(static_cast<int64>(state.iterations()));
}
BENCHMARK(BM_RegexReplace)
->UseRealTime()
->Arg(1)
->Arg(8)
->Arg(16)
@ -115,17 +115,17 @@ Graph* SetupStaticGraph(const Tensor& input, const string& input_pattern,
.Finalize(g, nullptr /* node */));
return g;
}
void BM_StaticRegexReplace(int iters, int batch_size) {
testing::StopTiming();
testing::ItemsProcessed(static_cast<int64>(iters));
testing::UseRealTime();
static void BM_StaticRegexReplace(::testing::benchmark::State& state) {
const int batch_size = state.range(0);
Tensor input = GetTestTensor(batch_size);
Graph* g = SetupStaticGraph(input, kRegExPattern, kRewrite);
testing::StartTiming();
test::Benchmark("cpu", g).Run(iters);
test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state);
state.SetItemsProcessed(static_cast<int64>(state.iterations()));
}
BENCHMARK(BM_StaticRegexReplace)
->UseRealTime()
->Arg(1)
->Arg(8)
->Arg(16)

View File

@ -67,56 +67,29 @@ TEST_F(RequantizationRangeTest, HandCrafted) {
test::ExpectTensorEqual<float>(expected_max, *GetOutput(1));
}
static void BM_RequantizationRange(int iters, int size) {
testing::StopTiming();
testing::UseRealTime();
testing::ItemsProcessed(static_cast<int64>(iters) * size);
testing::ItemsProcessed(static_cast<int64>(iters) * size * 4);
static void BM_RequantizationRange(::testing::benchmark::State& state) {
const int size = state.range(0);
Tensor quantized_tensor(DT_QINT32, TensorShape({1, size}));
test::FillFn<qint32>(&quantized_tensor, [](int n) { return qint32(n); });
qint32 actual_min;
qint32 actual_max;
testing::StartTiming();
for (int iter = 0; iter < iters; ++iter) {
for (auto s : state) {
CalculateUsedRange(quantized_tensor, &actual_min, &actual_max);
}
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * size);
state.SetBytesProcessed(static_cast<int64>(state.iterations()) * size * 4);
}
static void BM_RequantizationRange100(int iters) {
BM_RequantizationRange(100, iters);
}
BENCHMARK(BM_RequantizationRange100);
static void BM_RequantizationRange1000(int iters) {
BM_RequantizationRange(1000, iters);
}
BENCHMARK(BM_RequantizationRange1000);
static void BM_RequantizationRange10000(int iters) {
BM_RequantizationRange(10000, iters);
}
BENCHMARK(BM_RequantizationRange10000);
static void BM_RequantizationRange100000(int iters) {
BM_RequantizationRange(100000, iters);
}
BENCHMARK(BM_RequantizationRange100000);
static void BM_RequantizationRange1000000(int iters) {
BM_RequantizationRange(1000000, iters);
}
BENCHMARK(BM_RequantizationRange1000000);
static void BM_RequantizationRange10000000(int iters) {
BM_RequantizationRange(10000000, iters);
}
BENCHMARK(BM_RequantizationRange10000000);
static void BM_RequantizationRange100000000(int iters) {
BM_RequantizationRange(100000000, iters);
}
BENCHMARK(BM_RequantizationRange100000000);
BENCHMARK(BM_RequantizationRange)
->UseRealTime()
->Arg(100)
->Arg(1000)
->Arg(10000)
->Arg(100000)
->Arg(1000000)
->Arg(10000000)
->Arg(100000000);
} // end namespace tensorflow

View File

@ -197,148 +197,187 @@ static Graph* Reverse(const TensorShape& shape, int reverse_axis) {
}
template <typename T>
static void RunReverseRowsBenchmark(int iters, int outer_dim, int middle_dim,
static void RunReverseRowsBenchmark(::testing::benchmark::State& state,
int outer_dim, int middle_dim,
int intra_threads, int channels) {
SessionOptions opts = GetOptions(intra_threads);
TensorShape shape{outer_dim, middle_dim, channels};
const int64 num_items = static_cast<int64>(iters) * shape.num_elements();
testing::ItemsProcessed(num_items);
testing::BytesProcessed(num_items * sizeof(T));
testing::UseRealTime();
test::Benchmark("cpu", Reverse<T>(shape, 1), &opts).Run(iters);
test::Benchmark("cpu", Reverse<T>(shape, 1), &opts, nullptr, nullptr, "",
/*old_benchmark_api*/ false)
.Run(state);
const int64 num_items =
static_cast<int64>(state.iterations()) * shape.num_elements();
state.SetItemsProcessed(num_items);
state.SetBytesProcessed(num_items * sizeof(T));
}
static void BM_ReverseRowsOf1Channel_1T_float(int iters, int outer_dim,
int middle_dim) {
RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim,
void BM_ReverseRowsOf1Channel_1T_float(::testing::benchmark::State& state) {
const int outer_dim = state.range(0);
const int middle_dim = state.range(1);
RunReverseRowsBenchmark<float>(state, outer_dim, middle_dim,
1 /* intra_threads */, 1 /* channels */);
}
BENCHMARK(BM_ReverseRowsOf1Channel_1T_float)
->UseRealTime()
->ArgPair(288, 288)
->ArgPair(1024, 1024)
->ArgPair(10 * 1024, 1024);
static void BM_ReverseRowsOf1Channel_1T_uint8(int iters, int outer_dim,
int middle_dim) {
RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim,
void BM_ReverseRowsOf1Channel_1T_uint8(::testing::benchmark::State& state) {
const int outer_dim = state.range(0);
const int middle_dim = state.range(1);
RunReverseRowsBenchmark<uint8>(state, outer_dim, middle_dim,
1 /* intra_threads */, 1 /* channels */);
}
BENCHMARK(BM_ReverseRowsOf1Channel_1T_uint8)
->UseRealTime()
->ArgPair(288, 288)
->ArgPair(1024, 1024)
->ArgPair(10 * 1024, 1024);
static void BM_ReverseRowsOf1Channel_4T_float(int iters, int outer_dim,
int middle_dim) {
RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim,
void BM_ReverseRowsOf1Channel_4T_float(::testing::benchmark::State& state) {
const int outer_dim = state.range(0);
const int middle_dim = state.range(1);
RunReverseRowsBenchmark<float>(state, outer_dim, middle_dim,
4 /* intra_threads */, 1 /* channels */);
}
BENCHMARK(BM_ReverseRowsOf1Channel_4T_float)
->UseRealTime()
->ArgPair(288, 288)
->ArgPair(1024, 1024)
->ArgPair(10 * 1024, 1024);
static void BM_ReverseRowsOf1Channel_4T_uint8(int iters, int outer_dim,
int middle_dim) {
RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim,
void BM_ReverseRowsOf1Channel_4T_uint8(::testing::benchmark::State& state) {
const int outer_dim = state.range(0);
const int middle_dim = state.range(1);
RunReverseRowsBenchmark<uint8>(state, outer_dim, middle_dim,
4 /* intra_threads */, 1 /* channels */);
}
BENCHMARK(BM_ReverseRowsOf1Channel_4T_uint8)
->UseRealTime()
->ArgPair(288, 288)
->ArgPair(1024, 1024)
->ArgPair(10 * 1024, 1024);
static void BM_ReverseRowsOf3Channels_1T_float(int iters, int outer_dim,
int middle_dim) {
RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim,
void BM_ReverseRowsOf3Channels_1T_float(::testing::benchmark::State& state) {
const int outer_dim = state.range(0);
const int middle_dim = state.range(1);
RunReverseRowsBenchmark<float>(state, outer_dim, middle_dim,
1 /* intra_threads */, 3 /* channels */);
}
BENCHMARK(BM_ReverseRowsOf3Channels_1T_float)
->UseRealTime()
->ArgPair(288, 288)
->ArgPair(30, 30)
->ArgPair(1024, 1024)
->ArgPair(10 * 1024, 1024);
static void BM_ReverseRowsOf3Channels_1T_uint8(int iters, int outer_dim,
int middle_dim) {
RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim,
void BM_ReverseRowsOf3Channels_1T_uint8(::testing::benchmark::State& state) {
const int outer_dim = state.range(0);
const int middle_dim = state.range(1);
RunReverseRowsBenchmark<uint8>(state, outer_dim, middle_dim,
1 /* intra_threads */, 3 /* channels */);
}
BENCHMARK(BM_ReverseRowsOf3Channels_1T_uint8)
->UseRealTime()
->ArgPair(288, 288)
->ArgPair(30, 30)
->ArgPair(1024, 1024)
->ArgPair(10 * 1024, 1024);
static void BM_ReverseRowsOf3Channels_4T_float(int iters, int outer_dim,
int middle_dim) {
RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim,
void BM_ReverseRowsOf3Channels_4T_float(::testing::benchmark::State& state) {
const int outer_dim = state.range(0);
const int middle_dim = state.range(1);
RunReverseRowsBenchmark<float>(state, outer_dim, middle_dim,
4 /* intra_threads */, 3 /* channels */);
}
BENCHMARK(BM_ReverseRowsOf3Channels_4T_float)
->UseRealTime()
->ArgPair(288, 288)
->ArgPair(30, 30)
->ArgPair(1024, 1024)
->ArgPair(10 * 1024, 1024);
static void BM_ReverseRowsOf3Channels_4T_uint8(int iters, int outer_dim,
int middle_dim) {
RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim,
void BM_ReverseRowsOf3Channels_4T_uint8(::testing::benchmark::State& state) {
const int outer_dim = state.range(0);
const int middle_dim = state.range(1);
RunReverseRowsBenchmark<uint8>(state, outer_dim, middle_dim,
4 /* intra_threads */, 3 /* channels */);
}
BENCHMARK(BM_ReverseRowsOf3Channels_4T_uint8)
->UseRealTime()
->ArgPair(288, 288)
->ArgPair(30, 30)
->ArgPair(1024, 1024)
->ArgPair(10 * 1024, 1024);
static void BM_ReverseRowsOf4Channels_1T_float(int iters, int outer_dim,
int middle_dim) {
RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim,
void BM_ReverseRowsOf4Channels_1T_float(::testing::benchmark::State& state) {
const int outer_dim = state.range(0);
const int middle_dim = state.range(1);
RunReverseRowsBenchmark<float>(state, outer_dim, middle_dim,
1 /* intra_threads */, 4 /* channels */);
}
BENCHMARK(BM_ReverseRowsOf4Channels_1T_float)
->UseRealTime()
->ArgPair(288, 288)
->ArgPair(1024, 1024)
->ArgPair(10 * 1024, 1024);
static void BM_ReverseRowsOf4Channels_1T_uint8(int iters, int outer_dim,
int middle_dim) {
RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim,
void BM_ReverseRowsOf4Channels_1T_uint8(::testing::benchmark::State& state) {
const int outer_dim = state.range(0);
const int middle_dim = state.range(1);
RunReverseRowsBenchmark<uint8>(state, outer_dim, middle_dim,
1 /* intra_threads */, 4 /* channels */);
}
BENCHMARK(BM_ReverseRowsOf4Channels_1T_uint8)
->UseRealTime()
->ArgPair(288, 288)
->ArgPair(1024, 1024)
->ArgPair(10 * 1024, 1024);
static void BM_ReverseRowsOf4Channels_4T_float(int iters, int outer_dim,
int middle_dim) {
RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim,
void BM_ReverseRowsOf4Channels_4T_float(::testing::benchmark::State& state) {
const int outer_dim = state.range(0);
const int middle_dim = state.range(1);
RunReverseRowsBenchmark<float>(state, outer_dim, middle_dim,
4 /* intra_threads */, 4 /* channels */);
}
BENCHMARK(BM_ReverseRowsOf4Channels_4T_float)
->UseRealTime()
->ArgPair(288, 288)
->ArgPair(1024, 1024)
->ArgPair(10 * 1024, 1024);
static void BM_ReverseRowsOf4Channels_4T_uint8(int iters, int outer_dim,
int middle_dim) {
RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim,
void BM_ReverseRowsOf4Channels_4T_uint8(::testing::benchmark::State& state) {
const int outer_dim = state.range(0);
const int middle_dim = state.range(1);
RunReverseRowsBenchmark<uint8>(state, outer_dim, middle_dim,
4 /* intra_threads */, 4 /* channels */);
}
BENCHMARK(BM_ReverseRowsOf4Channels_4T_uint8)
->UseRealTime()
->ArgPair(288, 288)
->ArgPair(1024, 1024)
->ArgPair(10 * 1024, 1024);

View File

@ -451,30 +451,40 @@ static Graph* RollGraph(const TensorShape& shape, int isd) {
}
#define BM_ROLL_OUTER(DEVICE) \
static void BM_##DEVICE##_roll_outer(int iters, int rows, int columns) { \
static void BM_##DEVICE##_roll_outer(::testing::benchmark::State& state) { \
const int rows = state.range(0); \
const int columns = state.range(1); \
\
TensorShape shape{rows, columns}; \
const int64 num_items = static_cast<int64>(iters) * shape.num_elements(); \
testing::ItemsProcessed(num_items); \
testing::BytesProcessed(num_items * sizeof(float)); \
testing::UseRealTime(); \
test::Benchmark(#DEVICE, RollGraph(shape, 0)).Run(iters); \
test::Benchmark(#DEVICE, RollGraph(shape, 0), /*old_benchmark_api*/ false) \
.Run(state); \
const int64 num_items = \
static_cast<int64>(state.iterations()) * shape.num_elements(); \
state.SetItemsProcessed(num_items); \
state.SetBytesProcessed(num_items * sizeof(float)); \
} \
BENCHMARK(BM_##DEVICE##_roll_outer) \
->UseRealTime() \
->ArgPair(256, 256) \
->ArgPair(512, 512) \
->ArgPair(1024, 1024) \
->ArgPair(2048, 2048)
#define BM_ROLL_ALL(DEVICE) \
static void BM_##DEVICE##_roll_all(int iters, int rows, int columns) { \
static void BM_##DEVICE##_roll_all(::testing::benchmark::State& state) { \
const int rows = state.range(0); \
const int columns = state.range(1); \
\
TensorShape shape{rows, columns}; \
const int64 num_items = static_cast<int64>(iters) * shape.num_elements(); \
testing::ItemsProcessed(num_items); \
testing::BytesProcessed(num_items * sizeof(float)); \
testing::UseRealTime(); \
test::Benchmark(#DEVICE, RollGraph(shape, 1)).Run(iters); \
test::Benchmark(#DEVICE, RollGraph(shape, 1), /*old_benchmark_api*/ false) \
.Run(state); \
const int64 num_items = \
static_cast<int64>(state.iterations()) * shape.num_elements(); \
state.SetItemsProcessed(num_items); \
state.SetBytesProcessed(num_items * sizeof(float)); \
} \
BENCHMARK(BM_##DEVICE##_roll_all) \
->UseRealTime() \
->ArgPair(256, 256) \
->ArgPair(512, 512) \
->ArgPair(1024, 1024) \

View File

@ -663,8 +663,8 @@ TEST_F(SaveOpSlices2Test, TwoSlices) {
// Benchmark-related code below.
static void BM_LargeTensorWrite(int iters, int num_elements) {
testing::StopTiming();
void BM_LargeTensorWrite(::testing::benchmark::State& state) {
const int num_elements = state.range(0);
// 4 * num_elements bytes total , since sizeof(float) == 4.
Tensor tensor(DT_FLOAT, TensorShape({num_elements}));
@ -689,8 +689,9 @@ static void BM_LargeTensorWrite(int iters, int num_elements) {
VLOG(1) << "Save op's output path: " << temp_filename;
VLOG(1) << "# nodes in Graph: " << g->num_nodes();
testing::StartTiming();
test::Benchmark("cpu", g, &session_options).Run(iters);
test::Benchmark("cpu", g, &session_options, nullptr, nullptr, "",
/*old_benchmark_api*/ false)
.Run(state);
}
BENCHMARK(BM_LargeTensorWrite)->Arg((1 << 30) / 4 /* 1GB float tensor */);

View File

@ -67,79 +67,120 @@ static Graph* ThreeDYCumsum(int num_y, int num_z, bool reverse = false) {
}
template <typename T>
static void LargeOneDimensional(int iters, const string& device, int num_x,
static void LargeOneDimensional(::testing::benchmark::State& state,
const string& device, int num_x,
bool reverse = false) {
testing::ItemsProcessed(static_cast<int64>(iters) * num_x);
testing::BytesProcessed(static_cast<int64>(iters) * num_x * sizeof(T));
test::Benchmark(device, LargeOneDCumsum<T>(num_x, reverse)).Run(iters);
test::Benchmark(device, LargeOneDCumsum<T>(num_x, reverse),
/*old_benchmark_api*/ false)
.Run(state);
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x);
state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
sizeof(T));
}
static void DoRowCumsum(int iters, const string& device, int num_x, int num_y,
static void DoRowCumsum(::testing::benchmark::State& state,
const string& device, int num_x, int num_y,
bool reverse = false) {
testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
sizeof(float));
test::Benchmark(device, RowCumsum(num_x, num_y, reverse)).Run(iters);
test::Benchmark(device, RowCumsum(num_x, num_y, reverse),
/*old_benchmark_api*/ false)
.Run(state);
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x *
num_y);
state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
num_y * sizeof(float));
}
static void DoColCumsum(int iters, const string& device, int num_x, int num_y,
static void DoColCumsum(::testing::benchmark::State& state,
const string& device, int num_x, int num_y,
bool reverse = false) {
testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
sizeof(float));
test::Benchmark(device, ColCumsum(num_x, num_y, reverse)).Run(iters);
test::Benchmark(device, ColCumsum(num_x, num_y, reverse),
/*old_benchmark_api*/ false)
.Run(state);
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x *
num_y);
state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
num_y * sizeof(float));
}
static void Do3DYCumsum(int iters, const string& device, int num_x, int num_y,
static void Do3DYCumsum(::testing::benchmark::State& state,
const string& device, int num_x, int num_y,
bool reverse = false) {
testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
sizeof(float));
test::Benchmark(device, ThreeDYCumsum(num_x, num_y, reverse)).Run(iters);
test::Benchmark(device, ThreeDYCumsum(num_x, num_y, reverse),
/*old_benchmark_api*/ false)
.Run(state);
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x *
num_y);
state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
num_y * sizeof(float));
}
static void BM_OneDCumsumGPU(int iters, int num_x) {
LargeOneDimensional<float>(iters, "gpu", num_x);
static void BM_OneDCumsumGPU(::testing::benchmark::State& state) {
const int num_x = state.range(0);
LargeOneDimensional<float>(state, "gpu", num_x);
}
BENCHMARK(BM_OneDCumsumGPU)->Range(1, 1 << 21);
static void BM_OneDCumsumGPUHalf(int iters, int num_x) {
LargeOneDimensional<Eigen::half>(iters, "gpu", num_x);
static void BM_OneDCumsumGPUHalf(::testing::benchmark::State& state) {
const int num_x = state.range(0);
LargeOneDimensional<Eigen::half>(state, "gpu", num_x);
}
BENCHMARK(BM_OneDCumsumGPUHalf)->Range(1, 1 << 21);
static void BM_Sum2DRowCumsumGPU(int iters, int num_x, int num_y) {
DoRowCumsum(iters, "gpu", num_x, num_y);
static void BM_Sum2DRowCumsumGPU(::testing::benchmark::State& state) {
const int num_x = state.range(0);
const int num_y = state.range(1);
DoRowCumsum(state, "gpu", num_x, num_y);
}
BENCHMARK(BM_Sum2DRowCumsumGPU)->RangePair(1, 8192, 1, 8192);
static void BM_Sum2DColumnCumsumGPU(int iters, int num_x, int num_y) {
DoColCumsum(iters, "gpu", num_x, num_y);
static void BM_Sum2DColumnCumsumGPU(::testing::benchmark::State& state) {
const int num_x = state.range(0);
const int num_y = state.range(1);
DoColCumsum(state, "gpu", num_x, num_y);
}
BENCHMARK(BM_Sum2DColumnCumsumGPU)->RangePair(1, 8192, 1, 8192);
static void BM_Sum3DYCumsumGPU(int iters, int num_x, int num_y) {
Do3DYCumsum(iters, "gpu", num_x, num_y);
static void BM_Sum3DYCumsumGPU(::testing::benchmark::State& state) {
const int num_x = state.range(0);
const int num_y = state.range(1);
Do3DYCumsum(state, "gpu", num_x, num_y);
}
BENCHMARK(BM_Sum3DYCumsumGPU)->RangePair(64, 4096, 64, 4096);
static void BM_OneDCumsumGPU_reverse(int iters, int num_x) {
LargeOneDimensional<float>(iters, "gpu", num_x, true);
static void BM_OneDCumsumGPU_reverse(::testing::benchmark::State& state) {
const int num_x = state.range(0);
LargeOneDimensional<float>(state, "gpu", num_x, true);
}
BENCHMARK(BM_OneDCumsumGPU_reverse)->Range(1, 1 << 21);
static void BM_Sum2DRowCumsumGPU_reverse(int iters, int num_x, int num_y) {
DoRowCumsum(iters, "gpu", num_x, num_y, true);
static void BM_Sum2DRowCumsumGPU_reverse(::testing::benchmark::State& state) {
const int num_x = state.range(0);
const int num_y = state.range(1);
DoRowCumsum(state, "gpu", num_x, num_y, true);
}
BENCHMARK(BM_Sum2DRowCumsumGPU_reverse)->RangePair(1, 8192, 1, 8192);
static void BM_Sum2DColumnCumsumGPU_reverse(int iters, int num_x, int num_y) {
DoColCumsum(iters, "gpu", num_x, num_y, true);
static void BM_Sum2DColumnCumsumGPU_reverse(
::testing::benchmark::State& state) {
const int num_x = state.range(0);
const int num_y = state.range(1);
DoColCumsum(state, "gpu", num_x, num_y, true);
}
BENCHMARK(BM_Sum2DColumnCumsumGPU_reverse)->RangePair(1, 8192, 1, 8192);
static void BM_Sum3DYCumsumGPU_reverse(int iters, int num_x, int num_y) {
Do3DYCumsum(iters, "gpu", num_x, num_y, true);
static void BM_Sum3DYCumsumGPU_reverse(::testing::benchmark::State& state) {
const int num_x = state.range(0);
const int num_y = state.range(1);
Do3DYCumsum(state, "gpu", num_x, num_y, true);
}
BENCHMARK(BM_Sum3DYCumsumGPU_reverse)->RangePair(32, 2048, 32, 2048);

View File

@ -254,8 +254,8 @@ class ScatterNdUpdateBM : public ScatterNdUpdateOpTest {
};
template <typename Index>
static void BM_ScatterNdHelper(int iters, int embedding_size, const char* op) {
testing::StopTiming();
void BM_ScatterNdHelper(::testing::benchmark::State& state, int embedding_size,
const char* op) {
const int kRows = 10000000 / embedding_size;
std::vector<float> values;
values.reserve(kRows);
@ -280,27 +280,33 @@ static void BM_ScatterNdHelper(int iters, int embedding_size, const char* op) {
bm.AddInputFromArray<Index>(TensorShape({kNumUpdates}), indices);
bm.AddInputFromArray<float>(TensorShape({kNumUpdates, embedding_size}),
updates);
testing::ItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) *
iters);
testing::StartTiming();
while (iters-- > 0) {
for (auto i : state) {
Status s = bm.RunOpKernel();
}
testing::StopTiming();
state.SetItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) *
state.iterations());
}
static void BM_ScatterNdUpdateInt32(int iters, int embedding_size) {
BM_ScatterNdHelper<int32>(iters, embedding_size, "ScatterNdUpdate");
void BM_ScatterNdUpdateInt32(::testing::benchmark::State& state) {
const int embedding_size = state.range(0);
BM_ScatterNdHelper<int32>(state, embedding_size, "ScatterNdUpdate");
}
static void BM_ScatterNdUpdateInt64(int iters, int embedding_size) {
BM_ScatterNdHelper<int64>(iters, embedding_size, "ScatterNdUpdate");
void BM_ScatterNdUpdateInt64(::testing::benchmark::State& state) {
const int embedding_size = state.range(0);
BM_ScatterNdHelper<int64>(state, embedding_size, "ScatterNdUpdate");
}
static void BM_ScatterNdAddInt32(int iters, int embedding_size) {
BM_ScatterNdHelper<int32>(iters, embedding_size, "ScatterNdAdd");
void BM_ScatterNdAddInt32(::testing::benchmark::State& state) {
const int embedding_size = state.range(0);
BM_ScatterNdHelper<int32>(state, embedding_size, "ScatterNdAdd");
}
static void BM_ScatterNdAddInt64(int iters, int embedding_size) {
BM_ScatterNdHelper<int64>(iters, embedding_size, "ScatterNdAdd");
void BM_ScatterNdAddInt64(::testing::benchmark::State& state) {
const int embedding_size = state.range(0);
BM_ScatterNdHelper<int64>(state, embedding_size, "ScatterNdAdd");
}
BENCHMARK(BM_ScatterNdUpdateInt32)

View File

@ -280,9 +280,8 @@ class ScatterUpdateBM : public ScatterUpdateOpTest {
};
template <typename Index>
static void BM_ScatterHelper(int iters, int embedding_size, const char* op,
bool big_num_updates = false) {
testing::StopTiming();
void BM_ScatterHelper(::testing::benchmark::State& state, int embedding_size,
const char* op, bool big_num_updates = false) {
const int kRows = 10000000 / embedding_size;
std::vector<float> values;
values.reserve(kRows);
@ -307,59 +306,83 @@ static void BM_ScatterHelper(int iters, int embedding_size, const char* op,
bm.AddInputFromArray<Index>(TensorShape({kNumUpdates}), indices);
bm.AddInputFromArray<float>(TensorShape({kNumUpdates, embedding_size}),
updates);
testing::ItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) *
iters);
testing::StartTiming();
while (iters-- > 0) {
for (auto i : state) {
Status s = bm.RunOpKernel();
}
testing::StopTiming();
state.SetItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) *
state.iterations());
}
static void BM_ScatterUpdateInt32(int iters, int embedding_size) {
BM_ScatterHelper<int32>(iters, embedding_size, "ScatterUpdate");
void BM_ScatterUpdateInt32(::testing::benchmark::State& state) {
const int embedding_size = state.range(0);
BM_ScatterHelper<int32>(state, embedding_size, "ScatterUpdate");
}
static void BM_ScatterUpdateInt64(int iters, int embedding_size) {
BM_ScatterHelper<int64>(iters, embedding_size, "ScatterUpdate");
void BM_ScatterUpdateInt64(::testing::benchmark::State& state) {
const int embedding_size = state.range(0);
BM_ScatterHelper<int64>(state, embedding_size, "ScatterUpdate");
}
static void BM_ScatterAddInt32(int iters, int embedding_size) {
BM_ScatterHelper<int32>(iters, embedding_size, "ScatterAdd");
void BM_ScatterAddInt32(::testing::benchmark::State& state) {
const int embedding_size = state.range(0);
BM_ScatterHelper<int32>(state, embedding_size, "ScatterAdd");
}
static void BM_ScatterAddInt32Large(int iters, int embedding_size) {
BM_ScatterHelper<int32>(iters, embedding_size, "ScatterAdd", true);
void BM_ScatterAddInt32Large(::testing::benchmark::State& state) {
const int embedding_size = state.range(0);
BM_ScatterHelper<int32>(state, embedding_size, "ScatterAdd", true);
}
static void BM_ScatterAddInt64(int iters, int embedding_size) {
BM_ScatterHelper<int64>(iters, embedding_size, "ScatterAdd");
void BM_ScatterAddInt64(::testing::benchmark::State& state) {
const int embedding_size = state.range(0);
BM_ScatterHelper<int64>(state, embedding_size, "ScatterAdd");
}
static void BM_ScatterMulInt32(int iters, int embedding_size) {
BM_ScatterHelper<int32>(iters, embedding_size, "ScatterMul");
void BM_ScatterMulInt32(::testing::benchmark::State& state) {
const int embedding_size = state.range(0);
BM_ScatterHelper<int32>(state, embedding_size, "ScatterMul");
}
static void BM_ScatterMulInt64(int iters, int embedding_size) {
BM_ScatterHelper<int64>(iters, embedding_size, "ScatterMul");
void BM_ScatterMulInt64(::testing::benchmark::State& state) {
const int embedding_size = state.range(0);
BM_ScatterHelper<int64>(state, embedding_size, "ScatterMul");
}
static void BM_ScatterDivInt32(int iters, int embedding_size) {
BM_ScatterHelper<int32>(iters, embedding_size, "ScatterDiv");
void BM_ScatterDivInt32(::testing::benchmark::State& state) {
const int embedding_size = state.range(0);
BM_ScatterHelper<int32>(state, embedding_size, "ScatterDiv");
}
static void BM_ScatterDivInt64(int iters, int embedding_size) {
BM_ScatterHelper<int64>(iters, embedding_size, "ScatterDiv");
void BM_ScatterDivInt64(::testing::benchmark::State& state) {
const int embedding_size = state.range(0);
BM_ScatterHelper<int64>(state, embedding_size, "ScatterDiv");
}
static void BM_ScatterMinInt32(int iters, int embedding_size) {
BM_ScatterHelper<int32>(iters, embedding_size, "ScatterMin");
void BM_ScatterMinInt32(::testing::benchmark::State& state) {
const int embedding_size = state.range(0);
BM_ScatterHelper<int32>(state, embedding_size, "ScatterMin");
}
static void BM_ScatterMinInt64(int iters, int embedding_size) {
BM_ScatterHelper<int64>(iters, embedding_size, "ScatterMin");
void BM_ScatterMinInt64(::testing::benchmark::State& state) {
const int embedding_size = state.range(0);
BM_ScatterHelper<int64>(state, embedding_size, "ScatterMin");
}
static void BM_ScatterMaxInt32(int iters, int embedding_size) {
BM_ScatterHelper<int32>(iters, embedding_size, "ScatterMax");
void BM_ScatterMaxInt32(::testing::benchmark::State& state) {
const int embedding_size = state.range(0);
BM_ScatterHelper<int32>(state, embedding_size, "ScatterMax");
}
static void BM_ScatterMaxInt64(int iters, int embedding_size) {
BM_ScatterHelper<int64>(iters, embedding_size, "ScatterMax");
void BM_ScatterMaxInt64(::testing::benchmark::State& state) {
const int embedding_size = state.range(0);
BM_ScatterHelper<int64>(state, embedding_size, "ScatterMax");
}
BENCHMARK(BM_ScatterUpdateInt32)

Some files were not shown because too many files have changed in this diff Show More